""" Interactive Gradio UI for exploring the local SPECTER2 corpus.""" from __future__ import annotations from collections import Counter, defaultdict import subprocess import sys import time from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Sequence, Set, Tuple import gradio as gr import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go import matplotlib.pyplot as plt from matplotlib.collections import LineCollection from matplotlib.colors import to_rgba from matplotlib.figure import Figure FULLSCREEN_JS = """ () => { const container = document.getElementById('embedding-plot'); if (!container) return; const plot = container.querySelector('.js-plotly-plot') || container; if (!document.fullscreenElement) { if (plot.requestFullscreen) { plot.requestFullscreen(); } else if (plot.webkitRequestFullscreen) { plot.webkitRequestFullscreen(); } } else { if (document.exitFullscreen) { document.exitFullscreen(); } else if (document.webkitExitFullscreen) { document.webkitExitFullscreen(); } } } """ ORBIT_JS = """ () => { const container = document.getElementById('embedding-plot'); if (!container) return; const plot = container.querySelector('.js-plotly-plot'); if (!plot) return; window._plotOrbitIntervals = window._plotOrbitIntervals || {}; const key = 'embedding-plot'; if (window._plotOrbitIntervals[key]) { clearInterval(window._plotOrbitIntervals[key]); delete window._plotOrbitIntervals[key]; return; } let angle = 0; const radius = 1.6; window._plotOrbitIntervals[key] = setInterval(() => { const updatedPlot = container.querySelector('.js-plotly-plot'); if (!updatedPlot) { clearInterval(window._plotOrbitIntervals[key]); delete window._plotOrbitIntervals[key]; return; } angle = (angle + 2) % 360; const rad = angle * Math.PI / 180; Plotly.relayout(updatedPlot, { 'scene.camera.eye': { x: radius * Math.cos(rad), y: radius * Math.sin(rad), z: 0.9, }, }); }, 50); } """ CUSTOM_JS = """ function(componentId, action) { const el = document.getElementById(componentId); if (!el) return; if (action === "orbit") { if (window._orbitIntervals === undefined) { window._orbitIntervals = {}; } if (window._orbitIntervals[componentId]) { clearInterval(window._orbitIntervals[componentId]); delete window._orbitIntervals[componentId]; } else { let angle = 0; const interval = setInterval(() => { angle = (angle + 2) % 360; const rad = angle * Math.PI / 180; const r = 1.6; const layout = { scene: {camera: {eye: {x: r * Math.cos(rad), y: r * Math.sin(rad), z: 0.9}}} }; Plotly.relayout(el, layout); }, 50); window._orbitIntervals[componentId] = interval; } } else if (action === "fullscreen") { const container = el.closest("div.svelte-1ipelgc"); const target = container || el; if (!document.fullscreenElement) { target.requestFullscreen?.(); } else { document.exitFullscreen?.(); } } } """ from pipeline.embed import Specter2Embedder from pipeline.storage import load_embeddings, load_canonical_corpus INDEX_DIR = Path(__file__).resolve().parents[1] / "index" CORPUS_PATH = INDEX_DIR / "corpus.json" EMBEDDINGS_PATH = INDEX_DIR / "embeddings.npy" DEFAULT_COLOR_BASIS = "Cluster" DEFAULT_PALETTE = "Plotly" COLOR_BASIS_OPTIONS: Dict[str, str] = { "Cluster": "cluster", "Primary Category": "primary_category", } PALETTE_OPTIONS: Dict[str, List[str]] = { "Plotly": px.colors.qualitative.Plotly, "Bold": px.colors.qualitative.Bold, "Vivid": px.colors.qualitative.Vivid, "Pastel": px.colors.qualitative.Pastel, "Safe": px.colors.qualitative.Safe, } MAX_EDGE_RENDER = 2000 def _float_rgba_to_plotly(rgba: Tuple[float, float, float, float], alpha: float | None = None) -> str: r, g, b, a = rgba if alpha is not None: a = alpha return f"rgba({int(r * 255)}, {int(g * 255)}, {int(b * 255)}, {a:.2f})" def _build_cluster_color_map(cluster_ids: Sequence[int], palette: Sequence[Tuple[float, float, float, float]]) -> Dict[int, Tuple[float, float, float, float]]: unique_ids = sorted(set(int(cid) for cid in cluster_ids)) color_map: Dict[int, Tuple[float, float, float, float]] = {} for idx, cluster_id in enumerate(unique_ids): color_map[cluster_id] = palette[idx % len(palette)] return color_map def _build_cluster_overview(papers: Sequence[Dict[str, Any]]) -> pd.DataFrame: clusters: Dict[int, Dict[str, Any]] = defaultdict(lambda: { "cluster_id": None, "size": 0, "categories": Counter(), "sample_titles": [], }) for paper in papers: cluster_id = int(paper.get("cluster_id", -1)) entry = clusters[cluster_id] entry["cluster_id"] = cluster_id entry["size"] += 1 category = paper.get("primary_category") or "unknown" entry["categories"][category] += 1 if len(entry["sample_titles"]) < 3: entry["sample_titles"].append(paper.get("title", "(untitled)")) entry["major_category"] = category.split(".")[0] if "." in category else category overview_rows = [] for data in clusters.values(): dominant_category = data["categories"].most_common(1)[0][0] if data["categories"] else "unknown" overview_rows.append( { "cluster_id": data["cluster_id"], "size": data["size"], "major_category": data.get("major_category", "unknown"), "dominant_category": dominant_category, "sample_titles": " | ".join(data["sample_titles"]), } ) overview_rows.sort(key=lambda row: row["cluster_id"]) return pd.DataFrame(overview_rows) def _build_cluster_hierarchy_json(papers: Sequence[Dict[str, Any]]) -> Dict[str, Any]: hierarchy: Dict[str, Dict[str, List[Dict[str, Any]]]] = defaultdict(lambda: defaultdict(list)) for paper in papers: cluster_id = int(paper.get("cluster_id", -1)) category = paper.get("primary_category") or "unknown" major = category.split(".")[0] if "." in category else category hierarchy[major][category].append( { "cluster_id": cluster_id, "paper_id": paper.get("paper_id"), "title": paper.get("title"), } ) major_payload = [] for major, subcategories in hierarchy.items(): sub_payload = [] for category, clusters in sorted(subcategories.items()): clusters_sorted = sorted(clusters, key=lambda c: c["cluster_id"]) sub_payload.append({ "category": category, "clusters": clusters_sorted, "cluster_ids": sorted({entry["cluster_id"] for entry in clusters_sorted}), }) major_payload.append({ "major": major, "subcategories": sub_payload, }) major_payload.sort(key=lambda entry: entry["major"]) return {"major_categories": major_payload} def _filter_edges(edges: Sequence[Dict[str, Any]], selected: Set[int]) -> List[Dict[str, Any]]: """Return only edges whose endpoints are in the selected set.""" return [ edge for edge in edges if int(edge.get("source", -1)) in selected and int(edge.get("target", -1)) in selected ] def _normalise_embeddings(vectors: np.ndarray) -> np.ndarray: """Return L2-normalised embeddings, guarding against zero vectors.""" if vectors.size == 0: return vectors norms = np.linalg.norm(vectors, axis=1, keepdims=True) norms[norms == 0] = 1.0 return vectors / norms @lru_cache(maxsize=1) def load_resources() -> Tuple[ Dict[str, Any], List[Dict[str, Any]], np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[Dict[str, Any]], List[Dict[str, Any]], ]: """Load canonical corpus data, embeddings, and graph metadata from disk.""" if not CORPUS_PATH.exists() or not EMBEDDINGS_PATH.exists(): raise FileNotFoundError( "Corpus artifacts not found. Run `python -m pipeline.build_corpus` first." ) corpus_doc = load_canonical_corpus(CORPUS_PATH) papers = corpus_doc.get("papers", []) embeddings = load_embeddings(EMBEDDINGS_PATH) if embeddings.shape[0] != len(papers): raise ValueError( "Mismatch between embeddings and canonical corpus entries. Rebuild the corpus to continue." ) papers_sorted = sorted(papers, key=lambda entry: entry.get("embedding_idx", 0)) if not all(paper.get("embedding_idx") == idx for idx, paper in enumerate(papers_sorted)): raise ValueError("Embedding indices in canonical corpus do not match their positions; rebuild the corpus.") umap_2d = np.array([paper.get("umap_2d", [0.0, 0.0]) for paper in papers_sorted], dtype=np.float32) umap_3d = np.array([paper.get("umap_3d", [0.0, 0.0, 0.0]) for paper in papers_sorted], dtype=np.float32) normalised = _normalise_embeddings(embeddings.astype(np.float32)) graph_edges = corpus_doc.get("graph", {}).get("edges", []) cluster_metadata = corpus_doc.get("clusters", []) return ( corpus_doc, papers_sorted, embeddings, normalised, umap_2d, umap_3d, graph_edges, cluster_metadata, ) @lru_cache(maxsize=1) def get_embedder(device: str | None = None) -> Specter2Embedder: """Instantiate the Specter2 embedder once.""" return Specter2Embedder(device=device) @lru_cache(maxsize=1) def _cluster_options() -> List[str]: """Return the cluster dropdown options (All + IDs).""" (_, papers, *_rest) = load_resources() cluster_ids = sorted({int(paper.get("cluster_id", 0)) for paper in papers}) return ["All"] + [str(cluster_id) for cluster_id in cluster_ids] def _resolve_color_basis(choice: str) -> str: return COLOR_BASIS_OPTIONS.get(choice, COLOR_BASIS_OPTIONS[DEFAULT_COLOR_BASIS]) def _resolve_palette(choice: str) -> List[Tuple[float, float, float, float]]: palette = PALETTE_OPTIONS.get(choice, PALETTE_OPTIONS[DEFAULT_PALETTE]) resolved: List[Tuple[float, float, float, float]] = [] for color in palette: try: resolved.append(to_rgba(color)) except ValueError: if color.startswith("rgb"): parts = color[color.find("(") + 1 : color.find(")")].split(",") floats = tuple(float(part.strip()) / 255.0 for part in parts) resolved.append((*floats, 1.0)) else: raise if not resolved: resolved.append((0.2, 0.4, 0.8, 1.0)) return resolved def _hover_text_for_papers(papers: Sequence[Dict[str, Any]]) -> np.ndarray: """Generate hover text for each paper.""" hover = [] for paper in papers: hover.append( "
".join( [ paper.get("title", "(untitled)"), f"ID: {paper.get('paper_id', 'n/a')}", f"Cluster: {paper.get('cluster_id', 'n/a')}", f"Category: {paper.get('primary_category', 'unknown')}", f"Authors: {', '.join(paper.get('authors', [])[:3])}" + ("…" if len(paper.get('authors', [])) > 3 else ""), ] ) ) return np.array(hover) def _group_points(labels: np.ndarray, palette: Sequence[str]) -> List[Tuple[str, np.ndarray, str]]: """Return masking information for each unique label.""" unique = sorted(np.unique(labels)) groups: List[Tuple[str, np.ndarray, str]] = [] for idx, label in enumerate(unique): mask = labels == label color = palette[idx % len(palette)] groups.append((label, mask, color)) return groups def _build_2d_plot( coords: np.ndarray, original_indices: Sequence[int], labels: np.ndarray, hover_text: np.ndarray, edges: Sequence[Dict[str, Any]], clusters: Sequence[Dict[str, Any]], cluster_ids_subset: np.ndarray, point_color_map: Dict[str, Tuple[float, float, float, float]], cluster_color_map: Dict[int, Tuple[float, float, float, float]], ) -> plt.Figure: fig, ax = plt.subplots(figsize=(6.8, 6.2), dpi=120) if coords.shape[0] < 1: ax.set_title("Corpus Embedding Map (2D)") ax.axis("off") return fig label_order = sorted(set(labels)) for label in label_order: mask = labels == label if not np.any(mask): continue rgba = point_color_map.get(label) if rgba is None: rgba = (0.25, 0.5, 0.85, 1.0) ax.scatter( coords[mask, 0], coords[mask, 1], s=26, c=[rgba], alpha=0.9, linewidths=0.3, edgecolors="#f5f5f5", label=label, ) if edges: index_map = {orig_idx: pos for pos, orig_idx in enumerate(original_indices)} segment_map: Dict[int, List[List[Tuple[float, float]]]] = defaultdict(list) for edge in edges[:MAX_EDGE_RENDER]: source = int(edge["source"]) target = int(edge["target"]) if source not in index_map or target not in index_map: continue src_idx = index_map[source] tgt_idx = index_map[target] cluster_id = int(cluster_ids_subset[src_idx]) if src_idx < len(cluster_ids_subset) else -1 segment_map[cluster_id].append( [ (coords[src_idx, 0], coords[src_idx, 1]), (coords[tgt_idx, 0], coords[tgt_idx, 1]), ] ) for cluster_id, segments in segment_map.items(): base = cluster_color_map.get(cluster_id, (0.55, 0.55, 0.55, 1.0)) lc = LineCollection( segments, colors=[(base[0], base[1], base[2], 0.22)], linewidths=0.55, ) ax.add_collection(lc) for cluster in clusters: centroid = cluster.get("centroid_2d") if not centroid: continue cluster_id = int(cluster.get("cluster_id", -1)) rgba = cluster_color_map.get(cluster_id, (0.1, 0.1, 0.1, 1.0)) ax.scatter( centroid[0], centroid[1], s=150, marker="D", c=[rgba], edgecolors="#222222", linewidths=0.6, alpha=0.95, ) ax.text( centroid[0], centroid[1], f"C{cluster['cluster_id']}", fontsize=9, ha="center", va="bottom", color="#222222", ) ax.set_title("Corpus Embedding Map (2D)") ax.set_xlabel("UMAP 1") ax.set_ylabel("UMAP 2") ax.tick_params(labelsize=8) ax.set_aspect("equal", adjustable="datalim") ax.grid(alpha=0.15, linestyle="--", linewidth=0.45) ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.16), ncol=4, fontsize=7, frameon=False) fig.tight_layout() return fig def _build_3d_figure( coords: np.ndarray, original_indices: Sequence[int], labels: np.ndarray, hover_text: np.ndarray, edges: Sequence[Dict[str, Any]], clusters: Sequence[Dict[str, Any]], cluster_ids_subset: np.ndarray, embedding_indices_subset: np.ndarray, point_color_map: Dict[str, Tuple[float, float, float, float]], cluster_color_map: Dict[int, Tuple[float, float, float, float]], ) -> go.Figure: """Generate a 3D Plotly figure for the embedding map.""" fig = go.Figure() if coords.shape[0] < 1: fig.update_layout(title="Corpus Embedding Map (3D)") return fig label_order = sorted(set(labels)) for label in label_order: mask = labels == label if not np.any(mask): continue rgba = point_color_map.get(label) rgba_str = _float_rgba_to_plotly(rgba) if rgba else "rgba(52, 120, 198, 0.9)" fig.add_trace( go.Scatter3d( x=coords[mask, 0], y=coords[mask, 1], z=coords[mask, 2], mode="markers", marker=dict(color=rgba_str, size=4.8, opacity=0.9, line=dict(width=0.6, color="#101010"), symbol="circle"), name=str(label), hovertext=hover_text[mask], hoverinfo="text", customdata=embedding_indices_subset[mask][:, None], ) ) if edges: index_map = {orig_idx: pos for pos, orig_idx in enumerate(original_indices)} edge_segments: Dict[int, Dict[str, List[float]]] = defaultdict(lambda: {"x": [], "y": [], "z": []}) for edge in edges[:MAX_EDGE_RENDER]: source = int(edge["source"]) target = int(edge["target"]) if source not in index_map or target not in index_map: continue src_idx = index_map[source] tgt_idx = index_map[target] cluster_id = int(cluster_ids_subset[src_idx]) if src_idx < len(cluster_ids_subset) else -1 seg = edge_segments[cluster_id] seg["x"].extend([coords[src_idx, 0], coords[tgt_idx, 0], None]) seg["y"].extend([coords[src_idx, 1], coords[tgt_idx, 1], None]) seg["z"].extend([coords[src_idx, 2], coords[tgt_idx, 2], None]) for cluster_id, seg in edge_segments.items(): cluster_color = cluster_color_map.get(cluster_id, (0.4, 0.4, 0.4, 1.0)) fig.add_trace( go.Scatter3d( x=seg["x"], y=seg["y"], z=seg["z"], mode="lines", line=dict(color=_float_rgba_to_plotly(cluster_color, alpha=0.18), width=1.3), hoverinfo="none", name=f"Cluster {cluster_id} edges", showlegend=False, ) ) if clusters: fig.add_trace( go.Scatter3d( x=[c["centroid_3d"][0] for c in clusters], y=[c["centroid_3d"][1] for c in clusters], z=[c["centroid_3d"][2] for c in clusters], mode="markers+text", marker=dict( symbol="diamond", size=12, color=[_float_rgba_to_plotly(cluster_color_map.get(int(c["cluster_id"]), (0.3, 0.3, 0.3, 1.0))) for c in clusters], line=dict(width=1.5, color="#222222"), ), text=[f"C{c['cluster_id']}" for c in clusters], textposition="top center", hovertext=[f"Cluster {c['cluster_id']}
Size: {c['size']}" for c in clusters], hoverinfo="text", name="Centroids", showlegend=False, ) ) fig.update_layout( title="Corpus Embedding Map (3D)", scene=dict( xaxis_title="UMAP 1", yaxis_title="UMAP 2", zaxis_title="UMAP 3", xaxis=dict(showgrid=True, zeroline=False, showbackground=False), yaxis=dict(showgrid=True, zeroline=False, showbackground=False), zaxis=dict(showgrid=True, zeroline=False, showbackground=False), ), legend=dict(orientation="h", y=-0.1), margin=dict(l=10, r=10, t=60, b=10), template="plotly_white", scene_camera=dict(eye=dict(x=1.6, y=1.6, z=0.9)), hovermode="closest", ) return fig def render_plots( show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, ) -> Tuple[Figure, go.Figure, pd.DataFrame, Dict[str, Any], Dict[str, Dict[str, Any]], List[Tuple[str, str]], Dict[str, Any]]: """Render the 2D and 3D figures with the requested options.""" ( _corpus, papers, _embeddings, _normalised, umap_2d, umap_3d, graph_edges, cluster_metadata, ) = load_resources() cluster_ids = np.array([paper.get("cluster_id", 0) for paper in papers], dtype=int) if cluster_choice != "All": cluster_value = int(cluster_choice) mask = cluster_ids == cluster_value clusters_for_plot = [c for c in cluster_metadata if int(c.get("cluster_id", -1)) == cluster_value] else: mask = np.ones(len(papers), dtype=bool) clusters_for_plot = cluster_metadata selected_indices = np.where(mask)[0] if selected_indices.size == 0: metrics_empty = { "clusters": 0, "points": 0, "edges": 0, "render_ms": {"2d": 0.0, "3d": 0.0}, } return go.Figure(), go.Figure(), pd.DataFrame(), {}, {}, [], metrics_empty filtered_papers = [papers[idx] for idx in selected_indices] coords_2d = umap_2d[selected_indices] coords_3d = umap_3d[selected_indices] cluster_ids_subset = cluster_ids[selected_indices] embedding_indices_subset = np.array([int(filtered_papers[i].get("embedding_idx", selected_indices[i])) for i in range(len(filtered_papers))]) selected_set = {int(idx) for idx in selected_indices.tolist()} filtered_edges = _filter_edges(graph_edges, selected_set) if show_edges else [] color_basis_key = _resolve_color_basis(color_choice) palette = _resolve_palette(palette_choice) cluster_palette = _resolve_palette(DEFAULT_PALETTE) cluster_color_map = _build_cluster_color_map(cluster_ids, cluster_palette) if color_basis_key == "cluster": label_values = np.array([str(paper.get("cluster_id", "unknown")) for paper in filtered_papers]) point_color_map = {str(cluster_id): cluster_color_map.get(int(cluster_id), (0.2, 0.4, 0.8, 1.0)) for cluster_id in label_values} else: label_values = np.array([paper.get("primary_category") or "unknown" for paper in filtered_papers]) unique_labels = sorted(set(label_values)) point_color_map = {label: palette[idx % len(palette)] for idx, label in enumerate(unique_labels)} hover_text = _hover_text_for_papers(filtered_papers) start_2d = time.perf_counter() fig2d = _build_2d_plot( coords_2d, selected_indices, label_values, hover_text, filtered_edges, clusters_for_plot, cluster_ids_subset, point_color_map, cluster_color_map, ) render_2d_ms = (time.perf_counter() - start_2d) * 1000.0 start_3d = time.perf_counter() fig3d = _build_3d_figure( coords_3d, selected_indices, label_values, hover_text, filtered_edges, clusters_for_plot, cluster_ids_subset, embedding_indices_subset, point_color_map, cluster_color_map, ) render_3d_ms = (time.perf_counter() - start_3d) * 1000.0 overview_df = _build_cluster_overview(filtered_papers) hierarchy_json = _build_cluster_hierarchy_json(filtered_papers) paper_lookup = { str(int(embedding_indices_subset[i])): { "title": paper.get("title", "(untitled)"), "paper_id": paper.get("paper_id"), "cluster_id": paper.get("cluster_id"), "primary_category": paper.get("primary_category"), "authors": paper.get("authors", []), "abstract": paper.get("abstract", ""), "published": paper.get("published"), "url": paper.get("meta", {}).get("url") if isinstance(paper.get("meta"), dict) else paper.get("url"), } for i, paper in enumerate(filtered_papers) } paper_options = [ (f"{details['title']} (C{details['cluster_id']})", str(idx)) for idx, details in paper_lookup.items() ] metrics = { "clusters": int(len(set(cluster_ids_subset))), "points": int(len(selected_indices)), "edges": int(len(filtered_edges)), "render_ms": { "2d": round(render_2d_ms, 2), "3d": round(render_3d_ms, 2), }, } return fig2d, fig3d, overview_df, hierarchy_json, paper_lookup, paper_options, metrics def refresh_embedding_plot() -> None: """Clear caches to force plot regeneration on next render.""" load_resources.cache_clear() get_embedding_plots.cache_clear() @lru_cache(maxsize=1) def get_embedding_plots() -> Tuple[Figure, go.Figure, pd.DataFrame, Dict[str, Any], Dict[str, Dict[str, Any]], List[Tuple[str, str]], Dict[str, Any]]: """Return cached 2D and 3D plots plus cluster summaries using default settings.""" return render_plots( show_edges=True, cluster_choice="All", color_choice=DEFAULT_COLOR_BASIS, palette_choice=DEFAULT_PALETTE, ) def _format_results(indices: np.ndarray, scores: np.ndarray, papers: Sequence[Dict[str, Any]]) -> List[List[Any]]: """Convert ranked results into display-friendly rows.""" formatted: List[List[Any]] = [] for rank, (idx, score) in enumerate(zip(indices, scores), start=1): paper = papers[int(idx)] abstract = str(paper.get("abstract", "")).strip() summary = abstract[:220] + ("…" if len(abstract) > 220 else "") formatted.append( [ rank, round(float(score), 4), paper.get("title", "(untitled)"), paper.get("paper_id", "N/A"), summary, ] ) return formatted def search_corpus(query: str, top_k: int) -> List[List[Any]]: """Perform a cosine-similarity search over the local corpus.""" query = (query or "").strip() if not query: return [] _, papers, embeddings, normalised, _, _, _, _ = load_resources() embedder = get_embedder(None) query_vector = embedder.embed_query(query) query_norm = query_vector / np.linalg.norm(query_vector) scores = normalised @ query_norm top_k = max(1, min(int(top_k), len(papers))) ranked_indices = np.argsort(scores)[::-1][:top_k] ranked_scores = scores[ranked_indices] return _format_results(ranked_indices, ranked_scores, papers) def _refresh_and_render( show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, ) -> Tuple[Figure, go.Figure, pd.DataFrame, Dict[str, Any], Dict[str, Dict[str, Any]], List[Tuple[str, str]], Dict[str, Any]]: refresh_embedding_plot() return render_plots(show_edges, cluster_choice, color_choice, palette_choice) def build_interface() -> gr.Blocks: """Assemble and return the Gradio Blocks interface.""" with gr.Blocks(title="NexaSci Mini Corpus Search") as demo: gr.Markdown( """ # NexaSci Corpus Explorer Enter a short description or paper title to retrieve the closest papers from the locally built corpus. """ ) with gr.Accordion("Corpus Builder", open=False): categories_box = gr.Textbox( label="Categories", value="cs.AI cs.LG cs.CL stat.ML", placeholder="Space-separated arXiv categories", ) max_papers_slider = gr.Slider(label="Max papers", minimum=100, maximum=1000, step=50, value=500) num_clusters_slider = gr.Slider(label="KMeans clusters", minimum=5, maximum=60, step=5, value=30) batch_size_slider = gr.Slider(label="Embedding batch size", minimum=4, maximum=64, step=4, value=16) build_button = gr.Button("Build Corpus", variant="primary") build_status = gr.Markdown() with gr.Row(): show_edges_checkbox = gr.Checkbox(label="Show graph edges", value=True) cluster_dropdown = gr.Dropdown( label="Cluster filter", value="All", choices=_cluster_options(), ) color_basis_dropdown = gr.Radio( label="Color by", choices=list(COLOR_BASIS_OPTIONS.keys()), value=DEFAULT_COLOR_BASIS, ) palette_dropdown = gr.Dropdown( label="Color palette", choices=list(PALETTE_OPTIONS.keys()), value=DEFAULT_PALETTE, ) initial_2d, initial_3d, initial_overview, initial_hierarchy, initial_lookup, initial_options, initial_metrics = get_embedding_plots() view_selector = gr.Radio( label="Visualization", choices=["2D", "3D"], value="2D", interactive=True, ) embedding_plot = gr.Plot(label="Embedding", value=initial_2d, elem_id="embedding-plot") controls_row = gr.Row() with controls_row: orbit_button = gr.Button("Toggle Orbit", variant="secondary") fullscreen_button = gr.Button("Fullscreen", variant="secondary") cluster_overview_table = gr.Dataframe( value=initial_overview, label="Cluster Overview", interactive=False, ) cluster_hierarchy_json = gr.JSON(value=initial_hierarchy, label="Cluster Hierarchy") paper_state = gr.State(initial_lookup) gr.Markdown("## Paper Details") paper_selector = gr.Dropdown( choices=initial_options, label="Select Paper", value=None, ) paper_detail_display = gr.Markdown("Select a paper from the dropdown.") metrics_json = gr.JSON(value=initial_metrics, label="Render Metrics") def _build_corpus(max_papers: int, categories: str, num_clusters: int, batch_size: int, show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, view: str): cat_list = [c.strip() for c in categories.split() if c.strip()] if not cat_list: cat_list = ["cs.AI"] cmd = [ sys.executable, "-m", "pipeline.build_corpus", "--categories", *cat_list, "--max-papers", str(int(max_papers)), "--num-clusters", str(int(num_clusters)), "--batch-size", str(int(batch_size)), ] start = time.perf_counter() result = subprocess.run(cmd, capture_output=True, text=True) elapsed = time.perf_counter() - start if result.returncode != 0: logs = (result.stderr or result.stdout or "").strip() if len(logs) > 800: logs = "..." + logs[-800:] status = f"❌ Corpus build failed in {elapsed:.1f}s\n```\n{logs}\n```" else: logs = (result.stdout or "Success").strip() if len(logs) > 800: logs = "..." + logs[-800:] status = f"✅ Corpus rebuilt with {int(max_papers)} papers in {elapsed:.1f}s\n```\n{logs}\n```" fig2d, fig3d, overview, hierarchy, lookup, options, metrics = _refresh_and_render( show_edges, cluster_choice, color_choice, palette_choice ) return ( status, fig2d if view == "2D" else fig3d, overview, hierarchy, lookup, gr.update(choices=options, value=None), "Select a paper from the dropdown.", metrics, ) def _update_plots(show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str): return render_plots(show_edges, cluster_choice, color_choice, palette_choice) refresh_button = gr.Button("Refresh Data") def _refresh_and_update(show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, view: str): fig2d, fig3d, overview, hierarchy, lookup, options, metrics = _refresh_and_render( show_edges, cluster_choice, color_choice, palette_choice ) if view == "3D": fig3d.update_layout(margin=dict(l=10, r=10, t=60, b=10)) return ( fig2d if view == "2D" else fig3d, overview, hierarchy, lookup, gr.update(choices=options, value=None), "Select a paper from the dropdown.", metrics, ) refresh_button.click( _refresh_and_update, inputs=[show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown, view_selector], outputs=[embedding_plot, cluster_overview_table, cluster_hierarchy_json, paper_state, paper_selector, paper_detail_display, metrics_json], ) def _update_visual(show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, view: str): fig2d, fig3d, overview, hierarchy, lookup, options, metrics = _update_plots( show_edges, cluster_choice, color_choice, palette_choice ) return ( fig2d if view == "2D" else fig3d, overview, hierarchy, lookup, gr.update(choices=options, value=None), "Select a paper from the dropdown.", metrics, ) view_selector.change( _update_visual, inputs=[show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown, view_selector], outputs=[embedding_plot, cluster_overview_table, cluster_hierarchy_json, paper_state, paper_selector, paper_detail_display, metrics_json], ) for control in [show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown]: control.change( _update_visual, inputs=[show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown, view_selector], outputs=[embedding_plot, cluster_overview_table, cluster_hierarchy_json, paper_state, paper_selector, paper_detail_display, metrics_json], ) orbit_button.click(None, inputs=None, outputs=None, js=ORBIT_JS) fullscreen_button.click(None, inputs=None, outputs=None, js=FULLSCREEN_JS) build_button.click( _build_corpus, inputs=[ max_papers_slider, categories_box, num_clusters_slider, batch_size_slider, show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown, view_selector, ], outputs=[ build_status, embedding_plot, cluster_overview_table, cluster_hierarchy_json, paper_state, paper_selector, paper_detail_display, metrics_json, ], ) gr.Markdown("## Semantic Search") with gr.Row(): query_input = gr.Textbox( label="Query", placeholder="e.g. graph neural networks for chemistry", lines=2, ) topk_slider = gr.Slider( label="Top K Results", minimum=1, maximum=20, step=1, value=5, ) results_table = gr.Dataframe( headers=["rank", "score", "title", "paper_id", "summary"], label="Results", datatype=["number", "number", "str", "str", "str"], interactive=False, ) submit_btn = gr.Button("Search") submit_btn.click(search_corpus, inputs=[query_input, topk_slider], outputs=[results_table]) def _format_details(selection: str | None, paper_map: Dict[str, Dict[str, Any]]): if not selection: return "Select a paper from the dropdown." details = paper_map.get(selection) if not details: return "No details available for this paper." authors = ", ".join(details.get("authors", [])) or "Unknown" lines = [ f"### {details.get('title', '(untitled)')}", f"**Paper ID:** {details.get('paper_id', 'N/A')}", f"**Cluster:** {details.get('cluster_id', 'N/A')} | **Category:** {details.get('primary_category', 'unknown')}", f"**Authors:** {authors}", f"**Published:** {details.get('published', 'N/A')}", "", details.get("abstract", "No abstract available."), ] url = details.get("url") if url: lines.append(f"\n[View paper]({url})") return "\n\n".join(lines) paper_selector.change(_format_details, inputs=[paper_selector, paper_state], outputs=paper_detail_display) return demo def main() -> None: """Launch the Gradio demo.""" interface = build_interface() interface.launch() if __name__ == "__main__": # pragma: no cover - manual launch helper main()