|
|
""" |
|
|
Helion-OSC Training Script |
|
|
Fine-tuning and training utilities for Helion-OSC model |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import json |
|
|
import logging |
|
|
from typing import Optional, Dict, Any, List |
|
|
from dataclasses import dataclass, field |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
DataCollatorForLanguageModeling, |
|
|
EarlyStoppingCallback |
|
|
) |
|
|
from datasets import load_dataset, Dataset, DatasetDict |
|
|
from peft import ( |
|
|
LoraConfig, |
|
|
get_peft_model, |
|
|
prepare_model_for_kbit_training, |
|
|
TaskType |
|
|
) |
|
|
import wandb |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelArguments: |
|
|
"""Arguments for model configuration""" |
|
|
model_name_or_path: str = field( |
|
|
default="DeepXR/Helion-OSC", |
|
|
metadata={"help": "Path to pretrained model or model identifier"} |
|
|
) |
|
|
use_lora: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Whether to use LoRA for efficient fine-tuning"} |
|
|
) |
|
|
lora_r: int = field( |
|
|
default=16, |
|
|
metadata={"help": "LoRA attention dimension"} |
|
|
) |
|
|
lora_alpha: int = field( |
|
|
default=32, |
|
|
metadata={"help": "LoRA alpha parameter"} |
|
|
) |
|
|
lora_dropout: float = field( |
|
|
default=0.05, |
|
|
metadata={"help": "LoRA dropout probability"} |
|
|
) |
|
|
load_in_8bit: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Load model in 8-bit precision"} |
|
|
) |
|
|
load_in_4bit: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Load model in 4-bit precision"} |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataArguments: |
|
|
"""Arguments for data processing""" |
|
|
dataset_name: Optional[str] = field( |
|
|
default=None, |
|
|
metadata={"help": "Name of the dataset to use"} |
|
|
) |
|
|
dataset_path: Optional[str] = field( |
|
|
default=None, |
|
|
metadata={"help": "Path to local dataset"} |
|
|
) |
|
|
train_file: Optional[str] = field( |
|
|
default=None, |
|
|
metadata={"help": "Path to training data file"} |
|
|
) |
|
|
validation_file: Optional[str] = field( |
|
|
default=None, |
|
|
metadata={"help": "Path to validation data file"} |
|
|
) |
|
|
max_seq_length: int = field( |
|
|
default=2048, |
|
|
metadata={"help": "Maximum sequence length"} |
|
|
) |
|
|
preprocessing_num_workers: int = field( |
|
|
default=4, |
|
|
metadata={"help": "Number of workers for preprocessing"} |
|
|
) |
|
|
|
|
|
|
|
|
class HelionOSCTrainer: |
|
|
"""Trainer class for Helion-OSC model""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_args: ModelArguments, |
|
|
data_args: DataArguments, |
|
|
training_args: TrainingArguments |
|
|
): |
|
|
self.model_args = model_args |
|
|
self.data_args = data_args |
|
|
self.training_args = training_args |
|
|
|
|
|
|
|
|
self.tokenizer = self._load_tokenizer() |
|
|
|
|
|
|
|
|
self.model = self._load_model() |
|
|
|
|
|
|
|
|
self.datasets = self._load_datasets() |
|
|
|
|
|
logger.info("Trainer initialized successfully") |
|
|
|
|
|
def _load_tokenizer(self): |
|
|
"""Load and configure tokenizer""" |
|
|
logger.info("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_args.model_name_or_path, |
|
|
trust_remote_code=True, |
|
|
padding_side="right" |
|
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
return tokenizer |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load and configure model""" |
|
|
logger.info("Loading model...") |
|
|
|
|
|
model_kwargs = { |
|
|
"trust_remote_code": True, |
|
|
"low_cpu_mem_usage": True |
|
|
} |
|
|
|
|
|
|
|
|
if self.model_args.load_in_8bit: |
|
|
model_kwargs["load_in_8bit"] = True |
|
|
elif self.model_args.load_in_4bit: |
|
|
model_kwargs["load_in_4bit"] = True |
|
|
model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16 |
|
|
model_kwargs["bnb_4bit_use_double_quant"] = True |
|
|
model_kwargs["bnb_4bit_quant_type"] = "nf4" |
|
|
else: |
|
|
model_kwargs["torch_dtype"] = torch.bfloat16 |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_args.model_name_or_path, |
|
|
**model_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if self.model_args.use_lora: |
|
|
logger.info("Applying LoRA configuration...") |
|
|
|
|
|
if self.model_args.load_in_8bit or self.model_args.load_in_4bit: |
|
|
model = prepare_model_for_kbit_training(model) |
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=self.model_args.lora_r, |
|
|
lora_alpha=self.model_args.lora_alpha, |
|
|
target_modules=[ |
|
|
"q_proj", |
|
|
"k_proj", |
|
|
"v_proj", |
|
|
"o_proj", |
|
|
"gate_proj", |
|
|
"up_proj", |
|
|
"down_proj" |
|
|
], |
|
|
lora_dropout=self.model_args.lora_dropout, |
|
|
bias="none", |
|
|
task_type=TaskType.CAUSAL_LM |
|
|
) |
|
|
|
|
|
model = get_peft_model(model, lora_config) |
|
|
model.print_trainable_parameters() |
|
|
|
|
|
return model |
|
|
|
|
|
def _load_datasets(self) -> DatasetDict: |
|
|
"""Load and preprocess datasets""" |
|
|
logger.info("Loading datasets...") |
|
|
|
|
|
if self.data_args.dataset_name: |
|
|
|
|
|
datasets = load_dataset(self.data_args.dataset_name) |
|
|
elif self.data_args.train_file: |
|
|
|
|
|
data_files = {"train": self.data_args.train_file} |
|
|
if self.data_args.validation_file: |
|
|
data_files["validation"] = self.data_args.validation_file |
|
|
|
|
|
datasets = load_dataset("json", data_files=data_files) |
|
|
else: |
|
|
raise ValueError("Must provide either dataset_name or train_file") |
|
|
|
|
|
|
|
|
logger.info("Preprocessing datasets...") |
|
|
datasets = datasets.map( |
|
|
self._preprocess_function, |
|
|
batched=True, |
|
|
num_proc=self.data_args.preprocessing_num_workers, |
|
|
remove_columns=datasets["train"].column_names, |
|
|
desc="Preprocessing datasets" |
|
|
) |
|
|
|
|
|
return datasets |
|
|
|
|
|
def _preprocess_function(self, examples): |
|
|
"""Preprocess examples for training""" |
|
|
|
|
|
if "prompt" in examples and "completion" in examples: |
|
|
|
|
|
texts = [ |
|
|
f"{prompt}\n{completion}" |
|
|
for prompt, completion in zip(examples["prompt"], examples["completion"]) |
|
|
] |
|
|
elif "text" in examples: |
|
|
|
|
|
texts = examples["text"] |
|
|
else: |
|
|
raise ValueError("Dataset must contain 'text' or 'prompt'/'completion' columns") |
|
|
|
|
|
|
|
|
tokenized = self.tokenizer( |
|
|
texts, |
|
|
truncation=True, |
|
|
max_length=self.data_args.max_seq_length, |
|
|
padding="max_length", |
|
|
return_tensors=None |
|
|
) |
|
|
|
|
|
|
|
|
tokenized["labels"] = tokenized["input_ids"].copy() |
|
|
|
|
|
return tokenized |
|
|
|
|
|
def train(self): |
|
|
"""Train the model""" |
|
|
logger.info("Starting training...") |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=self.tokenizer, |
|
|
mlm=False |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=self.model, |
|
|
args=self.training_args, |
|
|
train_dataset=self.datasets["train"], |
|
|
eval_dataset=self.datasets.get("validation"), |
|
|
tokenizer=self.tokenizer, |
|
|
data_collator=data_collator, |
|
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] |
|
|
) |
|
|
|
|
|
|
|
|
train_result = trainer.train() |
|
|
|
|
|
|
|
|
trainer.save_model() |
|
|
|
|
|
|
|
|
metrics = train_result.metrics |
|
|
trainer.log_metrics("train", metrics) |
|
|
trainer.save_metrics("train", metrics) |
|
|
trainer.save_state() |
|
|
|
|
|
logger.info("Training completed successfully!") |
|
|
|
|
|
return trainer, metrics |
|
|
|
|
|
def evaluate(self, trainer: Optional[Trainer] = None): |
|
|
"""Evaluate the model""" |
|
|
if trainer is None: |
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=self.tokenizer, |
|
|
mlm=False |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=self.model, |
|
|
args=self.training_args, |
|
|
eval_dataset=self.datasets.get("validation"), |
|
|
tokenizer=self.tokenizer, |
|
|
data_collator=data_collator |
|
|
) |
|
|
|
|
|
logger.info("Evaluating model...") |
|
|
metrics = trainer.evaluate() |
|
|
|
|
|
trainer.log_metrics("eval", metrics) |
|
|
trainer.save_metrics("eval", metrics) |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def create_code_dataset(examples: List[Dict[str, str]]) -> Dataset: |
|
|
""" |
|
|
Create a dataset from code examples |
|
|
|
|
|
Args: |
|
|
examples: List of dictionaries with 'prompt' and 'completion' keys |
|
|
|
|
|
Returns: |
|
|
Dataset object |
|
|
""" |
|
|
return Dataset.from_dict({ |
|
|
"prompt": [ex["prompt"] for ex in examples], |
|
|
"completion": [ex["completion"] for ex in examples] |
|
|
}) |
|
|
|
|
|
|
|
|
def create_math_dataset(examples: List[Dict[str, str]]) -> Dataset: |
|
|
""" |
|
|
Create a dataset from math examples |
|
|
|
|
|
Args: |
|
|
examples: List of dictionaries with 'problem' and 'solution' keys |
|
|
|
|
|
Returns: |
|
|
Dataset object |
|
|
""" |
|
|
return Dataset.from_dict({ |
|
|
"prompt": [f"Problem: {ex['problem']}\nSolution:" for ex in examples], |
|
|
"completion": [ex["solution"] for ex in examples] |
|
|
}) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main training script""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train Helion-OSC model") |
|
|
|
|
|
|
|
|
parser.add_argument("--model_name_or_path", type=str, default="DeepXR/Helion-OSC") |
|
|
parser.add_argument("--use_lora", action="store_true", default=True) |
|
|
parser.add_argument("--lora_r", type=int, default=16) |
|
|
parser.add_argument("--lora_alpha", type=int, default=32) |
|
|
parser.add_argument("--lora_dropout", type=float, default=0.05) |
|
|
parser.add_argument("--load_in_8bit", action="store_true") |
|
|
parser.add_argument("--load_in_4bit", action="store_true") |
|
|
|
|
|
|
|
|
parser.add_argument("--dataset_name", type=str, default=None) |
|
|
parser.add_argument("--dataset_path", type=str, default=None) |
|
|
parser.add_argument("--train_file", type=str, required=True) |
|
|
parser.add_argument("--validation_file", type=str, default=None) |
|
|
parser.add_argument("--max_seq_length", type=int, default=2048) |
|
|
parser.add_argument("--preprocessing_num_workers", type=int, default=4) |
|
|
|
|
|
|
|
|
parser.add_argument("--output_dir", type=str, required=True) |
|
|
parser.add_argument("--num_train_epochs", type=int, default=3) |
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=4) |
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=4) |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
|
|
parser.add_argument("--learning_rate", type=float, default=2e-5) |
|
|
parser.add_argument("--warmup_steps", type=int, default=100) |
|
|
parser.add_argument("--logging_steps", type=int, default=10) |
|
|
parser.add_argument("--save_steps", type=int, default=500) |
|
|
parser.add_argument("--eval_steps", type=int, default=500) |
|
|
parser.add_argument("--save_total_limit", type=int, default=3) |
|
|
parser.add_argument("--fp16", action="store_true") |
|
|
parser.add_argument("--bf16", action="store_true") |
|
|
parser.add_argument("--gradient_checkpointing", action="store_true") |
|
|
parser.add_argument("--use_wandb", action="store_true") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
model_args = ModelArguments( |
|
|
model_name_or_path=args.model_name_or_path, |
|
|
use_lora=args.use_lora, |
|
|
lora_r=args.lora_r, |
|
|
lora_alpha=args.lora_alpha, |
|
|
lora_dropout=args.lora_dropout, |
|
|
load_in_8bit=args.load_in_8bit, |
|
|
load_in_4bit=args.load_in_4bit |
|
|
) |
|
|
|
|
|
data_args = DataArguments( |
|
|
dataset_name=args.dataset_name, |
|
|
dataset_path=args.dataset_path, |
|
|
train_file=args.train_file, |
|
|
validation_file=args.validation_file, |
|
|
max_seq_length=args.max_seq_length, |
|
|
preprocessing_num_workers=args.preprocessing_num_workers |
|
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=args.output_dir, |
|
|
num_train_epochs=args.num_train_epochs, |
|
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size, |
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
|
learning_rate=args.learning_rate, |
|
|
warmup_steps=args.warmup_steps, |
|
|
logging_steps=args.logging_steps, |
|
|
save_steps=args.save_steps, |
|
|
eval_steps=args.eval_steps, |
|
|
save_total_limit=args.save_total_limit, |
|
|
fp16=args.fp16, |
|
|
bf16=args.bf16, |
|
|
gradient_checkpointing=args.gradient_checkpointing, |
|
|
report_to="wandb" if args.use_wandb else "none", |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="eval_loss", |
|
|
greater_is_better=False, |
|
|
evaluation_strategy="steps", |
|
|
save_strategy="steps", |
|
|
logging_dir=f"{args.output_dir}/logs", |
|
|
remove_unused_columns=False |
|
|
) |
|
|
|
|
|
|
|
|
helion_trainer = HelionOSCTrainer( |
|
|
model_args=model_args, |
|
|
data_args=data_args, |
|
|
training_args=training_args |
|
|
) |
|
|
|
|
|
|
|
|
trainer, metrics = helion_trainer.train() |
|
|
|
|
|
|
|
|
if args.validation_file: |
|
|
eval_metrics = helion_trainer.evaluate(trainer) |
|
|
logger.info(f"Evaluation metrics: {eval_metrics}") |
|
|
|
|
|
logger.info("Training pipeline completed!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |