Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
@@ -27,14 +27,41 @@ def load_sentence_data():
|
|
27 |
|
28 |
|
29 |
@st.cache(allow_output_mutation=True)
|
30 |
-
def
|
31 |
npz_comp = np.load("sentence_embeddings_789k.npz")
|
32 |
sentence_embeddings = npz_comp["arr_0"]
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
|
|
|
|
|
|
36 |
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
|
39 |
with torch.no_grad():
|
40 |
inputs = tokenizer.encode_plus(
|
@@ -58,47 +85,17 @@ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_d
|
|
58 |
return pd.DataFrame({"sentences": retrieved_sentences})
|
59 |
|
60 |
|
61 |
-
def main(model, tokenizer, sentence_df, index):
|
62 |
-
st.markdown("## AI-based Paraphrasing for Academic Writing")
|
63 |
-
|
64 |
-
input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
|
65 |
-
top_k = st.number_input('top_k', min_value=1, value=10, step=1)
|
66 |
-
|
67 |
-
df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
|
68 |
-
st.table(df)
|
69 |
-
|
70 |
-
|
71 |
if __name__ == "__main__":
|
72 |
model, tokenizer = load_model_and_tokenizer()
|
73 |
sentence_df = load_sentence_data()
|
74 |
-
sentence_embeddings =
|
75 |
-
|
76 |
-
faiss.normalize_L2(sentence_embeddings)
|
77 |
-
|
78 |
-
D = 768
|
79 |
-
N = 789188
|
80 |
-
Xt = sentence_embeddings[:39000]
|
81 |
-
X = sentence_embeddings
|
82 |
-
|
83 |
-
# Param of PQ
|
84 |
-
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
|
85 |
-
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
|
86 |
-
# Param of IVF
|
87 |
-
nlist = 1000 # The number of cells (space partition). Typical value is sqrt(N)
|
88 |
-
# Param of HNSW
|
89 |
-
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
|
90 |
-
|
91 |
-
# Setup
|
92 |
-
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
|
93 |
-
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
|
94 |
|
95 |
-
# Train
|
96 |
-
index.train(Xt)
|
97 |
|
98 |
-
|
99 |
-
index.add(X)
|
100 |
|
101 |
-
|
102 |
-
|
103 |
|
104 |
-
|
|
|
|
|
|
27 |
|
28 |
|
29 |
@st.cache(allow_output_mutation=True)
|
30 |
+
def load_sentence_embeddings_and_index():
|
31 |
npz_comp = np.load("sentence_embeddings_789k.npz")
|
32 |
sentence_embeddings = npz_comp["arr_0"]
|
33 |
|
34 |
+
faiss.normalize_L2(sentence_embeddings)
|
35 |
+
D = 768
|
36 |
+
N = 789188
|
37 |
+
Xt = sentence_embeddings[:39000]
|
38 |
+
X = sentence_embeddings
|
39 |
+
|
40 |
+
# Param of PQ
|
41 |
+
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
|
42 |
+
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
|
43 |
+
# Param of IVF
|
44 |
+
nlist = 888 # The number of cells (space partition). Typical value is sqrt(N)
|
45 |
+
# Param of HNSW
|
46 |
+
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
|
47 |
|
48 |
+
# Setup
|
49 |
+
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
|
50 |
+
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
|
51 |
|
52 |
+
# Train
|
53 |
+
index.train(Xt)
|
54 |
+
|
55 |
+
# Add
|
56 |
+
index.add(X)
|
57 |
+
|
58 |
+
# Search
|
59 |
+
index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
|
60 |
+
|
61 |
+
return sentence_embeddings, index
|
62 |
+
|
63 |
+
|
64 |
+
@st.cache(allow_output_mutation=True)
|
65 |
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
|
66 |
with torch.no_grad():
|
67 |
inputs = tokenizer.encode_plus(
|
|
|
85 |
return pd.DataFrame({"sentences": retrieved_sentences})
|
86 |
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
if __name__ == "__main__":
|
89 |
model, tokenizer = load_model_and_tokenizer()
|
90 |
sentence_df = load_sentence_data()
|
91 |
+
sentence_embeddings, index = load_sentence_embeddings_and_index()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
|
|
|
|
93 |
|
94 |
+
st.markdown("## AI-based Paraphrasing for Academic Writing")
|
|
|
95 |
|
96 |
+
input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
|
97 |
+
top_k = st.number_input('top_k', min_value=1, value=10, step=1)
|
98 |
|
99 |
+
if st.button('search'):
|
100 |
+
df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
|
101 |
+
st.table(df)
|