Shield's picture
Update app.py
58e61ff
import streamlit as st
import pandas as pd
import torch
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForSequenceClassification
import os
def getTop95(predictions):
for i in range(len(predictions)):
vals, ids = torch.topk(predictions, i)
if torch.sum(vals).item() >= 0.95:
return ids
@st.cache(show_spinner=False)
def predict(text):
classes = pd.read_csv('classes.csv')
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased")
to_predict = title + '|' + summary
X = tokenizer(to_predict, truncation=True, padding=True)
tokens = torch.tensor(X['input_ids']).unsqueeze(0)
mask = torch.tensor(X['attention_mask']).unsqueeze(0)
model = DistilBertForSequenceClassification.from_pretrained(
os.getcwd(),
num_labels=len(classes)
)
model.eval()
logits = model(tokens, mask)[0][0]
softmax = torch.nn.Softmax()
predictions = softmax(logits)
ids = getTop95(predictions)
return classes.tag.to_numpy()[ids]
st.set_page_config(
page_title="ArXiv classificator",
page_icon=":book:"
)
st.header("Theme classification of ArXiv articles")
st.markdown("""
Please enter title and summary (at least one is required) and oracul will predict classes of the arcticle according to taxonometry of ArXiv.
""")
with st.form(key='input_form'):
title = st.text_input(label='Enter title of the article here')
summary = st.text_area("Enter summary of the article here")
submit = st.form_submit_button(label='Analyze')
if submit:
if not title and not summary:
st.markdown('Please enter at least one: title or summary')
else:
with st.spinner(text='Oracul thinks, please wait for his wise prediction'):
prediction = predict(title + '|' + summary)
st.markdown("Most likely it is:")
for tag in prediction[:5]:
st.markdown(f"- {tag}")
st.markdown("Other possible variants:")
st.write(', '.join(prediction[5:]))
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)