File size: 10,736 Bytes
04e7b78
9604b3c
9e973e8
a51d9c2
66a71cd
a51d9c2
 
 
 
9604b3c
2adecad
bcd7e20
 
9604b3c
bcd7e20
6d5fe23
362e959
6d5fe23
 
 
 
 
 
 
 
b92107e
 
 
 
6d5fe23
 
d2e2fb3
bbcf520
 
d2e2fb3
 
b0bab2b
 
 
 
 
 
66a71cd
b0bab2b
 
 
 
 
 
 
 
 
66a71cd
b0bab2b
 
66a71cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0bab2b
 
 
 
66a71cd
 
 
b0bab2b
66a71cd
b0bab2b
 
 
66a71cd
 
 
 
 
 
 
 
 
b0bab2b
66a71cd
 
 
 
 
 
b0bab2b
66a71cd
 
b0bab2b
66a71cd
 
 
 
 
 
4800973
2adecad
 
 
 
 
 
d2e2fb3
 
 
2adecad
d2e2fb3
 
 
4800973
 
 
d2e2fb3
 
 
 
 
 
 
4800973
 
 
 
 
d2e2fb3
 
 
4800973
d2e2fb3
 
4800973
66a71cd
fb5842d
bcd7e20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df02321
bcd7e20
 
 
 
 
 
 
 
 
 
 
 
 
 
2adecad
6d5fe23
362e959
d3061d0
d2e2fb3
6d5fe23
3922cca
7660f80
6d5fe23
 
2adecad
d2e2fb3
2adecad
 
 
d3061d0
2adecad
 
 
 
6d5fe23
d2e2fb3
2adecad
 
 
fb5842d
 
6d5fe23
d2e2fb3
2adecad
 
 
 
 
 
bcd7e20
2adecad
068f0da
baad6f6
af7b7a1
6d5fe23
fb5842d
b92a5dd
fb5842d
b92a5dd
fb5842d
af7b7a1
6d5fe23
fb5842d
b92a5dd
d2e2fb3
 
 
 
 
 
 
 
b38e092
 
fb5842d
 
6d5fe23
fb5842d
 
 
 
b38e092
d2e2fb3
b38e092
d3061d0
2adecad
 
 
 
 
d3061d0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import gradio as gr
from transformers import pipeline
import re

# Custom sentence tokenizer
def sent_tokenize(text):
    sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)(\s|$)')
    sentences = sentence_endings.split(text)
    return [s.strip() for s in sentences if s.strip()]

# Initialize the classifiers
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli", device="cpu")
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli", device="cpu")

