gatesla commited on
Commit
3a9ef72
·
verified ·
1 Parent(s): 7b51903

Testing if maskformer is working

Browse files
Files changed (1) hide show
  1. app.py +77 -114
app.py CHANGED
@@ -5,115 +5,66 @@ import requests, validators
5
  import torch
6
  import pathlib
7
  from PIL import Image
8
- from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
9
- from ultralyticsplus import YOLO, render_result
 
10
 
11
  import os
12
 
13
- # colors for visualization
14
- COLORS = [
15
- [0.000, 0.447, 0.741],
16
- [0.850, 0.325, 0.098],
17
- [0.929, 0.694, 0.125],
18
- [0.494, 0.184, 0.556],
19
- [0.466, 0.674, 0.188],
20
- [0.301, 0.745, 0.933]
21
- ]
22
-
23
- YOLOV8_LABELS = ['pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
24
-
25
- def make_prediction(img, feature_extractor, model):
26
- inputs = feature_extractor(img, return_tensors="pt")
27
- outputs = model(**inputs)
28
- img_size = torch.tensor([tuple(reversed(img.size))])
29
- processed_outputs = feature_extractor.post_process(outputs, img_size)
30
- return processed_outputs
31
-
32
- def fig2img(fig):
33
- buf = io.BytesIO()
34
- fig.savefig(buf, bbox_inches="tight")
35
- buf.seek(0)
36
- img = Image.open(buf)
37
- return img
38
-
39
-
40
- def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
41
- keep = output_dict["scores"] > threshold
42
- boxes = output_dict["boxes"][keep].tolist()
43
- scores = output_dict["scores"][keep].tolist()
44
- labels = output_dict["labels"][keep].tolist()
45
- if id2label is not None:
46
- labels = [id2label[x] for x in labels]
47
-
48
- # print("Labels " + str(labels))
49
-
50
- plt.figure(figsize=(16, 10))
51
- plt.imshow(pil_img)
52
- ax = plt.gca()
53
- colors = COLORS * 100
54
- for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
55
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
56
- ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
57
- plt.axis("off")
58
- return fig2img(plt.gcf())
59
 
60
  def detect_objects(model_name,url_input,image_input,threshold):
61
 
62
 
63
- if 'yolov8' in model_name:
64
- # Working on getting this to work, another approach
65
- # https://docs.ultralytics.com/modes/predict/#key-features-of-predict-mode
66
-
67
- model = YOLO(model_name)
68
- # set model parameters
69
- model.overrides['conf'] = 0.15 # NMS confidence threshold
70
- model.overrides['iou'] = 0.05 # NMS IoU threshold https://www.google.com/search?client=firefox-b-1-d&q=intersection+over+union+meaning
71
- model.overrides['agnostic_nms'] = False # NMS class-agnostic
72
- model.overrides['max_det'] = 1000 # maximum number of detections per image
73
-
74
- results = model.predict(image_input)
75
-
76
- render = render_result(model=model, image=image_input, result=results[0])
77
-
78
- final_str = ""
79
- final_str_abv = ""
80
- final_str_else = ""
81
-
82
- for result in results:
83
- boxes = result.boxes.cpu().numpy()
84
- for i, box in enumerate(boxes):
85
- # r = box.xyxy[0].astype(int)
86
- coordinates = box.xyxy[0].astype(int)
87
- try:
88
- label = YOLOV8_LABELS[int(box.cls)]
89
- except:
90
- label = "ERROR"
91
- try:
92
- confi = float(box.conf)
93
- except:
94
- confi = 0.0
95
- # final_str_abv += str() + "__" + str(box.cls) + "__" + str(box.conf) + "__" + str(box) + "\n"
96
- if confi >= threshold:
97
- final_str_abv += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
98
- else:
99
- final_str_else += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
100
-
101
- final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
102
-
103
- return render, final_str
104
- else:
105
-
106
- #Extract model and feature extractor
107
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
108
- if 'detr' in model_name:
109
-
110
- model = DetrForObjectDetection.from_pretrained(model_name)
111
 
112
- elif 'yolos' in model_name:
113
-
114
- model = YolosForObjectDetection.from_pretrained(model_name)
115
-
116
- tb_label = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  if validators.url(url_input):
118
  image = Image.open(requests.get(url_input, stream=True).raw)
119
  tb_label = "Confidence Values URL"
@@ -121,11 +72,28 @@ def detect_objects(model_name,url_input,image_input,threshold):
121
  elif image_input:
122
  image = image_input
123
  tb_label = "Confidence Values Upload"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- #Make prediction
126
- processed_output_list = make_prediction(image, feature_extractor, model)
127
- # print("After make_prediction" + str(processed_output_list))
128
- processed_outputs = processed_output_list[0]
129
 
130
  #Visualize prediction
131
  viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
@@ -146,6 +114,8 @@ def detect_objects(model_name,url_input,image_input,threshold):
146
  final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
147
 
148
  return viz_img, final_str
 
 
149
 
150
  def set_example_image(example: list) -> dict:
151
  return gr.Image.update(value=example[0])
@@ -162,16 +132,9 @@ Links to HuggingFace Models:
162
  - [facebook/detr-resnet-50-panoptic](https://huggingface.co/facebook/detr-resnet-50-panoptic)
163
  - [facebook/detr-resnet-101-panoptic](https://huggingface.co/facebook/detr-resnet-101-panoptic)
164
  - [facebook/maskformer-swin-large-coco](https://huggingface.co/facebook/maskformer-swin-large-coco)
165
- - [hustvl/yolos-small](https://huggingface.co/hustvl/yolos-small)
166
- - [hustvl/yolos-tiny](https://huggingface.co/hustvl/yolos-tiny)
167
- - [facebook/detr-resnet-101-dc5](https://huggingface.co/facebook/detr-resnet-101-dc5)
168
- - [hustvl/yolos-small-300](https://huggingface.co/hustvl/yolos-small-300)
169
- - [mshamrai/yolov8x-visdrone](https://huggingface.co/mshamrai/yolov8x-visdrone)
170
-
171
  """
172
 
173
- models = ["facebook/detr-resnet-50-panoptic","facebook/detr-resnet-101-panoptic","facebook/maskformer-swin-large-coco",
174
- 'hustvl/yolos-small','hustvl/yolos-tiny','facebook/detr-resnet-101-dc5', 'hustvl/yolos-small-300', 'mshamrai/yolov8x-visdrone']
175
  urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
176
 
177
  # twitter_link = """
@@ -196,7 +159,7 @@ with demo:
196
  gr.Markdown(title)
197
  gr.Markdown(description)
198
  # gr.Markdown(twitter_link)
199
- options = gr.Dropdown(choices=models,label='Select Object Detection Model',show_label=True)
200
 
201
  slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.7,label='Prediction Threshold')
202
 
 
5
  import torch
6
  import pathlib
7
  from PIL import Image
8
+
9
+ from transformers import DetrFeatureExtractor, DetrForSegmentation
10
+ from transformers.models.detr.feature_extraction_detr import rgb_to_id
11
 
12
  import os
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def detect_objects(model_name,url_input,image_input,threshold):
16
 
17
 
18
+ if 'maskformer' in model_name:
19
+ if validators.url(url_input):
20
+ image = Image.open(requests.get(url_input, stream=True).raw)
21
+ tb_label = "Confidence Values URL"
22
+
23
+ elif image_input:
24
+ image = image_input
25
+ tb_label = "Confidence Values Upload"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # NOTE: Pulling from the example on https://huggingface.co/facebook/maskformer-swin-large-coco
28
+ # and https://huggingface.co/spaces/ajcdp/Image-Segmentation-Gradio/blob/main/app.py
29
+
30
+ processor = MaskFormerImageProcessor.from_pretrained(model_name)
31
+ model = MaskFormerForInstanceSegmentation.from_pretrained(model_name)
32
+
33
+ target_size = (img.shape[0], img.shape[1])
34
+ inputs = preprocessor(images=img, return_tensors="pt")
35
+ with torch.no_grad():
36
+ outputs = model(**inputs)
37
+ outputs.class_queries_logits = outputs.class_queries_logits.cpu()
38
+ outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
39
+ results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
40
+ results = torch.argmax(results, dim=0).numpy()
41
+ results = visualize_instance_seg_mask(results)
42
+ return results, "EMPTY"
43
+
44
+ # for result in results:
45
+ # boxes = result.boxes.cpu().numpy()
46
+ # for i, box in enumerate(boxes):
47
+ # # r = box.xyxy[0].astype(int)
48
+ # coordinates = box.xyxy[0].astype(int)
49
+ # try:
50
+ # label = YOLOV8_LABELS[int(box.cls)]
51
+ # except:
52
+ # label = "ERROR"
53
+ # try:
54
+ # confi = float(box.conf)
55
+ # except:
56
+ # confi = 0.0
57
+ # # final_str_abv += str() + "__" + str(box.cls) + "__" + str(box.conf) + "__" + str(box) + "\n"
58
+ # if confi >= threshold:
59
+ # final_str_abv += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
60
+ # else:
61
+ # final_str_else += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
62
+
63
+ # final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
64
+
65
+ # return render, final_str
66
+ elif "detr" in model_name:
67
+ # NOTE: Using the example on https://huggingface.co/facebook/detr-resnet-50-panoptic
68
  if validators.url(url_input):
69
  image = Image.open(requests.get(url_input, stream=True).raw)
70
  tb_label = "Confidence Values URL"
 
72
  elif image_input:
73
  image = image_input
74
  tb_label = "Confidence Values Upload"
75
+
76
+ feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
77
+ model = DetrForSegmentation.from_pretrained(model_name)
78
+ inputs = feature_extractor(images=image, return_tensors="pt")
79
+
80
+ outputs = model(**inputs)
81
+
82
+ # use the `post_process_panoptic` method of `DetrFeatureExtractor` to convert to COCO format
83
+ processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
84
+ result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
85
+
86
+ # the segmentation is stored in a special-format png
87
+ panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
88
+ panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8)
89
+ # retrieve the ids corresponding to each mask
90
+ panoptic_seg_id = rgb_to_id(panoptic_seg)
91
+
92
 
93
+
94
+ return gr.Image.update(), "EMPTY"
95
+
96
+
97
 
98
  #Visualize prediction
99
  viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
 
114
  final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
115
 
116
  return viz_img, final_str
117
+ else:
118
+ raise NameError(f"Model name {model_name} not prepared")
119
 
120
  def set_example_image(example: list) -> dict:
121
  return gr.Image.update(value=example[0])
 
132
  - [facebook/detr-resnet-50-panoptic](https://huggingface.co/facebook/detr-resnet-50-panoptic)
133
  - [facebook/detr-resnet-101-panoptic](https://huggingface.co/facebook/detr-resnet-101-panoptic)
134
  - [facebook/maskformer-swin-large-coco](https://huggingface.co/facebook/maskformer-swin-large-coco)
 
 
 
 
 
 
135
  """
136
 
137
+ models = ["facebook/detr-resnet-50-panoptic","facebook/detr-resnet-101-panoptic","facebook/maskformer-swin-large-coco"]
 
138
  urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
139
 
140
  # twitter_link = """
 
159
  gr.Markdown(title)
160
  gr.Markdown(description)
161
  # gr.Markdown(twitter_link)
162
+ options = gr.Dropdown(choices=models,label='Select Image Segmentation Model',show_label=True)
163
 
164
  slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.7,label='Prediction Threshold')
165