Witold Wydmański commited on
Commit
a06c206
·
1 Parent(s): 9de3a79
.gitattributes CHANGED
@@ -1,3 +1,5 @@
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ encoder-quant.onnx filter=lfs diff=lfs merge=lfs -text
2
+ decoder-quant.onnx filter=lfs diff=lfs merge=lfs -text
3
  *.7z filter=lfs diff=lfs merge=lfs -text
4
  *.arrow filter=lfs diff=lfs merge=lfs -text
5
  *.bin filter=lfs diff=lfs merge=lfs -text
__pycache__/transforms.cpython-38.pyc ADDED
Binary file (3.98 kB). View file
 
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import onnxruntime as rt
4
+ import numpy as np
5
+ from transforms import ResizeLongestSide
6
+ from torch.nn import functional as F
7
+ import torch
8
+ import onnxruntime
9
+
10
+ IMAGE_SIZE = 1024
11
+
12
+ def preprocess_image(image):
13
+ transform = ResizeLongestSide(IMAGE_SIZE)
14
+ input_image = transform.apply_image(image)
15
+ input_image_torch = torch.as_tensor(input_image, device="cpu")
16
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
17
+ pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
18
+ pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
19
+ x = (input_image_torch - pixel_mean) / pixel_std
20
+ h, w = x.shape[-2:]
21
+ padh = IMAGE_SIZE - h
22
+ padw = IMAGE_SIZE - w
23
+ x = F.pad(x, (0, padw, 0, padh))
24
+ x = x.numpy()
25
+ return x
26
+
27
+ def prepare_inputs(image_embedding, input_point, image_shape):
28
+ transform = ResizeLongestSide(IMAGE_SIZE)
29
+
30
+ input_label = np.array([1])
31
+ onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
32
+ onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
33
+
34
+ onnx_coord = transform.apply_coords(onnx_coord, image_shape).astype(np.float32)
35
+
36
+ onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
37
+ onnx_has_mask_input = np.zeros(1, dtype=np.float32)
38
+
39
+ decoder_inputs = {
40
+ "image_embeddings": image_embedding,
41
+ "point_coords": onnx_coord,
42
+ "point_labels": onnx_label,
43
+ "mask_input": onnx_mask_input,
44
+ "has_mask_input": onnx_has_mask_input,
45
+ "orig_im_size": np.array(image_shape, dtype=np.float32)
46
+ }
47
+ return decoder_inputs
48
+
49
+ enc_session = onnxruntime.InferenceSession("encoder-quant.onnx")
50
+ dec_session = onnxruntime.InferenceSession("decoder-quant.onnx")
51
+
52
+ def predict_image(img):
53
+ x = preprocess_image(img)
54
+
55
+ encoder_inputs = {
56
+ "x": x,
57
+ }
58
+
59
+ output = enc_session.run(None, encoder_inputs)
60
+ image_embedding = output[0]
61
+
62
+ middle_of_photo = np.array([[img.shape[1] / 2, img.shape[0] / 2]])
63
+
64
+ decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, img.shape[:2])
65
+ masks, _, low_res_logits = dec_session.run(None, decoder_inputs)
66
+
67
+ # normalize the results between -1 and 1
68
+ masks = masks[0][0]
69
+ masks[masks<0] = 0
70
+ masks = masks / np.max(masks)
71
+ return masks, image_embedding, img.shape[:2]
72
+
73
+ def segment_image(image_embedding, shape, evt: gr.SelectData):
74
+ image_embedding = np.array(image_embedding)
75
+ middle_of_photo = np.array([evt.index])
76
+ decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, shape)
77
+ masks, _, low_res_logits = dec_session.run(None, decoder_inputs)
78
+
79
+ # normalize the results between -1 and 1
80
+ masks = masks[0][0]
81
+ masks[masks<0] = 0
82
+ masks = masks / np.max(masks)
83
+ return masks
84
+
85
+ with gr.Blocks() as demo:
86
+ gr.Markdown("# SAM quantized (Segment Anything Model)")
87
+ markdown = """
88
+ This is a demo of the SAM model, which is a model for segmenting anything in an image.
89
+ It returns segmentation mask of the image that's overlapping with the clicked point.
90
+
91
+ The model is quantized using ONNX Runtime
92
+ """
93
+
94
+ gr.Markdown(markdown)
95
+
96
+ embedding = gr.State()
97
+ shape = gr.State()
98
+ with gr.Row():
99
+ with gr.Column():
100
+ inputs = gr.Image()
101
+ start_segmentation = gr.Button("Segment")
102
+
103
+ with gr.Column():
104
+ outputs = gr.Image(label="Segmentation Mask")
105
+
106
+ start_segmentation.click(
107
+ predict_image,
108
+ inputs,
109
+ [outputs, embedding, shape],
110
+ )
111
+
112
+ outputs.select(
113
+ segment_image,
114
+ [embedding, shape],
115
+ outputs,
116
+ )
117
+
118
+
119
+
120
+
121
+ demo.launch()
decoder-quant.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64dedbe577d41b18ccb8d5496d26916929f65c7ecd8f06d5b23c5197434bfcb0
3
+ size 8738974
encoder-quant.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f1b905f70f4a3e769473b222f277c45c8e2aa0085b522e33f8b457f2b11faa5
3
+ size 322569075
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==1.8.1
2
+ torchvision==0.9.1
3
+ onnxruntime==1.16.1
4
+ gradio==3.44.0
transforms.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ from torch.nn import functional as F
5
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
6
+
7
+ from copy import deepcopy
8
+ from typing import Tuple
9
+
10
+
11
+ class ResizeLongestSide:
12
+ """
13
+ Resizes images to the longest side 'target_length', as well as provides
14
+ methods for resizing coordinates and boxes. Provides methods for
15
+ transforming both numpy array and batched torch tensors.
16
+ """
17
+
18
+ def __init__(self, target_length: int) -> None:
19
+ self.target_length = target_length
20
+
21
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
22
+ """
23
+ Expects a numpy array with shape HxWxC in uint8 format.
24
+ """
25
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
26
+ return np.array(resize(to_pil_image(image), target_size))
27
+
28
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
29
+ """
30
+ Expects a numpy array of length 2 in the final dimension. Requires the
31
+ original image size in (H, W) format.
32
+ """
33
+ old_h, old_w = original_size
34
+ new_h, new_w = self.get_preprocess_shape(
35
+ original_size[0], original_size[1], self.target_length
36
+ )
37
+ coords = deepcopy(coords).astype(float)
38
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
39
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
40
+ return coords
41
+
42
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
43
+ """
44
+ Expects a numpy array shape Bx4. Requires the original image size
45
+ in (H, W) format.
46
+ """
47
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
48
+ return boxes.reshape(-1, 4)
49
+
50
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
51
+ """
52
+ Expects batched images with shape BxCxHxW and float format. This
53
+ transformation may not exactly match apply_image. apply_image is
54
+ the transformation expected by the model.
55
+ """
56
+ # Expects an image in BCHW format. May not exactly match apply_image.
57
+ target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
58
+ return F.interpolate(
59
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
60
+ )
61
+
62
+ def apply_coords_torch(
63
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
64
+ ) -> torch.Tensor:
65
+ """
66
+ Expects a torch tensor with length 2 in the last dimension. Requires the
67
+ original image size in (H, W) format.
68
+ """
69
+ old_h, old_w = original_size
70
+ new_h, new_w = self.get_preprocess_shape(
71
+ original_size[0], original_size[1], self.target_length
72
+ )
73
+ coords = deepcopy(coords).to(torch.float)
74
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
75
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
76
+ return coords
77
+
78
+ def apply_boxes_torch(
79
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
80
+ ) -> torch.Tensor:
81
+ """
82
+ Expects a torch tensor with shape Bx4. Requires the original image
83
+ size in (H, W) format.
84
+ """
85
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
86
+ return boxes.reshape(-1, 4)
87
+
88
+ @staticmethod
89
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
90
+ """
91
+ Compute the output size given input size and target long side length.
92
+ """
93
+ scale = long_side_length * 1.0 / max(oldh, oldw)
94
+ newh, neww = oldh * scale, oldw * scale
95
+ neww = int(neww + 0.5)
96
+ newh = int(newh + 0.5)
97
+ return (newh, neww)