Kevin Hu
commited on
Commit
·
bf00d96
1
Parent(s):
8b574ab
fix duplicated llm name betweeen different suppliers (#2477)
Browse files### What problem does this PR solve?
#2465
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/apps/chunk_app.py +6 -10
- api/db/services/dialog_service.py +9 -2
- api/db/services/llm_service.py +12 -5
- rag/app/naive.py +1 -1
api/apps/chunk_app.py
CHANGED
|
@@ -27,7 +27,7 @@ from rag.utils.es_conn import ELASTICSEARCH
|
|
| 27 |
from rag.utils import rmSpace
|
| 28 |
from api.db import LLMType, ParserType
|
| 29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 30 |
-
from api.db.services.llm_service import
|
| 31 |
from api.db.services.user_service import UserTenantService
|
| 32 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 33 |
from api.db.services.document_service import DocumentService
|
|
@@ -141,8 +141,7 @@ def set():
|
|
| 141 |
return get_data_error_result(retmsg="Tenant not found!")
|
| 142 |
|
| 143 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
| 144 |
-
embd_mdl =
|
| 145 |
-
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
| 146 |
|
| 147 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 148 |
if not e:
|
|
@@ -235,8 +234,7 @@ def create():
|
|
| 235 |
return get_data_error_result(retmsg="Tenant not found!")
|
| 236 |
|
| 237 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
| 238 |
-
embd_mdl =
|
| 239 |
-
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
| 240 |
|
| 241 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
| 242 |
v = 0.1 * v[0] + 0.9 * v[1]
|
|
@@ -281,16 +279,14 @@ def retrieval_test():
|
|
| 281 |
if not e:
|
| 282 |
return get_data_error_result(retmsg="Knowledgebase not found!")
|
| 283 |
|
| 284 |
-
embd_mdl =
|
| 285 |
-
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
| 286 |
|
| 287 |
rerank_mdl = None
|
| 288 |
if req.get("rerank_id"):
|
| 289 |
-
rerank_mdl =
|
| 290 |
-
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
| 291 |
|
| 292 |
if req.get("keyword", False):
|
| 293 |
-
chat_mdl =
|
| 294 |
question += keyword_extraction(chat_mdl, question)
|
| 295 |
|
| 296 |
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
|
|
|
| 27 |
from rag.utils import rmSpace
|
| 28 |
from api.db import LLMType, ParserType
|
| 29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 30 |
+
from api.db.services.llm_service import LLMBundle
|
| 31 |
from api.db.services.user_service import UserTenantService
|
| 32 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 33 |
from api.db.services.document_service import DocumentService
|
|
|
|
| 141 |
return get_data_error_result(retmsg="Tenant not found!")
|
| 142 |
|
| 143 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
| 144 |
+
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
|
|
|
| 145 |
|
| 146 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 147 |
if not e:
|
|
|
|
| 234 |
return get_data_error_result(retmsg="Tenant not found!")
|
| 235 |
|
| 236 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
| 237 |
+
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
|
|
|
| 238 |
|
| 239 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
| 240 |
v = 0.1 * v[0] + 0.9 * v[1]
|
|
|
|
| 279 |
if not e:
|
| 280 |
return get_data_error_result(retmsg="Knowledgebase not found!")
|
| 281 |
|
| 282 |
+
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
|
|
|
| 283 |
|
| 284 |
rerank_mdl = None
|
| 285 |
if req.get("rerank_id"):
|
| 286 |
+
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
|
|
|
| 287 |
|
| 288 |
if req.get("keyword", False):
|
| 289 |
+
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
| 290 |
question += keyword_extraction(chat_mdl, question)
|
| 291 |
|
| 292 |
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
api/db/services/dialog_service.py
CHANGED
|
@@ -78,6 +78,7 @@ def message_fit_in(msg, max_length=4000):
|
|
| 78 |
|
| 79 |
|
| 80 |
def llm_id2llm_type(llm_id):
|
|
|
|
| 81 |
fnm = os.path.join(get_project_base_directory(), "conf")
|
| 82 |
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
| 83 |
for llm_factory in llm_factories["factory_llm_infos"]:
|
|
@@ -89,9 +90,15 @@ def llm_id2llm_type(llm_id):
|
|
| 89 |
def chat(dialog, messages, stream=True, **kwargs):
|
| 90 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 91 |
st = timer()
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
if not llm:
|
| 94 |
-
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=
|
|
|
|
| 95 |
if not llm:
|
| 96 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
| 97 |
max_tokens = 8192
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
def llm_id2llm_type(llm_id):
|
| 81 |
+
llm_id = llm_id.split("@")[0]
|
| 82 |
fnm = os.path.join(get_project_base_directory(), "conf")
|
| 83 |
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
| 84 |
for llm_factory in llm_factories["factory_llm_infos"]:
|
|
|
|
| 90 |
def chat(dialog, messages, stream=True, **kwargs):
|
| 91 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 92 |
st = timer()
|
| 93 |
+
tmp = dialog.llm_id.split("@")
|
| 94 |
+
fid = None
|
| 95 |
+
llm_id = tmp[0]
|
| 96 |
+
if len(tmp)>1: fid = tmp[1]
|
| 97 |
+
|
| 98 |
+
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
| 99 |
if not llm:
|
| 100 |
+
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
| 101 |
+
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
|
| 102 |
if not llm:
|
| 103 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
| 104 |
max_tokens = 8192
|
api/db/services/llm_service.py
CHANGED
|
@@ -17,7 +17,7 @@ from api.db.services.user_service import TenantService
|
|
| 17 |
from api.settings import database_logger
|
| 18 |
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
| 19 |
from api.db import LLMType
|
| 20 |
-
from api.db.db_models import DB
|
| 21 |
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
| 22 |
from api.db.services.common_service import CommonService
|
| 23 |
|
|
@@ -36,7 +36,11 @@ class TenantLLMService(CommonService):
|
|
| 36 |
@classmethod
|
| 37 |
@DB.connection_context()
|
| 38 |
def get_api_key(cls, tenant_id, model_name):
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
if not objs:
|
| 41 |
return
|
| 42 |
return objs[0]
|
|
@@ -81,14 +85,17 @@ class TenantLLMService(CommonService):
|
|
| 81 |
assert False, "LLM type error"
|
| 82 |
|
| 83 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
|
|
|
|
|
|
|
|
|
| 84 |
if model_config: model_config = model_config.to_dict()
|
| 85 |
if not model_config:
|
| 86 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
| 87 |
-
llm = LLMService.query(llm_name=
|
| 88 |
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
| 89 |
-
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name":
|
| 90 |
if not model_config:
|
| 91 |
-
if
|
| 92 |
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
| 93 |
"llm_name": llm_name, "api_base": ""}
|
| 94 |
else:
|
|
|
|
| 17 |
from api.settings import database_logger
|
| 18 |
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
| 19 |
from api.db import LLMType
|
| 20 |
+
from api.db.db_models import DB
|
| 21 |
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
| 22 |
from api.db.services.common_service import CommonService
|
| 23 |
|
|
|
|
| 36 |
@classmethod
|
| 37 |
@DB.connection_context()
|
| 38 |
def get_api_key(cls, tenant_id, model_name):
|
| 39 |
+
arr = model_name.split("@")
|
| 40 |
+
if len(arr) < 2:
|
| 41 |
+
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
|
| 42 |
+
else:
|
| 43 |
+
objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
|
| 44 |
if not objs:
|
| 45 |
return
|
| 46 |
return objs[0]
|
|
|
|
| 85 |
assert False, "LLM type error"
|
| 86 |
|
| 87 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
| 88 |
+
tmp = mdlnm.split("@")
|
| 89 |
+
fid = None if len(tmp) < 2 else tmp[1]
|
| 90 |
+
mdlnm = tmp[0]
|
| 91 |
if model_config: model_config = model_config.to_dict()
|
| 92 |
if not model_config:
|
| 93 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
| 94 |
+
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
| 95 |
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
| 96 |
+
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
|
| 97 |
if not model_config:
|
| 98 |
+
if mdlnm == "flag-embedding":
|
| 99 |
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
| 100 |
"llm_name": llm_name, "api_base": ""}
|
| 101 |
else:
|
rag/app/naive.py
CHANGED
|
@@ -76,7 +76,7 @@ class Docx(DocxParser):
|
|
| 76 |
if last_image:
|
| 77 |
image_list.insert(0, last_image)
|
| 78 |
last_image = None
|
| 79 |
-
lines.append((self.__clean(p.text), image_list, p.style.name))
|
| 80 |
else:
|
| 81 |
if current_image := self.get_picture(self.doc, p):
|
| 82 |
if lines:
|
|
|
|
| 76 |
if last_image:
|
| 77 |
image_list.insert(0, last_image)
|
| 78 |
last_image = None
|
| 79 |
+
lines.append((self.__clean(p.text), image_list, p.style.name if p.style else ""))
|
| 80 |
else:
|
| 81 |
if current_image := self.get_picture(self.doc, p):
|
| 82 |
if lines:
|