marksverdhei commited on
Commit
78022ff
·
1 Parent(s): 867cf07

Apply ruff

Browse files
Files changed (2) hide show
  1. app.py +8 -132
  2. utils.py +28 -1
app.py CHANGED
@@ -1,74 +1,17 @@
1
- import os
2
- import pickle
3
  import streamlit as st
4
- import pandas as pd
5
- import vec2text
6
  import torch
7
- from transformers import AutoModel, AutoTokenizer
8
- from umap import UMAP
9
- import plotly.express as px
10
  import numpy as np
11
- from sklearn.decomposition import PCA
12
- from streamlit_plotly_events import plotly_events
13
- import utils
14
 
15
  use_cpu = not torch.cuda.is_available()
16
  device = "cpu" if use_cpu else "cuda"
17
 
18
- # Custom file cache decorator
19
-
20
- def file_cache(file_path):
21
- def decorator(func):
22
- def wrapper(*args, **kwargs):
23
- # Ensure the directory exists
24
- dir_path = os.path.dirname(file_path)
25
- if not os.path.exists(dir_path):
26
- os.makedirs(dir_path, exist_ok=True)
27
- print(f"Created directory {dir_path}")
28
-
29
- # Check if the file already exists
30
- if os.path.exists(file_path):
31
- # Load from cache
32
- with open(file_path, "rb") as f:
33
- print(f"Loading cached data from {file_path}")
34
- return pickle.load(f)
35
- else:
36
- # Compute and save to cache
37
- result = func(*args, **kwargs)
38
- with open(file_path, "wb") as f:
39
- pickle.dump(result, f)
40
- print(f"Saving new cache to {file_path}")
41
- return result
42
- return wrapper
43
- return decorator
44
-
45
- @st.cache_resource
46
- def vector_compressor_from_config():
47
- # Return UMAP with 2 components for dimensionality reduction
48
- # return UMAP(n_components=2)
49
- return PCA(n_components=2)
50
-
51
-
52
- # Caching the dataframe since loading from an external source can be time-consuming
53
- @st.cache_data
54
- def load_data():
55
- return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
56
-
57
  df = load_data()
58
 
59
- # Caching the model and tokenizer to avoid reloading
60
- @st.cache_resource
61
- def load_model_and_tokenizer():
62
- encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to(device)
63
- tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
64
- return encoder, tokenizer
65
-
66
  encoder, tokenizer = load_model_and_tokenizer()
67
 
68
- # Caching the vec2text corrector
69
- @st.cache_resource
70
- def load_corrector():
71
- return vec2text.load_pretrained_corrector("gtr-base")
72
 
73
  corrector = load_corrector()
74
 
@@ -79,78 +22,11 @@ def load_embeddings():
79
 
80
  embeddings = load_embeddings()
81
 
82
- # Custom cache the UMAP reduction using file_cache decorator
83
- @st.cache_data
84
- @file_cache(".cache/reducer_embeddings.pickle")
85
- def reduce_embeddings(embeddings):
86
- reducer = vector_compressor_from_config()
87
- return reducer.fit_transform(embeddings), reducer
88
-
89
- vectors_2d, reducer = reduce_embeddings(embeddings)
90
-
91
- # Add a scatter plot using Plotly
92
- fig = px.scatter(
93
- x=vectors_2d[:, 0],
94
- y=vectors_2d[:, 1],
95
- opacity=0.6,
96
- hover_data={"Title": df["title"]},
97
- labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
98
- title="UMAP Scatter Plot of Reddit Titles",
99
- color_discrete_sequence=["#ff504c"] # Set default blue color for points
100
- )
101
-
102
- # Customize the layout to adapt to browser settings (light/dark mode)
103
- fig.update_layout(
104
- template=None, # Let Plotly adapt automatically based on user settings
105
- plot_bgcolor="rgba(0, 0, 0, 0)",
106
- paper_bgcolor="rgba(0, 0, 0, 0)"
107
- )
108
-
109
- x, y = 0.0, 0.0
110
- vec = np.array([x, y]).astype("float32")
111
-
112
- # Add a card container to the right of the content with Streamlit columns
113
- col1, col2 = st.columns([3, 1]) # Adjusting ratio to allocate space for the card container
114
-
115
- with col1:
116
- # Main content stays here (scatterplot, form, etc.)
117
- selected_points = plotly_events(fig, click_event=True, hover_event=False, #override_height=600, override_width="100%"
118
- )
119
- with st.form(key="form1_main"):
120
- if selected_points:
121
- clicked_point = selected_points[0]
122
- x_coord = x = clicked_point['x']
123
- y_coord = y = clicked_point['y']
124
-
125
- x = st.number_input("X Coordinate", value=x, format="%.10f")
126
- y = st.number_input("Y Coordinate", value=y, format="%.10f")
127
- vec = np.array([x, y]).astype("float32")
128
-
129
-
130
- submit_button = st.form_submit_button("Submit")
131
-
132
- if selected_points or submit_button:
133
- inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
134
- inferred_embedding = inferred_embedding.astype("float32")
135
 
