Spaces:
Sleeping
Sleeping
File size: 4,121 Bytes
d81440f b180c39 632df2f d81440f 632df2f d81440f b180c39 d81440f b180c39 9fc53c5 b180c39 6d36a0d b180c39 240c11f b180c39 240c11f 520908e b180c39 520908e b180c39 520908e b180c39 240c11f 520908e 240c11f d81440f 240c11f b180c39 240c11f b180c39 d81440f e53ed5b d81440f e53ed5b d81440f b180c39 d81440f 240c11f d81440f 240c11f d81440f b180c39 e53ed5b b180c39 d81440f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import json
import re
from utils.logger import log
from .ai_model import AIModel
class InfoExtractor:
def __init__(self, ai_model):
self.ai_model = ai_model
self.prompt_template = self._build_prompt_template()
def _build_prompt_template(self) -> str:
# --- 重点更新:使用更严格的指令和结构化示例 ---
return """你的任务是且仅是作为文本解析器。
严格分析用户输入,并以一个纯净、无注释的JSON对象格式返回。
**核心规则:**
1. **绝对禁止** 在JSON之外添加任何文本、注释、解释或Markdown标记。你的输出必须从 `{{` 开始,到 `}}` 结束。
2. **必须严格遵守** 下方定义的嵌套JSON结构。不要创造新的键,也不要改变层级。
3. 如果信息未提供,对应的键值必须为 `null`,而不是省略该键。
4. 如果用户输入与旅行无关(如 "你好"),必须返回一个空的JSON对象: `{{}}`。
**强制JSON输出结构:**
```json
{{
"destination": {{
"name": "string or null"
}},
"duration": {{
"days": "integer or null"
}},
"budget": {{
"type": "string ('economy', 'comfortable', 'luxury') or null",
"amount": "number or null",
"currency": "string or null"
}}
}}
```
**示例:**
- 用户输入: "我想去柏林玩3天"
- 你的输出:
```json
{{
"destination": {{
"name": "柏林"
}},
"duration": {{
"days": 3
}},
"budget": {{
"type": null,
"amount": null,
"currency": null
}}
}}
```
---
**用户输入:**
`{user_message}`
**你的输出 (必须是纯JSON):**
"""
def extract(self, message: str) -> dict:
log.info(f"🧠 使用LLM开始提取信息,消息: '{message}'")
prompt = self.prompt_template.format(user_message=message)
raw_response = self.ai_model.generate(prompt)
if not raw_response:
log.error("❌ LLM模型没有返回任何内容。")
return {}
json_str = ""
try:
match = re.search(r'```json\s*(\{.*?\})\s*```', raw_response, re.DOTALL)
if match:
json_str = match.group(1)
else:
start_index = raw_response.find('{')
end_index = raw_response.rfind('}')
if start_index != -1 and end_index != -1 and end_index > start_index:
json_str = raw_response[start_index:end_index + 1]
else:
raise json.JSONDecodeError("在LLM的返回中未找到有效的JSON对象。", raw_response, 0)
extracted_data = json.loads(json_str)
log.info(f"✅ LLM成功提取并解析JSON: {extracted_data}")
except json.JSONDecodeError as e:
log.error(f"❌ 无法解析LLM返回的JSON: '{raw_response}'. 错误: {e}")
return {}
# --- 重点更新:使用更健壮、更安全的逻辑来清理数据 ---
final_info = {}
# 安全地处理 'destination'
destination_data = extracted_data.get("destination")
if isinstance(destination_data, dict) and destination_data.get("name"):
final_info["destination"] = {"name": destination_data["name"]}
# 安全地处理 'duration'
duration_data = extracted_data.get("duration")
if isinstance(duration_data, dict) and duration_data.get("days"):
try:
final_info["duration"] = {"days": int(duration_data["days"])}
except (ValueError, TypeError):
log.warning(f"⚠️ 无法将duration days '{duration_data.get('days')}' 转换为整数。")
# 安全地处理 'budget'
budget_data = extracted_data.get("budget")
if isinstance(budget_data, dict):
# 只要budget对象里有任何非null的值,就把它加进来
if any(v is not None for v in budget_data.values()):
final_info["budget"] = budget_data
log.info(f"📊 LLM最终提取结果 (安全处理后): {list(final_info.keys())}")
return final_info |