Spaces:
Paused
Paused
| """Simple HTTP server for the NexaSci model to enable sharing across processes.""" | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from .client_llm import Message, NexaSciModelClient | |
| # Add project root to path if running as module | |
| if __name__ == "__main__" or "agent.model_server" in sys.modules: | |
| project_root = Path(__file__).resolve().parents[1] | |
| if str(project_root) not in sys.path: | |
| sys.path.insert(0, str(project_root)) | |
| app = FastAPI(title="NexaSci Model Server", version="0.1.0") | |
| # Global model client (loaded once) | |
| _model_client: NexaSciModelClient | None = None | |
| class GenerateRequest(BaseModel): | |
| messages: List[Dict[str, str]] | |
| max_new_tokens: int | None = None | |
| temperature: float | None = None | |
| top_p: float | None = None | |
| class GenerateResponse(BaseModel): | |
| text: str | |
| model_loaded: bool | |
| async def load_model() -> None: | |
| """Load the model when the server starts.""" | |
| global _model_client | |
| import time | |
| print("=" * 80) | |
| print("Loading NexaSci model (this may take 30-60 seconds)...") | |
| print("=" * 80) | |
| print("Step 1: Loading tokenizer...") | |
| start_time = time.time() | |
| try: | |
| # Set tokenizers parallelism to avoid warnings | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| _model_client = NexaSciModelClient() | |
| elapsed = time.time() - start_time | |
| print(f"β Model loaded successfully in {elapsed:.1f}s") | |
| if torch.cuda.is_available(): | |
| print(f"β GPU: {torch.cuda.get_device_name(0)}") | |
| total_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| allocated = torch.cuda.memory_allocated(0) / (1024**3) | |
| print(f"β GPU Memory: {allocated:.1f} GB / {total_mem:.1f} GB allocated") | |
| print("=" * 80) | |
| print("Model server ready! Listening on http://0.0.0.0:8001") | |
| print("=" * 80) | |
| except Exception as e: | |
| elapsed = time.time() - start_time | |
| print(f"β Failed to load model after {elapsed:.1f}s: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| async def health_check() -> Dict[str, Any]: | |
| """Health check endpoint.""" | |
| gpu_available = torch.cuda.is_available() | |
| result = { | |
| "status": "healthy", | |
| "model_loaded": _model_client is not None, | |
| "gpu_available": gpu_available, | |
| } | |
| if gpu_available and _model_client is not None: | |
| # Check if model is actually on GPU | |
| try: | |
| model_device = next(_model_client.model.parameters()).device | |
| result["model_device"] = str(model_device) | |
| result["gpu_name"] = torch.cuda.get_device_name(0) | |
| result["gpu_memory_allocated_gb"] = round(torch.cuda.memory_allocated(0) / (1024**3), 2) | |
| result["gpu_memory_total_gb"] = round(torch.cuda.get_device_properties(0).total_memory / (1024**3), 2) | |
| except Exception as e: | |
| result["model_device_check_error"] = str(e) | |
| return result | |
| async def generate(request: GenerateRequest) -> GenerateResponse: | |
| """Generate text from the model.""" | |
| if _model_client is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| messages = [Message(role=msg["role"], content=msg["content"]) for msg in request.messages] | |
| text = _model_client.generate( | |
| messages, | |
| max_new_tokens=request.max_new_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| ) | |
| return GenerateResponse(text=text, model_loaded=True) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| async def list_tools() -> Dict[str, List[str]]: | |
| """List available tools.""" | |
| if _model_client is None: | |
| return {"tools": []} | |
| return {"tools": list(_model_client.available_tools)} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8001) | |