Spaces:
Build error
Build error
from transformers import AutoFeatureExtractor, YolosForObjectDetection | |
import gradio as gr | |
from PIL import Image | |
import torch | |
import matplotlib.pyplot as plt | |
import io | |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
def infer(img, model_name): | |
feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/{model_name}") | |
model = YolosForObjectDetection.from_pretrained(f"hustvl/{model_name}") | |
img = Image.fromarray(img) | |
pixel_values = feature_extractor(img, return_tensors="pt").pixel_values | |
with torch.no_grad(): | |
outputs = model(pixel_values, output_attentions=True) | |
probas = outputs.logits.softmax(-1)[0, :, :-1] | |
keep = probas.max(-1).values > 0.9 | |
target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0) | |
postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes) | |
bboxes_scaled = postprocessed_outputs[0]['boxes'] | |
res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model) | |
return res_img | |
def plot_results(pil_img, prob, boxes, model): | |
plt.figure(figsize=(16,10)) | |
plt.imshow(pil_img) | |
ax = plt.gca() | |
colors = COLORS * 100 | |
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): | |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, | |
fill=False, color=c, linewidth=3)) | |
cl = p.argmax() | |
text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}' | |
ax.text(xmin, ymin, text, fontsize=15, | |
bbox=dict(facecolor='yellow', alpha=0.5)) | |
plt.axis('off') | |
return fig2img(plt.gcf()) | |
def fig2img(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
description = """Object Detection with YOLOS. Choose your model and you're good to go.""" | |
image_in = gr.components.Image() | |
image_out = gr.components.Image() | |
model_choice = gr.components.Dropdown(["yolos-tiny", "yolos-small", "yolos_base", "yolos-small-300", "yolos-small-dwr"], value="yolos-small") | |
Iface = gr.Interface( | |
fn=infer, | |
inputs=[image_in,model_choice], | |
outputs=image_out, | |
examples=[["examples/10_People_Marching_People_Marching_2_120.jpg"], ["examples/12_Group_Group_12_Group_Group_12_26.jpg"], ["examples/43_Row_Boat_Canoe_43_247.jpg"]], | |
title="Object Detection with YOLOS", | |
description=description, | |
).launch() |