Spaces:
Runtime error
Runtime error
File size: 2,262 Bytes
7ae2fd5 dcc574b 7ae2fd5 dcc574b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import pandas as pd
import streamlit as st
import plotly.express as px
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
st.title("Zero-shot Turkish Text Classification")
method_selection = st.radio(
"Select a zero-shot classification method.",
[
METHOD_OPTIONS["nli"],
METHOD_OPTIONS["nsp"],
],
)
if method_selection == METHOD_OPTIONS["nli"]:
model = st.selectbox(
"Select a natural language inference model.", NLI_MODEL_OPTIONS
)
if method_selection == METHOD_OPTIONS["nsp"]:
model = st.selectbox(
"Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS
)
st.header("Configure prompts and labels")
col1, col2 = st.columns(2)
with col1:
st.subheader("Candidate labels")
labels = st.text_area(
label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
value="spor,dünya,siyaset,ekonomi,kültür ve sanat",
)
st.header("Make predictions")
st.text_area("", value="Enter some text to classify.")
st.button("Predict")
with col2:
st.subheader("Prompt template")
prompt_template = st.text_area(
label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
value="Bu metin {} kategorisine aittir",
)
st.header("")
probs = [0.86, 0.10, 0.01, 0.02, 0.01]
data = pd.DataFrame(
{"labels": labels.split(","), "probability": probs}
).sort_values(by="probability", ascending=False)
chart = px.bar(
data,
x="probability",
y="labels",
color="labels",
orientation="h",
height=290,
width=500,
).update_layout(
{
"xaxis": {"title": "probability", "visible": True, "showticklabels": True},
"yaxis": {"title": None, "visible": True, "showticklabels": True},
"margin": dict(
l=10, # left
r=10, # right
t=50, # top
b=10, # bottom
),
"showlegend": False,
}
)
st.plotly_chart(chart)
|