SIH-ML-Backend / main.py
Pulastya0's picture
Update main.py
ae4f620
raw
history blame
13.4 kB
import os
import json
import random
import chromadb
import math # βœ… Add the math library for ceiling division
from fastapi import FastAPI, HTTPException, Depends, Query
from pydantic import BaseModel, Field
from typing import List, Optional
import firebase_admin
from firebase_admin import credentials, firestore
# --- Local Imports ---
from encoder import SentenceEncoder
from populate_chroma import populate_vector_db
from llm_handler import (
initialize_llm, get_rag_response, create_chat_session,
clear_chat_session, delete_chat_session, get_chat_history,
get_chat_session_count, clear_all_chat_sessions
)
import llm_handler
# --------------------------------------------------------------------
# Cache & Root Path Setup
# --------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/data/cache"
root_path = os.getenv("HF_SPACE_ROOT_PATH", "")
# --------------------------------------------------------------------
# Pydantic Models
# --------------------------------------------------------------------
class UserProfile(BaseModel):
skills: List[str] = Field(..., example=["python", "data analysis"])
sectors: List[str] = Field(..., example=["machine learning", "web development"])
internshipType: str = Field(..., example="Bengaluru")
class SearchQuery(BaseModel):
query: str = Field(..., example="marketing internship in mumbai")
class InternshipData(BaseModel):
id: str = Field(..., example="int_021")
title: str
description: str
skills: List[str]
duration: int
createdAt: str
stipend: int = None
class SimpleRecommendation(BaseModel):
internship_id: str
score: float
class RecommendationResponse(BaseModel):
recommendations: List[SimpleRecommendation]
class StatusResponse(BaseModel):
status: str
internship_id: str
# --- βœ… UPDATED CHAT MODELS ---
class ChatMessage(BaseModel):
query: str
session_id: Optional[str] = Field(None, description="Chat session ID (optional - will be auto-created if not provided)")
class ChatResponse(BaseModel):
response: str
session_id: str
is_new_session: bool = Field(default=False, description="True if this was a new session created automatically")
class NewChatSessionResponse(BaseModel):
session_id: str
message: str
class ChatHistoryResponse(BaseModel):
session_id: str
history: List[dict]
class ClearChatResponse(BaseModel):
session_id: str
message: str
class MasterClearResponse(BaseModel):
message: str
sessions_cleared: int
timestamp: str
# --------------------------------------------------------------------
# FastAPI App
# --------------------------------------------------------------------
app = FastAPI(
title="Internship Recommendation & Chatbot API",
description="An API using Firestore for metadata, ChromaDB for vector search, and an LLM chatbot with memory.",
version="3.2.0",
root_path=root_path
)
# --------------------------------------------------------------------
# Firebase Initialization
# --------------------------------------------------------------------
db = None
try:
firebase_creds = os.getenv("FIREBASE_CREDS_JSON")
if firebase_creds:
creds_dict = json.loads(firebase_creds)
cred = credentials.Certificate(creds_dict)
if not firebase_admin._apps:
firebase_admin.initialize_app(cred)
db = firestore.client()
print("βœ… Firebase initialized with Hugging Face secret.")
else:
raise Exception("FIREBASE_CREDS_JSON not found")
except Exception as e:
print(f"❌ Could not initialize Firebase: {e}")
def get_db():
if db is None:
raise HTTPException(status_code=503, detail="Firestore connection not available.")
return db
# --------------------------------------------------------------------
# Global Variables (encoder + chroma)
# --------------------------------------------------------------------
encoder = None
chroma_collection = None
@app.on_event("startup")
def load_model_and_data():
global encoder, chroma_collection
print("πŸš€ Loading sentence encoder model...")
encoder = SentenceEncoder()
chroma_db_path = "/data/chroma_db"
try:
client = chromadb.PersistentClient(path=chroma_db_path)
chroma_collection = client.get_or_create_collection(name="internships")
print("βœ… ChromaDB client initialized and collection is ready.")
print(f" - Internships in DB: {chroma_collection.count()}")
llm_handler.encoder = encoder
llm_handler.chroma_collection = chroma_collection
initialize_llm()
except Exception as e:
print(f"❌ Error initializing ChromaDB or LLM: {e}")
raise
# --------------------------------------------------------------------
# Existing Endpoints
# --------------------------------------------------------------------
@app.get("/")
def read_root():
return {"message": "Welcome to the Internship Recommendation API with Chat Memory!"}
@app.post("/setup")
def run_initial_setup(secret_key: str = Query(..., example="your_secret_password")):
correct_key = os.getenv("SETUP_SECRET_KEY")
if not correct_key or secret_key != correct_key:
raise HTTPException(status_code=403, detail="Invalid secret key.")
try:
print("--- RUNNING DATABASE POPULATION SCRIPT ---")
populate_vector_db()
print("--- SETUP COMPLETE ---")
return {"status": "Setup completed successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred during setup: {str(e)}")
@app.post("/add-internship", response_model=StatusResponse)
def add_internship(internship: InternshipData, db_client: firestore.Client = Depends(get_db)):
if chroma_collection is None or encoder is None:
raise HTTPException(status_code=503, detail="Server is not ready.")
doc_ref = db_client.collection('internships').document(internship.id)
if doc_ref.get().exists:
raise HTTPException(status_code=400, detail="Internship ID already exists.")
doc_ref.set(internship.dict())
text_to_encode = f"{internship.title}. {internship.description}. Skills: {', '.join(internship.skills)}"
embedding = encoder.encode([text_to_encode])[0].tolist()
metadata_for_chroma = internship.dict()
metadata_for_chroma['skills'] = json.dumps(metadata_for_chroma['skills'])
chroma_collection.add(ids=[internship.id], embeddings=[embedding], metadatas=[metadata_for_chroma])
print(f"βœ… Added internship to Firestore and ChromaDB: {internship.id}")
return {"status": "success", "internship_id": internship.id}
@app.post("/profile-recommendations", response_model=RecommendationResponse)
def get_profile_recommendations(profile: UserProfile):
if chroma_collection is None or encoder is None:
raise HTTPException(status_code=503, detail="Server is not ready.")
query_text = f"Skills: {', '.join(profile.skills)}. Sectors: {', '.join(profile.sectors)}. internshipType: {profile.location}"
query_embedding = encoder.encode([query_text])[0].tolist()
results = chroma_collection.query(
query_embeddings=[query_embedding],
n_results=random.randint(5, 7) # Get 5 to 7 results
)
recommendations = []
ids = results.get('ids', [[]])[0]
distances = results.get('distances', [[]])[0]
for i, internship_id in enumerate(ids):
recommendations.append({
"internship_id": internship_id,
"score": 1 - distances[i]
})
return {"recommendations": recommendations}
@app.post("/search", response_model=RecommendationResponse)
def search_internships(search: SearchQuery):
if chroma_collection is None or encoder is None:
raise HTTPException(status_code=503, detail="Server is not ready.")
query_embedding = encoder.encode([search.query])[0].tolist()
results = chroma_collection.query(
query_embeddings=[query_embedding],
n_results=random.randint(3, 5) # Get 3 to 5 results
)
recommendations = []
ids = results.get('ids', [[]])[0]
distances = results.get('distances', [[]])[0]
for i, internship_id in enumerate(ids):
recommendations.append({
"internship_id": internship_id,
"score": 1 - distances[i]
})
return {"recommendations": recommendations}
# --------------------------------------------------------------------
# βœ… NEW CHAT ENDPOINTS WITH MEMORY
# --------------------------------------------------------------------
@app.post("/chat/new-session", response_model=NewChatSessionResponse)
def create_new_chat_session():
"""Create a new chat session."""
session_id = create_chat_session()
return {
"session_id": session_id,
"message": "New chat session created successfully"
}
@app.post("/chat", response_model=ChatResponse)
def chat_with_bot(message: ChatMessage):
"""
Chat with the bot. Automatically creates a session if none provided.
- If session_id is not provided: Creates a new session automatically
- If session_id is provided but doesn't exist: Creates a new session with that ID
- If session_id exists: Continues the existing conversation
"""
print(f"πŸ“¨ Received chat request:")
print(f" Query: {message.query}")
print(f" Session ID: {message.session_id}")
try:
is_new_session = message.session_id is None or message.session_id == ""
response, session_id = get_rag_response(message.query, message.session_id)
print(f"πŸ“€ Sending response:")
print(f" Session ID: {session_id}")
print(f" Is New Session: {is_new_session}")
print(f" Response: {response[:100]}...")
return {
"response": response,
"session_id": session_id,
"is_new_session": is_new_session
}
except Exception as e:
print(f"❌ Error in chat endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}")
@app.get("/chat/{session_id}/history", response_model=ChatHistoryResponse)
def get_session_history(session_id: str):
"""Get the chat history for a specific session."""
history = get_chat_history(session_id)
if history is None:
raise HTTPException(status_code=404, detail="Chat session not found")
return {
"session_id": session_id,
"history": history
}
@app.delete("/chat/{session_id}/clear", response_model=ClearChatResponse)
def clear_session_history(session_id: str):
"""Clear the chat history for a specific session."""
success = clear_chat_session(session_id)
if not success:
raise HTTPException(status_code=404, detail="Chat session not found")
return {
"session_id": session_id,
"message": "Chat history cleared successfully"
}
@app.delete("/chat/{session_id}/delete", response_model=ClearChatResponse)
def delete_session(session_id: str):
"""
Delete a chat session completely.
⭐ RECOMMENDED: Call this when user closes the chatbot to free up memory.
This helps keep the server efficient by cleaning up unused sessions.
"""
success = delete_chat_session(session_id)
if not success:
raise HTTPException(status_code=404, detail="Chat session not found")
print(f"πŸ—‘οΈ Session deleted by user: {session_id}")
return {
"session_id": session_id,
"message": "Chat session deleted successfully"
}
@app.delete("/chat/sessions/clear-all", response_model=MasterClearResponse)
def clear_all_sessions(secret_key: str = Query(..., example="your_admin_secret")):
"""
🚨 MASTER ENDPOINT: Clear all chat sessions at once.
This endpoint requires an admin secret key and will:
- Clear ALL active chat sessions
- Free up memory immediately
- Useful for maintenance and preventing memory bloating
⚠️ WARNING: This will terminate all ongoing conversations!
"""
# Check admin secret key
admin_secret = os.getenv("ADMIN_SECRET_KEY")
if not admin_secret or secret_key != admin_secret:
raise HTTPException(status_code=403, detail="Invalid admin secret key.")
from datetime import datetime
sessions_cleared = clear_all_chat_sessions()
timestamp = datetime.now().isoformat()
return {
"message": f"Successfully cleared all chat sessions. Memory freed.",
"sessions_cleared": sessions_cleared,
"timestamp": timestamp
}
@app.get("/chat/sessions/count")
def get_active_sessions():
"""Get the number of active chat sessions."""
count = get_chat_session_count()
return {
"active_sessions": count,
"message": f"There are {count} active chat sessions",
"memory_status": "healthy" if count <= 15 else "high_usage"
}
# Health check endpoint
@app.get("/healthz")
def health_check():
status = {
"status": "healthy",
"encoder_ready": encoder is not None,
"chroma_ready": chroma_collection is not None,
"firebase_ready": db is not None,
"active_chat_sessions": get_chat_session_count()
}
return status