Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
@@ -6,7 +6,8 @@ import streamlit as st
|
|
6 |
import torch
|
7 |
|
8 |
import math
|
9 |
-
import os
|
|
|
10 |
|
11 |
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
12 |
|
@@ -14,7 +15,7 @@ os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
|
14 |
@st.cache(allow_output_mutation=True)
|
15 |
def load_model_and_tokenizer():
|
16 |
tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
|
17 |
-
model = AutoModel.from_pretrained("kaisugi/scitoricsbert")
|
18 |
model.eval()
|
19 |
|
20 |
return model, tokenizer
|
@@ -63,7 +64,56 @@ def load_sentence_embeddings_and_index():
|
|
63 |
|
64 |
|
65 |
@st.cache(allow_output_mutation=True)
|
66 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
with torch.no_grad():
|
68 |
inputs = tokenizer.encode_plus(
|
69 |
input_text,
|
@@ -79,19 +129,28 @@ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_d
|
|
79 |
|
80 |
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
|
81 |
retrieved_sentences = []
|
82 |
-
|
83 |
|
84 |
for id in ids[0]:
|
85 |
-
|
86 |
-
|
87 |
|
88 |
-
|
|
|
|
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
if __name__ == "__main__":
|
@@ -102,11 +161,23 @@ if __name__ == "__main__":
|
|
102 |
|
103 |
st.markdown("## AI-based Paraphrasing for Academic Writing")
|
104 |
|
105 |
-
input_text = st.text_area("text input", "
|
106 |
-
top_k = st.number_input('top_k (upperbound)', min_value=1, value=
|
107 |
-
input_words = st.text_input("exclude words (comma separated)", "
|
|
|
|
|
108 |
|
109 |
if st.button('search'):
|
110 |
exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import torch
|
7 |
|
8 |
import math
|
9 |
+
import os
|
10 |
+
import re
|
11 |
|
12 |
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
13 |
|
|
|
15 |
@st.cache(allow_output_mutation=True)
|
16 |
def load_model_and_tokenizer():
|
17 |
tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
|
18 |
+
model = AutoModel.from_pretrained("kaisugi/scitoricsbert", output_attentions=True)
|
19 |
model.eval()
|
20 |
|
21 |
return model, tokenizer
|
|
|
64 |
|
65 |
|
66 |
@st.cache(allow_output_mutation=True)
|
67 |
+
def formulaic_phrase_extraction(sentences, model, tokenizer):
|
68 |
+
THRESHOLD = 0.01
|
69 |
+
LAYER = 10
|
70 |
+
|
71 |
+
output_sentences = []
|
72 |
+
|
73 |
+
with torch.no_grad():
|
74 |
+
inputs = tokenizer.batch_encode_plus(
|
75 |
+
sentences,
|
76 |
+
padding=True,
|
77 |
+
truncation=True,
|
78 |
+
max_length=512,
|
79 |
+
return_tensors='pt'
|
80 |
+
)
|
81 |
+
outputs = model(**inputs)
|
82 |
+
attention = outputs[-1]
|
83 |
+
|
84 |
+
cls_attentions = torch.mean(attention[LAYER][0], dim=0)
|
85 |
+
|
86 |
+
for sentence, cls_attention in zip(sentences, cls_attentions):
|
87 |
+
check_bool_arr = list((cls_attention > THRESHOLD).numpy())[1:-1]
|
88 |
+
tokens = tokenizer.tokenize(sentence)
|
89 |
+
|
90 |
+
cur_tokens = tokens.copy()
|
91 |
+
|
92 |
+
while True:
|
93 |
+
flg = False
|
94 |
+
|
95 |
+
for idx, token in enumerate(cur_tokens):
|
96 |
+
if token.startswith("##"):
|
97 |
+
flg = True
|
98 |
+
back_token = token.replace("##", "")
|
99 |
+
front_token = cur_tokens.pop(idx-1)
|
100 |
+
cur_tokens[idx-1] = front_token + back_token
|
101 |
+
|
102 |
+
back_bool_val = check_bool_arr[idx]
|
103 |
+
front_bool_val = check_bool_arr.pop(idx-1)
|
104 |
+
check_bool_arr[idx-1] = front_bool_val and back_bool_val
|
105 |
+
|
106 |
+
if not flg:
|
107 |
+
break
|
108 |
+
|
109 |
+
result = " ".join([f'<font color="coral">{original_word}</font>' if b else original_word for (b, original_word) in zip(check_bool_arr, sentence.split())])
|
110 |
+
output_sentences.append(result)
|
111 |
+
|
112 |
+
return output_sentences
|
113 |
+
|
114 |
+
|
115 |
+
@st.cache(allow_output_mutation=True)
|
116 |
+
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=True):
|
117 |
with torch.no_grad():
|
118 |
inputs = tokenizer.encode_plus(
|
119 |
input_text,
|
|
|
129 |
|
130 |
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
|
131 |
retrieved_sentences = []
|
132 |
+
retrieved_paper_ids = []
|
133 |
|
134 |
for id in ids[0]:
|
135 |
+
cur_sentence = sentence_df.loc[id, "sentence"]
|
136 |
+
cur_link = f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}"
|
137 |
|
138 |
+
if len(exclude_word_list) == 0:
|
139 |
+
retrieved_sentences.append(cur_sentence)
|
140 |
+
retrieved_paper_ids.append(cur_link)
|
141 |
|
142 |
+
else:
|
143 |
+
exclude_word_list_regex = '|'.join(exclude_word_list)
|
144 |
+
pat = re.compile(f'{exclude_word_list_regex}')
|
145 |
+
|
146 |
+
if not bool(pat.search(cur_sentence)):
|
147 |
+
retrieved_sentences.append(cur_sentence)
|
148 |
+
retrieved_paper_ids.append(cur_link)
|
149 |
+
|
150 |
+
if phrase_annotated:
|
151 |
+
retrieved_sentences = formulaic_phrase_extraction(retrieved_sentences, model, tokenizer)
|
152 |
+
|
153 |
+
return retrieved_sentences, retrieved_paper_ids
|
154 |
|
155 |
|
156 |
if __name__ == "__main__":
|
|
|
161 |
|
162 |
st.markdown("## AI-based Paraphrasing for Academic Writing")
|
163 |
|
164 |
+
input_text = st.text_area("text input", "Our model shows good results.", placeholder="Write something here...")
|
165 |
+
top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1)
|
166 |
+
input_words = st.text_input("exclude words (comma separated)", "good, result")
|
167 |
+
|
168 |
+
agree = st.checkbox('Include phrase annotation')
|
169 |
|
170 |
if st.button('search'):
|
171 |
exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
|
172 |
+
retrieved_sentences, retrieved_paper_ids = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=agree)
|
173 |
+
|
174 |
+
result_table_markdown = "| sentence | source link |\n|:---|:---|\n"
|
175 |
+
|
176 |
+
for (retrieved_sentence, retrieved_paper_id) in zip(retrieved_sentences, retrieved_paper_ids):
|
177 |
+
result_table_markdown += f"| {retrieved_sentence} | {retrieved_paper_id} |\n"
|
178 |
+
|
179 |
+
st.markdown(result_table_markdown, unsafe_allow_html=True)
|
180 |
+
|
181 |
+
st.markdown("---\n#### How this works")
|
182 |
+
|
183 |
+
st.markdown("This app uses ScitoricsBERT [(Sugimoto and Aizawa, 2022)](https://aclanthology.org/2022.sdp-1.7/), a functional sentence representation model, to retrieve sentences that are functionally similar to the input. It also extracts phrasal patterns that accord to the function, by leveraging the attention patterns within ScitoricsBERT.")
|