lemonaddie commited on
Commit
fbf7415
·
verified ·
1 Parent(s): 47e2130

Upload app_recon.py

Browse files
Files changed (1) hide show
  1. app_recon.py +295 -0
app_recon.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import shutil
4
+ import sys
5
+ import git
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch as torch
10
+ from PIL import Image
11
+
12
+ from gradio_imageslider import ImageSlider
13
+ from bilateral_normal_integration.bilateral_normal_integration_cupy import bilateral_normal_integration_function
14
+
15
+ import spaces
16
+
17
+ import fire
18
+
19
+ import argparse
20
+ import os
21
+ import logging
22
+
23
+ import numpy as np
24
+ import torch
25
+ from PIL import Image
26
+ from tqdm.auto import tqdm
27
+ import glob
28
+ import json
29
+ import cv2
30
+
31
+ from rembg import remove
32
+ from segment_anythi ng import sam_model_registry, SamPredictor
33
+ from datetime import datetime
34
+ import time
35
+
36
+
37
+ import sys
38
+ sys.path.append("../")
39
+ from models.geowizard_pipeline import DepthNormalEstimationPipeline
40
+ from utils.seed_all import seed_all
41
+ import matplotlib.pyplot as plt
42
+ from utils.de_normalized import align_scale_shift
43
+ from utils.depth2normal import *
44
+
45
+ from diffusers import DiffusionPipeline, DDIMScheduler, AutoencoderKL
46
+ from models.unet_2d_condition import UNet2DConditionModel
47
+
48
+ from transformers import CLIPTextModel, CLIPTokenizer
49
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
50
+ import torchvision.transforms.functional as TF
51
+ from torchvision.transforms import InterpolationMode
52
+
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+ stable_diffusion_repo_path = "stabilityai/stable-diffusion-2-1-unclip"
56
+ vae = AutoencoderKL.from_pretrained(stable_diffusion_repo_path, subfolder='vae')
57
+ scheduler = DDIMScheduler.from_pretrained(stable_diffusion_repo_path, subfolder='scheduler')
58
+ sd_image_variations_diffusers_path = 'lambdalabs/sd-image-variations-diffusers'
59
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_image_variations_diffusers_path, subfolder="image_encoder")
60
+ feature_extractor = CLIPImageProcessor.from_pretrained(sd_image_variations_diffusers_path, subfolder="feature_extractor")
61
+ unet = UNet2DConditionModel.from_pretrained('.', subfolder="unet")
62
+
63
+ pipe = DepthNormalEstimationPipeline(vae=vae,
64
+ image_encoder=image_encoder,
65
+ feature_extractor=feature_extractor,
66
+ unet=unet,
67
+ scheduler=scheduler)
68
+
69
+ try:
70
+ import xformers
71
+ pipe.enable_xformers_memory_efficient_attention()
72
+ except:
73
+ pass # run without xformers
74
+
75
+ pipe = pipe.to(device)
76
+
77
+ def sam_init():
78
+ sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_l_0b3195.pth")
79
+ model_type = "vit_l"
80
+
81
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda")
82
+ predictor = SamPredictor(sam)
83
+ return predictor
84
+
85
+ sam_predictor = sam_init()
86
+
87
+ def sam_segment(predictor, input_image, *bbox_coords):
88
+ bbox = np.array(bbox_coords)
89
+ image = np.asarray(input_image)
90
+
91
+ start_time = time.time()
92
+ predictor.set_image(image)
93
+
94
+ masks_bbox, scores_bbox, logits_bbox = predictor.predict(
95
+ box=bbox,
96
+ multimask_output=True
97
+ )
98
+
99
+ print(f"SAM Time: {time.time() - start_time:.3f}s")
100
+ out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
101
+ out_image[:, :, :3] = image
102
+ out_image_bbox = out_image.copy()
103
+ out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
104
+ torch.cuda.empty_cache()
105
+ return Image.fromarray(out_image_bbox, mode='RGBA'), masks_bbox
106
+
107
+ @spaces.GPU
108
+ def depth_normal(img_path,
109
+ denoising_steps,
110
+ ensemble_size,
111
+ processing_res,
112
+ seed,
113
+ domain):
114
+
115
+ seed = int(seed)
116
+ if seed >= 0:
117
+ torch.manual_seed(seed)
118
+
119
+ img = Image.open(img_path)
120
+
121
+ pipe_out = pipe(
122
+ img,
123
+ denoising_steps=denoising_steps,
124
+ ensemble_size=ensemble_size,
125
+ processing_res=processing_res,
126
+ batch_size=0,
127
+ domain=domain,
128
+ show_progress_bar=True,
129
+ )
130
+
131
+ depth_colored = pipe_out.depth_colored
132
+ normal_colored = pipe_out.normal_colored
133
+
134
+ depth_np = pipe_out.depth_np
135
+ normal_np = pipe_out.normal_np
136
+
137
+ path_output_dir = os.path.splitext(os.path.basename(img_path))[0] + datetime.now().strftime('%Y%m%d-%H%M%S')
138
+ os.makedirs(path_output_dir, exist_ok=True)
139
+
140
+ name_base = os.path.splitext(os.path.basename(img_path))[0]
141
+ depth_path = os.path.join(path_output_dir, f"{name_base}_depth.npy")
142
+ normal_path = os.path.join(path_output_dir, f"{name_base}_normal.npy")
143
+
144
+ np.save(normal_path, normal_np)
145
+ np.save(depth_path, depth_np)
146
+
147
+ return depth_colored, normal_colored, [depth_path, normal_path]
148
+
149
+ def reconstruction(image, files):
150
+
151
+ torch.cuda.empty_cache()
152
+
153
+ img = Image.open(image)
154
+
155
+ image_rem = img.convert('RGBA')
156
+ image_nobg = remove(image_rem, alpha_matting=True)
157
+ arr = np.asarray(image_nobg)[:,:,-1]
158
+ x_nonzero = np.nonzero(arr.sum(axis=0))
159
+ y_nonzero = np.nonzero(arr.sum(axis=1))
160
+ x_min = int(x_nonzero[0].min())
161
+ y_min = int(y_nonzero[0].min())
162
+ x_max = int(x_nonzero[0].max())
163
+ y_max = int(y_nonzero[0].max())
164
+ masked_image, mask = sam_segment(sam_predictor, img.convert('RGB'), x_min, y_min, x_max, y_max)
165
+
166
+ depth_np = np.load(files[0])
167
+ normal_np = np.load(files[1])
168
+
169
+ dir_name = os.path.dirname(os.path.realpath(files[0]))
170
+ mask_output_temp = mask[-1]
171
+ name_base = os.path.splitext(os.path.basename(files[0]))[0][:-6]
172
+
173
+ normal_np[:, :, 0] *= -1
174
+ _, surface, _, _, _ = bilateral_normal_integration_function(normal_np, mask_output_temp, k=2, K=None, max_iter=100, tol=1e-4, cg_max_iter=5000, cg_tol=1e-3)
175
+ ply_path = os.path.join(dir_name, f"{name_base}_mask.ply")
176
+ surface.save(ply_path, binary=False)
177
+ return ply_path
178
+
179
+ def run_demo():
180
+
181
+
182
+ custom_theme = gr.themes.Soft(primary_hue="blue").set(
183
+ button_secondary_background_fill="*neutral_100",
184
+ button_secondary_background_fill_hover="*neutral_200")
185
+ custom_css = '''#disp_image {
186
+ text-align: center; /* Horizontally center the content */
187
+ }'''
188
+
189
+ _TITLE = '''GeoWizard: Unleashing the Diffusion Priors for 3D Geometry Estimation from a Single Image'''
190
+ _DESCRIPTION = '''
191
+ <div>
192
+ Generate consistent depth and normal from single image. High quality and rich details. (PS: We find the demo running on ZeroGPU output slightly inferior results compared to A100 or 3060 with everything exactly the same.)
193
+ <a style="display:inline-block; margin-left: .5em" href='https://github.com/fuxiao0719/GeoWizard/'><img src='https://img.shields.io/github/stars/fuxiao0719/GeoWizard?style=social' /></a>
194
+ </div>
195
+ '''
196
+ _GPU_ID = 0
197
+
198
+ with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
199
+ with gr.Row():
200
+ with gr.Column(scale=1):
201
+ gr.Markdown('# ' + _TITLE)
202
+ gr.Markdown(_DESCRIPTION)
203
+ with gr.Row(variant='panel'):
204
+ with gr.Column(scale=1):
205
+ input_image = gr.Image(type='filepath', height=320, label='Input image')
206
+
207
+ example_folder = os.path.join(os.path.dirname(__file__), "./files")
208
+ example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
209
+ gr.Examples(
210
+ examples=example_fns,
211
+ inputs=[input_image],
212
+ cache_examples=False,
213
+ label='Examples (click one of the images below to start)',
214
+ examples_per_page=30
215
+ )
216
+ with gr.Column(scale=1):
217
+
218
+ with gr.Accordion('Advanced options', open=True):
219
+ with gr.Column():
220
+
221
+ domain = gr.Radio(
222
+ [
223
+ ("Outdoor", "outdoor"),
224
+ ("Indoor", "indoor"),
225
+ ("Object", "object"),
226
+ ],
227
+ label="Data Type (Must Select One matches your image)",
228
+ value="indoor",
229
+ )
230
+ denoising_steps = gr.Slider(
231
+ label="Number of denoising steps (More steps, better quality)",
232
+ minimum=1,
233
+ maximum=50,
234
+ step=1,
235
+ value=10,
236
+ )
237
+ ensemble_size = gr.Slider(
238
+ label="Ensemble size (More steps, higher accuracy)",
239
+ minimum=1,
240
+ maximum=15,
241
+ step=1,
242
+ value=3,
243
+ )
244
+ seed = gr.Number(0, label='Random Seed. Negative values for not specifying')
245
+
246
+ processing_res = gr.Radio(
247
+ [
248
+ ("Native", 0),
249
+ ("Recommended", 768),
250
+ ],
251
+ label="Processing resolution",
252
+ value=768,
253
+ )
254
+
255
+
256
+ run_btn = gr.Button('Generate', variant='primary', interactive=True)
257
+ with gr.Row():
258
+ with gr.Column():
259
+ depth = gr.Image(interactive=False, show_label=False)
260
+ with gr.Column():
261
+ normal = gr.Image(interactive=False, show_label=False)
262
+
263
+ with gr.Row():
264
+ files = gr.Files(
265
+ label = "Depth and Normal (numpy)",
266
+ elem_id = "download",
267
+ interactive=False,
268
+ )
269
+
270
+ with gr.Row():
271
+ recon_btn = gr.Button('Is there a salient foreground object? If yes, Click here to Reconstruct its 3D model.', variant='primary', interactive=True)
272
+
273
+ with gr.Row():
274
+ reconstructed_3d = gr.Model3D(
275
+ label = 'Bini post-processed 3D model', height=320, interactive=False,
276
+ )
277
+
278
+
279
+ run_btn.click(fn=depth_normal,
280
+ inputs=[input_image, denoising_steps,
281
+ ensemble_size,
282
+ processing_res,
283
+ seed,
284
+ domain],
285
+ outputs=[depth, normal, files]
286
+ )
287
+ recon_btn.click(fn=reconstruction,
288
+ inputs=[input_image, files],
289
+ outputs=[reconstructed_3d]
290
+ )
291
+ demo.queue().launch(share=True, max_threads=80)
292
+
293
+
294
+ if __name__ == '__main__':
295
+ fire.Fire(run_demo)