Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| from utils.logger import log | |
| from .ai_model import AIModel | |
| class InfoExtractor: | |
| def __init__(self, ai_model): | |
| self.ai_model = ai_model | |
| self.prompt_template = self._build_prompt_template() | |
| def _build_prompt_template(self) -> str: | |
| # --- 重点更新:使用更严格的指令和结构化示例 --- | |
| return """你的任务是且仅是作为文本解析器。 | |
| 严格分析用户输入,并以一个纯净、无注释的JSON对象格式返回。 | |
| **核心规则:** | |
| 1. **绝对禁止** 在JSON之外添加任何文本、注释、解释或Markdown标记。你的输出必须从 `{{` 开始,到 `}}` 结束。 | |
| 2. **必须严格遵守** 下方定义的嵌套JSON结构。不要创造新的键,也不要改变层级。 | |
| 3. 如果信息未提供,对应的键值必须为 `null`,而不是省略该键。 | |
| 4. 如果用户输入与旅行无关(如 "你好"),必须返回一个空的JSON对象: `{{}}`。 | |
| **强制JSON输出结构:** | |
| ```json | |
| {{ | |
| "destination": {{ | |
| "name": "string or null" | |
| }}, | |
| "duration": {{ | |
| "days": "integer or null" | |
| }}, | |
| "budget": {{ | |
| "type": "string ('economy', 'comfortable', 'luxury') or null", | |
| "amount": "number or null", | |
| "currency": "string or null" | |
| }} | |
| }} | |
| ``` | |
| **示例:** | |
| - 用户输入: "我想去柏林玩3天" | |
| - 你的输出: | |
| ```json | |
| {{ | |
| "destination": {{ | |
| "name": "柏林" | |
| }}, | |
| "duration": {{ | |
| "days": 3 | |
| }}, | |
| "budget": {{ | |
| "type": null, | |
| "amount": null, | |
| "currency": null | |
| }} | |
| }} | |
| ``` | |
| --- | |
| **用户输入:** | |
| `{user_message}` | |
| **你的输出 (必须是纯JSON):** | |
| """ | |
| def extract(self, message: str) -> dict: | |
| log.info(f"🧠 使用LLM开始提取信息,消息: '{message}'") | |
| prompt = self.prompt_template.format(user_message=message) | |
| raw_response = self.ai_model.generate(prompt) | |
| if not raw_response: | |
| log.error("❌ LLM模型没有返回任何内容。") | |
| return {} | |
| json_str = "" | |
| try: | |
| match = re.search(r'```json\s*(\{.*?\})\s*```', raw_response, re.DOTALL) | |
| if match: | |
| json_str = match.group(1) | |
| else: | |
| start_index = raw_response.find('{') | |
| end_index = raw_response.rfind('}') | |
| if start_index != -1 and end_index != -1 and end_index > start_index: | |
| json_str = raw_response[start_index:end_index + 1] | |
| else: | |
| raise json.JSONDecodeError("在LLM的返回中未找到有效的JSON对象。", raw_response, 0) | |
| extracted_data = json.loads(json_str) | |
| log.info(f"✅ LLM成功提取并解析JSON: {extracted_data}") | |
| except json.JSONDecodeError as e: | |
| log.error(f"❌ 无法解析LLM返回的JSON: '{raw_response}'. 错误: {e}") | |
| return {} | |
| # --- 重点更新:使用更健壮、更安全的逻辑来清理数据 --- | |
| final_info = {} | |
| # 安全地处理 'destination' | |
| destination_data = extracted_data.get("destination") | |
| if isinstance(destination_data, dict) and destination_data.get("name"): | |
| final_info["destination"] = {"name": destination_data["name"]} | |
| # 安全地处理 'duration' | |
| duration_data = extracted_data.get("duration") | |
| if isinstance(duration_data, dict) and duration_data.get("days"): | |
| try: | |
| final_info["duration"] = {"days": int(duration_data["days"])} | |
| except (ValueError, TypeError): | |
| log.warning(f"⚠️ 无法将duration days '{duration_data.get('days')}' 转换为整数。") | |
| # 安全地处理 'budget' | |
| budget_data = extracted_data.get("budget") | |
| if isinstance(budget_data, dict): | |
| # 只要budget对象里有任何非null的值,就把它加进来 | |
| if any(v is not None for v in budget_data.values()): | |
| final_info["budget"] = budget_data | |
| log.info(f"📊 LLM最终提取结果 (安全处理后): {list(final_info.keys())}") | |
| return final_info |