Nexa_Labs / app.py
Allanatrix's picture
Update app.py
7b88b54 verified
""" Interactive Gradio UI for exploring the local SPECTER2 corpus."""
from __future__ import annotations
from collections import Counter, defaultdict
import subprocess
import sys
import time
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Sequence, Set, Tuple
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import to_rgba
from matplotlib.figure import Figure
FULLSCREEN_JS = """
() => {
const container = document.getElementById('embedding-plot');
if (!container) return;
const plot = container.querySelector('.js-plotly-plot') || container;
if (!document.fullscreenElement) {
if (plot.requestFullscreen) {
plot.requestFullscreen();
} else if (plot.webkitRequestFullscreen) {
plot.webkitRequestFullscreen();
}
} else {
if (document.exitFullscreen) {
document.exitFullscreen();
} else if (document.webkitExitFullscreen) {
document.webkitExitFullscreen();
}
}
}
"""
ORBIT_JS = """
() => {
const container = document.getElementById('embedding-plot');
if (!container) return;
const plot = container.querySelector('.js-plotly-plot');
if (!plot) return;
window._plotOrbitIntervals = window._plotOrbitIntervals || {};
const key = 'embedding-plot';
if (window._plotOrbitIntervals[key]) {
clearInterval(window._plotOrbitIntervals[key]);
delete window._plotOrbitIntervals[key];
return;
}
let angle = 0;
const radius = 1.6;
window._plotOrbitIntervals[key] = setInterval(() => {
const updatedPlot = container.querySelector('.js-plotly-plot');
if (!updatedPlot) {
clearInterval(window._plotOrbitIntervals[key]);
delete window._plotOrbitIntervals[key];
return;
}
angle = (angle + 2) % 360;
const rad = angle * Math.PI / 180;
Plotly.relayout(updatedPlot, {
'scene.camera.eye': {
x: radius * Math.cos(rad),
y: radius * Math.sin(rad),
z: 0.9,
},
});
}, 50);
}
"""
CUSTOM_JS = """
function(componentId, action) {
const el = document.getElementById(componentId);
if (!el) return;
if (action === "orbit") {
if (window._orbitIntervals === undefined) {
window._orbitIntervals = {};
}
if (window._orbitIntervals[componentId]) {
clearInterval(window._orbitIntervals[componentId]);
delete window._orbitIntervals[componentId];
} else {
let angle = 0;
const interval = setInterval(() => {
angle = (angle + 2) % 360;
const rad = angle * Math.PI / 180;
const r = 1.6;
const layout = {
scene: {camera: {eye: {x: r * Math.cos(rad), y: r * Math.sin(rad), z: 0.9}}}
};
Plotly.relayout(el, layout);
}, 50);
window._orbitIntervals[componentId] = interval;
}
} else if (action === "fullscreen") {
const container = el.closest("div.svelte-1ipelgc");
const target = container || el;
if (!document.fullscreenElement) {
target.requestFullscreen?.();
} else {
document.exitFullscreen?.();
}
}
}
"""
from pipeline.embed import Specter2Embedder
from pipeline.storage import load_embeddings, load_canonical_corpus
INDEX_DIR = Path(__file__).resolve().parents[1] / "index"
CORPUS_PATH = INDEX_DIR / "corpus.json"
EMBEDDINGS_PATH = INDEX_DIR / "embeddings.npy"
DEFAULT_COLOR_BASIS = "Cluster"
DEFAULT_PALETTE = "Plotly"
COLOR_BASIS_OPTIONS: Dict[str, str] = {
"Cluster": "cluster",
"Primary Category": "primary_category",
}
PALETTE_OPTIONS: Dict[str, List[str]] = {
"Plotly": px.colors.qualitative.Plotly,
"Bold": px.colors.qualitative.Bold,
"Vivid": px.colors.qualitative.Vivid,
"Pastel": px.colors.qualitative.Pastel,
"Safe": px.colors.qualitative.Safe,
}
MAX_EDGE_RENDER = 2000
def _float_rgba_to_plotly(rgba: Tuple[float, float, float, float], alpha: float | None = None) -> str:
r, g, b, a = rgba
if alpha is not None:
a = alpha
return f"rgba({int(r * 255)}, {int(g * 255)}, {int(b * 255)}, {a:.2f})"
def _build_cluster_color_map(cluster_ids: Sequence[int], palette: Sequence[Tuple[float, float, float, float]]) -> Dict[int, Tuple[float, float, float, float]]:
unique_ids = sorted(set(int(cid) for cid in cluster_ids))
color_map: Dict[int, Tuple[float, float, float, float]] = {}
for idx, cluster_id in enumerate(unique_ids):
color_map[cluster_id] = palette[idx % len(palette)]
return color_map
def _build_cluster_overview(papers: Sequence[Dict[str, Any]]) -> pd.DataFrame:
clusters: Dict[int, Dict[str, Any]] = defaultdict(lambda: {
"cluster_id": None,
"size": 0,
"categories": Counter(),
"sample_titles": [],
})
for paper in papers:
cluster_id = int(paper.get("cluster_id", -1))
entry = clusters[cluster_id]
entry["cluster_id"] = cluster_id
entry["size"] += 1
category = paper.get("primary_category") or "unknown"
entry["categories"][category] += 1
if len(entry["sample_titles"]) < 3:
entry["sample_titles"].append(paper.get("title", "(untitled)"))
entry["major_category"] = category.split(".")[0] if "." in category else category
overview_rows = []
for data in clusters.values():
dominant_category = data["categories"].most_common(1)[0][0] if data["categories"] else "unknown"
overview_rows.append(
{
"cluster_id": data["cluster_id"],
"size": data["size"],
"major_category": data.get("major_category", "unknown"),
"dominant_category": dominant_category,
"sample_titles": " | ".join(data["sample_titles"]),
}
)
overview_rows.sort(key=lambda row: row["cluster_id"])
return pd.DataFrame(overview_rows)
def _build_cluster_hierarchy_json(papers: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
hierarchy: Dict[str, Dict[str, List[Dict[str, Any]]]] = defaultdict(lambda: defaultdict(list))
for paper in papers:
cluster_id = int(paper.get("cluster_id", -1))
category = paper.get("primary_category") or "unknown"
major = category.split(".")[0] if "." in category else category
hierarchy[major][category].append(
{
"cluster_id": cluster_id,
"paper_id": paper.get("paper_id"),
"title": paper.get("title"),
}
)
major_payload = []
for major, subcategories in hierarchy.items():
sub_payload = []
for category, clusters in sorted(subcategories.items()):
clusters_sorted = sorted(clusters, key=lambda c: c["cluster_id"])
sub_payload.append({
"category": category,
"clusters": clusters_sorted,
"cluster_ids": sorted({entry["cluster_id"] for entry in clusters_sorted}),
})
major_payload.append({
"major": major,
"subcategories": sub_payload,
})
major_payload.sort(key=lambda entry: entry["major"])
return {"major_categories": major_payload}
def _filter_edges(edges: Sequence[Dict[str, Any]], selected: Set[int]) -> List[Dict[str, Any]]:
"""Return only edges whose endpoints are in the selected set."""
return [
edge
for edge in edges
if int(edge.get("source", -1)) in selected and int(edge.get("target", -1)) in selected
]
def _normalise_embeddings(vectors: np.ndarray) -> np.ndarray:
"""Return L2-normalised embeddings, guarding against zero vectors."""
if vectors.size == 0:
return vectors
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms[norms == 0] = 1.0
return vectors / norms
@lru_cache(maxsize=1)
def load_resources() -> Tuple[
Dict[str, Any],
List[Dict[str, Any]],
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
List[Dict[str, Any]],
List[Dict[str, Any]],
]:
"""Load canonical corpus data, embeddings, and graph metadata from disk."""
if not CORPUS_PATH.exists() or not EMBEDDINGS_PATH.exists():
raise FileNotFoundError(
"Corpus artifacts not found. Run `python -m pipeline.build_corpus` first."
)
corpus_doc = load_canonical_corpus(CORPUS_PATH)
papers = corpus_doc.get("papers", [])
embeddings = load_embeddings(EMBEDDINGS_PATH)
if embeddings.shape[0] != len(papers):
raise ValueError(
"Mismatch between embeddings and canonical corpus entries. Rebuild the corpus to continue."
)
papers_sorted = sorted(papers, key=lambda entry: entry.get("embedding_idx", 0))
if not all(paper.get("embedding_idx") == idx for idx, paper in enumerate(papers_sorted)):
raise ValueError("Embedding indices in canonical corpus do not match their positions; rebuild the corpus.")
umap_2d = np.array([paper.get("umap_2d", [0.0, 0.0]) for paper in papers_sorted], dtype=np.float32)
umap_3d = np.array([paper.get("umap_3d", [0.0, 0.0, 0.0]) for paper in papers_sorted], dtype=np.float32)
normalised = _normalise_embeddings(embeddings.astype(np.float32))
graph_edges = corpus_doc.get("graph", {}).get("edges", [])
cluster_metadata = corpus_doc.get("clusters", [])
return (
corpus_doc,
papers_sorted,
embeddings,
normalised,
umap_2d,
umap_3d,
graph_edges,
cluster_metadata,
)
@lru_cache(maxsize=1)
def get_embedder(device: str | None = None) -> Specter2Embedder:
"""Instantiate the Specter2 embedder once."""
return Specter2Embedder(device=device)
@lru_cache(maxsize=1)
def _cluster_options() -> List[str]:
"""Return the cluster dropdown options (All + IDs)."""
(_, papers, *_rest) = load_resources()
cluster_ids = sorted({int(paper.get("cluster_id", 0)) for paper in papers})
return ["All"] + [str(cluster_id) for cluster_id in cluster_ids]
def _resolve_color_basis(choice: str) -> str:
return COLOR_BASIS_OPTIONS.get(choice, COLOR_BASIS_OPTIONS[DEFAULT_COLOR_BASIS])
def _resolve_palette(choice: str) -> List[Tuple[float, float, float, float]]:
palette = PALETTE_OPTIONS.get(choice, PALETTE_OPTIONS[DEFAULT_PALETTE])
resolved: List[Tuple[float, float, float, float]] = []
for color in palette:
try:
resolved.append(to_rgba(color))
except ValueError:
if color.startswith("rgb"):
parts = color[color.find("(") + 1 : color.find(")")].split(",")
floats = tuple(float(part.strip()) / 255.0 for part in parts)
resolved.append((*floats, 1.0))
else:
raise
if not resolved:
resolved.append((0.2, 0.4, 0.8, 1.0))
return resolved
def _hover_text_for_papers(papers: Sequence[Dict[str, Any]]) -> np.ndarray:
"""Generate hover text for each paper."""
hover = []
for paper in papers:
hover.append(
"<br>".join(
[
paper.get("title", "(untitled)"),
f"ID: {paper.get('paper_id', 'n/a')}",
f"Cluster: {paper.get('cluster_id', 'n/a')}",
f"Category: {paper.get('primary_category', 'unknown')}",
f"Authors: {', '.join(paper.get('authors', [])[:3])}" + ("…" if len(paper.get('authors', [])) > 3 else ""),
]
)
)
return np.array(hover)
def _group_points(labels: np.ndarray, palette: Sequence[str]) -> List[Tuple[str, np.ndarray, str]]:
"""Return masking information for each unique label."""
unique = sorted(np.unique(labels))
groups: List[Tuple[str, np.ndarray, str]] = []
for idx, label in enumerate(unique):
mask = labels == label
color = palette[idx % len(palette)]
groups.append((label, mask, color))
return groups
def _build_2d_plot(
coords: np.ndarray,
original_indices: Sequence[int],
labels: np.ndarray,
hover_text: np.ndarray,
edges: Sequence[Dict[str, Any]],
clusters: Sequence[Dict[str, Any]],
cluster_ids_subset: np.ndarray,
point_color_map: Dict[str, Tuple[float, float, float, float]],
cluster_color_map: Dict[int, Tuple[float, float, float, float]],
) -> plt.Figure:
fig, ax = plt.subplots(figsize=(6.8, 6.2), dpi=120)
if coords.shape[0] < 1:
ax.set_title("Corpus Embedding Map (2D)")
ax.axis("off")
return fig
label_order = sorted(set(labels))
for label in label_order:
mask = labels == label
if not np.any(mask):
continue
rgba = point_color_map.get(label)
if rgba is None:
rgba = (0.25, 0.5, 0.85, 1.0)
ax.scatter(
coords[mask, 0],
coords[mask, 1],
s=26,
c=[rgba],
alpha=0.9,
linewidths=0.3,
edgecolors="#f5f5f5",
label=label,
)
if edges:
index_map = {orig_idx: pos for pos, orig_idx in enumerate(original_indices)}
segment_map: Dict[int, List[List[Tuple[float, float]]]] = defaultdict(list)
for edge in edges[:MAX_EDGE_RENDER]:
source = int(edge["source"])
target = int(edge["target"])
if source not in index_map or target not in index_map:
continue
src_idx = index_map[source]
tgt_idx = index_map[target]
cluster_id = int(cluster_ids_subset[src_idx]) if src_idx < len(cluster_ids_subset) else -1
segment_map[cluster_id].append(
[
(coords[src_idx, 0], coords[src_idx, 1]),
(coords[tgt_idx, 0], coords[tgt_idx, 1]),
]
)
for cluster_id, segments in segment_map.items():
base = cluster_color_map.get(cluster_id, (0.55, 0.55, 0.55, 1.0))
lc = LineCollection(
segments,
colors=[(base[0], base[1], base[2], 0.22)],
linewidths=0.55,
)
ax.add_collection(lc)
for cluster in clusters:
centroid = cluster.get("centroid_2d")
if not centroid:
continue
cluster_id = int(cluster.get("cluster_id", -1))
rgba = cluster_color_map.get(cluster_id, (0.1, 0.1, 0.1, 1.0))
ax.scatter(
centroid[0],
centroid[1],
s=150,
marker="D",
c=[rgba],
edgecolors="#222222",
linewidths=0.6,
alpha=0.95,
)
ax.text(
centroid[0],
centroid[1],
f"C{cluster['cluster_id']}",
fontsize=9,
ha="center",
va="bottom",
color="#222222",
)
ax.set_title("Corpus Embedding Map (2D)")
ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.tick_params(labelsize=8)
ax.set_aspect("equal", adjustable="datalim")
ax.grid(alpha=0.15, linestyle="--", linewidth=0.45)
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.16), ncol=4, fontsize=7, frameon=False)
fig.tight_layout()
return fig
def _build_3d_figure(
coords: np.ndarray,
original_indices: Sequence[int],
labels: np.ndarray,
hover_text: np.ndarray,
edges: Sequence[Dict[str, Any]],
clusters: Sequence[Dict[str, Any]],
cluster_ids_subset: np.ndarray,
embedding_indices_subset: np.ndarray,
point_color_map: Dict[str, Tuple[float, float, float, float]],
cluster_color_map: Dict[int, Tuple[float, float, float, float]],
) -> go.Figure:
"""Generate a 3D Plotly figure for the embedding map."""
fig = go.Figure()
if coords.shape[0] < 1:
fig.update_layout(title="Corpus Embedding Map (3D)")
return fig
label_order = sorted(set(labels))
for label in label_order:
mask = labels == label
if not np.any(mask):
continue
rgba = point_color_map.get(label)
rgba_str = _float_rgba_to_plotly(rgba) if rgba else "rgba(52, 120, 198, 0.9)"
fig.add_trace(
go.Scatter3d(
x=coords[mask, 0],
y=coords[mask, 1],
z=coords[mask, 2],
mode="markers",
marker=dict(color=rgba_str, size=4.8, opacity=0.9, line=dict(width=0.6, color="#101010"), symbol="circle"),
name=str(label),
hovertext=hover_text[mask],
hoverinfo="text",
customdata=embedding_indices_subset[mask][:, None],
)
)
if edges:
index_map = {orig_idx: pos for pos, orig_idx in enumerate(original_indices)}
edge_segments: Dict[int, Dict[str, List[float]]] = defaultdict(lambda: {"x": [], "y": [], "z": []})
for edge in edges[:MAX_EDGE_RENDER]:
source = int(edge["source"])
target = int(edge["target"])
if source not in index_map or target not in index_map:
continue
src_idx = index_map[source]
tgt_idx = index_map[target]
cluster_id = int(cluster_ids_subset[src_idx]) if src_idx < len(cluster_ids_subset) else -1
seg = edge_segments[cluster_id]
seg["x"].extend([coords[src_idx, 0], coords[tgt_idx, 0], None])
seg["y"].extend([coords[src_idx, 1], coords[tgt_idx, 1], None])
seg["z"].extend([coords[src_idx, 2], coords[tgt_idx, 2], None])
for cluster_id, seg in edge_segments.items():
cluster_color = cluster_color_map.get(cluster_id, (0.4, 0.4, 0.4, 1.0))
fig.add_trace(
go.Scatter3d(
x=seg["x"],
y=seg["y"],
z=seg["z"],
mode="lines",
line=dict(color=_float_rgba_to_plotly(cluster_color, alpha=0.18), width=1.3),
hoverinfo="none",
name=f"Cluster {cluster_id} edges",
showlegend=False,
)
)
if clusters:
fig.add_trace(
go.Scatter3d(
x=[c["centroid_3d"][0] for c in clusters],
y=[c["centroid_3d"][1] for c in clusters],
z=[c["centroid_3d"][2] for c in clusters],
mode="markers+text",
marker=dict(
symbol="diamond",
size=12,
color=[_float_rgba_to_plotly(cluster_color_map.get(int(c["cluster_id"]), (0.3, 0.3, 0.3, 1.0))) for c in clusters],
line=dict(width=1.5, color="#222222"),
),
text=[f"C{c['cluster_id']}" for c in clusters],
textposition="top center",
hovertext=[f"Cluster {c['cluster_id']}<br>Size: {c['size']}" for c in clusters],
hoverinfo="text",
name="Centroids",
showlegend=False,
)
)
fig.update_layout(
title="Corpus Embedding Map (3D)",
scene=dict(
xaxis_title="UMAP 1",
yaxis_title="UMAP 2",
zaxis_title="UMAP 3",
xaxis=dict(showgrid=True, zeroline=False, showbackground=False),
yaxis=dict(showgrid=True, zeroline=False, showbackground=False),
zaxis=dict(showgrid=True, zeroline=False, showbackground=False),
),
legend=dict(orientation="h", y=-0.1),
margin=dict(l=10, r=10, t=60, b=10),
template="plotly_white",
scene_camera=dict(eye=dict(x=1.6, y=1.6, z=0.9)),
hovermode="closest",
)
return fig
def render_plots(
show_edges: bool,
cluster_choice: str,
color_choice: str,
palette_choice: str,
) -> Tuple[Figure, go.Figure, pd.DataFrame, Dict[str, Any], Dict[str, Dict[str, Any]], List[Tuple[str, str]], Dict[str, Any]]:
"""Render the 2D and 3D figures with the requested options."""
(
_corpus,
papers,
_embeddings,
_normalised,
umap_2d,
umap_3d,
graph_edges,
cluster_metadata,
) = load_resources()
cluster_ids = np.array([paper.get("cluster_id", 0) for paper in papers], dtype=int)
if cluster_choice != "All":
cluster_value = int(cluster_choice)
mask = cluster_ids == cluster_value
clusters_for_plot = [c for c in cluster_metadata if int(c.get("cluster_id", -1)) == cluster_value]
else:
mask = np.ones(len(papers), dtype=bool)
clusters_for_plot = cluster_metadata
selected_indices = np.where(mask)[0]
if selected_indices.size == 0:
metrics_empty = {
"clusters": 0,
"points": 0,
"edges": 0,
"render_ms": {"2d": 0.0, "3d": 0.0},
}
return go.Figure(), go.Figure(), pd.DataFrame(), {}, {}, [], metrics_empty
filtered_papers = [papers[idx] for idx in selected_indices]
coords_2d = umap_2d[selected_indices]
coords_3d = umap_3d[selected_indices]
cluster_ids_subset = cluster_ids[selected_indices]
embedding_indices_subset = np.array([int(filtered_papers[i].get("embedding_idx", selected_indices[i])) for i in range(len(filtered_papers))])
selected_set = {int(idx) for idx in selected_indices.tolist()}
filtered_edges = _filter_edges(graph_edges, selected_set) if show_edges else []
color_basis_key = _resolve_color_basis(color_choice)
palette = _resolve_palette(palette_choice)
cluster_palette = _resolve_palette(DEFAULT_PALETTE)
cluster_color_map = _build_cluster_color_map(cluster_ids, cluster_palette)
if color_basis_key == "cluster":
label_values = np.array([str(paper.get("cluster_id", "unknown")) for paper in filtered_papers])
point_color_map = {str(cluster_id): cluster_color_map.get(int(cluster_id), (0.2, 0.4, 0.8, 1.0)) for cluster_id in label_values}
else:
label_values = np.array([paper.get("primary_category") or "unknown" for paper in filtered_papers])
unique_labels = sorted(set(label_values))
point_color_map = {label: palette[idx % len(palette)] for idx, label in enumerate(unique_labels)}
hover_text = _hover_text_for_papers(filtered_papers)
start_2d = time.perf_counter()
fig2d = _build_2d_plot(
coords_2d,
selected_indices,
label_values,
hover_text,
filtered_edges,
clusters_for_plot,
cluster_ids_subset,
point_color_map,
cluster_color_map,
)
render_2d_ms = (time.perf_counter() - start_2d) * 1000.0
start_3d = time.perf_counter()
fig3d = _build_3d_figure(
coords_3d,
selected_indices,
label_values,
hover_text,
filtered_edges,
clusters_for_plot,
cluster_ids_subset,
embedding_indices_subset,
point_color_map,
cluster_color_map,
)
render_3d_ms = (time.perf_counter() - start_3d) * 1000.0
overview_df = _build_cluster_overview(filtered_papers)
hierarchy_json = _build_cluster_hierarchy_json(filtered_papers)
paper_lookup = {
str(int(embedding_indices_subset[i])): {
"title": paper.get("title", "(untitled)"),
"paper_id": paper.get("paper_id"),
"cluster_id": paper.get("cluster_id"),
"primary_category": paper.get("primary_category"),
"authors": paper.get("authors", []),
"abstract": paper.get("abstract", ""),
"published": paper.get("published"),
"url": paper.get("meta", {}).get("url") if isinstance(paper.get("meta"), dict) else paper.get("url"),
}
for i, paper in enumerate(filtered_papers)
}
paper_options = [
(f"{details['title']} (C{details['cluster_id']})", str(idx))
for idx, details in paper_lookup.items()
]
metrics = {
"clusters": int(len(set(cluster_ids_subset))),
"points": int(len(selected_indices)),
"edges": int(len(filtered_edges)),
"render_ms": {
"2d": round(render_2d_ms, 2),
"3d": round(render_3d_ms, 2),
},
}
return fig2d, fig3d, overview_df, hierarchy_json, paper_lookup, paper_options, metrics
def refresh_embedding_plot() -> None:
"""Clear caches to force plot regeneration on next render."""
load_resources.cache_clear()
get_embedding_plots.cache_clear()
@lru_cache(maxsize=1)
def get_embedding_plots() -> Tuple[Figure, go.Figure, pd.DataFrame, Dict[str, Any], Dict[str, Dict[str, Any]], List[Tuple[str, str]], Dict[str, Any]]:
"""Return cached 2D and 3D plots plus cluster summaries using default settings."""
return render_plots(
show_edges=True,
cluster_choice="All",
color_choice=DEFAULT_COLOR_BASIS,
palette_choice=DEFAULT_PALETTE,
)
def _format_results(indices: np.ndarray, scores: np.ndarray, papers: Sequence[Dict[str, Any]]) -> List[List[Any]]:
"""Convert ranked results into display-friendly rows."""
formatted: List[List[Any]] = []
for rank, (idx, score) in enumerate(zip(indices, scores), start=1):
paper = papers[int(idx)]
abstract = str(paper.get("abstract", "")).strip()
summary = abstract[:220] + ("…" if len(abstract) > 220 else "")
formatted.append(
[
rank,
round(float(score), 4),
paper.get("title", "(untitled)"),
paper.get("paper_id", "N/A"),
summary,
]
)
return formatted
def search_corpus(query: str, top_k: int) -> List[List[Any]]:
"""Perform a cosine-similarity search over the local corpus."""
query = (query or "").strip()
if not query:
return []
_, papers, embeddings, normalised, _, _, _, _ = load_resources()
embedder = get_embedder(None)
query_vector = embedder.embed_query(query)
query_norm = query_vector / np.linalg.norm(query_vector)
scores = normalised @ query_norm
top_k = max(1, min(int(top_k), len(papers)))
ranked_indices = np.argsort(scores)[::-1][:top_k]
ranked_scores = scores[ranked_indices]
return _format_results(ranked_indices, ranked_scores, papers)
def _refresh_and_render(
show_edges: bool,
cluster_choice: str,
color_choice: str,
palette_choice: str,
) -> Tuple[Figure, go.Figure, pd.DataFrame, Dict[str, Any], Dict[str, Dict[str, Any]], List[Tuple[str, str]], Dict[str, Any]]:
refresh_embedding_plot()
return render_plots(show_edges, cluster_choice, color_choice, palette_choice)
def build_interface() -> gr.Blocks:
"""Assemble and return the Gradio Blocks interface."""
with gr.Blocks(title="NexaSci Mini Corpus Search") as demo:
gr.Markdown(
"""
# NexaSci Corpus Explorer
Enter a short description or paper title to retrieve the closest papers from the locally built corpus.
"""
)
with gr.Accordion("Corpus Builder", open=False):
categories_box = gr.Textbox(
label="Categories",
value="cs.AI cs.LG cs.CL stat.ML",
placeholder="Space-separated arXiv categories",
)
max_papers_slider = gr.Slider(label="Max papers", minimum=100, maximum=1000, step=50, value=500)
num_clusters_slider = gr.Slider(label="KMeans clusters", minimum=5, maximum=60, step=5, value=30)
batch_size_slider = gr.Slider(label="Embedding batch size", minimum=4, maximum=64, step=4, value=16)
build_button = gr.Button("Build Corpus", variant="primary")
build_status = gr.Markdown()
with gr.Row():
show_edges_checkbox = gr.Checkbox(label="Show graph edges", value=True)
cluster_dropdown = gr.Dropdown(
label="Cluster filter",
value="All",
choices=_cluster_options(),
)
color_basis_dropdown = gr.Radio(
label="Color by",
choices=list(COLOR_BASIS_OPTIONS.keys()),
value=DEFAULT_COLOR_BASIS,
)
palette_dropdown = gr.Dropdown(
label="Color palette",
choices=list(PALETTE_OPTIONS.keys()),
value=DEFAULT_PALETTE,
)
initial_2d, initial_3d, initial_overview, initial_hierarchy, initial_lookup, initial_options, initial_metrics = get_embedding_plots()
view_selector = gr.Radio(
label="Visualization",
choices=["2D", "3D"],
value="2D",
interactive=True,
)
embedding_plot = gr.Plot(label="Embedding", value=initial_2d, elem_id="embedding-plot")
controls_row = gr.Row()
with controls_row:
orbit_button = gr.Button("Toggle Orbit", variant="secondary")
fullscreen_button = gr.Button("Fullscreen", variant="secondary")
cluster_overview_table = gr.Dataframe(
value=initial_overview,
label="Cluster Overview",
interactive=False,
)
cluster_hierarchy_json = gr.JSON(value=initial_hierarchy, label="Cluster Hierarchy")
paper_state = gr.State(initial_lookup)
gr.Markdown("## Paper Details")
paper_selector = gr.Dropdown(
choices=initial_options,
label="Select Paper",
value=None,
)
paper_detail_display = gr.Markdown("Select a paper from the dropdown.")
metrics_json = gr.JSON(value=initial_metrics, label="Render Metrics")
def _build_corpus(max_papers: int, categories: str, num_clusters: int, batch_size: int,
show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, view: str):
cat_list = [c.strip() for c in categories.split() if c.strip()]
if not cat_list:
cat_list = ["cs.AI"]
cmd = [
sys.executable,
"-m",
"pipeline.build_corpus",
"--categories",
*cat_list,
"--max-papers",
str(int(max_papers)),
"--num-clusters",
str(int(num_clusters)),
"--batch-size",
str(int(batch_size)),
]
start = time.perf_counter()
result = subprocess.run(cmd, capture_output=True, text=True)
elapsed = time.perf_counter() - start
if result.returncode != 0:
logs = (result.stderr or result.stdout or "").strip()
if len(logs) > 800:
logs = "..." + logs[-800:]
status = f"❌ Corpus build failed in {elapsed:.1f}s\n```\n{logs}\n```"
else:
logs = (result.stdout or "Success").strip()
if len(logs) > 800:
logs = "..." + logs[-800:]
status = f"✅ Corpus rebuilt with {int(max_papers)} papers in {elapsed:.1f}s\n```\n{logs}\n```"
fig2d, fig3d, overview, hierarchy, lookup, options, metrics = _refresh_and_render(
show_edges, cluster_choice, color_choice, palette_choice
)
return (
status,
fig2d if view == "2D" else fig3d,
overview,
hierarchy,
lookup,
gr.update(choices=options, value=None),
"Select a paper from the dropdown.",
metrics,
)
def _update_plots(show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str):
return render_plots(show_edges, cluster_choice, color_choice, palette_choice)
refresh_button = gr.Button("Refresh Data")
def _refresh_and_update(show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, view: str):
fig2d, fig3d, overview, hierarchy, lookup, options, metrics = _refresh_and_render(
show_edges, cluster_choice, color_choice, palette_choice
)
if view == "3D":
fig3d.update_layout(margin=dict(l=10, r=10, t=60, b=10))
return (
fig2d if view == "2D" else fig3d,
overview,
hierarchy,
lookup,
gr.update(choices=options, value=None),
"Select a paper from the dropdown.",
metrics,
)
refresh_button.click(
_refresh_and_update,
inputs=[show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown, view_selector],
outputs=[embedding_plot, cluster_overview_table, cluster_hierarchy_json, paper_state, paper_selector, paper_detail_display, metrics_json],
)
def _update_visual(show_edges: bool, cluster_choice: str, color_choice: str, palette_choice: str, view: str):
fig2d, fig3d, overview, hierarchy, lookup, options, metrics = _update_plots(
show_edges, cluster_choice, color_choice, palette_choice
)
return (
fig2d if view == "2D" else fig3d,
overview,
hierarchy,
lookup,
gr.update(choices=options, value=None),
"Select a paper from the dropdown.",
metrics,
)
view_selector.change(
_update_visual,
inputs=[show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown, view_selector],
outputs=[embedding_plot, cluster_overview_table, cluster_hierarchy_json, paper_state, paper_selector, paper_detail_display, metrics_json],
)
for control in [show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown]:
control.change(
_update_visual,
inputs=[show_edges_checkbox, cluster_dropdown, color_basis_dropdown, palette_dropdown, view_selector],
outputs=[embedding_plot, cluster_overview_table, cluster_hierarchy_json, paper_state, paper_selector, paper_detail_display, metrics_json],
)
orbit_button.click(None, inputs=None, outputs=None, js=ORBIT_JS)
fullscreen_button.click(None, inputs=None, outputs=None, js=FULLSCREEN_JS)
build_button.click(
_build_corpus,
inputs=[
max_papers_slider,
categories_box,
num_clusters_slider,
batch_size_slider,
show_edges_checkbox,
cluster_dropdown,
color_basis_dropdown,
palette_dropdown,
view_selector,
],
outputs=[
build_status,
embedding_plot,
cluster_overview_table,
cluster_hierarchy_json,
paper_state,
paper_selector,
paper_detail_display,
metrics_json,
],
)
gr.Markdown("## Semantic Search")
with gr.Row():
query_input = gr.Textbox(
label="Query",
placeholder="e.g. graph neural networks for chemistry",
lines=2,
)
topk_slider = gr.Slider(
label="Top K Results",
minimum=1,
maximum=20,
step=1,
value=5,
)
results_table = gr.Dataframe(
headers=["rank", "score", "title", "paper_id", "summary"],
label="Results",
datatype=["number", "number", "str", "str", "str"],
interactive=False,
)
submit_btn = gr.Button("Search")
submit_btn.click(search_corpus, inputs=[query_input, topk_slider], outputs=[results_table])
def _format_details(selection: str | None, paper_map: Dict[str, Dict[str, Any]]):
if not selection:
return "Select a paper from the dropdown."
details = paper_map.get(selection)
if not details:
return "No details available for this paper."
authors = ", ".join(details.get("authors", [])) or "Unknown"
lines = [
f"### {details.get('title', '(untitled)')}",
f"**Paper ID:** {details.get('paper_id', 'N/A')}",
f"**Cluster:** {details.get('cluster_id', 'N/A')} | **Category:** {details.get('primary_category', 'unknown')}",
f"**Authors:** {authors}",
f"**Published:** {details.get('published', 'N/A')}",
"",
details.get("abstract", "No abstract available."),
]
url = details.get("url")
if url:
lines.append(f"\n[View paper]({url})")
return "\n\n".join(lines)
paper_selector.change(_format_details, inputs=[paper_selector, paper_state], outputs=paper_detail_display)
return demo
def main() -> None:
"""Launch the Gradio demo."""
interface = build_interface()
interface.launch()
if __name__ == "__main__": # pragma: no cover - manual launch helper
main()