Helion-OSC / train.py
Trouter-Library's picture
Create train.py
a13f30f verified
raw
history blame
14.5 kB
"""
Helion-OSC Training Script
Fine-tuning and training utilities for Helion-OSC model
"""
import os
import torch
import json
import logging
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
EarlyStoppingCallback
)
from datasets import load_dataset, Dataset, DatasetDict
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
TaskType
)
import wandb
from torch.utils.data import DataLoader
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""Arguments for model configuration"""
model_name_or_path: str = field(
default="DeepXR/Helion-OSC",
metadata={"help": "Path to pretrained model or model identifier"}
)
use_lora: bool = field(
default=True,
metadata={"help": "Whether to use LoRA for efficient fine-tuning"}
)
lora_r: int = field(
default=16,
metadata={"help": "LoRA attention dimension"}
)
lora_alpha: int = field(
default=32,
metadata={"help": "LoRA alpha parameter"}
)
lora_dropout: float = field(
default=0.05,
metadata={"help": "LoRA dropout probability"}
)
load_in_8bit: bool = field(
default=False,
metadata={"help": "Load model in 8-bit precision"}
)
load_in_4bit: bool = field(
default=False,
metadata={"help": "Load model in 4-bit precision"}
)
@dataclass
class DataArguments:
"""Arguments for data processing"""
dataset_name: Optional[str] = field(
default=None,
metadata={"help": "Name of the dataset to use"}
)
dataset_path: Optional[str] = field(
default=None,
metadata={"help": "Path to local dataset"}
)
train_file: Optional[str] = field(
default=None,
metadata={"help": "Path to training data file"}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "Path to validation data file"}
)
max_seq_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length"}
)
preprocessing_num_workers: int = field(
default=4,
metadata={"help": "Number of workers for preprocessing"}
)
class HelionOSCTrainer:
"""Trainer class for Helion-OSC model"""
def __init__(
self,
model_args: ModelArguments,
data_args: DataArguments,
training_args: TrainingArguments
):
self.model_args = model_args
self.data_args = data_args
self.training_args = training_args
# Initialize tokenizer
self.tokenizer = self._load_tokenizer()
# Initialize model
self.model = self._load_model()
# Load and preprocess data
self.datasets = self._load_datasets()
logger.info("Trainer initialized successfully")
def _load_tokenizer(self):
"""Load and configure tokenizer"""
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
self.model_args.model_name_or_path,
trust_remote_code=True,
padding_side="right"
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def _load_model(self):
"""Load and configure model"""
logger.info("Loading model...")
model_kwargs = {
"trust_remote_code": True,
"low_cpu_mem_usage": True
}
# Configure quantization
if self.model_args.load_in_8bit:
model_kwargs["load_in_8bit"] = True
elif self.model_args.load_in_4bit:
model_kwargs["load_in_4bit"] = True
model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16
model_kwargs["bnb_4bit_use_double_quant"] = True
model_kwargs["bnb_4bit_quant_type"] = "nf4"
else:
model_kwargs["torch_dtype"] = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
self.model_args.model_name_or_path,
**model_kwargs
)
# Apply LoRA if requested
if self.model_args.use_lora:
logger.info("Applying LoRA configuration...")
if self.model_args.load_in_8bit or self.model_args.load_in_4bit:
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=self.model_args.lora_r,
lora_alpha=self.model_args.lora_alpha,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj"
],
lora_dropout=self.model_args.lora_dropout,
bias="none",
task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model
def _load_datasets(self) -> DatasetDict:
"""Load and preprocess datasets"""
logger.info("Loading datasets...")
if self.data_args.dataset_name:
# Load from HuggingFace Hub
datasets = load_dataset(self.data_args.dataset_name)
elif self.data_args.train_file:
# Load from local files
data_files = {"train": self.data_args.train_file}
if self.data_args.validation_file:
data_files["validation"] = self.data_args.validation_file
datasets = load_dataset("json", data_files=data_files)
else:
raise ValueError("Must provide either dataset_name or train_file")
# Preprocess datasets
logger.info("Preprocessing datasets...")
datasets = datasets.map(
self._preprocess_function,
batched=True,
num_proc=self.data_args.preprocessing_num_workers,
remove_columns=datasets["train"].column_names,
desc="Preprocessing datasets"
)
return datasets
def _preprocess_function(self, examples):
"""Preprocess examples for training"""
# Tokenize inputs
if "prompt" in examples and "completion" in examples:
# Instruction-following format
texts = [
f"{prompt}\n{completion}"
for prompt, completion in zip(examples["prompt"], examples["completion"])
]
elif "text" in examples:
# Raw text format
texts = examples["text"]
else:
raise ValueError("Dataset must contain 'text' or 'prompt'/'completion' columns")
# Tokenize
tokenized = self.tokenizer(
texts,
truncation=True,
max_length=self.data_args.max_seq_length,
padding="max_length",
return_tensors=None
)
# Create labels (same as input_ids for causal LM)
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
def train(self):
"""Train the model"""
logger.info("Starting training...")
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False
)
# Initialize trainer
trainer = Trainer(
model=self.model,
args=self.training_args,
train_dataset=self.datasets["train"],
eval_dataset=self.datasets.get("validation"),
tokenizer=self.tokenizer,
data_collator=data_collator,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
# Train
train_result = trainer.train()
# Save model
trainer.save_model()
# Save metrics
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
logger.info("Training completed successfully!")
return trainer, metrics
def evaluate(self, trainer: Optional[Trainer] = None):
"""Evaluate the model"""
if trainer is None:
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False
)
trainer = Trainer(
model=self.model,
args=self.training_args,
eval_dataset=self.datasets.get("validation"),
tokenizer=self.tokenizer,
data_collator=data_collator
)
logger.info("Evaluating model...")
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
return metrics
def create_code_dataset(examples: List[Dict[str, str]]) -> Dataset:
"""
Create a dataset from code examples
Args:
examples: List of dictionaries with 'prompt' and 'completion' keys
Returns:
Dataset object
"""
return Dataset.from_dict({
"prompt": [ex["prompt"] for ex in examples],
"completion": [ex["completion"] for ex in examples]
})
def create_math_dataset(examples: List[Dict[str, str]]) -> Dataset:
"""
Create a dataset from math examples
Args:
examples: List of dictionaries with 'problem' and 'solution' keys
Returns:
Dataset object
"""
return Dataset.from_dict({
"prompt": [f"Problem: {ex['problem']}\nSolution:" for ex in examples],
"completion": [ex["solution"] for ex in examples]
})
def main():
"""Main training script"""
import argparse
parser = argparse.ArgumentParser(description="Train Helion-OSC model")
# Model arguments
parser.add_argument("--model_name_or_path", type=str, default="DeepXR/Helion-OSC")
parser.add_argument("--use_lora", action="store_true", default=True)
parser.add_argument("--lora_r", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.05)
parser.add_argument("--load_in_8bit", action="store_true")
parser.add_argument("--load_in_4bit", action="store_true")
# Data arguments
parser.add_argument("--dataset_name", type=str, default=None)
parser.add_argument("--dataset_path", type=str, default=None)
parser.add_argument("--train_file", type=str, required=True)
parser.add_argument("--validation_file", type=str, default=None)
parser.add_argument("--max_seq_length", type=int, default=2048)
parser.add_argument("--preprocessing_num_workers", type=int, default=4)
# Training arguments
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--num_train_epochs", type=int, default=3)
parser.add_argument("--per_device_train_batch_size", type=int, default=4)
parser.add_argument("--per_device_eval_batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--warmup_steps", type=int, default=100)
parser.add_argument("--logging_steps", type=int, default=10)
parser.add_argument("--save_steps", type=int, default=500)
parser.add_argument("--eval_steps", type=int, default=500)
parser.add_argument("--save_total_limit", type=int, default=3)
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--gradient_checkpointing", action="store_true")
parser.add_argument("--use_wandb", action="store_true")
args = parser.parse_args()
# Create argument objects
model_args = ModelArguments(
model_name_or_path=args.model_name_or_path,
use_lora=args.use_lora,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
load_in_8bit=args.load_in_8bit,
load_in_4bit=args.load_in_4bit
)
data_args = DataArguments(
dataset_name=args.dataset_name,
dataset_path=args.dataset_path,
train_file=args.train_file,
validation_file=args.validation_file,
max_seq_length=args.max_seq_length,
preprocessing_num_workers=args.preprocessing_num_workers
)
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_steps=args.warmup_steps,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
save_total_limit=args.save_total_limit,
fp16=args.fp16,
bf16=args.bf16,
gradient_checkpointing=args.gradient_checkpointing,
report_to="wandb" if args.use_wandb else "none",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
evaluation_strategy="steps",
save_strategy="steps",
logging_dir=f"{args.output_dir}/logs",
remove_unused_columns=False
)
# Initialize trainer
helion_trainer = HelionOSCTrainer(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
# Train
trainer, metrics = helion_trainer.train()
# Evaluate
if args.validation_file:
eval_metrics = helion_trainer.evaluate(trainer)
logger.info(f"Evaluation metrics: {eval_metrics}")
logger.info("Training pipeline completed!")
if __name__ == "__main__":
main()