Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image, ImageOps | |
import torch | |
from inference import SegmentPredictor, DepthPredictor | |
from utils import generate_PCL, PCL3, point_cloud | |
sam = SegmentPredictor() | |
sam_cpu = SegmentPredictor(device="cpu") | |
dpt = DepthPredictor() | |
red = (255, 0, 0) | |
blue = (0, 0, 255) | |
annos = [] | |
block = gr.Blocks() | |
with block: | |
# States | |
def point_coords_empty(): | |
return [] | |
def point_labels_empty(): | |
return [] | |
image_edit_trigger = gr.State(True) | |
point_coords = gr.State(point_coords_empty) | |
point_labels = gr.State(point_labels_empty) | |
masks = gr.State([]) | |
cutout_idx = gr.State(set()) | |
pred_masks = gr.State([]) | |
prompt_masks = gr.State([]) | |
embedding = gr.State() | |
# UI | |
with gr.Column(): | |
gr.Markdown( | |
"""# Segment Anything Model (SAM) | |
## a new AI model from Meta AI that can "cut out" any object, in any image, with a single click π | |
SAM is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training. [**Official Project**](https://segment-anything.com/) [**Code**](https://github.com/facebookresearch/segment-anything). | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tab("Upload Image"): | |
# mirror_webcam = False | |
upload_image = gr.Image(label="Input", type="pil", tool=None) | |
with gr.Tab("Webcam"): | |
# mirror_webcam = False | |
input_image = gr.Image( | |
label="Input", type="pil", tool=None, source="webcam" | |
) | |
with gr.Row(): | |
sam_encode_btn = gr.Button("Encode", variant="primary") | |
sam_sgmt_everything_btn = gr.Button( | |
"Segment Everything!", variant="primary" | |
) | |
# sam_encode_status = gr.Label('Not encoded yet') | |
with gr.Row(): | |
prompt_image = gr.Image(label="Segments") | |
# prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels') | |
lbl_image = gr.AnnotatedImage(label="Everything") | |
with gr.Row(): | |
point_label_radio = gr.Radio(label="Point Label", choices=[1, 0], value=1) | |
text = gr.Textbox(label="Mask Name") | |
reset_btn = gr.Button("New Mask") | |
selected_masks_image = gr.AnnotatedImage(label="Selected Masks") | |
with gr.Row(): | |
with gr.Column(): | |
pcl_figure = gr.Model3D( | |
label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0] | |
) | |
with gr.Row(): | |
max_depth = gr.Slider( | |
minimum=0, maximum=10, value=3, step=0.01, label="Max Depth" | |
) | |
min_depth = gr.Slider( | |
minimum=0, maximum=10, step=0.01, value=1, label="Min Depth" | |
) | |
n_samples = gr.Slider( | |
minimum=1e3, | |
maximum=1e6, | |
step=1e3, | |
value=1e5, | |
label="Number of Samples", | |
) | |
cube_size = gr.Slider( | |
minimum=0.00001, | |
maximum=0.001, | |
step=0.000001, | |
default=0.00001, | |
label="Cube size", | |
) | |
depth_reconstruction_btn = gr.Button( | |
"3D Reconstruction", variant="primary" | |
) | |
depth_reconstruction_mask_btn = gr.Button( | |
"Mask Reconstruction", variant="primary" | |
) | |
sam_decode_btn = gr.Button("Predict using points!", variant="primary") | |
# components | |
components = { | |
point_coords, | |
point_labels, | |
image_edit_trigger, | |
masks, | |
cutout_idx, | |
input_image, | |
embedding, | |
point_label_radio, | |
text, | |
reset_btn, | |
sam_sgmt_everything_btn, | |
sam_decode_btn, | |
depth_reconstruction_btn, | |
prompt_image, | |
lbl_image, | |
n_samples, | |
max_depth, | |
min_depth, | |
cube_size, | |
selected_masks_image, | |
} | |
def on_upload_image(input_image, upload_image): | |
# Mirror because gradio.image webcam has mirror = True | |
upload_image_mirror = ImageOps.mirror(upload_image) | |
return [upload_image_mirror, upload_image] | |
upload_image.upload( | |
on_upload_image, [input_image, upload_image], [input_image, upload_image] | |
) | |
# event - init coords | |
def on_reset_btn_click(input_image): | |
return input_image, point_coords_empty(), point_labels_empty(), None, [] | |
reset_btn.click( | |
on_reset_btn_click, | |
[input_image], | |
[prompt_image, point_coords, point_labels], | |
queue=False, | |
) | |
def on_prompt_image_select( | |
input_image, | |
prompt_image, | |
point_coords, | |
point_labels, | |
point_label_radio, | |
text, | |
pred_masks, | |
embedding, | |
evt: gr.SelectData, | |
): | |
sam_cpu.dummy_encode(input_image) | |
x, y = evt.index | |
color = red if point_label_radio == 0 else blue | |
if prompt_image is None: | |
prompt_image = np.array(input_image.copy()) | |
cv2.circle(prompt_image, (x, y), 5, color, -1) | |
point_coords.append([x, y]) | |
point_labels.append(point_label_radio) | |
sam_masks = sam_cpu.cond_pred( | |
pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding | |
) | |
return [ | |
prompt_image, | |
(input_image, sam_masks), | |
point_coords, | |
point_labels, | |
sam_masks, | |
] | |
prompt_image.select( | |
on_prompt_image_select, | |
[ | |
input_image, | |
prompt_image, | |
point_coords, | |
point_labels, | |
point_label_radio, | |
text, | |
pred_masks, | |
embedding, | |
], | |
[prompt_image, lbl_image, point_coords, point_labels, pred_masks], | |
queue=True, | |
) | |
def on_everything_image_select( | |
input_image, pred_masks, masks, text, evt: gr.SelectData | |
): | |
i = evt.index | |
mask = pred_masks[i][0] | |
print(mask) | |
print(type(mask)) | |
masks.append((mask, text)) | |
anno = (input_image, masks) | |
return [masks, anno] | |
lbl_image.select( | |
on_everything_image_select, | |
[input_image, pred_masks, masks, text], | |
[masks, selected_masks_image], | |
queue=False, | |
) | |
def on_selected_masks_image_select(input_image, masks, evt: gr.SelectData): | |
i = evt.index | |
del masks[i] | |
anno = (input_image, masks) | |
return [masks, anno] | |
selected_masks_image.select( | |
on_selected_masks_image_select, | |
[input_image, masks], | |
[masks, selected_masks_image], | |
queue=False, | |
) | |
# prompt_lbl_image.select(on_everything_image_select, | |
# [input_image, prompt_masks, masks, text], | |
# [masks, selected_masks_image], queue=False) | |
def on_click_sam_encode_btn(inputs): | |
print("encoding") | |
# encode image on click | |
embedding = sam.encode(inputs[input_image]).cpu() | |
sam_cpu.dummy_encode(inputs[input_image]) | |
print("encoding done") | |
return [inputs[input_image], embedding] | |
sam_encode_btn.click( | |
on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False | |
) | |
def on_click_sam_dencode_btn(inputs): | |
print("inferencing") | |
image = inputs[input_image] | |
generated_mask, _, _ = sam.cond_pred( | |
pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]) | |
) | |
inputs[masks].append((generated_mask, inputs[text])) | |
print(inputs[masks][0]) | |
return {prompt_image: (image, inputs[masks])} | |
sam_decode_btn.click( | |
on_click_sam_dencode_btn, | |
components, | |
[prompt_image, masks, cutout_idx], | |
queue=True, | |
) | |
def on_depth_reconstruction_btn_click(inputs): | |
print("depth reconstruction") | |
path = dpt.generate_obj_rgb( | |
image=inputs[input_image], | |
cube_size=inputs[cube_size], | |
n_samples=inputs[n_samples], | |
# masks=inputs[masks], | |
min_depth=inputs[min_depth], | |
max_depth=inputs[max_depth], | |
) | |
return {pcl_figure: path} | |
depth_reconstruction_btn.click( | |
on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False | |
) | |
def on_depth_reconstruction_mask_btn_click(inputs): | |
print("depth reconstruction") | |
path = dpt.generate_obj_masks2( | |
image=inputs[input_image], | |
cube_size=inputs[cube_size], | |
n_samples=inputs[n_samples], | |
masks=inputs[masks], | |
min_depth=inputs[min_depth], | |
max_depth=inputs[max_depth], | |
) | |
return {pcl_figure: path} | |
depth_reconstruction_mask_btn.click( | |
on_depth_reconstruction_mask_btn_click, components, [pcl_figure], queue=False | |
) | |
def on_sam_sgmt_everything_btn_click(inputs): | |
print("segmenting everything") | |
image = inputs[input_image] | |
sam_masks = sam.segment_everything(image) | |
print(image) | |
print(sam_masks) | |
return [(image, sam_masks), sam_masks] | |
sam_sgmt_everything_btn.click( | |
on_sam_sgmt_everything_btn_click, | |
components, | |
[lbl_image, pred_masks], | |
queue=True, | |
) | |
if __name__ == "__main__": | |
block.queue() | |
block.launch(auth=("novouser", "bstad2023")) | |