Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
import os | |
import ast | |
import time | |
import random | |
from PIL import Image | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
from plyfile import PlyData | |
import gradio as gr | |
import plotly.graph_objs as go | |
from sam_3d import SAM3DDemo | |
def pc_to_plot(pc): | |
return go.Figure( | |
data=[ | |
go.Scatter3d( | |
x=pc['x'], y=pc['y'], z=pc['z'], | |
mode='markers', | |
marker=dict( | |
size=2, | |
color=['rgb({},{},{})'.format(r,g,b) for r,g,b in zip(pc['red'], pc['green'], pc['blue'])], | |
) | |
) | |
], | |
layout=dict( | |
scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)) | |
), | |
) | |
def inference(scene_name, granularity, coords, plot): | |
print(scene_name, coords) | |
sam_3d = SAM3DDemo('vit_b', 'sam_vit_b_01ec64.pth', scene_name) | |
coords = ast.literal_eval(coords) | |
data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final = sam_3d.run_with_coord(coords, int(granularity)) | |
return pc_to_plot(data_point_select), Image.fromarray(rgb_img_w_points), Image.fromarray(rgb_img_w_masks), pc_to_plot(data_final) | |
plydatas = [] | |
for scene_name in ['scene0000_00', 'scene0001_00', 'scene0002_00']: | |
plydata = PlyData.read(f"./scannet_data/{scene_name}/{scene_name}.ply") | |
data = plydata.elements[0].data | |
plydatas.append(data) | |
examples = [['scene0000_00', 0, [0, -2.5, 0.7], pc_to_plot(plydatas[0])], | |
['scene0001_00', 0, [1.9, 1.1, 0.5], pc_to_plot(plydatas[1])], | |
['scene0002_00', 0, [0.58, 0.47, 0.25], pc_to_plot(plydatas[2])],] | |
title = 'Segment_Anything on 3D in-door point clouds' | |
description = """ | |
Gradio Demo for Segmenting Anything on 3D Indoor Scenes (ScanNet supported). \n | |
The logic is straightforward: 1) Find a point in 3D. 2) Project the 3D point onto valid images. 3) Perform 2D SAM on the valid images. 4) Reproject the 2D results back to 3D. 5) Visualize the results. \n | |
Unfortunately, this demo does not support automatically generating coordinates by clicking on the point cloud. You may need to manually write down the coordinates and input them. \n | |
Play with the examples below first and try to modify the coordinates and mask granularity. \n | |
""" | |
article = """ | |
<p style='text-align: center'> | |
<a href='https://arxiv.org/abs/2210.04150' target='_blank'> | |
Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP | |
</a> | |
| | |
<a href='https://github.com/facebookresearch/ov-seg' target='_blank'>Github Repo</a></p> | |
""" | |
gr.Interface( | |
inference, | |
inputs=[ | |
gr.Dropdown(choices=['scene0000_00', 'scene0001_00', 'scene0002_00'], label="Scannet scene name (limited scenes supported)"), | |
gr.Dropdown(choices=[0, 1, 2], label="Mask granularity from 0 (most coarse) to 2 (most precise)"), | |
gr.Textbox(lines=1, label='Coordinates'), | |
gr.Plot(label="Input Point cloud (For visualization and point finding only, click responce not supported yet.)"), | |
], | |
outputs=[gr.Plot(label='Selected point(s): red points show the top 10 cloest points for your input anchor point'), | |
gr.Image(label='Selected image with projected points'), | |
gr.Image(label='Selected image processed after SAM'), | |
gr.Plot(label='Output Point cloud: blue points represent the mask')], | |
title=title, | |
description=description, | |
article=article, | |
examples=examples).queue().launch() |