136
- output = vec2text.invert_embeddings(
137
- embeddings=torch.tensor(inferred_embedding).cuda(),
138
- corrector=corrector,
139
- num_steps=20,
140
- )
141
 
142
- st.text(str(output))
143
- st.text(str(inferred_embedding))
144
- else:
145
- st.text("Click on a point in the scatterplot to see its coordinates.")
146
 
147
- with col2:
148
- closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
149
- st.write(f"{vectors_2d.dtype} {vec.dtype}")
150
- if closest_sentence_index > -1:
151
- st.write(df["title"].iloc[closest_sentence_index])
152
- # Card content
153
- st.markdown("## Card Container")
154
- st.write("This is an additional card container to the right of the main content.")
155
- st.write("You can use this space to show additional information, actions, or insights.")
156
- st.button("Card Button")
 
 
 
1
  import streamlit as st
 
 
2
  import torch
 
 
 
3
  import numpy as np
4
+ import views
5
+ from resources import load_corrector, load_data, load_model_and_tokenizer
 
6
 
7
  use_cpu = not torch.cuda.is_available()
8
  device = "cpu" if use_cpu else "cuda"
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  df = load_data()
11
 
 
 
 
 
 
 
 
12
  encoder, tokenizer = load_model_and_tokenizer()
13
 
14
+
 
 
 
15
 
16
  corrector = load_corrector()
17
 
 
22
 
23
  embeddings = load_embeddings()
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ tab1, tab2 = st.tabs(["plot", "diffs"])
 
 
 
 
27
 
28
+ with tab1:
29
+ views.plot()
 
 
30
 
31
+ with tab2:
32
+ views.diffs()
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import numpy as np
2
 
3
 
@@ -25,4 +27,29 @@ def find_exact_match(matrix, query_vector, decimals=9):
25
  if np.any(matches):
26
  return np.where(matches)[0][0] # Return the first match
27
  else:
28
- return -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
  import numpy as np
4
 
5
 
 
27
  if np.any(matches):
28
  return np.where(matches)[0][0] # Return the first match
29
  else:
30
+ return -1
31
+
32
+ def file_cache(file_path):
33
+ def decorator(func):
34
+ def wrapper(*args, **kwargs):
35
+ # Ensure the directory exists
36
+ dir_path = os.path.dirname(file_path)
37
+ if not os.path.exists(dir_path):
38
+ os.makedirs(dir_path, exist_ok=True)
39
+ print(f"Created directory {dir_path}")
40
+
41
+ # Check if the file already exists
42
+ if os.path.exists(file_path):
43
+ # Load from cache
44
+ with open(file_path, "rb") as f:
45
+ print(f"Loading cached data from {file_path}")
46
+ return pickle.load(f)
47
+ else:
48
+ # Compute and save to cache
49
+ result = func(*args, **kwargs)
50
+ with open(file_path, "wb") as f:
51
+ pickle.dump(result, f)
52
+ print(f"Saving new cache to {file_path}")
53
+ return result
54
+ return wrapper
55
+ return decorator