"""HTTP client for interacting with the NexaSci tool server.""" from __future__ import annotations import os from contextlib import AbstractContextManager from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional, Type, TypeVar import httpx import yaml from pydantic import BaseModel from tools.schemas import ( CorpusSearchRequest, CorpusSearchResponse, PaperFetchRequest, PaperFetchResponse, PaperSearchRequest, PaperSearchResponse, PythonRunRequest, PythonRunResponse, ToolCall, ToolResult, ) T = TypeVar("T", bound=BaseModel) @dataclass(frozen=True) class ToolClientConfig: """Configuration required to initialise the ToolClient.""" base_url: str timeout_s: int = 30 @classmethod def from_yaml(cls, path: Path | str = "agent/config.yaml") -> "ToolClientConfig": """Load configuration parameters from the shared YAML file.""" env_base_url = os.environ.get("TOOL_SERVER_BASE_URL") env_timeout = os.environ.get("TOOL_SERVER_TIMEOUT") if env_base_url: timeout_value = int(env_timeout) if env_timeout else 30 return cls(base_url=env_base_url, timeout_s=timeout_value) config_path = Path(path) if not config_path.exists(): raise FileNotFoundError(f"Configuration file not found: {config_path}") with config_path.open("r", encoding="utf-8") as handle: data = yaml.safe_load(handle) tool_cfg = data.get("tool_server", {}) return cls( base_url=str(tool_cfg.get("base_url", "http://127.0.0.1:8000")), timeout_s=int(tool_cfg.get("request_timeout_s", 30)), ) class ToolClient(AbstractContextManager["ToolClient"]): """Synchronous HTTP client for the NexaSci tool server.""" _tool_to_endpoint: Dict[str, str] = { "python.run": "/tools/python.run", "papers.search": "/tools/papers.search", "papers.fetch": "/tools/papers.fetch", "papers.search_corpus": "/tools/papers.search_corpus", } def __init__(self, config: ToolClientConfig) -> None: self._config = config self._client = httpx.Client(base_url=self._config.base_url, timeout=self._config.timeout_s) @classmethod def from_config(cls, path: Path | str = "agent/config.yaml") -> "ToolClient": """Construct a ToolClient by loading configuration from disk.""" return cls(ToolClientConfig.from_yaml(path)) def __exit__(self, exc_type, exc, exc_tb) -> None: # type: ignore[override] self.close() def close(self) -> None: """Close the underlying HTTP client.""" self._client.close() def call_tool(self, call: ToolCall) -> ToolResult: """Invoke a tool using the generic tool call schema.""" endpoint = self._resolve_endpoint(call.tool) response = self._client.post(endpoint, json=call.arguments) if not response.is_success: return ToolResult.failed(call.tool, f"Tool invocation failed: {response.text}") payload = response.json() return ToolResult.ok(call.tool, payload) def python_run(self, request: PythonRunRequest) -> PythonRunResponse: """Execute code snippets inside the Python sandbox.""" return self._post("python.run", request, PythonRunResponse) def papers_search(self, request: PaperSearchRequest) -> PaperSearchResponse: """Search the arXiv API via the tool server.""" return self._post("papers.search", request, PaperSearchResponse) def papers_fetch(self, request: PaperFetchRequest) -> PaperFetchResponse: """Fetch a single paper's metadata from arXiv.""" return self._post("papers.fetch", request, PaperFetchResponse) def papers_search_corpus(self, request: CorpusSearchRequest) -> CorpusSearchResponse: """Search the local SPECTER2 corpus.""" return self._post("papers.search_corpus", request, CorpusSearchResponse) def _post(self, tool: str, model: BaseModel, response_model: Type[T]) -> T: endpoint = self._resolve_endpoint(tool) response = self._client.post(endpoint, json=model.dict()) response.raise_for_status() return response_model.parse_obj(response.json()) def _resolve_endpoint(self, tool: str) -> str: try: return self._tool_to_endpoint[tool] except KeyError as exc: raise ValueError(f"Unknown tool: {tool}") from exc