Enhance BM25 index creation and retrieval functionality; save index to output directory
Browse files
app.py
CHANGED
|
@@ -357,22 +357,29 @@ bm25_index = BM25Index.build_from_documents(
|
|
| 357 |
|
| 358 |
|
| 359 |
class Hit(TypedDict):
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
|
| 365 |
demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
|
| 366 |
return_type = List[Hit]
|
| 367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
## YOUR_CODE_STARTS_HERE
|
| 370 |
def retrieve(query: str, topk: int = 10) -> return_type:
|
| 371 |
-
ranking = bm25_retriever.retrieve(query=query, topk=
|
| 372 |
hits = []
|
| 373 |
for cid, score in ranking.items():
|
| 374 |
text = bm25_retriever.index.doc_texts[bm25_retriever.index.cid2docid[cid]]
|
| 375 |
-
hits.append(
|
| 376 |
return hits
|
| 377 |
|
| 378 |
|
|
@@ -388,5 +395,5 @@ demo = gr.Interface(
|
|
| 388 |
["What are the symptoms of immunodeficiency?"],
|
| 389 |
],
|
| 390 |
)
|
| 391 |
-
## YOUR_CODE_ENDS_HERE
|
| 392 |
demo.launch()
|
|
|
|
|
|
| 357 |
|
| 358 |
|
| 359 |
class Hit(TypedDict):
|
| 360 |
+
cid: str
|
| 361 |
+
score: float
|
| 362 |
+
text: str
|
|
|
|
| 363 |
|
| 364 |
demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
|
| 365 |
return_type = List[Hit]
|
| 366 |
+
bm25_index = BM25Index.build_from_documents(
|
| 367 |
+
documents=iter(sciq.corpus),
|
| 368 |
+
ndocs=12160,
|
| 369 |
+
show_progress_bar=True,
|
| 370 |
+
k1=0.9,
|
| 371 |
+
b=0.4,
|
| 372 |
+
)
|
| 373 |
+
bm25_index.save("output/bm25_index")
|
| 374 |
+
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
|
| 375 |
|
| 376 |
## YOUR_CODE_STARTS_HERE
|
| 377 |
def retrieve(query: str, topk: int = 10) -> return_type:
|
| 378 |
+
ranking = bm25_retriever.retrieve(query=query, topk=topk)
|
| 379 |
hits = []
|
| 380 |
for cid, score in ranking.items():
|
| 381 |
text = bm25_retriever.index.doc_texts[bm25_retriever.index.cid2docid[cid]]
|
| 382 |
+
hits.append(Hit(cid=cid, score=score, text=text))
|
| 383 |
return hits
|
| 384 |
|
| 385 |
|
|
|
|
| 395 |
["What are the symptoms of immunodeficiency?"],
|
| 396 |
],
|
| 397 |
)
|
|
|
|
| 398 |
demo.launch()
|
| 399 |
+
## YOUR_CODE_ENDS_HERE
|