import streamlit as st import pandas as pd import vec2text import torch from transformers import AutoModel, AutoTokenizer from umap import UMAP from tqdm import tqdm import plotly.express as px import numpy as np from sklearn.decomposition import PCA from streamlit_plotly_events import plotly_events import plotly.graph_objects as go import logging # Activate tqdm with pandas tqdm.pandas() @st.cache_resource def vector_compressor_from_config(): 'TODO' # return PCA(2) return UMAP(2) # Caching the dataframe since loading from external source can be time-consuming @st.cache_data def load_data(): return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") df = load_data() # Caching the model and tokenizer to avoid reloading @st.cache_resource def load_model_and_tokenizer(): encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda") tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") return encoder, tokenizer encoder, tokenizer = load_model_and_tokenizer() # Caching the vec2text corrector @st.cache_resource def load_corrector(): return vec2text.load_pretrained_corrector("gtr-base") corrector = load_corrector() # Caching the precomputed embeddings since they are stored locally and large @st.cache_data def load_embeddings(): return np.load("syac-title-embeddings.npy") embeddings = load_embeddings() # Caching UMAP reduction as it's a heavy computation @st.cache_resource def reduce_embeddings(embeddings): reducer = vector_compressor_from_config() return reducer.fit_transform(embeddings), reducer vectors_2d, reducer = reduce_embeddings(embeddings) # Add a scatter plot using Plotly fig = px.scatter( x=vectors_2d[:, 0], y=vectors_2d[:, 1], opacity=0.6, hover_data={"Title": df["title"]}, labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'}, title="UMAP Scatter Plot of Reddit Titles", color_discrete_sequence=["#01a8d3"] # Set default blue color for points ) # Customize the layout to adapt to browser settings (light/dark mode) fig.update_layout( template=None, # Let Plotly adapt automatically based on user settings plot_bgcolor="rgba(0, 0, 0, 0)", paper_bgcolor="rgba(0, 0, 0, 0)" ) # Display the scatterplot and capture click events selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%") # If a point is clicked, handle the embedding inversion if selected_points: clicked_point = selected_points[0] x_coord = x = clicked_point['x'] y_coord = y = clicked_point['y'] st.text(f"Embeddings shape: {embeddings.shape}") st.text(f"2dvector shapes shape: {vectors_2d.shape}") st.text(f"Clicked point coordinates: x = {x_coord}, y = {y_coord}") st.text("fOO") logging.info("Foo") inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]])) logging.info("Bar") st.text("Bar") inferred_embedding = inferred_embedding.astype("float32") st.text("Bar") output = vec2text.invert_embeddings( embeddings=torch.tensor(inferred_embedding).cuda(), corrector=corrector, num_steps=20, ) st.text("Bar") st.text(str(output)) st.text(str(inferred_embedding)) else: st.text("Click on a point in the scatterplot to see its coordinates.")