File size: 8,213 Bytes
b7137b1 afd47cf b72b8d9 167f69d b7137b1 ddeee45 3b3fa96 b7137b1 6f4ba26 b7137b1 b9fa3c7 b7137b1 08b9f95 ecce248 b7137b1 a03c359 b7137b1 424d29e 1c53eb1 c6d5fcb 1f4b5d0 b7137b1 f8dc81b a03c359 424d29e a03c359 b7137b1 1f4b5d0 b7137b1 de31505 293e817 2406613 b9fa3c7 2406613 b7137b1 f8dc81b b7137b1 1c53eb1 f8dc81b b7137b1 b9f419a f8dc81b b9f419a d1b63cc 41b19ce d1b63cc b9f419a d1b63cc b9f419a 3f2b07b 7190b6a d1b63cc 7190b6a d9f055f 3f2b07b 7190b6a d1b63cc 7190b6a 3eb018f d9f055f caf6c21 815063d d1b63cc caf6c21 c273b5f d1b63cc 3f2b07b 0e5769d d1b63cc b9f419a 77d733c d1357c0 181a8b0 7190b6a 6f5d2d2 d1b63cc b7137b1 6f5d2d2 d1b63cc eb9cce7 a340b6b ddeee45 6f5d2d2 3856ec6 b7137b1 8a3b8f4 6f5d2d2 1cfaf5b d1b63cc 167f69d 6f5d2d2 d1b63cc d1357c0 d1b63cc 8a3b8f4 a7a3b55 3eb018f fde541a b9f419a b7137b1 6f5d2d2 b7137b1 6f5d2d2 316bafd 6f5d2d2 b7137b1 |
|
import time
import streamlit as st
import torch
import string
from transformers import BertTokenizer, BertForMaskedLM
st.set_page_config(page_title='Compare BERT models qualitatively', page_icon=None, layout='centered', initial_sidebar_state='auto')
@st.cache()
def load_bert_model(model_name):
try:
bert_tokenizer = BertTokenizer.from_pretrained(model_name,do_lower_case
=False)
bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
return bert_tokenizer,bert_model
except Exception as e:
pass
def decode(tokenizer, pred_idx, top_clean):
ignore_tokens = string.punctuation
tokens = []
for w in pred_idx:
token = ''.join(tokenizer.decode(w).split())
if token not in ignore_tokens and len(token) > 1 and not token.startswith('.') and not token.startswith('['):
#tokens.append(token.replace('##', ''))
tokens.append(token)
return '\n'.join(tokens[:top_clean])
def encode(tokenizer, text_sentence, add_special_tokens=True):
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
tokenized_text = tokenizer.tokenize(text_sentence)
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
if (tokenizer.mask_token in text_sentence.split()):
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
else:
mask_idx = 0
return input_ids, mask_idx,tokenized_text
def get_all_predictions(text_sentence, model_name,top_clean=5):
bert_tokenizer = st.session_state['bert_tokenizer']
bert_model = st.session_state['bert_model']
top_k = st.session_state['top_k']
# ========================= BERT =================================
input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
with torch.no_grad():
predict = bert_model(input_ids)[0]
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*10).indices.tolist(), top_clean)
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*10).indices.tolist(), top_clean)
if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
return {'Input sentence':text_sentence,'Tokenized text': tokenized_text, 'results_count':top_k,'Model':model_name,'Masked position': bert,'[CLS]':cls}
else:
return {'Input sentence':text_sentence,'Tokenized text': tokenized_text,'results_count':top_k,'Model':model_name,'[CLS]':cls}
def get_bert_prediction(input_text,top_k,model_name):
try:
#input_text += ' <mask>'
res = get_all_predictions(input_text,model_name, top_clean=int(top_k))
return res
except Exception as error:
pass
def run_test(sent,top_k,model_name):
start = None
if (st.session_state['bert_tokenizer'] is None):
st.info("Loading model:" + st.session_state['model_name'])
st.session_state['bert_tokenizer'], st.session_state['bert_model'] = load_bert_model(st.session_state['model_name'])
with st.spinner("Computing"):
start = time.time()
try:
res = get_bert_prediction(sent,st.session_state['top_k'],st.session_state['model_name'])
st.caption("Results in JSON")
st.json(res)
except Exception as e:
st.error("Some error occurred during prediction" + str(e))
st.stop()
if start is not None:
st.text(f"prediction took {time.time() - start:.2f}s")
def on_text_change():
text = st.session_state.my_text
run_test(text,st.session_state['top_k'],st.session_state['model_name'])
def on_option_change():
text = st.session_state.my_choice
#st.info("Preselected text chosen:" + text)
run_test(text,st.session_state['top_k'],st.session_state['model_name'])
def on_results_count_change():
st.session_state['top_k'] = int(st.session_state.my_slider)
st.info("Results count changed " + str(st.session_state['top_k']))
def on_model_change1():
st.session_state['model_name'] = st.session_state.my_model1
st.info("Pre-selected model chosen: " + st.session_state['model_name'])
st.session_state['bert_tokenizer'], st.session_state['bert_model'] = load_bert_model(st.session_state['model_name'])
def on_model_change2():
st.session_state['model_name'] = st.session_state.my_model2
st.info("Custom model chosen: " + st.session_state['model_name'])
st.session_state['bert_tokenizer'], st.session_state['bert_model'] = load_bert_model(st.session_state['model_name'])
def init_selectbox():
st.selectbox(
'Choose any of these sentences or type any text below',
('', "[MASK] who lives in New York and works for XCorp suffers from Parkinson's", "Lou Gehrig who lives in [MASK] and works for XCorp suffers from Parkinson's","Lou Gehrig who lives in New York and works for [MASK] suffers from Parkinson's","Lou Gehrig who lives in New York and works for XCorp suffers from [MASK]","[MASK] who lives in New York and works for XCorp suffers from Lou Gehrig's", "Parkinson who lives in [MASK] and works for XCorp suffers from Lou Gehrig's","Parkinson who lives in New York and works for [MASK] suffers from Lou Gehrig's","Parkinson who lives in New York and works for XCorp suffers from [MASK]","Lou Gehrig","Parkinson","Lou Gehrigh's is a [MASK]","Parkinson is a [MASK]","New York is a [MASK]","New York","XCorp","XCorp is a [MASK]","acute lymphoblastic leukemia","acute lymphoblastic leukemia is a [MASK]"),on_change=on_option_change,key='my_choice')
def init_session_states():
if 'top_k' not in st.session_state:
st.session_state['top_k'] = 20
if 'bert_tokenizer' not in st.session_state:
st.session_state['bert_tokenizer'] = None
if 'bert_model' not in st.session_state:
st.session_state['bert_model'] = None
if 'model_name' not in st.session_state:
st.session_state['model_name'] = "ajitrajasekharan/biomedical"
def main():
init_session_states()
st.markdown("<h3 style='text-align: center;'>Compare BERT models qualitatively</h3>", unsafe_allow_html=True)
st.markdown("""
<small style="font-size:20px; color: #2f2f2f"><br/>Why compare pretrained models <b>before fine-tuning</b>?</small><br/><small style="font-size:18px; color: #7f7f7f">Pretrained BERT models can be used as is, <a href="https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html"><b>with no fine tuning to perform tasks like NER.</b><br/></a>This can be done ideally by using both fill-mask and CLS predictions, or minimally using fill-mask predictions alone if they are adequate</small>
""", unsafe_allow_html=True)
st.write("This app can be used to examine both model prediction for a masked position as well as the neighborhood of CLS vector")
st.write(" - To examine model prediction for a position, enter the token [MASK] or <mask>")
st.write(" - To examine just the [CLS] vector, enter a word/phrase or sentence. Example: eGFR or EGFR or non small cell lung cancer")
st.sidebar.slider("Select count of predictions to display", 1 , 50, 20,key='my_slider',on_change=on_results_count_change) #some times it is possible to have less words
try:
st.sidebar.selectbox(label='Select Model to Apply', options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased','dmis-lab/biobert-v1.1'], index=0, key = "my_model1",on_change=on_model_change1)
init_selectbox()
st.text_input("Enter text below", "",on_change=on_text_change,key='my_text')
st.text_input("Model not listed on left? Type the model name (fill-mask BERT models only)", "",key="my_model2",on_change=on_model_change2)
st.info("Current status:")
st.info("Selected results count = " + str(st.session_state['top_k']))
st.info("Selected Model name = " + st.session_state['model_name'])
#if (st.session_state['bert_tokenizer'] is None):
# st.session_state['bert_tokenizer'], st.session_state['bert_model'] = load_bert_model(st.session_state['model_name'])
except Exception as e:
st.error("Some error occurred during loading" + str(e))
st.stop()
st.write("---")
if __name__ == "__main__":
main()
|