Spaces:
Paused
Paused
| """HTTP client for interacting with the NexaSci tool server.""" | |
| from __future__ import annotations | |
| import os | |
| from contextlib import AbstractContextManager | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, Optional, Type, TypeVar | |
| import httpx | |
| import yaml | |
| from pydantic import BaseModel | |
| from tools.schemas import ( | |
| CorpusSearchRequest, | |
| CorpusSearchResponse, | |
| PaperFetchRequest, | |
| PaperFetchResponse, | |
| PaperSearchRequest, | |
| PaperSearchResponse, | |
| PythonRunRequest, | |
| PythonRunResponse, | |
| ToolCall, | |
| ToolResult, | |
| ) | |
| T = TypeVar("T", bound=BaseModel) | |
| class ToolClientConfig: | |
| """Configuration required to initialise the ToolClient.""" | |
| base_url: str | |
| timeout_s: int = 30 | |
| def from_yaml(cls, path: Path | str = "agent/config.yaml") -> "ToolClientConfig": | |
| """Load configuration parameters from the shared YAML file.""" | |
| env_base_url = os.environ.get("TOOL_SERVER_BASE_URL") | |
| env_timeout = os.environ.get("TOOL_SERVER_TIMEOUT") | |
| if env_base_url: | |
| timeout_value = int(env_timeout) if env_timeout else 30 | |
| return cls(base_url=env_base_url, timeout_s=timeout_value) | |
| config_path = Path(path) | |
| if not config_path.exists(): | |
| raise FileNotFoundError(f"Configuration file not found: {config_path}") | |
| with config_path.open("r", encoding="utf-8") as handle: | |
| data = yaml.safe_load(handle) | |
| tool_cfg = data.get("tool_server", {}) | |
| return cls( | |
| base_url=str(tool_cfg.get("base_url", "http://127.0.0.1:8000")), | |
| timeout_s=int(tool_cfg.get("request_timeout_s", 30)), | |
| ) | |
| class ToolClient(AbstractContextManager["ToolClient"]): | |
| """Synchronous HTTP client for the NexaSci tool server.""" | |
| _tool_to_endpoint: Dict[str, str] = { | |
| "python.run": "/tools/python.run", | |
| "papers.search": "/tools/papers.search", | |
| "papers.fetch": "/tools/papers.fetch", | |
| "papers.search_corpus": "/tools/papers.search_corpus", | |
| } | |
| def __init__(self, config: ToolClientConfig) -> None: | |
| self._config = config | |
| self._client = httpx.Client(base_url=self._config.base_url, timeout=self._config.timeout_s) | |
| def from_config(cls, path: Path | str = "agent/config.yaml") -> "ToolClient": | |
| """Construct a ToolClient by loading configuration from disk.""" | |
| return cls(ToolClientConfig.from_yaml(path)) | |
| def __exit__(self, exc_type, exc, exc_tb) -> None: # type: ignore[override] | |
| self.close() | |
| def close(self) -> None: | |
| """Close the underlying HTTP client.""" | |
| self._client.close() | |
| def call_tool(self, call: ToolCall) -> ToolResult: | |
| """Invoke a tool using the generic tool call schema.""" | |
| endpoint = self._resolve_endpoint(call.tool) | |
| response = self._client.post(endpoint, json=call.arguments) | |
| if not response.is_success: | |
| return ToolResult.failed(call.tool, f"Tool invocation failed: {response.text}") | |
| payload = response.json() | |
| return ToolResult.ok(call.tool, payload) | |
| def python_run(self, request: PythonRunRequest) -> PythonRunResponse: | |
| """Execute code snippets inside the Python sandbox.""" | |
| return self._post("python.run", request, PythonRunResponse) | |
| def papers_search(self, request: PaperSearchRequest) -> PaperSearchResponse: | |
| """Search the arXiv API via the tool server.""" | |
| return self._post("papers.search", request, PaperSearchResponse) | |
| def papers_fetch(self, request: PaperFetchRequest) -> PaperFetchResponse: | |
| """Fetch a single paper's metadata from arXiv.""" | |
| return self._post("papers.fetch", request, PaperFetchResponse) | |
| def papers_search_corpus(self, request: CorpusSearchRequest) -> CorpusSearchResponse: | |
| """Search the local SPECTER2 corpus.""" | |
| return self._post("papers.search_corpus", request, CorpusSearchResponse) | |
| def _post(self, tool: str, model: BaseModel, response_model: Type[T]) -> T: | |
| endpoint = self._resolve_endpoint(tool) | |
| response = self._client.post(endpoint, json=model.dict()) | |
| response.raise_for_status() | |
| return response_model.parse_obj(response.json()) | |
| def _resolve_endpoint(self, tool: str) -> str: | |
| try: | |
| return self._tool_to_endpoint[tool] | |
| except KeyError as exc: | |
| raise ValueError(f"Unknown tool: {tool}") from exc | |