Merge pull request #10 from MotifTechnologies/fix_a2a_gs_assert
Browse files- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β _optimizer_3dafb3e_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +15 -6
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +15 -6
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +15 -6
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β _optimizer_3dafb3e_dirty.abi3.so} +1 -1
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +15 -6
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +15 -6
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +15 -6
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +15 -6
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β _optimizer_3dafb3e_dirty.abi3.so} +1 -1
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +15 -6
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β _optimizer_3dafb3e_dirty.abi3.so} +1 -1
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +15 -6
- torch-ext/optimizer/muon.py +15 -6
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_3dafb3e_dirty
|
| 3 |
+
ops = torch.ops._optimizer_3dafb3e_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_3dafb3e_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β _optimizer_3dafb3e_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787368
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ac95d5ee6d89fa59a5832429cf1583bd5d7d4eaea2e3eda3424363d79289733c
|
| 3 |
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_3dafb3e_dirty
|
| 3 |
+
ops = torch.ops._optimizer_3dafb3e_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_3dafb3e_dirty::{op_name}"
|
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824256
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cd7fd63421d33d894dcd8fee8cd0734eba1132d27d72ce56cbe1e316f146b4d
|
| 3 |
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_3dafb3e_dirty
|
| 3 |
+
ops = torch.ops._optimizer_3dafb3e_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_3dafb3e_dirty::{op_name}"
|
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883344
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c84fab3b40edda427ccb78c2473f6be130a815f303c4d6c14d277401a1853d2
|
| 3 |
size 1883344
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_3dafb3e_dirty
|
| 3 |
+
ops = torch.ops._optimizer_3dafb3e_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_3dafb3e_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β _optimizer_3dafb3e_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1749776
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1b01d97695b1dbca8bfe33e1e106f3988963d5689cfd31b5766060ea00dc79e3
|
| 3 |
size 1749776
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_3dafb3e_dirty
|
| 3 |
+
ops = torch.ops._optimizer_3dafb3e_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_3dafb3e_dirty::{op_name}"
|
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824256
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cd7fd63421d33d894dcd8fee8cd0734eba1132d27d72ce56cbe1e316f146b4d
|
| 3 |
size 1824256
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_3dafb3e_dirty
|
| 3 |
+
ops = torch.ops._optimizer_3dafb3e_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_3dafb3e_dirty::{op_name}"
|
build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883344
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c84fab3b40edda427ccb78c2473f6be130a815f303c4d6c14d277401a1853d2
|
| 3 |
size 1883344
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_3dafb3e_dirty
|
| 3 |
+
ops = torch.ops._optimizer_3dafb3e_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_3dafb3e_dirty::{op_name}"
|
build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so β torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883344
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:048047c599b0ac03395691515edb9f1320b6c46d3437763859c8cfbbe763951a
|
| 3 |
size 1883344
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_3dafb3e_dirty
|
| 3 |
+
ops = torch.ops._optimizer_3dafb3e_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_3dafb3e_dirty::{op_name}"
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β _optimizer_3dafb3e_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1749936
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:16e86e5507db995c4c1f0215f030f4153c5e240c294a4c089984f59c6bacf3c2
|
| 3 |
size 1749936
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_3dafb3e_dirty
|
| 3 |
+
ops = torch.ops._optimizer_3dafb3e_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_3dafb3e_dirty::{op_name}"
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_15336dc_dirty.abi3.so β _optimizer_3dafb3e_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1750024
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78ba936e881be0307ed7f13b01c299fb4054a0f60b47f0092adc9075a5752af0
|
| 3 |
size 1750024
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|