Mark7549 commited on
Commit
8fb441e
·
1 Parent(s): 7088ca8

nn function now compares vectors of target word only with vectors within the same model

Browse files
Files changed (2) hide show
  1. app.py +33 -34
  2. word2vec.py +67 -6
app.py CHANGED
@@ -12,6 +12,9 @@ from streamlit_tags import st_tags, st_tags_sidebar
12
 
13
  st.set_page_config(page_title="Ancient Greek Word2Vec", layout="centered")
14
 
 
 
 
15
  # Horizontal menu
16
  active_tab = option_menu(None, ["Nearest neighbours", "Cosine similarity", "3D graph", 'Dictionary'],
17
  menu_icon="cast", default_index=0, orientation="horizontal")
@@ -29,59 +32,55 @@ if active_tab == "Nearest neighbours":
29
  all_words = load_compressed_word_list(compressed_word_list_filename)
30
  eligible_models = ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"]
31
 
 
 
 
32
  with st.container():
33
- with col1:
34
- word = st.multiselect("Enter a word", all_words, max_selections=1)
35
- if len(word) > 0:
36
- word = word[0]
37
-
38
- # Check which models contain the word
39
- eligible_models = check_word_in_models(word)
40
 
41
- with col2:
42
- time_slice = st.selectbox("Time slice", eligible_models)
43
 
44
  models = st.multiselect(
45
  "Select models to search for neighbours",
46
- ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"]
47
  )
48
  n = st.slider("Number of neighbours", 1, 50, 15)
49
 
50
- nearest_neighbours_button = st.button("Find nearest neighbours")
51
 
52
  # If the button to calculate nearest neighbours is clicked
53
- if nearest_neighbours_button:
54
-
55
- # Rewrite timeslices to model names: Archaic -> archaic_cbow
56
- if time_slice == 'Hellenistic':
57
- time_slice = 'hellen'
58
- elif time_slice == 'Early Roman':
59
- time_slice = 'early_roman'
60
- elif time_slice == 'Late Roman':
61
- time_slice = 'late_roman'
62
-
63
- time_slice = time_slice.lower() + "_cbow"
64
-
65
 
66
  # Check if all fields are filled in
67
- if validate_nearest_neighbours(word, time_slice, n, models) == False:
68
  st.error('Please fill in all fields')
69
  else:
70
  # Rewrite models to list of all loaded models
71
  models = load_selected_models(models)
72
 
73
- nearest_neighbours = get_nearest_neighbours(word, time_slice, n, models)
74
-
75
- df = pd.DataFrame(
76
- nearest_neighbours,
77
- columns=["Word", "Time slice", "Similarity"],
78
- index = range(1, len(nearest_neighbours) + 1)
79
- )
80
- st.table(df)
81
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
 
83
  # Store content in a temporary file
84
- tmp_file = store_df_in_temp_file(df)
85
 
86
  # Open the temporary file and read its content
87
  with open(tmp_file, "rb") as file:
@@ -91,7 +90,7 @@ if active_tab == "Nearest neighbours":
91
  st.download_button(
92
  "Download results",
93
  data=file_byte,
94
- file_name = f'nearest_neighbours_{word}_{time_slice}.xlsx',
95
  mime='application/octet-stream'
96
  )
97
 
 
12
 
13
  st.set_page_config(page_title="Ancient Greek Word2Vec", layout="centered")
14
 
15
+ def click_nn_button():
16
+ st.session_state.nearest_neighbours = not st.session_state.nearest_neighbours
17
+
18
  # Horizontal menu
19
  active_tab = option_menu(None, ["Nearest neighbours", "Cosine similarity", "3D graph", 'Dictionary'],
20
  menu_icon="cast", default_index=0, orientation="horizontal")
 
32
  all_words = load_compressed_word_list(compressed_word_list_filename)
33
  eligible_models = ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"]
34
 
35
+ if 'nearest_neighbours' not in st.session_state:
36
+ st.session_state.nearest_neighbours = False
37
+
38
  with st.container():
39
+
40
+ word = st.multiselect("Enter a word", all_words, max_selections=1)
41
+ if len(word) > 0:
42
+ word = word[0]
43
+
44
+ # Check which models contain the word
45
+ eligible_models = check_word_in_models(word)
46
 
 
 
47
 
48
  models = st.multiselect(
49
  "Select models to search for neighbours",
50
+ eligible_models
51
  )
52
  n = st.slider("Number of neighbours", 1, 50, 15)
53
 
54
+ nearest_neighbours_button = st.button("Find nearest neighbours", on_click = click_nn_button)
55
 
56
  # If the button to calculate nearest neighbours is clicked
57
+ if st.session_state.nearest_neighbours:
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Check if all fields are filled in
60
+ if validate_nearest_neighbours(word, n, models) == False:
61
  st.error('Please fill in all fields')
62
  else:
63
  # Rewrite models to list of all loaded models
64
  models = load_selected_models(models)
65
 
66
+ nearest_neighbours = get_nearest_neighbours(word, n, models)
 
 
 
 
 
 
 
67
 
68
+ all_dfs = []
69
+
70
+ # Create dataframes
71
+ for model in nearest_neighbours.keys():
72
+ st.write(f"### {model}")
73
+ df = pd.DataFrame(
74
+ nearest_neighbours[model],
75
+ columns = ['Word', 'Cosine Similarity']
76
+ )
77
+
78
+ all_dfs.append((model, df))
79
+ st.table(df)
80
 
81
+
82
  # Store content in a temporary file
83
+ tmp_file = store_df_in_temp_file(all_dfs)
84
 
85
  # Open the temporary file and read its content
86
  with open(tmp_file, "rb") as file:
 
