Nexa_Labs / agent /model_server.py
Allanatrix's picture
Upload 57 files
d8328bf verified
"""Simple HTTP server for the NexaSci model to enable sharing across processes."""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Any, Dict, List
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from .client_llm import Message, NexaSciModelClient
# Add project root to path if running as module
if __name__ == "__main__" or "agent.model_server" in sys.modules:
project_root = Path(__file__).resolve().parents[1]
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
app = FastAPI(title="NexaSci Model Server", version="0.1.0")
# Global model client (loaded once)
_model_client: NexaSciModelClient | None = None
class GenerateRequest(BaseModel):
messages: List[Dict[str, str]]
max_new_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
class GenerateResponse(BaseModel):
text: str
model_loaded: bool
@app.on_event("startup")
async def load_model() -> None:
"""Load the model when the server starts."""
global _model_client
import time
print("=" * 80)
print("Loading NexaSci model (this may take 30-60 seconds)...")
print("=" * 80)
print("Step 1: Loading tokenizer...")
start_time = time.time()
try:
# Set tokenizers parallelism to avoid warnings
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
_model_client = NexaSciModelClient()
elapsed = time.time() - start_time
print(f"βœ“ Model loaded successfully in {elapsed:.1f}s")
if torch.cuda.is_available():
print(f"βœ“ GPU: {torch.cuda.get_device_name(0)}")
total_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
allocated = torch.cuda.memory_allocated(0) / (1024**3)
print(f"βœ“ GPU Memory: {allocated:.1f} GB / {total_mem:.1f} GB allocated")
print("=" * 80)
print("Model server ready! Listening on http://0.0.0.0:8001")
print("=" * 80)
except Exception as e:
elapsed = time.time() - start_time
print(f"βœ— Failed to load model after {elapsed:.1f}s: {e}")
import traceback
traceback.print_exc()
raise
@app.get("/health")
async def health_check() -> Dict[str, Any]:
"""Health check endpoint."""
gpu_available = torch.cuda.is_available()
result = {
"status": "healthy",
"model_loaded": _model_client is not None,
"gpu_available": gpu_available,
}
if gpu_available and _model_client is not None:
# Check if model is actually on GPU
try:
model_device = next(_model_client.model.parameters()).device
result["model_device"] = str(model_device)
result["gpu_name"] = torch.cuda.get_device_name(0)
result["gpu_memory_allocated_gb"] = round(torch.cuda.memory_allocated(0) / (1024**3), 2)
result["gpu_memory_total_gb"] = round(torch.cuda.get_device_properties(0).total_memory / (1024**3), 2)
except Exception as e:
result["model_device_check_error"] = str(e)
return result
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest) -> GenerateResponse:
"""Generate text from the model."""
if _model_client is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
messages = [Message(role=msg["role"], content=msg["content"]) for msg in request.messages]
text = _model_client.generate(
messages,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_p=request.top_p,
)
return GenerateResponse(text=text, model_loaded=True)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.get("/tools")
async def list_tools() -> Dict[str, List[str]]:
"""List available tools."""
if _model_client is None:
return {"tools": []}
return {"tools": list(_model_client.available_tools)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)