import numpy as np
import pandas as pd
import streamlit as st
import umap
import hdbscan
import plotly.express as px
import plotly.graph_objects as go
from utils.console_manager import console_manager
from embeddings.embedder import (
initialize_chroma,
initialize_embedding_model,
extract_embeddings,
)
def reduce_dimensionality(embeddings: np.ndarray, n_components: int = 3):
reducer = umap.UMAP(
n_neighbors=15,
min_dist=0.1,
n_components=n_components,
metric="cosine",
random_state=42,
)
embedding_3d = reducer.fit_transform(embeddings)
console_manager.print_info(
f"UMAP dimensionality reduction done: {embedding_3d.shape}"
)
return embedding_3d
def cluster_embeddings(embedding_3d: np.ndarray, min_cluster_size: int = 20):
clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size)
labels = clusterer.fit_predict(embedding_3d)
console_manager.print_info(
f"HDBSCAN clustering done: {len(set(labels))} clusters found"
)
return labels
def visualize_3d(embedding_3d: np.ndarray, metadata: list, labels: np.ndarray):
df_vis = pd.DataFrame(
{
"x": embedding_3d[:, 0],
"y": embedding_3d[:, 1],
"z": embedding_3d[:, 2],
"title": [m.get("title", "") for m in metadata],
"category": [m.get("categories", "") for m in metadata],
"year": [m.get("year", 0) for m in metadata],
"cluster": labels,
}
)
# Count number of clusters (excluding outliers)
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
# Define color map (outliers = black)
unique_labels = sorted(set(labels))
palette = px.colors.qualitative.Plotly
color_map = {
label: ("black" if label == -1 else palette[label % len(palette)])
for label in unique_labels
}
fig = go.Figure()
for label in unique_labels:
cluster_points = df_vis[df_vis["cluster"] == label]
color = color_map[label]
name = f"Cluster {label}" if label != -1 else "Outliers"
hover_text = (
"Title: %{customdata[0]}
"
"Category: %{customdata[1]}
"
"Year: %{customdata[2]}
"
"Cluster: %{customdata[3]}"
)
fig.add_trace(
go.Scatter3d(
x=cluster_points["x"],
y=cluster_points["y"],
z=cluster_points["z"],
mode="markers",
marker=dict(size=4, color=color, opacity=0.8),
name=name,
customdata=np.stack(
[
cluster_points["title"],
cluster_points["category"],
cluster_points["year"],
cluster_points["cluster"],
],
axis=-1,
),
hovertemplate=hover_text,
)
)
fig.update_layout(
title=f"Clusters: {n_clusters} ",
scene=dict(
xaxis_title="Dimension 1",
yaxis_title="Dimension 2",
zaxis_title="Dimension 3",
),
legend=dict(itemsizing="constant"),
)
st.plotly_chart(fig, use_container_width=True)
def run_clustering_pipeline(embedding_model=None, vectordb=None):
with console_manager.status("Running clustering pipeline..."):
if embedding_model is None:
embedding_model = initialize_embedding_model()
if vectordb is None:
vectordb = initialize_chroma(embedding_model)
if vectordb is None:
st.warning("No ChromaDB found. Run embeddings generation first.")
return
embeddings, metadata = extract_embeddings(vectordb)
embedding_3d = reduce_dimensionality(embeddings)
labels = cluster_embeddings(embedding_3d)
visualize_3d(embedding_3d, metadata, labels)