Nexa_Labs / agent /tool_client.py
Allanatrix's picture
Upload 57 files
d8328bf verified
"""HTTP client for interacting with the NexaSci tool server."""
from __future__ import annotations
import os
from contextlib import AbstractContextManager
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Type, TypeVar
import httpx
import yaml
from pydantic import BaseModel
from tools.schemas import (
CorpusSearchRequest,
CorpusSearchResponse,
PaperFetchRequest,
PaperFetchResponse,
PaperSearchRequest,
PaperSearchResponse,
PythonRunRequest,
PythonRunResponse,
ToolCall,
ToolResult,
)
T = TypeVar("T", bound=BaseModel)
@dataclass(frozen=True)
class ToolClientConfig:
"""Configuration required to initialise the ToolClient."""
base_url: str
timeout_s: int = 30
@classmethod
def from_yaml(cls, path: Path | str = "agent/config.yaml") -> "ToolClientConfig":
"""Load configuration parameters from the shared YAML file."""
env_base_url = os.environ.get("TOOL_SERVER_BASE_URL")
env_timeout = os.environ.get("TOOL_SERVER_TIMEOUT")
if env_base_url:
timeout_value = int(env_timeout) if env_timeout else 30
return cls(base_url=env_base_url, timeout_s=timeout_value)
config_path = Path(path)
if not config_path.exists():
raise FileNotFoundError(f"Configuration file not found: {config_path}")
with config_path.open("r", encoding="utf-8") as handle:
data = yaml.safe_load(handle)
tool_cfg = data.get("tool_server", {})
return cls(
base_url=str(tool_cfg.get("base_url", "http://127.0.0.1:8000")),
timeout_s=int(tool_cfg.get("request_timeout_s", 30)),
)
class ToolClient(AbstractContextManager["ToolClient"]):
"""Synchronous HTTP client for the NexaSci tool server."""
_tool_to_endpoint: Dict[str, str] = {
"python.run": "/tools/python.run",
"papers.search": "/tools/papers.search",
"papers.fetch": "/tools/papers.fetch",
"papers.search_corpus": "/tools/papers.search_corpus",
}
def __init__(self, config: ToolClientConfig) -> None:
self._config = config
self._client = httpx.Client(base_url=self._config.base_url, timeout=self._config.timeout_s)
@classmethod
def from_config(cls, path: Path | str = "agent/config.yaml") -> "ToolClient":
"""Construct a ToolClient by loading configuration from disk."""
return cls(ToolClientConfig.from_yaml(path))
def __exit__(self, exc_type, exc, exc_tb) -> None: # type: ignore[override]
self.close()
def close(self) -> None:
"""Close the underlying HTTP client."""
self._client.close()
def call_tool(self, call: ToolCall) -> ToolResult:
"""Invoke a tool using the generic tool call schema."""
endpoint = self._resolve_endpoint(call.tool)
response = self._client.post(endpoint, json=call.arguments)
if not response.is_success:
return ToolResult.failed(call.tool, f"Tool invocation failed: {response.text}")
payload = response.json()
return ToolResult.ok(call.tool, payload)
def python_run(self, request: PythonRunRequest) -> PythonRunResponse:
"""Execute code snippets inside the Python sandbox."""
return self._post("python.run", request, PythonRunResponse)
def papers_search(self, request: PaperSearchRequest) -> PaperSearchResponse:
"""Search the arXiv API via the tool server."""
return self._post("papers.search", request, PaperSearchResponse)
def papers_fetch(self, request: PaperFetchRequest) -> PaperFetchResponse:
"""Fetch a single paper's metadata from arXiv."""
return self._post("papers.fetch", request, PaperFetchResponse)
def papers_search_corpus(self, request: CorpusSearchRequest) -> CorpusSearchResponse:
"""Search the local SPECTER2 corpus."""
return self._post("papers.search_corpus", request, CorpusSearchResponse)
def _post(self, tool: str, model: BaseModel, response_model: Type[T]) -> T:
endpoint = self._resolve_endpoint(tool)
response = self._client.post(endpoint, json=model.dict())
response.raise_for_status()
return response_model.parse_obj(response.json())
def _resolve_endpoint(self, tool: str) -> str:
try:
return self._tool_to_endpoint[tool]
except KeyError as exc:
raise ValueError(f"Unknown tool: {tool}") from exc