File size: 10,736 Bytes
04e7b78 9604b3c 9e973e8 a51d9c2 66a71cd a51d9c2 9604b3c 2adecad bcd7e20 9604b3c bcd7e20 6d5fe23 362e959 6d5fe23 b92107e 6d5fe23 d2e2fb3 bbcf520 d2e2fb3 b0bab2b 66a71cd b0bab2b 66a71cd b0bab2b 66a71cd b0bab2b 66a71cd b0bab2b 66a71cd b0bab2b 66a71cd b0bab2b 66a71cd b0bab2b 66a71cd b0bab2b 66a71cd 4800973 2adecad d2e2fb3 2adecad d2e2fb3 4800973 d2e2fb3 4800973 d2e2fb3 4800973 d2e2fb3 4800973 66a71cd fb5842d bcd7e20 df02321 bcd7e20 2adecad 6d5fe23 362e959 d3061d0 d2e2fb3 6d5fe23 3922cca 7660f80 6d5fe23 2adecad d2e2fb3 2adecad d3061d0 2adecad 6d5fe23 d2e2fb3 2adecad fb5842d 6d5fe23 d2e2fb3 2adecad bcd7e20 2adecad 068f0da baad6f6 af7b7a1 6d5fe23 fb5842d b92a5dd fb5842d b92a5dd fb5842d af7b7a1 6d5fe23 fb5842d b92a5dd d2e2fb3 b38e092 fb5842d 6d5fe23 fb5842d b38e092 d2e2fb3 b38e092 d3061d0 2adecad d3061d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
import gradio as gr
from transformers import pipeline
import re
# Custom sentence tokenizer
def sent_tokenize(text):
sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)(\s|$)')
sentences = sentence_endings.split(text)
return [s.strip() for s in sentences if s.strip()]
# Initialize the classifiers
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli", device="cpu")
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli", device="cpu")
# [Previous example definitions remain the same]
zero_shot_examples = [
["I absolutely love this product, it's amazing!", "positive, negative, neutral"],
["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
["The sun is very bright today", "weather, astronomy, complaints, poetry"],
["I love playing video games", "entertainment, sports, education, business"],
["The car won't start", "transportation, art, cooking, literature"]
]
nli_examples = [
["A man is sleeping on a couch", "The man is awake"],
["The restaurant's waiting area is bustling, but several tables remain vacant", "The establishment is at maximum capacity"],
["The child is methodically arranging blocks while frowning in concentration", "The kid is experiencing joy"],
["Dark clouds are gathering and the pavement shows scattered wet spots", "It's been raining heavily all day"],
["A German Shepherd is exhibiting defensive behavior towards someone approaching the property", "The animal making noise is feline"]
]
long_context_examples = [
["""A company's environmental policy typically has a profound impact upon its standing in the community. There are legal regulations, with stiff penalties attached, which compel managers to ensure that any waste products are disposed of without contaminating the air or water supplies. In addition, employees can be educated about the inevitable commercial and social benefits of recycling paper and other substances produced as by-products of the manufacturing process. One popular method for gaining staff co-operation is the internal incentive scheme. These often target teams rather than individuals, since the interdependence of staff organising any reprocessing, masks the importance of a given player's role.""",
"The regard held for an organisation may be affected by its commitment to environmental issues."]
]
def get_label_color(label, confidence=1.0):
"""Return color based on NLI label with confidence-based saturation."""
base_colors = {
'entailment': 'rgb(144, 238, 144)', # Light green
'neutral': 'rgb(255, 229, 180)', # Peach
'contradiction': 'rgb(255, 182, 193)' # Light pink
}
# Convert RGB color to RGBA with confidence-based alpha
if label in base_colors:
rgb = base_colors[label].replace('rgb(', '').replace(')', '').split(',')
r, g, b = map(int, rgb)
# Adjust the color based on confidence
alpha = 0.3 + (0.7 * confidence) # Range from 0.3 to 1.0
return f"rgba({r},{g},{b},{alpha})"
return '#FFFFFF'
def create_analysis_html(sentence_results, global_label, global_confidence):
"""Create HTML table for sentence analysis with color coding and confidence."""
html = """
<style>
.analysis-table {
width: 100%;
border-collapse: collapse;
margin: 20px 0;
font-family: Arial, sans-serif;
}
.analysis-table th, .analysis-table td {
padding: 12px;
border: 1px solid #ddd;
text-align: left;
}
.analysis-table th {
background-color: #f5f5f5;
}
.global-prediction {
padding: 15px;
margin: 20px 0;
border-radius: 5px;
font-weight: bold;
}
.confidence {
font-size: 0.9em;
color: #666;
}
</style>
"""
# Add global prediction box with confidence
html += f"""
<div class="global-prediction" style="background-color: {get_label_color(global_label, global_confidence)}">
Global Prediction: {global_label}
<span class="confidence">(Confidence: {global_confidence:.2%})</span>
</div>
"""
# Create table
html += """
<table class="analysis-table">
<tr>
<th>Sentence</th>
<th>Prediction</th>
<th>Confidence</th>
</tr>
"""
# Add rows for each sentence
for result in sentence_results:
html += f"""
<tr style="background-color: {get_label_color(result['prediction'], result['confidence'])}">
<td>{result['sentence']}</td>
<td>{result['prediction']}</td>
<td>{result['confidence']:.2%}</td>
</tr>
"""
html += "</table>"
return html
def process_input(text_input, labels_or_premise, mode):
if mode == "Zero-Shot Classification":
labels = [label.strip() for label in labels_or_premise.split(',')]
prediction = zero_shot_classifier(text_input, labels)
results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
return results, ''
elif mode == "Natural Language Inference":
pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
results = {pred['label']: pred['score'] for pred in pred}
return results, ''
else: # Long Context NLI
# Global prediction
global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
global_results = {p['label']: p['score'] for p in global_pred}
global_label = max(global_results.items(), key=lambda x: x[1])[0]
global_confidence = max(global_results.values())
# Sentence-level analysis
sentences = sent_tokenize(text_input)
sentence_results = []
for sentence in sentences:
sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
# Get the prediction and confidence for the sentence
pred_scores = [(p['label'], p['score']) for p in sent_pred]
max_pred = max(pred_scores, key=lambda x: x[1])
max_label, confidence = max_pred
sentence_results.append({
'sentence': sentence,
'prediction': max_label,
'confidence': confidence
})
analysis_html = create_analysis_html(sentence_results, global_label, global_confidence)
return global_results, analysis_html
def update_interface(mode):
if mode == "Zero-Shot Classification":
return (
gr.update(
label="π·οΈ Categories",
placeholder="Enter comma-separated categories...",
value=zero_shot_examples[0][1]
),
gr.update(value=zero_shot_examples[0][0])
)
elif mode == "Natural Language Inference":
return (
gr.update(
label="π Hypothesis",
placeholder="Enter a hypothesis to compare with the premise...",
value=nli_examples[0][1]
),
gr.update(value=nli_examples[0][0])
)
else: # Long Context NLI
return (
gr.update(
label="π Hypothesis",
placeholder="Enter a hypothesis to test against the full context...",
value=long_context_examples[0][1]
),
gr.update(value=long_context_examples[0][0])
)
def update_visibility(mode):
return (
gr.update(visible=(mode == "Zero-Shot Classification")),
gr.update(visible=(mode == "Natural Language Inference")),
gr.update(visible=(mode == "Long Context NLI"))
)
# Now define the Blocks interface
with gr.Blocks() as demo:
gr.Markdown("""
# tasksource/ModernBERT-nli demonstration
This space uses [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli),
fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
on tasksource classification tasks.
This NLI model achieves high accuracy on categorization, logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL (long-context NLI) and FOLIO (logical reasoning).
""")
mode = gr.Radio(
["Zero-Shot Classification", "Natural Language Inference", "Long Context NLI"],
label="Select Mode",
value="Zero-Shot Classification"
)
with gr.Column():
text_input = gr.Textbox(
label="βοΈ Input Text",
placeholder="Enter your text...",
lines=3,
value=zero_shot_examples[0][0]
)
labels_or_premise = gr.Textbox(
label="π·οΈ Categories",
placeholder="Enter comma-separated categories...",
lines=2,
value=zero_shot_examples[0][1]
)
submit_btn = gr.Button("Submit")
outputs = [
gr.Label(label="π Results"),
gr.HTML(label="π Sentence Analysis")
]
with gr.Column(variant="panel") as zero_shot_examples_panel:
gr.Examples(
examples=zero_shot_examples,
inputs=[text_input, labels_or_premise],
label="Zero-Shot Classification Examples",
)
with gr.Column(variant="panel") as nli_examples_panel:
gr.Examples(
examples=nli_examples,
inputs=[text_input, labels_or_premise],
label="Natural Language Inference Examples",
)
with gr.Column(variant="panel") as long_context_examples_panel:
gr.Examples(
examples=long_context_examples,
inputs=[text_input, labels_or_premise],
label="Long Context NLI Examples",
)
mode.change(
fn=update_interface,
inputs=[mode],
outputs=[labels_or_premise, text_input]
)
mode.change(
fn=update_visibility,
inputs=[mode],
outputs=[zero_shot_examples_panel, nli_examples_panel, long_context_examples_panel]
)
submit_btn.click(
fn=process_input,
inputs=[text_input, labels_or_premise, mode],
outputs=outputs
)
if __name__ == "__main__":
demo.launch() |