mrdbourke's picture
Uploading Trashify V2 box detection model (with data augmentation)
aab9a3a verified
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection
# Note: Can load from Hugging Face or can load from local.
# You will have to replace {mrdbourke} for your own username if the model is on your Hugging Face account.
model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug"
# Load the model and preprocessor
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
model = AutoModelForObjectDetection.from_pretrained(model_save_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
model =
# Get the id2label dictionary from the model
id2label = model.config.id2label
# Set up a colour dictionary for plotting boxes with different colours
color_dict = {
"bin": "green",
"trash": "blue",
"hand": "purple",
"trash_arm": "yellow",
"not_trash": "red",
"not_bin": "red",
"not_hand": "red",
# Create helper functions for seeing if items from one list are in another
def any_in_list(list_a, list_b):
"Returns True if any item from list_a is in list_b, otherwise False."
return any(item in list_b for item in list_a)
def all_in_list(list_a, list_b):
"Returns True if all items from list_a are in list_b, otherwise False."
return all(item in list_b for item in list_a)
def predict_on_image(image, conf_threshold):
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
outputs = model(**
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width
results = image_processor.post_process_object_detection(outputs,
# Return all items in results to CPU
for key, value in results.items():
results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
results[key] = value.cpu()
# Can return results as plotted on a PIL image (then display the image)
draw = ImageDraw.Draw(image)
# Get a font from ImageFont
font = ImageFont.load_default(size=20)
# Get class names as text for print out
class_name_text_labels = []
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label_name
label_name = id2label[label.item()]
targ_color = color_dict[label_name]
# Draw the rectangle
draw.rectangle(xy=(x, y, x2, y2),
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text on the image
draw.text(xy=(x, y),
# Remove the draw each time
del draw
# Setup blank string to print out
return_string = ""
# Setup list of target items to discover
target_items = ["trash", "bin", "hand"]
# If no items detected or trash, bin, hand not in list, return notification
if (len(class_name_text_labels) == 0) or not (any_in_list(list_a=target_items, list_b=class_name_text_labels)):
return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
return image, return_string
# If there are some missing, print the ones which are missing
elif not all_in_list(list_a=target_items, list_b=class_name_text_labels):
missing_items = []
for item in target_items:
if item not in class_name_text_labels:
return_string = f"Detected the following items: {class_name_text_labels}. But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data."
# If all 3 trash, bin, hand occur = + 1
if all_in_list(list_a=target_items, list_b=class_name_text_labels):
return_string = f"+1! Found the following items: {class_name_text_labels}, thank you for cleaning up the area!"
return image, return_string
# Create the interface
demo = gr.Interface(
gr.Image(type="pil", label="Target Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
gr.Image(type="pil", label="Image Output"),
gr.Text(label="Text Output")
title="๐Ÿšฎ Trashify Object Detection Demo V2",
description="""Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
The [model]( in V2 has been trained with data augmentation preprocessing (color jitter, horizontal flipping) to improve robustness.
# Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
["examples/trashify_example_1.jpeg", 0.25],
["examples/trashify_example_2.jpeg", 0.25],
["examples/trashify_example_3.jpeg", 0.25]
# Launch the demo