nn function now compares vectors of target word only with vectors within the same model
Browse files- app.py +33 -34
- word2vec.py +67 -6
app.py
CHANGED
@@ -12,6 +12,9 @@ from streamlit_tags import st_tags, st_tags_sidebar
|
|
12 |
|
13 |
st.set_page_config(page_title="Ancient Greek Word2Vec", layout="centered")
|
14 |
|
|
|
|
|
|
|
15 |
# Horizontal menu
|
16 |
active_tab = option_menu(None, ["Nearest neighbours", "Cosine similarity", "3D graph", 'Dictionary'],
|
17 |
menu_icon="cast", default_index=0, orientation="horizontal")
|
@@ -29,59 +32,55 @@ if active_tab == "Nearest neighbours":
|
|
29 |
all_words = load_compressed_word_list(compressed_word_list_filename)
|
30 |
eligible_models = ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"]
|
31 |
|
|
|
|
|
|
|
32 |
with st.container():
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
|
41 |
-
with col2:
|
42 |
-
time_slice = st.selectbox("Time slice", eligible_models)
|
43 |
|
44 |
models = st.multiselect(
|
45 |
"Select models to search for neighbours",
|
46 |
-
|
47 |
)
|
48 |
n = st.slider("Number of neighbours", 1, 50, 15)
|
49 |
|
50 |
-
nearest_neighbours_button = st.button("Find nearest neighbours")
|
51 |
|
52 |
# If the button to calculate nearest neighbours is clicked
|
53 |
-
if
|
54 |
-
|
55 |
-
# Rewrite timeslices to model names: Archaic -> archaic_cbow
|
56 |
-
if time_slice == 'Hellenistic':
|
57 |
-
time_slice = 'hellen'
|
58 |
-
elif time_slice == 'Early Roman':
|
59 |
-
time_slice = 'early_roman'
|
60 |
-
elif time_slice == 'Late Roman':
|
61 |
-
time_slice = 'late_roman'
|
62 |
-
|
63 |
-
time_slice = time_slice.lower() + "_cbow"
|
64 |
-
|
65 |
|
66 |
# Check if all fields are filled in
|
67 |
-
if validate_nearest_neighbours(word,
|
68 |
st.error('Please fill in all fields')
|
69 |
else:
|
70 |
# Rewrite models to list of all loaded models
|
71 |
models = load_selected_models(models)
|
72 |
|
73 |
-
nearest_neighbours = get_nearest_neighbours(word,
|
74 |
-
|
75 |
-
df = pd.DataFrame(
|
76 |
-
nearest_neighbours,
|
77 |
-
columns=["Word", "Time slice", "Similarity"],
|
78 |
-
index = range(1, len(nearest_neighbours) + 1)
|
79 |
-
)
|
80 |
-
st.table(df)
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
|
|
83 |
# Store content in a temporary file
|
84 |
-
tmp_file = store_df_in_temp_file(
|
85 |
|
86 |
# Open the temporary file and read its content
|
87 |
with open(tmp_file, "rb") as file:
|
@@ -91,7 +90,7 @@ if active_tab == "Nearest neighbours":
|
|
91 |
st.download_button(
|
92 |
"Download results",
|
93 |
data=file_byte,
|
94 |
-
file_name = f'nearest_neighbours_{word}
|
95 |
mime='application/octet-stream'
|
96 |
)
|
97 |
|
|
|
12 |
|
13 |
st.set_page_config(page_title="Ancient Greek Word2Vec", layout="centered")
|
14 |
|
15 |
+
def click_nn_button():
|
16 |
+
st.session_state.nearest_neighbours = not st.session_state.nearest_neighbours
|
17 |
+
|
18 |
# Horizontal menu
|
19 |
active_tab = option_menu(None, ["Nearest neighbours", "Cosine similarity", "3D graph", 'Dictionary'],
|
20 |
menu_icon="cast", default_index=0, orientation="horizontal")
|
|
|
32 |
all_words = load_compressed_word_list(compressed_word_list_filename)
|
33 |
eligible_models = ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"]
|
34 |
|
35 |
+
if 'nearest_neighbours' not in st.session_state:
|
36 |
+
st.session_state.nearest_neighbours = False
|
37 |
+
|
38 |
with st.container():
|
39 |
+
|
40 |
+
word = st.multiselect("Enter a word", all_words, max_selections=1)
|
41 |
+
if len(word) > 0:
|
42 |
+
word = word[0]
|
43 |
+
|
44 |
+
# Check which models contain the word
|
45 |
+
eligible_models = check_word_in_models(word)
|
46 |
|
|
|
|
|
47 |
|
48 |
models = st.multiselect(
|
49 |
"Select models to search for neighbours",
|
50 |
+
eligible_models
|
51 |
)
|
52 |
n = st.slider("Number of neighbours", 1, 50, 15)
|
53 |
|
54 |
+
nearest_neighbours_button = st.button("Find nearest neighbours", on_click = click_nn_button)
|
55 |
|
56 |
# If the button to calculate nearest neighbours is clicked
|
57 |
+
if st.session_state.nearest_neighbours:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
# Check if all fields are filled in
|
60 |
+
if validate_nearest_neighbours(word, n, models) == False:
|
61 |
st.error('Please fill in all fields')
|
62 |
else:
|
63 |
# Rewrite models to list of all loaded models
|
64 |
models = load_selected_models(models)
|
65 |
|
66 |
+
nearest_neighbours = get_nearest_neighbours(word, n, models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
all_dfs = []
|
69 |
+
|
70 |
+
# Create dataframes
|
71 |
+
for model in nearest_neighbours.keys():
|
72 |
+
st.write(f"### {model}")
|
73 |
+
df = pd.DataFrame(
|
74 |
+
nearest_neighbours[model],
|
75 |
+
columns = ['Word', 'Cosine Similarity']
|
76 |
+
)
|
77 |
+
|
78 |
+
all_dfs.append((model, df))
|
79 |
+
st.table(df)
|
80 |
|
81 |
+
|
82 |
# Store content in a temporary file
|
83 |
+
tmp_file = store_df_in_temp_file(all_dfs)
|
84 |
|
85 |
# Open the temporary file and read its content
|
86 |
with open(tmp_file, "rb") as file:
|
|
|
90 |
st.download_button(
|
91 |
"Download results",
|
92 |
data=file_byte,
|
93 |
+
file_name = f'nearest_neighbours_{word}_TEST.xlsx',
|
94 |
mime='application/octet-stream'
|
95 |
)
|
96 |
|
word2vec.py
CHANGED
@@ -148,11 +148,11 @@ def get_cosine_similarity_one_word(word, time_slice1, time_slice2):
|
|
148 |
|
149 |
|
150 |
|
151 |
-
def validate_nearest_neighbours(word,
|
152 |
'''
|
153 |
Validate the input of the nearest neighbours function
|
154 |
'''
|
155 |
-
if word == '' or
|
156 |
return False
|
157 |
return True
|
158 |
|
@@ -198,7 +198,7 @@ def convert_time_name_to_model(time_name):
|
|
198 |
elif time_name == 'archaic':
|
199 |
return 'Archaic'
|
200 |
|
201 |
-
def
|
202 |
'''
|
203 |
Return the nearest neighbours of a word
|
204 |
|
@@ -243,6 +243,51 @@ def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models(
|
|
243 |
|
244 |
|
245 |
return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
|
248 |
def get_nearest_neighbours_vectors(word, time_slice_model, n=15):
|
@@ -287,7 +332,7 @@ def write_to_file(data):
|
|
287 |
return temp_file_path
|
288 |
|
289 |
|
290 |
-
def store_df_in_temp_file(
|
291 |
'''
|
292 |
Store the dataframe in a temporary file
|
293 |
'''
|
@@ -300,9 +345,25 @@ def store_df_in_temp_file(df):
|
|
300 |
# Create random tmp file name
|
301 |
_, temp_file_path = tempfile.mkstemp(prefix="temp_", suffix=".xlsx", dir=temp_dir)
|
302 |
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
304 |
with pd.ExcelWriter(temp_file_path, engine='xlsxwriter') as writer:
|
305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
return temp_file_path
|
308 |
|
|
|
148 |
|
149 |
|
150 |
|
151 |
+
def validate_nearest_neighbours(word, n, models):
|
152 |
'''
|
153 |
Validate the input of the nearest neighbours function
|
154 |
'''
|
155 |
+
if word == '' or n == '' or models == []:
|
156 |
return False
|
157 |
return True
|
158 |
|
|
|
198 |
elif time_name == 'archaic':
|
199 |
return 'Archaic'
|
200 |
|
201 |
+
def get_nearest_neighbours2(word, n=10, models=load_all_models()):
|
202 |
'''
|
203 |
Return the nearest neighbours of a word
|
204 |
|
|
|
243 |
|
244 |
|
245 |
return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
|
246 |
+
|
247 |
+
|
248 |
+
def get_nearest_neighbours(target_word, n=10, models=load_all_models()):
|
249 |
+
"""
|
250 |
+
Return the nearest neighbours of a word for the given models
|
251 |
+
|
252 |
+
word: the word for which the nearest neighbours are calculated
|
253 |
+
n: the number of nearest neighbours to return (default: 10)
|
254 |
+
models: list of tuples with the name of the time slice and the word2vec model (default: all in ./models)
|
255 |
+
|
256 |
+
Return: { 'model_name': [(word, cosine_similarity), ...], ... }
|
257 |
+
"""
|
258 |
+
nearest_neighbours = {}
|
259 |
+
|
260 |
+
# Iterate over models and compute nearest neighbours
|
261 |
+
for model in models:
|
262 |
+
model_neighbours = []
|
263 |
+
model_name = convert_model_to_time_name(model[0])
|
264 |
+
model = model[1]
|
265 |
+
vector_1 = get_word_vector(model, target_word)
|
266 |
+
|
267 |
+
# Iterate over all words of the model
|
268 |
+
for word, index in model.wv.key_to_index.items():
|
269 |
+
vector_2 = get_word_vector(model, word)
|
270 |
+
cosine_sim = cosine_similarity(vector_1, vector_2)
|
271 |
+
|
272 |
+
# If the list of nearest neighbours is not full yet, add the current word
|
273 |
+
if len(model_neighbours) < n:
|
274 |
+
model_neighbours.append((word, cosine_sim))
|
275 |
+
else:
|
276 |
+
# If the list of nearest neighbours is full, replace the word with the smallest cosine similarity
|
277 |
+
smallest_neighbour = min(model_neighbours, key=lambda x: x[1])
|
278 |
+
if cosine_sim > smallest_neighbour[1]:
|
279 |
+
model_neighbours.remove(smallest_neighbour)
|
280 |
+
model_neighbours.append((word, cosine_sim))
|
281 |
+
|
282 |
+
# Sort the nearest neighbours by cosine similarity
|
283 |
+
model_neighbours = sorted(model_neighbours, key=lambda x: x[1], reverse=True)
|
284 |
+
|
285 |
+
# Add the model name and the nearest neighbours to the dictionary
|
286 |
+
nearest_neighbours[model_name] = model_neighbours
|
287 |
+
|
288 |
+
return nearest_neighbours
|
289 |
+
|
290 |
+
|
291 |
|
292 |
|
293 |
def get_nearest_neighbours_vectors(word, time_slice_model, n=15):
|
|
|
332 |
return temp_file_path
|
333 |
|
334 |
|
335 |
+
def store_df_in_temp_file(all_dfs):
|
336 |
'''
|
337 |
Store the dataframe in a temporary file
|
338 |
'''
|
|
|
345 |
# Create random tmp file name
|
346 |
_, temp_file_path = tempfile.mkstemp(prefix="temp_", suffix=".xlsx", dir=temp_dir)
|
347 |
|
348 |
+
|
349 |
+
# Concatenate all dataframes
|
350 |
+
df = pd.concat([df for _, df in all_dfs], axis=1, keys=[model for model, _ in all_dfs])
|
351 |
+
|
352 |
+
|
353 |
+
# Create an ExcelWriter object
|
354 |
with pd.ExcelWriter(temp_file_path, engine='xlsxwriter') as writer:
|
355 |
+
# Create a new sheet
|
356 |
+
worksheet = writer.book.add_worksheet('Results')
|
357 |
+
|
358 |
+
# Write text before DataFrames
|
359 |
+
start_row = 0
|
360 |
+
for model, df in all_dfs:
|
361 |
+
# Write model name as text
|
362 |
+
worksheet.write(start_row, 0, f"Model: {model}")
|
363 |
+
# Write DataFrame
|
364 |
+
df.to_excel(writer, sheet_name='Results', index=False, startrow=start_row + 1, startcol=0)
|
365 |
+
# Update start_row for the next model
|
366 |
+
start_row += df.shape[0] + 3 # Add some space between models
|
367 |
|
368 |
return temp_file_path
|
369 |
|