zero-shot-demo / app.py
Saripudin's picture
Create app.py
186b145 verified
raw
history blame
4.31 kB
import gradio as gr
from transformers import pipeline
from typing import Dict, Union
from gliner import GLiNER
model = GLiNER.from_pretrained("numind/NuNER_Zero")
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-zeroshot-v1")
#define a function to process your input and output
def zero_shot(doc, candidates):
given_labels = candidates.split(", ")
dictionary = classifier(doc, given_labels)
labels = dictionary['labels']
scores = dictionary['scores']
return dict(zip(labels, scores))
examples = [
[
"The Moon is Earth's only natural satellite. It orbits at an average distance of 384,400 km (238,900 mi), about 30 times the diameter of Earth. Over time Earth's gravity has caused tidal locking, causing the same side of the Moon to always face Earth. Because of this, the lunar day and the lunar month are the same length, at 29.5 Earth days. The Moon's gravitational pull – and to a lesser extent, the Sun's – are the main drivers of Earth's tides.",
"celestial body,quantity,physical concept",
0.3,
False
],
]
def merge_entities(entities):
if not entities:
return []
merged = []
current = entities[0]
for next_entity in entities[1:]:
if next_entity['entity'] == current['entity'] and (next_entity['start'] == current['end'] + 1 or next_entity['start'] == current['end']):
current['word'] += ' ' + next_entity['word']
current['end'] = next_entity['end']
else:
merged.append(current)
current = next_entity
merged.append(current)
return merged
def ner(
text, labels: str, threshold: float, nested_ner: bool
) -> Dict[str, Union[str, int, float]]:
labels = labels.split(",")
r = {
"text": text,
"entities": [
{
"entity": entity["label"],
"word": entity["text"],
"start": entity["start"],
"end": entity["end"],
"score": 0,
}
for entity in model.predict_entities(
text, labels, flat_ner=not nested_ner, threshold=threshold
)
],
}
r["entities"] = merge_entities(r["entities"])
return r
with gr.Blocks(title="Zero-Shot Demo") as demo: #, theme=gr.themes.Soft()
#create input and output objects
with gr.Tab("Zero-Shot Text Classification"):
#input object1
input1 = gr.Textbox(label="Text")
#input object 2
input2 = gr.Textbox(label="Labels")
#output object
output = gr.Label(label="Output")
#create interface
gui = gr.Interface(
title="Zero-Shot Text Classification",
fn=zero_shot,
inputs=[input1, input2],
outputs=[output]
)
with gr.Tab("Zero-Shot NER"):
gr.Markdown(
"""
# Zero-Shot Named Entity Recognition (NER)
"""
)
input_text = gr.Textbox(
value=examples[0][0], label="Text input", placeholder="Enter your text here", lines=3
)
with gr.Row() as row:
labels = gr.Textbox(
value=examples[0][1],
label="Labels",
placeholder="Enter your labels here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=0.3,
step=0.01,
label="Threshold",
info="Lower the threshold to increase how many entities get predicted.",
scale=1,
)
output = gr.HighlightedText(label="Predicted Entities")
submit_btn = gr.Button("Submit")
# Submitting
# input_text.submit(
# fn=ner, inputs=[input_text, labels, threshold], outputs=output
# )
# labels.submit(
# fn=ner, inputs=[input_text, labels, threshold], outputs=output
# )
# threshold.release(
# fn=ner, inputs=[input_text, labels, threshold], outputs=output
# )
submit_btn.click(
fn=ner, inputs=[input_text, labels, threshold], outputs=output
)
demo.queue()
demo.launch(debug=True)