sileod's picture
Update app.py
baad6f6 verified
raw
history blame
5.23 kB
import gradio as gr
from transformers import pipeline
# Initialize the classifiers
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli")
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli")
if False:
gr.load("models/answerdotai/ModernBERT-base").launch()
# Define examples
zero_shot_examples = [
["I absolutely love this product, it's amazing!", "positive, negative, neutral"],
["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
["The sun is very bright today", "weather, astronomy, complaints, poetry"],
["I love playing video games", "entertainment, sports, education, business"],
["The car won't start", "transportation, art, cooking, literature"]
]
nli_examples = [
["A man is sleeping on a couch", "The man is awake"],
["The restaurant's waiting area is bustling, but several tables remain vacant", "The establishment is at maximum capacity"],
["The child is methodically arranging blocks while frowning in concentration", "The kid is experiencing joy"],
["Dark clouds are gathering and the pavement shows scattered wet spots", "It's been raining heavily all day"],
["A German Shepherd is exhibiting defensive behavior towards someone approaching the property", "The animal making noise is feline"]
]
def process_input(text_input, labels_or_premise, mode):
if mode == "Zero-Shot Classification":
labels = [label.strip() for label in labels_or_premise.split(',')]
prediction = zero_shot_classifier(text_input, labels)
results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
return results, ''
else: # NLI mode
pred= nli_classifier([{"text": text_input, "text_pair": labels_or_premise}],return_all_scores=True)[0]
results= {pred['label']:pred['score'] for pred in pred}
return results, ''
def update_interface(mode):
if mode == "Zero-Shot Classification":
return (
gr.update(
label="🏷️ Categories",
placeholder="Enter comma-separated categories...",
value=zero_shot_examples[0][1]
),
gr.update(value=zero_shot_examples[0][0])
)
else:
return (
gr.update(
label="πŸ”Ž Hypothesis",
placeholder="Enter a hypothesis to compare with the premise...",
value=nli_examples[0][1]
),
gr.update(value=nli_examples[0][0])
)
with gr.Blocks() as demo:
gr.Markdown("""
# tasksource/ModernBERT-nli demonstration
This spaces uses [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli),
fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
on tasksource classification tasks.
This NLI model achieves high accuracy on logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL and FOLIO.
""")
mode = gr.Radio(
["Zero-Shot Classification", "Natural Language Inference"],
label="Select Mode",
value="Zero-Shot Classification"
)
with gr.Column():
text_input = gr.Textbox(
label="✍️ Input Text",
placeholder="Enter your text...",
lines=3,
value=zero_shot_examples[0][0] # Initial value
)
labels_or_premise = gr.Textbox(
label="🏷️ Categories",
placeholder="Enter comma-separated categories...",
lines=2,
value=zero_shot_examples[0][1] # Initial value
)
submit_btn = gr.Button("Submit")
outputs = [
gr.Label(label="πŸ“Š Results"),
gr.Markdown(label="πŸ“ˆ Analysis", visible=False)
]
with gr.Column(variant="panel") as zero_shot_examples_panel:
gr.Examples(
examples=zero_shot_examples,
inputs=[text_input, labels_or_premise],
label="Zero-Shot Classification Examples",
headers=["Input Text", "Categories"] # Add headers
)
with gr.Column(variant="panel") as nli_examples_panel:
gr.Examples(
examples=nli_examples,
inputs=[text_input, labels_or_premise],
label="Natural Language Inference Examples",
headers=["Premise", "Hypothesis"] # Add headers
)
def update_visibility(mode):
return (
gr.update(visible=(mode == "Zero-Shot Classification")),
gr.update(visible=(mode == "Natural Language Inference"))
)
mode.change(
fn=update_interface,
inputs=[mode],
outputs=[labels_or_premise, text_input]
)
mode.change(
fn=update_visibility,
inputs=[mode],
outputs=[zero_shot_examples_panel, nli_examples_panel]
)
submit_btn.click(
fn=process_input,
inputs=[text_input, labels_or_premise, mode],
outputs=outputs
)
if __name__ == "__main__":
demo.launch()