|
import os |
|
import pickle |
|
import streamlit as st |
|
import pandas as pd |
|
import vec2text |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
from umap import UMAP |
|
from tqdm import tqdm |
|
import plotly.express as px |
|
import numpy as np |
|
from sklearn.decomposition import PCA |
|
from streamlit_plotly_events import plotly_events |
|
import plotly.graph_objects as go |
|
import logging |
|
import utils |
|
|
|
use_cpu = not torch.cuda.is_available() |
|
device = "cpu" if use_cpu else "cuda" |
|
|
|
|
|
import os |
|
import pickle |
|
|
|
def file_cache(file_path): |
|
def decorator(func): |
|
def wrapper(*args, **kwargs): |
|
|
|
dir_path = os.path.dirname(file_path) |
|
if not os.path.exists(dir_path): |
|
os.makedirs(dir_path, exist_ok=True) |
|
print(f"Created directory {dir_path}") |
|
|
|
|
|
if os.path.exists(file_path): |
|
|
|
with open(file_path, "rb") as f: |
|
print(f"Loading cached data from {file_path}") |
|
return pickle.load(f) |
|
else: |
|
|
|
result = func(*args, **kwargs) |
|
with open(file_path, "wb") as f: |
|
pickle.dump(result, f) |
|
print(f"Saving new cache to {file_path}") |
|
return result |
|
return wrapper |
|
return decorator |
|
|
|
@st.cache_resource |
|
def vector_compressor_from_config(): |
|
|
|
|
|
return PCA(n_components=2) |
|
|
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") |
|
|
|
df = load_data() |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(): |
|
encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to(device) |
|
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") |
|
return encoder, tokenizer |
|
|
|
encoder, tokenizer = load_model_and_tokenizer() |
|
|
|
|
|
@st.cache_resource |
|
def load_corrector(): |
|
return vec2text.load_pretrained_corrector("gtr-base") |
|
|
|
corrector = load_corrector() |
|
|
|
|
|
@st.cache_data |
|
def load_embeddings(): |
|
return np.load("syac-title-embeddings.npy") |
|
|
|
embeddings = load_embeddings() |
|
|
|
|
|
@st.cache_data |
|
@file_cache(".cache/reducer_embeddings.pickle") |
|
def reduce_embeddings(embeddings): |
|
reducer = vector_compressor_from_config() |
|
return reducer.fit_transform(embeddings), reducer |
|
|
|
vectors_2d, reducer = reduce_embeddings(embeddings) |
|
|
|
|
|
fig = px.scatter( |
|
x=vectors_2d[:, 0], |
|
y=vectors_2d[:, 1], |
|
opacity=0.6, |
|
hover_data={"Title": df["title"]}, |
|
labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'}, |
|
title="UMAP Scatter Plot of Reddit Titles", |
|
color_discrete_sequence=["#ff504c"] |
|
) |
|
|
|
|
|
fig.update_layout( |
|
template=None, |
|
plot_bgcolor="rgba(0, 0, 0, 0)", |
|
paper_bgcolor="rgba(0, 0, 0, 0)" |
|
) |
|
|
|
x, y = 0.0, 0.0 |
|
vec = np.array([x, y]).astype("float32") |
|
|
|
|
|
col1, col2 = st.columns([3, 1]) |
|
|
|
with col1: |
|
|
|
selected_points = plotly_events(fig, click_event=True, hover_event=False, |
|
) |
|
with st.form(key="form1_main"): |
|
if selected_points: |
|
clicked_point = selected_points[0] |
|
x_coord = x = clicked_point['x'] |
|
y_coord = y = clicked_point['y'] |
|
|
|
x = st.number_input("X Coordinate", value=x, format="%.10f") |
|
y = st.number_input("Y Coordinate", value=y, format="%.10f") |
|
vec = np.array([x, y]).astype("float32") |
|
|
|
|
|
submit_button = st.form_submit_button("Submit") |
|
|
|
if selected_points or submit_button: |
|
inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]])) |
|
inferred_embedding = inferred_embedding.astype("float32") |
|
|
|
output = vec2text.invert_embeddings( |
|
embeddings=torch.tensor(inferred_embedding).cuda(), |
|
corrector=corrector, |
|
num_steps=20, |
|
) |
|
|
|
st.text(str(output)) |
|
st.text(str(inferred_embedding)) |
|
else: |
|
st.text("Click on a point in the scatterplot to see its coordinates.") |
|
|
|
with col2: |
|
closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3) |
|
st.write(f"{vectors_2d.dtype} {vec.dtype}") |
|
if closest_sentence_index > -1: |
|
st.write(df["title"].iloc[closest_sentence_index]) |
|
|
|
st.markdown("## Card Container") |
|
st.write("This is an additional card container to the right of the main content.") |
|
st.write("You can use this space to show additional information, actions, or insights.") |
|
st.button("Card Button") |