polejowska commited on
Commit
edbbf31
1 Parent(s): 074be6a

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +12 -1
visualization.py CHANGED
@@ -7,8 +7,17 @@ from constants import COLORS
7
  from utils import fig2img
8
 
9
 
 
 
 
 
 
 
 
 
 
10
  def visualize_prediction(
11
- pil_img, output_dict, threshold=0.7, id2label=None
12
  ):
13
  keep = output_dict["scores"] > threshold
14
  boxes = output_dict["boxes"][keep].tolist()
@@ -19,6 +28,8 @@ def visualize_prediction(
19
 
20
  fig, ax = plt.subplots(figsize=(12, 12))
21
  ax.imshow(pil_img)
 
 
22
  colors = COLORS * 100
23
  for score, (xmin, ymin, xmax, ymax), label, color in zip(
24
  scores, boxes, labels, colors
 
7
  from utils import fig2img
8
 
9
 
10
+ def visualize_mask(mask, img, alpha=0.5):
11
+ mask = mask.cpu().squeeze().numpy()
12
+ img = img.cpu().squeeze().permute(1, 2, 0).numpy()
13
+ plt.imshow(img)
14
+ plt.imshow(mask, alpha=alpha)
15
+ plt.axis("off")
16
+ return fig2img(plt.gcf())
17
+
18
+
19
  def visualize_prediction(
20
+ pil_img, output_dict, threshold=0.7, id2label=None, display_mask=False, mask=None
21
  ):
22
  keep = output_dict["scores"] > threshold
23
  boxes = output_dict["boxes"][keep].tolist()
 
28
 
29
  fig, ax = plt.subplots(figsize=(12, 12))
30
  ax.imshow(pil_img)
31
+ if display_mask:
32
+ ax.imshow(mask, alpha=0.5)
33
  colors = COLORS * 100
34
  for score, (xmin, ymin, xmax, ymax), label, color in zip(
35
  scores, boxes, labels, colors