Kernels
TaehyunKim commited on
Commit
53deea3
·
unverified ·
2 Parent(s): 4f71bc9 de5bead

Merge pull request #11 from MotifTechnologies/ca1207-patch-1

Browse files
Files changed (30) hide show
  1. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  2. build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_3dafb3e_dirty.abi3.so → _optimizer_b0230e7_dirty.abi3.so} +1 -1
  3. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +8 -9
  4. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  5. build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} +1 -1
  6. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +8 -9
  7. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  8. build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} +1 -1
  9. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +8 -9
  10. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  11. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_3dafb3e_dirty.abi3.so → _optimizer_b0230e7_dirty.abi3.so} +1 -1
  12. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +8 -9
  13. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  14. build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} +1 -1
  15. build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +8 -9
  16. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  17. build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} +1 -1
  18. build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +8 -9
  19. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
  20. build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} +1 -1
  21. build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +8 -9
  22. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  23. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so +0 -3
  24. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so +3 -0
  25. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +8 -9
  26. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  27. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so +0 -3
  28. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so +3 -0
  29. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +8 -9
  30. torch-ext/optimizer/muon.py +8 -9
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_3dafb3e_dirty
3
- ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_3dafb3e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b0230e7_dirty
3
+ ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b0230e7_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_3dafb3e_dirty.abi3.so → _optimizer_b0230e7_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ac95d5ee6d89fa59a5832429cf1583bd5d7d4eaea2e3eda3424363d79289733c
3
  size 1787368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:236bb0d67cbb2718b076637569923cf240de1c7a074790623ecb9c049fca9732
3
  size 1787368
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_3dafb3e_dirty
3
- ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_3dafb3e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b0230e7_dirty
3
+ ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b0230e7_dirty::{op_name}"
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0cd7fd63421d33d894dcd8fee8cd0734eba1132d27d72ce56cbe1e316f146b4d
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69525fcbfbe640264f4d52c9843b395b17f1828d38e1eceb97cec6bf46b0d8d0
3
  size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_3dafb3e_dirty
3
- ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_3dafb3e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b0230e7_dirty
3
+ ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b0230e7_dirty::{op_name}"
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c84fab3b40edda427ccb78c2473f6be130a815f303c4d6c14d277401a1853d2
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:331cc0bc5ee469afdfe0fc590bf52910c118cd0cec62ccbf85778c12ae367a95
3
  size 1883344
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_3dafb3e_dirty
3
- ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_3dafb3e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b0230e7_dirty
3
+ ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b0230e7_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_3dafb3e_dirty.abi3.so → _optimizer_b0230e7_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b01d97695b1dbca8bfe33e1e106f3988963d5689cfd31b5766060ea00dc79e3
3
  size 1749776
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21d5da3673206b979eaba9dd6d8918d7745ecd3bd3715e55105fd57c234a3a42
3
  size 1749776
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_3dafb3e_dirty
3
- ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_3dafb3e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b0230e7_dirty
3
+ ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b0230e7_dirty::{op_name}"
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0cd7fd63421d33d894dcd8fee8cd0734eba1132d27d72ce56cbe1e316f146b4d
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69525fcbfbe640264f4d52c9843b395b17f1828d38e1eceb97cec6bf46b0d8d0
3
  size 1824256
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_3dafb3e_dirty
3
- ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_3dafb3e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b0230e7_dirty
3
+ ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b0230e7_dirty::{op_name}"
build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:048047c599b0ac03395691515edb9f1320b6c46d3437763859c8cfbbe763951a
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:331cc0bc5ee469afdfe0fc590bf52910c118cd0cec62ccbf85778c12ae367a95
3
  size 1883344
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_3dafb3e_dirty
3
- ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_3dafb3e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b0230e7_dirty
3
+ ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b0230e7_dirty::{op_name}"
build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c84fab3b40edda427ccb78c2473f6be130a815f303c4d6c14d277401a1853d2
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6ba7ad9228edcce4bf49173562b0796f1657eb734ddd6e23ca773c153eefce2
3
  size 1883344
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_3dafb3e_dirty
3
- ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_3dafb3e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b0230e7_dirty
3
+ ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b0230e7_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:16e86e5507db995c4c1f0215f030f4153c5e240c294a4c089984f59c6bacf3c2
3
- size 1749936
 
 
 
 
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:649c9c1ca7360650167cc191e373b271a4138161ec40b1e881a87515f82a613f
3
+ size 1750000
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_3dafb3e_dirty
3
- ops = torch.ops._optimizer_3dafb3e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_3dafb3e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b0230e7_dirty
3
+ ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b0230e7_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_3dafb3e_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:78ba936e881be0307ed7f13b01c299fb4054a0f60b47f0092adc9075a5752af0
3
- size 1750024
 
 
 
 
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42b60753dab0948f4009893fcf3a8b080ad00e0436cbdaf0995dc29ae066c0c7
3
+ size 1750088
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(
torch-ext/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(