File size: 2,262 Bytes
7ae2fd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcc574b
 
 
 
 
 
 
 
 
7ae2fd5
dcc574b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import pandas as pd
import streamlit as st
import plotly.express as px
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS

st.title("Zero-shot Turkish Text Classification")

method_selection = st.radio(
    "Select a zero-shot classification method.",
    [
        METHOD_OPTIONS["nli"],
        METHOD_OPTIONS["nsp"],
    ],
)

if method_selection == METHOD_OPTIONS["nli"]:
    model = st.selectbox(
        "Select a natural language inference model.", NLI_MODEL_OPTIONS
    )
if method_selection == METHOD_OPTIONS["nsp"]:
    model = st.selectbox(
        "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS
    )

st.header("Configure prompts and labels")
col1, col2 = st.columns(2)

with col1:
    st.subheader("Candidate labels")
    labels = st.text_area(
        label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
        value="spor,dünya,siyaset,ekonomi,kültür ve sanat",
    )
    st.header("Make predictions")
    st.text_area("", value="Enter some text to classify.")
    st.button("Predict")

with col2:
    st.subheader("Prompt template")
    prompt_template = st.text_area(
        label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
        value="Bu metin {} kategorisine aittir",
    )
    st.header("")
    probs = [0.86, 0.10, 0.01, 0.02, 0.01]
    data = pd.DataFrame(
        {"labels": labels.split(","), "probability": probs}
    ).sort_values(by="probability", ascending=False)
    chart = px.bar(
        data,
        x="probability",
        y="labels",
        color="labels",
        orientation="h",
        height=290,
        width=500,
    ).update_layout(
        {
            "xaxis": {"title": "probability", "visible": True, "showticklabels": True},
            "yaxis": {"title": None, "visible": True, "showticklabels": True},
            "margin": dict(
                l=10,  # left
                r=10,  # right
                t=50,  # top
                b=10,  # bottom
            ),
            "showlegend": False,
        }
    )
    st.plotly_chart(chart)