|
|
""" |
|
|
Helion-OSC Inference Script |
|
|
DeepXR/Helion-OSC - Mathematical Coding Language Model |
|
|
|
|
|
This module provides comprehensive inference capabilities for the Helion-OSC model, |
|
|
including specialized methods for different programming and mathematical tasks. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import json |
|
|
import logging |
|
|
from typing import Optional, Dict, Any, List, Union |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
GenerationConfig, |
|
|
StoppingCriteria, |
|
|
StoppingCriteriaList |
|
|
) |
|
|
from dataclasses import dataclass |
|
|
import warnings |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerationParameters: |
|
|
"""Parameters for text generation""" |
|
|
max_length: int = 2048 |
|
|
temperature: float = 0.7 |
|
|
top_p: float = 0.95 |
|
|
top_k: int = 50 |
|
|
repetition_penalty: float = 1.05 |
|
|
length_penalty: float = 1.0 |
|
|
do_sample: bool = True |
|
|
num_return_sequences: int = 1 |
|
|
early_stopping: bool = False |
|
|
|
|
|
|
|
|
class CodeStoppingCriteria(StoppingCriteria): |
|
|
"""Custom stopping criteria for code generation""" |
|
|
|
|
|
def __init__(self, stop_sequences: List[str], tokenizer): |
|
|
self.stop_sequences = stop_sequences |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
|
decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
return any(seq in decoded for seq in self.stop_sequences) |
|
|
|
|
|
|
|
|
class HelionOSCInference: |
|
|
""" |
|
|
Comprehensive inference wrapper for Helion-OSC model |
|
|
|
|
|
Supports multiple generation modes: |
|
|
- Code generation |
|
|
- Mathematical reasoning |
|
|
- Algorithm design |
|
|
- Code debugging |
|
|
- Documentation generation |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "DeepXR/Helion-OSC", |
|
|
device: Optional[str] = None, |
|
|
load_in_8bit: bool = False, |
|
|
load_in_4bit: bool = False, |
|
|
use_flash_attention: bool = True, |
|
|
trust_remote_code: bool = True |
|
|
): |
|
|
""" |
|
|
Initialize the Helion-OSC model |
|
|
|
|
|
Args: |
|
|
model_name: HuggingFace model identifier |
|
|
device: Device to load model on (cuda/cpu/mps) |
|
|
load_in_8bit: Load model in 8-bit precision |
|
|
load_in_4bit: Load model in 4-bit precision |
|
|
use_flash_attention: Use flash attention for faster inference |
|
|
trust_remote_code: Trust remote code from model repository |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.device = self._get_device(device) |
|
|
self.load_in_8bit = load_in_8bit |
|
|
self.load_in_4bit = load_in_4bit |
|
|
|
|
|
logger.info(f"Initializing Helion-OSC on {self.device}...") |
|
|
|
|
|
|
|
|
self.tokenizer = self._load_tokenizer(trust_remote_code) |
|
|
|
|
|
|
|
|
self.model = self._load_model( |
|
|
use_flash_attention=use_flash_attention, |
|
|
trust_remote_code=trust_remote_code |
|
|
) |
|
|
|
|
|
|
|
|
self.generation_configs = self._load_generation_configs() |
|
|
|
|
|
logger.info("Model loaded successfully!") |
|
|
self._print_model_info() |
|
|
|
|
|
def _get_device(self, device: Optional[str]) -> str: |
|
|
"""Determine the best available device""" |
|
|
if device: |
|
|
return device |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda" |
|
|
elif torch.backends.mps.is_available(): |
|
|
return "mps" |
|
|
return "cpu" |
|
|
|
|
|
def _load_tokenizer(self, trust_remote_code: bool): |
|
|
"""Load and configure tokenizer""" |
|
|
logger.info("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_name, |
|
|
trust_remote_code=trust_remote_code, |
|
|
padding_side="left" |
|
|
) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
return tokenizer |
|
|
|
|
|
def _load_model(self, use_flash_attention: bool, trust_remote_code: bool): |
|
|
"""Load and configure model""" |
|
|
logger.info("Loading model...") |
|
|
|
|
|
model_kwargs = { |
|
|
"trust_remote_code": trust_remote_code, |
|
|
"low_cpu_mem_usage": True |
|
|
} |
|
|
|
|
|
|
|
|
if self.load_in_8bit: |
|
|
model_kwargs["load_in_8bit"] = True |
|
|
logger.info("Loading in 8-bit precision") |
|
|
elif self.load_in_4bit: |
|
|
model_kwargs["load_in_4bit"] = True |
|
|
model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16 |
|
|
model_kwargs["bnb_4bit_use_double_quant"] = True |
|
|
model_kwargs["bnb_4bit_quant_type"] = "nf4" |
|
|
logger.info("Loading in 4-bit precision") |
|
|
else: |
|
|
if self.device == "cuda": |
|
|
model_kwargs["torch_dtype"] = torch.bfloat16 |
|
|
else: |
|
|
model_kwargs["torch_dtype"] = torch.float32 |
|
|
|
|
|
|
|
|
if self.device == "cuda" and not (self.load_in_8bit or self.load_in_4bit): |
|
|
model_kwargs["device_map"] = "auto" |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_name, |
|
|
**model_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if self.device != "cuda" or (self.load_in_8bit or self.load_in_4bit): |
|
|
if not (self.load_in_8bit or self.load_in_4bit): |
|
|
model = model.to(self.device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if hasattr(model, 'gradient_checkpointing_enable'): |
|
|
model.gradient_checkpointing_enable() |
|
|
|
|
|
return model |
|
|
|
|
|
def _load_generation_configs(self) -> Dict[str, GenerationParameters]: |
|
|
"""Load task-specific generation configurations""" |
|
|
return { |
|
|
"code_generation": GenerationParameters( |
|
|
max_length=4096, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
top_k=50, |
|
|
repetition_penalty=1.05, |
|
|
do_sample=True |
|
|
), |
|
|
"mathematical_reasoning": GenerationParameters( |
|
|
max_length=2048, |
|
|
temperature=0.3, |
|
|
top_p=0.9, |
|
|
top_k=40, |
|
|
repetition_penalty=1.0, |
|
|
do_sample=False |
|
|
), |
|
|
"code_completion": GenerationParameters( |
|
|
max_length=1024, |
|
|
temperature=0.6, |
|
|
top_p=0.92, |
|
|
top_k=45, |
|
|
repetition_penalty=1.03, |
|
|
do_sample=True |
|
|
), |
|
|
"algorithm_design": GenerationParameters( |
|
|
max_length=3072, |
|
|
temperature=0.5, |
|
|
top_p=0.93, |
|
|
top_k=50, |
|
|
repetition_penalty=1.08, |
|
|
do_sample=True |
|
|
), |
|
|
"debugging": GenerationParameters( |
|
|
max_length=2048, |
|
|
temperature=0.4, |
|
|
top_p=0.88, |
|
|
repetition_penalty=1.0, |
|
|
do_sample=False |
|
|
) |
|
|
} |
|
|
|
|
|
def _print_model_info(self): |
|
|
"""Print model information""" |
|
|
try: |
|
|
num_params = sum(p.numel() for p in self.model.parameters()) |
|
|
logger.info(f"Model parameters: {num_params:,}") |
|
|
logger.info(f"Model dtype: {next(self.model.parameters()).dtype}") |
|
|
logger.info(f"Device: {self.device}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not get model info: {e}") |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: Union[str, List[str]], |
|
|
task_type: str = "code_generation", |
|
|
custom_params: Optional[GenerationParameters] = None, |
|
|
stop_sequences: Optional[List[str]] = None, |
|
|
return_full_text: bool = False, |
|
|
**kwargs |
|
|
) -> Union[str, List[str]]: |
|
|
""" |
|
|
Generate text based on prompt |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt or list of prompts |
|
|
task_type: Type of task (code_generation, mathematical_reasoning, etc.) |
|
|
custom_params: Custom generation parameters |
|
|
stop_sequences: List of sequences to stop generation |
|
|
return_full_text: Whether to return full text including prompt |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Generated text or list of generated texts |
|
|
""" |
|
|
|
|
|
if custom_params: |
|
|
params = custom_params |
|
|
elif task_type in self.generation_configs: |
|
|
params = self.generation_configs[task_type] |
|
|
else: |
|
|
logger.warning(f"Unknown task type '{task_type}', using default parameters") |
|
|
params = GenerationParameters() |
|
|
|
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
if hasattr(params, key): |
|
|
setattr(params, key, value) |
|
|
|
|
|
|
|
|
is_batch = isinstance(prompt, list) |
|
|
inputs = self.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=self.model.config.max_position_embeddings |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
stopping_criteria = None |
|
|
if stop_sequences: |
|
|
stopping_criteria = StoppingCriteriaList([ |
|
|
CodeStoppingCriteria(stop_sequences, self.tokenizer) |
|
|
]) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_length=params.max_length, |
|
|
temperature=params.temperature, |
|
|
top_p=params.top_p, |
|
|
top_k=params.top_k, |
|
|
repetition_penalty=params.repetition_penalty, |
|
|
length_penalty=params.length_penalty, |
|
|
do_sample=params.do_sample, |
|
|
num_return_sequences=params.num_return_sequences, |
|
|
early_stopping=params.early_stopping, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
stopping_criteria=stopping_criteria |
|
|
) |
|
|
|
|
|
|
|
|
generated_texts = [] |
|
|
for output in outputs: |
|
|
text = self.tokenizer.decode(output, skip_special_tokens=True) |
|
|
if not return_full_text and not is_batch: |
|
|
|
|
|
if isinstance(prompt, str): |
|
|
text = text[len(prompt):].strip() |
|
|
generated_texts.append(text) |
|
|
|
|
|
return generated_texts if is_batch or params.num_return_sequences > 1 else generated_texts[0] |
|
|
|
|
|
def code_generation( |
|
|
self, |
|
|
prompt: str, |
|
|
language: Optional[str] = None, |
|
|
max_length: int = 4096, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Generate code for a given prompt |
|
|
|
|
|
Args: |
|
|
prompt: Code generation prompt |
|
|
language: Programming language (optional) |
|
|
max_length: Maximum length of generated code |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Generated code |
|
|
""" |
|
|
if language: |
|
|
prompt = f"Language: {language}\n{prompt}" |
|
|
|
|
|
return self.generate( |
|
|
prompt, |
|
|
task_type="code_generation", |
|
|
max_length=max_length, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def mathematical_reasoning( |
|
|
self, |
|
|
prompt: str, |
|
|
max_length: int = 2048, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Solve mathematical problems with step-by-step reasoning |
|
|
|
|
|
Args: |
|
|
prompt: Mathematical problem |
|
|
max_length: Maximum length of solution |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Mathematical solution with reasoning |
|
|
""" |
|
|
return self.generate( |
|
|
prompt, |
|
|
task_type="mathematical_reasoning", |
|
|
max_length=max_length, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def algorithm_design( |
|
|
self, |
|
|
prompt: str, |
|
|
include_complexity: bool = True, |
|
|
max_length: int = 3072, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Design algorithms with complexity analysis |
|
|
|
|
|
Args: |
|
|
prompt: Algorithm design prompt |
|
|
include_complexity: Whether to include complexity analysis |
|
|
max_length: Maximum length of output |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Algorithm design with analysis |
|
|
""" |
|
|
if include_complexity: |
|
|
prompt += "\n\nPlease include time and space complexity analysis." |
|
|
|
|
|
return self.generate( |
|
|
prompt, |
|
|
task_type="algorithm_design", |
|
|
max_length=max_length, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def debug_code( |
|
|
self, |
|
|
code: str, |
|
|
error_message: Optional[str] = None, |
|
|
max_length: int = 2048, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Debug code and provide fixes |
|
|
|
|
|
Args: |
|
|
code: Code to debug |
|
|
error_message: Optional error message |
|
|
max_length: Maximum length of output |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Debugging analysis and fixes |
|
|
""" |
|
|
prompt = f"Debug the following code:\n\n```\n{code}\n```" |
|
|
if error_message: |
|
|
prompt += f"\n\nError message: {error_message}" |
|
|
prompt += "\n\nProvide a detailed explanation and fixed code." |
|
|
|
|
|
return self.generate( |
|
|
prompt, |
|
|
task_type="debugging", |
|
|
max_length=max_length, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def complete_code( |
|
|
self, |
|
|
code_context: str, |
|
|
max_length: int = 1024, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Complete partial code |
|
|
|
|
|
Args: |
|
|
code_context: Partial code to complete |
|
|
max_length: Maximum length of completion |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Code completion |
|
|
""" |
|
|
return self.generate( |
|
|
code_context, |
|
|
task_type="code_completion", |
|
|
max_length=max_length, |
|
|
stop_sequences=["\n\n", "```", "###"], |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def batch_generate( |
|
|
self, |
|
|
prompts: List[str], |
|
|
task_type: str = "code_generation", |
|
|
batch_size: int = 4, |
|
|
**kwargs |
|
|
) -> List[str]: |
|
|
""" |
|
|
Generate responses for multiple prompts in batches |
|
|
|
|
|
Args: |
|
|
prompts: List of prompts |
|
|
task_type: Type of task |
|
|
batch_size: Batch size for processing |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
List of generated responses |
|
|
""" |
|
|
results = [] |
|
|
for i in range(0, len(prompts), batch_size): |
|
|
batch = prompts[i:i + batch_size] |
|
|
batch_results = self.generate(batch, task_type=task_type, **kwargs) |
|
|
if isinstance(batch_results, str): |
|
|
batch_results = [batch_results] |
|
|
results.extend(batch_results) |
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example usage and demonstrations""" |
|
|
print("=" * 80) |
|
|
print("Helion-OSC Inference Examples") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
helion = HelionOSCInference( |
|
|
load_in_8bit=False, |
|
|
load_in_4bit=False |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Example 1: Code Generation") |
|
|
print("=" * 80) |
|
|
code_prompt = """Write a Python function to implement a binary search tree with the following methods: |
|
|
- insert(value): Insert a new value |
|
|
- search(value): Search for a value |
|
|
- delete(value): Delete a value |
|
|
- inorder_traversal(): Return inorder traversal |
|
|
|
|
|
Include proper documentation and type hints.""" |
|
|
|
|
|
print(f"\nPrompt:\n{code_prompt}") |
|
|
print("\nGenerating...") |
|
|
result = helion.code_generation(code_prompt, language="python") |
|
|
print(f"\nGenerated Code:\n{result}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Example 2: Mathematical Reasoning") |
|
|
print("=" * 80) |
|
|
math_prompt = """Prove that the sum of the first n natural numbers equals n(n+1)/2 using mathematical induction.""" |
|
|
|
|
|
print(f"\nPrompt:\n{math_prompt}") |
|
|
print("\nGenerating...") |
|
|
result = helion.mathematical_reasoning(math_prompt) |
|
|
print(f"\nSolution:\n{result}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Example 3: Algorithm Design") |
|
|
print("=" * 80) |
|
|
algo_prompt = """Design an efficient algorithm to find the longest palindromic substring in a given string.""" |
|
|
|
|
|
print(f"\nPrompt:\n{algo_prompt}") |
|
|
print("\nGenerating...") |
|
|
result = helion.algorithm_design(algo_prompt, include_complexity=True) |
|
|
print(f"\nAlgorithm:\n{result}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Example 4: Code Debugging") |
|
|
print("=" * 80) |
|
|
buggy_code = """ |
|
|
def fibonacci(n): |
|
|
if n <= 1: |
|
|
return n |
|
|
return fibonacci(n-1) + fibonacci(n-2) |
|
|
|
|
|
# This is too slow for large n |
|
|
result = fibonacci(100) |
|
|
""" |
|
|
|
|
|
print(f"\nBuggy Code:\n{buggy_code}") |
|
|
print("\nGenerating debugging analysis...") |
|
|
result = helion.debug_code(buggy_code, error_message="Takes too long to compute") |
|
|
print(f"\nDebug Analysis:\n{result}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Example 5: Batch Code Generation") |
|
|
print("=" * 80) |
|
|
batch_prompts = [ |
|
|
"Write a Python function to reverse a linked list", |
|
|
"Write a JavaScript function to debounce API calls", |
|
|
"Write a Rust function to parse JSON safely" |
|
|
] |
|
|
|
|
|
print("\nProcessing batch prompts...") |
|
|
results = helion.batch_generate(batch_prompts, batch_size=2) |
|
|
for i, (prompt, result) in enumerate(zip(batch_prompts, results), 1): |
|
|
print(f"\nPrompt {i}: {prompt}") |
|
|
print(f"Result {i}:\n{result}\n") |
|
|
|
|
|
print("=" * 80) |
|
|
print("Examples completed!") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |