Spaces:
Runtime error
Runtime error
polejowska
commited on
Commit
•
edbbf31
1
Parent(s):
074be6a
Update visualization.py
Browse files- visualization.py +12 -1
visualization.py
CHANGED
@@ -7,8 +7,17 @@ from constants import COLORS
|
|
7 |
from utils import fig2img
|
8 |
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def visualize_prediction(
|
11 |
-
pil_img, output_dict, threshold=0.7, id2label=None
|
12 |
):
|
13 |
keep = output_dict["scores"] > threshold
|
14 |
boxes = output_dict["boxes"][keep].tolist()
|
@@ -19,6 +28,8 @@ def visualize_prediction(
|
|
19 |
|
20 |
fig, ax = plt.subplots(figsize=(12, 12))
|
21 |
ax.imshow(pil_img)
|
|
|
|
|
22 |
colors = COLORS * 100
|
23 |
for score, (xmin, ymin, xmax, ymax), label, color in zip(
|
24 |
scores, boxes, labels, colors
|
|
|
7 |
from utils import fig2img
|
8 |
|
9 |
|
10 |
+
def visualize_mask(mask, img, alpha=0.5):
|
11 |
+
mask = mask.cpu().squeeze().numpy()
|
12 |
+
img = img.cpu().squeeze().permute(1, 2, 0).numpy()
|
13 |
+
plt.imshow(img)
|
14 |
+
plt.imshow(mask, alpha=alpha)
|
15 |
+
plt.axis("off")
|
16 |
+
return fig2img(plt.gcf())
|
17 |
+
|
18 |
+
|
19 |
def visualize_prediction(
|
20 |
+
pil_img, output_dict, threshold=0.7, id2label=None, display_mask=False, mask=None
|
21 |
):
|
22 |
keep = output_dict["scores"] > threshold
|
23 |
boxes = output_dict["boxes"][keep].tolist()
|
|
|
28 |
|
29 |
fig, ax = plt.subplots(figsize=(12, 12))
|
30 |
ax.imshow(pil_img)
|
31 |
+
if display_mask:
|
32 |
+
ax.imshow(mask, alpha=0.5)
|
33 |
colors = COLORS * 100
|
34 |
for score, (xmin, ymin, xmax, ymax), label, color in zip(
|
35 |
scores, boxes, labels, colors
|