JUGGHM commited on
Commit
65edc3a
·
verified ·
1 Parent(s): e8d42f0

Update reconstruction (beta version)

Browse files
Files changed (1) hide show
  1. app.py +109 -15
app.py CHANGED
@@ -54,7 +54,10 @@ device = "cuda"
54
  model_large.to(device)
55
  model_small.to(device)
56
 
57
- def depth_normal(img, model_selection="vit-small"):
 
 
 
58
  if model_selection == "vit-small":
59
  model = model_small
60
  cfg = cfg_small
@@ -65,7 +68,10 @@ def depth_normal(img, model_selection="vit-small"):
65
  else:
66
  raise NotImplementedError
67
 
 
 
68
  cv_image = np.array(img)
 
69
  img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
70
  intrinsic = [1000.0, 1000.0, img.shape[1]/2, img.shape[0]/2]
71
  rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(img, intrinsic, cfg.data_basic)
@@ -89,37 +95,125 @@ def depth_normal(img, model_selection="vit-small"):
89
  pred_depth = pred_depth.squeeze().cpu().numpy()
90
  pred_depth[pred_depth<0] = 0
91
  pred_color = gray_to_colormap(pred_depth)
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  pred_normal = pred_normal.squeeze()
94
  if pred_normal.size(0) == 3:
95
  pred_normal = pred_normal.permute(1,2,0)
96
  pred_color_normal = vis_surface_normal(pred_normal)
 
 
 
 
 
 
97
 
98
- ##formatted = (output * 255 / np.max(output)).astype('uint8')
 
 
99
  img = Image.fromarray(pred_color)
 
 
100
  img_normal = Image.fromarray(pred_color_normal)
101
- return img, img_normal
 
 
102
 
103
- #inputs = gr.inputs.Image(type='pil', label="Original Image")
104
- #depth = gr.outputs.Image(type="pil",label="Output Depth")
105
- #normal = gr.outputs.Image(type="pil",label="Output Normal")
 
 
 
 
 
 
 
106
 
107
  title = "Metric3D"
108
- description = "Gradio demo for Metric3D v1/v2 which takes in a single image for computing metric depth and surface normal. To use it, simply upload your image, or click one of the examples to load them. Learn more from our paper linked below."
109
  article = "<p style='text-align: center'><a href='https://arxiv.org/pdf/2307.10984.pdf'>Metric3D: Towards Zero-shot Metric 3D Prediction from A Single Image</a> | <a href='https://github.com/YvanYin/Metric3D'>Github Repo</a></p>"
110
 
111
  examples = [
112
- #["turtle.jpg"],
113
- #["lions.jpg"]
114
- #["files/gundam.jpg"],
115
  ["files/museum.jpg"],
116
  ["files/terra.jpg"],
117
  ["files/underwater.jpg"],
118
  ["files/venue.jpg"]
119
  ]
120
 
