Travel_Assistant / modules /info_extractor.py
Eliot0110's picture
fix: 占位符{{}}
9fc53c5
raw
history blame
4.12 kB
import json
import re
from utils.logger import log
from .ai_model import AIModel
class InfoExtractor:
def __init__(self, ai_model):
self.ai_model = ai_model
self.prompt_template = self._build_prompt_template()
def _build_prompt_template(self) -> str:
# --- 重点更新:使用更严格的指令和结构化示例 ---
return """你的任务是且仅是作为文本解析器。
严格分析用户输入,并以一个纯净、无注释的JSON对象格式返回。
**核心规则:**
1. **绝对禁止** 在JSON之外添加任何文本、注释、解释或Markdown标记。你的输出必须从 `{{` 开始,到 `}}` 结束。
2. **必须严格遵守** 下方定义的嵌套JSON结构。不要创造新的键,也不要改变层级。
3. 如果信息未提供,对应的键值必须为 `null`,而不是省略该键。
4. 如果用户输入与旅行无关(如 "你好"),必须返回一个空的JSON对象: `{{}}`。
**强制JSON输出结构:**
```json
{{
"destination": {{
"name": "string or null"
}},
"duration": {{
"days": "integer or null"
}},
"budget": {{
"type": "string ('economy', 'comfortable', 'luxury') or null",
"amount": "number or null",
"currency": "string or null"
}}
}}
```
**示例:**
- 用户输入: "我想去柏林玩3天"
- 你的输出:
```json
{{
"destination": {{
"name": "柏林"
}},
"duration": {{
"days": 3
}},
"budget": {{
"type": null,
"amount": null,
"currency": null
}}
}}
```
---
**用户输入:**
`{user_message}`
**你的输出 (必须是纯JSON):**
"""
def extract(self, message: str) -> dict:
log.info(f"🧠 使用LLM开始提取信息,消息: '{message}'")
prompt = self.prompt_template.format(user_message=message)
raw_response = self.ai_model.generate(prompt)
if not raw_response:
log.error("❌ LLM模型没有返回任何内容。")
return {}
json_str = ""
try:
match = re.search(r'```json\s*(\{.*?\})\s*```', raw_response, re.DOTALL)
if match:
json_str = match.group(1)
else:
start_index = raw_response.find('{')
end_index = raw_response.rfind('}')
if start_index != -1 and end_index != -1 and end_index > start_index:
json_str = raw_response[start_index:end_index + 1]
else:
raise json.JSONDecodeError("在LLM的返回中未找到有效的JSON对象。", raw_response, 0)
extracted_data = json.loads(json_str)
log.info(f"✅ LLM成功提取并解析JSON: {extracted_data}")
except json.JSONDecodeError as e:
log.error(f"❌ 无法解析LLM返回的JSON: '{raw_response}'. 错误: {e}")
return {}
# --- 重点更新:使用更健壮、更安全的逻辑来清理数据 ---
final_info = {}
# 安全地处理 'destination'
destination_data = extracted_data.get("destination")
if isinstance(destination_data, dict) and destination_data.get("name"):
final_info["destination"] = {"name": destination_data["name"]}
# 安全地处理 'duration'
duration_data = extracted_data.get("duration")
if isinstance(duration_data, dict) and duration_data.get("days"):
try:
final_info["duration"] = {"days": int(duration_data["days"])}
except (ValueError, TypeError):
log.warning(f"⚠️ 无法将duration days '{duration_data.get('days')}' 转换为整数。")
# 安全地处理 'budget'
budget_data = extracted_data.get("budget")
if isinstance(budget_data, dict):
# 只要budget对象里有任何非null的值,就把它加进来
if any(v is not None for v in budget_data.values()):
final_info["budget"] = budget_data
log.info(f"📊 LLM最终提取结果 (安全处理后): {list(final_info.keys())}")
return final_info