# [Previous example definitions remain the same]
zero_shot_examples = [
    ["I absolutely love this product, it's amazing!", "positive, negative, neutral"],
    ["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
    ["The sun is very bright today", "weather, astronomy, complaints, poetry"],
    ["I love playing video games", "entertainment, sports, education, business"],
    ["The car won't start", "transportation, art, cooking, literature"]
]

nli_examples = [
    ["A man is sleeping on a couch", "The man is awake"],
    ["The restaurant's waiting area is bustling, but several tables remain vacant", "The establishment is at maximum capacity"],
    ["The child is methodically arranging blocks while frowning in concentration", "The kid is experiencing joy"],
    ["Dark clouds are gathering and the pavement shows scattered wet spots", "It's been raining heavily all day"],
    ["A German Shepherd is exhibiting defensive behavior towards someone approaching the property", "The animal making noise is feline"]
]

long_context_examples = [
    ["""A company's environmental policy typically has a profound impact upon its standing in the community. There are legal regulations, with stiff penalties attached, which compel managers to ensure that any waste products are disposed of without contaminating the air or water supplies. In addition, employees can be educated about the inevitable commercial and social benefits of recycling paper and other substances produced as by-products of the manufacturing process. One popular method for gaining staff co-operation is the internal incentive scheme. These often target teams rather than individuals, since the interdependence of staff organising any reprocessing, masks the importance of a given player's role.""",
     "The regard held for an organisation may be affected by its commitment to environmental issues."]
]

def get_label_color(label, confidence=1.0):
    """Return color based on NLI label with confidence-based saturation."""
    base_colors = {
        'entailment': 'rgb(144, 238, 144)',    # Light green
        'neutral': 'rgb(255, 229, 180)',       # Peach
        'contradiction': 'rgb(255, 182, 193)'   # Light pink
    }
    
    # Convert RGB color to RGBA with confidence-based alpha
    if label in base_colors:
        rgb = base_colors[label].replace('rgb(', '').replace(')', '').split(',')
        r, g, b = map(int, rgb)
        # Adjust the color based on confidence
        alpha = 0.3 + (0.7 * confidence)  # Range from 0.3 to 1.0
        return f"rgba({r},{g},{b},{alpha})"
    return '#FFFFFF'

def create_analysis_html(sentence_results, global_label, global_confidence):
    """Create HTML table for sentence analysis with color coding and confidence."""
    html = """
    <style>
        .analysis-table {
            width: 100%;
            border-collapse: collapse;
            margin: 20px 0;
            font-family: Arial, sans-serif;
        }
        .analysis-table th, .analysis-table td {
            padding: 12px;
            border: 1px solid #ddd;
            text-align: left;
        }
        .analysis-table th {
            background-color: #f5f5f5;
        }
        .global-prediction {
            padding: 15px;
            margin: 20px 0;
            border-radius: 5px;
            font-weight: bold;
        }
        .confidence {
            font-size: 0.9em;
            color: #666;
        }
    </style>
    """
    
    # Add global prediction box with confidence
    html += f"""
    <div class="global-prediction" style="background-color: {get_label_color(global_label, global_confidence)}">
        Global Prediction: {global_label} 
        <span class="confidence">(Confidence: {global_confidence:.2%})</span>
    </div>
    """
    
    # Create table
    html += """
    <table class="analysis-table">
        <tr>
            <th>Sentence</th>
            <th>Prediction</th>
            <th>Confidence</th>
        </tr>
    """
    
    # Add rows for each sentence
    for result in sentence_results:
        html += f"""
        <tr style="background-color: {get_label_color(result['prediction'], result['confidence'])}">
            <td>{result['sentence']}</td>
            <td>{result['prediction']}</td>
            <td>{result['confidence']:.2%}</td>
        </tr>
        """
    
    html += "</table>"
    return html


def process_input(text_input, labels_or_premise, mode):
    if mode == "Zero-Shot Classification":
        labels = [label.strip() for label in labels_or_premise.split(',')]
        prediction = zero_shot_classifier(text_input, labels)
        results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
        return results, ''
    elif mode == "Natural Language Inference":
        pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
        results = {pred['label']: pred['score'] for pred in pred}
        return results, ''
    else:  # Long Context NLI
        # Global prediction
        global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
        global_results = {p['label']: p['score'] for p in global_pred}
        global_label = max(global_results.items(), key=lambda x: x[1])[0]
        global_confidence = max(global_results.values())
        
        # Sentence-level analysis
        sentences = sent_tokenize(text_input)
        sentence_results = []
        
        for sentence in sentences:
            sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
            # Get the prediction and confidence for the sentence
            pred_scores = [(p['label'], p['score']) for p in sent_pred]
            max_pred = max(pred_scores, key=lambda x: x[1])
            max_label, confidence = max_pred
            
            sentence_results.append({
                'sentence': sentence,
                'prediction': max_label,
                'confidence': confidence
            })
        
        analysis_html = create_analysis_html(sentence_results, global_label, global_confidence)
        return global_results, analysis_html

def update_interface(mode):
    if mode == "Zero-Shot Classification":
        return (
            gr.update(
                label="🏷️ Categories", 
                placeholder="Enter comma-separated categories...",
                value=zero_shot_examples[0][1]
            ),
            gr.update(value=zero_shot_examples[0][0])
        )
    elif mode == "Natural Language Inference":
        return (
            gr.update(
                label="πŸ”Ž Hypothesis", 
                placeholder="Enter a hypothesis to compare with the premise...",
                value=nli_examples[0][1]
            ),
            gr.update(value=nli_examples[0][0])
        )
    else:  # Long Context NLI
        return (
            gr.update(
                label="πŸ”Ž Hypothesis",
                placeholder="Enter a hypothesis to test against the full context...",
                value=long_context_examples[0][1]
            ),
            gr.update(value=long_context_examples[0][0])
        )

def update_visibility(mode):
    return (
        gr.update(visible=(mode == "Zero-Shot Classification")),
        gr.update(visible=(mode == "Natural Language Inference")),
        gr.update(visible=(mode == "Long Context NLI"))
    )

# Now define the Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("""
    # tasksource/ModernBERT-nli demonstration
    
    This space uses [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli), 
    fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) 
    on tasksource classification tasks. 
    This NLI model achieves high accuracy on categorization, logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL (long-context NLI) and FOLIO (logical reasoning).
    """)

    mode = gr.Radio(
        ["Zero-Shot Classification", "Natural Language Inference", "Long Context NLI"],
        label="Select Mode",
        value="Zero-Shot Classification"
    )
    
    with gr.Column():
        text_input = gr.Textbox(
            label="✍️ Input Text",
            placeholder="Enter your text...",
            lines=3,
            value=zero_shot_examples[0][0]
        )
        
        labels_or_premise = gr.Textbox(
            label="🏷️ Categories",
            placeholder="Enter comma-separated categories...",
            lines=2,
            value=zero_shot_examples[0][1]
        )
        
        submit_btn = gr.Button("Submit")
        
        outputs = [
            gr.Label(label="πŸ“Š Results"),
            gr.HTML(label="πŸ“ˆ Sentence Analysis")
        ]

        with gr.Column(variant="panel") as zero_shot_examples_panel:
            gr.Examples(
                examples=zero_shot_examples,
                inputs=[text_input, labels_or_premise],
                label="Zero-Shot Classification Examples",
            )
    
        with gr.Column(variant="panel") as nli_examples_panel:
            gr.Examples(
                examples=nli_examples,
                inputs=[text_input, labels_or_premise],
                label="Natural Language Inference Examples",
            )
            
        with gr.Column(variant="panel") as long_context_examples_panel:
            gr.Examples(
                examples=long_context_examples,
                inputs=[text_input, labels_or_premise],
                label="Long Context NLI Examples",
            )

    mode.change(
        fn=update_interface,
        inputs=[mode],
        outputs=[labels_or_premise, text_input]
    )
    
    mode.change(
        fn=update_visibility,
        inputs=[mode],
        outputs=[zero_shot_examples_panel, nli_examples_panel, long_context_examples_panel]
    )
    
    submit_btn.click(
        fn=process_input,
        inputs=[text_input, labels_or_premise, mode],
        outputs=outputs
    )

if __name__ == "__main__":
    demo.launch()