Kevin Hu
commited on
Commit
·
6a9fa6b
1
Parent(s):
b6c5e1b
fix bug about fetching knowledge graph (#3394)
Browse files### What problem does this PR solve?
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/apps/chunk_app.py +1 -4
- api/apps/document_app.py +2 -2
- api/db/services/file_service.py +37 -0
- api/db/services/knowledgebase_service.py +1 -1
- deepdoc/parser/txt_parser.py +17 -13
- rag/utils/es_conn.py +62 -42
api/apps/chunk_app.py
CHANGED
|
@@ -301,16 +301,13 @@ def retrieval_test():
|
|
| 301 |
@login_required
|
| 302 |
def knowledge_graph():
|
| 303 |
doc_id = request.args["doc_id"]
|
| 304 |
-
e, doc = DocumentService.get_by_id(doc_id)
|
| 305 |
-
if not e:
|
| 306 |
-
return get_data_error_result(message="Document not found!")
|
| 307 |
tenant_id = DocumentService.get_tenant_id(doc_id)
|
| 308 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
| 309 |
req = {
|
| 310 |
"doc_ids":[doc_id],
|
| 311 |
"knowledge_graph_kwd": ["graph", "mind_map"]
|
| 312 |
}
|
| 313 |
-
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids
|
| 314 |
obj = {"graph": {}, "mind_map": {}}
|
| 315 |
for id in sres.ids[:2]:
|
| 316 |
ty = sres.field[id]["knowledge_graph_kwd"]
|
|
|
|
| 301 |
@login_required
|
| 302 |
def knowledge_graph():
|
| 303 |
doc_id = request.args["doc_id"]
|
|
|
|
|
|
|
|
|
|
| 304 |
tenant_id = DocumentService.get_tenant_id(doc_id)
|
| 305 |
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
| 306 |
req = {
|
| 307 |
"doc_ids":[doc_id],
|
| 308 |
"knowledge_graph_kwd": ["graph", "mind_map"]
|
| 309 |
}
|
| 310 |
+
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
| 311 |
obj = {"graph": {}, "mind_map": {}}
|
| 312 |
for id in sres.ids[:2]:
|
| 313 |
ty = sres.field[id]["knowledge_graph_kwd"]
|
api/apps/document_app.py
CHANGED
|
@@ -524,7 +524,7 @@ def upload_and_parse():
|
|
| 524 |
@manager.route('/parse', methods=['POST'])
|
| 525 |
@login_required
|
| 526 |
def parse():
|
| 527 |
-
url = request.json.get("url")
|
| 528 |
if url:
|
| 529 |
if not is_valid_url(url):
|
| 530 |
return get_json_result(
|
|
@@ -537,7 +537,7 @@ def parse():
|
|
| 537 |
options.add_argument('--disable-dev-shm-usage')
|
| 538 |
driver = Chrome(options=options)
|
| 539 |
driver.get(url)
|
| 540 |
-
sections = RAGFlowHtmlParser()(
|
| 541 |
return get_json_result(data="\n".join(sections))
|
| 542 |
|
| 543 |
if 'file' not in request.files:
|
|
|
|
| 524 |
@manager.route('/parse', methods=['POST'])
|
| 525 |
@login_required
|
| 526 |
def parse():
|
| 527 |
+
url = request.json.get("url") if request.json else ""
|
| 528 |
if url:
|
| 529 |
if not is_valid_url(url):
|
| 530 |
return get_json_result(
|
|
|
|
| 537 |
options.add_argument('--disable-dev-shm-usage')
|
| 538 |
driver = Chrome(options=options)
|
| 539 |
driver.get(url)
|
| 540 |
+
sections = RAGFlowHtmlParser().parser_txt(driver.page_source)
|
| 541 |
return get_json_result(data="\n".join(sections))
|
| 542 |
|
| 543 |
if 'file' not in request.files:
|
api/db/services/file_service.py
CHANGED
|
@@ -15,6 +15,8 @@
|
|
| 15 |
#
|
| 16 |
import re
|
| 17 |
import os
|
|
|
|
|
|
|
| 18 |
from flask_login import current_user
|
| 19 |
from peewee import fn
|
| 20 |
|
|
@@ -385,6 +387,41 @@ class FileService(CommonService):
|
|
| 385 |
|
| 386 |
return err, files
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
@staticmethod
|
| 389 |
def get_parser(doc_type, filename, default):
|
| 390 |
if doc_type == FileType.VISUAL:
|
|
|
|
| 15 |
#
|
| 16 |
import re
|
| 17 |
import os
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 19 |
+
|
| 20 |
from flask_login import current_user
|
| 21 |
from peewee import fn
|
| 22 |
|
|
|
|
| 387 |
|
| 388 |
return err, files
|
| 389 |
|
| 390 |
+
@staticmethod
|
| 391 |
+
def parse_docs(file_objs, user_id):
|
| 392 |
+
from rag.app import presentation, picture, naive, audio, email
|
| 393 |
+
|
| 394 |
+
def dummy(prog=None, msg=""):
|
| 395 |
+
pass
|
| 396 |
+
|
| 397 |
+
FACTORY = {
|
| 398 |
+
ParserType.PRESENTATION.value: presentation,
|
| 399 |
+
ParserType.PICTURE.value: picture,
|
| 400 |
+
ParserType.AUDIO.value: audio,
|
| 401 |
+
ParserType.EMAIL.value: email
|
| 402 |
+
}
|
| 403 |
+
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
|
| 404 |
+
exe = ThreadPoolExecutor(max_workers=12)
|
| 405 |
+
threads = []
|
| 406 |
+
for file in file_objs:
|
| 407 |
+
kwargs = {
|
| 408 |
+
"lang": "English",
|
| 409 |
+
"callback": dummy,
|
| 410 |
+
"parser_config": parser_config,
|
| 411 |
+
"from_page": 0,
|
| 412 |
+
"to_page": 100000,
|
| 413 |
+
"tenant_id": user_id
|
| 414 |
+
}
|
| 415 |
+
filetype = filename_type(file.filename)
|
| 416 |
+
blob = file.read()
|
| 417 |
+
threads.append(exe.submit(FACTORY.get(FileService.get_parser(filetype, file.filename, ""), naive).chunk, file.filename, blob, **kwargs))
|
| 418 |
+
|
| 419 |
+
res = []
|
| 420 |
+
for th in threads:
|
| 421 |
+
res.append("\n".join([ck["content_with_weight"] for ck in th.result()]))
|
| 422 |
+
|
| 423 |
+
return "\n\n".join(res)
|
| 424 |
+
|
| 425 |
@staticmethod
|
| 426 |
def get_parser(doc_type, filename, default):
|
| 427 |
if doc_type == FileType.VISUAL:
|
api/db/services/knowledgebase_service.py
CHANGED
|
@@ -73,7 +73,7 @@ class KnowledgebaseService(CommonService):
|
|
| 73 |
cls.model.id,
|
| 74 |
]
|
| 75 |
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
| 76 |
-
kb_ids = [kb
|
| 77 |
return kb_ids
|
| 78 |
|
| 79 |
@classmethod
|
|
|
|
| 73 |
cls.model.id,
|
| 74 |
]
|
| 75 |
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
| 76 |
+
kb_ids = [kb.id for kb in kbs]
|
| 77 |
return kb_ids
|
| 78 |
|
| 79 |
@classmethod
|
deepdoc/parser/txt_parser.py
CHANGED
|
@@ -10,6 +10,8 @@
|
|
| 10 |
# See the License for the specific language governing permissions and
|
| 11 |
# limitations under the License.
|
| 12 |
#
|
|
|
|
|
|
|
| 13 |
from deepdoc.parser.utils import get_text
|
| 14 |
from rag.nlp import num_tokens_from_string
|
| 15 |
|
|
@@ -29,8 +31,6 @@ class RAGFlowTxtParser:
|
|
| 29 |
def add_chunk(t):
|
| 30 |
nonlocal cks, tk_nums, delimiter
|
| 31 |
tnum = num_tokens_from_string(t)
|
| 32 |
-
if tnum < 8:
|
| 33 |
-
pos = ""
|
| 34 |
if tk_nums[-1] > chunk_token_num:
|
| 35 |
cks.append(t)
|
| 36 |
tk_nums.append(tnum)
|
|
@@ -38,15 +38,19 @@ class RAGFlowTxtParser:
|
|
| 38 |
cks[-1] += t
|
| 39 |
tk_nums[-1] += tnum
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
return [[c,""] for c in cks]
|
|
|
|
| 10 |
# See the License for the specific language governing permissions and
|
| 11 |
# limitations under the License.
|
| 12 |
#
|
| 13 |
+
import re
|
| 14 |
+
|
| 15 |
from deepdoc.parser.utils import get_text
|
| 16 |
from rag.nlp import num_tokens_from_string
|
| 17 |
|
|
|
|
| 31 |
def add_chunk(t):
|
| 32 |
nonlocal cks, tk_nums, delimiter
|
| 33 |
tnum = num_tokens_from_string(t)
|
|
|
|
|
|
|
| 34 |
if tk_nums[-1] > chunk_token_num:
|
| 35 |
cks.append(t)
|
| 36 |
tk_nums.append(tnum)
|
|
|
|
| 38 |
cks[-1] += t
|
| 39 |
tk_nums[-1] += tnum
|
| 40 |
|
| 41 |
+
dels = []
|
| 42 |
+
s = 0
|
| 43 |
+
for m in re.finditer(r"`([^`]+)`", delimiter, re.I):
|
| 44 |
+
f, t = m.span()
|
| 45 |
+
dels.append(m.group(1))
|
| 46 |
+
dels.extend(list(delimiter[s: f]))
|
| 47 |
+
s = t
|
| 48 |
+
if s < len(delimiter):
|
| 49 |
+
dels.extend(list(delimiter[s:]))
|
| 50 |
+
dels = [re.escape(d) for d in delimiter if d]
|
| 51 |
+
dels = [d for d in dels if d]
|
| 52 |
+
dels = "|".join(dels)
|
| 53 |
+
secs = re.split(r"(%s)" % dels, txt)
|
| 54 |
+
for sec in secs: add_chunk(sec)
|
| 55 |
|
| 56 |
+
return [[c, ""] for c in cks]
|
rag/utils/es_conn.py
CHANGED
|
@@ -13,7 +13,8 @@ from rag import settings
|
|
| 13 |
from rag.utils import singleton
|
| 14 |
from api.utils.file_utils import get_project_base_directory
|
| 15 |
import polars as pl
|
| 16 |
-
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr,
|
|
|
|
| 17 |
from rag.nlp import is_english, rag_tokenizer
|
| 18 |
|
| 19 |
|
|
@@ -26,7 +27,8 @@ class ESConnection(DocStoreConnection):
|
|
| 26 |
try:
|
| 27 |
self.es = Elasticsearch(
|
| 28 |
settings.ES["hosts"].split(","),
|
| 29 |
-
basic_auth=(settings.ES["username"], settings.ES[
|
|
|
|
| 30 |
verify_certs=False,
|
| 31 |
timeout=600
|
| 32 |
)
|
|
@@ -57,6 +59,7 @@ class ESConnection(DocStoreConnection):
|
|
| 57 |
"""
|
| 58 |
Database operations
|
| 59 |
"""
|
|
|
|
| 60 |
def dbType(self) -> str:
|
| 61 |
return "elasticsearch"
|
| 62 |
|
|
@@ -66,6 +69,7 @@ class ESConnection(DocStoreConnection):
|
|
| 66 |
"""
|
| 67 |
Table operations
|
| 68 |
"""
|
|
|
|
| 69 |
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
| 70 |
if self.indexExist(indexName, knowledgebaseId):
|
| 71 |
return True
|
|
@@ -97,7 +101,10 @@ class ESConnection(DocStoreConnection):
|
|
| 97 |
"""
|
| 98 |
CRUD operations
|
| 99 |
"""
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
| 101 |
"""
|
| 102 |
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
| 103 |
"""
|
|
@@ -109,8 +116,10 @@ class ESConnection(DocStoreConnection):
|
|
| 109 |
bqry = None
|
| 110 |
vector_similarity_weight = 0.5
|
| 111 |
for m in matchExprs:
|
| 112 |
-
if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
|
| 113 |
-
assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
|
|
|
|
|
|
|
| 114 |
weights = m.fusion_params["weights"]
|
| 115 |
vector_similarity_weight = float(weights.split(",")[1])
|
| 116 |
for m in matchExprs:
|
|
@@ -119,36 +128,41 @@ class ESConnection(DocStoreConnection):
|
|
| 119 |
if "minimum_should_match" in m.extra_options:
|
| 120 |
minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
|
| 121 |
bqry = Q("bool",
|
| 122 |
-
|
| 123 |
type="best_fields", query=m.matching_text,
|
| 124 |
-
minimum_should_match
|
| 125 |
boost=1),
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
if condition:
|
| 129 |
-
for k, v in condition.items():
|
| 130 |
-
if not isinstance(k, str) or not v:
|
| 131 |
-
continue
|
| 132 |
-
if isinstance(v, list):
|
| 133 |
-
bqry.filter.append(Q("terms", **{k: v}))
|
| 134 |
-
elif isinstance(v, str) or isinstance(v, int):
|
| 135 |
-
bqry.filter.append(Q("term", **{k: v}))
|
| 136 |
-
else:
|
| 137 |
-
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
| 138 |
elif isinstance(m, MatchDenseExpr):
|
| 139 |
-
assert(bqry is not None)
|
| 140 |
similarity = 0.0
|
| 141 |
if "similarity" in m.extra_options:
|
| 142 |
similarity = m.extra_options["similarity"]
|
| 143 |
s = s.knn(m.vector_column_name,
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
for field in highlightFields:
|
| 153 |
s = s.highlight(field)
|
| 154 |
|
|
@@ -157,12 +171,13 @@ class ESConnection(DocStoreConnection):
|
|
| 157 |
for field, order in orderBy.fields:
|
| 158 |
order = "asc" if order == 0 else "desc"
|
| 159 |
orders.append({field: {"order": order, "unmapped_type": "float",
|
| 160 |
-
|
| 161 |
s = s.sort(*orders)
|
| 162 |
|
| 163 |
if limit > 0:
|
| 164 |
s = s[offset:limit]
|
| 165 |
q = s.to_dict()
|
|
|
|
| 166 |
# logger.info("ESConnection.search [Q]: " + json.dumps(q))
|
| 167 |
|
| 168 |
for i in range(3):
|
|
@@ -189,7 +204,7 @@ class ESConnection(DocStoreConnection):
|
|
| 189 |
for i in range(3):
|
| 190 |
try:
|
| 191 |
res = self.es.get(index=(indexName),
|
| 192 |
-
|
| 193 |
if str(res.get("timed_out", "")).lower() == "true":
|
| 194 |
raise Exception("Es Timeout.")
|
| 195 |
if not res.get("found"):
|
|
@@ -222,7 +237,7 @@ class ESConnection(DocStoreConnection):
|
|
| 222 |
for _ in range(100):
|
| 223 |
try:
|
| 224 |
r = self.es.bulk(index=(indexName), operations=operations,
|
| 225 |
-
|
| 226 |
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
| 227 |
return res
|
| 228 |
|
|
@@ -249,7 +264,8 @@ class ESConnection(DocStoreConnection):
|
|
| 249 |
self.es.update(index=indexName, id=chunkId, doc=doc)
|
| 250 |
return True
|
| 251 |
except Exception as e:
|
| 252 |
-
logger.exception(
|
|
|
|
| 253 |
if str(e).find("Timeout") > 0:
|
| 254 |
continue
|
| 255 |
else:
|
|
@@ -263,7 +279,8 @@ class ESConnection(DocStoreConnection):
|
|
| 263 |
elif isinstance(v, str) or isinstance(v, int):
|
| 264 |
bqry.filter.append(Q("term", **{k: v}))
|
| 265 |
else:
|
| 266 |
-
raise Exception(
|
|
|
|
| 267 |
scripts = []
|
| 268 |
for k, v in newValue.items():
|
| 269 |
if not isinstance(k, str) or not v:
|
|
@@ -273,7 +290,8 @@ class ESConnection(DocStoreConnection):
|
|
| 273 |
elif isinstance(v, int):
|
| 274 |
scripts.append(f"ctx._source.{k} = {v}")
|
| 275 |
else:
|
| 276 |
-
raise Exception(
|
|
|
|
| 277 |
ubq = UpdateByQuery(
|
| 278 |
index=indexName).using(
|
| 279 |
self.es).query(bqry)
|
|
@@ -313,7 +331,7 @@ class ESConnection(DocStoreConnection):
|
|
| 313 |
try:
|
| 314 |
res = self.es.delete_by_query(
|
| 315 |
index=indexName,
|
| 316 |
-
body
|
| 317 |
refresh=True)
|
| 318 |
return res["deleted"]
|
| 319 |
except Exception as e:
|
|
@@ -325,10 +343,10 @@ class ESConnection(DocStoreConnection):
|
|
| 325 |
return 0
|
| 326 |
return 0
|
| 327 |
|
| 328 |
-
|
| 329 |
"""
|
| 330 |
Helper functions for search result
|
| 331 |
"""
|
|
|
|
| 332 |
def getTotal(self, res):
|
| 333 |
if isinstance(res["hits"]["total"], type({})):
|
| 334 |
return res["hits"]["total"]["value"]
|
|
@@ -376,12 +394,13 @@ class ESConnection(DocStoreConnection):
|
|
| 376 |
continue
|
| 377 |
|
| 378 |
txt = d["_source"][fieldnm]
|
| 379 |
-
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
| 380 |
txts = []
|
| 381 |
for t in re.split(r"[.?!;\n]", txt):
|
| 382 |
for w in keywords:
|
| 383 |
-
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t,
|
| 384 |
-
|
|
|
|
| 385 |
continue
|
| 386 |
txts.append(t)
|
| 387 |
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
|
@@ -395,10 +414,10 @@ class ESConnection(DocStoreConnection):
|
|
| 395 |
bkts = res["aggregations"][agg_field]["buckets"]
|
| 396 |
return [(b["key"], b["doc_count"]) for b in bkts]
|
| 397 |
|
| 398 |
-
|
| 399 |
"""
|
| 400 |
SQL
|
| 401 |
"""
|
|
|
|
| 402 |
def sql(self, sql: str, fetch_size: int, format: str):
|
| 403 |
logger.info(f"ESConnection.sql get sql: {sql}")
|
| 404 |
sql = re.sub(r"[ `]+", " ", sql)
|
|
@@ -413,7 +432,7 @@ class ESConnection(DocStoreConnection):
|
|
| 413 |
r.group(1),
|
| 414 |
r.group(2),
|
| 415 |
r.group(3)),
|
| 416 |
-
|
| 417 |
|
| 418 |
for p, r in replaces:
|
| 419 |
sql = sql.replace(p, r, 1)
|
|
@@ -421,7 +440,8 @@ class ESConnection(DocStoreConnection):
|
|
| 421 |
|
| 422 |
for i in range(3):
|
| 423 |
try:
|
| 424 |
-
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
|
|
|
| 425 |
return res
|
| 426 |
except ConnectionTimeout:
|
| 427 |
logger.exception("ESConnection.sql timeout [Q]: " + sql)
|
|
|
|
| 13 |
from rag.utils import singleton
|
| 14 |
from api.utils.file_utils import get_project_base_directory
|
| 15 |
import polars as pl
|
| 16 |
+
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
| 17 |
+
FusionExpr
|
| 18 |
from rag.nlp import is_english, rag_tokenizer
|
| 19 |
|
| 20 |
|
|
|
|
| 27 |
try:
|
| 28 |
self.es = Elasticsearch(
|
| 29 |
settings.ES["hosts"].split(","),
|
| 30 |
+
basic_auth=(settings.ES["username"], settings.ES[
|
| 31 |
+
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
| 32 |
verify_certs=False,
|
| 33 |
timeout=600
|
| 34 |
)
|
|
|
|
| 59 |
"""
|
| 60 |
Database operations
|
| 61 |
"""
|
| 62 |
+
|
| 63 |
def dbType(self) -> str:
|
| 64 |
return "elasticsearch"
|
| 65 |
|
|
|
|
| 69 |
"""
|
| 70 |
Table operations
|
| 71 |
"""
|
| 72 |
+
|
| 73 |
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
| 74 |
if self.indexExist(indexName, knowledgebaseId):
|
| 75 |
return True
|
|
|
|
| 101 |
"""
|
| 102 |
CRUD operations
|
| 103 |
"""
|
| 104 |
+
|
| 105 |
+
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr],
|
| 106 |
+
orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str],
|
| 107 |
+
knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
|
| 108 |
"""
|
| 109 |
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
| 110 |
"""
|
|
|
|
| 116 |
bqry = None
|
| 117 |
vector_similarity_weight = 0.5
|
| 118 |
for m in matchExprs:
|
| 119 |
+
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
| 120 |
+
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
|
| 121 |
+
MatchDenseExpr) and isinstance(
|
| 122 |
+
matchExprs[2], FusionExpr)
|
| 123 |
weights = m.fusion_params["weights"]
|
| 124 |
vector_similarity_weight = float(weights.split(",")[1])
|
| 125 |
for m in matchExprs:
|
|
|
|
| 128 |
if "minimum_should_match" in m.extra_options:
|
| 129 |
minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
|
| 130 |
bqry = Q("bool",
|
| 131 |
+
must=Q("query_string", fields=m.fields,
|
| 132 |
type="best_fields", query=m.matching_text,
|
| 133 |
+
minimum_should_match=minimum_should_match,
|
| 134 |
boost=1),
|
| 135 |
+
boost=1.0 - vector_similarity_weight,
|
| 136 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
elif isinstance(m, MatchDenseExpr):
|
| 138 |
+
assert (bqry is not None)
|
| 139 |
similarity = 0.0
|
| 140 |
if "similarity" in m.extra_options:
|
| 141 |
similarity = m.extra_options["similarity"]
|
| 142 |
s = s.knn(m.vector_column_name,
|
| 143 |
+
m.topn,
|
| 144 |
+
m.topn * 2,
|
| 145 |
+
query_vector=list(m.embedding_data),
|
| 146 |
+
filter=bqry.to_dict(),
|
| 147 |
+
similarity=similarity,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if condition:
|
| 151 |
+
if not bqry:
|
| 152 |
+
bqry = Q("bool", must=[])
|
| 153 |
+
for k, v in condition.items():
|
| 154 |
+
if not isinstance(k, str) or not v:
|
| 155 |
+
continue
|
| 156 |
+
if isinstance(v, list):
|
| 157 |
+
bqry.filter.append(Q("terms", **{k: v}))
|
| 158 |
+
elif isinstance(v, str) or isinstance(v, int):
|
| 159 |
+
bqry.filter.append(Q("term", **{k: v}))
|
| 160 |
+
else:
|
| 161 |
+
raise Exception(
|
| 162 |
+
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
| 163 |
+
|
| 164 |
+
if bqry:
|
| 165 |
+
s = s.query(bqry)
|
| 166 |
for field in highlightFields:
|
| 167 |
s = s.highlight(field)
|
| 168 |
|
|
|
|
| 171 |
for field, order in orderBy.fields:
|
| 172 |
order = "asc" if order == 0 else "desc"
|
| 173 |
orders.append({field: {"order": order, "unmapped_type": "float",
|
| 174 |
+
"mode": "avg", "numeric_type": "double"}})
|
| 175 |
s = s.sort(*orders)
|
| 176 |
|
| 177 |
if limit > 0:
|
| 178 |
s = s[offset:limit]
|
| 179 |
q = s.to_dict()
|
| 180 |
+
print(json.dumps(q), flush=True)
|
| 181 |
# logger.info("ESConnection.search [Q]: " + json.dumps(q))
|
| 182 |
|
| 183 |
for i in range(3):
|
|
|
|
| 204 |
for i in range(3):
|
| 205 |
try:
|
| 206 |
res = self.es.get(index=(indexName),
|
| 207 |
+
id=chunkId, source=True, )
|
| 208 |
if str(res.get("timed_out", "")).lower() == "true":
|
| 209 |
raise Exception("Es Timeout.")
|
| 210 |
if not res.get("found"):
|
|
|
|
| 237 |
for _ in range(100):
|
| 238 |
try:
|
| 239 |
r = self.es.bulk(index=(indexName), operations=operations,
|
| 240 |
+
refresh=False, timeout="600s")
|
| 241 |
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
| 242 |
return res
|
| 243 |
|
|
|
|
| 264 |
self.es.update(index=indexName, id=chunkId, doc=doc)
|
| 265 |
return True
|
| 266 |
except Exception as e:
|
| 267 |
+
logger.exception(
|
| 268 |
+
f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})")
|
| 269 |
if str(e).find("Timeout") > 0:
|
| 270 |
continue
|
| 271 |
else:
|
|
|
|
| 279 |
elif isinstance(v, str) or isinstance(v, int):
|
| 280 |
bqry.filter.append(Q("term", **{k: v}))
|
| 281 |
else:
|
| 282 |
+
raise Exception(
|
| 283 |
+
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
| 284 |
scripts = []
|
| 285 |
for k, v in newValue.items():
|
| 286 |
if not isinstance(k, str) or not v:
|
|
|
|
| 290 |
elif isinstance(v, int):
|
| 291 |
scripts.append(f"ctx._source.{k} = {v}")
|
| 292 |
else:
|
| 293 |
+
raise Exception(
|
| 294 |
+
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
| 295 |
ubq = UpdateByQuery(
|
| 296 |
index=indexName).using(
|
| 297 |
self.es).query(bqry)
|
|
|
|
| 331 |
try:
|
| 332 |
res = self.es.delete_by_query(
|
| 333 |
index=indexName,
|
| 334 |
+
body=Search().query(qry).to_dict(),
|
| 335 |
refresh=True)
|
| 336 |
return res["deleted"]
|
| 337 |
except Exception as e:
|
|
|
|
| 343 |
return 0
|
| 344 |
return 0
|
| 345 |
|
|
|
|
| 346 |
"""
|
| 347 |
Helper functions for search result
|
| 348 |
"""
|
| 349 |
+
|
| 350 |
def getTotal(self, res):
|
| 351 |
if isinstance(res["hits"]["total"], type({})):
|
| 352 |
return res["hits"]["total"]["value"]
|
|
|
|
| 394 |
continue
|
| 395 |
|
| 396 |
txt = d["_source"][fieldnm]
|
| 397 |
+
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
| 398 |
txts = []
|
| 399 |
for t in re.split(r"[.?!;\n]", txt):
|
| 400 |
for w in keywords:
|
| 401 |
+
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
|
| 402 |
+
flags=re.IGNORECASE | re.MULTILINE)
|
| 403 |
+
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
| 404 |
continue
|
| 405 |
txts.append(t)
|
| 406 |
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
|
|
|
| 414 |
bkts = res["aggregations"][agg_field]["buckets"]
|
| 415 |
return [(b["key"], b["doc_count"]) for b in bkts]
|
| 416 |
|
|
|
|
| 417 |
"""
|
| 418 |
SQL
|
| 419 |
"""
|
| 420 |
+
|
| 421 |
def sql(self, sql: str, fetch_size: int, format: str):
|
| 422 |
logger.info(f"ESConnection.sql get sql: {sql}")
|
| 423 |
sql = re.sub(r"[ `]+", " ", sql)
|
|
|
|
| 432 |
r.group(1),
|
| 433 |
r.group(2),
|
| 434 |
r.group(3)),
|
| 435 |
+
match))
|
| 436 |
|
| 437 |
for p, r in replaces:
|
| 438 |
sql = sql.replace(p, r, 1)
|
|
|
|
| 440 |
|
| 441 |
for i in range(3):
|
| 442 |
try:
|
| 443 |
+
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
| 444 |
+
request_timeout="2s")
|
| 445 |
return res
|
| 446 |
except ConnectionTimeout:
|
| 447 |
logger.exception("ESConnection.sql timeout [Q]: " + sql)
|