Matthijs Hollemans
make noice
f1cff84
raw
history blame
4.26 kB
import numpy as np
import gradio as gr
from PIL import Image
import torch
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
model_checkpoint = "apple/deeplabv3-mobilevit-small"
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint)
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
palette = np.array(
[
[ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
[ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192],
[128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0],
[128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192],
[ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0],
[ 0, 128, 192]
],
dtype=np.uint8)
labels = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
# Draw the labels. Light colors use black text, dark colors use white text.
inverted = [ 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20 ]
labels_colored = []
for i in range(len(labels)):
r, g, b = palette[i]
label = labels[i]
color = "white" if i in inverted else "black"
text = "<span style='background-color: rgb(%d, %d, %d); color: %s; padding: 2px 4px;'>%s</span>" % (r, g, b, color, label)
labels_colored.append(text)
labels_text = ", ".join(labels_colored)
title = "Semantic Segmentation with MobileViT and DeepLabV3"
description = """
The input image is resized and center cropped to 512Γ—512 pixels. The segmentation output is 32Γ—32 pixels.<br>
This model has been trained on <a href="http://host.robots.ox.ac.uk/pascal/VOC/">Pascal VOC</a>.
The classes are:
""" + labels_text + "</p>"
article = """
<div style='margin:20px auto;'>
<p>Sources:<p>
<p>πŸ“œ <a href="https://arxiv.org/abs/2110.02178">MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer</a></p>
<p>πŸ‹οΈ Original pretrained weights from <a href="https://github.com/apple/ml-cvnets">this GitHub repo</a></p>
<p>πŸ™ Example images from <a href="https://huggingface.co/datasets/mishig/sample_images">this dataset</a><p>
</div>
"""
examples = [
["cat-3.jpg"],
["construction-site.jpg"],
["dog-cat.jpg"],
["football-match.jpg"],
]
def predict(image):
with torch.no_grad():
inputs = feature_extractor(image, return_tensors="pt")
outputs = model(**inputs)
# Get preprocessed image. The pixel values don't need to be unnormalized
# for this particular model.
resized = (inputs["pixel_values"].numpy().squeeze().transpose(1, 2, 0)[..., ::-1] * 255).astype(np.uint8)
# Class predictions for each pixel.
classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)
# Super slow method but it works... should probably improve this.
colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
for y in range(classes.shape[0]):
for x in range(classes.shape[1]):
colored[y, x] = palette[classes[y, x]]
# Resize predictions to input size (not original size).
colored = Image.fromarray(colored)
colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)
# Keep everything that is not background.
mask = (classes != 0) * 255
mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB")
mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)
# Blend with the input image.
resized = Image.fromarray(resized)
highlighted = Image.blend(resized, mask, 0.4)
#colored = colored.resize((256, 256), resample=Image.Resampling.BICUBIC)
#highlighted = highlighted.resize((256, 256), resample=Image.Resampling.BICUBIC)
return colored, highlighted
gr.Interface(
fn=predict,
inputs=gr.inputs.Image(label="Upload image"),
outputs=[gr.outputs.Image(label="Classes"), gr.outputs.Image(label="Overlay")],
title=title,
description=description,
article=article,
examples=examples,
).launch()