Spaces:
Runtime error
Runtime error
from transformers import AutoModel, AutoTokenizer | |
import faiss | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import torch | |
import math | |
import os | |
import re | |
os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
def load_model_and_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert") | |
model = AutoModel.from_pretrained("kaisugi/scitoricsbert", output_attentions=True) | |
model.eval() | |
return model, tokenizer | |
def load_sentence_data(): | |
sentence_df = pd.read_csv("sentence_data_858k.csv.gz") | |
return sentence_df | |
def load_sentence_embeddings_and_index(): | |
npz_comp = np.load("sentence_embeddings_858k.npz") | |
sentence_embeddings = npz_comp["arr_0"] | |
faiss.normalize_L2(sentence_embeddings) | |
D = 768 | |
N = 857610 | |
Xt = sentence_embeddings[:100000] | |
X = sentence_embeddings | |
# Param of PQ | |
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc. | |
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte | |
# Param of IVF | |
nlist = int(math.sqrt(N)) # The number of cells (space partition). Typical value is sqrt(N) | |
# Param of HNSW | |
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32 | |
# Setup | |
quantizer = faiss.IndexHNSWFlat(D, hnsw_m) | |
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits) | |
# Train | |
index.train(Xt) | |
# Add | |
index.add(X) | |
# Search | |
index.nprobe = 8 # Runtime param. The number of cells that are visited for search. | |
return sentence_embeddings, index | |
def formulaic_phrase_extraction(sentences, model, tokenizer): | |
THRESHOLD = 0.01 | |
LAYER = 10 | |
output_sentences = [] | |
with torch.no_grad(): | |
inputs = tokenizer.batch_encode_plus( | |
sentences, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors='pt' | |
) | |
outputs = model(**inputs) | |
attention = outputs[-1] | |
cls_attentions = torch.mean(attention[LAYER][0], dim=0) | |
for sentence, cls_attention in zip(sentences, cls_attentions): | |
check_bool_arr = list((cls_attention > THRESHOLD).numpy())[1:-1] | |
tokens = tokenizer.tokenize(sentence) | |
cur_tokens = tokens.copy() | |
while True: | |
flg = False | |
for idx, token in enumerate(cur_tokens): | |
if token.startswith("##"): | |
flg = True | |
back_token = token.replace("##", "") | |
front_token = cur_tokens.pop(idx-1) | |
cur_tokens[idx-1] = front_token + back_token | |
back_bool_val = check_bool_arr[idx] | |
front_bool_val = check_bool_arr.pop(idx-1) | |
check_bool_arr[idx-1] = front_bool_val and back_bool_val | |
if not flg: | |
break | |
result = " ".join([f'<font color="coral">{original_word}</font>' if b else original_word for (b, original_word) in zip(check_bool_arr, sentence.split())]) | |
output_sentences.append(result) | |
return output_sentences | |
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=True): | |
with torch.no_grad(): | |
inputs = tokenizer.encode_plus( | |
input_text, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors='pt' | |
) | |
outputs = model(**inputs) | |
query_embeddings = outputs.last_hidden_state[:, 0, :][0] | |
query_embeddings = query_embeddings.detach().cpu().numpy() | |
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2) | |
_, ids = index.search(x=np.array([query_embeddings]), k=top_k) | |
retrieved_sentences = [] | |
retrieved_paper_ids = [] | |
for id in ids[0]: | |
cur_sentence = sentence_df.loc[id, "sentence"] | |
cur_link = f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}" | |
if len(exclude_word_list) == 0: | |
retrieved_sentences.append(cur_sentence) | |
retrieved_paper_ids.append(cur_link) | |
else: | |
exclude_word_list_regex = '|'.join(exclude_word_list) | |
pat = re.compile(f'{exclude_word_list_regex}') | |
if not bool(pat.search(cur_sentence)): | |
retrieved_sentences.append(cur_sentence) | |
retrieved_paper_ids.append(cur_link) | |
if phrase_annotated: | |
retrieved_sentences = formulaic_phrase_extraction(retrieved_sentences, model, tokenizer) | |
return retrieved_sentences, retrieved_paper_ids | |
if __name__ == "__main__": | |
model, tokenizer = load_model_and_tokenizer() | |
sentence_df = load_sentence_data() | |
sentence_embeddings, index = load_sentence_embeddings_and_index() | |
st.markdown("## AI-based Paraphrasing for Academic Writing") | |
input_text = st.text_area("text input", "Our model shows good results.", placeholder="Write something here...") | |
top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1) | |
input_words = st.text_input("exclude words (comma separated)", "good, result") | |
agree = st.checkbox('Include phrase annotation') | |
if st.button('search'): | |
exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""] | |
retrieved_sentences, retrieved_paper_ids = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=agree) | |
result_table_markdown = "| sentence | source link |\n|:---|:---|\n" | |
for (retrieved_sentence, retrieved_paper_id) in zip(retrieved_sentences, retrieved_paper_ids): | |
result_table_markdown += f"| {retrieved_sentence} | {retrieved_paper_id} |\n" | |
st.markdown(result_table_markdown, unsafe_allow_html=True) | |
st.markdown("---\n#### How this works") | |
st.markdown("This app uses ScitoricsBERT [(Sugimoto and Aizawa, 2022)](https://aclanthology.org/2022.sdp-1.7/), a functional sentence representation model, to retrieve sentences that are functionally similar to the input. It also extracts phrasal patterns that accord to the function, by leveraging the attention patterns within ScitoricsBERT.") |