Kevin Hu
commited on
Commit
·
6d672a7
1
Parent(s):
e10ed78
add prompt to message (#2099)
Browse files### What problem does this PR solve?
#2098
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
api/apps/conversation_app.py
CHANGED
|
@@ -140,7 +140,8 @@ def completion():
|
|
| 140 |
if not conv.reference:
|
| 141 |
conv.reference.append(ans["reference"])
|
| 142 |
else: conv.reference[-1] = ans["reference"]
|
| 143 |
-
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
|
|
|
|
| 144 |
|
| 145 |
def stream():
|
| 146 |
nonlocal dia, msg, req, conv
|
|
|
|
| 140 |
if not conv.reference:
|
| 141 |
conv.reference.append(ans["reference"])
|
| 142 |
else: conv.reference[-1] = ans["reference"]
|
| 143 |
+
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
|
| 144 |
+
"id": message_id, "prompt": ans.get("prompt", "")}
|
| 145 |
|
| 146 |
def stream():
|
| 147 |
nonlocal dia, msg, req, conv
|
api/db/services/dialog_service.py
CHANGED
|
@@ -179,6 +179,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 179 |
for m in messages if m["role"] != "system"])
|
| 180 |
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
| 181 |
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
|
|
|
| 182 |
|
| 183 |
if "max_tokens" in gen_conf:
|
| 184 |
gen_conf["max_tokens"] = min(
|
|
@@ -186,7 +187,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 186 |
max_tokens - used_token_count)
|
| 187 |
|
| 188 |
def decorate_answer(answer):
|
| 189 |
-
nonlocal prompt_config, knowledges, kwargs, kbinfos
|
| 190 |
refs = []
|
| 191 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
| 192 |
answer, idx = retr.insert_citations(answer,
|
|
@@ -210,17 +211,16 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 210 |
|
| 211 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 212 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 213 |
-
return {"answer": answer, "reference": refs}
|
| 214 |
|
| 215 |
if stream:
|
| 216 |
answer = ""
|
| 217 |
-
for ans in chat_mdl.chat_streamly(
|
| 218 |
answer = ans
|
| 219 |
-
yield {"answer": answer, "reference": {}}
|
| 220 |
yield decorate_answer(answer)
|
| 221 |
else:
|
| 222 |
-
answer = chat_mdl.chat(
|
| 223 |
-
msg[0]["content"], msg[1:], gen_conf)
|
| 224 |
chat_logger.info("User: {}|Assistant: {}".format(
|
| 225 |
msg[-1]["content"], answer))
|
| 226 |
yield decorate_answer(answer)
|
|
@@ -334,7 +334,8 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
| 334 |
chat_logger.warning("SQL missing field: " + sql)
|
| 335 |
return {
|
| 336 |
"answer": "\n".join([clmns, line, rows]),
|
| 337 |
-
"reference": {"chunks": [], "doc_aggs": []}
|
|
|
|
| 338 |
}
|
| 339 |
|
| 340 |
docid_idx = list(docid_idx)[0]
|
|
@@ -348,7 +349,8 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
| 348 |
"answer": "\n".join([clmns, line, rows]),
|
| 349 |
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
| 350 |
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
| 351 |
-
doc_aggs.items()]}
|
|
|
|
| 352 |
}
|
| 353 |
|
| 354 |
|
|
|
|
| 179 |
for m in messages if m["role"] != "system"])
|
| 180 |
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
| 181 |
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
| 182 |
+
prompt = msg[0]["content"]
|
| 183 |
|
| 184 |
if "max_tokens" in gen_conf:
|
| 185 |
gen_conf["max_tokens"] = min(
|
|
|
|
| 187 |
max_tokens - used_token_count)
|
| 188 |
|
| 189 |
def decorate_answer(answer):
|
| 190 |
+
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt
|
| 191 |
refs = []
|
| 192 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
| 193 |
answer, idx = retr.insert_citations(answer,
|
|
|
|
| 211 |
|
| 212 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 213 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 214 |
+
return {"answer": answer, "reference": refs, "prompt": prompt}
|
| 215 |
|
| 216 |
if stream:
|
| 217 |
answer = ""
|
| 218 |
+
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
| 219 |
answer = ans
|
| 220 |
+
yield {"answer": answer, "reference": {}, "prompt": prompt}
|
| 221 |
yield decorate_answer(answer)
|
| 222 |
else:
|
| 223 |
+
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
|
|
|
| 224 |
chat_logger.info("User: {}|Assistant: {}".format(
|
| 225 |
msg[-1]["content"], answer))
|
| 226 |
yield decorate_answer(answer)
|
|
|
|
| 334 |
chat_logger.warning("SQL missing field: " + sql)
|
| 335 |
return {
|
| 336 |
"answer": "\n".join([clmns, line, rows]),
|
| 337 |
+
"reference": {"chunks": [], "doc_aggs": []},
|
| 338 |
+
"prompt": sys_prompt
|
| 339 |
}
|
| 340 |
|
| 341 |
docid_idx = list(docid_idx)[0]
|
|
|
|
| 349 |
"answer": "\n".join([clmns, line, rows]),
|
| 350 |
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
| 351 |
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
| 352 |
+
doc_aggs.items()]},
|
| 353 |
+
"prompt": sys_prompt
|
| 354 |
}
|
| 355 |
|
| 356 |
|