kaisugi commited on
Commit
7db6000
·
1 Parent(s): e1a3f25
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -62,7 +62,7 @@ def load_sentence_embeddings_and_index():
62
 
63
 
64
  @st.cache(allow_output_mutation=True)
65
- def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
66
  with torch.no_grad():
67
  inputs = tokenizer.encode_plus(
68
  input_text,
@@ -78,11 +78,19 @@ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_d
78
 
79
  _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
80
  retrieved_sentences = []
 
81
 
82
  for id in ids[0]:
83
  retrieved_sentences.append(sentence_df.loc[id, "sentence"])
 
84
 
85
- return pd.DataFrame({"sentences": retrieved_sentences})
 
 
 
 
 
 
86
 
87
 
88
  if __name__ == "__main__":
@@ -93,9 +101,11 @@ if __name__ == "__main__":
93
 
94
  st.markdown("## AI-based Paraphrasing for Academic Writing")
95
 
96
- input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
97
- top_k = st.number_input('top_k', min_value=1, value=10, step=1)
 
98
 
99
  if st.button('search'):
100
- df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
 
101
  st.table(df)
 
62
 
63
 
64
  @st.cache(allow_output_mutation=True)
65
+ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list):
66
  with torch.no_grad():
67
  inputs = tokenizer.encode_plus(
68
  input_text,
 
78
 
79
  _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
80
  retrieved_sentences = []
81
+ retrieved_paper_id = []
82
 
83
  for id in ids[0]:
84
  retrieved_sentences.append(sentence_df.loc[id, "sentence"])
85
+ retrieved_paper_id.append(f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}")
86
 
87
+ all_df = pd.DataFrame({"sentence": retrieved_sentences, "source link": retrieved_paper_id})
88
+
89
+ if len(exclude_word_list) == 0:
90
+ return all_df
91
+ else:
92
+ exclude_word_list_regex = '|'.join(exclude_word_list)
93
+ return all_df[~all_df["sentence"].str.contains(exclude_word_list_regex)]
94
 
95
 
96
  if __name__ == "__main__":
 
101
 
102
  st.markdown("## AI-based Paraphrasing for Academic Writing")
103
 
104
+ input_text = st.text_area("text input", "We saw difference in the results between A and B.", placeholder="Write something here...")
105
+ top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1)
106
+ input_words = st.text_input("exclude words (comma separated)", "see, saw")
107
 
108
  if st.button('search'):
109
+ exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
110
+ df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list)
111
  st.table(df)