90
  st.download_button(
91
  "Download results",
92
  data=file_byte,
93
+ file_name = f'nearest_neighbours_{word}_TEST.xlsx',
94
  mime='application/octet-stream'
95
  )
96
 
word2vec.py CHANGED
@@ -148,11 +148,11 @@ def get_cosine_similarity_one_word(word, time_slice1, time_slice2):
148
 
149
 
150
 
151
- def validate_nearest_neighbours(word, time_slice_model, n, models):
152
  '''
153
  Validate the input of the nearest neighbours function
154
  '''
155
- if word == '' or time_slice_model == [] or n == '' or models == []:
156
  return False
157
  return True
158
 
@@ -198,7 +198,7 @@ def convert_time_name_to_model(time_name):
198
  elif time_name == 'archaic':
199
  return 'Archaic'
200
 
201
- def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models()):
202
  '''
203
  Return the nearest neighbours of a word
204
 
@@ -243,6 +243,51 @@ def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models(
243
 
244
 
245
  return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
 
248
  def get_nearest_neighbours_vectors(word, time_slice_model, n=15):
@@ -287,7 +332,7 @@ def write_to_file(data):
287
  return temp_file_path
288
 
289
 
290
- def store_df_in_temp_file(df):
291
  '''
292
  Store the dataframe in a temporary file
293
  '''
@@ -300,9 +345,25 @@ def store_df_in_temp_file(df):
300
  # Create random tmp file name
301
  _, temp_file_path = tempfile.mkstemp(prefix="temp_", suffix=".xlsx", dir=temp_dir)
302
 
303
- # Write data to the temporary file
 
 
 
 
 
304
  with pd.ExcelWriter(temp_file_path, engine='xlsxwriter') as writer:
305
- df.to_excel(writer, index=False)
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  return temp_file_path
308
 
 
148
 
149
 
150
 
151
+ def validate_nearest_neighbours(word, n, models):
152
  '''
153
  Validate the input of the nearest neighbours function
154
  '''
155
+ if word == '' or n == '' or models == []:
156
  return False
157
  return True
158
 
 
198
  elif time_name == 'archaic':
199
  return 'Archaic'
200
 
201
+ def get_nearest_neighbours2(word, n=10, models=load_all_models()):
202
  '''
203
  Return the nearest neighbours of a word
204
 
 
243
 
244
 
245
  return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
246
+
247
+
248
+ def get_nearest_neighbours(target_word, n=10, models=load_all_models()):
249
+ """
250
+ Return the nearest neighbours of a word for the given models
251
+
252
+ word: the word for which the nearest neighbours are calculated
253
+ n: the number of nearest neighbours to return (default: 10)
254
+ models: list of tuples with the name of the time slice and the word2vec model (default: all in ./models)
255
+
256
+ Return: { 'model_name': [(word, cosine_similarity), ...], ... }
257
+ """
258
+ nearest_neighbours = {}
259
+
260
+ # Iterate over models and compute nearest neighbours
261
+ for model in models:
262
+ model_neighbours = []
263
+ model_name = convert_model_to_time_name(model[0])
264
+ model = model[1]
265
+ vector_1 = get_word_vector(model, target_word)
266
+
267
+ # Iterate over all words of the model
268
+ for word, index in model.wv.key_to_index.items():
269
+ vector_2 = get_word_vector(model, word)
270
+ cosine_sim = cosine_similarity(vector_1, vector_2)
271
+
272
+ # If the list of nearest neighbours is not full yet, add the current word
273
+ if len(model_neighbours) < n:
274
+ model_neighbours.append((word, cosine_sim))
275
+ else:
276
+ # If the list of nearest neighbours is full, replace the word with the smallest cosine similarity
277
+ smallest_neighbour = min(model_neighbours, key=lambda x: x[1])
278
+ if cosine_sim > smallest_neighbour[1]:
279
+ model_neighbours.remove(smallest_neighbour)
280
+ model_neighbours.append((word, cosine_sim))
281
+
282
+ # Sort the nearest neighbours by cosine similarity
283
+ model_neighbours = sorted(model_neighbours, key=lambda x: x[1], reverse=True)
284
+
285
+ # Add the model name and the nearest neighbours to the dictionary
286
+ nearest_neighbours[model_name] = model_neighbours
287
+
288
+ return nearest_neighbours
289
+
290
+
291
 
292
 
293
  def get_nearest_neighbours_vectors(word, time_slice_model, n=15):
 
332
  return temp_file_path
333
 
334
 
335
+ def store_df_in_temp_file(all_dfs):
336
  '''
337
  Store the dataframe in a temporary file
338
  '''
 
345
  # Create random tmp file name
346
  _, temp_file_path = tempfile.mkstemp(prefix="temp_", suffix=".xlsx", dir=temp_dir)
347
 
348
+
349
+ # Concatenate all dataframes
350
+ df = pd.concat([df for _, df in all_dfs], axis=1, keys=[model for model, _ in all_dfs])
351
+
352
+
353
+ # Create an ExcelWriter object
354
  with pd.ExcelWriter(temp_file_path, engine='xlsxwriter') as writer:
355
+ # Create a new sheet
356
+ worksheet = writer.book.add_worksheet('Results')
357
+
358
+ # Write text before DataFrames
359
+ start_row = 0
360
+ for model, df in all_dfs:
361
+ # Write model name as text
362
+ worksheet.write(start_row, 0, f"Model: {model}")
363
+ # Write DataFrame
364
+ df.to_excel(writer, sheet_name='Results', index=False, startrow=start_row + 1, startcol=0)
365
+ # Update start_row for the next model
366
+ start_row += df.shape[0] + 3 # Add some space between models
367
 
368
  return temp_file_path
369