File size: 5,233 Bytes
04e7b78
9604b3c
 
2adecad
 
 
9604b3c
7400288
 
 
6d5fe23
 
362e959
6d5fe23
 
 
 
 
 
 
 
b92107e
 
 
 
6d5fe23
 
2adecad
 
 
 
 
 
 
72f30d6
b92a5dd
 
2adecad
 
fb5842d
 
6d5fe23
 
 
 
 
 
 
 
fb5842d
6d5fe23
 
 
 
 
 
 
 
fb5842d
2adecad
6d5fe23
362e959
d3061d0
3922cca
6d5fe23
3922cca
 
6d5fe23
 
2adecad
 
 
 
 
d3061d0
2adecad
 
 
 
6d5fe23
 
2adecad
 
 
fb5842d
 
6d5fe23
 
2adecad
 
 
 
 
 
 
 
068f0da
baad6f6
fb5842d
6d5fe23
fb5842d
b92a5dd
 
fb5842d
b92a5dd
fb5842d
 
6d5fe23
fb5842d
b92a5dd
 
 
b38e092
fb5842d
 
 
 
 
2adecad
b38e092
fb5842d
 
6d5fe23
fb5842d
 
 
 
b38e092
fb5842d
b38e092
d3061d0
2adecad
 
 
 
 
d3061d0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
from transformers import pipeline

# Initialize the classifiers
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli")
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli")

if False:
    gr.load("models/answerdotai/ModernBERT-base").launch()

# Define examples
zero_shot_examples = [
    ["I absolutely love this product, it's amazing!", "positive, negative, neutral"],
    ["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
    ["The sun is very bright today", "weather, astronomy, complaints, poetry"],
    ["I love playing video games", "entertainment, sports, education, business"],
    ["The car won't start", "transportation, art, cooking, literature"]
]

nli_examples = [
    ["A man is sleeping on a couch", "The man is awake"],
    ["The restaurant's waiting area is bustling, but several tables remain vacant", "The establishment is at maximum capacity"],
    ["The child is methodically arranging blocks while frowning in concentration", "The kid is experiencing joy"],
    ["Dark clouds are gathering and the pavement shows scattered wet spots", "It's been raining heavily all day"],
    ["A German Shepherd is exhibiting defensive behavior towards someone approaching the property", "The animal making noise is feline"]
]

def process_input(text_input, labels_or_premise, mode):
    if mode == "Zero-Shot Classification":
        labels = [label.strip() for label in labels_or_premise.split(',')]
        prediction = zero_shot_classifier(text_input, labels)
        results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
        return results, ''
    else:  # NLI mode
        pred= nli_classifier([{"text": text_input, "text_pair": labels_or_premise}],return_all_scores=True)[0]
        results= {pred['label']:pred['score'] for pred in pred}

        return results, ''

def update_interface(mode):
    if mode == "Zero-Shot Classification":
        return (
            gr.update(
                label="🏷️ Categories", 
                placeholder="Enter comma-separated categories...",
                value=zero_shot_examples[0][1]
            ),
            gr.update(value=zero_shot_examples[0][0])
        )
    else:
        return (
            gr.update(
                label="πŸ”Ž Hypothesis", 
                placeholder="Enter a hypothesis to compare with the premise...",
                value=nli_examples[0][1]
            ),
            gr.update(value=nli_examples[0][0])
        )

with gr.Blocks() as demo:
    gr.Markdown("""
    # tasksource/ModernBERT-nli demonstration
    
    This spaces uses [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli), 
    fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) 
    on tasksource classification tasks. 
    This NLI model achieves high accuracy on logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL and FOLIO.
    """)

    mode = gr.Radio(
        ["Zero-Shot Classification", "Natural Language Inference"],
        label="Select Mode",
        value="Zero-Shot Classification"
    )
    
    with gr.Column():
        text_input = gr.Textbox(
            label="✍️ Input Text",
            placeholder="Enter your text...",
            lines=3,
            value=zero_shot_examples[0][0]  # Initial value
        )
        
        labels_or_premise = gr.Textbox(
            label="🏷️ Categories",
            placeholder="Enter comma-separated categories...",
            lines=2,
            value=zero_shot_examples[0][1]  # Initial value
        )
        
        submit_btn = gr.Button("Submit")
        
        outputs = [
            gr.Label(label="πŸ“Š Results"),
            gr.Markdown(label="πŸ“ˆ Analysis", visible=False)
        ]

        with gr.Column(variant="panel") as zero_shot_examples_panel:
            gr.Examples(
                examples=zero_shot_examples,
                inputs=[text_input, labels_or_premise],
                label="Zero-Shot Classification Examples",
                headers=["Input Text", "Categories"]  # Add headers
            )
    
        with gr.Column(variant="panel") as nli_examples_panel:
            gr.Examples(
                examples=nli_examples,
                inputs=[text_input, labels_or_premise],
                label="Natural Language Inference Examples",
                headers=["Premise", "Hypothesis"]  # Add headers
        )

    def update_visibility(mode):
        return (
            gr.update(visible=(mode == "Zero-Shot Classification")),
            gr.update(visible=(mode == "Natural Language Inference"))
        )

    mode.change(
        fn=update_interface,
        inputs=[mode],
        outputs=[labels_or_premise, text_input]
    )
    
    mode.change(
        fn=update_visibility,
        inputs=[mode],
        outputs=[zero_shot_examples_panel, nli_examples_panel]
    )
    
    submit_btn.click(
        fn=process_input,
        inputs=[text_input, labels_or_premise, mode],
        outputs=outputs
    )

if __name__ == "__main__":
    demo.launch()