import numpy as np import pandas as pd import streamlit as st import umap import hdbscan import plotly.express as px import plotly.graph_objects as go from utils.console_manager import console_manager from embeddings.embedder import ( initialize_chroma, initialize_embedding_model, extract_embeddings, ) def reduce_dimensionality(embeddings: np.ndarray, n_components: int = 3): reducer = umap.UMAP( n_neighbors=15, min_dist=0.1, n_components=n_components, metric="cosine", random_state=42, ) embedding_3d = reducer.fit_transform(embeddings) console_manager.print_info( f"UMAP dimensionality reduction done: {embedding_3d.shape}" ) return embedding_3d def cluster_embeddings(embedding_3d: np.ndarray, min_cluster_size: int = 20): clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size) labels = clusterer.fit_predict(embedding_3d) console_manager.print_info( f"HDBSCAN clustering done: {len(set(labels))} clusters found" ) return labels def visualize_3d(embedding_3d: np.ndarray, metadata: list, labels: np.ndarray): df_vis = pd.DataFrame( { "x": embedding_3d[:, 0], "y": embedding_3d[:, 1], "z": embedding_3d[:, 2], "title": [m.get("title", "") for m in metadata], "category": [m.get("categories", "") for m in metadata], "year": [m.get("year", 0) for m in metadata], "cluster": labels, } ) # Count number of clusters (excluding outliers) n_clusters = len(set(labels)) - (1 if -1 in labels else 0) # Define color map (outliers = black) unique_labels = sorted(set(labels)) palette = px.colors.qualitative.Plotly color_map = { label: ("black" if label == -1 else palette[label % len(palette)]) for label in unique_labels } fig = go.Figure() for label in unique_labels: cluster_points = df_vis[df_vis["cluster"] == label] color = color_map[label] name = f"Cluster {label}" if label != -1 else "Outliers" hover_text = ( "Title: %{customdata[0]}
" "Category: %{customdata[1]}
" "Year: %{customdata[2]}
" "Cluster: %{customdata[3]}" ) fig.add_trace( go.Scatter3d( x=cluster_points["x"], y=cluster_points["y"], z=cluster_points["z"], mode="markers", marker=dict(size=4, color=color, opacity=0.8), name=name, customdata=np.stack( [ cluster_points["title"], cluster_points["category"], cluster_points["year"], cluster_points["cluster"], ], axis=-1, ), hovertemplate=hover_text, ) ) fig.update_layout( title=f"Clusters: {n_clusters} ", scene=dict( xaxis_title="Dimension 1", yaxis_title="Dimension 2", zaxis_title="Dimension 3", ), legend=dict(itemsizing="constant"), ) st.plotly_chart(fig, use_container_width=True) def run_clustering_pipeline(embedding_model=None, vectordb=None): with console_manager.status("Running clustering pipeline..."): if embedding_model is None: embedding_model = initialize_embedding_model() if vectordb is None: vectordb = initialize_chroma(embedding_model) if vectordb is None: st.warning("No ChromaDB found. Run embeddings generation first.") return embeddings, metadata = extract_embeddings(vectordb) embedding_3d = reduce_dimensionality(embeddings) labels = cluster_embeddings(embedding_3d) visualize_3d(embedding_3d, metadata, labels)