"""Agent controller orchestrating the LLM ↔ tool server interaction loop.""" from __future__ import annotations import json import re from dataclasses import dataclass, field from typing import Any, Dict, List, Sequence from tools.schemas import ToolCall, ToolResult from .client_llm import Message, NexaSciModelClient from .client_llm_remote import RemoteNexaSciClient from .tool_client import ToolClient TOOLCALL_REGEX = re.compile(r"~~~toolcall(.*?)~~~", re.DOTALL) FINAL_REGEX = re.compile(r"~~~final(.*?)~~~", re.DOTALL) @dataclass class AgentRunResult: """Container describing the outcome of an agent run.""" final_response: Dict[str, Any] messages: Sequence[Message] tool_results: Sequence[ToolResult] = field(default_factory=list) def pretty(self) -> str: """Return a readable JSON representation of the final response.""" return json.dumps(self.final_response, indent=2) class AgentController: """Core agent loop handling tool invocation and final response parsing.""" def __init__( self, llm_client: NexaSciModelClient | RemoteNexaSciClient | None = None, tool_client: ToolClient | None = None, *, max_turns: int = 8, use_remote_model: bool = False, model_server_url: str = "http://127.0.0.1:8001", ) -> None: """Initialize the agent controller. Parameters ---------- llm_client: Optional LLM client. If None, will create one based on use_remote_model. tool_client: Optional tool client. If None, will create from config. max_turns: Maximum number of agent turns. use_remote_model: If True, connect to remote model server instead of loading locally. model_server_url: URL of the model server (if use_remote_model is True). """ if llm_client is None: if use_remote_model: llm_client = RemoteNexaSciClient(base_url=model_server_url) else: llm_client = NexaSciModelClient(lazy_load=True) self.llm_client = llm_client self.tool_client = tool_client or ToolClient.from_config() self.max_turns = max_turns def run(self, user_prompt: str) -> AgentRunResult: """Execute the agent loop until a final response is produced.""" messages: List[Message] = [Message(role="user", content=user_prompt)] tool_results: List[ToolResult] = [] for _ in range(self.max_turns): response_text = self.llm_client.generate(messages) messages.append(Message(role="assistant", content=response_text)) tool_calls = _extract_tool_calls(response_text) if tool_calls: for call in tool_calls: result = self._dispatch_tool(call) tool_results.append(result) messages.append( Message( role="tool", content=json.dumps(result.output, ensure_ascii=False), ) ) continue final_payload = _extract_final_response(response_text) if final_payload is not None: return AgentRunResult(final_response=final_payload, messages=messages, tool_results=tool_results) raise RuntimeError("Agent did not produce a final response within the maximum number of turns.") def _dispatch_tool(self, call: ToolCall) -> ToolResult: """Invoke the requested tool via the ToolClient.""" return self.tool_client.call_tool(call) def _extract_tool_calls(text: str) -> List[ToolCall]: """Parse tool call JSON payloads embedded in the assistant response.""" tool_calls: List[ToolCall] = [] for match in TOOLCALL_REGEX.findall(text): snippet = match.strip() if not snippet: continue try: payload = json.loads(snippet) tool_calls.append(ToolCall(**payload)) except json.JSONDecodeError: continue return tool_calls def _extract_final_response(text: str) -> Dict[str, Any] | None: """Parse the final response JSON block from the assistant output.""" match = FINAL_REGEX.search(text) if not match: return None snippet = match.group(1).strip() if not snippet: return {} return json.loads(snippet) __all__ = ["AgentController", "AgentRunResult"]