Kevin Hu
commited on
Commit
·
1ca7adb
1
Parent(s):
bdb8bf3
fix term weight issue (#3306)
Browse files### What problem does this PR solve?
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/benchmark.py +18 -10
- rag/nlp/search.py +2 -2
rag/benchmark.py
CHANGED
|
@@ -34,12 +34,13 @@ from tqdm import tqdm
|
|
| 34 |
|
| 35 |
class Benchmark:
|
| 36 |
def __init__(self, kb_id):
|
| 37 |
-
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
| 38 |
-
self.similarity_threshold = kb.similarity_threshold
|
| 39 |
-
self.vector_similarity_weight = kb.vector_similarity_weight
|
| 40 |
-
self.embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
| 41 |
|
| 42 |
def _get_benchmarks(self, query, dataset_idxnm, count=16):
|
|
|
|
| 43 |
req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
|
| 44 |
sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
|
| 45 |
return sres
|
|
@@ -48,11 +49,15 @@ class Benchmark:
|
|
| 48 |
run = defaultdict(dict)
|
| 49 |
query_list = list(qrels.keys())
|
| 50 |
for query in query_list:
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
return run
|
| 57 |
|
| 58 |
def embedding(self, docs, batch_size=16):
|
|
@@ -99,7 +104,8 @@ class Benchmark:
|
|
| 99 |
query = data.iloc[i]['query']
|
| 100 |
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
| 101 |
d = {
|
| 102 |
-
"id": get_uuid()
|
|
|
|
| 103 |
}
|
| 104 |
tokenize(d, text, "english")
|
| 105 |
docs.append(d)
|
|
@@ -208,6 +214,8 @@ class Benchmark:
|
|
| 208 |
scores = sorted(scores, key=lambda kk: kk[1])
|
| 209 |
for score in scores[:10]:
|
| 210 |
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
|
|
|
|
|
|
|
| 211 |
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
|
| 212 |
|
| 213 |
def __call__(self, dataset, file_path, miracl_corpus=''):
|
|
|
|
| 34 |
|
| 35 |
class Benchmark:
|
| 36 |
def __init__(self, kb_id):
|
| 37 |
+
e, self.kb = KnowledgebaseService.get_by_id(kb_id)
|
| 38 |
+
self.similarity_threshold = self.kb.similarity_threshold
|
| 39 |
+
self.vector_similarity_weight = self.kb.vector_similarity_weight
|
| 40 |
+
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
|
| 41 |
|
| 42 |
def _get_benchmarks(self, query, dataset_idxnm, count=16):
|
| 43 |
+
|
| 44 |
req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
|
| 45 |
sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
|
| 46 |
return sres
|
|
|
|
| 49 |
run = defaultdict(dict)
|
| 50 |
query_list = list(qrels.keys())
|
| 51 |
for query in query_list:
|
| 52 |
+
|
| 53 |
+
ranks = retrievaler.retrieval(query, self.embd_mdl, dataset_idxnm.replace("ragflow_", ""),
|
| 54 |
+
[self.kb.id], 0, 30,
|
| 55 |
+
0.0, self.vector_similarity_weight)
|
| 56 |
+
for c in ranks["chunks"]:
|
| 57 |
+
if "vector" in c:
|
| 58 |
+
del c["vector"]
|
| 59 |
+
run[query][c["chunk_id"]] = c["similarity"]
|
| 60 |
+
|
| 61 |
return run
|
| 62 |
|
| 63 |
def embedding(self, docs, batch_size=16):
|
|
|
|
| 104 |
query = data.iloc[i]['query']
|
| 105 |
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
| 106 |
d = {
|
| 107 |
+
"id": get_uuid(),
|
| 108 |
+
"kb_id": self.kb.id
|
| 109 |
}
|
| 110 |
tokenize(d, text, "english")
|
| 111 |
docs.append(d)
|
|
|
|
| 214 |
scores = sorted(scores, key=lambda kk: kk[1])
|
| 215 |
for score in scores[:10]:
|
| 216 |
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
|
| 217 |
+
json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
|
| 218 |
+
json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
|
| 219 |
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
|
| 220 |
|
| 221 |
def __call__(self, dataset, file_path, miracl_corpus=''):
|
rag/nlp/search.py
CHANGED
|
@@ -211,8 +211,8 @@ class Dealer:
|
|
| 211 |
continue
|
| 212 |
if not isinstance(v, type("")):
|
| 213 |
m[n] = str(m[n])
|
| 214 |
-
if n.find("tks") > 0:
|
| 215 |
-
|
| 216 |
|
| 217 |
if m:
|
| 218 |
res[d["id"]] = m
|
|
|
|
| 211 |
continue
|
| 212 |
if not isinstance(v, type("")):
|
| 213 |
m[n] = str(m[n])
|
| 214 |
+
#if n.find("tks") > 0:
|
| 215 |
+
# m[n] = rmSpace(m[n])
|
| 216 |
|
| 217 |
if m:
|
| 218 |
res[d["id"]] = m
|