"""Pydantic schemas shared between the NexaSci tool server and agent.""" from __future__ import annotations from datetime import datetime from typing import Any, Dict, List, Optional, Sequence from pydantic import BaseModel, Field, HttpUrl, root_validator, validator class ToolCall(BaseModel): """Represents a tool invocation emitted by the language model.""" tool: str = Field(..., description="Identifier of the tool to invoke.") arguments: Dict[str, Any] = Field( default_factory=dict, description="JSON-serializable arguments for the tool.", ) class ToolResult(BaseModel): """Standard response payload returned to the agent after tool execution.""" tool: str = Field(..., description="Identifier of the tool that produced this result.") success: bool = Field(..., description="Whether the tool completed without error.") output: Dict[str, Any] = Field( default_factory=dict, description="Structured result payload returned by the tool.", ) error: Optional[str] = Field( default=None, description="Human-readable error message when success is False.", ) @classmethod def ok(cls, tool: str, output: Dict[str, Any]) -> "ToolResult": """Construct a successful tool response.""" return cls(tool=tool, success=True, output=output, error=None) @classmethod def failed(cls, tool: str, error: str) -> "ToolResult": """Construct an error response for tool execution failures.""" return cls(tool=tool, success=False, output={}, error=error) class PythonRunRequest(BaseModel): """Request model for the sandboxed Python execution tool.""" code: str = Field(..., description="Python source code to execute.") timeout_s: int = Field(10, ge=1, le=60, description="Execution timeout in seconds.") class ArtifactRecord(BaseModel): """Metadata describing an artifact file produced by the sandbox.""" name: str path: str mime_type: Optional[str] = None class PythonRunResponse(BaseModel): """Response payload returned by the python.run tool.""" stdout: str = "" stderr: str = "" artifacts: List[ArtifactRecord] = Field(default_factory=list) class PaperSearchRequest(BaseModel): """Request payload for papers.search endpoint.""" query: str = Field(..., description="Search query string.") top_k: int = Field(10, ge=1, le=50, description="Number of results to return.") @validator("query") def validate_query(cls, value: str) -> str: if not value.strip(): raise ValueError("Query must not be empty.") return value class PaperMetadata(BaseModel): """Normalized metadata for a single scientific paper.""" title: str abstract: Optional[str] = None authors: Sequence[str] = Field(default_factory=list) doi: Optional[str] = None arxiv_id: Optional[str] = None published: Optional[datetime] = None primary_category: Optional[str] = None url: Optional[HttpUrl] = None source: Optional[str] = Field(default=None, description="Originating corpus or API.") class PaperSearchResponse(BaseModel): """Response payload for papers.search endpoint.""" results: List[PaperMetadata] = Field(default_factory=list) class PaperFetchRequest(BaseModel): """Request payload for retrieving full metadata for a given paper.""" doi: Optional[str] = Field(default=None, description="Digital object identifier for the paper.") arxiv_id: Optional[str] = Field(default=None, description="arXiv identifier for the paper.") @root_validator(pre=True) def ensure_identifier(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Ensure that at least one identifier is provided.""" if not values.get("doi") and not values.get("arxiv_id"): raise ValueError("Either doi or arxiv_id must be provided.") return values class PaperFetchResponse(BaseModel): """Response payload for the papers.fetch endpoint.""" paper: PaperMetadata class CorpusSearchRequest(BaseModel): """Request payload for performing semantic search over the local corpus.""" query: str top_k: int = Field(5, ge=1, le=50) @validator("query") def validate_query(cls, value: str) -> str: if not value.strip(): raise ValueError("Query must not be empty.") return value class CorpusDocument(BaseModel): """Single document entry returned from the local corpus search.""" title: str paper_id: Optional[str] = None score: float abstract: Optional[str] = None cluster_id: Optional[int] = None url: Optional[HttpUrl] = None class CorpusSearchResponse(BaseModel): """Response payload for the papers.search_corpus endpoint.""" results: List[CorpusDocument] = Field(default_factory=list) __all__ = [ "ArtifactRecord", "CorpusDocument", "CorpusSearchRequest", "CorpusSearchResponse", "PaperFetchRequest", "PaperFetchResponse", "PaperMetadata", "PaperSearchRequest", "PaperSearchResponse", "PythonRunRequest", "PythonRunResponse", "ToolCall", "ToolResult", ]