Merge pull request #1 from MotifTechnologies/pre-commit_test_and_apply_lint
Browse files- .github/workflows/pre-commit.yml +30 -0
- .pre-commit-config.yaml +37 -0
- README.md +47 -2
- optimizer/dummy.cu +1 -1
- torch-ext/optimizer/muon.py +31 -24
.github/workflows/pre-commit.yml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: pre-commit
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
pull_request:
|
| 5 |
+
push:
|
| 6 |
+
branches: [ main, master ]
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
run-pre-commit:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
permissions:
|
| 12 |
+
contents: read
|
| 13 |
+
pull-requests: read
|
| 14 |
+
steps:
|
| 15 |
+
- uses: actions/checkout@v4
|
| 16 |
+
|
| 17 |
+
- uses: actions/setup-python@v5
|
| 18 |
+
with:
|
| 19 |
+
python-version: "3.11"
|
| 20 |
+
|
| 21 |
+
- name: Cache pre-commit
|
| 22 |
+
uses: actions/cache@v4
|
| 23 |
+
with:
|
| 24 |
+
path: ~/.cache/pre-commit
|
| 25 |
+
key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}
|
| 26 |
+
restore-keys: |
|
| 27 |
+
pre-commit-${{ runner.os }}-
|
| 28 |
+
|
| 29 |
+
- name: Run pre-commit
|
| 30 |
+
uses: pre-commit/[email protected]
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_install_hook_types:
|
| 2 |
+
- pre-commit
|
| 3 |
+
- commit-msg
|
| 4 |
+
default_stages:
|
| 5 |
+
- pre-commit # Run locally
|
| 6 |
+
- manual # Run in CI
|
| 7 |
+
exclude: '(build|result)/.*'
|
| 8 |
+
repos:
|
| 9 |
+
- repo: https://github.com/google/yapf
|
| 10 |
+
rev: v0.43.0
|
| 11 |
+
hooks:
|
| 12 |
+
- id: yapf
|
| 13 |
+
args: [--in-place, --verbose]
|
| 14 |
+
- repo: https://github.com/crate-ci/typos
|
| 15 |
+
rev: v1.34.0
|
| 16 |
+
hooks:
|
| 17 |
+
- id: typos
|
| 18 |
+
exclude: '.gitattributes'
|
| 19 |
+
- repo: https://github.com/PyCQA/isort
|
| 20 |
+
rev: 6.0.1
|
| 21 |
+
hooks:
|
| 22 |
+
- id: isort
|
| 23 |
+
- repo: https://github.com/pre-commit/mirrors-clang-format
|
| 24 |
+
rev: v20.1.3
|
| 25 |
+
hooks:
|
| 26 |
+
- id: clang-format
|
| 27 |
+
types_or: [c++, cuda]
|
| 28 |
+
args: [--style=file, --verbose]
|
| 29 |
+
- repo: https://github.com/jackdewinter/pymarkdown
|
| 30 |
+
rev: v0.9.29
|
| 31 |
+
hooks:
|
| 32 |
+
- id: pymarkdown
|
| 33 |
+
args: [fix]
|
| 34 |
+
- repo: https://github.com/rhysd/actionlint
|
| 35 |
+
rev: v1.7.7
|
| 36 |
+
hooks:
|
| 37 |
+
- id: actionlint
|
README.md
CHANGED
|
@@ -10,7 +10,7 @@ Optimizer is a python package that provides:
|
|
| 10 |
- PyTorch implementation of recent optimizer algorithms
|
| 11 |
- with support for parallelism techniques for efficient large-scale training.
|
| 12 |
|
| 13 |
-
|
| 14 |
- [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf)
|
| 15 |
|
| 16 |
## Usage
|
|
@@ -31,4 +31,49 @@ optim = optimizer.Muon(
|
|
| 31 |
momentum=0.9,
|
| 32 |
weight_decay=1e-4,
|
| 33 |
)
|
| 34 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
- PyTorch implementation of recent optimizer algorithms
|
| 11 |
- with support for parallelism techniques for efficient large-scale training.
|
| 12 |
|
| 13 |
+
## Currently implemented
|
| 14 |
- [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf)
|
| 15 |
|
| 16 |
## Usage
|
|
|
|
| 31 |
momentum=0.9,
|
| 32 |
weight_decay=1e-4,
|
| 33 |
)
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Pre-commit Hooks
|
| 37 |
+
|
| 38 |
+
This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
|
| 39 |
+
|
| 40 |
+
### Setup
|
| 41 |
+
|
| 42 |
+
1. Install pre-commit:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
pip install pre-commit
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
2. Install the git hooks:
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
pre-commit install
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Once installed, the configured hooks will run automatically on each commit.
|
| 55 |
+
|
| 56 |
+
### Included Hooks
|
| 57 |
+
|
| 58 |
+
The following tools are run via pre-commit:
|
| 59 |
+
|
| 60 |
+
- **[yapf](https://github.com/google/yapf)** – Python code formatter
|
| 61 |
+
- **[typos](https://github.com/crate-ci/typos)** – Spell checker for common typos
|
| 62 |
+
- **[isort](https://github.com/PyCQA/isort)** – Organizes and sorts Python imports
|
| 63 |
+
- **[clang-format](https://clang.llvm.org/docs/ClangFormat.html)** – Formats C++/CUDA code (`--style=file`)
|
| 64 |
+
- **[pymarkdown](https://github.com/jackdewinter/pymarkdown)** – Lints and auto-fixes Markdown files
|
| 65 |
+
- **[actionlint](https://github.com/rhysd/actionlint)** – Validates GitHub Actions workflows
|
| 66 |
+
|
| 67 |
+
### Usage
|
| 68 |
+
|
| 69 |
+
- Run all checks on the entire codebase:
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
pre-commit run --all-files
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
- Run a specific hook (example: isort):
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
pre-commit run isort --all-files
|
| 79 |
+
```
|
optimizer/dummy.cu
CHANGED
|
@@ -3,4 +3,4 @@ namespace {
|
|
| 3 |
__global__ void dummy() {
|
| 4 |
// This kernel does nothing but serves as a placeholder
|
| 5 |
}
|
| 6 |
-
}
|
|
|
|
| 3 |
__global__ void dummy() {
|
| 4 |
// This kernel does nothing but serves as a placeholder
|
| 5 |
}
|
| 6 |
+
} // namespace
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -59,7 +59,9 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
-
gather_list = [
|
|
|
|
|
|
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
@@ -73,8 +75,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
-
"Gather event already exists, which should not happen."
|
| 77 |
-
)
|
| 78 |
state.gathered_grad = torch.cat(gather_list, dim=0)
|
| 79 |
state.gather_event = torch.cuda.Event()
|
| 80 |
state.gather_event.record()
|
|
@@ -240,9 +241,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 240 |
"""
|
| 241 |
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
"""
|
| 243 |
-
assert isinstance(
|
|
|
|
| 244 |
|
| 245 |
-
if p.placements == (Shard(dim=0),):
|
| 246 |
# Case for FSDP
|
| 247 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
|
@@ -269,11 +271,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 269 |
total_flops += flops
|
| 270 |
|
| 271 |
if self.debug:
|
| 272 |
-
print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
|
|
|
|
| 273 |
|
| 274 |
-
ordered_params = sorted(
|
| 275 |
-
|
| 276 |
-
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
|
@@ -369,28 +372,29 @@ class Muon(torch.optim.Optimizer):
|
|
| 369 |
p.grad = g
|
| 370 |
|
| 371 |
param_to_state, ordered_params = self.init_state_and_assign_params(
|
| 372 |
-
params, group
|
| 373 |
-
)
|
| 374 |
|
| 375 |
def enqueue_gathers(start_idx, chunk_size):
|
| 376 |
-
for p in ordered_params[start_idx
|
| 377 |
state = param_to_state[id(p)]
|
| 378 |
-
_gather(p, state, self.rank, self.comm_stream,
|
|
|
|
| 379 |
|
| 380 |
def enqueue_computes(start_idx, chunk_size):
|
| 381 |
-
for p in ordered_params[start_idx
|
| 382 |
state = param_to_state[id(p)]
|
| 383 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
|
|
|
| 384 |
|
| 385 |
def enqueue_scatters(start_idx, chunk_size):
|
| 386 |
-
for p in ordered_params[start_idx
|
| 387 |
state = param_to_state[id(p)]
|
| 388 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 389 |
-
_scatter(
|
| 390 |
-
|
| 391 |
-
)
|
| 392 |
|
| 393 |
-
chunk_size = dist.get_world_size(param_to_state[id(
|
|
|
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
@@ -436,15 +440,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 436 |
continue
|
| 437 |
if isinstance(p.data, DTensor):
|
| 438 |
if all(
|
| 439 |
-
|
| 440 |
-
|
| 441 |
param_tensors.append(p)
|
| 442 |
else:
|
| 443 |
param_dtensors.append(p)
|
| 444 |
elif isinstance(p.data, torch.Tensor):
|
| 445 |
param_tensors.append(p)
|
| 446 |
else:
|
| 447 |
-
raise TypeError(
|
|
|
|
| 448 |
|
| 449 |
if self.debug:
|
| 450 |
print(
|
|
@@ -479,7 +484,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 479 |
# AdamW backup #
|
| 480 |
############################
|
| 481 |
|
| 482 |
-
params = [
|
|
|
|
|
|
|
| 483 |
lr = group["lr"]
|
| 484 |
beta1, beta2 = group["adamw_betas"]
|
| 485 |
eps = group["adamw_eps"]
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [
|
| 63 |
+
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 64 |
+
]
|
| 65 |
else:
|
| 66 |
gather_list = None
|
| 67 |
|
|
|
|
| 75 |
if rank == state.worker_rank:
|
| 76 |
if state.gathered_grad is not None:
|
| 77 |
raise RuntimeError(
|
| 78 |
+
"Gather event already exists, which should not happen.")
|
|
|
|
| 79 |
state.gathered_grad = torch.cat(gather_list, dim=0)
|
| 80 |
state.gather_event = torch.cuda.Event()
|
| 81 |
state.gather_event.record()
|
|
|
|
| 241 |
"""
|
| 242 |
Get the shard mesh for a parameter p on the given rank.
|
| 243 |
"""
|
| 244 |
+
assert isinstance(
|
| 245 |
+
p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 246 |
|
| 247 |
+
if p.placements == (Shard(dim=0), ):
|
| 248 |
# Case for FSDP
|
| 249 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 250 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
|
|
|
| 271 |
total_flops += flops
|
| 272 |
|
| 273 |
if self.debug:
|
| 274 |
+
print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
|
| 275 |
+
flush=True)
|
| 276 |
|
| 277 |
+
ordered_params = sorted(params,
|
| 278 |
+
key=lambda p: param_to_flops[id(p)],
|
| 279 |
+
reverse=True)
|
| 280 |
|
| 281 |
round_robin = 0
|
| 282 |
mesh = None
|
|
|
|
| 372 |
p.grad = g
|
| 373 |
|
| 374 |
param_to_state, ordered_params = self.init_state_and_assign_params(
|
| 375 |
+
params, group)
|
|
|
|
| 376 |
|
| 377 |
def enqueue_gathers(start_idx, chunk_size):
|
| 378 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 379 |
state = param_to_state[id(p)]
|
| 380 |
+
_gather(p, state, self.rank, self.comm_stream,
|
| 381 |
+
group["none_grad"])
|
| 382 |
|
| 383 |
def enqueue_computes(start_idx, chunk_size):
|
| 384 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 385 |
state = param_to_state[id(p)]
|
| 386 |
+
_compute_u(state, group["ns_steps"], self.rank,
|
| 387 |
+
self.compute_stream)
|
| 388 |
|
| 389 |
def enqueue_scatters(start_idx, chunk_size):
|
| 390 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 391 |
state = param_to_state[id(p)]
|
| 392 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 393 |
+
_scatter(p, state, lr, adjusted_lr, weight_decay, self.rank,
|
| 394 |
+
self.comm_stream)
|
|
|
|
| 395 |
|
| 396 |
+
chunk_size = dist.get_world_size(param_to_state[id(
|
| 397 |
+
params[0])].process_group)
|
| 398 |
|
| 399 |
# Wait grad update
|
| 400 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 440 |
continue
|
| 441 |
if isinstance(p.data, DTensor):
|
| 442 |
if all(
|
| 443 |
+
isinstance(placement, Replicate)
|
| 444 |
+
for placement in p.placements):
|
| 445 |
param_tensors.append(p)
|
| 446 |
else:
|
| 447 |
param_dtensors.append(p)
|
| 448 |
elif isinstance(p.data, torch.Tensor):
|
| 449 |
param_tensors.append(p)
|
| 450 |
else:
|
| 451 |
+
raise TypeError(
|
| 452 |
+
f"Unsupported parameter type: {type(p.data)}")
|
| 453 |
|
| 454 |
if self.debug:
|
| 455 |
print(
|
|
|
|
| 484 |
# AdamW backup #
|
| 485 |
############################
|
| 486 |
|
| 487 |
+
params = [
|
| 488 |
+
p for p in group["params"] if not self.state[p]["use_muon"]
|
| 489 |
+
]
|
| 490 |
lr = group["lr"]
|
| 491 |
beta1, beta2 = group["adamw_betas"]
|
| 492 |
eps = group["adamw_eps"]
|