File size: 7,574 Bytes
b7137b1 afd47cf a03cf87 6f5d2d2 167f69d b7137b1 3b3fa96 b7137b1 6f4ba26 b7137b1 b9fa3c7 b7137b1 08b9f95 ecce248 b7137b1 1f4b5d0 b7137b1 1f4b5d0 1c53eb1 c6d5fcb 1f4b5d0 b7137b1 f8dc81b b7137b1 1f4b5d0 b7137b1 08b9f95 293e817 1f4b5d0 b9fa3c7 1f4b5d0 b7137b1 f8dc81b b7137b1 1c53eb1 f8dc81b b7137b1 b9f419a f8dc81b b9f419a a03cf87 167f69d b9f419a f8dc81b b9f419a 3f2b07b 7190b6a d549833 7190b6a d549833 3f2b07b 7190b6a d549833 7190b6a d549833 caf6c21 815063d 6f5d2d2 caf6c21 c273b5f caf6c21 3f2b07b c273b5f caf6c21 b9f419a 77d733c 181a8b0 7190b6a 33d77c3 6f5d2d2 b7137b1 6f5d2d2 5ce6476 b7137b1 6f5d2d2 167f69d b7137b1 b9f419a 3b3fa96 b7137b1 3b3fa96 6f5d2d2 fce5f58 caf6c21 77d733c 7190b6a caf6c21 3f2b07b b27f63f d57b0b0 fde541a b9f419a b7137b1 6f5d2d2 b7137b1 6f5d2d2 b7137b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import time
import streamlit as st
import torch
import string
bert_tokenizer = None
bert_model = None
top_k = 20
from transformers import BertTokenizer, BertForMaskedLM
st.set_page_config(page_title='Qualitative pretrained model eveluation', page_icon=None, layout='centered', initial_sidebar_state='auto')
@st.cache()
def load_bert_model(model_name):
try:
bert_tokenizer = BertTokenizer.from_pretrained(model_name,do_lower_case
=False)
bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
return bert_tokenizer,bert_model
except Exception as e:
pass
def decode(tokenizer, pred_idx, top_clean):
ignore_tokens = string.punctuation
tokens = []
for w in pred_idx:
token = ''.join(tokenizer.decode(w).split())
if token not in ignore_tokens and len(token) > 1 and not token.startswith('.') and not token.startswith('['):
#tokens.append(token.replace('##', ''))
tokens.append(token)
return '\n'.join(tokens[:top_clean])
def encode(tokenizer, text_sentence, add_special_tokens=True):
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
# if <mask> is the last token, append a "." so that models dont predict punctuation.
#if tokenizer.mask_token == text_sentence.split()[-1]:
# text_sentence += ' .'
tokenized_text = bert_tokenizer.tokenize(text_sentence)
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
if (tokenizer.mask_token in text_sentence.split()):
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
else:
mask_idx = 0
return input_ids, mask_idx,tokenized_text
def get_all_predictions(text_sentence, model_name,top_clean=5):
# ========================= BERT =================================
input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
with torch.no_grad():
predict = bert_model(input_ids)[0]
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
return {'Input sentence':text_sentence,'Tokenized text': tokenized_text,'Model':model_name,'Masked position': bert,'[CLS]':cls}
else:
return {'Input sentence':text_sentence,'Tokenized text': tokenized_text,'Model':model_name,'[CLS]':cls}
def get_bert_prediction(input_text,top_k,model_name):
try:
#input_text += ' <mask>'
res = get_all_predictions(input_text,model_name, top_clean=int(top_k))
return res
except Exception as error:
pass
def run_test(sent,top_k,model_name):
start = None
global bert_tokenizer
global bert_model
if (bert_tokenizer is None):
bert_tokenizer, bert_model = load_bert_model(model_name)
with st.spinner("Computing"):
start = time.time()
try:
res = get_bert_prediction(sent,top_k,model_name)
st.caption("Results in JSON")
st.json(res)
except Exception as e:
st.error("Some error occurred during prediction" + str(e))
st.stop()
if start is not None:
st.text(f"prediction took {time.time() - start:.2f}s")
def on_text_change():
global top_k,model_name
text = st.session_state.my_text
run_test(text,top_k,model_name)
def on_option_change():
global top_k,model_name
text = st.session_state.my_choice
run_test(text,top_k,model_name)
def on_results_count_change():
top_k = int(st.session_state.my_slider)
def on_model_change1():
model_name = st.session_state.my_model1
st.info("Pre-selected model chosen: " + model_name)
bert_tokenizer, bert_model = load_bert_model(model_name)
def on_model_change2():
model_name = st.session_state.my_model2
st.info("Custom model chosen: " + model_name)
bert_tokenizer, bert_model = load_bert_model(model_name)
def init_selectbox():
option = st.selectbox(
'Choose any of these sentences or type any text below',
('', "[MASK] who lives in New York and works for XCorp suffers from Parkinson's", "Lou Gehrig who lives in [MASK] and works for XCorp suffers from Parkinson's","Lou Gehrig who lives in New York and works for [MASK] suffers from Parkinson's","Lou Gehrig who lives in New York and works for XCorp suffers from [MASK]","[MASK] who lives in New York and works for XCorp suffers from Lou Gehrig's", "Parkinson who lives in [MASK] and works for XCorp suffers from Lou Gehrig's","Parkinson who lives in New York and works for [MASK] suffers from Lou Gehrig's","Parkinson who lives in New York and works for XCorp suffers from [MASK]","Lou Gehrig","Parkinson","Lou Gehrigh's is a [MASK]","Parkinson is a [MASK]","New York is a [MASK]","New York","XCorp","XCorp is a [MASK]","acute lymphoblastic leukemia","acute lymphoblastic leukemia is a [MASK]"),on_change=on_option_change,key='my_choice')
return option
def main():
st.markdown("<h3 style='text-align: center;'>Qualitative evaluation of any pretrained BERT model</h3>", unsafe_allow_html=True)
st.markdown("""
<small style="font-size:18px; color: #7f7f7f">Pretrained BERT models can be used as is, <a href="https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html"><b>with no fine tuning to perform tasks like NER</b></a> <i>ideally if both fill-mask and CLS predictions are good, or minimally if fill-mask predictions are adequate</i></small>
""", unsafe_allow_html=True)
#st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
st.write("This app can be used to examine both model prediction for a masked position as well as the neighborhood of CLS vector")
st.write(" - To examine model prediction for a position, enter the token [MASK] or <mask>")
st.write(" - To examine just the [CLS] vector, enter a word/phrase or sentence. Example: eGFR or EGFR or non small cell lung cancer")
top_k = st.sidebar.slider("Select how many predictions do you need", 1 , 50, top_k,key='my_slider',on_change=on_results_count_change) #some times it is possible to have less words
print(top_k)
#if st.button("Submit"):
# with st.spinner("Computing"):
try:
model_name = st.sidebar.selectbox(label='Select Model to Apply', options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased','dmis-lab/biobert-v1.1'], index=0, key = "my_model1",on_change=on_model_change1)
option = init_selectbox()
input_text = st.text_input("Enter text below", "",on_change=on_text_change,key='my_text')
custom_model_name = st.text_input("Model not listed on left? Type the model name (fill-mask BERT models only)", "",key="my_model2",on_change=on_model_change2)
#if (len(custom_model_name) > 0):
# model_name = custom_model_name
#st.info("Custom model selected: " + model_name)
# bert_tokenizer, bert_model = load_bert_model(model_name)
#if len(input_text) > 0:
# run_test(input_text,top_k,model_name)
#else:
# if len(option) > 0:
# run_test(option,top_k,model_name)
if (bert_tokenizer is None):
bert_tokenizer, bert_model = load_bert_model(model_name)
except Exception as e:
st.error("Some error occurred during loading" + str(e))
st.stop()
st.write("---")
if __name__ == "__main__":
main()
|