import gradio as gr from transformers import pipeline # Initialize the classifiers zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli") nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli") if False: gr.load("models/answerdotai/ModernBERT-base").launch() # Define examples zero_shot_examples = [ ["I absolutely love this product, it's amazing!", "positive, negative, neutral"], ["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"], ["The sun is very bright today", "weather, astronomy, complaints, poetry"], ["I love playing video games", "entertainment, sports, education, business"], ["The car won't start", "transportation, art, cooking, literature"] ] nli_examples = [ ["A man is sleeping on a couch", "The man is awake"], ["The restaurant's waiting area is bustling, but several tables remain vacant", "The establishment is at maximum capacity"], ["The child is methodically arranging blocks while frowning in concentration", "The kid is experiencing joy"], ["Dark clouds are gathering and the pavement shows scattered wet spots", "It's been raining heavily all day"], ["A German Shepherd is exhibiting defensive behavior towards someone approaching the property", "The animal making noise is feline"] ] def process_input(text_input, labels_or_premise, mode): if mode == "Zero-Shot Classification": labels = [label.strip() for label in labels_or_premise.split(',')] prediction = zero_shot_classifier(text_input, labels) results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])} return results, '' else: # NLI mode pred= nli_classifier([{"text": text_input, "text_pair": labels_or_premise}],return_all_scores=True)[0] results= {pred['label']:pred['score'] for pred in pred} return results, '' def update_interface(mode): if mode == "Zero-Shot Classification": return ( gr.update( label="🏷️ Categories", placeholder="Enter comma-separated categories...", value=zero_shot_examples[0][1] ), gr.update(value=zero_shot_examples[0][0]) ) else: return ( gr.update( label="🔎 Hypothesis", placeholder="Enter a hypothesis to compare with the premise...", value=nli_examples[0][1] ), gr.update(value=nli_examples[0][0]) ) with gr.Blocks() as demo: gr.Markdown(""" # tasksource/ModernBERT-nli demonstration This spaces uses [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli), fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on tasksource classification tasks. This NLI model achieves high accuracy on logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL and FOLIO. """) mode = gr.Radio( ["Zero-Shot Classification", "Natural Language Inference"], label="Select Mode", value="Zero-Shot Classification" ) with gr.Column(): text_input = gr.Textbox( label="✍️ Input Text", placeholder="Enter your text...", lines=3, value=zero_shot_examples[0][0] # Initial value ) labels_or_premise = gr.Textbox( label="🏷️ Categories", placeholder="Enter comma-separated categories...", lines=2, value=zero_shot_examples[0][1] # Initial value ) submit_btn = gr.Button("Submit") outputs = [ gr.Label(label="📊 Results"), gr.Markdown(label="📈 Analysis", visible=False) ] with gr.Column(variant="panel") as zero_shot_examples_panel: gr.Examples( examples=zero_shot_examples, inputs=[text_input, labels_or_premise], label="Zero-Shot Classification Examples", headers=["Input Text", "Categories"] # Add headers ) with gr.Column(variant="panel") as nli_examples_panel: gr.Examples( examples=nli_examples, inputs=[text_input, labels_or_premise], label="Natural Language Inference Examples", headers=["Premise", "Hypothesis"] # Add headers ) def update_visibility(mode): return ( gr.update(visible=(mode == "Zero-Shot Classification")), gr.update(visible=(mode == "Natural Language Inference")) ) mode.change( fn=update_interface, inputs=[mode], outputs=[labels_or_premise, text_input] ) mode.change( fn=update_visibility, inputs=[mode], outputs=[zero_shot_examples_panel, nli_examples_panel] ) submit_btn.click( fn=process_input, inputs=[text_input, labels_or_premise, mode], outputs=outputs ) if __name__ == "__main__": demo.launch()