gatesla commited on
Commit
e3ab040
·
verified ·
1 Parent(s): 80a09bd

Reset the main functions

Browse files
Files changed (1) hide show
  1. app.py +67 -110
app.py CHANGED
@@ -12,111 +12,68 @@ from transformers.models.detr.feature_extraction_detr import rgb_to_id
12
 
13
  import os
14
 
15
-
16
- def detect_objects(model_name,url_input,image_input,threshold):
17
-
18
-
19
- if 'maskformer' in model_name:
20
- if validators.url(url_input):
21
- image = Image.open(requests.get(url_input, stream=True).raw)
22
- tb_label = "Confidence Values URL"
23
-
24
- elif image_input:
25
- image = image_input
26
- tb_label = "Confidence Values Upload"
27
-
28
- # NOTE: Pulling from the example on https://huggingface.co/facebook/maskformer-swin-large-coco
29
- # and https://huggingface.co/spaces/ajcdp/Image-Segmentation-Gradio/blob/main/app.py
30
-
31
- processor = MaskFormerImageProcessor.from_pretrained(model_name)
32
- model = MaskFormerForInstanceSegmentation.from_pretrained(model_name)
33
-
34
- target_size = (image.size[0], image.size[1])
35
- inputs = processor(images=image, return_tensors="pt")
36
- with torch.no_grad():
37
- outputs = model(**inputs)
38
- outputs.class_queries_logits = outputs.class_queries_logits.cpu()
39
- outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
40
- results = processor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
41
- results = torch.argmax(results, dim=0).numpy()
42
- results = visualize_instance_seg_mask(results)
43
- return results, "EMPTY"
44
-
45
- # for result in results:
46
- # boxes = result.boxes.cpu().numpy()
47
- # for i, box in enumerate(boxes):
48
- # # r = box.xyxy[0].astype(int)
49
- # coordinates = box.xyxy[0].astype(int)
50
- # try:
51
- # label = YOLOV8_LABELS[int(box.cls)]
52
- # except:
53
- # label = "ERROR"
54
- # try:
55
- # confi = float(box.conf)
56
- # except:
57
- # confi = 0.0
58
- # # final_str_abv += str() + "__" + str(box.cls) + "__" + str(box.conf) + "__" + str(box) + "\n"
59
- # if confi >= threshold:
60
- # final_str_abv += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
61
- # else:
62
- # final_str_else += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
63
-
64
- # final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
65
-
66
- # return render, final_str
67
- elif "detr" in model_name:
68
- # NOTE: Using the example on https://huggingface.co/facebook/detr-resnet-50-panoptic
69
- if validators.url(url_input):
70
- image = Image.open(requests.get(url_input, stream=True).raw)
71
- tb_label = "Confidence Values URL"
72
-
73
- elif image_input:
74
- image = image_input
75
- tb_label = "Confidence Values Upload"
76
-
77
- feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
78
- model = DetrForSegmentation.from_pretrained(model_name)
79
- inputs = feature_extractor(images=image, return_tensors="pt")
80
-
81
- outputs = model(**inputs)
82
-
83
- # use the `post_process_panoptic` method of `DetrFeatureExtractor` to convert to COCO format
84
- processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
85
- result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
86
-
87
- # the segmentation is stored in a special-format png
88
- panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
89
- panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8)
90
- # retrieve the ids corresponding to each mask
91
- panoptic_seg_id = rgb_to_id(panoptic_seg)
92
-
93
-
94
-
95
- return gr.Image.update(), "EMPTY"
96
-
97
-
98
-
99
- #Visualize prediction
100
- viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
101
-
102
- # return [viz_img, processed_outputs]
103
- # print(type(viz_img))
104
-
105
- final_str_abv = ""
106
- final_str_else = ""
107
- for score, label, box in sorted(zip(processed_outputs["scores"], processed_outputs["labels"], processed_outputs["boxes"]), key = lambda x: x[0].item(), reverse=True):
108
- box = [round(i, 2) for i in box.tolist()]
109
- if score.item() >= threshold:
110
- final_str_abv += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
111
- else:
112
- final_str_else += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
113
-
114
- # https://docs.python.org/3/library/string.html#format-examples
115
- final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
116
-
117
- return viz_img, final_str
118
  else:
