File size: 5,737 Bytes
b7137b1 afd47cf b7137b1 3b3fa96 b7137b1 6f4ba26 b7137b1 b9fa3c7 b7137b1 08b9f95 ecce248 b7137b1 1c53eb1 c6d5fcb b7137b1 08b9f95 293e817 b9fa3c7 293e817 b7137b1 1c53eb1 b7137b1 b9f419a b7137b1 35a5cd4 5408f33 35a5cd4 b7137b1 d1cc326 9c3de2e 5408f33 b7137b1 b9f419a 3b3fa96 b7137b1 3b3fa96 5408f33 fce5f58 0a7c967 9edc4d0 b9f419a 1012c21 7affda1 fade61c 2056bb6 9c3de2e fce5f58 b9f419a b7137b1 b9f419a b7137b1 d4804d5 5408f33 b7137b1 32acf13 b7137b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import time
import streamlit as st
import torch
import string
from transformers import BertTokenizer, BertForMaskedLM
st.set_page_config(page_title='Qualitative pretrained model eveluation', page_icon=None, layout='centered', initial_sidebar_state='auto')
@st.cache()
def load_bert_model(model_name):
try:
bert_tokenizer = BertTokenizer.from_pretrained(model_name,do_lower_case
=False)
bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
return bert_tokenizer,bert_model
except Exception as e:
pass
def decode(tokenizer, pred_idx, top_clean):
ignore_tokens = string.punctuation
tokens = []
for w in pred_idx:
token = ''.join(tokenizer.decode(w).split())
if token not in ignore_tokens and len(token) > 1 and not token.startswith('.') and not token.startswith('['):
#tokens.append(token.replace('##', ''))
tokens.append(token)
return '\n'.join(tokens[:top_clean])
def encode(tokenizer, text_sentence, add_special_tokens=True):
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
# if <mask> is the last token, append a "." so that models dont predict punctuation.
if tokenizer.mask_token == text_sentence.split()[-1]:
text_sentence += ' .'
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
if (tokenizer.mask_token in text_sentence.split()):
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
else:
mask_idx = 0
return input_ids, mask_idx
def get_all_predictions(text_sentence, top_clean=5):
# ========================= BERT =================================
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
with torch.no_grad():
predict = bert_model(input_ids)[0]
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
return {'Input sentence':text_sentence,'Masked position': bert,'[CLS]':cls}
else:
return {'Input sentence':text_sentence,'[CLS]':cls}
def get_bert_prediction(input_text,top_k):
try:
#input_text += ' <mask>'
res = get_all_predictions(input_text, top_clean=int(top_k))
return res
except Exception as error:
pass
def run_test(sent,top_k):
start = None
with st.spinner("Computing"):
start = time.time()
try:
res = get_bert_prediction(sent,top_k)
st.caption("Results in JSON")
st.json(res)
except Exception as e:
st.error("Some error occurred during prediction" + str(e))
st.stop()
if start is not None:
st.text(f"prediction took {time.time() - start:.2f}s")
st.markdown("<h3 style='text-align: center;'>Qualitative evaluation of Pretrained BERT models</h3>", unsafe_allow_html=True)
st.markdown("""
<small style="font-size:18px; color: #8f8f8f">This app is used to qualitatively examine the performance of pretrained models to do NER , <a href="https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html"><b>with no fine tuning</b></small></a>
""", unsafe_allow_html=True)
#st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
st.write("Model prediction for a masked position as well as the neighborhood of CLS vector for input text can be examined")
st.write(" - To examine model prediction for a position, enter the token [MASK] or <mask>")
st.write(" - To examine just the [CLS] vector, enter a word/phrase or sentence. Example: eGFR or EGFR or non small cell lung cancer")
top_k = st.sidebar.slider("Select how many predictions do you need", 1 , 50, 20) #some times it is possible to have less words
print(top_k)
#if st.button("Submit"):
# with st.spinner("Computing"):
try:
model_name = st.sidebar.selectbox(label='Select Model to Apply', options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased'], index=0, key = "model_name")
option = st.selectbox(
'Choose any of these sentences or type any text below',
('', "[MASK] who lives in New York and works for XCorp suffers from Parkinson's", "Lou Gehrig who lives in [MASK] and works for XCorp suffers from Parkinson's","'Lou Gehrig who lives in New York and works for [MASK] suffers from Parkinson's'","'Lou Gehrig who lives in New York and works for XCorp suffers from [MASK]'","[MASK] who lives in New York and works for XCorp suffers from Lou Gehrig's", "Parkinson who lives in [MASK] and works for XCorp suffers from Lou Gehrig's","Parkinson who lives in New York and works for [MASK] suffers from Lou Gehrig's","Parkinson who lives in New York and works for XCorp suffers from [MASK]","Lou Gehrig","Parkinson","Lou Gehrigh's is a [MASK]","Parkinson is a [MASK]","New York is a [MASK]","New York","XCorp","XCorp is a [MASK]","acute lymphoblastic leukemia","acute lymphoblastic leukemia is a [MASK]"))
bert_tokenizer, bert_model = load_bert_model(model_name)
default_text = "Imatinib is used to [MASK] acute lymphoblastic leukemia"
input_text = st.text_area(
label="Enter text below",
value=default_text,
)
if st.button("Submit"):
run_test(input_text,top_k)
else:
if len(option) > 0:
run_test(option,top_k)
except Exception as e:
st.error("Some error occurred during loading" + str(e))
st.stop()
st.write("---")
|