NLP / zeroshot_clf.py
ashishraics's picture
updated architecture
e867b58
raw
history blame
1.48 kB
import pandas as pd
import streamlit
import torch
from transformers import AutoModelForSequenceClassification,AutoTokenizer
import numpy as np
import plotly.express as px
model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokenizer):
try:
labels=labels.split(',')
labels=[l.lower() for l in labels]
except:
raise Exception("please pass atleast 2 labels to classify")
premise=premise.lower()
labels_prob=[]
for l in labels:
hypothesis= f'this is an example of {l}'
input = tokenizer.encode(premise,hypothesis,
return_tensors='pt',
truncation_strategy='only_first')
output = model(input)
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item()
labels_prob.append(entail_contra_prob)
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
df=pd.DataFrame({'labels':labels,
'Probability':labels_prob_norm})
fig=px.bar(x=df['Probability'],
y=df['labels'])
return streamlit.plotly_chart(fig)
# zero_shot_classification(premise='Tiny worms and breath analyzers could screen for disease while it’s early and treatable',
# labels='science, sports, museum')