fix assert in a2a gather scatter
Browse files- torch-ext/optimizer/muon.py +15 -6
torch-ext/optimizer/muon.py
CHANGED
|
@@ -128,11 +128,14 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
-
assert
|
| 132 |
-
len(v) > 0
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 136 |
owned_params = [
|
| 137 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 138 |
]
|
|
@@ -288,8 +291,14 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
assert offset == u_full.numel()
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
else:
|
| 294 |
# all_to_all requires participation from all ranks
|
| 295 |
# Even non-owner ranks must join the collective call
|
|
|
|
| 128 |
per_dst[dst].append(g)
|
| 129 |
send_counts[dst] += shard_elems
|
| 130 |
|
| 131 |
+
assert any(
|
| 132 |
+
len(v) > 0 for v in per_dst
|
| 133 |
+
), "At least one destination rank must receive a sharded tensor"
|
| 134 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 135 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 136 |
+
|
| 137 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 138 |
|
|
|
|
| 139 |
owned_params = [
|
| 140 |
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 141 |
]
|
|
|
|
| 291 |
|
| 292 |
assert offset == u_full.numel()
|
| 293 |
|
| 294 |
+
lengths = [len(v) for v in per_dst]
|
| 295 |
+
if all(l > 0 for l in lengths):
|
| 296 |
+
assert all(
|
| 297 |
+
l == lengths[0] for l in lengths
|
| 298 |
+
), "All destination ranks must have the same number of sharded tensor"
|
| 299 |
+
# list[list[Tensor]] -> list[Tensor]
|
| 300 |
+
per_dst = [t for dst in per_dst for t in dst]
|
| 301 |
+
send_buf = torch.cat(per_dst, dim=0)
|
| 302 |
else:
|
| 303 |
# all_to_all requires participation from all ranks
|
| 304 |
# Even non-owner ranks must join the collective call
|