Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import umap | |
| import hdbscan | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from utils.console_manager import console_manager | |
| from embeddings.embedder import ( | |
| initialize_chroma, | |
| initialize_embedding_model, | |
| extract_embeddings, | |
| ) | |
| def reduce_dimensionality(embeddings: np.ndarray, n_components: int = 3): | |
| reducer = umap.UMAP( | |
| n_neighbors=15, | |
| min_dist=0.1, | |
| n_components=n_components, | |
| metric="cosine", | |
| random_state=42, | |
| ) | |
| embedding_3d = reducer.fit_transform(embeddings) | |
| console_manager.print_info( | |
| f"UMAP dimensionality reduction done: {embedding_3d.shape}" | |
| ) | |
| return embedding_3d | |
| def cluster_embeddings(embedding_3d: np.ndarray, min_cluster_size: int = 20): | |
| clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size) | |
| labels = clusterer.fit_predict(embedding_3d) | |
| console_manager.print_info( | |
| f"HDBSCAN clustering done: {len(set(labels))} clusters found" | |
| ) | |
| return labels | |
| def visualize_3d(embedding_3d: np.ndarray, metadata: list, labels: np.ndarray): | |
| df_vis = pd.DataFrame( | |
| { | |
| "x": embedding_3d[:, 0], | |
| "y": embedding_3d[:, 1], | |
| "z": embedding_3d[:, 2], | |
| "title": [m.get("title", "") for m in metadata], | |
| "category": [m.get("categories", "") for m in metadata], | |
| "year": [m.get("year", 0) for m in metadata], | |
| "cluster": labels, | |
| } | |
| ) | |
| # Count number of clusters (excluding outliers) | |
| n_clusters = len(set(labels)) - (1 if -1 in labels else 0) | |
| # Define color map (outliers = black) | |
| unique_labels = sorted(set(labels)) | |
| palette = px.colors.qualitative.Plotly | |
| color_map = { | |
| label: ("black" if label == -1 else palette[label % len(palette)]) | |
| for label in unique_labels | |
| } | |
| fig = go.Figure() | |
| for label in unique_labels: | |
| cluster_points = df_vis[df_vis["cluster"] == label] | |
| color = color_map[label] | |
| name = f"Cluster {label}" if label != -1 else "Outliers" | |
| hover_text = ( | |
| "Title: %{customdata[0]}<br>" | |
| "Category: %{customdata[1]}<br>" | |
| "Year: %{customdata[2]}<br>" | |
| "Cluster: %{customdata[3]}" | |
| ) | |
| fig.add_trace( | |
| go.Scatter3d( | |
| x=cluster_points["x"], | |
| y=cluster_points["y"], | |
| z=cluster_points["z"], | |
| mode="markers", | |
| marker=dict(size=4, color=color, opacity=0.8), | |
| name=name, | |
| customdata=np.stack( | |
| [ | |
| cluster_points["title"], | |
| cluster_points["category"], | |
| cluster_points["year"], | |
| cluster_points["cluster"], | |
| ], | |
| axis=-1, | |
| ), | |
| hovertemplate=hover_text, | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=f"Clusters: {n_clusters} ", | |
| scene=dict( | |
| xaxis_title="Dimension 1", | |
| yaxis_title="Dimension 2", | |
| zaxis_title="Dimension 3", | |
| ), | |
| legend=dict(itemsizing="constant"), | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| def run_clustering_pipeline(embedding_model=None, vectordb=None): | |
| with console_manager.status("Running clustering pipeline..."): | |
| if embedding_model is None: | |
| embedding_model = initialize_embedding_model() | |
| if vectordb is None: | |
| vectordb = initialize_chroma(embedding_model) | |
| if vectordb is None: | |
| st.warning("No ChromaDB found. Run embeddings generation first.") | |
| return | |
| embeddings, metadata = extract_embeddings(vectordb) | |
| embedding_3d = reduce_dimensionality(embeddings) | |
| labels = cluster_embeddings(embedding_3d) | |
| visualize_3d(embedding_3d, metadata, labels) | |