import os import pickle import streamlit as st import pandas as pd import vec2text import torch from transformers import AutoModel, AutoTokenizer from umap import UMAP from tqdm import tqdm import plotly.express as px import numpy as np from sklearn.decomposition import PCA from streamlit_plotly_events import plotly_events import plotly.graph_objects as go import logging # Activate tqdm with pandas tqdm.pandas() # Custom file cache decorator def file_cache(file_path): def decorator(func): def wrapper(*args, **kwargs): # Check if the file already exists if os.path.exists(file_path): # Load from cache with open(file_path, "rb") as f: print(f"Loading cached data from {file_path}") return pickle.load(f) else: # Compute and save to cache result = func(*args, **kwargs) with open(file_path, "wb") as f: pickle.dump(result, f) print(f"Saving new cache to {file_path}") return result return wrapper return decorator @st.cache_resource def vector_compressor_from_config(): # Return UMAP with 2 components for dimensionality reduction return UMAP(n_components=2) # Caching the dataframe since loading from an external source can be time-consuming @st.cache_data def load_data(): return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") df = load_data() # Caching the model and tokenizer to avoid reloading @st.cache_resource def load_model_and_tokenizer(): encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda") tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") return encoder, tokenizer encoder, tokenizer = load_model_and_tokenizer() # Caching the vec2text corrector @st.cache_resource def load_corrector(): return vec2text.load_pretrained_corrector("gtr-base") corrector = load_corrector() # Caching the precomputed embeddings since they are stored locally and large @st.cache_data def load_embeddings(): return np.load("syac-title-embeddings.npy") embeddings = load_embeddings() # Custom cache the UMAP reduction using file_cache decorator @st.cache_data @file_cache(".cache/reducer_embeddings.pickle") def reduce_embeddings(embeddings): reducer = vector_compressor_from_config() return reducer.fit_transform(embeddings), reducer vectors_2d, reducer = reduce_embeddings(embeddings) # Add a scatter plot using Plotly fig = px.scatter( x=vectors_2d[:, 0], y=vectors_2d[:, 1], opacity=0.6, hover_data={"Title": df["title"]}, labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'}, title="UMAP Scatter Plot of Reddit Titles", color_discrete_sequence=["#ff504c"] # Set default blue color for points ) # Customize the layout to adapt to browser settings (light/dark mode) fig.update_layout( template=None, # Let Plotly adapt automatically based on user settings plot_bgcolor="rgba(0, 0, 0, 0)", paper_bgcolor="rgba(0, 0, 0, 0)" ) x, y = 0.0, 0.0 # Display the scatterplot and capture click events selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%") # Sidebar for additional information st.sidebar.header("Scatter Plot Info") st.sidebar.write(""" This scatter plot visualizes the UMAP dimensionality reduction of Reddit post titles. Each point represents a post, with similar titles being positioned closer together. """) st.sidebar.write("Use the form below to select coordinates or click on a point in the scatter plot.") st.sidebar.markdown("---") st.sidebar.header("How to Use") st.sidebar.write(""" 1. **Click a point** in the scatter plot to see the corresponding coordinates. 2. **Adjust the coordinates** using the form inputs if needed. 3. **Submit** to see the reconstructed text output. """) # Form for inputting coordinates with st.form(key="form1"): # If a point is clicked, handle the embedding inversion if selected_points: clicked_point = selected_points[0] x_coord = x = clicked_point['x'] y_coord = y = clicked_point['y'] x = st.number_input("X Coordinate", value=x, format="%.10f") y = st.number_input("Y Coordinate", value=y, format="%.10f") submit_button = st.form_submit_button("Submit") if selected_points or submit_button: inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]])) inferred_embedding = inferred_embedding.astype("float32") output = vec2text.invert_embeddings( embeddings=torch.tensor(inferred_embedding).cuda(), corrector=corrector, num_steps=20, ) st.text(str(output)) st.text(str(inferred_embedding)) else: st.text("Click on a point in the scatterplot to see its coordinates.")