121
- gr.Interface(
122
- depth_normal,
123
- inputs=[gr.Image(type='pil', label="Original Image"), gr.Dropdown(["vit-small", "vit-large"], label="Model", info="Select a model type", value="vit-large")],
124
- outputs=[gr.Image(type="pil",label="Output Depth"), gr.Image(type="pil",label="Output Normal")],
125
- title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  model_large.to(device)
55
  model_small.to(device)
56
 
57
+
58
+ outputs_dir = "./outs"
59
+
60
+ def depth_normal(img_path, model_selection="vit-small"):
61
  if model_selection == "vit-small":
62
  model = model_small
63
  cfg = cfg_small
 
68
  else:
69
  raise NotImplementedError
70
 
71
+ img = Image.open(img_path)
72
+
73
  cv_image = np.array(img)
74
+ img = cv_image
75
  img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
76
  intrinsic = [1000.0, 1000.0, img.shape[1]/2, img.shape[0]/2]
77
  rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(img, intrinsic, cfg.data_basic)
 
95
  pred_depth = pred_depth.squeeze().cpu().numpy()
96
  pred_depth[pred_depth<0] = 0
97
  pred_color = gray_to_colormap(pred_depth)
98
+
99
+ ##formatted = (output * 255 / np.max(output)).astype('uint8')
100
+
101
+ path_output_dir = os.path.splitext(os.path.basename(img_path))[0] + datetime.now().strftime('%Y%m%d-%H%M%S')
102
+ path_output_dir = os.path.join(path_output_dir, outputs_dir)
103
+ os.makedirs(path_output_dir, exist_ok=True)
104
+
105
+ name_base = os.path.splitext(os.path.basename(img_path))[0]
106
+
107
+ depth_np = pred_depth
108
+ normal_np = torch.nn.functional.interpolate(pred_normal, [img.shape[0], img.shape[1]], mode='bilinear').squeeze().cpu().numpy()
109
+ normal_np = normal_np.transpose(1,2,0)
110
 
111
  pred_normal = pred_normal.squeeze()
112
  if pred_normal.size(0) == 3:
113
  pred_normal = pred_normal.permute(1,2,0)
114
  pred_color_normal = vis_surface_normal(pred_normal)
115
+
116
+ depth_path = os.path.join(path_output_dir, f"{name_base}_depth.npy")
117
+ normal_path = os.path.join(path_output_dir, f"{name_base}_normal.npy")
118
+
119
+ np.save(normal_path, normal_np)
120
+ np.save(depth_path, depth_np)
121
 
122
+ ori_w = img.shape[1]
123
+ ori_h = img.shape[0]
124
+
125
  img = Image.fromarray(pred_color)
126
+ #img = img.resize((int(300 * ori_w/ ori_h), 300))
127
+
128
  img_normal = Image.fromarray(pred_color_normal)
129
+ #img_normal = img_normal.resize((int(300 * ori_w/ ori_h), 300))
130
+
131
+ return img, img_normal, [depth_path, normal_path]
132
 
133
+ def reconstruction(img_path, files, focal_length, reconstructed_file):
134
+ img = Image.open(img_path)
135
+ cv_image = np.array(img)
136
+ img = cv_image
137
+
138
+ depth_np = np.load(files[0])
139
+ pcd = reconstruct_pcd(depth_np * focal_length / 1000, focal_length, focal_length, img.shape[1]/2, img.shape[0]/2)
140
+ pcd_path = files[0].replace('_depth.npy', '.ply')
141
+ save_point_cloud(pcd.reshape((-1, 3)), img.reshape(-1, 3), pcd_path)
142
+ return [pcd_path]
143
 
144
  title = "Metric3D"
145
+ description = "Gradio demo for Metric3D which takes in a single image for computing metric depth and surface normal. To use it, simply upload your image, or click one of the examples to load them. Learn more from our paper linked below."
146
  article = "<p style='text-align: center'><a href='https://arxiv.org/pdf/2307.10984.pdf'>Metric3D: Towards Zero-shot Metric 3D Prediction from A Single Image</a> | <a href='https://github.com/YvanYin/Metric3D'>Github Repo</a></p>"
147
 
148
  examples = [
 
 
 
149
  ["files/museum.jpg"],
150
  ["files/terra.jpg"],
151
  ["files/underwater.jpg"],
152
  ["files/venue.jpg"]
153
  ]
154
 
155
+ def run_demo():
156
+
157
+ _TITLE = '''Metric3Dv2: A versatile monocular geometric foundation model for zero-shot metric depth and surface normal estimation'''
158
+ _DESCRIPTION = description
159
+
160
+ with gr.Blocks(title=_TITLE) as demo:
161
+ with gr.Row():
162
+ with gr.Column(scale=1):
163
+ gr.Markdown('# ' + _TITLE)
164
+ gr.Markdown(_DESCRIPTION)
165
+ with gr.Row(variant='panel'):
166
+ with gr.Column(scale=1):
167
+ #input_image = gr.Image(type='pil', label='Original Image')
168
+ input_image = gr.Image(type='filepath', height=300, label='Input image')
169
+
170
+ example_folder = os.path.join(os.path.dirname(__file__), "./files")
171
+ example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
172
+ gr.Examples(
173
+ examples=example_fns,
174
+ inputs=[input_image],
175
+ cache_examples=False,
176
+ label='Examples (click one of the images below to start)',
177
+ examples_per_page=30
178
+ )
179
+
180
+ model_choice = gr.Dropdown(["vit-small", "vit-large"], label="Model", info="Select a model type", value="vit-small")
181
+ run_btn = gr.Button('Predict', variant='primary', interactive=True)
182
+
183
+ with gr.Column(scale=1):
184
+ depth = gr.Image(interactive=False, label="Depth")
185
+ normal = gr.Image(interactive=False, label="Normal")
186
+
187
+ with gr.Row():
188
+ files = gr.Files(
189
+ label = "Depth and Normal (numpy)",
190
+ elem_id = "download",
191
+ interactive=False,
192
+ )
193
+
194
+ with gr.Row():
195
+ recon_btn = gr.Button('Is focal length available? If Yes, Enter and Click Here for Metric 3D Reconstruction', variant='primary', interactive=True)
196
+ focal_length = gr.Number(value=1000, label="Focal Length")
197
+
198
+ with gr.Row():
199
+ reconstructed_file = gr.Files(
200
+ label = "3D pointclouds (plyfile)",
201
+ elem_id = "download",
202
+ interactive=False
203
+ )
204
+
205
+ run_btn.click(fn=depth_normal,
206
+ inputs=[input_image,
207
+ model_choice],
208
+ outputs=[depth, normal, files]
209
+ )
210
+ recon_btn.click(fn=reconstruction,
211
+ inputs=[input_image, files, focal_length],
212
+ outputs=[reconstructed_file]
213
+ )
214
+
215
+ demo.queue().launch(share=True, max_threads=80)
216
+
217
+
218
+ if __name__ == '__main__':
219
+ fire.Fire(run_demo)