jens commited on
Commit
01bc85d
·
1 Parent(s): 769894a
Files changed (3) hide show
  1. app.py +3 -2
  2. inference.py +29 -1
  3. utils.py +4 -1
app.py CHANGED
@@ -4,12 +4,13 @@ import numpy as np
4
  import cv2
5
  from PIL import Image
6
  import torch
7
- from inference import SegmentPredictor
8
  from utils import generate_PCL, PCL3, point_cloud
9
 
10
 
11
 
12
  sam = SegmentPredictor()
 
13
  red = (255,0,0)
14
  blue = (0,0,255)
15
  annos = []
@@ -52,7 +53,7 @@ with block:
52
  print("depth reconstruction")
53
  image = inputs[raw_image]
54
  # depth reconstruction
55
- fig = point_cloud(image)
56
  return {pcl_figure: fig}
57
 
58
  depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False)
 
4
  import cv2
5
  from PIL import Image
6
  import torch
7
+ from inference import SegmentPredictor, DepthPredictor
8
  from utils import generate_PCL, PCL3, point_cloud
9
 
10
 
11
 
12
  sam = SegmentPredictor()
13
+ dpt = DepthPredictor()
14
  red = (255,0,0)
15
  blue = (0,0,255)
16
  annos = []
 
53
  print("depth reconstruction")
54
  image = inputs[raw_image]
55
  # depth reconstruction
56
+ fig = dpt.generate_fig(image)
57
  return {pcl_figure: fig}
58
 
59
  depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False)
inference.py CHANGED
@@ -6,6 +6,9 @@ import torch
6
  import numpy as np
7
  from PIL import Image
8
  import requests
 
 
 
9
 
10
  class DepthPredictor:
11
  def __init__(self):
@@ -17,7 +20,7 @@ class DepthPredictor:
17
  def predict(self, image):
18
  # prepare image for the model
19
  encoding = self.feature_extractor(image, return_tensors="pt")
20
-
21
  # forward pass
22
  with torch.no_grad():
23
  outputs = self.model(**encoding)
@@ -36,6 +39,31 @@ class DepthPredictor:
36
  #img = Image.fromarray(formatted)
37
  return formatted
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
 
 
6
  import numpy as np
7
  from PIL import Image
8
  import requests
9
+ import open3d as o3d
10
+ import pandas as pd
11
+ import plotly.express as px
12
 
13
  class DepthPredictor:
14
  def __init__(self):
 
20
  def predict(self, image):
21
  # prepare image for the model
22
  encoding = self.feature_extractor(image, return_tensors="pt")
23
+ self.img = image
24
  # forward pass
25
  with torch.no_grad():
26
  outputs = self.model(**encoding)
 
39
  #img = Image.fromarray(formatted)
40
  return formatted
41
 
42
+ def generate_pcl(self, image):
43
+ depth = self.predict(image)
44
+ # Step 2: Create an RGBD image from the RGB and depth image
45
+ depth_o3d = o3d.geometry.Image(depth)
46
+ image_o3d = o3d.geometry.Image(image)
47
+ rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(image_o3d, depth_o3d, convert_rgb_to_intensity=False)
48
+ # Step 3: Create a PointCloud from the RGBD image
49
+ pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, o3d.camera.PinholeCameraIntrinsic(o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault))
50
+ # Step 4: Convert PointCloud data to a NumPy array
51
+ points = np.asarray(pcd.points)
52
+ colors = np.asarray(pcd.colors)
53
+ return points, colors
54
+
55
+ def generate_fig(self, image):
56
+ points, colors = self.generate_pcl(image)
57
+ data = {'x': points[:, 0], 'y': points[:, 1], 'z': points[:, 2],
58
+ 'red': colors[:, 0], 'green': colors[:, 1], 'blue': colors[:, 2]}
59
+ df = pd.DataFrame(data)
60
+ size = np.zeros(len(df))
61
+ size[:] = 0.01
62
+ # Step 6: Create a 3D scatter plot using Plotly Express
63
+ fig = px.scatter_3d(df, x='x', y='y', z='z', color='red', size=size)
64
+ return fig
65
+
66
+
67
 
68
 
69
 
utils.py CHANGED
@@ -93,7 +93,9 @@ def create_3d_pc(rgb_image, depth_image, depth=10):
93
  return filename # Return the file path where the PLY file is saved
94
 
95
 
96
- def point_cloud(rgb_image, depth_image):
 
 
97
  # Step 2: Create an RGBD image from the RGB and depth image
98
  depth_o3d = o3d.geometry.Image(depth_image)
99
  image_o3d = o3d.geometry.Image(rgb_image)
@@ -112,6 +114,7 @@ def point_cloud(rgb_image, depth_image):
112
  # Step 6: Create a 3D scatter plot using Plotly Express
113
  fig = px.scatter_3d(df, x='x', y='y', z='z', color='red', size=size)
114
 
 
115
  return fig
116
 
117
  def array_PCL(rgb_image, depth_image):
 
93
  return filename # Return the file path where the PLY file is saved
94
 
95
 
96
+ def point_cloud(rgb_image):
97
+ depth_predictor = DepthPredictor()
98
+ depth_result = depth_predictor.predict(rgb_image)
99
  # Step 2: Create an RGBD image from the RGB and depth image
100
  depth_o3d = o3d.geometry.Image(depth_image)
101
  image_o3d = o3d.geometry.Image(rgb_image)
 
114
  # Step 6: Create a 3D scatter plot using Plotly Express
115
  fig = px.scatter_3d(df, x='x', y='y', z='z', color='red', size=size)
116
 
117
+
118
  return fig
119
 
120
  def array_PCL(rgb_image, depth_image):