Merge pull request #11 from MotifTechnologies/ca1207-patch-1
Browse files- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_3dafb3e_dirty.abi3.so → _optimizer_b0230e7_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +8 -9
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +8 -9
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +8 -9
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_3dafb3e_dirty.abi3.so → _optimizer_b0230e7_dirty.abi3.so} +1 -1
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +8 -9
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +8 -9
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +8 -9
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +8 -9
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so +0 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so +3 -0
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +8 -9
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so +0 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so +3 -0
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +8 -9
- torch-ext/optimizer/muon.py +8 -9
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b0230e7_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b0230e7_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b0230e7_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_3dafb3e_dirty.abi3.so → _optimizer_b0230e7_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787368
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:236bb0d67cbb2718b076637569923cf240de1c7a074790623ecb9c049fca9732
|
| 3 |
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b0230e7_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b0230e7_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b0230e7_dirty::{op_name}"
|
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824256
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69525fcbfbe640264f4d52c9843b395b17f1828d38e1eceb97cec6bf46b0d8d0
|
| 3 |
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b0230e7_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b0230e7_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b0230e7_dirty::{op_name}"
|
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883344
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:331cc0bc5ee469afdfe0fc590bf52910c118cd0cec62ccbf85778c12ae367a95
|
| 3 |
size 1883344
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b0230e7_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b0230e7_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b0230e7_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_3dafb3e_dirty.abi3.so → _optimizer_b0230e7_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1749776
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21d5da3673206b979eaba9dd6d8918d7745ecd3bd3715e55105fd57c234a3a42
|
| 3 |
size 1749776
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b0230e7_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b0230e7_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b0230e7_dirty::{op_name}"
|
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824256
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69525fcbfbe640264f4d52c9843b395b17f1828d38e1eceb97cec6bf46b0d8d0
|
| 3 |
size 1824256
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b0230e7_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b0230e7_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b0230e7_dirty::{op_name}"
|
build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883344
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:331cc0bc5ee469afdfe0fc590bf52910c118cd0cec62ccbf85778c12ae367a95
|
| 3 |
size 1883344
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b0230e7_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b0230e7_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b0230e7_dirty::{op_name}"
|
build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883344
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6ba7ad9228edcce4bf49173562b0796f1657eb734ddd6e23ca773c153eefce2
|
| 3 |
size 1883344
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b0230e7_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b0230e7_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b0230e7_dirty::{op_name}"
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:16e86e5507db995c4c1f0215f030f4153c5e240c294a4c089984f59c6bacf3c2
|
| 3 |
-
size 1749936
|
|
|
|
|
|
|
|
|
|
|
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:649c9c1ca7360650167cc191e373b271a4138161ec40b1e881a87515f82a613f
|
| 3 |
+
size 1750000
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b0230e7_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b0230e7_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b0230e7_dirty::{op_name}"
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:78ba936e881be0307ed7f13b01c299fb4054a0f60b47f0092adc9075a5752af0
|
| 3 |
-
size 1750024
|
|
|
|
|
|
|
|
|
|
|
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:42b60753dab0948f4009893fcf3a8b080ad00e0436cbdaf0995dc29ae066c0c7
|
| 3 |
+
size 1750088
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|