Kernels
TaehyunKim commited on
Commit
4f71bc9
Β·
unverified Β·
2 Parent(s): 94799ac aee4dc0

Merge pull request #10 from MotifTechnologies/fix_a2a_gs_assert

Browse files
Files changed (28) hide show
  1. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  2. build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β†’ _optimizer_3dafb3e_dirty.abi3.so} +1 -1
  3. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +15 -6
  4. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  5. build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β†’ torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} +1 -1
  6. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +15 -6
  7. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  8. build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β†’ torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} +1 -1
  9. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +15 -6
  10. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  11. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β†’ _optimizer_3dafb3e_dirty.abi3.so} +1 -1
  12. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +15 -6
  13. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  14. build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β†’ torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} +1 -1
  15. build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +15 -6
  16. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  17. build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β†’ torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} +1 -1
  18. build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +15 -6
  19. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
  20. build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β†’ torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} +1 -1
  21. build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +15 -6
  22. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  23. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β†’ _optimizer_3dafb3e_dirty.abi3.so} +1 -1
  24. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +15 -6
  25. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  26. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β†’ _optimizer_3dafb3e_dirty.abi3.so} +1 -1
  27. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +15 -6
  28. torch-ext/optimizer/muon.py +15 -6
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_15336dc_dirty
3
- ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_15336dc_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_3dafb3e_dirty
3
+ ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_3dafb3e_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β†’ _optimizer_3dafb3e_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:94a28c3602d8c7a6b216976b1fb09cdd1e9f61bfc9359a80f41b5b628efdfc28
3
  size 1787368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac95d5ee6d89fa59a5832429cf1583bd5d7d4eaea2e3eda3424363d79289733c
3
  size 1787368
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_15336dc_dirty
3
- ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_15336dc_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_3dafb3e_dirty
3
+ ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_3dafb3e_dirty::{op_name}"
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β†’ torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8ca6ca8225dc9b7888566f5c7fd824234a3b4ac76718a5d18e6c75ca7acd488d
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cd7fd63421d33d894dcd8fee8cd0734eba1132d27d72ce56cbe1e316f146b4d
3
  size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_15336dc_dirty
3
- ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_15336dc_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_3dafb3e_dirty
3
+ ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_3dafb3e_dirty::{op_name}"
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β†’ torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e06baa32b0950126ee192654bd9f7adc79cc05d8ec39d2078c70d62ee81fdcd5
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c84fab3b40edda427ccb78c2473f6be130a815f303c4d6c14d277401a1853d2
3
  size 1883344
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_15336dc_dirty
3
- ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_15336dc_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_3dafb3e_dirty
3
+ ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_3dafb3e_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β†’ _optimizer_3dafb3e_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c7cf2f7b8519dbc3f20e9d151914b55e56d10c012e2232d550b7c8d262746d71
3
  size 1749776
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b01d97695b1dbca8bfe33e1e106f3988963d5689cfd31b5766060ea00dc79e3
3
  size 1749776
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_15336dc_dirty
3
- ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_15336dc_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_3dafb3e_dirty
3
+ ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_3dafb3e_dirty::{op_name}"
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β†’ torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8ca6ca8225dc9b7888566f5c7fd824234a3b4ac76718a5d18e6c75ca7acd488d
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cd7fd63421d33d894dcd8fee8cd0734eba1132d27d72ce56cbe1e316f146b4d
3
  size 1824256
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_15336dc_dirty
3
- ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_15336dc_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_3dafb3e_dirty
3
+ ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_3dafb3e_dirty::{op_name}"
build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β†’ torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6880c22f63ccd66e8ac62792a564d1ade58325b47369a1773c7753d4243893b9
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c84fab3b40edda427ccb78c2473f6be130a815f303c4d6c14d277401a1853d2
3
  size 1883344
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_15336dc_dirty
3
- ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_15336dc_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_3dafb3e_dirty
3
+ ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_3dafb3e_dirty::{op_name}"
build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β†’ torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e06baa32b0950126ee192654bd9f7adc79cc05d8ec39d2078c70d62ee81fdcd5
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:048047c599b0ac03395691515edb9f1320b6c46d3437763859c8cfbbe763951a
3
  size 1883344
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_15336dc_dirty
3
- ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_15336dc_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_3dafb3e_dirty
3
+ ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_3dafb3e_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β†’ _optimizer_3dafb3e_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ae22a3afdffd54435c6e5b145fc0b7772d03eb8c8bad0d388d9b2d1c8d2f60d5
3
  size 1749936
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16e86e5507db995c4c1f0215f030f4153c5e240c294a4c089984f59c6bacf3c2
3
  size 1749936
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_15336dc_dirty
3
- ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_15336dc_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_3dafb3e_dirty
3
+ ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_3dafb3e_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β†’ _optimizer_3dafb3e_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8092bc6ee3e353b2188f0874bc7f145e4eafd0366a40da9750c225732961f7c7
3
  size 1750024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78ba936e881be0307ed7f13b01c299fb4054a0f60b47f0092adc9075a5752af0
3
  size 1750024
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call
torch-ext/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call