"""Tools for visualising embedding spaces using UMAP.""" from __future__ import annotations import argparse from pathlib import Path from typing import List import matplotlib.pyplot as plt import numpy as np from matplotlib.lines import Line2D from mpl_toolkits.mplot3d import Axes3D # noqa: F401 - needed for 3D projections from umap import UMAP from pipeline.storage import load_corpus, load_embeddings DEFAULT_INDEX_DIR = Path("index") DEFAULT_CORPUS_PATH = DEFAULT_INDEX_DIR / "corpus.json" DEFAULT_EMBEDDINGS_PATH = DEFAULT_INDEX_DIR / "embeddings.npy" def parse_args(argv: List[str] | None = None) -> argparse.Namespace: """Parse command-line options for the visualiser. Parameters ---------- argv: List[str] | None, default None Optional argument list override for testing. Returns ------- argparse.Namespace Parsed CLI arguments. """ parser = argparse.ArgumentParser(description="Visualise SPECTER2 embeddings in 2D or 3D.") parser.add_argument( "--embeddings", type=Path, default=DEFAULT_EMBEDDINGS_PATH, help="Path to embeddings.npy (default: index/embeddings.npy)", ) parser.add_argument( "--corpus", type=Path, default=DEFAULT_CORPUS_PATH, help="Path to corpus.json metadata (default: index/corpus.json)", ) parser.add_argument( "--dims", type=int, choices=(2, 3), default=2, help="Number of UMAP dimensions (2 or 3, default: 2)", ) parser.add_argument( "--output", type=Path, default=None, help="Optional output path for the generated plot (default: derived from embeddings path)", ) parser.add_argument( "--show", action="store_true", help="Display the plot interactively after saving", ) return parser.parse_args(argv) def plot_embeddings( embeddings_path: Path, corpus_path: Path, dims: int = 2, output_path: Path | None = None, show: bool = False, ) -> Path: """Create a UMAP projection and save the resulting plot. Parameters ---------- embeddings_path: Path Location of the `embeddings.npy` file. corpus_path: Path Location of the `corpus.json` file. dims: int, default 2 Number of UMAP dimensions (2 or 3). output_path: Path | None, default None Optional destination for the saved figure. Uses a default if not provided. show: bool, default False Whether to display the plot after saving. Returns ------- Path The path to the saved figure. """ embeddings = load_embeddings(embeddings_path) corpus = load_corpus(corpus_path) if embeddings.shape[0] != len(corpus): raise ValueError( "Embeddings and corpus lengths do not match. Ensure the inputs originate from the same build run." ) reducer = UMAP(n_components=dims, n_neighbors=15, min_dist=0.1, random_state=42) coordinates = reducer.fit_transform(embeddings) categories = [metadata.get("categories", []) for metadata in corpus] primary_labels = [category[0] if category else "unknown" for category in categories] label_to_index = {label: idx for idx, label in enumerate(sorted(set(primary_labels)))} colour_indices = np.array([label_to_index[label] for label in primary_labels]) fig = _create_figure(coordinates, colour_indices, primary_labels, label_to_index, dims) derived_output = embeddings_path.with_name(f"embedding_plot_{dims}d.png") output = output_path or derived_output output.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output, dpi=200, bbox_inches="tight") if show: plt.show() else: plt.close(fig) print(f"Saved {dims}D embedding visualisation to {output}") return output def _create_figure( coordinates: np.ndarray, colour_indices: np.ndarray, labels: List[str], label_to_index: dict[str, int], dims: int, ) -> plt.Figure: """Create a matplotlib figure for the requested dimensionality. Parameters ---------- coordinates: np.ndarray UMAP-reduced coordinates of shape (n_samples, dims). colour_indices: np.ndarray Integer indices representing colour assignments per sample. labels: List[str] Primary category labels aligned with the coordinates. label_to_index: dict[str, int] Mapping from label names to integer colour indices. dims: int Dimensionality of the embedding visualisation (2 or 3). Returns ------- plt.Figure The generated matplotlib figure. """ plt.rcdefaults() fig = plt.figure(figsize=(10, 8)) if dims == 2: ax = fig.add_subplot(111) scatter = ax.scatter( coordinates[:, 0], coordinates[:, 1], c=colour_indices, cmap="tab20", s=20, alpha=0.85, ) ax.set_xlabel("UMAP 1") ax.set_ylabel("UMAP 2") else: ax = fig.add_subplot(111, projection="3d") scatter = ax.scatter( coordinates[:, 0], coordinates[:, 1], coordinates[:, 2], c=colour_indices, cmap="tab20", s=20, alpha=0.85, ) ax.set_xlabel("UMAP 1") ax.set_ylabel("UMAP 2") ax.set_zlabel("UMAP 3") ax.set_title(f"SPECTER2 Embeddings ({dims}D UMAP)") # Build a small legend using the primary labels. unique_labels = sorted(set(labels)) handles = [] for label in unique_labels: colour_value = label_to_index[label] rgba = scatter.cmap(scatter.norm(colour_value)) handle = Line2D([0], [0], marker="o", color="w", label=label, markerfacecolor=rgba, markersize=8) handles.append(handle) if len(handles) <= 12: ax.legend(handles=handles, title="Primary Category", bbox_to_anchor=(1.05, 1), loc="upper left") return fig def main(argv: List[str] | None = None) -> None: """Entry point for the visualisation CLI. Parameters ---------- argv: List[str] | None, default None Optional argument override when invoking programmatically. """ args = parse_args(argv) plot_embeddings( embeddings_path=args.embeddings, corpus_path=args.corpus, dims=args.dims, output_path=args.output, show=args.show, ) if __name__ == "__main__": # pragma: no cover - CLI entry point main()