Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
@@ -62,7 +62,7 @@ def load_sentence_embeddings_and_index():
|
|
62 |
|
63 |
|
64 |
@st.cache(allow_output_mutation=True)
|
65 |
-
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
|
66 |
with torch.no_grad():
|
67 |
inputs = tokenizer.encode_plus(
|
68 |
input_text,
|
@@ -78,11 +78,19 @@ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_d
|
|
78 |
|
79 |
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
|
80 |
retrieved_sentences = []
|
|
|
81 |
|
82 |
for id in ids[0]:
|
83 |
retrieved_sentences.append(sentence_df.loc[id, "sentence"])
|
|
|
84 |
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
|
88 |
if __name__ == "__main__":
|
@@ -93,9 +101,11 @@ if __name__ == "__main__":
|
|
93 |
|
94 |
st.markdown("## AI-based Paraphrasing for Academic Writing")
|
95 |
|
96 |
-
input_text = st.text_area("text input", "
|
97 |
-
top_k = st.number_input('top_k', min_value=1, value=
|
|
|
98 |
|
99 |
if st.button('search'):
|
100 |
-
|
|
|
101 |
st.table(df)
|
|
|
62 |
|
63 |
|
64 |
@st.cache(allow_output_mutation=True)
|
65 |
+
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list):
|
66 |
with torch.no_grad():
|
67 |
inputs = tokenizer.encode_plus(
|
68 |
input_text,
|
|
|
78 |
|
79 |
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
|
80 |
retrieved_sentences = []
|
81 |
+
retrieved_paper_id = []
|
82 |
|
83 |
for id in ids[0]:
|
84 |
retrieved_sentences.append(sentence_df.loc[id, "sentence"])
|
85 |
+
retrieved_paper_id.append(f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}")
|
86 |
|
87 |
+
all_df = pd.DataFrame({"sentence": retrieved_sentences, "source link": retrieved_paper_id})
|
88 |
+
|
89 |
+
if len(exclude_word_list) == 0:
|
90 |
+
return all_df
|
91 |
+
else:
|
92 |
+
exclude_word_list_regex = '|'.join(exclude_word_list)
|
93 |
+
return all_df[~all_df["sentence"].str.contains(exclude_word_list_regex)]
|
94 |
|
95 |
|
96 |
if __name__ == "__main__":
|
|
|
101 |
|
102 |
st.markdown("## AI-based Paraphrasing for Academic Writing")
|
103 |
|
104 |
+
input_text = st.text_area("text input", "We saw difference in the results between A and B.", placeholder="Write something here...")
|
105 |
+
top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1)
|
106 |
+
input_words = st.text_input("exclude words (comma separated)", "see, saw")
|
107 |
|
108 |
if st.button('search'):
|
109 |
+
exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
|
110 |
+
df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list)
|
111 |
st.table(df)
|