Pulastya0 commited on
Commit
0239397
Β·
1 Parent(s): 74db693

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +110 -21
main.py CHANGED
@@ -4,14 +4,18 @@ import chromadb
4
  import math # βœ… Add the math library for ceiling division
5
  from fastapi import FastAPI, HTTPException, Depends, Query
6
  from pydantic import BaseModel, Field
7
- from typing import List
8
  import firebase_admin
9
  from firebase_admin import credentials, firestore
10
 
11
  # --- Local Imports ---
12
  from encoder import SentenceEncoder
13
  from populate_chroma import populate_vector_db
14
- from llm_handler import initialize_llm, get_rag_response
 
 
 
 
15
  import llm_handler
16
 
17
  # --------------------------------------------------------------------
@@ -44,7 +48,6 @@ class SimpleRecommendation(BaseModel):
44
  internship_id: str
45
  score: float
46
 
47
- # --- βœ… UPDATED RESPONSE MODEL ---
48
  class RecommendationResponse(BaseModel):
49
  recommendations: List[SimpleRecommendation]
50
  page: int
@@ -54,19 +57,34 @@ class StatusResponse(BaseModel):
54
  status: str
55
  internship_id: str
56
 
 
57
  class ChatMessage(BaseModel):
58
  query: str
 
59
 
60
  class ChatResponse(BaseModel):
61
  response: str
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # --------------------------------------------------------------------
64
  # FastAPI App
65
  # --------------------------------------------------------------------
66
  app = FastAPI(
67
  title="Internship Recommendation & Chatbot API",
68
- description="An API using Firestore for metadata, ChromaDB for vector search, and an LLM chatbot.",
69
- version="3.1.0",
70
  root_path=root_path
71
  )
72
 
@@ -118,13 +136,11 @@ def load_model_and_data():
118
  raise
119
 
120
  # --------------------------------------------------------------------
121
- # Endpoints
122
  # --------------------------------------------------------------------
123
  @app.get("/")
124
  def read_root():
125
- return {"message": "Welcome to the Internship Recommendation API!"}
126
-
127
- # ... (setup, add-internship, and healthz endpoints are unchanged)
128
 
129
  @app.post("/setup")
130
  def run_initial_setup(secret_key: str = Query(..., example="your_secret_password")):
