mrdbourke commited on
Commit
dd1fd86
ยท
verified ยท
1 Parent(s): 4845fdd

Uploading Trashify V2 box detection model (with data augmentation) app.py

Browse files
Files changed (2) hide show
  1. README.md +16 -5
  2. app.py +79 -18
README.md CHANGED
@@ -1,13 +1,24 @@
1
  ---
2
- title: Trashify Demo V2
3
- emoji: ๐Ÿ‘
4
  colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.41.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Trashify Demo V2 ๐Ÿšฎ
3
+ emoji: ๐Ÿ—‘๏ธ
4
  colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.40.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
+ # ๐Ÿšฎ Trashify Object Detector Demo V2
14
+
15
+ Object detection demo to detect `trash`, `bin`, `hand`, `trash_arm`, `not_trash`, `not_bin`, `not_hand`.
16
+
17
+ Used as example for encouraging people to cleanup their local area.
18
+
19
+ If `trash`, `hand`, `bin` all detected = +1 point.
20
+
21
+ * V1 = model trained *without* data augmentation
22
+ * V2 = model trained *with* data augmentation
23
+
24
+ TK - finish the README.md + update with links to materials
app.py CHANGED
@@ -1,29 +1,45 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image, ImageDraw
4
 
5
  from transformers import AutoImageProcessor
6
  from transformers import AutoModelForObjectDetection
7
 
8
- from PIL import Image
9
-
10
- model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_synthetic_data_only"
11
 
 
12
  image_processor = AutoImageProcessor.from_pretrained(model_save_path)
13
  model = AutoModelForObjectDetection.from_pretrained(model_save_path)
14
 
 
 
 
 
15
  id2label = model.config.id2label
16
- color_dict = {
17
- "not_trash": "red",
 
18
  "bin": "green",
19
  "trash": "blue",
20
- "hand": "purple"
 
 
 
 
21
  }
22
 
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- model = model.to(device)
 
 
 
 
 
 
25
 
26
- def predict_on_image(image, conf_threshold=0.25):
27
  with torch.no_grad():
28
  inputs = image_processor(images=[image], return_tensors="pt")
29
  outputs = model(**inputs.to(device))
@@ -43,6 +59,12 @@ def predict_on_image(image, conf_threshold=0.25):
43
  # Can return results as plotted on a PIL image (then display the image)
44
  draw = ImageDraw.Draw(image)
45
 
 
 
 
 
 
 
46
  for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
47
  # Create coordinates
48
  x, y, x2, y2 = tuple(box.tolist())
@@ -50,6 +72,7 @@ def predict_on_image(image, conf_threshold=0.25):
50
  # Get label_name
51
  label_name = id2label[label.item()]
52
  targ_color = color_dict[label_name]
 
53
 
54
  # Draw the rectangle
