Kernels
ca1207 commited on
Commit
3dafb3e
·
1 Parent(s): 94799ac

fix assert in a2a gather scatter

Browse files
Files changed (1) hide show
  1. torch-ext/optimizer/muon.py +15 -6
torch-ext/optimizer/muon.py CHANGED
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
- assert all(
132
- len(v) > 0
133
- for v in per_dst), "all params should be sharded to all devices"
 
 
 
 
134
 
135
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
  owned_params = [
137
  p for p in params if param_to_state[id(p)].worker_rank == rank
138
  ]
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  assert offset == u_full.numel()
290
 
291
- if any(len(v) > 0 for v in per_dst):
292
- send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
 
 
 
 
 
 
293
  else:
294
  # all_to_all requires participation from all ranks
295
  # Even non-owner ranks must join the collective call
 
128
  per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
 
 
139
  owned_params = [
140
  p for p in params if param_to_state[id(p)].worker_rank == rank
141
  ]
 
291
 
292
  assert offset == u_full.numel()
293
 
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
  else:
303
  # all_to_all requires participation from all ranks
304
  # Even non-owner ranks must join the collective call