s194649
auth
8a04aab
import os
import gradio as gr
import numpy as np
import cv2
from PIL import Image, ImageOps
import torch
from inference import SegmentPredictor, DepthPredictor
from utils import generate_PCL, PCL3, point_cloud
sam = SegmentPredictor()
sam_cpu = SegmentPredictor(device="cpu")
dpt = DepthPredictor()
red = (255, 0, 0)
blue = (0, 0, 255)
annos = []
block = gr.Blocks()
with block:
# States
def point_coords_empty():
return []
def point_labels_empty():
return []
image_edit_trigger = gr.State(True)
point_coords = gr.State(point_coords_empty)
point_labels = gr.State(point_labels_empty)
masks = gr.State([])
cutout_idx = gr.State(set())
pred_masks = gr.State([])
prompt_masks = gr.State([])
embedding = gr.State()
# UI
with gr.Column():
gr.Markdown(
"""# Segment Anything Model (SAM)
## a new AI model from Meta AI that can "cut out" any object, in any image, with a single click πŸš€
SAM is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training. [**Official Project**](https://segment-anything.com/) [**Code**](https://github.com/facebookresearch/segment-anything).
"""
)
with gr.Row():
with gr.Column():
with gr.Tab("Upload Image"):
# mirror_webcam = False
upload_image = gr.Image(label="Input", type="pil", tool=None)
with gr.Tab("Webcam"):
# mirror_webcam = False
input_image = gr.Image(
label="Input", type="pil", tool=None, source="webcam"
)
with gr.Row():
sam_encode_btn = gr.Button("Encode", variant="primary")
sam_sgmt_everything_btn = gr.Button(
"Segment Everything!", variant="primary"
)
# sam_encode_status = gr.Label('Not encoded yet')
with gr.Row():
prompt_image = gr.Image(label="Segments")
# prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
lbl_image = gr.AnnotatedImage(label="Everything")
with gr.Row():
point_label_radio = gr.Radio(label="Point Label", choices=[1, 0], value=1)
text = gr.Textbox(label="Mask Name")
reset_btn = gr.Button("New Mask")
selected_masks_image = gr.AnnotatedImage(label="Selected Masks")
with gr.Row():
with gr.Column():
pcl_figure = gr.Model3D(
label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0]
)
with gr.Row():
max_depth = gr.Slider(
minimum=0, maximum=10, value=3, step=0.01, label="Max Depth"
)
min_depth = gr.Slider(
minimum=0, maximum=10, step=0.01, value=1, label="Min Depth"
)
n_samples = gr.Slider(
minimum=1e3,
maximum=1e6,
step=1e3,
value=1e5,
label="Number of Samples",
)
cube_size = gr.Slider(
minimum=0.00001,
maximum=0.001,
step=0.000001,
default=0.00001,
label="Cube size",
)
depth_reconstruction_btn = gr.Button(
"3D Reconstruction", variant="primary"
)
depth_reconstruction_mask_btn = gr.Button(
"Mask Reconstruction", variant="primary"
)
sam_decode_btn = gr.Button("Predict using points!", variant="primary")
# components
components = {
point_coords,
point_labels,
image_edit_trigger,
masks,
cutout_idx,
input_image,
embedding,
point_label_radio,
text,
reset_btn,
sam_sgmt_everything_btn,
sam_decode_btn,
depth_reconstruction_btn,
prompt_image,
lbl_image,
n_samples,
max_depth,
min_depth,
cube_size,
selected_masks_image,
}
def on_upload_image(input_image, upload_image):
# Mirror because gradio.image webcam has mirror = True
upload_image_mirror = ImageOps.mirror(upload_image)
return [upload_image_mirror, upload_image]
upload_image.upload(
on_upload_image, [input_image, upload_image], [input_image, upload_image]
)
# event - init coords
def on_reset_btn_click(input_image):
return input_image, point_coords_empty(), point_labels_empty(), None, []
reset_btn.click(
on_reset_btn_click,
[input_image],
[prompt_image, point_coords, point_labels],
queue=False,
)
def on_prompt_image_select(
input_image,
prompt_image,
point_coords,
point_labels,
point_label_radio,
text,
pred_masks,
embedding,
evt: gr.SelectData,
):
sam_cpu.dummy_encode(input_image)
x, y = evt.index
color = red if point_label_radio == 0 else blue
if prompt_image is None:
prompt_image = np.array(input_image.copy())
cv2.circle(prompt_image, (x, y), 5, color, -1)
point_coords.append([x, y])
point_labels.append(point_label_radio)
sam_masks = sam_cpu.cond_pred(
pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding
)
return [
prompt_image,
(input_image, sam_masks),
point_coords,
point_labels,
sam_masks,
]
prompt_image.select(
on_prompt_image_select,
[
input_image,
prompt_image,
point_coords,
point_labels,
point_label_radio,
text,
pred_masks,
embedding,
],
[prompt_image, lbl_image, point_coords, point_labels, pred_masks],
queue=True,
)
def on_everything_image_select(
input_image, pred_masks, masks, text, evt: gr.SelectData
):
i = evt.index
mask = pred_masks[i][0]
print(mask)
print(type(mask))
masks.append((mask, text))
anno = (input_image, masks)
return [masks, anno]
lbl_image.select(
on_everything_image_select,
[input_image, pred_masks, masks, text],
[masks, selected_masks_image],
queue=False,
)
def on_selected_masks_image_select(input_image, masks, evt: gr.SelectData):
i = evt.index
del masks[i]
anno = (input_image, masks)
return [masks, anno]
selected_masks_image.select(
on_selected_masks_image_select,
[input_image, masks],
[masks, selected_masks_image],
queue=False,
)
# prompt_lbl_image.select(on_everything_image_select,
# [input_image, prompt_masks, masks, text],
# [masks, selected_masks_image], queue=False)
def on_click_sam_encode_btn(inputs):
print("encoding")
# encode image on click
embedding = sam.encode(inputs[input_image]).cpu()
sam_cpu.dummy_encode(inputs[input_image])
print("encoding done")
return [inputs[input_image], embedding]
sam_encode_btn.click(
on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False
)
def on_click_sam_dencode_btn(inputs):
print("inferencing")
image = inputs[input_image]
generated_mask, _, _ = sam.cond_pred(
pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels])
)
inputs[masks].append((generated_mask, inputs[text]))
print(inputs[masks][0])
return {prompt_image: (image, inputs[masks])}
sam_decode_btn.click(
on_click_sam_dencode_btn,
components,
[prompt_image, masks, cutout_idx],
queue=True,
)
def on_depth_reconstruction_btn_click(inputs):
print("depth reconstruction")
path = dpt.generate_obj_rgb(
image=inputs[input_image],
cube_size=inputs[cube_size],
n_samples=inputs[n_samples],
# masks=inputs[masks],
min_depth=inputs[min_depth],
max_depth=inputs[max_depth],
)
return {pcl_figure: path}
depth_reconstruction_btn.click(
on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False
)
def on_depth_reconstruction_mask_btn_click(inputs):
print("depth reconstruction")
path = dpt.generate_obj_masks2(
image=inputs[input_image],
cube_size=inputs[cube_size],
n_samples=inputs[n_samples],
masks=inputs[masks],
min_depth=inputs[min_depth],
max_depth=inputs[max_depth],
)
return {pcl_figure: path}
depth_reconstruction_mask_btn.click(
on_depth_reconstruction_mask_btn_click, components, [pcl_figure], queue=False
)
def on_sam_sgmt_everything_btn_click(inputs):
print("segmenting everything")
image = inputs[input_image]
sam_masks = sam.segment_everything(image)
print(image)
print(sam_masks)
return [(image, sam_masks), sam_masks]
sam_sgmt_everything_btn.click(
on_sam_sgmt_everything_btn_click,
components,
[lbl_image, pred_masks],
queue=True,
)
if __name__ == "__main__":
block.queue()
block.launch(auth=("novouser", "bstad2023"))