kaisugi commited on
Commit
e1a3f25
·
1 Parent(s): b6363d9
Files changed (1) hide show
  1. app.py +37 -40
app.py CHANGED
@@ -27,14 +27,41 @@ def load_sentence_data():
27
 
28
 
29
  @st.cache(allow_output_mutation=True)
30
- def load_sentence_embeddings():
31
  npz_comp = np.load("sentence_embeddings_789k.npz")
32
  sentence_embeddings = npz_comp["arr_0"]
33
 
34
- return sentence_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
35
 
 
 
 
36
 
37
- @st.cache
 
 
 
 
 
 
 
 
 
 
 
 
38
  def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
39
  with torch.no_grad():
40
  inputs = tokenizer.encode_plus(
@@ -58,47 +85,17 @@ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_d
58
  return pd.DataFrame({"sentences": retrieved_sentences})
59
 
60
 
61
- def main(model, tokenizer, sentence_df, index):
62
- st.markdown("## AI-based Paraphrasing for Academic Writing")
63
-
64
- input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
65
- top_k = st.number_input('top_k', min_value=1, value=10, step=1)
66
-
67
- df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
68
- st.table(df)
69
-
70
-
71
  if __name__ == "__main__":
72
  model, tokenizer = load_model_and_tokenizer()
73
  sentence_df = load_sentence_data()
74
- sentence_embeddings = load_sentence_embeddings()
75
-
76
- faiss.normalize_L2(sentence_embeddings)
77
-
78
- D = 768
79
- N = 789188
80
- Xt = sentence_embeddings[:39000]
81
- X = sentence_embeddings
82
-
83
- # Param of PQ
84
- M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
85
- nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
86
- # Param of IVF
87
- nlist = 1000 # The number of cells (space partition). Typical value is sqrt(N)
88
- # Param of HNSW
89
- hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
90
-
91
- # Setup
92
- quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
93
- index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
94
 
95
- # Train
96
- index.train(Xt)
97
 
98
- # Add
99
- index.add(X)
100
 
101
- # Search
102
- index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
103
 
104
- main(model, tokenizer, sentence_df, index)
 
 
 
27
 
28
 
29
  @st.cache(allow_output_mutation=True)
30
+ def load_sentence_embeddings_and_index():
31
  npz_comp = np.load("sentence_embeddings_789k.npz")
32
  sentence_embeddings = npz_comp["arr_0"]
33
 
34
+ faiss.normalize_L2(sentence_embeddings)
35
+ D = 768
36
+ N = 789188
37
+ Xt = sentence_embeddings[:39000]
38
+ X = sentence_embeddings
39
+
40
+ # Param of PQ
41
+ M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
42
+ nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
43
+ # Param of IVF
44
+ nlist = 888 # The number of cells (space partition). Typical value is sqrt(N)
45
+ # Param of HNSW
46
+ hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
47
 
48
+ # Setup
49
+ quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
50
+ index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
51
 
52
+ # Train
53
+ index.train(Xt)
54
+
55
+ # Add
56
+ index.add(X)
57
+
58
+ # Search
59
+ index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
60
+
61
+ return sentence_embeddings, index
62
+
63
+
64
+ @st.cache(allow_output_mutation=True)
65
  def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
66
  with torch.no_grad():
67
  inputs = tokenizer.encode_plus(
 
85
  return pd.DataFrame({"sentences": retrieved_sentences})
86
 
87
 
 
 
 
 
 
 
 
 
 
 
88
  if __name__ == "__main__":
89
  model, tokenizer = load_model_and_tokenizer()
90
  sentence_df = load_sentence_data()
91
+ sentence_embeddings, index = load_sentence_embeddings_and_index()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
 
 
93
 
94
+ st.markdown("## AI-based Paraphrasing for Academic Writing")
 
95
 
96
+ input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
97
+ top_k = st.number_input('top_k', min_value=1, value=10, step=1)
98
 
99
+ if st.button('search'):
100
+ df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
101
+ st.table(df)