119
- raise NameError(f"Model name {model_name} not prepared")
120
 
121
  def set_example_image(example: list) -> dict:
122
  return gr.Image.update(value=example[0])
@@ -197,13 +154,13 @@ with demo:
197
  options.change(fn=changing, inputs=[], outputs=[img_but, url_but])
198
 
199
 
200
- url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_url, output_text1],queue=True)
201
- img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_upload, output_text1],queue=True)
202
- # url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_url, _],queue=True)
203
- # img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_upload, _],queue=True)
204
 
205
- # url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True)
206
- # img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True)
207
 
208
 
209
  example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])
 
12
 
13
  import os
14
 
15
+ # colors for visualization
16
+ COLORS = [
17
+ [0.000, 0.447, 0.741],
18
+ [0.850, 0.325, 0.098],
19
+ [0.929, 0.694, 0.125],
20
+ [0.494, 0.184, 0.556],
21
+ [0.466, 0.674, 0.188],
22
+ [0.301, 0.745, 0.933]
23
+ ]
24
+
25
+ YOLOV8_LABELS = ['pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
26
+
27
+ def make_prediction(img, feature_extractor, model):
28
+ inputs = feature_extractor(img, return_tensors="pt")
29
+ outputs = model(**inputs)
30
+ img_size = torch.tensor([tuple(reversed(img.size))])
31
+ processed_outputs = feature_extractor.post_process(outputs, img_size)
32
+ return processed_outputs
33
+
34
+ def fig2img(fig):
35
+ buf = io.BytesIO()
36
+ fig.savefig(buf, bbox_inches="tight")
37
+ buf.seek(0)
38
+ img = Image.open(buf)
39
+ return img
40
+
41
+
42
+ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
43
+ keep = output_dict["scores"] > threshold
44
+ boxes = output_dict["boxes"][keep].tolist()
45
+ scores = output_dict["scores"][keep].tolist()
46
+ labels = output_dict["labels"][keep].tolist()
47
+ if id2label is not None:
48
+ labels = [id2label[x] for x in labels]
49
+
50
+ # print("Labels " + str(labels))
51
+
52
+ plt.figure(figsize=(16, 10))
53
+ plt.imshow(pil_img)
54
+ ax = plt.gca()
55
+ colors = COLORS * 100
56
+ for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
57
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
58
+ ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
59
+ plt.axis("off")
60
+ return fig2img(plt.gcf())
61
+
62
+ def segment_images(model_name,url_input,image_input,threshold):
63
+ ####
64
+ # Get Image Object
65
+ if validators.url(url_input):
66
+ image = Image.open(requests.get(url_input, stream=True).raw)
67
+ elif image_input:
68
+ image = image_input
69
+ ####
70
+
71
+ if "detr" in model_name:
72
+ pass
73
+ elif "maskformer" in model_name.lower():
74
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  else:
76
+ raise NameError("Model is not implemented")
77
 
78
  def set_example_image(example: list) -> dict:
79
  return gr.Image.update(value=example[0])
 
154
  options.change(fn=changing, inputs=[], outputs=[img_but, url_but])
155
 
156
 
157
+ url_but.click(segment_images,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_url, output_text1],queue=True)
158
+ img_but.click(segment_images,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_upload, output_text1],queue=True)
159
+ # url_but.click(segment_images,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_url, _],queue=True)
160
+ # img_but.click(segment_images,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_upload, _],queue=True)
161
 
162
+ # url_but.click(segment_images,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True)
163
+ # img_but.click(segment_images,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True)
164
 
165
 
166
  example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])