@@ -155,8 +171,6 @@ def add_internship(internship: InternshipData, db_client: firestore.Client = Dep
155
  print(f"βœ… Added internship to Firestore and ChromaDB: {internship.id}")
156
  return {"status": "success", "internship_id": internship.id}
157
 
158
-
159
- # --- βœ… ENDPOINT UPDATED FOR PAGINATION ---
160
  @app.post("/profile-recommendations", response_model=RecommendationResponse)
161
  def get_profile_recommendations(profile: UserProfile, page: int = 1, page_size: int = 4):
162
  if chroma_collection is None or encoder is None:
@@ -165,16 +179,12 @@ def get_profile_recommendations(profile: UserProfile, page: int = 1, page_size:
165
  query_text = f"Skills: {', '.join(profile.skills)}. Sectors: {', '.join(profile.sectors)}"
166
  query_embedding = encoder.encode([query_text])[0].tolist()
167
 
168
- # Query for all results to sort them, then paginate
169
- # This is less efficient at scale, but simple and effective for this project
170
- # A more advanced approach would use ChromaDB's offset/limit if available or other methods.
171
  total_items = chroma_collection.count()
172
  results = chroma_collection.query(
173
  query_embeddings=[query_embedding],
174
- n_results=total_items # Get all results to sort
175
  )
176
 
177
- # Process and sort all hits by score
178
  ids = results.get('ids', [[]])[0]
179
  distances = results.get('distances', [[]])[0]
180
  all_recommendations = [
@@ -182,7 +192,6 @@ def get_profile_recommendations(profile: UserProfile, page: int = 1, page_size:
182
  for i in range(len(ids))
183
  ]
184
 
185
- # --- PAGINATION LOGIC ---
186
  start_index = (page - 1) * page_size
187
  end_index = start_index + page_size
188
  paginated_results = all_recommendations[start_index:end_index]
@@ -195,9 +204,8 @@ def get_profile_recommendations(profile: UserProfile, page: int = 1, page_size:
195
  "total_pages": total_pages
196
  }
197
 
198
- # --- βœ… ENDPOINT UPDATED FOR PAGINATION ---
199
  @app.post("/search", response_model=RecommendationResponse)
200
- def search_internships(search: SearchQuery, page: int = 1, page_size: int =4 ):
201
  if chroma_collection is None or encoder is None:
202
  raise HTTPException(status_code=503, detail="Server is not ready.")
203
 
@@ -227,7 +235,88 @@ def search_internships(search: SearchQuery, page: int = 1, page_size: int =4 ):
227
  "total_pages": total_pages
228
  }
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  @app.post("/chat", response_model=ChatResponse)
231
  def chat_with_bot(message: ChatMessage):
232
- response = get_rag_response(message.query)
233
- return {"response": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import math # βœ… Add the math library for ceiling division
5
  from fastapi import FastAPI, HTTPException, Depends, Query
6
  from pydantic import BaseModel, Field
7
+ from typing import List, Optional
8
  import firebase_admin
9
  from firebase_admin import credentials, firestore
10
 
11
  # --- Local Imports ---
12
  from encoder import SentenceEncoder
13
  from populate_chroma import populate_vector_db
14
+ from llm_handler import (
15
+ initialize_llm, get_rag_response, create_chat_session,
16
+ clear_chat_session, delete_chat_session, get_chat_history,
17
+ get_chat_session_count
18
+ )
19
  import llm_handler
20
 
21
  # --------------------------------------------------------------------
 
48
  internship_id: str
49
  score: float
50
 
 
51
  class RecommendationResponse(BaseModel):
52
  recommendations: List[SimpleRecommendation]
53
  page: int
 
57
  status: str
58
  internship_id: str
59
 
60
+ # --- βœ… UPDATED CHAT MODELS ---
61
  class ChatMessage(BaseModel):
62
  query: str
63
+ session_id: Optional[str] = Field(None, description="Chat session ID (optional)")
64
 
65
  class ChatResponse(BaseModel):
66
  response: str
67
+ session_id: str
68
+
69
+ class NewChatSessionResponse(BaseModel):
70
+ session_id: str
71
+ message: str
72
+
73
+ class ChatHistoryResponse(BaseModel):
74
+ session_id: str
75
+ history: List[dict]
76
+
77
+ class ClearChatResponse(BaseModel):
78
+ session_id: str
79
+ message: str
80
 
81
  # --------------------------------------------------------------------
82
  # FastAPI App
83
  # --------------------------------------------------------------------
84
  app = FastAPI(
85
  title="Internship Recommendation & Chatbot API",
86
+ description="An API using Firestore for metadata, ChromaDB for vector search, and an LLM chatbot with memory.",
87
+ version="3.2.0",
88
  root_path=root_path
89
  )
90
 
 
136
  raise
137
 
138
  # --------------------------------------------------------------------
139
+ # Existing Endpoints
140
  # --------------------------------------------------------------------
141
  @app.get("/")
142
  def read_root():
143
+ return {"message": "Welcome to the Internship Recommendation API with Chat Memory!"}
 
 
144
 
145
  @app.post("/setup")
146
  def run_initial_setup(secret_key: str = Query(..., example="your_secret_password")):
 
171
  print(f"βœ… Added internship to Firestore and ChromaDB: {internship.id}")
172
  return {"status": "success", "internship_id": internship.id}
173
 
 
 
174
  @app.post("/profile-recommendations", response_model=RecommendationResponse)
175
  def get_profile_recommendations(profile: UserProfile, page: int = 1, page_size: int = 4):
176
  if chroma_collection is None or encoder is None:
 
179
  query_text = f"Skills: {', '.join(profile.skills)}. Sectors: {', '.join(profile.sectors)}"
180
  query_embedding = encoder.encode([query_text])[0].tolist()
181
 
 
 
 
182
  total_items = chroma_collection.count()
183
  results = chroma_collection.query(
184
  query_embeddings=[query_embedding],
185
+ n_results=total_items
186
  )
187
 
 
188
  ids = results.get('ids', [[]])[0]
189
  distances = results.get('distances', [[]])[0]
190
  all_recommendations = [
 
192
  for i in range(len(ids))
193
  ]
194
 
 
195
  start_index = (page - 1) * page_size
196
  end_index = start_index + page_size
197
  paginated_results = all_recommendations[start_index:end_index]
 
204
  "total_pages": total_pages
205
  }
206
 
 
207
  @app.post("/search", response_model=RecommendationResponse)
208
+ def search_internships(search: SearchQuery, page: int = 1, page_size: int = 4):
209
  if chroma_collection is None or encoder is None:
210
  raise HTTPException(status_code=503, detail="Server is not ready.")
211
 
 
235
  "total_pages": total_pages
236
  }
237
 
238
+ # --------------------------------------------------------------------
239
+ # βœ… NEW CHAT ENDPOINTS WITH MEMORY
240
+ # --------------------------------------------------------------------
241
+
242
+ @app.post("/chat/new-session", response_model=NewChatSessionResponse)
243
+ def create_new_chat_session():
244
+ """Create a new chat session."""
245
+ session_id = create_chat_session()
246
+ return {
247
+ "session_id": session_id,
248
+ "message": "New chat session created successfully"
249
+ }
250
+
251
  @app.post("/chat", response_model=ChatResponse)
252
  def chat_with_bot(message: ChatMessage):
253
+ """Chat with the bot. Maintains memory within the session."""
254
+ try:
255
+ response, session_id = get_rag_response(message.query, message.session_id)
256
+ return {
257
+ "response": response,
258
+ "session_id": session_id
259
+ }
260
+ except Exception as e:
261
+ # Handle the case where get_rag_response returns only response (backward compatibility)
262
+ if isinstance(e, ValueError):
263
+ response = get_rag_response(message.query, message.session_id)
264
+ return {
265
+ "response": response,
266
+ "session_id": message.session_id or "unknown"
267
+ }
268
+ raise HTTPException(status_code=500, detail=str(e))
269
+
270
+ @app.get("/chat/{session_id}/history", response_model=ChatHistoryResponse)
271
+ def get_session_history(session_id: str):
272
+ """Get the chat history for a specific session."""
273
+ history = get_chat_history(session_id)
274
+ if history is None:
275
+ raise HTTPException(status_code=404, detail="Chat session not found")
276
+ return {
277
+ "session_id": session_id,
278
+ "history": history
279
+ }
280
+
281
+ @app.delete("/chat/{session_id}/clear", response_model=ClearChatResponse)
282
+ def clear_session_history(session_id: str):
283
+ """Clear the chat history for a specific session."""
284
+ success = clear_chat_session(session_id)
285
+ if not success:
286
+ raise HTTPException(status_code=404, detail="Chat session not found")
287
+ return {
288
+ "session_id": session_id,
289
+ "message": "Chat history cleared successfully"
290
+ }
291
+
292
+ @app.delete("/chat/{session_id}/delete", response_model=ClearChatResponse)
293
+ def delete_session(session_id: str):
294
+ """Delete a chat session completely."""
295
+ success = delete_chat_session(session_id)
296
+ if not success:
297
+ raise HTTPException(status_code=404, detail="Chat session not found")
298
+ return {
299
+ "session_id": session_id,
300
+ "message": "Chat session deleted successfully"
301
+ }
302
+
303
+ @app.get("/chat/sessions/count")
304
+ def get_active_sessions():
305
+ """Get the number of active chat sessions."""
306
+ count = get_chat_session_count()
307
+ return {
308
+ "active_sessions": count,
309
+ "message": f"There are {count} active chat sessions"
310
+ }
311
+
312
+ # Health check endpoint
313
+ @app.get("/healthz")
314
+ def health_check():
315
+ status = {
316
+ "status": "healthy",
317
+ "encoder_ready": encoder is not None,
318
+ "chroma_ready": chroma_collection is not None,
319
+ "firebase_ready": db is not None,
320
+ "active_chat_sessions": get_chat_session_count()
321
+ }
322
+ return status