Commit
·
3eacaec
1
Parent(s):
010edb7
Add sidebar and fix distances
Browse files
app.py
CHANGED
@@ -21,6 +21,17 @@ def load_embeddings():
|
|
21 |
embeddings = load_embeddings()
|
22 |
vectors_2d, reducer = reduce_embeddings(embeddings)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
tab1, tab2 = st.tabs(["plot", "diffs"])
|
25 |
|
26 |
with tab1:
|
|
|
21 |
embeddings = load_embeddings()
|
22 |
vectors_2d, reducer = reduce_embeddings(embeddings)
|
23 |
|
24 |
+
def sidebar():
|
25 |
+
st.sidebar.title("About this app")
|
26 |
+
st.sidebar.markdown(
|
27 |
+
"This app is intended to give a more intuitive and interactive understanding of sequence embeddings (e.g. sentence), \n"
|
28 |
+
"through interactive plots and operations with these embeddings, with a focus on embedding inversion.\n"
|
29 |
+
"We explore both sequence embedding inversion using the method described in [Morris et al., 2023](https://arxiv.org/abs/2310.06816), as well as"
|
30 |
+
" dimensionality rediction transforms and inverse transforms, and its effect on embedding inversion."
|
31 |
+
)
|
32 |
+
|
33 |
+
sidebar()
|
34 |
+
|
35 |
tab1, tab2 = st.tabs(["plot", "diffs"])
|
36 |
|
37 |
with tab1:
|
views.py
CHANGED
@@ -8,6 +8,7 @@ from streamlit_plotly_events import plotly_events
|
|
8 |
from resources import reduce_embeddings
|
9 |
import utils
|
10 |
import pandas as pd
|
|
|
11 |
|
12 |
|
13 |
def diffs(embeddings: np.ndarray, corrector):
|
@@ -37,7 +38,7 @@ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, correcto
|
|
37 |
|
38 |
x, y = 0.0, 0.0
|
39 |
vec = np.array([x, y]).astype("float32")
|
40 |
-
|
41 |
# Add a card container to the right of the content with Streamlit columns
|
42 |
col1, col2 = st.columns([0.6, 0.4]) # Adjusting ratio to allocate space for the card container
|
43 |
|
@@ -48,8 +49,8 @@ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, correcto
|
|
48 |
with st.form(key="form1_main"):
|
49 |
if selected_points:
|
50 |
clicked_point = selected_points[0]
|
51 |
-
|
52 |
-
|
53 |
|
54 |
x = st.number_input("X Coordinate", value=x, format="%.10f")
|
55 |
y = st.number_input("Y Coordinate", value=y, format="%.10f")
|
@@ -75,6 +76,17 @@ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, correcto
|
|
75 |
|
76 |
with col2:
|
77 |
closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
|
|
|
|
|
|
|
78 |
st.markdown(
|
79 |
-
f"### Selected text:\n```console\n{
|
80 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from resources import reduce_embeddings
|
9 |
import utils
|
10 |
import pandas as pd
|
11 |
+
from scipy.spatial import distance
|
12 |
|
13 |
|
14 |
def diffs(embeddings: np.ndarray, corrector):
|
|
|
38 |
|
39 |
x, y = 0.0, 0.0
|
40 |
vec = np.array([x, y]).astype("float32")
|
41 |
+
inferred_embedding = None
|
42 |
# Add a card container to the right of the content with Streamlit columns
|
43 |
col1, col2 = st.columns([0.6, 0.4]) # Adjusting ratio to allocate space for the card container
|
44 |
|
|
|
49 |
with st.form(key="form1_main"):
|
50 |
if selected_points:
|
51 |
clicked_point = selected_points[0]
|
52 |
+
x = clicked_point['x']
|
53 |
+
y = clicked_point['y']
|
54 |
|
55 |
x = st.number_input("X Coordinate", value=x, format="%.10f")
|
56 |
y = st.number_input("Y Coordinate", value=y, format="%.10f")
|
|
|
76 |
|
77 |
with col2:
|
78 |
closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
|
79 |
+
selected_sentence = df.title.iloc[closest_sentence_index] if closest_sentence_index > -1 else None
|
80 |
+
selected_sentence_embedding = embeddings[closest_sentence_index] if closest_sentence_index > -1 else None
|
81 |
+
|
82 |
st.markdown(
|
83 |
+
f"### Selected text:\n```console\n{selected_sentence}\n```"
|
84 |
)
|
85 |
+
|
86 |
+
if inferred_embedding is not None and (closest_sentence_index != -1):
|
87 |
+
couple = selected_sentence_embedding.squeeze(), inferred_embedding.squeeze()
|
88 |
+
st.markdown(f"### Inferred embedding distance:")
|
89 |
+
st.number_input("Euclidean", value=distance.euclidean(
|
90 |
+
*couple
|
91 |
+
), disabled=True)
|
92 |
+
st.number_input("Cosine", value=distance.cosine(*couple), disabled=True)
|