Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
from typing import Dict, Union | |
from gliner import GLiNER | |
model = GLiNER.from_pretrained("numind/NuNER_Zero") | |
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-zeroshot-v1") | |
#define a function to process your input and output | |
def zero_shot(doc, candidates): | |
given_labels = candidates.split(", ") | |
dictionary = classifier(doc, given_labels) | |
labels = dictionary['labels'] | |
scores = dictionary['scores'] | |
return dict(zip(labels, scores)) | |
examples = [ | |
[ | |
"The Moon is Earth's only natural satellite. It orbits at an average distance of 384,400 km (238,900 mi), about 30 times the diameter of Earth. Over time Earth's gravity has caused tidal locking, causing the same side of the Moon to always face Earth. Because of this, the lunar day and the lunar month are the same length, at 29.5 Earth days. The Moon's gravitational pull β and to a lesser extent, the Sun's β are the main drivers of Earth's tides.", | |
"celestial body,quantity,physical concept", | |
0.3, | |
False | |
], | |
] | |
def merge_entities(entities): | |
if not entities: | |
return [] | |
merged = [] | |
current = entities[0] | |
for next_entity in entities[1:]: | |
if next_entity['entity'] == current['entity'] and (next_entity['start'] == current['end'] + 1 or next_entity['start'] == current['end']): | |
current['word'] += ' ' + next_entity['word'] | |
current['end'] = next_entity['end'] | |
else: | |
merged.append(current) | |
current = next_entity | |
merged.append(current) | |
return merged | |
def ner( | |
text, labels: str, threshold: float, nested_ner: bool | |
) -> Dict[str, Union[str, int, float]]: | |
labels = labels.split(",") | |
r = { | |
"text": text, | |
"entities": [ | |
{ | |
"entity": entity["label"], | |
"word": entity["text"], | |
"start": entity["start"], | |
"end": entity["end"], | |
"score": 0, | |
} | |
for entity in model.predict_entities( | |
text, labels, flat_ner=not nested_ner, threshold=threshold | |
) | |
], | |
} | |
r["entities"] = merge_entities(r["entities"]) | |
return r | |
with gr.Blocks(title="Zero-Shot Demo") as demo: #, theme=gr.themes.Soft() | |
#create input and output objects | |
with gr.Tab("Zero-Shot Text Classification"): | |
#input object1 | |
input1 = gr.Textbox(label="Text") | |
#input object 2 | |
input2 = gr.Textbox(label="Labels") | |
#output object | |
output = gr.Label(label="Output") | |
#create interface | |
gui = gr.Interface( | |
title="Zero-Shot Text Classification", | |
fn=zero_shot, | |
inputs=[input1, input2], | |
outputs=[output] | |
) | |
with gr.Tab("Zero-Shot NER"): | |
gr.Markdown( | |
""" | |
# Zero-Shot Named Entity Recognition (NER) | |
""" | |
) | |
input_text = gr.Textbox( | |
value=examples[0][0], label="Text input", placeholder="Enter your text here", lines=3 | |
) | |
with gr.Row() as row: | |
labels = gr.Textbox( | |
value=examples[0][1], | |
label="Labels", | |
placeholder="Enter your labels here (comma separated)", | |
scale=2, | |
) | |
threshold = gr.Slider( | |
0, | |
1, | |
value=0.3, | |
step=0.01, | |
label="Threshold", | |
info="Lower the threshold to increase how many entities get predicted.", | |
scale=1, | |
) | |
output = gr.HighlightedText(label="Predicted Entities") | |
submit_btn = gr.Button("Submit") | |
# Submitting | |
# input_text.submit( | |
# fn=ner, inputs=[input_text, labels, threshold], outputs=output | |
# ) | |
# labels.submit( | |
# fn=ner, inputs=[input_text, labels, threshold], outputs=output | |
# ) | |
# threshold.release( | |
# fn=ner, inputs=[input_text, labels, threshold], outputs=output | |
# ) | |
submit_btn.click( | |
fn=ner, inputs=[input_text, labels, threshold], outputs=output | |
) | |
demo.queue() | |
demo.launch(debug=True) |