"""Client utilities for interacting with the NexaSci Assistant LLM.""" from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence import torch import yaml from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig @dataclass(frozen=True) class Message: """Represents a single conversational turn.""" role: str content: str @dataclass(frozen=True) class ModelConfig: """Configuration for loading the NexaSci model.""" base_repo: str merged_path: Optional[str] adapter_path: Optional[str] backend: str torch_dtype: Optional[str] trust_remote_code: bool @dataclass(frozen=True) class GenerationSettings: """Settings governing text generation.""" max_new_tokens: int temperature: float top_p: float repetition_penalty: float tool_prefix: str tool_suffix: str final_prefix: str final_suffix: str stop_sequences: Sequence[str] system_prompt: str class NexaSciModelClient: """High-level client for loading and querying the NexaSci Assistant.""" def __init__(self, config_path: Path | str = "agent/config.yaml", lazy_load: bool = False) -> None: """Initialise the client by loading configuration and model weights. Parameters ---------- config_path: Path to the agent configuration YAML file. lazy_load: If True, delay model loading until first generation call. """ self._config_path = Path(config_path) if not self._config_path.exists(): raise FileNotFoundError(f"Configuration file not found at {self._config_path}") raw_cfg = _load_yaml(self._config_path) self.model_config = _parse_model_config(raw_cfg["model"]) self.generation_settings = _parse_generation_settings(raw_cfg["generation"]) self._tooling_config = raw_cfg.get("tooling", {}) self._tokenizer: Any | None = None self._model: AutoModelForCausalLM | None = None self._lazy_load = lazy_load if not lazy_load: print("Loading tokenizer and model...") self._tokenizer = self._load_tokenizer() self._model = self._load_model() print("✓ Model loaded") else: print("Model will be loaded on first generation call") @property def tokenizer(self) -> Any: """Lazy-load tokenizer if needed.""" if self._tokenizer is None: self._tokenizer = self._load_tokenizer() return self._tokenizer @property def model(self) -> AutoModelForCausalLM: """Lazy-load model if needed.""" if self._model is None: print("Loading model (this may take 30-60 seconds)...") self._model = self._load_model() print("✓ Model loaded") return self._model @property def available_tools(self) -> Sequence[str]: """Return the list of tool identifiers declared in configuration.""" return tuple(self._tooling_config.get("available_tools", [])) def _resolve_model_path(self, path: Optional[str]) -> str: """Resolve a model path, handling relative paths relative to config file.""" if path is None: return self.model_config.base_repo path_obj = Path(path) if path_obj.is_absolute(): return str(path_obj) # Resolve relative to config file's parent directory (project root) config_dir = self._config_path.parent resolved = (config_dir / path_obj).resolve() return str(resolved) def _load_tokenizer(self) -> Any: """Load the tokenizer for the configured model.""" source = self._resolve_model_path(self.model_config.merged_path) print(f" Loading tokenizer from: {source}") tokenizer = AutoTokenizer.from_pretrained( source, trust_remote_code=self.model_config.trust_remote_code, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(" ✓ Tokenizer loaded") return tokenizer def _load_model(self) -> AutoModelForCausalLM: """Load the base or merged model for inference.""" source = self._resolve_model_path(self.model_config.merged_path) torch_dtype = _resolve_torch_dtype(self.model_config.torch_dtype) print(f" Loading model from: {source}") print(f" Using dtype: {torch_dtype}") # Check CUDA availability if torch.cuda.is_available(): print(f" ✓ CUDA available: {torch.cuda.get_device_name(0)}") print(f" ✓ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB") device_map = "auto" else: print(" ⚠️ WARNING: CUDA not available! Model will load on CPU (very slow)") print(" Check: 1) NVIDIA drivers installed, 2) PyTorch with CUDA support, 3) GPU visible") device_map = None print(f" This may take 30-60 seconds...") model = AutoModelForCausalLM.from_pretrained( source, device_map=device_map, torch_dtype=torch_dtype, low_cpu_mem_usage=True, trust_remote_code=self.model_config.trust_remote_code, ) model.eval() # Verify device if torch.cuda.is_available(): model_device = next(model.parameters()).device if model_device.type == "cuda": print(f" ✓ Model loaded on GPU: {model_device}") else: print(f" ⚠️ WARNING: Model loaded on {model_device}, not GPU!") else: print(" ⚠️ Model loaded on CPU (will be very slow)") return model def build_chat_messages(self, messages: Iterable[Message]) -> List[Dict[str, str]]: """Convert internal message objects to the tokenizer chat format.""" formatted: List[Dict[str, str]] = [] system_present = any(message.role == "system" for message in messages) if not system_present: formatted.append( { "role": "system", "content": self.generation_settings.system_prompt, } ) for message in messages: formatted.append({"role": message.role, "content": message.content}) return formatted def _format_prompt(self, messages: Sequence[Message]) -> str: """Format messages into a prompt string for Falcon models.""" parts = [] system_present = any(msg.role == "system" for msg in messages) if not system_present: parts.append(f"System: {self.generation_settings.system_prompt}") for message in messages: role = message.role.capitalize() if role == "System": parts.append(f"System: {message.content}") elif role == "User": parts.append(f"User: {message.content}") elif role == "Assistant": parts.append(f"Assistant: {message.content}") elif role == "Tool": parts.append(f"Tool: {message.content}") parts.append("Assistant:") return "\n\n".join(parts) def generate( self, messages: Sequence[Message], *, max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, ) -> str: """Generate a response from the model given a message history.""" # Check if tokenizer has a chat template, otherwise format manually if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None: chat_messages = self.build_chat_messages(messages) inputs = self.tokenizer.apply_chat_template( chat_messages, tokenize=True, return_tensors="pt", add_generation_prompt=True, ) # apply_chat_template returns tensor directly if isinstance(inputs, torch.Tensor): input_ids = inputs.to(self.model.device) else: input_ids = inputs["input_ids"].to(self.model.device) else: # Manual formatting for models without chat templates (e.g., Falcon) prompt_text = self._format_prompt(messages) tokenized = self.tokenizer( prompt_text, return_tensors="pt", add_special_tokens=True, ) input_ids = tokenized["input_ids"].to(self.model.device) temp = temperature or self.generation_settings.temperature top_p_val = top_p or self.generation_settings.top_p # Enable sampling when temperature/top_p are used do_sample = temp > 0.0 or top_p_val < 1.0 generation_config = GenerationConfig( max_new_tokens=max_new_tokens or self.generation_settings.max_new_tokens, temperature=temp, top_p=top_p_val, do_sample=do_sample, repetition_penalty=self.generation_settings.repetition_penalty, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Ensure model is loaded model = self.model model_device = next(model.parameters()).device if model_device.type == "cuda": torch.cuda.empty_cache() free_mem = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0) free_gb = free_mem / (1024**3) print(f" 💾 GPU memory: {free_gb:.1f} GB free") if free_gb < 5: print(f" ⚠️ Warning: Low GPU memory ({free_gb:.1f} GB)") else: print(f" ⚠️ WARNING: Model is on {model_device}, not GPU! Generation will be very slow.") print(f" This is likely why it's freezing. Check CUDA installation.") try: print(f" 🚀 Starting generation (max {generation_config.max_new_tokens} tokens)...") with torch.inference_mode(): output_ids = model.generate( input_ids=input_ids, generation_config=generation_config, do_sample=do_sample, ) print(f" ✓ Generation complete") except torch.cuda.OutOfMemoryError as e: torch.cuda.empty_cache() raise RuntimeError( f"GPU out of memory. Try: 1) Reduce max_new_tokens, 2) Use CPU, " f"3) Close other GPU processes. Original error: {e}" ) from e except RuntimeError as e: if "out of memory" in str(e).lower(): torch.cuda.empty_cache() raise RuntimeError( f"GPU out of memory. Try reducing max_new_tokens or closing other processes." ) from e raise except Exception as e: raise RuntimeError(f"Generation failed: {e}") from e generated_ids = output_ids[0, input_ids.shape[-1] :] generated_text = self.tokenizer.decode( generated_ids, skip_special_tokens=True, ) return generated_text.strip() def _parse_model_config(raw_config: Dict[str, Any]) -> ModelConfig: """Validate and coerce the raw model configuration.""" return ModelConfig( base_repo=raw_config["base_repo"], merged_path=raw_config.get("merged_path"), adapter_path=raw_config.get("adapter_path"), backend=raw_config.get("backend", "transformers"), torch_dtype=raw_config.get("torch_dtype"), trust_remote_code=bool(raw_config.get("trust_remote_code", True)), ) def _parse_generation_settings(raw_config: Dict[str, Any]) -> GenerationSettings: """Validate and coerce generation settings.""" return GenerationSettings( max_new_tokens=int(raw_config.get("max_new_tokens", 512)), temperature=float(raw_config.get("temperature", 0.3)), top_p=float(raw_config.get("top_p", 0.9)), repetition_penalty=float(raw_config.get("repetition_penalty", 1.05)), tool_prefix=str(raw_config.get("tool_prefix", "~~~toolcall")), tool_suffix=str(raw_config.get("tool_suffix", "~~~")), final_prefix=str(raw_config.get("final_prefix", "~~~final")), final_suffix=str(raw_config.get("final_suffix", "~~~")), stop_sequences=tuple(raw_config.get("stop_sequences", [])), system_prompt=str( raw_config.get( "system_prompt", "You are the NexaSci Assistant, a scientific research agent.", ) ), ) def _resolve_torch_dtype(dtype_name: Optional[str]) -> Optional[torch.dtype]: """Map configuration dtype strings to torch dtypes.""" if dtype_name is None: return None normalised = dtype_name.strip().lower() mapping = { "float16": torch.float16, "fp16": torch.float16, "half": torch.float16, "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, "float32": torch.float32, "fp32": torch.float32, } try: return mapping[normalised] except KeyError as exc: raise ValueError(f"Unsupported torch dtype: {dtype_name}") from exc def _load_yaml(path: Path) -> Dict[str, Any]: """Load and parse YAML configuration from disk.""" with path.open("r", encoding="utf-8") as handle: return yaml.safe_load(handle) __all__ = [ "GenerationSettings", "Message", "ModelConfig", "NexaSciModelClient", ]