#!/usr/bin/env python3 """ Nanomind pretraining script for decoder-only causal LM on JSONL.gz data. - Expects input file with one JSON object per line containing a `text` field. - Streams, tokenizes, and packs sequences to a fixed length for efficient training. - Uses a small LLaMA-style config by default (RMSNorm + SwiGLU + RoPE, MQA). Usage example: python /workspace/nanomind/train.py \ --data_path /workspace/nanomind_data/pretrain_1m.jsonl.gz \ --out_dir /workspace/nanomind_runs/run1 \ --tokenizer_name hf-internal-testing/llama-tokenizer \ --seq_len 4096 --global_batch_size 256 \ --lr 1e-3 --warmup_steps 2000 --max_steps 50000 --bf16 """ import os import io import gc import gzip import json import math import time import random import argparse from pathlib import Path from typing import Iterator, List, Dict, Optional import torch from torch import nn from torch.utils.data import IterableDataset, DataLoader from transformers import ( AutoTokenizer, LlamaConfig, LlamaForCausalLM, get_cosine_schedule_with_warmup, ) class JsonlPackedDataset(IterableDataset): """ Streams a JSONL(.gz) file of objects with a `text` field, tokenizes, and packs tokens into fixed-length blocks of `seq_len`. """ def __init__( self, data_path: str, tokenizer, seq_len: int, shuffle_lines: bool = False, add_bos_eos: bool = True, repeat: bool = True, buffer_tokens_limit: int = 4_000_000, ) -> None: super().__init__() self.data_path = str(data_path) self.tokenizer = tokenizer self.seq_len = int(seq_len) self.shuffle_lines = bool(shuffle_lines) self.add_bos_eos = bool(add_bos_eos) self.repeat = bool(repeat) self.buffer_tokens_limit = int(buffer_tokens_limit) # pack buffers self._token_buffer: List[int] = [] def _line_iter(self) -> Iterator[str]: path = self.data_path is_gz = path.endswith(".gz") open_fn = gzip.open if is_gz else open mode = "rt" while True: with open_fn(path, mode, encoding="utf-8") as f: for line in f: yield line if not self.repeat: break def _yield_blocks(self) -> Iterator[Dict[str, torch.Tensor]]: bos_id = getattr(self.tokenizer, "bos_token_id", None) eos_id = getattr(self.tokenizer, "eos_token_id", None) # local references for speed token_buffer = self._token_buffer seq_len = self.seq_len for raw_line in self._line_iter(): raw_line = raw_line.strip() if not raw_line: continue try: obj = json.loads(raw_line) except json.JSONDecodeError: continue text = obj.get("text") if not text or len(text) < 10: continue if self.add_bos_eos and bos_id is not None and eos_id is not None: encoded = self.tokenizer.encode( text, add_special_tokens=False ) # Guard against rare None returns if not encoded: continue token_buffer.append(bos_id) token_buffer.extend(encoded) token_buffer.append(eos_id) else: encoded = self.tokenizer.encode(text, add_special_tokens=True) if not encoded: continue token_buffer.extend(encoded) # If buffer grows too large, drop tail to constrain RAM if len(token_buffer) > self.buffer_tokens_limit: del token_buffer[: len(token_buffer) - self.buffer_tokens_limit] # Emit fixed-length blocks while len(token_buffer) >= seq_len: block = token_buffer[:seq_len] del token_buffer[:seq_len] input_ids = torch.tensor(block, dtype=torch.long) attention_mask = torch.ones_like(input_ids) # Causal LM uses labels equal to inputs yield { "input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids.clone(), } def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: # Worker-specific shard: in IterableDataset DataLoader workers receive cloned objects. # To keep it simple and deterministic, don't split lines per-worker; rely on global batching. return self._yield_blocks() def build_model_and_tokenizer( tokenizer_name: Optional[str], tokenizer_dir: Optional[str], model_name: Optional[str], vocab_size_override: Optional[int], hidden_size: int, n_layers: int, n_heads: int, n_kv_heads: int, rope_theta: float, max_position_embeddings: int, ) -> tuple: # Tokenizer if tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) elif tokenizer_dir: tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True) else: raise ValueError("Provide --tokenizer_name or --tokenizer_dir") # Ensure pad token for batching; map to eos if missing (common for causal LMs) if tokenizer.pad_token_id is None: if tokenizer.eos_token_id is not None: tokenizer.pad_token = tokenizer.eos_token else: # Fallback: add a [PAD] token tokenizer.add_special_tokens({"pad_token": "[PAD]"}) vocab_size = vocab_size_override or len(tokenizer) # Model if model_name: model = LlamaForCausalLM.from_pretrained(model_name) # Resize embeddings if tokenizer changed if model.get_input_embeddings().weight.shape[0] != vocab_size: model.resize_token_embeddings(vocab_size) else: config = LlamaConfig( vocab_size=vocab_size, hidden_size=hidden_size, # d_model intermediate_size=int(hidden_size * 2.2), # SwiGLU widen 2.0–2.5 num_hidden_layers=n_layers, num_attention_heads=n_heads, num_key_value_heads=n_kv_heads, rms_norm_eps=1e-5, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, tie_word_embeddings=True, ) model = LlamaForCausalLM(config) return model, tokenizer def get_dataloader( data_path: str, tokenizer, seq_len: int, micro_batch_size: int, num_workers: int, ) -> DataLoader: dataset = JsonlPackedDataset( data_path=data_path, tokenizer=tokenizer, seq_len=seq_len, shuffle_lines=False, add_bos_eos=True, repeat=True, ) return DataLoader( dataset, batch_size=micro_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, collate_fn=_collate_batch, ) def _collate_batch(features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: # All are fixed-length; just stack input_ids = torch.stack([f["input_ids"] for f in features], dim=0) attention_mask = torch.stack([f["attention_mask"] for f in features], dim=0) labels = torch.stack([f["labels"] for f in features], dim=0) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} def parse_args() -> argparse.Namespace: ap = argparse.ArgumentParser() # Data ap.add_argument("--data_path", required=True, help="Path to JSONL(.gz) with {text}") ap.add_argument("--seq_len", type=int, default=4096) ap.add_argument("--num_workers", type=int, default=2) # Tokenizer & Model ap.add_argument("--tokenizer_name", default=None, help="HF tokenizer name") ap.add_argument("--tokenizer_dir", default=None, help="Local dir of HF tokenizer") ap.add_argument("--model_name", default=None, help="HF model name to continue from (CPT)") ap.add_argument("--vocab_size_override", type=int, default=None) # Small LLaMA-like config (used when --model_name not provided) ap.add_argument("--hidden_size", type=int, default=768) ap.add_argument("--n_layers", type=int, default=24) ap.add_argument("--n_heads", type=int, default=12) ap.add_argument("--n_kv_heads", type=int, default=1) ap.add_argument("--rope_theta", type=float, default=1e6) ap.add_argument("--max_position_embeddings", type=int, default=4096) # Training ap.add_argument("--out_dir", required=True) ap.add_argument("--global_batch_size", type=int, default=256) ap.add_argument("--micro_batch_size", type=int, default=None, help="Per-step batch size before grad accumulation") ap.add_argument("--lr", type=float, default=1e-3) ap.add_argument("--weight_decay", type=float, default=0.05) ap.add_argument("--warmup_steps", type=int, default=2000) ap.add_argument("--max_steps", type=int, default=50_000) ap.add_argument("--save_every", type=int, default=2000) ap.add_argument("--clip_grad", type=float, default=1.0) ap.add_argument("--bf16", action="store_true") ap.add_argument("--seed", type=int, default=42) return ap.parse_args() def set_seed(seed: int) -> None: random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def main() -> None: args = parse_args() set_seed(args.seed) out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True model, tokenizer = build_model_and_tokenizer( tokenizer_name=args.tokenizer_name, tokenizer_dir=args.tokenizer_dir, model_name=args.model_name, vocab_size_override=args.vocab_size_override, hidden_size=args.hidden_size, n_layers=args.n_layers, n_heads=args.n_heads, n_kv_heads=args.n_kv_heads, rope_theta=args.rope_theta, max_position_embeddings=args.max_position_embeddings, ) model = model.to(device) # Data micro_bs = args.micro_batch_size or min( max(1, args.global_batch_size // 8), args.global_batch_size) grad_accum = max(1, args.global_batch_size // micro_bs) train_loader = get_dataloader( data_path=args.data_path, tokenizer=tokenizer, seq_len=args.seq_len, micro_batch_size=micro_bs, num_workers=args.num_workers, ) # Optimizer & Scheduler optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.95)) scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps, ) scaler = None use_bf16 = args.bf16 and torch.cuda.is_available() autocast_dtype = torch.bfloat16 if use_bf16 else torch.float16 model.train() step = 0 running_loss = 0.0 tokens_per_step = args.global_batch_size * args.seq_len last_log = time.time() # Simple training loop over streaming dataloader data_iter = iter(train_loader) while step < args.max_steps: optimizer.zero_grad(set_to_none=True) for micro_step in range(grad_accum): try: batch = next(data_iter) except StopIteration: data_iter = iter(train_loader) batch = next(data_iter) input_ids = batch["input_ids"].to(device, non_blocking=True) attention_mask = batch["attention_mask"].to(device, non_blocking=True) labels = batch["labels"].to(device, non_blocking=True) with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=use_bf16): outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss / grad_accum loss.backward() running_loss += loss.item() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad) optimizer.step() scheduler.step() step += 1 # Logging if step % 10 == 0: now = time.time() dt = now - last_log last_log = now avg_loss = running_loss / 10 running_loss = 0.0 ppl = math.exp(avg_loss) if avg_loss < 30 else float("inf") tokens_sec = tokens_per_step / dt if dt > 0 else 0.0 print( f"step {step:6d} | loss {avg_loss:.4f} | ppl {ppl:.2f} | tokens/s {tokens_sec:,.0f} | lr {scheduler.get_last_lr()[0]:.2e}", flush=True, ) # Checkpointing if step % args.save_every == 0 or step == args.max_steps: ckpt_dir = out_dir / f"step_{step:06d}" ckpt_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(ckpt_dir) tokenizer.save_pretrained(ckpt_dir) # Small memory hygiene if step % 100 == 0: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Final save model.save_pretrained(out_dir / "final") tokenizer.save_pretrained(out_dir / "final") if __name__ == "__main__": main()