55
  draw.rectangle(xy=(x, y, x2, y2),
@@ -62,23 +85,61 @@ def predict_on_image(image, conf_threshold=0.25):
62
  # Draw the text on the image
63
  draw.text(xy=(x, y),
64
  text=text_string_to_show,
65
- fill="white")
 
66
 
67
  # Remove the draw each time
68
  del draw
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- return image
71
 
 
72
  demo = gr.Interface(
73
  fn=predict_on_image,
74
  inputs=[
75
- gr.Image(type="pil", label="Upload Target Image"),
76
  gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
77
  ],
78
- outputs=gr.Image(type="pil"),
79
- title="๐Ÿšฎ Trashify Object Detection Demo",
80
- description="Upload an image to detect whether there's a bin, a hand or trash in it. Model trained on synthetically generated images by Flux and labels creating by GroundingDINO."
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
 
83
- if __name__ == "__main__":
84
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from PIL import Image, ImageDraw, ImageFont
4
 
5
  from transformers import AutoImageProcessor
6
  from transformers import AutoModelForObjectDetection
7
 
8
+ # Note: Can load from Hugging Face or can load from local.
9
+ # You will have to replace {mrdbourke} for your own username if the model is on your Hugging Face account.
10
+ model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug"
11
 
12
+ # Load the model and preprocessor
13
  image_processor = AutoImageProcessor.from_pretrained(model_save_path)
14
  model = AutoModelForObjectDetection.from_pretrained(model_save_path)
15
 
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model = model.to(device)
18
+
19
+ # Get the id2label dictionary from the model
20
  id2label = model.config.id2label
21
+
22
+ # Set up a colour dictionary for plotting boxes with different colours
23
+ color_dict = {
24
  "bin": "green",
25
  "trash": "blue",
26
+ "hand": "purple",
27
+ "trash_arm": "yellow",
28
+ "not_trash": "red",
29
+ "not_bin": "red",
30
+ "not_hand": "red",
31
  }
32
 
33
+ # Create helper functions for seeing if items from one list are in another
34
+ def any_in_list(list_a, list_b):
35
+ "Returns True if any item from list_a is in list_b, otherwise False."
36
+ return any(item in list_b for item in list_a)
37
+
38
+ def all_in_list(list_a, list_b):
39
+ "Returns True if all items from list_a are in list_b, otherwise False."
40
+ return all(item in list_b for item in list_a)
41
 
42
+ def predict_on_image(image, conf_threshold):
43
  with torch.no_grad():
44
  inputs = image_processor(images=[image], return_tensors="pt")
45
  outputs = model(**inputs.to(device))
 
59
  # Can return results as plotted on a PIL image (then display the image)
60
  draw = ImageDraw.Draw(image)
61
 
62
+ # Get a font from ImageFont
63
+ font = ImageFont.load_default(size=20)
64
+
65
+ # Get class names as text for print out
66
+ class_name_text_labels = []
67
+
68
  for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
69
  # Create coordinates
70
  x, y, x2, y2 = tuple(box.tolist())
 
72
  # Get label_name
73
  label_name = id2label[label.item()]
74
  targ_color = color_dict[label_name]
75
+ class_name_text_labels.append(label_name)
76
 
77
  # Draw the rectangle
78
  draw.rectangle(xy=(x, y, x2, y2),
 
85
  # Draw the text on the image
86
  draw.text(xy=(x, y),
87
  text=text_string_to_show,
88
+ fill="white",
89
+ font=font)
90
 
91
  # Remove the draw each time
92
  del draw
93
+
94
+ # Setup blank string to print out
95
+ return_string = ""
96
+
97
+ # Setup list of target items to discover
98
+ target_items = ["trash", "bin", "hand"]
99
+
100
+ # If no items detected or trash, bin, hand not in list, return notification
101
+ if (len(class_name_text_labels) == 0) or not (any_in_list(list_a=target_items, list_b=class_name_text_labels)):
102
+ return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
103
+ return image, return_string
104
+
105
+ # If there are some missing, print the ones which are missing
106
+ elif not all_in_list(list_a=target_items, list_b=class_name_text_labels):
107
+ missing_items = []
108
+ for item in target_items:
109
+ if item not in class_name_text_labels:
110
+ missing_items.append(item)
111
+ return_string = f"Detected the following items: {class_name_text_labels}. But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data."
112
+
113
+ # If all 3 trash, bin, hand occur = + 1
114
+ if all_in_list(list_a=target_items, list_b=class_name_text_labels):
115
+ return_string = f"+1! Found the following items: {class_name_text_labels}, thank you for cleaning up the area!"
116
+
117
+ print(return_string)
118
 
119
+ return image, return_string
120
 
121
+ # Create the interface
122
  demo = gr.Interface(
123
  fn=predict_on_image,
124
  inputs=[
125
+ gr.Image(type="pil", label="Target Image"),
126
  gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
127
  ],
128
+ outputs=[
129
+ gr.Image(type="pil", label="Image Output"),
130
+ gr.Text(label="Text Output")
131
+ ],
132
+ title="๐Ÿšฎ Trashify Object Detection Demo V2",
133
+ description="""Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
134
+ Model in V2 has been trained with data augmentation (tk - add link to model).
135
+ """,
136
+ # Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
137
+ examples=[
138
+ ["examples/trashify_example_1.jpeg", 0.25],
139
+ ["examples/trashify_example_2.jpeg", 0.25]
140
+ ],
141
+ cache_examples=True
142
  )
143
 
144
+ # Launch the demo
145
+ demo.launch()