""" Helion-OSC Training Script Fine-tuning and training utilities for Helion-OSC model """ import os import torch import json import logging from typing import Optional, Dict, Any, List from dataclasses import dataclass, field from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling, EarlyStoppingCallback ) from datasets import load_dataset, Dataset, DatasetDict from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType ) import wandb from torch.utils.data import DataLoader logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class ModelArguments: """Arguments for model configuration""" model_name_or_path: str = field( default="DeepXR/Helion-OSC", metadata={"help": "Path to pretrained model or model identifier"} ) use_lora: bool = field( default=True, metadata={"help": "Whether to use LoRA for efficient fine-tuning"} ) lora_r: int = field( default=16, metadata={"help": "LoRA attention dimension"} ) lora_alpha: int = field( default=32, metadata={"help": "LoRA alpha parameter"} ) lora_dropout: float = field( default=0.05, metadata={"help": "LoRA dropout probability"} ) load_in_8bit: bool = field( default=False, metadata={"help": "Load model in 8-bit precision"} ) load_in_4bit: bool = field( default=False, metadata={"help": "Load model in 4-bit precision"} ) @dataclass class DataArguments: """Arguments for data processing""" dataset_name: Optional[str] = field( default=None, metadata={"help": "Name of the dataset to use"} ) dataset_path: Optional[str] = field( default=None, metadata={"help": "Path to local dataset"} ) train_file: Optional[str] = field( default=None, metadata={"help": "Path to training data file"} ) validation_file: Optional[str] = field( default=None, metadata={"help": "Path to validation data file"} ) max_seq_length: int = field( default=2048, metadata={"help": "Maximum sequence length"} ) preprocessing_num_workers: int = field( default=4, metadata={"help": "Number of workers for preprocessing"} ) class HelionOSCTrainer: """Trainer class for Helion-OSC model""" def __init__( self, model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments ): self.model_args = model_args self.data_args = data_args self.training_args = training_args # Initialize tokenizer self.tokenizer = self._load_tokenizer() # Initialize model self.model = self._load_model() # Load and preprocess data self.datasets = self._load_datasets() logger.info("Trainer initialized successfully") def _load_tokenizer(self): """Load and configure tokenizer""" logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( self.model_args.model_name_or_path, trust_remote_code=True, padding_side="right" ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer def _load_model(self): """Load and configure model""" logger.info("Loading model...") model_kwargs = { "trust_remote_code": True, "low_cpu_mem_usage": True } # Configure quantization if self.model_args.load_in_8bit: model_kwargs["load_in_8bit"] = True elif self.model_args.load_in_4bit: model_kwargs["load_in_4bit"] = True model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16 model_kwargs["bnb_4bit_use_double_quant"] = True model_kwargs["bnb_4bit_quant_type"] = "nf4" else: model_kwargs["torch_dtype"] = torch.bfloat16 model = AutoModelForCausalLM.from_pretrained( self.model_args.model_name_or_path, **model_kwargs ) # Apply LoRA if requested if self.model_args.use_lora: logger.info("Applying LoRA configuration...") if self.model_args.load_in_8bit or self.model_args.load_in_4bit: model = prepare_model_for_kbit_training(model) lora_config = LoraConfig( r=self.model_args.lora_r, lora_alpha=self.model_args.lora_alpha, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj" ], lora_dropout=self.model_args.lora_dropout, bias="none", task_type=TaskType.CAUSAL_LM ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() return model def _load_datasets(self) -> DatasetDict: """Load and preprocess datasets""" logger.info("Loading datasets...") if self.data_args.dataset_name: # Load from HuggingFace Hub datasets = load_dataset(self.data_args.dataset_name) elif self.data_args.train_file: # Load from local files data_files = {"train": self.data_args.train_file} if self.data_args.validation_file: data_files["validation"] = self.data_args.validation_file datasets = load_dataset("json", data_files=data_files) else: raise ValueError("Must provide either dataset_name or train_file") # Preprocess datasets logger.info("Preprocessing datasets...") datasets = datasets.map( self._preprocess_function, batched=True, num_proc=self.data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names, desc="Preprocessing datasets" ) return datasets def _preprocess_function(self, examples): """Preprocess examples for training""" # Tokenize inputs if "prompt" in examples and "completion" in examples: # Instruction-following format texts = [ f"{prompt}\n{completion}" for prompt, completion in zip(examples["prompt"], examples["completion"]) ] elif "text" in examples: # Raw text format texts = examples["text"] else: raise ValueError("Dataset must contain 'text' or 'prompt'/'completion' columns") # Tokenize tokenized = self.tokenizer( texts, truncation=True, max_length=self.data_args.max_seq_length, padding="max_length", return_tensors=None ) # Create labels (same as input_ids for causal LM) tokenized["labels"] = tokenized["input_ids"].copy() return tokenized def train(self): """Train the model""" logger.info("Starting training...") # Data collator data_collator = DataCollatorForLanguageModeling( tokenizer=self.tokenizer, mlm=False ) # Initialize trainer trainer = Trainer( model=self.model, args=self.training_args, train_dataset=self.datasets["train"], eval_dataset=self.datasets.get("validation"), tokenizer=self.tokenizer, data_collator=data_collator, callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] ) # Train train_result = trainer.train() # Save model trainer.save_model() # Save metrics metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() logger.info("Training completed successfully!") return trainer, metrics def evaluate(self, trainer: Optional[Trainer] = None): """Evaluate the model""" if trainer is None: data_collator = DataCollatorForLanguageModeling( tokenizer=self.tokenizer, mlm=False ) trainer = Trainer( model=self.model, args=self.training_args, eval_dataset=self.datasets.get("validation"), tokenizer=self.tokenizer, data_collator=data_collator ) logger.info("Evaluating model...") metrics = trainer.evaluate() trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) return metrics def create_code_dataset(examples: List[Dict[str, str]]) -> Dataset: """ Create a dataset from code examples Args: examples: List of dictionaries with 'prompt' and 'completion' keys Returns: Dataset object """ return Dataset.from_dict({ "prompt": [ex["prompt"] for ex in examples], "completion": [ex["completion"] for ex in examples] }) def create_math_dataset(examples: List[Dict[str, str]]) -> Dataset: """ Create a dataset from math examples Args: examples: List of dictionaries with 'problem' and 'solution' keys Returns: Dataset object """ return Dataset.from_dict({ "prompt": [f"Problem: {ex['problem']}\nSolution:" for ex in examples], "completion": [ex["solution"] for ex in examples] }) def main(): """Main training script""" import argparse parser = argparse.ArgumentParser(description="Train Helion-OSC model") # Model arguments parser.add_argument("--model_name_or_path", type=str, default="DeepXR/Helion-OSC") parser.add_argument("--use_lora", action="store_true", default=True) parser.add_argument("--lora_r", type=int, default=16) parser.add_argument("--lora_alpha", type=int, default=32) parser.add_argument("--lora_dropout", type=float, default=0.05) parser.add_argument("--load_in_8bit", action="store_true") parser.add_argument("--load_in_4bit", action="store_true") # Data arguments parser.add_argument("--dataset_name", type=str, default=None) parser.add_argument("--dataset_path", type=str, default=None) parser.add_argument("--train_file", type=str, required=True) parser.add_argument("--validation_file", type=str, default=None) parser.add_argument("--max_seq_length", type=int, default=2048) parser.add_argument("--preprocessing_num_workers", type=int, default=4) # Training arguments parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--num_train_epochs", type=int, default=3) parser.add_argument("--per_device_train_batch_size", type=int, default=4) parser.add_argument("--per_device_eval_batch_size", type=int, default=4) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=2e-5) parser.add_argument("--warmup_steps", type=int, default=100) parser.add_argument("--logging_steps", type=int, default=10) parser.add_argument("--save_steps", type=int, default=500) parser.add_argument("--eval_steps", type=int, default=500) parser.add_argument("--save_total_limit", type=int, default=3) parser.add_argument("--fp16", action="store_true") parser.add_argument("--bf16", action="store_true") parser.add_argument("--gradient_checkpointing", action="store_true") parser.add_argument("--use_wandb", action="store_true") args = parser.parse_args() # Create argument objects model_args = ModelArguments( model_name_or_path=args.model_name_or_path, use_lora=args.use_lora, lora_r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit ) data_args = DataArguments( dataset_name=args.dataset_name, dataset_path=args.dataset_path, train_file=args.train_file, validation_file=args.validation_file, max_seq_length=args.max_seq_length, preprocessing_num_workers=args.preprocessing_num_workers ) training_args = TrainingArguments( output_dir=args.output_dir, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, warmup_steps=args.warmup_steps, logging_steps=args.logging_steps, save_steps=args.save_steps, eval_steps=args.eval_steps, save_total_limit=args.save_total_limit, fp16=args.fp16, bf16=args.bf16, gradient_checkpointing=args.gradient_checkpointing, report_to="wandb" if args.use_wandb else "none", load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, evaluation_strategy="steps", save_strategy="steps", logging_dir=f"{args.output_dir}/logs", remove_unused_columns=False ) # Initialize trainer helion_trainer = HelionOSCTrainer( model_args=model_args, data_args=data_args, training_args=training_args ) # Train trainer, metrics = helion_trainer.train() # Evaluate if args.validation_file: eval_metrics = helion_trainer.evaluate(trainer) logger.info(f"Evaluation metrics: {eval_metrics}") logger.info("Training pipeline completed!") if __name__ == "__main__": main()