marksverdhei's picture
Automatically create the dir
2e0c6aa
raw
history blame
5.57 kB
import os
import pickle
import streamlit as st
import pandas as pd
import vec2text
import torch
from transformers import AutoModel, AutoTokenizer
from umap import UMAP
from tqdm import tqdm
import plotly.express as px
import numpy as np
from sklearn.decomposition import PCA
from streamlit_plotly_events import plotly_events
import plotly.graph_objects as go
import logging
import utils
use_cpu = not torch.cuda.is_available()
device = "cpu" if use_cpu else "cuda"
# Custom file cache decorator
import os
import pickle
def file_cache(file_path):
def decorator(func):
def wrapper(*args, **kwargs):
# Ensure the directory exists
dir_path = os.path.dirname(file_path)
if not os.path.exists(dir_path):
os.makedirs(dir_path, exist_ok=True)
print(f"Created directory {dir_path}")
# Check if the file already exists
if os.path.exists(file_path):
# Load from cache
with open(file_path, "rb") as f:
print(f"Loading cached data from {file_path}")
return pickle.load(f)
else:
# Compute and save to cache
result = func(*args, **kwargs)
with open(file_path, "wb") as f:
pickle.dump(result, f)
print(f"Saving new cache to {file_path}")
return result
return wrapper
return decorator
@st.cache_resource
def vector_compressor_from_config():
# Return UMAP with 2 components for dimensionality reduction
# return UMAP(n_components=2)
return PCA(n_components=2)
# Caching the dataframe since loading from an external source can be time-consuming
@st.cache_data
def load_data():
return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
df = load_data()
# Caching the model and tokenizer to avoid reloading
@st.cache_resource
def load_model_and_tokenizer():
encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to(device)
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
return encoder, tokenizer
encoder, tokenizer = load_model_and_tokenizer()
# Caching the vec2text corrector
@st.cache_resource
def load_corrector():
return vec2text.load_pretrained_corrector("gtr-base")
corrector = load_corrector()
# Caching the precomputed embeddings since they are stored locally and large
@st.cache_data
def load_embeddings():
return np.load("syac-title-embeddings.npy")
embeddings = load_embeddings()
# Custom cache the UMAP reduction using file_cache decorator
@st.cache_data
@file_cache(".cache/reducer_embeddings.pickle")
def reduce_embeddings(embeddings):
reducer = vector_compressor_from_config()
return reducer.fit_transform(embeddings), reducer
vectors_2d, reducer = reduce_embeddings(embeddings)
# Add a scatter plot using Plotly
fig = px.scatter(
x=vectors_2d[:, 0],
y=vectors_2d[:, 1],
opacity=0.6,
hover_data={"Title": df["title"]},
labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
title="UMAP Scatter Plot of Reddit Titles",
color_discrete_sequence=["#ff504c"] # Set default blue color for points
)
# Customize the layout to adapt to browser settings (light/dark mode)
fig.update_layout(
template=None, # Let Plotly adapt automatically based on user settings
plot_bgcolor="rgba(0, 0, 0, 0)",
paper_bgcolor="rgba(0, 0, 0, 0)"
)
x, y = 0.0, 0.0
vec = np.array([x, y]).astype("float32")
# Add a card container to the right of the content with Streamlit columns
col1, col2 = st.columns([3, 1]) # Adjusting ratio to allocate space for the card container
with col1:
# Main content stays here (scatterplot, form, etc.)
selected_points = plotly_events(fig, click_event=True, hover_event=False, #override_height=600, override_width="100%"
)
with st.form(key="form1_main"):
if selected_points:
clicked_point = selected_points[0]
x_coord = x = clicked_point['x']
y_coord = y = clicked_point['y']
x = st.number_input("X Coordinate", value=x, format="%.10f")
y = st.number_input("Y Coordinate", value=y, format="%.10f")
vec = np.array([x, y]).astype("float32")
submit_button = st.form_submit_button("Submit")
if selected_points or submit_button:
inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
inferred_embedding = inferred_embedding.astype("float32")
output = vec2text.invert_embeddings(
embeddings=torch.tensor(inferred_embedding).cuda(),
corrector=corrector,
num_steps=20,
)
st.text(str(output))
st.text(str(inferred_embedding))
else:
st.text("Click on a point in the scatterplot to see its coordinates.")
with col2:
closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
st.write(f"{vectors_2d.dtype} {vec.dtype}")
if closest_sentence_index > -1:
st.write(df["title"].iloc[closest_sentence_index])
# Card content
st.markdown("## Card Container")
st.write("This is an additional card container to the right of the main content.")
st.write("You can use this space to show additional information, actions, or insights.")
st.button("Card Button")