""" Helion-OSC Inference Script DeepXR/Helion-OSC - Mathematical Coding Language Model This module provides comprehensive inference capabilities for the Helion-OSC model, including specialized methods for different programming and mathematical tasks. """ import torch import json import logging from typing import Optional, Dict, Any, List, Union from transformers import ( AutoTokenizer, AutoModelForCausalLM, GenerationConfig, StoppingCriteria, StoppingCriteriaList ) from dataclasses import dataclass import warnings # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) @dataclass class GenerationParameters: """Parameters for text generation""" max_length: int = 2048 temperature: float = 0.7 top_p: float = 0.95 top_k: int = 50 repetition_penalty: float = 1.05 length_penalty: float = 1.0 do_sample: bool = True num_return_sequences: int = 1 early_stopping: bool = False class CodeStoppingCriteria(StoppingCriteria): """Custom stopping criteria for code generation""" def __init__(self, stop_sequences: List[str], tokenizer): self.stop_sequences = stop_sequences self.tokenizer = tokenizer def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) return any(seq in decoded for seq in self.stop_sequences) class HelionOSCInference: """ Comprehensive inference wrapper for Helion-OSC model Supports multiple generation modes: - Code generation - Mathematical reasoning - Algorithm design - Code debugging - Documentation generation """ def __init__( self, model_name: str = "DeepXR/Helion-OSC", device: Optional[str] = None, load_in_8bit: bool = False, load_in_4bit: bool = False, use_flash_attention: bool = True, trust_remote_code: bool = True ): """ Initialize the Helion-OSC model Args: model_name: HuggingFace model identifier device: Device to load model on (cuda/cpu/mps) load_in_8bit: Load model in 8-bit precision load_in_4bit: Load model in 4-bit precision use_flash_attention: Use flash attention for faster inference trust_remote_code: Trust remote code from model repository """ self.model_name = model_name self.device = self._get_device(device) self.load_in_8bit = load_in_8bit self.load_in_4bit = load_in_4bit logger.info(f"Initializing Helion-OSC on {self.device}...") # Load tokenizer self.tokenizer = self._load_tokenizer(trust_remote_code) # Load model self.model = self._load_model( use_flash_attention=use_flash_attention, trust_remote_code=trust_remote_code ) # Load generation configs self.generation_configs = self._load_generation_configs() logger.info("Model loaded successfully!") self._print_model_info() def _get_device(self, device: Optional[str]) -> str: """Determine the best available device""" if device: return device if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" return "cpu" def _load_tokenizer(self, trust_remote_code: bool): """Load and configure tokenizer""" logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=trust_remote_code, padding_side="left" ) # Ensure pad token is set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer def _load_model(self, use_flash_attention: bool, trust_remote_code: bool): """Load and configure model""" logger.info("Loading model...") model_kwargs = { "trust_remote_code": trust_remote_code, "low_cpu_mem_usage": True } # Configure precision and quantization if self.load_in_8bit: model_kwargs["load_in_8bit"] = True logger.info("Loading in 8-bit precision") elif self.load_in_4bit: model_kwargs["load_in_4bit"] = True model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16 model_kwargs["bnb_4bit_use_double_quant"] = True model_kwargs["bnb_4bit_quant_type"] = "nf4" logger.info("Loading in 4-bit precision") else: if self.device == "cuda": model_kwargs["torch_dtype"] = torch.bfloat16 else: model_kwargs["torch_dtype"] = torch.float32 # Configure device mapping if self.device == "cuda" and not (self.load_in_8bit or self.load_in_4bit): model_kwargs["device_map"] = "auto" # Load model model = AutoModelForCausalLM.from_pretrained( self.model_name, **model_kwargs ) # Move to device if needed if self.device != "cuda" or (self.load_in_8bit or self.load_in_4bit): if not (self.load_in_8bit or self.load_in_4bit): model = model.to(self.device) model.eval() # Enable gradient checkpointing for memory efficiency if needed if hasattr(model, 'gradient_checkpointing_enable'): model.gradient_checkpointing_enable() return model def _load_generation_configs(self) -> Dict[str, GenerationParameters]: """Load task-specific generation configurations""" return { "code_generation": GenerationParameters( max_length=4096, temperature=0.7, top_p=0.95, top_k=50, repetition_penalty=1.05, do_sample=True ), "mathematical_reasoning": GenerationParameters( max_length=2048, temperature=0.3, top_p=0.9, top_k=40, repetition_penalty=1.0, do_sample=False ), "code_completion": GenerationParameters( max_length=1024, temperature=0.6, top_p=0.92, top_k=45, repetition_penalty=1.03, do_sample=True ), "algorithm_design": GenerationParameters( max_length=3072, temperature=0.5, top_p=0.93, top_k=50, repetition_penalty=1.08, do_sample=True ), "debugging": GenerationParameters( max_length=2048, temperature=0.4, top_p=0.88, repetition_penalty=1.0, do_sample=False ) } def _print_model_info(self): """Print model information""" try: num_params = sum(p.numel() for p in self.model.parameters()) logger.info(f"Model parameters: {num_params:,}") logger.info(f"Model dtype: {next(self.model.parameters()).dtype}") logger.info(f"Device: {self.device}") except Exception as e: logger.warning(f"Could not get model info: {e}") def generate( self, prompt: Union[str, List[str]], task_type: str = "code_generation", custom_params: Optional[GenerationParameters] = None, stop_sequences: Optional[List[str]] = None, return_full_text: bool = False, **kwargs ) -> Union[str, List[str]]: """ Generate text based on prompt Args: prompt: Input prompt or list of prompts task_type: Type of task (code_generation, mathematical_reasoning, etc.) custom_params: Custom generation parameters stop_sequences: List of sequences to stop generation return_full_text: Whether to return full text including prompt **kwargs: Additional generation parameters Returns: Generated text or list of generated texts """ # Get generation parameters if custom_params: params = custom_params elif task_type in self.generation_configs: params = self.generation_configs[task_type] else: logger.warning(f"Unknown task type '{task_type}', using default parameters") params = GenerationParameters() # Override with kwargs for key, value in kwargs.items(): if hasattr(params, key): setattr(params, key, value) # Tokenize input is_batch = isinstance(prompt, list) inputs = self.tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=self.model.config.max_position_embeddings ).to(self.device) # Setup stopping criteria stopping_criteria = None if stop_sequences: stopping_criteria = StoppingCriteriaList([ CodeStoppingCriteria(stop_sequences, self.tokenizer) ]) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=params.max_length, temperature=params.temperature, top_p=params.top_p, top_k=params.top_k, repetition_penalty=params.repetition_penalty, length_penalty=params.length_penalty, do_sample=params.do_sample, num_return_sequences=params.num_return_sequences, early_stopping=params.early_stopping, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, stopping_criteria=stopping_criteria ) # Decode outputs generated_texts = [] for output in outputs: text = self.tokenizer.decode(output, skip_special_tokens=True) if not return_full_text and not is_batch: # Remove prompt from single generation if isinstance(prompt, str): text = text[len(prompt):].strip() generated_texts.append(text) return generated_texts if is_batch or params.num_return_sequences > 1 else generated_texts[0] def code_generation( self, prompt: str, language: Optional[str] = None, max_length: int = 4096, **kwargs ) -> str: """ Generate code for a given prompt Args: prompt: Code generation prompt language: Programming language (optional) max_length: Maximum length of generated code **kwargs: Additional generation parameters Returns: Generated code """ if language: prompt = f"Language: {language}\n{prompt}" return self.generate( prompt, task_type="code_generation", max_length=max_length, **kwargs ) def mathematical_reasoning( self, prompt: str, max_length: int = 2048, **kwargs ) -> str: """ Solve mathematical problems with step-by-step reasoning Args: prompt: Mathematical problem max_length: Maximum length of solution **kwargs: Additional generation parameters Returns: Mathematical solution with reasoning """ return self.generate( prompt, task_type="mathematical_reasoning", max_length=max_length, **kwargs ) def algorithm_design( self, prompt: str, include_complexity: bool = True, max_length: int = 3072, **kwargs ) -> str: """ Design algorithms with complexity analysis Args: prompt: Algorithm design prompt include_complexity: Whether to include complexity analysis max_length: Maximum length of output **kwargs: Additional generation parameters Returns: Algorithm design with analysis """ if include_complexity: prompt += "\n\nPlease include time and space complexity analysis." return self.generate( prompt, task_type="algorithm_design", max_length=max_length, **kwargs ) def debug_code( self, code: str, error_message: Optional[str] = None, max_length: int = 2048, **kwargs ) -> str: """ Debug code and provide fixes Args: code: Code to debug error_message: Optional error message max_length: Maximum length of output **kwargs: Additional generation parameters Returns: Debugging analysis and fixes """ prompt = f"Debug the following code:\n\n```\n{code}\n```" if error_message: prompt += f"\n\nError message: {error_message}" prompt += "\n\nProvide a detailed explanation and fixed code." return self.generate( prompt, task_type="debugging", max_length=max_length, **kwargs ) def complete_code( self, code_context: str, max_length: int = 1024, **kwargs ) -> str: """ Complete partial code Args: code_context: Partial code to complete max_length: Maximum length of completion **kwargs: Additional generation parameters Returns: Code completion """ return self.generate( code_context, task_type="code_completion", max_length=max_length, stop_sequences=["\n\n", "```", "###"], **kwargs ) def batch_generate( self, prompts: List[str], task_type: str = "code_generation", batch_size: int = 4, **kwargs ) -> List[str]: """ Generate responses for multiple prompts in batches Args: prompts: List of prompts task_type: Type of task batch_size: Batch size for processing **kwargs: Additional generation parameters Returns: List of generated responses """ results = [] for i in range(0, len(prompts), batch_size): batch = prompts[i:i + batch_size] batch_results = self.generate(batch, task_type=task_type, **kwargs) if isinstance(batch_results, str): batch_results = [batch_results] results.extend(batch_results) return results def main(): """Example usage and demonstrations""" print("=" * 80) print("Helion-OSC Inference Examples") print("=" * 80) # Initialize model helion = HelionOSCInference( load_in_8bit=False, # Set to True for lower memory usage load_in_4bit=False # Set to True for even lower memory usage ) # Example 1: Code Generation print("\n" + "=" * 80) print("Example 1: Code Generation") print("=" * 80) code_prompt = """Write a Python function to implement a binary search tree with the following methods: - insert(value): Insert a new value - search(value): Search for a value - delete(value): Delete a value - inorder_traversal(): Return inorder traversal Include proper documentation and type hints.""" print(f"\nPrompt:\n{code_prompt}") print("\nGenerating...") result = helion.code_generation(code_prompt, language="python") print(f"\nGenerated Code:\n{result}") # Example 2: Mathematical Reasoning print("\n" + "=" * 80) print("Example 2: Mathematical Reasoning") print("=" * 80) math_prompt = """Prove that the sum of the first n natural numbers equals n(n+1)/2 using mathematical induction.""" print(f"\nPrompt:\n{math_prompt}") print("\nGenerating...") result = helion.mathematical_reasoning(math_prompt) print(f"\nSolution:\n{result}") # Example 3: Algorithm Design print("\n" + "=" * 80) print("Example 3: Algorithm Design") print("=" * 80) algo_prompt = """Design an efficient algorithm to find the longest palindromic substring in a given string.""" print(f"\nPrompt:\n{algo_prompt}") print("\nGenerating...") result = helion.algorithm_design(algo_prompt, include_complexity=True) print(f"\nAlgorithm:\n{result}") # Example 4: Code Debugging print("\n" + "=" * 80) print("Example 4: Code Debugging") print("=" * 80) buggy_code = """ def fibonacci(n): if n <= 1: return n return fibonacci(n-1) + fibonacci(n-2) # This is too slow for large n result = fibonacci(100) """ print(f"\nBuggy Code:\n{buggy_code}") print("\nGenerating debugging analysis...") result = helion.debug_code(buggy_code, error_message="Takes too long to compute") print(f"\nDebug Analysis:\n{result}") # Example 5: Batch Processing print("\n" + "=" * 80) print("Example 5: Batch Code Generation") print("=" * 80) batch_prompts = [ "Write a Python function to reverse a linked list", "Write a JavaScript function to debounce API calls", "Write a Rust function to parse JSON safely" ] print("\nProcessing batch prompts...") results = helion.batch_generate(batch_prompts, batch_size=2) for i, (prompt, result) in enumerate(zip(batch_prompts, results), 1): print(f"\nPrompt {i}: {prompt}") print(f"Result {i}:\n{result}\n") print("=" * 80) print("Examples completed!") print("=" * 80) if __name__ == "__main__": main()