kaisugi commited on
Commit
00eee05
·
1 Parent(s): 60689d7
Files changed (1) hide show
  1. app.py +88 -17
app.py CHANGED
@@ -6,7 +6,8 @@ import streamlit as st
6
  import torch
7
 
8
  import math
9
- import os
 
10
 
11
  os.environ['KMP_DUPLICATE_LIB_OK']='True'
12
 
@@ -14,7 +15,7 @@ os.environ['KMP_DUPLICATE_LIB_OK']='True'
14
  @st.cache(allow_output_mutation=True)
15
  def load_model_and_tokenizer():
16
  tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
17
- model = AutoModel.from_pretrained("kaisugi/scitoricsbert")
18
  model.eval()
19
 
20
  return model, tokenizer
@@ -63,7 +64,56 @@ def load_sentence_embeddings_and_index():
63
 
64
 
65
  @st.cache(allow_output_mutation=True)
66
- def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with torch.no_grad():
68
  inputs = tokenizer.encode_plus(
69
  input_text,
@@ -79,19 +129,28 @@ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_d
79
 
80
  _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
81
  retrieved_sentences = []
82
- retrieved_paper_id = []
83
 
84
  for id in ids[0]:
85
- retrieved_sentences.append(sentence_df.loc[id, "sentence"])
86
- retrieved_paper_id.append(f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}")
87
 
88
- all_df = pd.DataFrame({"sentence": retrieved_sentences, "source link": retrieved_paper_id})
 
 
89
 
90
- if len(exclude_word_list) == 0:
91
- return all_df
92
- else:
93
- exclude_word_list_regex = '|'.join(exclude_word_list)
94
- return all_df[~all_df["sentence"].str.contains(exclude_word_list_regex)]
 
 
 
 
 
 
 
95
 
96
 
97
  if __name__ == "__main__":
@@ -102,11 +161,23 @@ if __name__ == "__main__":
102
 
103
  st.markdown("## AI-based Paraphrasing for Academic Writing")
104
 
105
- input_text = st.text_area("text input", "We saw difference in the results between A and B.", placeholder="Write something here...")
106
- top_k = st.number_input('top_k (upperbound)', min_value=1, value=200, step=1)
107
- input_words = st.text_input("exclude words (comma separated)", "see, saw")
 
 
108
 
109
  if st.button('search'):
110
  exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
111
- df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list)
112
- st.table(df)
 
 
 
 
 
 
 
 
 
 
 
6
  import torch
7
 
8
  import math
9
+ import os
10
+ import re
11
 
12
  os.environ['KMP_DUPLICATE_LIB_OK']='True'
13
 
 
15
  @st.cache(allow_output_mutation=True)
16
  def load_model_and_tokenizer():
17
  tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
18
+ model = AutoModel.from_pretrained("kaisugi/scitoricsbert", output_attentions=True)
19
  model.eval()
20
 
21
  return model, tokenizer
 
64
 
65
 
66
  @st.cache(allow_output_mutation=True)
67
+ def formulaic_phrase_extraction(sentences, model, tokenizer):
68
+ THRESHOLD = 0.01
69
+ LAYER = 10
70
+
71
+ output_sentences = []
72
+
73
+ with torch.no_grad():
74
+ inputs = tokenizer.batch_encode_plus(
75
+ sentences,
76
+ padding=True,
77
+ truncation=True,
78
+ max_length=512,
79
+ return_tensors='pt'
80
+ )
81
+ outputs = model(**inputs)
82
+ attention = outputs[-1]
83
+
84
+ cls_attentions = torch.mean(attention[LAYER][0], dim=0)
85
+
86
+ for sentence, cls_attention in zip(sentences, cls_attentions):
87
+ check_bool_arr = list((cls_attention > THRESHOLD).numpy())[1:-1]
88
+ tokens = tokenizer.tokenize(sentence)
89
+
90
+ cur_tokens = tokens.copy()
91
+
92
+ while True:
93
+ flg = False
94
+
95
+ for idx, token in enumerate(cur_tokens):
96
+ if token.startswith("##"):
97
+ flg = True
98
+ back_token = token.replace("##", "")
99
+ front_token = cur_tokens.pop(idx-1)
100
+ cur_tokens[idx-1] = front_token + back_token
101
+
102
+ back_bool_val = check_bool_arr[idx]
103
+ front_bool_val = check_bool_arr.pop(idx-1)
104
+ check_bool_arr[idx-1] = front_bool_val and back_bool_val
105
+
106
+ if not flg:
107
+ break
108
+
109
+ result = " ".join([f'<font color="coral">{original_word}</font>' if b else original_word for (b, original_word) in zip(check_bool_arr, sentence.split())])
110
+ output_sentences.append(result)
111
+
112
+ return output_sentences
113
+
114
+
115
+ @st.cache(allow_output_mutation=True)
116
+ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=True):
117
  with torch.no_grad():
118
  inputs = tokenizer.encode_plus(
119
  input_text,
 
129
 
130
  _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
131
  retrieved_sentences = []
132
+ retrieved_paper_ids = []
133
 
134
  for id in ids[0]:
135
+ cur_sentence = sentence_df.loc[id, "sentence"]
136
+ cur_link = f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}"
137
 
138
+ if len(exclude_word_list) == 0:
139
+ retrieved_sentences.append(cur_sentence)
140
+ retrieved_paper_ids.append(cur_link)
141
 
142
+ else:
143
+ exclude_word_list_regex = '|'.join(exclude_word_list)
144
+ pat = re.compile(f'{exclude_word_list_regex}')
145
+
146
+ if not bool(pat.search(cur_sentence)):
147
+ retrieved_sentences.append(cur_sentence)
148
+ retrieved_paper_ids.append(cur_link)
149
+
150
+ if phrase_annotated:
151
+ retrieved_sentences = formulaic_phrase_extraction(retrieved_sentences, model, tokenizer)
152
+
153
+ return retrieved_sentences, retrieved_paper_ids
154
 
155
 
156
  if __name__ == "__main__":
 
161
 
162
  st.markdown("## AI-based Paraphrasing for Academic Writing")
163
 
164
+ input_text = st.text_area("text input", "Our model shows good results.", placeholder="Write something here...")
165
+ top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1)
166
+ input_words = st.text_input("exclude words (comma separated)", "good, result")
167
+
168
+ agree = st.checkbox('Include phrase annotation')
169
 
170
  if st.button('search'):
171
  exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
172
+ retrieved_sentences, retrieved_paper_ids = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=agree)
173
+
174
+ result_table_markdown = "| sentence | source link |\n|:---|:---|\n"
175
+
176
+ for (retrieved_sentence, retrieved_paper_id) in zip(retrieved_sentences, retrieved_paper_ids):
177
+ result_table_markdown += f"| {retrieved_sentence} | {retrieved_paper_id} |\n"
178
+
179
+ st.markdown(result_table_markdown, unsafe_allow_html=True)
180
+
181
+ st.markdown("---\n#### How this works")
182
+
183
+ st.markdown("This app uses ScitoricsBERT [(Sugimoto and Aizawa, 2022)](https://aclanthology.org/2022.sdp-1.7/), a functional sentence representation model, to retrieve sentences that are functionally similar to the input. It also extracts phrasal patterns that accord to the function, by leveraging the attention patterns within ScitoricsBERT.")