|
|
""" |
|
|
Multi-Model Inference System for Helion-OSC |
|
|
Supports 4 different model variants for specialized tasks |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from typing import Optional, Dict, Any, List |
|
|
import logging |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ModelType(Enum): |
|
|
"""Available model types""" |
|
|
BASE = "base" |
|
|
MATH = "math" |
|
|
ALGORITHM = "algorithm" |
|
|
DEBUG = "debug" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
"""Configuration for each model variant""" |
|
|
name: str |
|
|
model_path: str |
|
|
description: str |
|
|
default_temperature: float |
|
|
default_max_length: int |
|
|
default_top_p: float |
|
|
|
|
|
|
|
|
class MultiModelInference: |
|
|
""" |
|
|
Multi-model inference system with 4 specialized models |
|
|
""" |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
ModelType.BASE: ModelConfig( |
|
|
name="Helion-OSC Base", |
|
|
model_path="DeepXR/Helion-OSC", |
|
|
description="General purpose code generation and completion", |
|
|
default_temperature=0.7, |
|
|
default_max_length=2048, |
|
|
default_top_p=0.95 |
|
|
), |
|
|
ModelType.MATH: ModelConfig( |
|
|
name="Helion-OSC Math", |
|
|
model_path="DeepXR/Helion-OSC", |
|
|
description="Mathematical reasoning and theorem proving", |
|
|
default_temperature=0.3, |
|
|
default_max_length=2048, |
|
|
default_top_p=0.9 |
|
|
), |
|
|
ModelType.ALGORITHM: ModelConfig( |
|
|
name="Helion-OSC Algorithm", |
|
|
model_path="DeepXR/Helion-OSC", |
|
|
description="Algorithm design and optimization", |
|
|
default_temperature=0.5, |
|
|
default_max_length=3072, |
|
|
default_top_p=0.93 |
|
|
), |
|
|
ModelType.DEBUG: ModelConfig( |
|
|
name="Helion-OSC Debug", |
|
|
model_path="DeepXR/Helion-OSC", |
|
|
description="Code debugging and error fixing", |
|
|
default_temperature=0.4, |
|
|
default_max_length=2048, |
|
|
default_top_p=0.88 |
|
|
) |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
device: Optional[str] = None, |
|
|
load_all_models: bool = False, |
|
|
use_8bit: bool = False |
|
|
): |
|
|
""" |
|
|
Initialize multi-model inference system |
|
|
|
|
|
Args: |
|
|
device: Device to use (cuda/cpu) |
|
|
load_all_models: Load all models at startup (uses more memory) |
|
|
use_8bit: Use 8-bit quantization for memory efficiency |
|
|
""" |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.use_8bit = use_8bit |
|
|
self.loaded_models: Dict[ModelType, Any] = {} |
|
|
self.tokenizers: Dict[ModelType, Any] = {} |
|
|
|
|
|
logger.info(f"Initializing Multi-Model Inference System on {self.device}") |
|
|
|
|
|
if load_all_models: |
|
|
logger.info("Loading all models at startup...") |
|
|
for model_type in ModelType: |
|
|
self._load_model(model_type) |
|
|
else: |
|
|
logger.info("Models will be loaded on-demand") |
|
|
|
|
|
def _load_model(self, model_type: ModelType): |
|
|
"""Load a specific model variant""" |
|
|
if model_type in self.loaded_models: |
|
|
logger.info(f"{model_type.value} model already loaded") |
|
|
return |
|
|
|
|
|
config = self.MODELS[model_type] |
|
|
logger.info(f"Loading {config.name}...") |
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
config.model_path, |
|
|
trust_remote_code=True |
|
|
) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
model_kwargs = { |
|
|
"trust_remote_code": True, |
|
|
"low_cpu_mem_usage": True |
|
|
} |
|
|
|
|
|
if self.use_8bit: |
|
|
model_kwargs["load_in_8bit"] = True |
|
|
elif self.device == "cuda": |
|
|
model_kwargs["torch_dtype"] = torch.bfloat16 |
|
|
model_kwargs["device_map"] = "auto" |
|
|
else: |
|
|
model_kwargs["torch_dtype"] = torch.float32 |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
config.model_path, |
|
|
**model_kwargs |
|
|
) |
|
|
|
|
|
if self.device == "cpu" and not self.use_8bit: |
|
|
model = model.to(self.device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
self.loaded_models[model_type] = model |
|
|
self.tokenizers[model_type] = tokenizer |
|
|
|
|
|
logger.info(f"✓ {config.name} loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load {config.name}: {e}") |
|
|
raise |
|
|
|
|
|
def _ensure_model_loaded(self, model_type: ModelType): |
|
|
"""Ensure a model is loaded before use""" |
|
|
if model_type not in self.loaded_models: |
|
|
self._load_model(model_type) |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
model_type: ModelType = ModelType.BASE, |
|
|
max_length: Optional[int] = None, |
|
|
temperature: Optional[float] = None, |
|
|
top_p: Optional[float] = None, |
|
|
top_k: int = 50, |
|
|
do_sample: Optional[bool] = None, |
|
|
num_return_sequences: int = 1, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Generate text using specified model |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt |
|
|
model_type: Which model to use |
|
|
max_length: Maximum generation length |
|
|
temperature: Sampling temperature |
|
|
top_p: Nucleus sampling parameter |
|
|
top_k: Top-k sampling parameter |
|
|
do_sample: Whether to use sampling |
|
|
num_return_sequences: Number of sequences to generate |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Generated text |
|
|
""" |
|
|
self._ensure_model_loaded(model_type) |
|
|
|
|
|
config = self.MODELS[model_type] |
|
|
model = self.loaded_models[model_type] |
|
|
tokenizer = self.tokenizers[model_type] |
|
|
|
|
|
|
|
|
max_length = max_length or config.default_max_length |
|
|
temperature = temperature or config.default_temperature |
|
|
top_p = top_p or config.default_top_p |
|
|
do_sample = do_sample if do_sample is not None else (temperature > 0) |
|
|
|
|
|
logger.info(f"Generating with {config.name}...") |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_length=max_length, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
do_sample=do_sample, |
|
|
num_return_sequences=num_return_sequences, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if num_return_sequences == 1: |
|
|
generated = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return generated[len(prompt):].strip() |
|
|
else: |
|
|
results = [] |
|
|
for output in outputs: |
|
|
generated = tokenizer.decode(output, skip_special_tokens=True) |
|
|
results.append(generated[len(prompt):].strip()) |
|
|
return results |
|
|
|
|
|
def code_generation( |
|
|
self, |
|
|
prompt: str, |
|
|
language: Optional[str] = None, |
|
|
**kwargs |
|
|
) -> str: |
|
|
"""Generate code using base model""" |
|
|
if language: |
|
|
prompt = f"Language: {language}\n\n{prompt}" |
|
|
|
|
|
return self.generate( |
|
|
prompt, |
|
|
model_type=ModelType.BASE, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def solve_math( |
|
|
self, |
|
|
problem: str, |
|
|
show_steps: bool = True, |
|
|
**kwargs |
|
|
) -> str: |
|
|
"""Solve mathematical problem using math model""" |
|
|
if show_steps: |
|
|
prompt = f"Solve the following problem step by step:\n\n{problem}\n\nSolution:" |
|
|
else: |
|
|
prompt = f"Solve: {problem}\n\nAnswer:" |
|
|
|
|
|
return self.generate( |
|
|
prompt, |
|
|
model_type=ModelType.MATH, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def design_algorithm( |
|
|
self, |
|
|
problem: str, |
|
|
include_complexity: bool = True, |
|
|
**kwargs |
|
|
) -> str: |
|
|
"""Design algorithm using algorithm model""" |
|
|
prompt = f"Design an efficient algorithm for:\n\n{problem}" |
|
|
if include_complexity: |
|
|
prompt += "\n\nInclude time and space complexity analysis." |
|
|
|
|
|
return self.generate( |
|
|
prompt, |
|
|
model_type=ModelType.ALGORITHM, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def debug_code( |
|
|
self, |
|
|
code: str, |
|
|
error_message: Optional[str] = None, |
|
|
language: str = "python", |
|
|
**kwargs |
|
|
) -> str: |
|
|
"""Debug code using debug model""" |
|
|
prompt = f"Debug the following {language} code:\n\n```{language}\n{code}\n```" |
|
|
if error_message: |
|
|
prompt += f"\n\nError: {error_message}" |
|
|
prompt += "\n\nProvide analysis and fixed code:" |
|
|
|
|
|
return self.generate( |
|
|
prompt, |
|
|
model_type=ModelType.DEBUG, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def get_loaded_models(self) -> List[str]: |
|
|
"""Get list of currently loaded models""" |
|
|
return [self.MODELS[mt].name for mt in self.loaded_models.keys()] |
|
|
|
|
|
def unload_model(self, model_type: ModelType): |
|
|
"""Unload a model to free memory""" |
|
|
if model_type in self.loaded_models: |
|
|
del self.loaded_models[model_type] |
|
|
del self.tokenizers[model_type] |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
logger.info(f"Unloaded {self.MODELS[model_type].name}") |
|
|
|
|
|
def unload_all(self): |
|
|
"""Unload all models""" |
|
|
for model_type in list(self.loaded_models.keys()): |
|
|
self.unload_model(model_type) |
|
|
logger.info("All models unloaded") |
|
|
|
|
|
|
|
|
def demonstrate_all_models(): |
|
|
"""Demonstrate all 4 models""" |
|
|
print("="*80) |
|
|
print("HELION-OSC MULTI-MODEL INFERENCE DEMONSTRATION") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
system = MultiModelInference(load_all_models=False, use_8bit=False) |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("MODEL 1: BASE - General Code Generation") |
|
|
print("="*80) |
|
|
prompt1 = "Write a Python function to check if a string is a palindrome:" |
|
|
print(f"Prompt: {prompt1}") |
|
|
print("\nGenerating...") |
|
|
result1 = system.code_generation(prompt1, language="python", max_length=512) |
|
|
print(f"\nResult:\n{result1}\n") |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("MODEL 2: MATH - Mathematical Reasoning") |
|
|
print("="*80) |
|
|
prompt2 = "Find the derivative of f(x) = 3x^4 - 2x^3 + 5x - 7" |
|
|
print(f"Prompt: {prompt2}") |
|
|
print("\nGenerating...") |
|
|
result2 = system.solve_math(prompt2, show_steps=True, max_length=1024) |
|
|
print(f"\nResult:\n{result2}\n") |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("MODEL 3: ALGORITHM - Algorithm Design") |
|
|
print("="*80) |
|
|
prompt3 = "Find the longest common subsequence of two strings" |
|
|
print(f"Prompt: {prompt3}") |
|
|
print("\nGenerating...") |
|
|
result3 = system.design_algorithm(prompt3, include_complexity=True, max_length=2048) |
|
|
print(f"\nResult:\n{result3}\n") |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("MODEL 4: DEBUG - Code Debugging") |
|
|
print("="*80) |
|
|
buggy_code = """ |
|
|
def factorial(n): |
|
|
if n == 0: |
|
|
return 1 |
|
|
return n * factorial(n) |
|
|
""" |
|
|
print(f"Buggy Code:\n{buggy_code}") |
|
|
print("\nGenerating debugging analysis...") |
|
|
result4 = system.debug_code( |
|
|
buggy_code, |
|
|
error_message="RecursionError: maximum recursion depth exceeded", |
|
|
max_length=1024 |
|
|
) |
|
|
print(f"\nResult:\n{result4}\n") |
|
|
|
|
|
|
|
|
print("="*80) |
|
|
print("LOADED MODELS:") |
|
|
print("="*80) |
|
|
for model_name in system.get_loaded_models(): |
|
|
print(f"✓ {model_name}") |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("DEMONSTRATION COMPLETE") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
def interactive_mode(): |
|
|
"""Interactive mode for testing models""" |
|
|
system = MultiModelInference(load_all_models=False) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("HELION-OSC INTERACTIVE MODE") |
|
|
print("="*80) |
|
|
print("\nAvailable commands:") |
|
|
print(" 1 - Generate code (Base model)") |
|
|
print(" 2 - Solve math (Math model)") |
|
|
print(" 3 - Design algorithm (Algorithm model)") |
|
|
print(" 4 - Debug code (Debug model)") |
|
|
print(" models - Show loaded models") |
|
|
print(" quit - Exit") |
|
|
print("="*80) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
command = input("\nEnter command (1-4, models, or quit): ").strip().lower() |
|
|
|
|
|
if command == "quit": |
|
|
print("Exiting...") |
|
|
break |
|
|
|
|
|
elif command == "models": |
|
|
loaded = system.get_loaded_models() |
|
|
if loaded: |
|
|
print("\nLoaded models:") |
|
|
for model in loaded: |
|
|
print(f" ✓ {model}") |
|
|
else: |
|
|
print("\nNo models loaded yet") |
|
|
|
|
|
elif command == "1": |
|
|
prompt = input("\nEnter code generation prompt: ") |
|
|
language = input("Programming language (or press Enter for Python): ").strip() or "python" |
|
|
print("\nGenerating...") |
|
|
result = system.code_generation(prompt, language=language) |
|
|
print(f"\n{result}\n") |
|
|
|
|
|
elif command == "2": |
|
|
problem = input("\nEnter math problem: ") |
|
|
print("\nSolving...") |
|
|
result = system.solve_math(problem) |
|
|
print(f"\n{result}\n") |
|
|
|
|
|
elif command == "3": |
|
|
problem = input("\nEnter algorithm problem: ") |
|
|
print("\nDesigning algorithm...") |
|
|
result = system.design_algorithm(problem) |
|
|
print(f"\n{result}\n") |
|
|
|
|
|
elif command == "4": |
|
|
print("\nEnter code to debug (type 'END' on a new line when done):") |
|
|
code_lines = [] |
|
|
while True: |
|
|
line = input() |
|
|
if line == "END": |
|
|
break |
|
|
code_lines.append(line) |
|
|
code = "\n".join(code_lines) |
|
|
error = input("\nError message (optional): ").strip() or None |
|
|
print("\nDebugging...") |
|
|
result = system.debug_code(code, error_message=error) |
|
|
print(f"\n{result}\n") |
|
|
|
|
|
else: |
|
|
print("Invalid command. Please try again.") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\nExiting...") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"\nError: {e}") |
|
|
|
|
|
system.unload_all() |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main entry point""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Helion-OSC Multi-Model Inference") |
|
|
parser.add_argument( |
|
|
"--mode", |
|
|
choices=["demo", "interactive"], |
|
|
default="demo", |
|
|
help="Run mode: demo or interactive" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--load-all", |
|
|
action="store_true", |
|
|
help="Load all models at startup" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--use-8bit", |
|
|
action="store_true", |
|
|
help="Use 8-bit quantization" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.mode == "demo": |
|
|
demonstrate_all_models() |
|
|
else: |
|
|
interactive_mode() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |