al1kss commited on
Commit
2acbc30
Β·
verified Β·
1 Parent(s): 663e454

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +116 -246
main.py CHANGED
@@ -1,34 +1,16 @@
1
- import gradio as gr
 
 
 
2
  import asyncio
 
 
3
  import os
4
- import zipfile
5
  import requests
6
- from pathlib import Path
7
  import numpy as np
8
  from typing import List
9
 
10
- # Try different LightRAG imports based on version
11
- try:
12
- from lightrag import LightRAG, QueryParam
13
- from lightrag.utils import EmbeddingFunc
14
- LIGHTRAG_AVAILABLE = True
15
- except ImportError:
16
- try:
17
- from lightrag.lightrag import LightRAG
18
- from lightrag.query import QueryParam
19
- from lightrag.utils import EmbeddingFunc
20
- LIGHTRAG_AVAILABLE = True
21
- except ImportError:
22
- try:
23
- from lightrag.core import LightRAG
24
- from lightrag.core import QueryParam
25
- from lightrag.utils import EmbeddingFunc
26
- LIGHTRAG_AVAILABLE = True
27
- except ImportError:
28
- print("❌ LightRAG import failed - using fallback mode")
29
- LIGHTRAG_AVAILABLE = False
30
-
31
- # Fallback CloudflareWorker with simple search
32
  class CloudflareWorker:
33
  def __init__(self, cloudflare_api_key: str, api_base_url: str, llm_model_name: str, embedding_model_name: str):
34
  self.cloudflare_api_key = cloudflare_api_key
@@ -38,21 +20,20 @@ class CloudflareWorker:
38
  self.max_tokens = 4080
39
  self.max_response_tokens = 4080
40
 
41
- async def _send_request(self, model_name: str, input_: dict, debug_log: str = ""):
42
  headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"}
43
 
44
  try:
45
  response_raw = requests.post(
46
  f"{self.api_base_url}{model_name}",
47
  headers=headers,
48
- json=input_,
49
- timeout=30
50
  ).json()
51
 
52
  result = response_raw.get("result", {})
53
 
54
  if "data" in result:
55
- return np.array(result["data"]) if LIGHTRAG_AVAILABLE else result["data"]
56
  if "response" in result:
57
  return result["response"]
58
 
@@ -76,79 +57,23 @@ class CloudflareWorker:
76
  "response_token_limit": self.max_response_tokens,
77
  }
78
 
79
- result = await self._send_request(self.llm_model_name, input_)
80
  return result if result is not None else "Error: Failed to get response"
81
 
82
- async def embedding_chunk(self, texts: List[str]):
83
  input_ = {
84
  "text": texts,
85
  "max_tokens": self.max_tokens,
86
  "response_token_limit": self.max_response_tokens,
87
  }
88
 
89
- result = await self._send_request(self.embedding_model_name, input_)
90
 
91
  if result is None:
92
- if LIGHTRAG_AVAILABLE:
93
- return np.random.rand(len(texts), 1024).astype(np.float32)
94
- else:
95
- return [[0.0] * 1024 for _ in texts]
96
 
97
  return result
98
 
99
- # Simple fallback knowledge store if LightRAG fails
100
- class SimpleKnowledgeStore:
101
- def __init__(self, data_dir: str):
102
- self.data_dir = data_dir
103
- self.chunks = []
104
- self.entities = []
105
- self.load_data()
106
-
107
- def load_data(self):
108
- try:
109
- import json
110
- chunks_file = Path(self.data_dir) / "kv_store_text_chunks.json"
111
- if chunks_file.exists():
112
- with open(chunks_file, 'r', encoding='utf-8') as f:
113
- data = json.load(f)
114
- self.chunks = list(data.values()) if data else []
115
-
116
- entities_file = Path(self.data_dir) / "vdb_entities.json"
117
- if entities_file.exists():
118
- with open(entities_file, 'r', encoding='utf-8') as f:
119
- entities_data = json.load(f)
120
- if isinstance(entities_data, dict) and 'data' in entities_data:
121
- self.entities = entities_data['data']
122
- elif isinstance(entities_data, list):
123
- self.entities = entities_data
124
- else:
125
- self.entities = []
126
-
127
- print(f"βœ… Loaded {len(self.chunks)} chunks and {len(self.entities)} entities")
128
-
129
- except Exception as e:
130
- print(f"⚠️ Error loading data: {e}")
131
- self.chunks = []
132
- self.entities = []
133
-
134
- def search(self, query: str, limit: int = 5) -> List[str]:
135
- query_lower = query.lower()
136
- results = []
137
-
138
- for chunk in self.chunks:
139
- if isinstance(chunk, dict) and 'content' in chunk:
140
- content = chunk['content']
141
- if any(word in content.lower() for word in query_lower.split()):
142
- results.append(content)
143
-
144
- for entity in self.entities:
145
- if isinstance(entity, dict):
146
- entity_text = str(entity)
147
- if any(word in entity_text.lower() for word in query_lower.split()):
148
- results.append(entity_text)
149
-
150
- return results[:limit]
151
-
152
  # Configuration
153
  CLOUDFLARE_API_KEY = os.getenv('CLOUDFLARE_API_KEY', 'lMbDDfHi887AK243ZUenm4dHV2nwEx2NSmX6xuq5')
154
  API_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/07c4bcfbc1891c3e528e1c439fee68bd/ai/run/"
@@ -156,181 +81,126 @@ EMBEDDING_MODEL = '@cf/baai/bge-m3'
156
  LLM_MODEL = "@cf/meta/llama-3.2-3b-instruct"
157
  WORKING_DIR = "./dickens"
158
 
159
- # Global instances
 
 
 
 
 
 
 
 
 
 
 
 
160
  rag_instance = None
161
- knowledge_store = None
162
- cloudflare_worker = None
163
 
164
- async def initialize_system():
165
- global rag_instance, knowledge_store, cloudflare_worker
166
-
167
- print("πŸ”„ Initializing system...")
168
-
169
- # Download data if needed
170
- dickens_path = Path(WORKING_DIR)
171
- has_data = dickens_path.exists() and len(list(dickens_path.glob("*.json"))) > 0
 
 
 
 
 
 
172
 
173
- if not has_data:
174
- print("πŸ“₯ Downloading RAG database...")
175
- try:
176
- # REPLACE YOUR_USERNAME with your actual GitHub username
177
- data_url = "https://github.com/YOUR_USERNAME/fire-safety-ai/releases/download/v1.0-data/dickens.zip"
178
-
179
- response = requests.get(data_url, timeout=60)
180
- response.raise_for_status()
181
-
182
- with open("dickens.zip", "wb") as f:
183
- f.write(response.content)
184
-
185
- with zipfile.ZipFile("dickens.zip", 'r') as zip_ref:
186
- zip_ref.extractall(".")
187
-
188
- os.remove("dickens.zip")
189
- print("βœ… Data downloaded!")
190
-
191
- except Exception as e:
192
- print(f"⚠️ Download failed: {e}")
193
- os.makedirs(WORKING_DIR, exist_ok=True)
194
 
195
- # Initialize Cloudflare worker
196
  cloudflare_worker = CloudflareWorker(
197
  cloudflare_api_key=CLOUDFLARE_API_KEY,
198
  api_base_url=API_BASE_URL,
199
  embedding_model_name=EMBEDDING_MODEL,
200
  llm_model_name=LLM_MODEL,
201
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- # Try to initialize LightRAG, fallback to simple store
204
- if LIGHTRAG_AVAILABLE:
205
- try:
206
- rag_instance = LightRAG(
207
- working_dir=WORKING_DIR,
208
- max_parallel_insert=2,
209
- llm_model_func=cloudflare_worker.query,
210
- llm_model_name=LLM_MODEL,
211
- llm_model_max_token_size=4080,
212
- embedding_func=EmbeddingFunc(
213
- embedding_dim=1024,
214
- max_token_size=2048,
215
- func=lambda texts: cloudflare_worker.embedding_chunk(texts),
216
- ),
217
- )
218
-
219
- await rag_instance.initialize_storages()
220
- print("βœ… LightRAG system initialized!")
221
-
222
- except Exception as e:
223
- print(f"⚠️ LightRAG failed, using fallback: {e}")
224
- knowledge_store = SimpleKnowledgeStore(WORKING_DIR)
225
- else:
226
- print("πŸ”„ Using simple knowledge store...")
227
- knowledge_store = SimpleKnowledgeStore(WORKING_DIR)
228
-
229
- print("βœ… System ready!")
230
 
231
- # Initialize on startup
232
- asyncio.run(initialize_system())
 
233
 
234
- async def ask_question(question, mode="hybrid"):
235
- if not question.strip():
236
- return "❌ Please enter a question."
 
 
 
 
 
 
237
 
238
  try:
239
- print(f"πŸ” Processing question: {question}")
 
240
 
241
- # Use LightRAG if available, otherwise fallback
242
- if rag_instance and LIGHTRAG_AVAILABLE:
243
- response = await rag_instance.aquery(
244
- question,
245
- param=QueryParam(mode=mode)
246
- )
247
- return response
248
 
249
- elif knowledge_store and cloudflare_worker:
250
- # Fallback: simple search + Cloudflare AI
251
- relevant_chunks = knowledge_store.search(question, limit=3)
252
- context = "\n".join(relevant_chunks) if relevant_chunks else "No specific context found."
253
-
254
- system_prompt = """You are a Fire Safety AI Assistant specializing in Vietnamese fire safety regulations.
255
- Use the provided context to answer questions about building codes, emergency exits, and fire safety requirements."""
256
-
257
- user_prompt = f"""Context: {context}
258
-
259
- Question: {question}
260
-
261
- Please provide a helpful answer based on the context about Vietnamese fire safety regulations."""
262
-
263
- response = await cloudflare_worker.query(user_prompt, system_prompt)
264
- return response
265
 
266
- else:
267
- return "❌ System not initialized yet. Please wait..."
268
-
269
  except Exception as e:
270
- return f"❌ Error: {str(e)}"
271
-
272
- def sync_ask_question(question, mode):
273
- return asyncio.run(ask_question(question, mode))
274
-
275
- # Create Gradio interface
276
- with gr.Blocks(title="πŸ”₯ Fire Safety AI Assistant", theme=gr.themes.Soft()) as demo:
277
- gr.HTML("<h1 style='text-align: center;'>πŸ”₯ Fire Safety AI Assistant</h1>")
278
- gr.HTML("<p style='text-align: center;'>Ask questions about Vietnamese fire safety regulations</p>")
279
-
280
- with gr.Row():
281
- with gr.Column(scale=1):
282
- question_input = gr.Textbox(
283
- label="Your Question",
284
- placeholder="What are the requirements for emergency exits?",
285
- lines=3
286
- )
287
- mode_dropdown = gr.Dropdown(
288
- choices=["hybrid", "local", "global", "naive"],
289
- value="hybrid",
290
- label="Search Mode",
291
- info="Hybrid is recommended for best results"
292
- )
293
- submit_btn = gr.Button("πŸ” Ask Question", variant="primary", size="lg")
294
-
295
- with gr.Column(scale=2):
296
- answer_output = gr.Textbox(
297
- label="Answer",
298
- lines=15,
299
- show_copy_button=True
300
- )
301
-
302
- # System status
303
- status_text = "βœ… LightRAG System" if LIGHTRAG_AVAILABLE else "⚠️ Fallback Mode"
304
- gr.HTML(f"<p style='text-align: center; color: gray;'>Status: {status_text}</p>")
305
-
306
- # Example questions
307
- gr.HTML("<h3 style='text-align: center;'>πŸ’‘ Example Questions:</h3>")
308
-
309
- with gr.Row():
310
- example1 = gr.Button("What are the requirements for emergency exits?", size="sm")
311
- example2 = gr.Button("How many exits does a building need?", size="sm")
312
-
313
- with gr.Row():
314
- example3 = gr.Button("What are fire safety rules for stairwells?", size="sm")
315
- example4 = gr.Button("What are building safety requirements?", size="sm")
316
-
317
- # Event handlers
318
- submit_btn.click(
319
- sync_ask_question,
320
- inputs=[question_input, mode_dropdown],
321
- outputs=answer_output
322
- )
323
-
324
- question_input.submit(
325
- sync_ask_question,
326
- inputs=[question_input, mode_dropdown],
327
- outputs=answer_output
328
- )
329
-
330
- example1.click(lambda: "What are the requirements for emergency exits?", outputs=question_input)
331
- example2.click(lambda: "How many exits does a building need?", outputs=question_input)
332
- example3.click(lambda: "What are fire safety rules for stairwells?", outputs=question_input)
333
- example4.click(lambda: "What are building safety requirements?", outputs=question_input)
334
 
335
  if __name__ == "__main__":
336
- demo.launch()
 
 
1
+ # main.py - FastAPI Backend
2
+ from fastapi import FastAPI, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
  import asyncio
6
+ from lightrag import LightRAG, QueryParam
7
+ from lightrag.utils import EmbeddingFunc
8
  import os
 
9
  import requests
 
10
  import numpy as np
11
  from typing import List
12
 
13
+ # Your CloudflareWorker class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class CloudflareWorker:
15
  def __init__(self, cloudflare_api_key: str, api_base_url: str, llm_model_name: str, embedding_model_name: str):
16
  self.cloudflare_api_key = cloudflare_api_key
 
20
  self.max_tokens = 4080
21
  self.max_response_tokens = 4080
22
 
23
+ async def _send_request(self, model_name: str, input_: dict, debug_log: str):
24
  headers = {"Authorization": f"Bearer {self.cloudflare_api_key}"}
25
 
26
  try:
27
  response_raw = requests.post(
28
  f"{self.api_base_url}{model_name}",
29
  headers=headers,
30
+ json=input_
 
31
  ).json()
32
 
33
  result = response_raw.get("result", {})
34
 
35
  if "data" in result:
36
+ return np.array(result["data"])
37
  if "response" in result:
38
  return result["response"]
39
 
 
57
  "response_token_limit": self.max_response_tokens,
58
  }
59
 
60
+ result = await self._send_request(self.llm_model_name, input_, "")
61
  return result if result is not None else "Error: Failed to get response"
62
 
63
+ async def embedding_chunk(self, texts: List[str]) -> np.ndarray:
64
  input_ = {
65
  "text": texts,
66
  "max_tokens": self.max_tokens,
67
  "response_token_limit": self.max_response_tokens,
68
  }
69
 
70
+ result = await self._send_request(self.embedding_model_name, input_, "")
71
 
72
  if result is None:
73
+ return np.random.rand(len(texts), 1024).astype(np.float32)
 
 
 
74
 
75
  return result
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Configuration
78
  CLOUDFLARE_API_KEY = os.getenv('CLOUDFLARE_API_KEY', 'lMbDDfHi887AK243ZUenm4dHV2nwEx2NSmX6xuq5')
79
  API_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/07c4bcfbc1891c3e528e1c439fee68bd/ai/run/"
 
81
  LLM_MODEL = "@cf/meta/llama-3.2-3b-instruct"
82
  WORKING_DIR = "./dickens"
83
 
84
+ # Initialize FastAPI
85
+ app = FastAPI(title="Fire Safety AI Assistant API", version="1.0.0")
86
+
87
+ # Enable CORS for frontend
88
+ app.add_middleware(
89
+ CORSMiddleware,
90
+ allow_origins=["*"], # In production, replace with your frontend domain
91
+ allow_credentials=True,
92
+ allow_methods=["*"],
93
+ allow_headers=["*"],
94
+ )
95
+
96
+ # Global RAG instance
97
  rag_instance = None
 
 
98
 
99
+ # Pydantic models
100
+ class QuestionRequest(BaseModel):
101
+ question: str
102
+ mode: str = "hybrid" # naive, local, global, hybrid
103
+
104
+ class QuestionResponse(BaseModel):
105
+ answer: str
106
+ mode: str
107
+ status: str
108
+
109
+ @app.on_event("startup")
110
+ async def startup_event():
111
+ """Initialize RAG system on startup"""
112
+ global rag_instance
113
 
114
+ print("πŸ”„ Initializing RAG system...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
 
116
  cloudflare_worker = CloudflareWorker(
117
  cloudflare_api_key=CLOUDFLARE_API_KEY,
118
  api_base_url=API_BASE_URL,
119
  embedding_model_name=EMBEDDING_MODEL,
120
  llm_model_name=LLM_MODEL,
121
  )
122
+
123
+ rag_instance = LightRAG(
124
+ working_dir=WORKING_DIR,
125
+ max_parallel_insert=2,
126
+ llm_model_func=cloudflare_worker.query,
127
+ llm_model_name=LLM_MODEL,
128
+ llm_model_max_token_size=4080,
129
+ embedding_func=EmbeddingFunc(
130
+ embedding_dim=1024,
131
+ max_token_size=2048,
132
+ func=lambda texts: cloudflare_worker.embedding_chunk(texts),
133
+ ),
134
+ )
135
 
136
+ await rag_instance.initialize_storages()
137
+ print("βœ… RAG system initialized!")
138
+
139
+ @app.get("/")
140
+ async def root():
141
+ return {"message": "πŸ”₯ Fire Safety AI Assistant API", "status": "running"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ @app.get("/health")
144
+ async def health_check():
145
+ return {"status": "healthy", "rag_ready": rag_instance is not None}
146
 
147
+ @app.post("/ask", response_model=QuestionResponse)
148
+ async def ask_question(request: QuestionRequest):
149
+ """Ask a question to the Fire Safety AI"""
150
+
151
+ if not rag_instance:
152
+ raise HTTPException(status_code=503, detail="RAG system not initialized")
153
+
154
+ if not request.question.strip():
155
+ raise HTTPException(status_code=400, detail="Question cannot be empty")
156
 
157
  try:
158
+ # Query the RAG system
159
+ print(f"πŸ” Processing question: {request.question}")
160
 
161
+ response = await rag_instance.aquery(
162
+ request.question,
163
+ param=QueryParam(mode=request.mode)
164
+ )
 
 
 
165
 
166
+ return QuestionResponse(
167
+ answer=response,
168
+ mode=request.mode,
169
+ status="success"
170
+ )
 
 
 
 
 
 
 
 
 
 
 
171
 
 
 
 
172
  except Exception as e:
173
+ print(f"❌ Error processing question: {e}")
174
+ raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
175
+
176
+ @app.get("/modes")
177
+ async def get_available_modes():
178
+ """Get available query modes"""
179
+ return {
180
+ "modes": [
181
+ {"name": "naive", "description": "Simple text search"},
182
+ {"name": "local", "description": "Search specific document sections"},
183
+ {"name": "global", "description": "Look at overall document themes"},
184
+ {"name": "hybrid", "description": "Combined approach (recommended)"}
185
+ ]
186
+ }
187
+
188
+ # Example questions endpoint
189
+ @app.get("/examples")
190
+ async def get_example_questions():
191
+ """Get example questions users can ask"""
192
+ return {
193
+ "examples": [
194
+ "What are the requirements for emergency exits?",
195
+ "How many exits does a building need?",
196
+ "What are fire safety rules for stairwells?",
197
+ "What are building safety requirements?",
198
+ "What are the fire safety regulations for high-rise buildings?",
199
+ "What are the requirements for fire doors?",
200
+ "How should evacuation routes be designed?"
201
+ ]
202
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  if __name__ == "__main__":
205
+ import uvicorn
206
+ uvicorn.run(app, host="0.0.0.0", port=8000)