Spaces:
Paused
Paused
| """ Interactive Gradio UI for exploring the local SPECTER2 corpus.""" | |
| from __future__ import annotations | |
| from collections import Counter, defaultdict | |
| import subprocess | |
| import sys | |
| import time | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Sequence, Set, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import matplotlib.pyplot as plt | |
| from matplotlib.collections import LineCollection | |
| from matplotlib.colors import to_rgba | |
| from matplotlib.figure import Figure | |
| FULLSCREEN_JS = """ | |
| () => { | |
| const container = document.getElementById('embedding-plot'); | |
| if (!container) return; | |
| const plot = container.querySelector('.js-plotly-plot') || container; | |
| if (!document.fullscreenElement) { | |
| if (plot.requestFullscreen) { | |
| plot.requestFullscreen(); | |
| } else if (plot.webkitRequestFullscreen) { | |
| plot.webkitRequestFullscreen(); | |
| } | |
| } else { | |
| if (document.exitFullscreen) { | |
| document.exitFullscreen(); | |
| } else if (document.webkitExitFullscreen) { | |
| document.webkitExitFullscreen(); | |
| } | |
| } | |
| } | |
| """ | |
| ORBIT_JS = """ | |
| () => { | |
| const container = document.getElementById('embedding-plot'); | |
| if (!container) return; | |
| const plot = container.querySelector('.js-plotly-plot'); | |
| if (!plot) return; | |
| window._plotOrbitIntervals = window._plotOrbitIntervals || {}; | |
| const key = 'embedding-plot'; | |
| if (window._plotOrbitIntervals[key]) { | |
| clearInterval(window._plotOrbitIntervals[key]); | |
| delete window._plotOrbitIntervals[key]; | |
| return; | |
| } | |
| let angle = 0; | |
| const radius = 1.6; | |
| window._plotOrbitIntervals[key] = setInterval(() => { | |
| const updatedPlot = container.querySelector('.js-plotly-plot'); | |
| if (!updatedPlot) { | |
| clearInterval(window._plotOrbitIntervals[key]); | |
| delete window._plotOrbitIntervals[key]; | |
| return; | |
| } | |
| angle = (angle + 2) % 360; | |
| const rad = angle * Math.PI / 180; | |
| Plotly.relayout(updatedPlot, { | |
| 'scene.camera.eye': { | |
| x: radius * Math.cos(rad), | |
| y: radius * Math.sin(rad), | |
| z: 0.9, | |
| }, | |
| }); | |
| }, 50); | |
| } | |
| """ | |
| CUSTOM_JS = """ | |
| function(componentId, action) { | |
| const el = document.getElementById(componentId); | |
| if (!el) return; | |
| if (action === "orbit") { | |
| if (window._orbitIntervals === undefined) { | |
| window._orbitIntervals = {}; | |
| } | |
| if (window._orbitIntervals[componentId]) { | |
| clearInterval(window._orbitIntervals[componentId]); | |
| delete window._orbitIntervals[componentId]; | |
| } else { | |
| let angle = 0; | |
| const interval = setInterval(() => { | |
| angle = (angle + 2) % 360; | |
| const rad = angle * Math.PI / 180; | |
| const r = 1.6; | |
| const layout = { | |
| scene: {camera: {eye: {x: r * Math.cos(rad), y: r * Math.sin(rad), z: 0.9}}} | |
| }; | |
| Plotly.relayout(el, layout); | |
| }, 50); | |
| window._orbitIntervals[componentId] = interval; | |
| } | |
| } else if (action === "fullscreen") { | |
| const container = el.closest("div.svelte-1ipelgc"); | |
| const target = container || el; | |
| if (!document.fullscreenElement) { | |
| target.requestFullscreen?.(); | |
| } else { | |
| document.exitFullscreen?.(); | |
| } | |
| } | |
| } | |
| """ | |
| from pipeline.embed import Specter2Embedder | |
| from pipeline.storage import load_embeddings, load_canonical_corpus | |
| INDEX_DIR = Path(__file__).resolve().parents[1] / "index" | |
| CORPUS_PATH = INDEX_DIR / "corpus.json" | |
| EMBEDDINGS_PATH = INDEX_DIR / "embeddings.npy" | |
| DEFAULT_COLOR_BASIS = "Cluster" | |
| DEFAULT_PALETTE = "Plotly" | |
| COLOR_BASIS_OPTIONS: Dict[str, str] = { | |
| "Cluster": "cluster", | |
| "Primary Category": "primary_category", | |
| } | |
| PALETTE_OPTIONS: Dict[str, List[str]] = { | |
| "Plotly": px.colors.qualitative.Plotly, | |
| "Bold": px.colors.qualitative.Bold, | |
| "Vivid": px.colors.qualitative.Vivid, | |
| "Pastel": px.colors.qualitative.Pastel, | |
| "Safe": px.colors.qualitative.Safe, | |
| } | |
| MAX_EDGE_RENDER = 2000 | |
| def _float_rgba_to_plotly(rgba: Tuple[float, float, float, float], alpha: float | None = None) -> str: | |
| r, g, b, a = rgba | |
| if alpha is not None: | |
| a = alpha | |
| return f"rgba({int(r * 255)}, {int(g * 255)}, {int(b * 255)}, {a:.2f})" | |
| def _build_cluster_color_map(cluster_ids: Sequence[int], palette: Sequence[Tuple[float, float, float, float]]) -> Dict[int, Tuple[float, float, float, float]]: | |
| unique_ids = sorted(set(int(cid) for cid in cluster_ids)) | |
| color_map: Dict[int, Tuple[float, float, float, float]] = {} | |
| for idx, cluster_id in enumerate(unique_ids): | |
| color_map[cluster_id] = palette[idx % len(palette)] | |
| return color_map | |
| def _build_cluster_overview(papers: Sequence[Dict[str, Any]]) -> pd.DataFrame: | |
| clusters: Dict[int, Dict[str, Any]] = defaultdict(lambda: { | |
| "cluster_id": None, | |
| "size": 0, | |
| "categories": Counter(), | |
| "sample_titles": [], | |
| }) | |
| for paper in papers: | |
| cluster_id = int(paper.get("cluster_id", -1)) | |
| entry = clusters[cluster_id] | |
| entry["cluster_id"] = cluster_id | |
| entry["size"] += 1 | |
| category = paper.get("primary_category") or "unknown" | |
| entry["categories"][category] += 1 | |
| if len(entry["sample_titles"]) < 3: | |
| entry["sample_titles"].append(paper.get("title", "(untitled)")) | |
| entry["major_category"] = category.split(".")[0] if "." in category else category | |
| overview_rows = [] | |
| for data in clusters.values(): | |
| dominant_category = data["categories"].most_common(1)[0][0] if data["categories"] else "unknown" | |
| overview_rows.append( | |
| { | |
| "cluster_id": data["cluster_id"], | |
| "size": data["size"], | |
| "major_category": data.get("major_category", "unknown"), | |
| "dominant_category": dominant_category, | |
| "sample_titles": " | ".join(data["sample_titles"]), | |
| } | |
| ) | |
| overview_rows.sort(key=lambda row: row["cluster_id"]) | |
| return pd.DataFrame(overview_rows) | |
| def _build_cluster_hierarchy_json(papers: Sequence[Dict[str, Any]]) -> Dict[str, Any]: | |
| hierarchy: Dict[str, Dict[str, List[Dict[str, Any]]]] = defaultdict(lambda: defaultdict(list)) | |
| for paper in papers: | |
| cluster_id = int(paper.get("cluster_id", -1)) | |
| category = paper.get("primary_category") or "unknown" | |
| major = category.split(".")[0] if "." in category else category | |
| hierarchy[major][category].append( | |
| { | |
| "cluster_id": cluster_id, | |
| "paper_id": paper.get("paper_id"), | |
| "title": paper.get("title"), | |
| } | |
| ) | |
| major_payload = [] | |
| for major, subcategories in hierarchy.items(): | |
| sub_payload = [] | |
| for category, clusters in sorted(subcategories.items()): | |
| clusters_sorted = sorted(clusters, key=lambda c: c["cluster_id"]) | |
| sub_payload.append({ | |
| "category": category, | |
| "clusters": clusters_sorted, | |
| "cluster_ids": sorted({entry["cluster_id"] for entry in clusters_sorted}), | |
| }) | |
| major_payload.append({ | |
| "major": major, | |
| "subcategories": sub_payload, | |
| }) | |
| major_payload.sort(key=lambda entry: entry["major"]) | |
| return {"major_categories": major_payload} | |
| def _filter_edges(edges: Sequence[Dict[str, Any]], selected: Set[int]) -> List[Dict[str, Any]]: | |
| """Return only edges whose endpoints are in the selected set.""" | |
| return [ | |
| edge | |
| for edge in edges | |
| if int(edge.get("source", -1)) in selected and int(edge.get("target", -1)) in selected | |
| ] | |
| def _normalise_embeddings(vectors: np.ndarray) -> np.ndarray: | |
| """Return L2-normalised embeddings, guarding against zero vectors.""" | |
| if vectors.size == 0: | |
| return vectors | |
| norms = np.linalg.norm(vectors, axis=1, keepdims=True) | |
| norms[norms == 0] = 1.0 | |
| return vectors / norms | |
| def load_resources() -> Tuple[ | |
| Dict[str, Any], | |
| List[Dict[str, Any]], | |
| np.ndarray, | |
| np.ndarray, | |
| np.ndarray, | |
| np.ndarray, | |
| List[Dict[str, Any]], | |
| List[Dict[str, Any]], | |
| ]: | |
| """Load canonical corpus data, embeddings, and graph metadata from disk.""" | |
| if not CORPUS_PATH.exists() or not EMBEDDINGS_PATH.exists(): | |
| raise FileNotFoundError( | |
| "Corpus artifacts not found. Run `python -m pipeline.build_corpus` first." | |
| ) | |
| corpus_doc = load_canonical_corpus(CORPUS_PATH) | |
| papers = corpus_doc.get("papers", []) | |
| embeddings = load_embeddings(EMBEDDINGS_PATH) | |
| if embeddings.shape[0] != len(papers): | |
| raise ValueError( | |
| "Mismatch between embeddings and canonical corpus entries. Rebuild the corpus to continue." | |
| ) | |
| papers_sorted = sorted(papers, key=lambda entry: entry.get("embedding_idx", 0)) | |
| if not all(paper.get("embedding_idx") == idx for idx, paper in enumerate(papers_sorted)): | |
| raise ValueError("Embedding indices in canonical corpus do not match their positions; rebuild the corpus.") | |
| umap_2d = np.array([paper.get("umap_2d", [0.0, 0.0]) for paper in papers_sorted], dtype=np.float32) | |
| umap_3d = np.array([paper.get("umap_3d", [0.0, 0.0, 0.0]) for paper in papers_sorted], dtype=np.float32) | |
| normalised = _normalise_embeddings(embeddings.astype(np.float32)) | |
| graph_edges = corpus_doc.get("graph", {}).get("edges", []) | |
| cluster_metadata = corpus_doc.get("clusters", []) | |
| return ( | |
| corpus_doc, | |
| papers_sorted, | |
| embeddings, | |
| normalised, | |
| umap_2d, | |
| umap_3d, | |
| graph_edges, | |
| cluster_metadata, | |
| ) | |
| def get_embedder(device: str | None = None) -> Specter2Embedder: | |
| """Instantiate the Specter2 embedder once.""" | |
| return Specter2Embedder(device=device) | |
| def _cluster_options() -> List[str]: | |
| """Return the cluster dropdown options (All + IDs).""" | |
| (_, papers, *_rest) = load_resources() | |
| cluster_ids = sorted({int(paper.get("cluster_id", 0)) for paper in papers}) | |
| return ["All"] + [str(cluster_id) for cluster_id in cluster_ids] | |
| def _resolve_color_basis(choice: str) -> str: | |
| return COLOR_BASIS_OPTIONS.get(choice, COLOR_BASIS_OPTIONS[DEFAULT_COLOR_BASIS]) | |
| def _resolve_palette(choice: str) -> List[Tuple[float, float, float, float]]: | |
| palette = PALETTE_OPTIONS.get(choice, PALETTE_OPTIONS[DEFAULT_PALETTE]) | |
| resolved: List[Tuple[float, float, float, float]] = [] | |
| for color in palette: | |
| try: | |
| resolved.append(to_rgba(color)) | |
| except ValueError: | |
| if color.startswith("rgb"): | |
| parts = color[color.find("(") + 1 : color.find(")")].split(",") | |
| floats = tuple(float(part.strip()) / 255.0 for part in parts) | |
| resolved.append((*floats, 1.0)) | |
| else: | |
| raise | |
| if not resolved: | |
| resolved.append((0.2, 0.4, 0.8, 1.0)) | |
| return resolved | |
| def _hover_text_for_papers(papers: Sequence[Dict[str, Any]]) -> np.ndarray: | |
| """Generate hover text for each paper.""" | |
| hover = [] | |
| for paper in papers: | |
| hover.append( | |
| "<br>".join( | |
| [ | |
| paper.get("title", "(untitled)"), | |
| f"ID: {paper.get('paper_id', 'n/a')}", | |
| f"Cluster: {paper.get('cluster_id', 'n/a')}", | |
| f"Category: {paper.get('primary_category', 'unknown')}", | |
| f"Authors: {', '.join(paper.get('authors', [])[:3])}" + ("…" if len(paper.get('authors', [])) > 3 else ""), | |
| ] | |
| ) | |
| ) | |
| return np.array(hover) | |
| def _group_points(labels: np.ndarray, palette: Sequence[str]) -> List[Tuple[str, np.ndarray, str]]: | |
| """Return masking information for each unique label.""" | |
| unique = sorted(np.unique(labels)) | |
| groups: List[Tuple[str, np.ndarray, str]] = [] | |
| for idx, label in enumerate(unique): | |
| mask = labels == label | |
| color = palette[idx % len(palette)] | |
| groups.append((label, mask, color)) | |
| return groups | |
| def _build_2d_plot( | |
| coords: np.ndarray, | |
| original_indices: Sequence[int], | |
| labels: np.ndarray, | |
| hover_text: np.ndarray, | |
| edges: Sequence[Dict[str, Any]], | |
| clusters: Sequence[Dict[str, Any]], | |
| cluster_ids_subset: np.ndarray, | |
| point_color_map: Dict[str, Tuple[float, float, float, float]], | |
| cluster_color_map: Dict[int, Tuple[float, float, float, float]], | |
| ) -> plt.Figure: | |
| fig, ax = plt.subplots(figsize=(6.8, 6.2), dpi=120) | |
| if coords.shape[0] < 1: | |
| ax.set_title("Corpus Embedding Map (2D)") | |
| ax.axis("off") | |
| return fig | |
| label_order = sorted(set(labels)) | |
| for label in label_order: | |
| mask = labels == label | |
| if not np.any(mask): | |
| continue | |
| rgba = point_color_map.get(label) | |
| if rgba is None: | |
| rgba = (0.25, 0.5, 0.85, 1.0) | |
| ax.scatter( | |
| coords[mask, 0], | |
| coords[mask, 1], | |
| s=26, | |
| c=[rgba], | |
| alpha=0.9, | |
| linewidths=0.3, | |
| edgecolors="#f5f5f5", | |
| label=label, | |
| ) | |
| if edges: | |
| index_map = {orig_idx: pos for pos, orig_idx in enumerate(original_indices)} | |
| segment_map: Dict[int, List[List[Tuple[float, float]]]] = defaultdict(list) | |
| for edge in edges[:MAX_EDGE_RENDER]: | |
| source = int(edge["source"]) | |
| target = int(edge["target"]) | |
| if source not in index_map or target not in index_map: | |
| continue | |
| src_idx = index_map[source] | |
| tgt_idx = index_map[target] | |
| cluster_id = int(cluster_ids_subset[src_idx]) if src_idx < len(cluster_ids_subset) else -1 | |
| segment_map[cluster_id].append( | |
| [ | |
| (coords[src_idx, 0], coords[src_idx, 1]), | |
| (coords[tgt_idx, 0], coords[tgt_idx, 1]), | |
| ] | |
| ) | |
| for cluster_id, segments in segment_map.items(): | |
| base = cluster_color_map.get(cluster_id, (0.55, 0.55, 0.55, 1.0)) | |
| lc = LineCollection( | |
| segments, | |
| colors=[(base[0], base[1], base[2], 0.22)], | |
| linewidths=0.55, | |
| ) | |
| ax.add_collection(lc) | |
| for cluster in clusters: | |
| centroid = cluster.get("centroid_2d") | |
| if not centroid: | |
| continue | |
| cluster_id = int(cluster.get("cluster_id", -1)) | |
| rgba = cluster_color_map.get(cluster_id, (0.1, 0.1, 0.1, 1.0)) | |
| ax.scatter( | |
| centroid[0], | |
| centroid[1], | |
| s=150, | |
| marker="D", | |
| c=[rgba], | |
| edgecolors="#222222", | |
| linewidths=0.6, | |
| alpha=0.95, | |
| ) | |
| ax.text( | |
| centroid[0], | |
| centroid[1], | |
| f"C{cluster['cluster_id']}", | |
| fontsize=9, | |
| ha="center", | |
| va="bottom", | |
| color="#222222", | |
| ) | |
| ax.set_title("Corpus Embedding Map (2D)") | |
| ax.set_xlabel("UMAP 1") | |
| ax.set_ylabel("UMAP 2") | |
| ax.tick_params(labelsize=8) | |
| ax.set_aspect("equal", adjustable="datalim") | |
| ax.grid(alpha=0.15, linestyle="--", linewidth=0.45) | |
| ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.16), ncol=4, fontsize=7, frameon=False) | |
| fig.tight_layout() | |
| return fig | |
| def _build_3d_figure( | |
| coords: np.ndarray, | |
| original_indices: Sequence[int], | |
| labels: np.ndarray, | |
| hover_text: np.ndarray, | |
| edges: Sequence[Dict[str, Any]], | |
| clusters: Sequence[Dict[str, Any]], | |
| cluster_ids_subset: np.ndarray, | |
| embedding_indices_subset: np.ndarray, | |
| point_color_map: Dict[str, Tuple[float, float, float, float]], | |
| cluster_color_map: Dict[int, Tuple[float, float, float, float]], | |
| ) -> go.Figure: | |
| """Generate a 3D Plotly figure for the embedding map.""" | |
| fig = go.Figure() | |
| if coords.shape[0] < 1: | |
| fig.update_layout(title="Corpus Embedding Map (3D)") | |
| return fig | |
| label_order = sorted(set(labels)) | |
| for label in label_order: | |
| mask = labels == label | |
| if not np.any(mask): | |
| continue | |
| rgba = point_color_map.get(label) | |
| rgba_str = _float_rgba_to_plotly(rgba) if rgba else "rgba(52, 120, 198, 0.9)" | |
| fig.add_trace( | |
| go.Scatter3d( | |
| x=coords[mask, 0], | |
| y=coords[mask, 1], | |
| z=coords[mask, 2], | |
| mode="markers", | |
| marker=dict(color=rgba_str, size=4.8, opacity=0.9, line=dict(width=0.6, color="#101010"), symbol="circle"), | |
| name=str(label), | |
| hovertext=hover_text[mask], | |
| hoverinfo="text", | |
| customdata=embedding_indices_subset[mask][:, None], | |
| ) | |
| ) | |
| if edges: | |
| index_map = {orig_idx: pos for pos, orig_idx in enumerate(original_indices)} | |
| edge_segments: Dict[int, Dict[str, List[float]]] = defaultdict(lambda: {"x": [], "y": [], "z": []}) | |
| for edge in edges[:MAX_EDGE_RENDER]: | |
| source = int(edge["source"]) | |
| target = int(edge["target"]) | |
| if source not in index_map or target not in index_map: | |
| continue | |
| src_idx = index_map[source] | |
| tgt_idx = index_map[target] | |
| cluster_id = int(cluster_ids_subset[src_idx]) if src_idx < len(cluster_ids_subset) else -1 | |
| seg = edge_segments[cluster_id] | |
| seg["x"].extend([coords[src_idx, 0], coords[tgt_idx, 0], None]) | |
| seg["y"].extend([coords[src_idx, 1], coords[tgt_idx, 1], None]) | |
| seg["z"].extend([coords[src_idx, 2], coords[tgt_idx, 2], None]) | |
| for cluster_id, seg in edge_segments.items(): | |
| cluster_color = cluster_color_map.get(cluster_id, (0.4, 0.4, 0.4, 1.0)) | |
| fig.add_trace( | |
| go.Scatter3d( | |
| x=seg["x"], | |
| y=seg["y"], | |
| z=seg["z"], | |
| mode="lines", | |
| line=dict(color=_float_rgba_to_plotly(cluster_color, alpha=0.18), width=1.3), | |
| hoverinfo="none", | |
| name=f"Cluster {cluster_id} edges", | |
| showlegend=False, | |
| ) | |
| ) | |
| if clusters: | |
| fig.add_trace( | |
| go.Scatter3d( | |
| x=[c["centroid_3d"][0] for c in clusters], | |
| y=[c["centroid_3d"][1] for c in clusters], | |
| z=[c["centroid_3d"][2] for c in clusters], | |
| mode="markers+text", | |
| marker=dict( | |
| symbol="diamond", | |
| size=12, | |
| color=[_float_rgba_to_plotly(cluster_color_map.get(int(c["cluster_id"]), (0.3, 0.3, 0.3, 1.0))) for c in clusters], | |
| line=dict(width=1.5, color="#222222"), | |
| ), | |
| text=[f"C{c['cluster_id']}" for c in clusters], | |
| textposition="top center", | |
| hovertext=[f"Cluster {c['cluster_id']}<br>Size: {c['size']}" for c in clusters], | |
| hoverinfo="text", | |
| name="Centroids", | |
| showlegend=False, | |
| ) | |
| ) | |
| fig.update_layout( | |
| title="Corpus Embedding Map (3D)", | |
| scene=dict( | |
| xaxis_title="UMAP 1", | |
| yaxis_title="UMAP 2", | |
| zaxis_title="UMAP 3", | |
| xaxis=dict(showgrid=True, zeroline=False, showbackground=False), | |
| yaxis=dict(showgrid=True, zeroline=False, showbackground=False), | |
| zaxis=dict(showgrid=True, zeroline=False, showbackground=False), | |
| ), | |
| legend=dict(orientation="h", y=-0.1), | |
| margin=dict(l=10, r=10, t=60, b=10), | |
| template="plotly_white", | |
| scene_camera=dict(eye=dict(x=1.6, y=1.6, z=0.9)), | |
| hovermode="closest", | |
| ) | |
| return fig | |
| def render_plots( | |
| show_edges: bool, | |
| cluster_choice: str, | |
| color_choice: str, | |
| palette_choice: str, | |
| ) -> Tuple[Figure, go.Figure, pd.DataFrame, Dict[str, Any], Dict[str, Dict[str, Any]], List[Tuple[str, str]], Dict[str, Any]]: | |
| """Render the 2D and 3D figures with the requested options.""" | |
| ( | |
| _corpus, | |
| papers, | |
| _embeddings, | |
| _normalised, | |
| umap_2d, | |
| umap_3d, | |
| graph_edges, | |
| cluster_metadata, | |
| ) = load_resources() | |
| cluster_ids = np.array([paper.get("cluster_id", 0) for paper in papers], dtype=int) | |
| if cluster_choice != "All": | |
| cluster_value = int(cluster_choice) | |
| mask = cluster_ids == cluster_value | |
| clusters_for_plot = [c for c in cluster_metadata if int(c.get("cluster_id", -1)) == cluster_value] | |
| else: | |
| mask = np.ones(len(papers), dtype=bool) | |
| clusters_for_plot = cluster_metadata | |
| selected_indices = np.where(mask)[0] | |
| if selected_indices.size == 0: | |
| metrics_empty = { | |
| "clusters": 0, | |
| "points": 0, | |
| "edges": 0, | |
| "render_ms": {"2d": 0.0, "3d": 0.0}, | |
| } | |
| return go.Figure(), go.Figure(), pd.DataFrame(), {}, {}, [], metrics_empty | |
| filtered_papers = [papers[idx] for idx in selected_indices] | |
| coords_2d = umap_2d[selected_indices] | |
| coords_3d = umap_3d[selected_indices] | |
| cluster_ids_subset = cluster_ids[selected_indices] | |
| embedding_indices_subset = np.array([int(filtered_papers[i].get("embedding_idx", selected_indices[i])) for i in range(len(filtered_papers))]) | |
| selected_set = {int(idx) for idx in selected_indices.tolist()} | |
| filtered_edges = _filter_edges(graph_edges, selected_set) if show_edges else [] | |
| color_basis_key = _resolve_color_basis(color_choice) | |
| palette = _resolve_palette(palette_choice) | |
| cluster_palette = _resolve_palette(DEFAULT_PALETTE) | |
| cluster_color_map = _build_cluster_color_map(cluster_ids, cluster_palette) | |
| if color_basis_key == "cluster": | |
| label_values = np.array([str(paper.get("cluster_id", "unknown")) for paper in filtered_papers]) | |
| point_color_map = {str(cluster_id): cluster_color_map.get(int(cluster_id), (0.2, 0.4, 0.8, 1.0)) for cluster_id in label_values} | |
| else: | |
| label_values = np.array([paper.get("primary_category") or "unknown" for paper in filtered_papers]) | |
| unique_labels = sorted(set(label_values)) | |
| point_color_map = {label: palette[idx % len(palette)] for idx, label in enumerate(unique_labels)} | |
| hover_text = _hover_text_for_papers(filtered_papers) | |
| start_2d = time.perf_counter() | |
| fig2d = _build_2d_plot( | |
| coords_2d, | |
| selected_indices, | |
| label_values, | |
| hover_text, | |
| filtered_edges, | |
| clusters_for_plot, | |
| cluster_ids_subset, | |
| point_color_map, | |
| cluster_color_map, | |
| ) | |
| render_2d_ms = (time.perf_counter() - start_2d) * 1000.0 | |
| start_3d = time.perf_counter() | |
| fig3d = _build_3d_figure( | |
| coords_3d, | |
| selected_indices, | |
| label_values, | |
| hover_text, | |
| filtered_edges, | |
| clusters_for_plot, | |
| cluster_ids_subset, | |
| embedding_indices_subset, | |
| point_color_map, | |
| cluster_color_map, | |
| ) | |
| render_3d_ms = (time.perf_counter() - start_3d) * 1000.0 | |
| overview_df = _build_cluster_overview(filtered_papers) | |
| hierarchy_json = _build_cluster_hierarchy_json(filtered_papers) | |
| paper_lookup = { | |
| str(int(embedding_indices_subset[i])): { | |
| "title": paper.get("title", "(untitled)"), | |
| "paper_id": paper.get("paper_id"), | |
| "cluster_id": paper.get("cluster_id"), | |
| "primary_category": paper.get("primary_category"), | |
| "authors": paper.get("authors", []), | |
| "abstract": paper.get("abstract", ""), | |
| "published": paper.get("published"), | |
| "url": paper.get("meta", {}).get("url") if isinstance(paper.get("meta"), dict) else paper.get("url"), | |
| } | |
| for i, paper in enumerate(filtered_papers) | |
| } | |
| paper_options = [ | |
| (f"{details['title']} (C{details['cluster_id']})", str(idx)) | |
| for idx, details in paper_lookup.items() | |
| ] | |
| metrics = { | |
| "clusters": int(len(set(cluster_ids_subset))), | |
| "points": int(len(selected_indices)), | |
| "edges": int(len(filtered_edges)), | |
| "render_ms": { | |
| "2d": round(render_2d_ms, 2), | |
| "3d": round(render_3d_ms, 2), | |
| }, | |
| } | |
| return fig2d, fig3d, overview_df, hierarchy_json, paper_lookup, paper_options, metrics | |
| def refresh_embedding_plot() -> None: | |
| """Clear caches to force plot regeneration on next render.""" | |
| load_resources.cache_clear() | |
| get_embedding_plots.cache_clear() | |
| def get_embedding_plots() -> Tuple[Figure, go.Figure, pd.DataFrame, Dict[str, Any], Dict[str, Dict[str, Any]], List[Tuple[str, str]], Dict[str, Any]]: | |
| """Return cached 2D and 3D plots plus cluster summaries using default settings.""" | |
| return render_plots( | |
| show_edges=True, | |
| cluster_choice="All", | |
| color_choice=DEFAULT_COLOR_BASIS, | |
| palette_choice=DEFAULT_PALETTE, | |
| ) | |
| def _format_results(indices: np.ndarray, scores: np.ndarray, papers: Sequence[Dict[str, Any]]) -> List[List[Any]]: | |
| """Convert ranked results into display-friendly rows.""" | |
| formatted: List[List[Any]] = [] | |
| for rank, (idx, score) in enumerate(zip(indices, scores), start=1): | |
| paper = papers[int(idx)] | |
| abstract = str(paper.get("abstract", "")).strip() | |
| summary = abstract[:220] + ("…" if len(abstract) > 220 else "") | |
| formatted.append( | |
| [ | |
| rank, | |
| round(float(score), 4), | |
| paper.get("title", "(untitled)"), | |
| paper.get("paper_id", "N/A"), | |
| summary, | |
| ] | |
| ) | |
| return formatted | |
| def search_corpus(query: str, top_k: int) -> List[List[Any]]: | |
| """Perform a cosine-similarity search over the local corpus.""" | |
| query = (query or "").strip() | |
| if not query: | |
| return [] | |
| _, papers, embeddings, normalised, _, _, _, _ = load_resources() | |
| embedder = get_embedder(None) | |
| query_vector = embedder.embed_query(query) | |
| query_norm = query_vector / np.linalg.norm(query_vector) | |
| scores = normalised @ query_norm | |
| top_k = max(1, min(int(top_k), len(papers))) | |
| ranked_indices = np.argsort(scores)[::-1][:top_k] | |
| ranked_scores = scores[ranked_indices] | |
| return _format_results(ranked_indices, ranked_scores, papers) | |
| def _refresh_and_render( | |
| show_edges: bool, | |
| cluster_choice: str, | |
| color_choice: str, | |
| palette_choice: str, | |
| ) -> Tuple[Figure, go.Figure, pd.DataFrame, Dict[str, Any], Dict[str, Dict[str, Any]], List[Tuple[str, str]], Dict[str, Any]]: | |
| refresh_embedding_plot() | |
| return render_plots(show_edges, cluster_choice, color_choice, palette_choice) | |
| def build_interface() -> gr.Blocks: | |
| """Assemble and return the Gradio Blocks interface.""" | |
| with gr.Blocks(title="NexaSci Mini Corpus Search") as demo: | |
| gr.Markdown( | |
| """ | |
| # NexaSci Corpus Explorer | |
| Enter a short description or paper title to retrieve the closest papers from the locally built corpus. | |
| """ | |
| ) | |
| with gr.Accordion("Corpus Builder", open=False): | |
| categories_box = gr.Textbox( | |
| label="Categories", | |
| value="cs.AI cs.LG cs.CL stat.ML", | |
| placeholder="Space-separated arXiv categories", | |
| ) | |
| max_papers_slider = gr.Slider(label="Max papers", minimum=100, maximum=1000, step=50, value=500) | |
| num_clusters_slider = gr.Slider(label="KMeans clusters", minimum=5, maximum=60, step=5, value=30) | |
| batch_size_slider = gr.Slider(label="Embedding batch size", minimum=4, maximum=64, step=4, value=16) | |
| build_button = gr.Button("Build Corpus", variant="primary") | |
| build_status = gr.Markdown() | |
| with gr.Row(): | |
| show_edges_checkbox = gr.Checkbox(label="Show graph edges", value=True) | |
| cluster_dropdown = gr.Dropdown( | |
| label="Cluster filter", | |
| value="All", | |
| choices=_cluster_options(), | |
| ) | |
| color_basis_dropdown = gr.Radio( | |
| label="Color by", | |
| choices=list(COLOR_BASIS_OPTIONS.keys()), | |
| value=DEFAULT_COLOR_BASIS, | |
| ) | |
| palette_dropdown = gr.Dropdown( | |
| label="Color palette", | |
| choices=list(PALETTE_OPTIONS.keys()), | |
| value=DEFAULT_PALETTE, | |
| ) | |
| initial_2d, initial_3d, initial_overview, initial_hierarchy, initial_lookup, initial_options, initial_metrics = get_embedding_plots() | |
| view_selector = gr.Radio( | |
| label="Visualization", | |
| choices=["2D", "3D"], | |
| value="2D", | |
| interactive=True, | |
| ) | |
| embedding_plot = gr.Plot(label="Embedding", value=initial_2d, elem_id="embedding-plot") | |
| controls_row = gr.Row() | |
| with controls_row: | |
| orbit_button = gr.Button("Toggle Orbit", variant="secondary") | |
| fullscreen_button = gr.Button("Fullscreen", variant="secondary") | |
| cluster_overview_table = gr.Dataframe( | |
| value=initial_overview, | |
| label="Cluster Overview", | |
| interactive=False, | |
| ) | |
| cluster_hierarchy_json = gr.JSON(value=initial_hierarchy, label="Cluster Hierarchy") | |
| paper_state = gr.State(initial_lookup) | |
| gr.Markdown("## Paper Details") | |
| paper_selector = gr.Dropdown( | |
| choices=initial_options, | |
| label="Select Paper", | |
| value=None, | |
| ) | |
| paper_detail_display = gr.Markdown("Select a paper from the dropdown.") | |
| metrics_json = gr.JSON(value=initial_metrics, label="Render Metrics") | |
| def _build_corpus(max_papers: int, categories: str, num_clusters: int, batch_size: int, | |
| show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, view: str): | |
| cat_list = [c.strip() for c in categories.split() if c.strip()] | |
| if not cat_list: | |
| cat_list = ["cs.AI"] | |
| cmd = [ | |
| sys.executable, | |
| "-m", | |
| "pipeline.build_corpus", | |
| "--categories", | |
| *cat_list, | |
| "--max-papers", | |
| str(int(max_papers)), | |
| "--num-clusters", | |
| str(int(num_clusters)), | |
| "--batch-size", | |
| str(int(batch_size)), | |
| ] | |
| start = time.perf_counter() | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| elapsed = time.perf_counter() - start | |
| if result.returncode != 0: | |
| logs = (result.stderr or result.stdout or "").strip() | |
| if len(logs) > 800: | |
| logs = "..." + logs[-800:] | |
| status = f"❌ Corpus build failed in {elapsed:.1f}s\n```\n{logs}\n```" | |
| else: | |
| logs = (result.stdout or "Success").strip() | |
| if len(logs) > 800: | |
| logs = "..." + logs[-800:] | |
| status = f"✅ Corpus rebuilt with {int(max_papers)} papers in {elapsed:.1f}s\n```\n{logs}\n```" | |
| fig2d, fig3d, overview, hierarchy, lookup, options, metrics = _refresh_and_render( | |
| show_edges, cluster_choice, color_choice, palette_choice | |
| ) | |
| return ( | |
| status, | |
| fig2d if view == "2D" else fig3d, | |
| overview, | |
| hierarchy, | |
| lookup, | |
| gr.update(choices=options, value=None), | |
| "Select a paper from the dropdown.", | |
| metrics, | |
| ) | |
| def _update_plots(show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str): | |
| return render_plots(show_edges, cluster_choice, color_choice, palette_choice) | |
| refresh_button = gr.Button("Refresh Data") | |
| def _refresh_and_update(show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, view: str): | |
| fig2d, fig3d, overview, hierarchy, lookup, options, metrics = _refresh_and_render( | |
| show_edges, cluster_choice, color_choice, palette_choice | |
| ) | |
| if view == "3D": | |
| fig3d.update_layout(margin=dict(l=10, r=10, t=60, b=10)) | |
| return ( | |
| fig2d if view == "2D" else fig3d, | |
| overview, | |
| hierarchy, | |
| lookup, | |
| gr.update(choices=options, value=None), | |
| "Select a paper from the dropdown.", | |
| metrics, | |
| ) | |
| refresh_button.click( | |
| _refresh_and_update, | |
| inputs=[show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown, view_selector], | |
| outputs=[embedding_plot, cluster_overview_table, cluster_hierarchy_json, paper_state, paper_selector, paper_detail_display, metrics_json], | |
| ) | |
| def _update_visual(show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, view: str): | |
| fig2d, fig3d, overview, hierarchy, lookup, options, metrics = _update_plots( | |
| show_edges, cluster_choice, color_choice, palette_choice | |
| ) | |
| return ( | |
| fig2d if view == "2D" else fig3d, | |
| overview, | |
| hierarchy, | |
| lookup, | |
| gr.update(choices=options, value=None), | |
| "Select a paper from the dropdown.", | |
| metrics, | |
| ) | |
| view_selector.change( | |
| _update_visual, | |
| inputs=[show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown, view_selector], | |
| outputs=[embedding_plot, cluster_overview_table, cluster_hierarchy_json, paper_state, paper_selector, paper_detail_display, metrics_json], | |
| ) | |
| for control in [show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown]: | |
| control.change( | |
| _update_visual, | |
| inputs=[show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown, view_selector], | |
| outputs=[embedding_plot, cluster_overview_table, cluster_hierarchy_json, paper_state, paper_selector, paper_detail_display, metrics_json], | |
| ) | |
| orbit_button.click(None, inputs=None, outputs=None, js=ORBIT_JS) | |
| fullscreen_button.click(None, inputs=None, outputs=None, js=FULLSCREEN_JS) | |
| build_button.click( | |
| _build_corpus, | |
| inputs=[ | |
| max_papers_slider, | |
| categories_box, | |
| num_clusters_slider, | |
| batch_size_slider, | |
| show_edges_checkbox, | |
| cluster_dropdown, | |
| color_basis_dropdown, | |
| palette_dropdown, | |
| view_selector, | |
| ], | |
| outputs=[ | |
| build_status, | |
| embedding_plot, | |
| cluster_overview_table, | |
| cluster_hierarchy_json, | |
| paper_state, | |
| paper_selector, | |
| paper_detail_display, | |
| metrics_json, | |
| ], | |
| ) | |
| gr.Markdown("## Semantic Search") | |
| with gr.Row(): | |
| query_input = gr.Textbox( | |
| label="Query", | |
| placeholder="e.g. graph neural networks for chemistry", | |
| lines=2, | |
| ) | |
| topk_slider = gr.Slider( | |
| label="Top K Results", | |
| minimum=1, | |
| maximum=20, | |
| step=1, | |
| value=5, | |
| ) | |
| results_table = gr.Dataframe( | |
| headers=["rank", "score", "title", "paper_id", "summary"], | |
| label="Results", | |
| datatype=["number", "number", "str", "str", "str"], | |
| interactive=False, | |
| ) | |
| submit_btn = gr.Button("Search") | |
| submit_btn.click(search_corpus, inputs=[query_input, topk_slider], outputs=[results_table]) | |
| def _format_details(selection: str | None, paper_map: Dict[str, Dict[str, Any]]): | |
| if not selection: | |
| return "Select a paper from the dropdown." | |
| details = paper_map.get(selection) | |
| if not details: | |
| return "No details available for this paper." | |
| authors = ", ".join(details.get("authors", [])) or "Unknown" | |
| lines = [ | |
| f"### {details.get('title', '(untitled)')}", | |
| f"**Paper ID:** {details.get('paper_id', 'N/A')}", | |
| f"**Cluster:** {details.get('cluster_id', 'N/A')} | **Category:** {details.get('primary_category', 'unknown')}", | |
| f"**Authors:** {authors}", | |
| f"**Published:** {details.get('published', 'N/A')}", | |
| "", | |
| details.get("abstract", "No abstract available."), | |
| ] | |
| url = details.get("url") | |
| if url: | |
| lines.append(f"\n[View paper]({url})") | |
| return "\n\n".join(lines) | |
| paper_selector.change(_format_details, inputs=[paper_selector, paper_state], outputs=paper_detail_display) | |
| return demo | |
| def main() -> None: | |
| """Launch the Gradio demo.""" | |
| interface = build_interface() | |
| interface.launch() | |
| if __name__ == "__main__": # pragma: no cover - manual launch helper | |
| main() | |