marksverdhei's picture
Add sidebar and fix distances
3eacaec
raw
history blame
3.77 kB
import streamlit as st
import vec2text
import torch
from umap import UMAP
import plotly.express as px
import numpy as np
from streamlit_plotly_events import plotly_events
from resources import reduce_embeddings
import utils
import pandas as pd
from scipy.spatial import distance
def diffs(embeddings: np.ndarray, corrector):
st.text(f"Embedding shape: {embeddings.shape}")
st.html('<a href="https://www.flaticon.com/free-icons/array" title="array icons">Array icons created by Voysla - Flaticon</a>')
def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, corrector):
# Add a scatter plot using Plotly
fig = px.scatter(
x=vectors_2d[:, 0],
y=vectors_2d[:, 1],
opacity=0.6,
hover_data={"Title": df["title"]},
labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
title="UMAP Scatter Plot of Reddit Titles",
color_discrete_sequence=["#ff504c"] # Set default blue color for points
)
# Customize the layout to adapt to browser settings (light/dark mode)
fig.update_layout(
template=None, # Let Plotly adapt automatically based on user settings
plot_bgcolor="rgba(0, 0, 0, 0)",
paper_bgcolor="rgba(0, 0, 0, 0)"
)
x, y = 0.0, 0.0
vec = np.array([x, y]).astype("float32")
inferred_embedding = None
# Add a card container to the right of the content with Streamlit columns
col1, col2 = st.columns([0.6, 0.4]) # Adjusting ratio to allocate space for the card container
with col1:
# Main content stays here (scatterplot, form, etc.)
selected_points = plotly_events(fig, click_event=True, hover_event=False, #override_height=600, override_width="100%"
)
with st.form(key="form1_main"):
if selected_points:
clicked_point = selected_points[0]
x = clicked_point['x']
y = clicked_point['y']
x = st.number_input("X Coordinate", value=x, format="%.10f")
y = st.number_input("Y Coordinate", value=y, format="%.10f")
vec = np.array([x, y]).astype("float32")
submit_button = st.form_submit_button("Submit")
if selected_points or submit_button:
inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
inferred_embedding = inferred_embedding.astype("float32")
output = vec2text.invert_embeddings(
embeddings=torch.tensor(inferred_embedding).cuda(),
corrector=corrector,
num_steps=20,
)
st.text(str(output))
st.text(str(inferred_embedding))
else:
st.text("Click on a point in the scatterplot to see its coordinates.")
with col2:
closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
selected_sentence = df.title.iloc[closest_sentence_index] if closest_sentence_index > -1 else None
selected_sentence_embedding = embeddings[closest_sentence_index] if closest_sentence_index > -1 else None
st.markdown(
f"### Selected text:\n```console\n{selected_sentence}\n```"
)
if inferred_embedding is not None and (closest_sentence_index != -1):
couple = selected_sentence_embedding.squeeze(), inferred_embedding.squeeze()
st.markdown(f"### Inferred embedding distance:")
st.number_input("Euclidean", value=distance.euclidean(
*couple
), disabled=True)
st.number_input("Cosine", value=distance.cosine(*couple), disabled=True)