import os import gradio as gr import numpy as np import cv2 from PIL import Image, ImageOps import torch from inference import SegmentPredictor, DepthPredictor from utils import generate_PCL, PCL3, point_cloud sam = SegmentPredictor() sam_cpu = SegmentPredictor(device="cpu") dpt = DepthPredictor() red = (255, 0, 0) blue = (0, 0, 255) annos = [] block = gr.Blocks() with block: # States def point_coords_empty(): return [] def point_labels_empty(): return [] image_edit_trigger = gr.State(True) point_coords = gr.State(point_coords_empty) point_labels = gr.State(point_labels_empty) masks = gr.State([]) cutout_idx = gr.State(set()) pred_masks = gr.State([]) prompt_masks = gr.State([]) embedding = gr.State() # UI with gr.Column(): gr.Markdown( """# Segment Anything Model (SAM) ## a new AI model from Meta AI that can "cut out" any object, in any image, with a single click 🚀 SAM is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training. [**Official Project**](https://segment-anything.com/) [**Code**](https://github.com/facebookresearch/segment-anything). """ ) with gr.Row(): with gr.Column(): with gr.Tab("Upload Image"): # mirror_webcam = False upload_image = gr.Image(label="Input", type="pil", tool=None) with gr.Tab("Webcam"): # mirror_webcam = False input_image = gr.Image( label="Input", type="pil", tool=None, source="webcam" ) with gr.Row(): sam_encode_btn = gr.Button("Encode", variant="primary") sam_sgmt_everything_btn = gr.Button( "Segment Everything!", variant="primary" ) # sam_encode_status = gr.Label('Not encoded yet') with gr.Row(): prompt_image = gr.Image(label="Segments") # prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels') lbl_image = gr.AnnotatedImage(label="Everything") with gr.Row(): point_label_radio = gr.Radio(label="Point Label", choices=[1, 0], value=1) text = gr.Textbox(label="Mask Name") reset_btn = gr.Button("New Mask") selected_masks_image = gr.AnnotatedImage(label="Selected Masks") with gr.Row(): with gr.Column(): pcl_figure = gr.Model3D( label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0] ) with gr.Row(): max_depth = gr.Slider( minimum=0, maximum=10, step=0.01, default=1, label="Max Depth" ) min_depth = gr.Slider( minimum=0, maximum=10, step=0.01, default=0.1, label="Min Depth" ) n_samples = gr.Slider( minimum=1e3, maximum=1e6, step=1e3, default=1e3, label="Number of Samples", ) cube_size = gr.Slider( minimum=0.00001, maximum=0.001, step=0.000001, default=0.00001, label="Cube size", ) depth_reconstruction_btn = gr.Button( "Depth Reconstruction", variant="primary" ) sam_decode_btn = gr.Button("Predict using points!", variant="primary") # components components = { point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image, embedding, point_label_radio, text, reset_btn, sam_sgmt_everything_btn, sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, max_depth, min_depth, cube_size, selected_masks_image, } def on_upload_image(input_image, upload_image): # Mirror because gradio.image webcam has mirror = True upload_image_mirror = ImageOps.mirror(upload_image) return [upload_image_mirror, upload_image] upload_image.upload( on_upload_image, [input_image, upload_image], [input_image, upload_image] ) # event - init coords def on_reset_btn_click(input_image): return input_image, point_coords_empty(), point_labels_empty(), None, [] reset_btn.click( on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False, ) def on_prompt_image_select( input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, embedding, evt: gr.SelectData, ): sam_cpu.dummy_encode(input_image) x, y = evt.index color = red if point_label_radio == 0 else blue if prompt_image is None: prompt_image = np.array(input_image.copy()) cv2.circle(prompt_image, (x, y), 5, color, -1) point_coords.append([x, y]) point_labels.append(point_label_radio) sam_masks = sam_cpu.cond_pred( pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding ) return [ prompt_image, (input_image, sam_masks), point_coords, point_labels, sam_masks, ] prompt_image.select( on_prompt_image_select, [ input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, embedding, ], [prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=True, ) def on_everything_image_select( input_image, pred_masks, masks, text, evt: gr.SelectData ): i = evt.index mask = pred_masks[i][0] print(mask) print(type(mask)) masks.append((mask, text)) anno = (input_image, masks) return [masks, anno] lbl_image.select( on_everything_image_select, [input_image, pred_masks, masks, text], [masks, selected_masks_image], queue=False, ) def on_selected_masks_image_select(input_image, masks, evt: gr.SelectData): i = evt.index del masks[i] anno = (input_image, masks) return [masks, anno] selected_masks_image.select( on_selected_masks_image_select, [input_image, masks], [masks, selected_masks_image], queue=False, ) # prompt_lbl_image.select(on_everything_image_select, # [input_image, prompt_masks, masks, text], # [masks, selected_masks_image], queue=False) def on_click_sam_encode_btn(inputs): print("encoding") # encode image on click embedding = sam.encode(inputs[input_image]).cpu() sam_cpu.dummy_encode(inputs[input_image]) print("encoding done") return [inputs[input_image], embedding] sam_encode_btn.click( on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False ) def on_click_sam_dencode_btn(inputs): print("inferencing") image = inputs[input_image] generated_mask, _, _ = sam.cond_pred( pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]) ) inputs[masks].append((generated_mask, inputs[text])) print(inputs[masks][0]) return {prompt_image: (image, inputs[masks])} sam_decode_btn.click( on_click_sam_dencode_btn, components, [prompt_image, masks, cutout_idx], queue=True, ) def on_depth_reconstruction_btn_click(inputs): print("depth reconstruction") path = dpt.generate_obj_rgb( image=inputs[input_image], cube_size=inputs[cube_size], n_samples=inputs[n_samples], # masks=inputs[masks], min_depth=inputs[min_depth], max_depth=inputs[max_depth], ) return {pcl_figure: path} depth_reconstruction_btn.click( on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False ) def on_sam_sgmt_everything_btn_click(inputs): print("segmenting everything") image = inputs[input_image] sam_masks = sam.segment_everything(image) print(image) print(sam_masks) return [(image, sam_masks), sam_masks] sam_sgmt_everything_btn.click( on_sam_sgmt_everything_btn_click, components, [lbl_image, pred_masks], queue=True, ) if __name__ == "__main__": block.queue() block.launch()