JeffLiang commited on
Commit
f9b1bcf
·
1 Parent(s): 8c62972

try to fix memory with fixed input resolution

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. open_vocab_seg/utils/predictor.py +13 -3
app.py CHANGED
@@ -55,7 +55,7 @@ def inference(class_names, proposal_gen, granularity, input_img):
55
 
56
  examples = [['Saturn V, toys, desk, wall, sunflowers, white roses, chrysanthemums, carnations, green dianthus', 'Segment_Anything', 0.8, './resources/demo_samples/sample_01.jpeg'],
57
  ['red bench, yellow bench, blue bench, brown bench, green bench, blue chair, yellow chair, green chair, brown chair, yellow square painting, barrel, buddha statue', 'Segment_Anything', 0.8, './resources/demo_samples/sample_04.png'],
58
- ['pillow, pipe, sweater, shirt, jeans jacket, shoes, cabinet, handbag, photo frame', 'Segment_Anything', 0.8, './resources/demo_samples/sample_05.png'],
59
  ['Saturn V, toys, blossom', 'MaskFormer', 1.0, './resources/demo_samples/sample_01.jpeg'],
60
  ['Oculus, Ukulele', 'MaskFormer', 1.0, './resources/demo_samples/sample_03.jpeg'],
61
  ['Golden gate, yacht', 'MaskFormer', 1.0, './resources/demo_samples/sample_02.jpeg'],]
@@ -89,7 +89,7 @@ gr.Interface(
89
  gr.Slider(0, 1.0, 0.8, label="For Segment_Anything only, granularity of masks from 0 (most coarse) to 1 (most precise)"),
90
  gr.Image(type='filepath'),
91
  ],
92
- outputs=gr.outputs.Image(label='segmentation map'),
93
  title=title,
94
  description=description,
95
  article=article,
 
55
 
56
  examples = [['Saturn V, toys, desk, wall, sunflowers, white roses, chrysanthemums, carnations, green dianthus', 'Segment_Anything', 0.8, './resources/demo_samples/sample_01.jpeg'],
57
  ['red bench, yellow bench, blue bench, brown bench, green bench, blue chair, yellow chair, green chair, brown chair, yellow square painting, barrel, buddha statue', 'Segment_Anything', 0.8, './resources/demo_samples/sample_04.png'],
58
+ ['pillow, pipe, sweater, shirt, jeans jacket, shoes, cabinet, handbag, photo frame', 'Segment_Anything', 0.7, './resources/demo_samples/sample_05.png'],
59
  ['Saturn V, toys, blossom', 'MaskFormer', 1.0, './resources/demo_samples/sample_01.jpeg'],
60
  ['Oculus, Ukulele', 'MaskFormer', 1.0, './resources/demo_samples/sample_03.jpeg'],
61
  ['Golden gate, yacht', 'MaskFormer', 1.0, './resources/demo_samples/sample_02.jpeg'],]
 
89
  gr.Slider(0, 1.0, 0.8, label="For Segment_Anything only, granularity of masks from 0 (most coarse) to 1 (most precise)"),
90
  gr.Image(type='filepath'),
91
  ],
92
+ outputs=gr.components.Image(type="pil", label='segmentation map'),
93
  title=title,
94
  description=description,
95
  article=article,
open_vocab_seg/utils/predictor.py CHANGED
@@ -153,11 +153,19 @@ class SAMVisualizationDemo(object):
153
  sam = sam_model_registry["vit_l"](checkpoint=sam_path).cuda()
154
  self.predictor = SamAutomaticMaskGenerator(sam, points_per_batch=16)
155
  self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path)
156
- self.clip_model.cuda()
157
 
158
- def run_on_image(self, image, class_names):
 
 
 
 
 
 
 
 
159
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
160
- visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names)
 
161
  with torch.no_grad(), torch.cuda.amp.autocast():
162
  masks = self.predictor.generate(image)
163
  pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))]
@@ -192,6 +200,7 @@ class SAMVisualizationDemo(object):
192
  img_batches = torch.split(imgs, 32, dim=0)
193
 
194
  with torch.no_grad(), torch.cuda.amp.autocast():
 
195
  text_features = self.clip_model.encode_text(text.cuda())
196
  text_features /= text_features.norm(dim=-1, keepdim=True)
197
  image_features = []
@@ -224,6 +233,7 @@ class SAMVisualizationDemo(object):
224
  pred_mask = r.argmax(dim=0).to('cpu')
225
  pred_mask[blank_area] = 255
226
  pred_mask = np.array(pred_mask, dtype=np.int)
 
227
 
228
  vis_output = visualizer.draw_sem_seg(
229
  pred_mask
 
153
  sam = sam_model_registry["vit_l"](checkpoint=sam_path).cuda()
154
  self.predictor = SamAutomaticMaskGenerator(sam, points_per_batch=16)
155
  self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path)
 
156
 
157
+ def run_on_image(self, ori_image, class_names):
158
+ height, width, _ = ori_image.shape
159
+ if width > height:
160
+ new_width = 1280
161
+ new_height = int((new_width / width) * height)
162
+ else:
163
+ new_height = 1280
164
+ new_width = int((new_height / height) * width)
165
+ image = cv2.resize(ori_image, (new_width, new_height))
166
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
167
+ ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB)
168
+ visualizer = OVSegVisualizer(ori_image, self.metadata, instance_mode=self.instance_mode, class_names=class_names)
169
  with torch.no_grad(), torch.cuda.amp.autocast():
170
  masks = self.predictor.generate(image)
171
  pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))]
 
200
  img_batches = torch.split(imgs, 32, dim=0)
201
 
202
  with torch.no_grad(), torch.cuda.amp.autocast():
203
+ self.clip_model.cuda()
204
  text_features = self.clip_model.encode_text(text.cuda())
205
  text_features /= text_features.norm(dim=-1, keepdim=True)
206
  image_features = []
 
233
  pred_mask = r.argmax(dim=0).to('cpu')
234
  pred_mask[blank_area] = 255
235
  pred_mask = np.array(pred_mask, dtype=np.int)
236
+ pred_mask = cv2.resize(pred_mask, (width, height), interpolation=cv2.INTER_NEAREST)
237
 
238
  vis_output = visualizer.draw_sem_seg(
239
  pred_mask