Spaces:
Paused
Paused
| """Agent controller orchestrating the LLM ↔ tool server interaction loop.""" | |
| from __future__ import annotations | |
| import json | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Sequence | |
| from tools.schemas import ToolCall, ToolResult | |
| from .client_llm import Message, NexaSciModelClient | |
| from .client_llm_remote import RemoteNexaSciClient | |
| from .tool_client import ToolClient | |
| TOOLCALL_REGEX = re.compile(r"~~~toolcall(.*?)~~~", re.DOTALL) | |
| FINAL_REGEX = re.compile(r"~~~final(.*?)~~~", re.DOTALL) | |
| class AgentRunResult: | |
| """Container describing the outcome of an agent run.""" | |
| final_response: Dict[str, Any] | |
| messages: Sequence[Message] | |
| tool_results: Sequence[ToolResult] = field(default_factory=list) | |
| def pretty(self) -> str: | |
| """Return a readable JSON representation of the final response.""" | |
| return json.dumps(self.final_response, indent=2) | |
| class AgentController: | |
| """Core agent loop handling tool invocation and final response parsing.""" | |
| def __init__( | |
| self, | |
| llm_client: NexaSciModelClient | RemoteNexaSciClient | None = None, | |
| tool_client: ToolClient | None = None, | |
| *, | |
| max_turns: int = 8, | |
| use_remote_model: bool = False, | |
| model_server_url: str = "http://127.0.0.1:8001", | |
| ) -> None: | |
| """Initialize the agent controller. | |
| Parameters | |
| ---------- | |
| llm_client: | |
| Optional LLM client. If None, will create one based on use_remote_model. | |
| tool_client: | |
| Optional tool client. If None, will create from config. | |
| max_turns: | |
| Maximum number of agent turns. | |
| use_remote_model: | |
| If True, connect to remote model server instead of loading locally. | |
| model_server_url: | |
| URL of the model server (if use_remote_model is True). | |
| """ | |
| if llm_client is None: | |
| if use_remote_model: | |
| llm_client = RemoteNexaSciClient(base_url=model_server_url) | |
| else: | |
| llm_client = NexaSciModelClient(lazy_load=True) | |
| self.llm_client = llm_client | |
| self.tool_client = tool_client or ToolClient.from_config() | |
| self.max_turns = max_turns | |
| def run(self, user_prompt: str) -> AgentRunResult: | |
| """Execute the agent loop until a final response is produced.""" | |
| messages: List[Message] = [Message(role="user", content=user_prompt)] | |
| tool_results: List[ToolResult] = [] | |
| for _ in range(self.max_turns): | |
| response_text = self.llm_client.generate(messages) | |
| messages.append(Message(role="assistant", content=response_text)) | |
| tool_calls = _extract_tool_calls(response_text) | |
| if tool_calls: | |
| for call in tool_calls: | |
| result = self._dispatch_tool(call) | |
| tool_results.append(result) | |
| messages.append( | |
| Message( | |
| role="tool", | |
| content=json.dumps(result.output, ensure_ascii=False), | |
| ) | |
| ) | |
| continue | |
| final_payload = _extract_final_response(response_text) | |
| if final_payload is not None: | |
| return AgentRunResult(final_response=final_payload, messages=messages, tool_results=tool_results) | |
| raise RuntimeError("Agent did not produce a final response within the maximum number of turns.") | |
| def _dispatch_tool(self, call: ToolCall) -> ToolResult: | |
| """Invoke the requested tool via the ToolClient.""" | |
| return self.tool_client.call_tool(call) | |
| def _extract_tool_calls(text: str) -> List[ToolCall]: | |
| """Parse tool call JSON payloads embedded in the assistant response.""" | |
| tool_calls: List[ToolCall] = [] | |
| for match in TOOLCALL_REGEX.findall(text): | |
| snippet = match.strip() | |
| if not snippet: | |
| continue | |
| try: | |
| payload = json.loads(snippet) | |
| tool_calls.append(ToolCall(**payload)) | |
| except json.JSONDecodeError: | |
| continue | |
| return tool_calls | |
| def _extract_final_response(text: str) -> Dict[str, Any] | None: | |
| """Parse the final response JSON block from the assistant output.""" | |
| match = FINAL_REGEX.search(text) | |
| if not match: | |
| return None | |
| snippet = match.group(1).strip() | |
| if not snippet: | |
| return {} | |
| return json.loads(snippet) | |
| __all__ = ["AgentController", "AgentRunResult"] | |