#!/usr/bin/env python3 """ AI Chat Application - Pure FastAPI Backend Serves custom frontend with OpenAI compatible API """ import os import sys import json import logging import time from typing import Optional, Dict, Any, Generator, List import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from fastapi import FastAPI, HTTPException, Response from fastapi.responses import StreamingResponse, FileResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware import asyncio import threading from threading import Thread from pydantic import BaseModel # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Pydantic models for API requests/responses class ChatMessage(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[ChatMessage] model: Optional[str] = "qwen-coder-3-30b" temperature: Optional[float] = 0.7 max_tokens: Optional[int] = 2048 stream: Optional[bool] = False class ChatResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[Dict[str, Any]] # Global model variables tokenizer = None model = None def load_model(): """Load the Qwen model and tokenizer""" global tokenizer, model try: model_name = "Qwen/Qwen3-Coder-30B-A3B-Instruct" # Adjust model name as needed logger.info(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Error loading model: {e}") # For development/testing, use a fallback logger.warning("Using fallback model response") def generate_response(messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2048): """Generate response from the model""" try: if model is None or tokenizer is None: # Fallback response for development return "I'm a Qwen AI assistant. The model is currently loading, please try again in a moment." # Format messages for the model formatted_messages = [] for msg in messages: formatted_messages.append({"role": msg.role, "content": msg.content}) # Apply chat template text = tokenizer.apply_chat_template( formatted_messages, tokenize=False, add_generation_prompt=True ) # Tokenize inputs = tokenizer(text, return_tensors="pt").to(model.device) # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode response response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return response.strip() except Exception as e: logger.error(f"Error generating response: {e}") return f"I apologize, but I encountered an error while processing your request: {str(e)}" def generate_streaming_response(messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2048): """Generate streaming response from the model""" try: if model is None or tokenizer is None: # Fallback streaming response response = "I'm a Qwen AI assistant. The model is currently loading, please try again in a moment." for char in response: yield f"data: {json.dumps({'choices': [{'delta': {'content': char}}]})}\n\n" time.sleep(0.05) yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" yield "data: [DONE]\n\n" return # Format messages formatted_messages = [] for msg in messages: formatted_messages.append({"role": msg.role, "content": msg.content}) # Apply chat template text = tokenizer.apply_chat_template( formatted_messages, tokenize=False, add_generation_prompt=True ) # Tokenize inputs = tokenizer(text, return_tensors="pt").to(model.device) # Setup streaming streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "max_new_tokens": max_tokens, "temperature": temperature, "do_sample": True, "pad_token_id": tokenizer.eos_token_id, "streamer": streamer } # Start generation in a thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Stream the response for new_text in streamer: if new_text: yield f"data: {json.dumps({'choices': [{'delta': {'content': new_text}}]})}\n\n" yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" yield "data: [DONE]\n\n" except Exception as e: logger.error(f"Error in streaming generation: {e}") error_msg = f"Error: {str(e)}" yield f"data: {json.dumps({'choices': [{'delta': {'content': error_msg}}]})}\n\n" yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" yield "data: [DONE]\n\n" # FastAPI app app = FastAPI(title="AI Chat API", description="OpenAI compatible interface for Qwen model") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # API endpoints @app.get("/") async def serve_index(): """Serve the main HTML file""" return FileResponse("public/index.html") @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "model_loaded": model is not None} @app.get("/ping") async def ping(): """Simple ping endpoint""" return {"status": "pong"} @app.get("/api/models") async def list_models(): """List available models""" return { "data": [ { "id": "qwen-coder-3-30b", "object": "model", "created": int(time.time()), "owned_by": "qwen" } ] } @app.post("/api/chat") async def chat_completion(request: ChatRequest): """OpenAI compatible chat completion endpoint""" try: if request.stream: return StreamingResponse( generate_streaming_response( request.messages, request.temperature or 0.7, request.max_tokens or 2048 ), media_type="text/plain" ) else: response_content = generate_response( request.messages, request.temperature or 0.7, request.max_tokens or 2048 ) return ChatResponse( id=f"chatcmpl-{int(time.time())}", created=int(time.time()), model=request.model or "qwen-coder-3-30b", choices=[{ "index": 0, "message": { "role": "assistant", "content": response_content }, "finish_reason": "stop" }] ) except Exception as e: logger.error(f"Error in chat completion: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/chat/completions") async def openai_chat_completion(request: ChatRequest): """OpenAI API compatible endpoint""" return await chat_completion(request) # Mount static files AFTER API routes app.mount("/", StaticFiles(directory="public", html=True), name="static") # Startup event @app.on_event("startup") async def startup_event(): """Initialize the model on startup""" # Load model in background thread to avoid blocking startup thread = Thread(target=load_model) thread.daemon = True thread.start() if __name__ == "__main__": import uvicorn # For Hugging Face Spaces port = int(os.environ.get("PORT", 7860)) uvicorn.run( app, host="0.0.0.0", port=port, access_log=True )