|
import streamlit as st |
|
from transformers import pipeline |
|
import plotly.express as px |
|
import pandas as pd |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
@st.cache(allow_output_mutation = True) |
|
def get_classifier_model(): |
|
return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("Review Analyzer") |
|
st.markdown("***") |
|
|
|
text = st.text_area(label="Paste/Type the review here..") |
|
|
|
st.markdown("***") |
|
|
|
col1, col2, col3 = st.columns((1,1,1)) |
|
|
|
col1.header("Select Sentiments") |
|
sentiments = col1.multiselect("",["Happy","Sad","Neutral"],["Happy","Sad","Neutral"]) |
|
col1.markdown(" \n") |
|
col1.markdown(" \n") |
|
|
|
additional_sentiments = col1.text_input("Enter comma separated sentiments.") |
|
|
|
if additional_sentiments: |
|
sentiments = sentiments + additional_sentiments.split(",") |
|
|
|
col2.header("Select Topics") |
|
entities = col2.multiselect("",["Bank Account","Credit Card","Home Loan","Motor Loan"], |
|
["Bank Account","Credit Card","Home Loan","Motor Loan"]) |
|
additional_entities= col2.text_input("Enter comma separated entities.") |
|
|
|
if additional_entities: |
|
entities = entities + additional_entities.split(",") |
|
|
|
|
|
col3.header("Select Reasons") |
|
|
|
reasons = col3.multiselect("",["Poor Service","No Empathy","Abuse"], |
|
["Poor Service","No Empathy","Abuse"]) |
|
additional_reasons= col3.text_input("Enter comma separated reasons.") |
|
|
|
if additional_reasons: |
|
reasons = reasons + additional_reasons.split(",") |
|
|
|
is_multi_class = st.checkbox("Can have more than one classes",value=True) |
|
|
|
st.markdown("***") |
|
|
|
classify_button_clicked = st.button("Classify") |
|
|
|
def get_classification(candidate_labels): |
|
classification_output = classifier(sequence_to_classify, candidate_labels, multi_class=is_multi_class) |
|
data = {'Class': classification_output['labels'], 'Scores': classification_output['scores']} |
|
df = pd.DataFrame(data) |
|
df = df.sort_values(by='Scores', ascending=False) |
|
fig = px.bar(df, x='Scores', y='Class', orientation='h', width=400, height=500) |
|
fig.update_layout( |
|
yaxis=dict( |
|
autorange='reversed' |
|
) |
|
) |
|
return fig |
|
|
|
|
|
|
|
if classify_button_clicked: |
|
if text: |
|
st.markdown("***") |
|
with st.spinner(" Please wait while the text is being classified.."): |
|
classifier = get_classifier_model() |
|
sequence_to_classify = text |
|
|
|
|
|
if sentiments: |
|
|
|
fig = get_classification(sentiments) |
|
|
|
col1.markdown(" \n") |
|
|
|
col1.write(fig) |
|
|
|
if entities: |
|
|
|
fig = get_classification(entities) |
|
|
|
|
|
col2.write(fig) |
|
|
|
if reasons: |
|
|
|
fig = get_classification(reasons) |
|
|
|
col3.write(fig) |