marksverdhei commited on
Commit
3eacaec
·
1 Parent(s): 010edb7

Add sidebar and fix distances

Browse files
Files changed (2) hide show
  1. app.py +11 -0
  2. views.py +16 -4
app.py CHANGED
@@ -21,6 +21,17 @@ def load_embeddings():
21
  embeddings = load_embeddings()
22
  vectors_2d, reducer = reduce_embeddings(embeddings)
23
 
 
 
 
 
 
 
 
 
 
 
 
24
  tab1, tab2 = st.tabs(["plot", "diffs"])
25
 
26
  with tab1:
 
21
  embeddings = load_embeddings()
22
  vectors_2d, reducer = reduce_embeddings(embeddings)
23
 
24
+ def sidebar():
25
+ st.sidebar.title("About this app")
26
+ st.sidebar.markdown(
27
+ "This app is intended to give a more intuitive and interactive understanding of sequence embeddings (e.g. sentence), \n"
28
+ "through interactive plots and operations with these embeddings, with a focus on embedding inversion.\n"
29
+ "We explore both sequence embedding inversion using the method described in [Morris et al., 2023](https://arxiv.org/abs/2310.06816), as well as"
30
+ " dimensionality rediction transforms and inverse transforms, and its effect on embedding inversion."
31
+ )
32
+
33
+ sidebar()
34
+
35
  tab1, tab2 = st.tabs(["plot", "diffs"])
36
 
37
  with tab1:
views.py CHANGED
@@ -8,6 +8,7 @@ from streamlit_plotly_events import plotly_events
8
  from resources import reduce_embeddings
9
  import utils
10
  import pandas as pd
 
11
 
12
 
13
  def diffs(embeddings: np.ndarray, corrector):
@@ -37,7 +38,7 @@ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, correcto
37
 
38
  x, y = 0.0, 0.0
39
  vec = np.array([x, y]).astype("float32")
40
-
41
  # Add a card container to the right of the content with Streamlit columns
42
  col1, col2 = st.columns([0.6, 0.4]) # Adjusting ratio to allocate space for the card container
43
 
@@ -48,8 +49,8 @@ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, correcto
48
  with st.form(key="form1_main"):
49
  if selected_points:
50
  clicked_point = selected_points[0]
51
- x_coord = x = clicked_point['x']
52
- y_coord = y = clicked_point['y']
53
 
54
  x = st.number_input("X Coordinate", value=x, format="%.10f")
55
  y = st.number_input("Y Coordinate", value=y, format="%.10f")
@@ -75,6 +76,17 @@ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, correcto
75
 
76
  with col2:
77
  closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
 
 
 
78
  st.markdown(
79
- f"### Selected text:\n```console\n{df.title.iloc[closest_sentence_index] if closest_sentence_index > -1 else '[no selected text]'}\n```"
80
  )
 
 
 
 
 
 
 
 
 
8
  from resources import reduce_embeddings
9
  import utils
10
  import pandas as pd
11
+ from scipy.spatial import distance
12
 
13
 
14
  def diffs(embeddings: np.ndarray, corrector):
 
38
 
39
  x, y = 0.0, 0.0
40
  vec = np.array([x, y]).astype("float32")
41
+ inferred_embedding = None
42
  # Add a card container to the right of the content with Streamlit columns
43
  col1, col2 = st.columns([0.6, 0.4]) # Adjusting ratio to allocate space for the card container
44
 
 
49
  with st.form(key="form1_main"):
50
  if selected_points:
51
  clicked_point = selected_points[0]
52
+ x = clicked_point['x']
53
+ y = clicked_point['y']
54
 
55
  x = st.number_input("X Coordinate", value=x, format="%.10f")
56
  y = st.number_input("Y Coordinate", value=y, format="%.10f")
 
76
 
77
  with col2:
78
  closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
79
+ selected_sentence = df.title.iloc[closest_sentence_index] if closest_sentence_index > -1 else None
80
+ selected_sentence_embedding = embeddings[closest_sentence_index] if closest_sentence_index > -1 else None
81
+
82
  st.markdown(
83
+ f"### Selected text:\n```console\n{selected_sentence}\n```"
84
  )
85
+
86
+ if inferred_embedding is not None and (closest_sentence_index != -1):
87
+ couple = selected_sentence_embedding.squeeze(), inferred_embedding.squeeze()
88
+ st.markdown(f"### Inferred embedding distance:")
89
+ st.number_input("Euclidean", value=distance.euclidean(
90
+ *couple
91
+ ), disabled=True)
92
+ st.number_input("Cosine", value=distance.cosine(*couple), disabled=True)