Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import random | |
| import chromadb | |
| import math # β Add the math library for ceiling division | |
| from fastapi import FastAPI, HTTPException, Depends, Query | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional | |
| import firebase_admin | |
| from firebase_admin import credentials, firestore | |
| # --- Local Imports --- | |
| from encoder import SentenceEncoder | |
| from populate_chroma import populate_vector_db | |
| from llm_handler import ( | |
| initialize_llm, get_rag_response, create_chat_session, | |
| clear_chat_session, delete_chat_session, get_chat_history, | |
| get_chat_session_count, clear_all_chat_sessions | |
| ) | |
| import llm_handler | |
| # -------------------------------------------------------------------- | |
| # Cache & Root Path Setup | |
| # -------------------------------------------------------------------- | |
| os.environ["HF_HOME"] = "/data/cache" | |
| os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/data/cache" | |
| root_path = os.getenv("HF_SPACE_ROOT_PATH", "") | |
| # -------------------------------------------------------------------- | |
| # Pydantic Models | |
| # -------------------------------------------------------------------- | |
| class UserProfile(BaseModel): | |
| skills: List[str] = Field(..., example=["python", "data analysis"]) | |
| sectors: List[str] = Field(..., example=["machine learning", "web development"]) | |
| internshipType: str = Field(..., example="Bengaluru") | |
| class SearchQuery(BaseModel): | |
| query: str = Field(..., example="marketing internship in mumbai") | |
| class InternshipData(BaseModel): | |
| id: str = Field(..., example="int_021") | |
| title: str | |
| description: str | |
| skills: List[str] | |
| duration: int | |
| createdAt: str | |
| stipend: int = None | |
| class SimpleRecommendation(BaseModel): | |
| internship_id: str | |
| score: float | |
| class RecommendationResponse(BaseModel): | |
| recommendations: List[SimpleRecommendation] | |
| class StatusResponse(BaseModel): | |
| status: str | |
| internship_id: str | |
| # --- β UPDATED CHAT MODELS --- | |
| class ChatMessage(BaseModel): | |
| query: str | |
| session_id: Optional[str] = Field(None, description="Chat session ID (optional - will be auto-created if not provided)") | |
| class ChatResponse(BaseModel): | |
| response: str | |
| session_id: str | |
| is_new_session: bool = Field(default=False, description="True if this was a new session created automatically") | |
| class NewChatSessionResponse(BaseModel): | |
| session_id: str | |
| message: str | |
| class ChatHistoryResponse(BaseModel): | |
| session_id: str | |
| history: List[dict] | |
| class ClearChatResponse(BaseModel): | |
| session_id: str | |
| message: str | |
| class MasterClearResponse(BaseModel): | |
| message: str | |
| sessions_cleared: int | |
| timestamp: str | |
| # -------------------------------------------------------------------- | |
| # FastAPI App | |
| # -------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="Internship Recommendation & Chatbot API", | |
| description="An API using Firestore for metadata, ChromaDB for vector search, and an LLM chatbot with memory.", | |
| version="3.2.0", | |
| root_path=root_path | |
| ) | |
| # -------------------------------------------------------------------- | |
| # Firebase Initialization | |
| # -------------------------------------------------------------------- | |
| db = None | |
| try: | |
| firebase_creds = os.getenv("FIREBASE_CREDS_JSON") | |
| if firebase_creds: | |
| creds_dict = json.loads(firebase_creds) | |
| cred = credentials.Certificate(creds_dict) | |
| if not firebase_admin._apps: | |
| firebase_admin.initialize_app(cred) | |
| db = firestore.client() | |
| print("β Firebase initialized with Hugging Face secret.") | |
| else: | |
| raise Exception("FIREBASE_CREDS_JSON not found") | |
| except Exception as e: | |
| print(f"β Could not initialize Firebase: {e}") | |
| def get_db(): | |
| if db is None: | |
| raise HTTPException(status_code=503, detail="Firestore connection not available.") | |
| return db | |
| # -------------------------------------------------------------------- | |
| # Global Variables (encoder + chroma) | |
| # -------------------------------------------------------------------- | |
| encoder = None | |
| chroma_collection = None | |
| def load_model_and_data(): | |
| global encoder, chroma_collection | |
| print("π Loading sentence encoder model...") | |
| encoder = SentenceEncoder() | |
| chroma_db_path = "/data/chroma_db" | |
| try: | |
| client = chromadb.PersistentClient(path=chroma_db_path) | |
| chroma_collection = client.get_or_create_collection(name="internships") | |
| print("β ChromaDB client initialized and collection is ready.") | |
| print(f" - Internships in DB: {chroma_collection.count()}") | |
| llm_handler.encoder = encoder | |
| llm_handler.chroma_collection = chroma_collection | |
| initialize_llm() | |
| except Exception as e: | |
| print(f"β Error initializing ChromaDB or LLM: {e}") | |
| raise | |
| # -------------------------------------------------------------------- | |
| # Existing Endpoints | |
| # -------------------------------------------------------------------- | |
| def read_root(): | |
| return {"message": "Welcome to the Internship Recommendation API with Chat Memory!"} | |
| def run_initial_setup(secret_key: str = Query(..., example="your_secret_password")): | |
| correct_key = os.getenv("SETUP_SECRET_KEY") | |
| if not correct_key or secret_key != correct_key: | |
| raise HTTPException(status_code=403, detail="Invalid secret key.") | |
| try: | |
| print("--- RUNNING DATABASE POPULATION SCRIPT ---") | |
| populate_vector_db() | |
| print("--- SETUP COMPLETE ---") | |
| return {"status": "Setup completed successfully."} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred during setup: {str(e)}") | |
| def add_internship(internship: InternshipData, db_client: firestore.Client = Depends(get_db)): | |
| if chroma_collection is None or encoder is None: | |
| raise HTTPException(status_code=503, detail="Server is not ready.") | |
| doc_ref = db_client.collection('internships').document(internship.id) | |
| if doc_ref.get().exists: | |
| raise HTTPException(status_code=400, detail="Internship ID already exists.") | |
| doc_ref.set(internship.dict()) | |
| text_to_encode = f"{internship.title}. {internship.description}. Skills: {', '.join(internship.skills)}" | |
| embedding = encoder.encode([text_to_encode])[0].tolist() | |
| metadata_for_chroma = internship.dict() | |
| metadata_for_chroma['skills'] = json.dumps(metadata_for_chroma['skills']) | |
| chroma_collection.add(ids=[internship.id], embeddings=[embedding], metadatas=[metadata_for_chroma]) | |
| print(f"β Added internship to Firestore and ChromaDB: {internship.id}") | |
| return {"status": "success", "internship_id": internship.id} | |
| def get_profile_recommendations(profile: UserProfile): | |
| if chroma_collection is None or encoder is None: | |
| raise HTTPException(status_code=503, detail="Server is not ready.") | |
| query_text = f"Skills: {', '.join(profile.skills)}. Sectors: {', '.join(profile.sectors)}. internshipType: {profile.location}" | |
| query_embedding = encoder.encode([query_text])[0].tolist() | |
| results = chroma_collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=random.randint(5, 7) # Get 5 to 7 results | |
| ) | |
| recommendations = [] | |
| ids = results.get('ids', [[]])[0] | |
| distances = results.get('distances', [[]])[0] | |
| for i, internship_id in enumerate(ids): | |
| recommendations.append({ | |
| "internship_id": internship_id, | |
| "score": 1 - distances[i] | |
| }) | |
| return {"recommendations": recommendations} | |
| def search_internships(search: SearchQuery): | |
| if chroma_collection is None or encoder is None: | |
| raise HTTPException(status_code=503, detail="Server is not ready.") | |
| query_embedding = encoder.encode([search.query])[0].tolist() | |
| results = chroma_collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=random.randint(3, 5) # Get 3 to 5 results | |
| ) | |
| recommendations = [] | |
| ids = results.get('ids', [[]])[0] | |
| distances = results.get('distances', [[]])[0] | |
| for i, internship_id in enumerate(ids): | |
| recommendations.append({ | |
| "internship_id": internship_id, | |
| "score": 1 - distances[i] | |
| }) | |
| return {"recommendations": recommendations} | |
| # -------------------------------------------------------------------- | |
| # β NEW CHAT ENDPOINTS WITH MEMORY | |
| # -------------------------------------------------------------------- | |
| def create_new_chat_session(): | |
| """Create a new chat session.""" | |
| session_id = create_chat_session() | |
| return { | |
| "session_id": session_id, | |
| "message": "New chat session created successfully" | |
| } | |
| def chat_with_bot(message: ChatMessage): | |
| """ | |
| Chat with the bot. Automatically creates a session if none provided. | |
| - If session_id is not provided: Creates a new session automatically | |
| - If session_id is provided but doesn't exist: Creates a new session with that ID | |
| - If session_id exists: Continues the existing conversation | |
| """ | |
| print(f"π¨ Received chat request:") | |
| print(f" Query: {message.query}") | |
| print(f" Session ID: {message.session_id}") | |
| try: | |
| is_new_session = message.session_id is None or message.session_id == "" | |
| response, session_id = get_rag_response(message.query, message.session_id) | |
| print(f"π€ Sending response:") | |
| print(f" Session ID: {session_id}") | |
| print(f" Is New Session: {is_new_session}") | |
| print(f" Response: {response[:100]}...") | |
| return { | |
| "response": response, | |
| "session_id": session_id, | |
| "is_new_session": is_new_session | |
| } | |
| except Exception as e: | |
| print(f"β Error in chat endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}") | |
| def get_session_history(session_id: str): | |
| """Get the chat history for a specific session.""" | |
| history = get_chat_history(session_id) | |
| if history is None: | |
| raise HTTPException(status_code=404, detail="Chat session not found") | |
| return { | |
| "session_id": session_id, | |
| "history": history | |
| } | |
| def clear_session_history(session_id: str): | |
| """Clear the chat history for a specific session.""" | |
| success = clear_chat_session(session_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="Chat session not found") | |
| return { | |
| "session_id": session_id, | |
| "message": "Chat history cleared successfully" | |
| } | |
| def delete_session(session_id: str): | |
| """ | |
| Delete a chat session completely. | |
| β RECOMMENDED: Call this when user closes the chatbot to free up memory. | |
| This helps keep the server efficient by cleaning up unused sessions. | |
| """ | |
| success = delete_chat_session(session_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="Chat session not found") | |
| print(f"ποΈ Session deleted by user: {session_id}") | |
| return { | |
| "session_id": session_id, | |
| "message": "Chat session deleted successfully" | |
| } | |
| def clear_all_sessions(secret_key: str = Query(..., example="your_admin_secret")): | |
| """ | |
| π¨ MASTER ENDPOINT: Clear all chat sessions at once. | |
| This endpoint requires an admin secret key and will: | |
| - Clear ALL active chat sessions | |
| - Free up memory immediately | |
| - Useful for maintenance and preventing memory bloating | |
| β οΈ WARNING: This will terminate all ongoing conversations! | |
| """ | |
| # Check admin secret key | |
| admin_secret = os.getenv("ADMIN_SECRET_KEY") | |
| if not admin_secret or secret_key != admin_secret: | |
| raise HTTPException(status_code=403, detail="Invalid admin secret key.") | |
| from datetime import datetime | |
| sessions_cleared = clear_all_chat_sessions() | |
| timestamp = datetime.now().isoformat() | |
| return { | |
| "message": f"Successfully cleared all chat sessions. Memory freed.", | |
| "sessions_cleared": sessions_cleared, | |
| "timestamp": timestamp | |
| } | |
| def get_active_sessions(): | |
| """Get the number of active chat sessions.""" | |
| count = get_chat_session_count() | |
| return { | |
| "active_sessions": count, | |
| "message": f"There are {count} active chat sessions", | |
| "memory_status": "healthy" if count <= 15 else "high_usage" | |
| } | |
| # Health check endpoint | |
| def health_check(): | |
| status = { | |
| "status": "healthy", | |
| "encoder_ready": encoder is not None, | |
| "chroma_ready": chroma_collection is not None, | |
| "firebase_ready": db is not None, | |
| "active_chat_sessions": get_chat_session_count() | |
| } | |
| return status |