|
|
|
|
|
""" |
|
|
Nanomind pretraining script for decoder-only causal LM on JSONL.gz data. |
|
|
|
|
|
- Expects input file with one JSON object per line containing a `text` field. |
|
|
- Streams, tokenizes, and packs sequences to a fixed length for efficient training. |
|
|
- Uses a small LLaMA-style config by default (RMSNorm + SwiGLU + RoPE, MQA). |
|
|
|
|
|
Usage example: |
|
|
python /workspace/nanomind/train.py \ |
|
|
--data_path /workspace/nanomind_data/pretrain_1m.jsonl.gz \ |
|
|
--out_dir /workspace/nanomind_runs/run1 \ |
|
|
--tokenizer_name hf-internal-testing/llama-tokenizer \ |
|
|
--seq_len 4096 --global_batch_size 256 \ |
|
|
--lr 1e-3 --warmup_steps 2000 --max_steps 50000 --bf16 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import io |
|
|
import gc |
|
|
import gzip |
|
|
import json |
|
|
import math |
|
|
import time |
|
|
import random |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import Iterator, List, Dict, Optional |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.utils.data import IterableDataset, DataLoader |
|
|
|
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
LlamaConfig, |
|
|
LlamaForCausalLM, |
|
|
get_cosine_schedule_with_warmup, |
|
|
) |
|
|
|
|
|
|
|
|
class JsonlPackedDataset(IterableDataset): |
|
|
""" |
|
|
Streams a JSONL(.gz) file of objects with a `text` field, tokenizes, and |
|
|
packs tokens into fixed-length blocks of `seq_len`. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_path: str, |
|
|
tokenizer, |
|
|
seq_len: int, |
|
|
shuffle_lines: bool = False, |
|
|
add_bos_eos: bool = True, |
|
|
repeat: bool = True, |
|
|
buffer_tokens_limit: int = 4_000_000, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.data_path = str(data_path) |
|
|
self.tokenizer = tokenizer |
|
|
self.seq_len = int(seq_len) |
|
|
self.shuffle_lines = bool(shuffle_lines) |
|
|
self.add_bos_eos = bool(add_bos_eos) |
|
|
self.repeat = bool(repeat) |
|
|
self.buffer_tokens_limit = int(buffer_tokens_limit) |
|
|
|
|
|
|
|
|
self._token_buffer: List[int] = [] |
|
|
|
|
|
def _line_iter(self) -> Iterator[str]: |
|
|
path = self.data_path |
|
|
is_gz = path.endswith(".gz") |
|
|
open_fn = gzip.open if is_gz else open |
|
|
mode = "rt" |
|
|
while True: |
|
|
with open_fn(path, mode, encoding="utf-8") as f: |
|
|
for line in f: |
|
|
yield line |
|
|
if not self.repeat: |
|
|
break |
|
|
|
|
|
def _yield_blocks(self) -> Iterator[Dict[str, torch.Tensor]]: |
|
|
bos_id = getattr(self.tokenizer, "bos_token_id", None) |
|
|
eos_id = getattr(self.tokenizer, "eos_token_id", None) |
|
|
|
|
|
|
|
|
token_buffer = self._token_buffer |
|
|
seq_len = self.seq_len |
|
|
|
|
|
for raw_line in self._line_iter(): |
|
|
raw_line = raw_line.strip() |
|
|
if not raw_line: |
|
|
continue |
|
|
try: |
|
|
obj = json.loads(raw_line) |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
text = obj.get("text") |
|
|
if not text or len(text) < 10: |
|
|
continue |
|
|
|
|
|
if self.add_bos_eos and bos_id is not None and eos_id is not None: |
|
|
encoded = self.tokenizer.encode( |
|
|
text, add_special_tokens=False |
|
|
) |
|
|
|
|
|
if not encoded: |
|
|
continue |
|
|
token_buffer.append(bos_id) |
|
|
token_buffer.extend(encoded) |
|
|
token_buffer.append(eos_id) |
|
|
else: |
|
|
encoded = self.tokenizer.encode(text, add_special_tokens=True) |
|
|
if not encoded: |
|
|
continue |
|
|
token_buffer.extend(encoded) |
|
|
|
|
|
|
|
|
if len(token_buffer) > self.buffer_tokens_limit: |
|
|
del token_buffer[: len(token_buffer) - self.buffer_tokens_limit] |
|
|
|
|
|
|
|
|
while len(token_buffer) >= seq_len: |
|
|
block = token_buffer[:seq_len] |
|
|
del token_buffer[:seq_len] |
|
|
|
|
|
input_ids = torch.tensor(block, dtype=torch.long) |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
yield { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"labels": input_ids.clone(), |
|
|
} |
|
|
|
|
|
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: |
|
|
|
|
|
|
|
|
return self._yield_blocks() |
|
|
|
|
|
|
|
|
def build_model_and_tokenizer( |
|
|
tokenizer_name: Optional[str], |
|
|
tokenizer_dir: Optional[str], |
|
|
model_name: Optional[str], |
|
|
vocab_size_override: Optional[int], |
|
|
hidden_size: int, |
|
|
n_layers: int, |
|
|
n_heads: int, |
|
|
n_kv_heads: int, |
|
|
rope_theta: float, |
|
|
max_position_embeddings: int, |
|
|
) -> tuple: |
|
|
|
|
|
if tokenizer_name: |
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) |
|
|
elif tokenizer_dir: |
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True) |
|
|
else: |
|
|
raise ValueError("Provide --tokenizer_name or --tokenizer_dir") |
|
|
|
|
|
|
|
|
if tokenizer.pad_token_id is None: |
|
|
if tokenizer.eos_token_id is not None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
else: |
|
|
|
|
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
|
|
|
|
|
vocab_size = vocab_size_override or len(tokenizer) |
|
|
|
|
|
|
|
|
if model_name: |
|
|
model = LlamaForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
if model.get_input_embeddings().weight.shape[0] != vocab_size: |
|
|
model.resize_token_embeddings(vocab_size) |
|
|
else: |
|
|
config = LlamaConfig( |
|
|
vocab_size=vocab_size, |
|
|
hidden_size=hidden_size, |
|
|
intermediate_size=int(hidden_size * 2.2), |
|
|
num_hidden_layers=n_layers, |
|
|
num_attention_heads=n_heads, |
|
|
num_key_value_heads=n_kv_heads, |
|
|
rms_norm_eps=1e-5, |
|
|
rope_theta=rope_theta, |
|
|
max_position_embeddings=max_position_embeddings, |
|
|
tie_word_embeddings=True, |
|
|
) |
|
|
model = LlamaForCausalLM(config) |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def get_dataloader( |
|
|
data_path: str, |
|
|
tokenizer, |
|
|
seq_len: int, |
|
|
micro_batch_size: int, |
|
|
num_workers: int, |
|
|
) -> DataLoader: |
|
|
dataset = JsonlPackedDataset( |
|
|
data_path=data_path, |
|
|
tokenizer=tokenizer, |
|
|
seq_len=seq_len, |
|
|
shuffle_lines=False, |
|
|
add_bos_eos=True, |
|
|
repeat=True, |
|
|
) |
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=micro_batch_size, |
|
|
num_workers=num_workers, |
|
|
pin_memory=True, |
|
|
drop_last=True, |
|
|
collate_fn=_collate_batch, |
|
|
) |
|
|
|
|
|
|
|
|
def _collate_batch(features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_ids = torch.stack([f["input_ids"] for f in features], dim=0) |
|
|
attention_mask = torch.stack([f["attention_mask"] for f in features], dim=0) |
|
|
labels = torch.stack([f["labels"] for f in features], dim=0) |
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
ap = argparse.ArgumentParser() |
|
|
|
|
|
ap.add_argument("--data_path", required=True, help="Path to JSONL(.gz) with {text}") |
|
|
ap.add_argument("--seq_len", type=int, default=4096) |
|
|
ap.add_argument("--num_workers", type=int, default=2) |
|
|
|
|
|
|
|
|
ap.add_argument("--tokenizer_name", default=None, help="HF tokenizer name") |
|
|
ap.add_argument("--tokenizer_dir", default=None, help="Local dir of HF tokenizer") |
|
|
ap.add_argument("--model_name", default=None, help="HF model name to continue from (CPT)") |
|
|
ap.add_argument("--vocab_size_override", type=int, default=None) |
|
|
|
|
|
|
|
|
ap.add_argument("--hidden_size", type=int, default=768) |
|
|
ap.add_argument("--n_layers", type=int, default=24) |
|
|
ap.add_argument("--n_heads", type=int, default=12) |
|
|
ap.add_argument("--n_kv_heads", type=int, default=1) |
|
|
ap.add_argument("--rope_theta", type=float, default=1e6) |
|
|
ap.add_argument("--max_position_embeddings", type=int, default=4096) |
|
|
|
|
|
|
|
|
ap.add_argument("--out_dir", required=True) |
|
|
ap.add_argument("--global_batch_size", type=int, default=256) |
|
|
ap.add_argument("--micro_batch_size", type=int, default=None, help="Per-step batch size before grad accumulation") |
|
|
ap.add_argument("--lr", type=float, default=1e-3) |
|
|
ap.add_argument("--weight_decay", type=float, default=0.05) |
|
|
ap.add_argument("--warmup_steps", type=int, default=2000) |
|
|
ap.add_argument("--max_steps", type=int, default=50_000) |
|
|
ap.add_argument("--save_every", type=int, default=2000) |
|
|
ap.add_argument("--clip_grad", type=float, default=1.0) |
|
|
ap.add_argument("--bf16", action="store_true") |
|
|
ap.add_argument("--seed", type=int, default=42) |
|
|
|
|
|
return ap.parse_args() |
|
|
|
|
|
|
|
|
def set_seed(seed: int) -> None: |
|
|
random.seed(seed) |
|
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
set_seed(args.seed) |
|
|
|
|
|
out_dir = Path(args.out_dir) |
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
model, tokenizer = build_model_and_tokenizer( |
|
|
tokenizer_name=args.tokenizer_name, |
|
|
tokenizer_dir=args.tokenizer_dir, |
|
|
model_name=args.model_name, |
|
|
vocab_size_override=args.vocab_size_override, |
|
|
hidden_size=args.hidden_size, |
|
|
n_layers=args.n_layers, |
|
|
n_heads=args.n_heads, |
|
|
n_kv_heads=args.n_kv_heads, |
|
|
rope_theta=args.rope_theta, |
|
|
max_position_embeddings=args.max_position_embeddings, |
|
|
) |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
micro_bs = args.micro_batch_size or min( max(1, args.global_batch_size // 8), args.global_batch_size) |
|
|
grad_accum = max(1, args.global_batch_size // micro_bs) |
|
|
train_loader = get_dataloader( |
|
|
data_path=args.data_path, |
|
|
tokenizer=tokenizer, |
|
|
seq_len=args.seq_len, |
|
|
micro_batch_size=micro_bs, |
|
|
num_workers=args.num_workers, |
|
|
) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.95)) |
|
|
scheduler = get_cosine_schedule_with_warmup( |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=args.warmup_steps, |
|
|
num_training_steps=args.max_steps, |
|
|
) |
|
|
|
|
|
scaler = None |
|
|
use_bf16 = args.bf16 and torch.cuda.is_available() |
|
|
autocast_dtype = torch.bfloat16 if use_bf16 else torch.float16 |
|
|
|
|
|
model.train() |
|
|
step = 0 |
|
|
running_loss = 0.0 |
|
|
tokens_per_step = args.global_batch_size * args.seq_len |
|
|
last_log = time.time() |
|
|
|
|
|
|
|
|
data_iter = iter(train_loader) |
|
|
while step < args.max_steps: |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
for micro_step in range(grad_accum): |
|
|
try: |
|
|
batch = next(data_iter) |
|
|
except StopIteration: |
|
|
data_iter = iter(train_loader) |
|
|
batch = next(data_iter) |
|
|
|
|
|
input_ids = batch["input_ids"].to(device, non_blocking=True) |
|
|
attention_mask = batch["attention_mask"].to(device, non_blocking=True) |
|
|
labels = batch["labels"].to(device, non_blocking=True) |
|
|
|
|
|
with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=use_bf16): |
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
|
|
loss = outputs.loss / grad_accum |
|
|
|
|
|
loss.backward() |
|
|
running_loss += loss.item() |
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
step += 1 |
|
|
|
|
|
|
|
|
if step % 10 == 0: |
|
|
now = time.time() |
|
|
dt = now - last_log |
|
|
last_log = now |
|
|
avg_loss = running_loss / 10 |
|
|
running_loss = 0.0 |
|
|
ppl = math.exp(avg_loss) if avg_loss < 30 else float("inf") |
|
|
tokens_sec = tokens_per_step / dt if dt > 0 else 0.0 |
|
|
print( |
|
|
f"step {step:6d} | loss {avg_loss:.4f} | ppl {ppl:.2f} | tokens/s {tokens_sec:,.0f} | lr {scheduler.get_last_lr()[0]:.2e}", |
|
|
flush=True, |
|
|
) |
|
|
|
|
|
|
|
|
if step % args.save_every == 0 or step == args.max_steps: |
|
|
ckpt_dir = out_dir / f"step_{step:06d}" |
|
|
ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
|
model.save_pretrained(ckpt_dir) |
|
|
tokenizer.save_pretrained(ckpt_dir) |
|
|
|
|
|
|
|
|
if step % 100 == 0: |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
model.save_pretrained(out_dir / "final") |
|
|
tokenizer.save_pretrained(out_dir / "final") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|