import os import pickle import streamlit as st import pandas as pd import vec2text import torch from transformers import AutoModel, AutoTokenizer from umap import UMAP from tqdm import tqdm import plotly.express as px import numpy as np from sklearn.decomposition import PCA from streamlit_plotly_events import plotly_events import plotly.graph_objects as go import logging import utils use_cpu = not torch.cuda.is_available() device = "cpu" if use_cpu else "cuda" # Custom file cache decorator import os import pickle def file_cache(file_path): def decorator(func): def wrapper(*args, **kwargs): # Ensure the directory exists dir_path = os.path.dirname(file_path) if not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) print(f"Created directory {dir_path}") # Check if the file already exists if os.path.exists(file_path): # Load from cache with open(file_path, "rb") as f: print(f"Loading cached data from {file_path}") return pickle.load(f) else: # Compute and save to cache result = func(*args, **kwargs) with open(file_path, "wb") as f: pickle.dump(result, f) print(f"Saving new cache to {file_path}") return result return wrapper return decorator @st.cache_resource def vector_compressor_from_config(): # Return UMAP with 2 components for dimensionality reduction # return UMAP(n_components=2) return PCA(n_components=2) # Caching the dataframe since loading from an external source can be time-consuming @st.cache_data def load_data(): return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") df = load_data() # Caching the model and tokenizer to avoid reloading @st.cache_resource def load_model_and_tokenizer(): encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to(device) tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") return encoder, tokenizer encoder, tokenizer = load_model_and_tokenizer() # Caching the vec2text corrector @st.cache_resource def load_corrector(): return vec2text.load_pretrained_corrector("gtr-base") corrector = load_corrector() # Caching the precomputed embeddings since they are stored locally and large @st.cache_data def load_embeddings(): return np.load("syac-title-embeddings.npy") embeddings = load_embeddings() # Custom cache the UMAP reduction using file_cache decorator @st.cache_data @file_cache(".cache/reducer_embeddings.pickle") def reduce_embeddings(embeddings): reducer = vector_compressor_from_config() return reducer.fit_transform(embeddings), reducer vectors_2d, reducer = reduce_embeddings(embeddings) # Add a scatter plot using Plotly fig = px.scatter( x=vectors_2d[:, 0], y=vectors_2d[:, 1], opacity=0.6, hover_data={"Title": df["title"]}, labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'}, title="UMAP Scatter Plot of Reddit Titles", color_discrete_sequence=["#ff504c"] # Set default blue color for points ) # Customize the layout to adapt to browser settings (light/dark mode) fig.update_layout( template=None, # Let Plotly adapt automatically based on user settings plot_bgcolor="rgba(0, 0, 0, 0)", paper_bgcolor="rgba(0, 0, 0, 0)" ) x, y = 0.0, 0.0 vec = np.array([x, y]).astype("float32") # Add a card container to the right of the content with Streamlit columns col1, col2 = st.columns([3, 1]) # Adjusting ratio to allocate space for the card container with col1: # Main content stays here (scatterplot, form, etc.) selected_points = plotly_events(fig, click_event=True, hover_event=False, #override_height=600, override_width="100%" ) with st.form(key="form1_main"): if selected_points: clicked_point = selected_points[0] x_coord = x = clicked_point['x'] y_coord = y = clicked_point['y'] x = st.number_input("X Coordinate", value=x, format="%.10f") y = st.number_input("Y Coordinate", value=y, format="%.10f") vec = np.array([x, y]).astype("float32") submit_button = st.form_submit_button("Submit") if selected_points or submit_button: inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]])) inferred_embedding = inferred_embedding.astype("float32") output = vec2text.invert_embeddings( embeddings=torch.tensor(inferred_embedding).cuda(), corrector=corrector, num_steps=20, ) st.text(str(output)) st.text(str(inferred_embedding)) else: st.text("Click on a point in the scatterplot to see its coordinates.") with col2: closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3) st.write(f"{vectors_2d.dtype} {vec.dtype}") if closest_sentence_index > -1: st.write(df["title"].iloc[closest_sentence_index]) # Card content st.markdown("## Card Container") st.write("This is an additional card container to the right of the main content.") st.write("You can use this space to show additional information, actions, or insights.") st.button("Card Button")