jens commited on
Commit
1689431
·
1 Parent(s): 8f90f14
Files changed (3) hide show
  1. app.py +16 -6
  2. app_legacy.py +1 -1
  3. utils.py +15 -2
app.py CHANGED
@@ -5,10 +5,11 @@ import cv2
5
  from PIL import Image
6
  import torch
7
  from inference import SegmentPredictor
 
8
 
9
 
10
 
11
- sam = SegmentPredictor() #service.get_sam(configs.model_type, configs.model_ckpt_path, configs.device)
12
  red = (255,0,0)
13
  blue = (0,0,255)
14
  annos = []
@@ -32,20 +33,30 @@ with block:
32
  with gr.Row():
33
  input_image = gr.Image(label='Input', height=512, type='pil')
34
  masks_annotated_image = gr.AnnotatedImage(label='Segments', height=512)
35
- cutout_galary = gr.Gallery(label='Cutouts', object_fit='contain', height=512)
 
36
  with gr.Row():
37
  with gr.Column(scale=1):
38
  with gr.Row():
39
  point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
40
  text = gr.Textbox(label='Mask Name')
 
41
  sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary')
42
  sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
43
- reset_btn = gr.Button('Reset')
44
-
45
  # components
46
  components = {point_coords, point_labels, raw_image, masks, cutout_idx, input_image,
47
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
48
- sam_decode_btn, masks_annotated_image}
 
 
 
 
 
 
 
 
 
49
  # event - init coords
50
  def on_reset_btn_click(raw_image):
51
  return raw_image, point_coords_empty(), point_labels_empty(), None, []
@@ -79,7 +90,6 @@ with block:
79
  inputs[masks].append((generated_mask, inputs[text]))
80
  return {masks_annotated_image: (image, inputs[masks])}
81
  sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True)
82
- #sam_sgmt_everything_btn.click(on_sam_sgmt_everything_click, components, [masks_annotated_image, masks, cutout_idx], queue=True)
83
 
84
 
85
  if __name__ == '__main__':
 
5
  from PIL import Image
6
  import torch
7
  from inference import SegmentPredictor
8
+ from utils import generate_PCL
9
 
10
 
11
 
12
+ sam = SegmentPredictor()
13
  red = (255,0,0)
14
  blue = (0,0,255)
15
  annos = []
 
33
  with gr.Row():
34
  input_image = gr.Image(label='Input', height=512, type='pil')
35
  masks_annotated_image = gr.AnnotatedImage(label='Segments', height=512)
36
+ pcl_figure = gr.Plot(label='3D Reconstruction')
37
+ #cutout_galary = gr.Gallery(label='Cutouts', object_fit='contain', height=512)
38
  with gr.Row():
39
  with gr.Column(scale=1):
40
  with gr.Row():
41
  point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
42
  text = gr.Textbox(label='Mask Name')
43
+ reset_btn = gr.Button('New Mask')
44
  sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary')
45
  sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
46
+ depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary')
 
47
  # components
48
  components = {point_coords, point_labels, raw_image, masks, cutout_idx, input_image,
49
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
50
+ sam_decode_btn, depth_reconstruction_btn, masks_annotated_image}
51
+ def on_depth_reconstruction_btn_click(inputs):
52
+ print("depth reconstruction")
53
+ image = inputs[raw_image]
54
+ # depth reconstruction
55
+ fig = generate_PCL(image)
56
+ return {pcl_figure: fig}
57
+
58
+ depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False)
59
+
60
  # event - init coords
61
  def on_reset_btn_click(raw_image):
62
  return raw_image, point_coords_empty(), point_labels_empty(), None, []
 
90
  inputs[masks].append((generated_mask, inputs[text]))
91
  return {masks_annotated_image: (image, inputs[masks])}
92
  sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True)
 
93
 
94
 
95
  if __name__ == '__main__':
app_legacy.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
  import supervision as sv
4
  from inference import DepthPredictor, SegmentPredictor
5
- from utils import create_3d_obj, create_3d_pc, point_cloud
6
  import numpy as np
7
 
8
  def produce_depth_map(image):
 
2
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
  import supervision as sv
4
  from inference import DepthPredictor, SegmentPredictor
5
+ from utils import create_3d_obj, create_3d_pc, point_cloud, generate_PCL
6
  import numpy as np
7
 
8
  def produce_depth_map(image):
utils.py CHANGED
@@ -4,7 +4,7 @@ import open3d as o3d
4
  import plotly.express as px
5
  import numpy as np
6
  import pandas as pd
7
-
8
 
9
  def create_3d_obj(rgb_image, depth_image, depth=10, path='./image.gltf'):
10
  depth_o3d = o3d.geometry.Image(depth_image)
@@ -139,4 +139,17 @@ def array_PCL(rgb_image, depth_image):
139
  xx_rgb = ((rgb_image[:, 0] * FX_RGB) / rgb_image[:, 2] + CX_RGB + width / 2).astype(int).clip(0, width - 1)
140
  yy_rgb = ((rgb_image[:, 1] * FY_RGB) / rgb_image[:, 2] + CY_RGB).astype(int).clip(0, height - 1)
141
  colors = rgb_image[yy_rgb, xx_rgb]/255
142
- return pcd, colors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import plotly.express as px
5
  import numpy as np
6
  import pandas as pd
7
+ from inference import DepthPredictor
8
 
9
  def create_3d_obj(rgb_image, depth_image, depth=10, path='./image.gltf'):
10
  depth_o3d = o3d.geometry.Image(depth_image)
 
139
  xx_rgb = ((rgb_image[:, 0] * FX_RGB) / rgb_image[:, 2] + CX_RGB + width / 2).astype(int).clip(0, width - 1)
140
  yy_rgb = ((rgb_image[:, 1] * FY_RGB) / rgb_image[:, 2] + CY_RGB).astype(int).clip(0, height - 1)
141
  colors = rgb_image[yy_rgb, xx_rgb]/255
142
+ return pcd, colors
143
+
144
+ def generate_PCL(image):
145
+ depth_predictor = DepthPredictor()
146
+ depth_result = depth_predictor.predict(image)
147
+ pcd, colors = array_PCL(image, depth_result)
148
+ fig = px.scatter_3d(x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], color=colors, size_max=0.1)
149
+ return fig
150
+
151
+
152
+ def plot_PCL(rgb_image, depth_image):
153
+ pcd, colors = array_PCL(rgb_image, depth_image)
154
+ fig = px.scatter_3d(x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], color=colors, size_max=0.1)
155
+ return fig