Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -56,7 +56,9 @@ def predict(input, topk):
|
|
56 |
t_image = img_resize.apply_image_torch(image)
|
57 |
t_orig_size = t_image.shape[-2:]
|
58 |
# pad to 1024x1024
|
|
|
59 |
t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))
|
|
|
60 |
|
61 |
# get box prompt
|
62 |
valid_boxes = []
|
@@ -69,7 +71,7 @@ def predict(input, topk):
|
|
69 |
t_boxes = np.array(valid_boxes)
|
70 |
t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
|
71 |
box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
|
72 |
-
batched_inputs = [{"image": t_image[0], "boxes": box_torch}]
|
73 |
with torch.no_grad():
|
74 |
outputs = sam.infer(batched_inputs, multimask_output=False)
|
75 |
# visualize and post on tensorboard
|
@@ -87,7 +89,7 @@ def predict(input, topk):
|
|
87 |
pred_logits = outputs.logits[i].detach().cpu().numpy()
|
88 |
top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
|
89 |
pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
|
90 |
-
coded_grasp = GraspCoder(
|
91 |
_ = coded_grasp.decode()
|
92 |
decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)
|
93 |
|
@@ -125,7 +127,4 @@ if __name__ == "__main__":
|
|
125 |
btn.click(predict,
|
126 |
inputs=[prompter, top_k],
|
127 |
outputs=[image_output])
|
128 |
-
app.launch()
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
56 |
t_image = img_resize.apply_image_torch(image)
|
57 |
t_orig_size = t_image.shape[-2:]
|
58 |
# pad to 1024x1024
|
59 |
+
pixel_mask = torch.ones(1, t_orig_size[0], t_orig_size[1], device=device)
|
60 |
t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))
|
61 |
+
pixel_mask = torch.nn.functional.pad(pixel_mask, (0, 1024 - t_orig_size[1], 0, 1024 - t_orig_size[0]))
|
62 |
|
63 |
# get box prompt
|
64 |
valid_boxes = []
|
|
|
71 |
t_boxes = np.array(valid_boxes)
|
72 |
t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
|
73 |
box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
|
74 |
+
batched_inputs = [{"image": t_image[0], "boxes": box_torch, "pixel_mask": pixel_mask}]
|
75 |
with torch.no_grad():
|
76 |
outputs = sam.infer(batched_inputs, multimask_output=False)
|
77 |
# visualize and post on tensorboard
|
|
|
89 |
pred_logits = outputs.logits[i].detach().cpu().numpy()
|
90 |
top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
|
91 |
pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
|
92 |
+
coded_grasp = GraspCoder(t_orig_size[0], t_orig_size[1], None, grasp_annos_reformat=pred_grasp)
|
93 |
_ = coded_grasp.decode()
|
94 |
decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)
|
95 |
|
|
|
127 |
btn.click(predict,
|
128 |
inputs=[prompter, top_k],
|
129 |
outputs=[image_output])
|
130 |
+
app.launch()
|
|
|
|
|
|