Spaces:
Build error
Build error
Witold Wydmański
commited on
Commit
·
a06c206
1
Parent(s):
9de3a79
init
Browse files- .gitattributes +2 -0
- __pycache__/transforms.cpython-38.pyc +0 -0
- app.py +121 -0
- decoder-quant.onnx +3 -0
- encoder-quant.onnx +3 -0
- requirements.txt +4 -0
- transforms.py +97 -0
.gitattributes
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
encoder-quant.onnx filter=lfs diff=lfs merge=lfs -text
|
2 |
+
decoder-quant.onnx filter=lfs diff=lfs merge=lfs -text
|
3 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
4 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
5 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
__pycache__/transforms.cpython-38.pyc
ADDED
Binary file (3.98 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import onnxruntime as rt
|
4 |
+
import numpy as np
|
5 |
+
from transforms import ResizeLongestSide
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import torch
|
8 |
+
import onnxruntime
|
9 |
+
|
10 |
+
IMAGE_SIZE = 1024
|
11 |
+
|
12 |
+
def preprocess_image(image):
|
13 |
+
transform = ResizeLongestSide(IMAGE_SIZE)
|
14 |
+
input_image = transform.apply_image(image)
|
15 |
+
input_image_torch = torch.as_tensor(input_image, device="cpu")
|
16 |
+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
17 |
+
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
18 |
+
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
19 |
+
x = (input_image_torch - pixel_mean) / pixel_std
|
20 |
+
h, w = x.shape[-2:]
|
21 |
+
padh = IMAGE_SIZE - h
|
22 |
+
padw = IMAGE_SIZE - w
|
23 |
+
x = F.pad(x, (0, padw, 0, padh))
|
24 |
+
x = x.numpy()
|
25 |
+
return x
|
26 |
+
|
27 |
+
def prepare_inputs(image_embedding, input_point, image_shape):
|
28 |
+
transform = ResizeLongestSide(IMAGE_SIZE)
|
29 |
+
|
30 |
+
input_label = np.array([1])
|
31 |
+
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
|
32 |
+
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
|
33 |
+
|
34 |
+
onnx_coord = transform.apply_coords(onnx_coord, image_shape).astype(np.float32)
|
35 |
+
|
36 |
+
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
|
37 |
+
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
|
38 |
+
|
39 |
+
decoder_inputs = {
|
40 |
+
"image_embeddings": image_embedding,
|
41 |
+
"point_coords": onnx_coord,
|
42 |
+
"point_labels": onnx_label,
|
43 |
+
"mask_input": onnx_mask_input,
|
44 |
+
"has_mask_input": onnx_has_mask_input,
|
45 |
+
"orig_im_size": np.array(image_shape, dtype=np.float32)
|
46 |
+
}
|
47 |
+
return decoder_inputs
|
48 |
+
|
49 |
+
enc_session = onnxruntime.InferenceSession("encoder-quant.onnx")
|
50 |
+
dec_session = onnxruntime.InferenceSession("decoder-quant.onnx")
|
51 |
+
|
52 |
+
def predict_image(img):
|
53 |
+
x = preprocess_image(img)
|
54 |
+
|
55 |
+
encoder_inputs = {
|
56 |
+
"x": x,
|
57 |
+
}
|
58 |
+
|
59 |
+
output = enc_session.run(None, encoder_inputs)
|
60 |
+
image_embedding = output[0]
|
61 |
+
|
62 |
+
middle_of_photo = np.array([[img.shape[1] / 2, img.shape[0] / 2]])
|
63 |
+
|
64 |
+
decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, img.shape[:2])
|
65 |
+
masks, _, low_res_logits = dec_session.run(None, decoder_inputs)
|
66 |
+
|
67 |
+
# normalize the results between -1 and 1
|
68 |
+
masks = masks[0][0]
|
69 |
+
masks[masks<0] = 0
|
70 |
+
masks = masks / np.max(masks)
|
71 |
+
return masks, image_embedding, img.shape[:2]
|
72 |
+
|
73 |
+
def segment_image(image_embedding, shape, evt: gr.SelectData):
|
74 |
+
image_embedding = np.array(image_embedding)
|
75 |
+
middle_of_photo = np.array([evt.index])
|
76 |
+
decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, shape)
|
77 |
+
masks, _, low_res_logits = dec_session.run(None, decoder_inputs)
|
78 |
+
|
79 |
+
# normalize the results between -1 and 1
|
80 |
+
masks = masks[0][0]
|
81 |
+
masks[masks<0] = 0
|
82 |
+
masks = masks / np.max(masks)
|
83 |
+
return masks
|
84 |
+
|
85 |
+
with gr.Blocks() as demo:
|
86 |
+
gr.Markdown("# SAM quantized (Segment Anything Model)")
|
87 |
+
markdown = """
|
88 |
+
This is a demo of the SAM model, which is a model for segmenting anything in an image.
|
89 |
+
It returns segmentation mask of the image that's overlapping with the clicked point.
|
90 |
+
|
91 |
+
The model is quantized using ONNX Runtime
|
92 |
+
"""
|
93 |
+
|
94 |
+
gr.Markdown(markdown)
|
95 |
+
|
96 |
+
embedding = gr.State()
|
97 |
+
shape = gr.State()
|
98 |
+
with gr.Row():
|
99 |
+
with gr.Column():
|
100 |
+
inputs = gr.Image()
|
101 |
+
start_segmentation = gr.Button("Segment")
|
102 |
+
|
103 |
+
with gr.Column():
|
104 |
+
outputs = gr.Image(label="Segmentation Mask")
|
105 |
+
|
106 |
+
start_segmentation.click(
|
107 |
+
predict_image,
|
108 |
+
inputs,
|
109 |
+
[outputs, embedding, shape],
|
110 |
+
)
|
111 |
+
|
112 |
+
outputs.select(
|
113 |
+
segment_image,
|
114 |
+
[embedding, shape],
|
115 |
+
outputs,
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
demo.launch()
|
decoder-quant.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64dedbe577d41b18ccb8d5496d26916929f65c7ecd8f06d5b23c5197434bfcb0
|
3 |
+
size 8738974
|
encoder-quant.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f1b905f70f4a3e769473b222f277c45c8e2aa0085b522e33f8b457f2b11faa5
|
3 |
+
size 322569075
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.8.1
|
2 |
+
torchvision==0.9.1
|
3 |
+
onnxruntime==1.16.1
|
4 |
+
gradio==3.44.0
|
transforms.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
6 |
+
|
7 |
+
from copy import deepcopy
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
|
11 |
+
class ResizeLongestSide:
|
12 |
+
"""
|
13 |
+
Resizes images to the longest side 'target_length', as well as provides
|
14 |
+
methods for resizing coordinates and boxes. Provides methods for
|
15 |
+
transforming both numpy array and batched torch tensors.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, target_length: int) -> None:
|
19 |
+
self.target_length = target_length
|
20 |
+
|
21 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
22 |
+
"""
|
23 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
24 |
+
"""
|
25 |
+
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
26 |
+
return np.array(resize(to_pil_image(image), target_size))
|
27 |
+
|
28 |
+
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
29 |
+
"""
|
30 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
31 |
+
original image size in (H, W) format.
|
32 |
+
"""
|
33 |
+
old_h, old_w = original_size
|
34 |
+
new_h, new_w = self.get_preprocess_shape(
|
35 |
+
original_size[0], original_size[1], self.target_length
|
36 |
+
)
|
37 |
+
coords = deepcopy(coords).astype(float)
|
38 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
39 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
40 |
+
return coords
|
41 |
+
|
42 |
+
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
43 |
+
"""
|
44 |
+
Expects a numpy array shape Bx4. Requires the original image size
|
45 |
+
in (H, W) format.
|
46 |
+
"""
|
47 |
+
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
48 |
+
return boxes.reshape(-1, 4)
|
49 |
+
|
50 |
+
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
51 |
+
"""
|
52 |
+
Expects batched images with shape BxCxHxW and float format. This
|
53 |
+
transformation may not exactly match apply_image. apply_image is
|
54 |
+
the transformation expected by the model.
|
55 |
+
"""
|
56 |
+
# Expects an image in BCHW format. May not exactly match apply_image.
|
57 |
+
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
|
58 |
+
return F.interpolate(
|
59 |
+
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
60 |
+
)
|
61 |
+
|
62 |
+
def apply_coords_torch(
|
63 |
+
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
64 |
+
) -> torch.Tensor:
|
65 |
+
"""
|
66 |
+
Expects a torch tensor with length 2 in the last dimension. Requires the
|
67 |
+
original image size in (H, W) format.
|
68 |
+
"""
|
69 |
+
old_h, old_w = original_size
|
70 |
+
new_h, new_w = self.get_preprocess_shape(
|
71 |
+
original_size[0], original_size[1], self.target_length
|
72 |
+
)
|
73 |
+
coords = deepcopy(coords).to(torch.float)
|
74 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
75 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
76 |
+
return coords
|
77 |
+
|
78 |
+
def apply_boxes_torch(
|
79 |
+
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
80 |
+
) -> torch.Tensor:
|
81 |
+
"""
|
82 |
+
Expects a torch tensor with shape Bx4. Requires the original image
|
83 |
+
size in (H, W) format.
|
84 |
+
"""
|
85 |
+
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
86 |
+
return boxes.reshape(-1, 4)
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
|
90 |
+
"""
|
91 |
+
Compute the output size given input size and target long side length.
|
92 |
+
"""
|
93 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
94 |
+
newh, neww = oldh * scale, oldw * scale
|
95 |
+
neww = int(neww + 0.5)
|
96 |
+
newh = int(newh + 0.5)
|
97 |
+
return (newh, neww)
|