Helion-OSC / inference.py
Trouter-Library's picture
Update inference.py
bbad13f verified
"""
Helion-OSC Inference Script
DeepXR/Helion-OSC - Mathematical Coding Language Model
This module provides comprehensive inference capabilities for the Helion-OSC model,
including specialized methods for different programming and mathematical tasks.
"""
import torch
import json
import logging
from typing import Optional, Dict, Any, List, Union
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
GenerationConfig,
StoppingCriteria,
StoppingCriteriaList
)
from dataclasses import dataclass
import warnings
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
@dataclass
class GenerationParameters:
"""Parameters for text generation"""
max_length: int = 2048
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 50
repetition_penalty: float = 1.05
length_penalty: float = 1.0
do_sample: bool = True
num_return_sequences: int = 1
early_stopping: bool = False
class CodeStoppingCriteria(StoppingCriteria):
"""Custom stopping criteria for code generation"""
def __init__(self, stop_sequences: List[str], tokenizer):
self.stop_sequences = stop_sequences
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
return any(seq in decoded for seq in self.stop_sequences)
class HelionOSCInference:
"""
Comprehensive inference wrapper for Helion-OSC model
Supports multiple generation modes:
- Code generation
- Mathematical reasoning
- Algorithm design
- Code debugging
- Documentation generation
"""
def __init__(
self,
model_name: str = "DeepXR/Helion-OSC",
device: Optional[str] = None,
load_in_8bit: bool = False,
load_in_4bit: bool = False,
use_flash_attention: bool = True,
trust_remote_code: bool = True
):
"""
Initialize the Helion-OSC model
Args:
model_name: HuggingFace model identifier
device: Device to load model on (cuda/cpu/mps)
load_in_8bit: Load model in 8-bit precision
load_in_4bit: Load model in 4-bit precision
use_flash_attention: Use flash attention for faster inference
trust_remote_code: Trust remote code from model repository
"""
self.model_name = model_name
self.device = self._get_device(device)
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
logger.info(f"Initializing Helion-OSC on {self.device}...")
# Load tokenizer
self.tokenizer = self._load_tokenizer(trust_remote_code)
# Load model
self.model = self._load_model(
use_flash_attention=use_flash_attention,
trust_remote_code=trust_remote_code
)
# Load generation configs
self.generation_configs = self._load_generation_configs()
logger.info("Model loaded successfully!")
self._print_model_info()
def _get_device(self, device: Optional[str]) -> str:
"""Determine the best available device"""
if device:
return device
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def _load_tokenizer(self, trust_remote_code: bool):
"""Load and configure tokenizer"""
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=trust_remote_code,
padding_side="left"
)
# Ensure pad token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def _load_model(self, use_flash_attention: bool, trust_remote_code: bool):
"""Load and configure model"""
logger.info("Loading model...")
model_kwargs = {
"trust_remote_code": trust_remote_code,
"low_cpu_mem_usage": True
}
# Configure precision and quantization
if self.load_in_8bit:
model_kwargs["load_in_8bit"] = True
logger.info("Loading in 8-bit precision")
elif self.load_in_4bit:
model_kwargs["load_in_4bit"] = True
model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16
model_kwargs["bnb_4bit_use_double_quant"] = True
model_kwargs["bnb_4bit_quant_type"] = "nf4"
logger.info("Loading in 4-bit precision")
else:
if self.device == "cuda":
model_kwargs["torch_dtype"] = torch.bfloat16
else:
model_kwargs["torch_dtype"] = torch.float32
# Configure device mapping
if self.device == "cuda" and not (self.load_in_8bit or self.load_in_4bit):
model_kwargs["device_map"] = "auto"
# Load model
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
**model_kwargs
)
# Move to device if needed
if self.device != "cuda" or (self.load_in_8bit or self.load_in_4bit):
if not (self.load_in_8bit or self.load_in_4bit):
model = model.to(self.device)
model.eval()
# Enable gradient checkpointing for memory efficiency if needed
if hasattr(model, 'gradient_checkpointing_enable'):
model.gradient_checkpointing_enable()
return model
def _load_generation_configs(self) -> Dict[str, GenerationParameters]:
"""Load task-specific generation configurations"""
return {
"code_generation": GenerationParameters(
max_length=4096,
temperature=0.7,
top_p=0.95,
top_k=50,
repetition_penalty=1.05,
do_sample=True
),
"mathematical_reasoning": GenerationParameters(
max_length=2048,
temperature=0.3,
top_p=0.9,
top_k=40,
repetition_penalty=1.0,
do_sample=False
),
"code_completion": GenerationParameters(
max_length=1024,
temperature=0.6,
top_p=0.92,
top_k=45,
repetition_penalty=1.03,
do_sample=True
),
"algorithm_design": GenerationParameters(
max_length=3072,
temperature=0.5,
top_p=0.93,
top_k=50,
repetition_penalty=1.08,
do_sample=True
),
"debugging": GenerationParameters(
max_length=2048,
temperature=0.4,
top_p=0.88,
repetition_penalty=1.0,
do_sample=False
)
}
def _print_model_info(self):
"""Print model information"""
try:
num_params = sum(p.numel() for p in self.model.parameters())
logger.info(f"Model parameters: {num_params:,}")
logger.info(f"Model dtype: {next(self.model.parameters()).dtype}")
logger.info(f"Device: {self.device}")
except Exception as e:
logger.warning(f"Could not get model info: {e}")
def generate(
self,
prompt: Union[str, List[str]],
task_type: str = "code_generation",
custom_params: Optional[GenerationParameters] = None,
stop_sequences: Optional[List[str]] = None,
return_full_text: bool = False,
**kwargs
) -> Union[str, List[str]]:
"""
Generate text based on prompt
Args:
prompt: Input prompt or list of prompts
task_type: Type of task (code_generation, mathematical_reasoning, etc.)
custom_params: Custom generation parameters
stop_sequences: List of sequences to stop generation
return_full_text: Whether to return full text including prompt
**kwargs: Additional generation parameters
Returns:
Generated text or list of generated texts
"""
# Get generation parameters
if custom_params:
params = custom_params
elif task_type in self.generation_configs:
params = self.generation_configs[task_type]
else:
logger.warning(f"Unknown task type '{task_type}', using default parameters")
params = GenerationParameters()
# Override with kwargs
for key, value in kwargs.items():
if hasattr(params, key):
setattr(params, key, value)
# Tokenize input
is_batch = isinstance(prompt, list)
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.model.config.max_position_embeddings
).to(self.device)
# Setup stopping criteria
stopping_criteria = None
if stop_sequences:
stopping_criteria = StoppingCriteriaList([
CodeStoppingCriteria(stop_sequences, self.tokenizer)
])
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=params.max_length,
temperature=params.temperature,
top_p=params.top_p,
top_k=params.top_k,
repetition_penalty=params.repetition_penalty,
length_penalty=params.length_penalty,
do_sample=params.do_sample,
num_return_sequences=params.num_return_sequences,
early_stopping=params.early_stopping,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
stopping_criteria=stopping_criteria
)
# Decode outputs
generated_texts = []
for output in outputs:
text = self.tokenizer.decode(output, skip_special_tokens=True)
if not return_full_text and not is_batch:
# Remove prompt from single generation
if isinstance(prompt, str):
text = text[len(prompt):].strip()
generated_texts.append(text)
return generated_texts if is_batch or params.num_return_sequences > 1 else generated_texts[0]
def code_generation(
self,
prompt: str,
language: Optional[str] = None,
max_length: int = 4096,
**kwargs
) -> str:
"""
Generate code for a given prompt
Args:
prompt: Code generation prompt
language: Programming language (optional)
max_length: Maximum length of generated code
**kwargs: Additional generation parameters
Returns:
Generated code
"""
if language:
prompt = f"Language: {language}\n{prompt}"
return self.generate(
prompt,
task_type="code_generation",
max_length=max_length,
**kwargs
)
def mathematical_reasoning(
self,
prompt: str,
max_length: int = 2048,
**kwargs
) -> str:
"""
Solve mathematical problems with step-by-step reasoning
Args:
prompt: Mathematical problem
max_length: Maximum length of solution
**kwargs: Additional generation parameters
Returns:
Mathematical solution with reasoning
"""
return self.generate(
prompt,
task_type="mathematical_reasoning",
max_length=max_length,
**kwargs
)
def algorithm_design(
self,
prompt: str,
include_complexity: bool = True,
max_length: int = 3072,
**kwargs
) -> str:
"""
Design algorithms with complexity analysis
Args:
prompt: Algorithm design prompt
include_complexity: Whether to include complexity analysis
max_length: Maximum length of output
**kwargs: Additional generation parameters
Returns:
Algorithm design with analysis
"""
if include_complexity:
prompt += "\n\nPlease include time and space complexity analysis."
return self.generate(
prompt,
task_type="algorithm_design",
max_length=max_length,
**kwargs
)
def debug_code(
self,
code: str,
error_message: Optional[str] = None,
max_length: int = 2048,
**kwargs
) -> str:
"""
Debug code and provide fixes
Args:
code: Code to debug
error_message: Optional error message
max_length: Maximum length of output
**kwargs: Additional generation parameters
Returns:
Debugging analysis and fixes
"""
prompt = f"Debug the following code:\n\n```\n{code}\n```"
if error_message:
prompt += f"\n\nError message: {error_message}"
prompt += "\n\nProvide a detailed explanation and fixed code."
return self.generate(
prompt,
task_type="debugging",
max_length=max_length,
**kwargs
)
def complete_code(
self,
code_context: str,
max_length: int = 1024,
**kwargs
) -> str:
"""
Complete partial code
Args:
code_context: Partial code to complete
max_length: Maximum length of completion
**kwargs: Additional generation parameters
Returns:
Code completion
"""
return self.generate(
code_context,
task_type="code_completion",
max_length=max_length,
stop_sequences=["\n\n", "```", "###"],
**kwargs
)
def batch_generate(
self,
prompts: List[str],
task_type: str = "code_generation",
batch_size: int = 4,
**kwargs
) -> List[str]:
"""
Generate responses for multiple prompts in batches
Args:
prompts: List of prompts
task_type: Type of task
batch_size: Batch size for processing
**kwargs: Additional generation parameters
Returns:
List of generated responses
"""
results = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i + batch_size]
batch_results = self.generate(batch, task_type=task_type, **kwargs)
if isinstance(batch_results, str):
batch_results = [batch_results]
results.extend(batch_results)
return results
def main():
"""Example usage and demonstrations"""
print("=" * 80)
print("Helion-OSC Inference Examples")
print("=" * 80)
# Initialize model
helion = HelionOSCInference(
load_in_8bit=False, # Set to True for lower memory usage
load_in_4bit=False # Set to True for even lower memory usage
)
# Example 1: Code Generation
print("\n" + "=" * 80)
print("Example 1: Code Generation")
print("=" * 80)
code_prompt = """Write a Python function to implement a binary search tree with the following methods:
- insert(value): Insert a new value
- search(value): Search for a value
- delete(value): Delete a value
- inorder_traversal(): Return inorder traversal
Include proper documentation and type hints."""
print(f"\nPrompt:\n{code_prompt}")
print("\nGenerating...")
result = helion.code_generation(code_prompt, language="python")
print(f"\nGenerated Code:\n{result}")
# Example 2: Mathematical Reasoning
print("\n" + "=" * 80)
print("Example 2: Mathematical Reasoning")
print("=" * 80)
math_prompt = """Prove that the sum of the first n natural numbers equals n(n+1)/2 using mathematical induction."""
print(f"\nPrompt:\n{math_prompt}")
print("\nGenerating...")
result = helion.mathematical_reasoning(math_prompt)
print(f"\nSolution:\n{result}")
# Example 3: Algorithm Design
print("\n" + "=" * 80)
print("Example 3: Algorithm Design")
print("=" * 80)
algo_prompt = """Design an efficient algorithm to find the longest palindromic substring in a given string."""
print(f"\nPrompt:\n{algo_prompt}")
print("\nGenerating...")
result = helion.algorithm_design(algo_prompt, include_complexity=True)
print(f"\nAlgorithm:\n{result}")
# Example 4: Code Debugging
print("\n" + "=" * 80)
print("Example 4: Code Debugging")
print("=" * 80)
buggy_code = """
def fibonacci(n):
if n <= 1:
return n
return fibonacci(n-1) + fibonacci(n-2)
# This is too slow for large n
result = fibonacci(100)
"""
print(f"\nBuggy Code:\n{buggy_code}")
print("\nGenerating debugging analysis...")
result = helion.debug_code(buggy_code, error_message="Takes too long to compute")
print(f"\nDebug Analysis:\n{result}")
# Example 5: Batch Processing
print("\n" + "=" * 80)
print("Example 5: Batch Code Generation")
print("=" * 80)
batch_prompts = [
"Write a Python function to reverse a linked list",
"Write a JavaScript function to debounce API calls",
"Write a Rust function to parse JSON safely"
]
print("\nProcessing batch prompts...")
results = helion.batch_generate(batch_prompts, batch_size=2)
for i, (prompt, result) in enumerate(zip(batch_prompts, results), 1):
print(f"\nPrompt {i}: {prompt}")
print(f"Result {i}:\n{result}\n")
print("=" * 80)
print("Examples completed!")
print("=" * 80)
if __name__ == "__main__":
main()