LouisLi commited on
Commit
3f9651e
·
verified ·
1 Parent(s): 4bba054

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +366 -1
app.py CHANGED
@@ -21,6 +21,261 @@ from segment_anything import sam_model_registry
21
  import easyocr
22
  import tts
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  gpt_state = 0
26
 
@@ -735,7 +990,117 @@ def create_ui():
735
  submit_tts = gr.Button(value="Submit", interactive=True)
736
  clear_tts = gr.Button(value="Clear", interactive=True)
737
 
738
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
  def clear_tts_fields():
740
  return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None]
741
 
 
21
  import easyocr
22
  import tts
23
 
24
+ ###############################################################################
25
+ ############# this part is for 3D generate #############
26
+ ###############################################################################
27
+
28
+
29
+
30
+ import spaces
31
+
32
+ import os
33
+ import imageio
34
+ import numpy as np
35
+ import torch
36
+ import rembg
37
+ from PIL import Image
38
+ from torchvision.transforms import v2
39
+ from pytorch_lightning import seed_everything
40
+ from omegaconf import OmegaConf
41
+ from einops import rearrange, repeat
42
+ from tqdm import tqdm
43
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
44
+
45
+ from src.utils.train_util import instantiate_from_config
46
+ from src.utils.camera_util import (
47
+ FOV_to_intrinsics,
48
+ get_zero123plus_input_cameras,
49
+ get_circular_camera_poses,
50
+ )
51
+ from src.utils.mesh_util import save_obj, save_glb
52
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
53
+
54
+ import tempfile
55
+ from functools import partial
56
+
57
+ from huggingface_hub import hf_hub_download
58
+
59
+
60
+
61
+
62
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
63
+ """
64
+ Get the rendering camera parameters.
65
+ """
66
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
67
+ if is_flexicubes:
68
+ cameras = torch.linalg.inv(c2ws)
69
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
70
+ else:
71
+ extrinsics = c2ws.flatten(-2)
72
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
73
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
74
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
75
+ return cameras
76
+
77
+
78
+ def images_to_video(images, output_path, fps=30):
79
+ # images: (N, C, H, W)
80
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
81
+ frames = []
82
+ for i in range(images.shape[0]):
83
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
84
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
85
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
86
+ assert frame.min() >= 0 and frame.max() <= 255, \
87
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
88
+ frames.append(frame)
89
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
90
+
91
+
92
+ ###############################################################################
93
+ # Configuration.
94
+ ###############################################################################
95
+
96
+ import shutil
97
+
98
+ def find_cuda():
99
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
100
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
101
+
102
+ if cuda_home and os.path.exists(cuda_home):
103
+ return cuda_home
104
+
105
+ # Search for the nvcc executable in the system's PATH
106
+ nvcc_path = shutil.which('nvcc')
107
+
108
+ if nvcc_path:
109
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
110
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
111
+ return cuda_path
112
+
113
+ return None
114
+
115
+ cuda_path = find_cuda()
116
+
117
+ if cuda_path:
118
+ print(f"CUDA installation found at: {cuda_path}")
119
+ else:
120
+ print("CUDA installation not found")
121
+
122
+ config_path = 'configs/instant-mesh-large.yaml'
123
+ config = OmegaConf.load(config_path)
124
+ config_name = os.path.basename(config_path).replace('.yaml', '')
125
+ model_config = config.model_config
126
+ infer_config = config.infer_config
127
+
128
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
129
+
130
+ device = torch.device('cuda')
131
+
132
+ # load diffusion model
133
+ print('Loading diffusion model ...')
134
+ pipeline = DiffusionPipeline.from_pretrained(
135
+ "sudo-ai/zero123plus-v1.2",
136
+ custom_pipeline="zero123plus",
137
+ torch_dtype=torch.float16,
138
+ )
139
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
140
+ pipeline.scheduler.config, timestep_spacing='trailing'
141
+ )
142
+
143
+ # load custom white-background UNet
144
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
145
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
146
+ pipeline.unet.load_state_dict(state_dict, strict=True)
147
+
148
+ pipeline = pipeline.to(device)
149
+
150
+ # load reconstruction model
151
+ print('Loading reconstruction model ...')
152
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
153
+ model = instantiate_from_config(model_config)
154
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
155
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
156
+ model.load_state_dict(state_dict, strict=True)
157
+
158
+ model = model.to(device)
159
+
160
+ print('Loading Finished!')
161
+
162
+
163
+ def check_input_image(input_image):
164
+ if input_image is None:
165
+ raise gr.Error("No image uploaded!")
166
+
167
+
168
+ def preprocess(input_image, do_remove_background):
169
+
170
+ rembg_session = rembg.new_session() if do_remove_background else None
171
+
172
+ if do_remove_background:
173
+ input_image = remove_background(input_image, rembg_session)
174
+ input_image = resize_foreground(input_image, 0.85)
175
+
176
+ return input_image
177
+
178
+
179
+ @spaces.GPU
180
+ def generate_mvs(input_image, sample_steps, sample_seed):
181
+
182
+ seed_everything(sample_seed)
183
+
184
+ # sampling
185
+ z123_image = pipeline(
186
+ input_image,
187
+ num_inference_steps=sample_steps
188
+ ).images[0]
189
+
190
+ show_image = np.asarray(z123_image, dtype=np.uint8)
191
+ show_image = torch.from_numpy(show_image) # (960, 640, 3)
192
+ show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
193
+ show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
194
+ show_image = Image.fromarray(show_image.numpy())
195
+
196
+ return z123_image, show_image
197
+
198
+
199
+ @spaces.GPU
200
+ def make3d(images):
201
+
202
+ global model
203
+ if IS_FLEXICUBES:
204
+ model.init_flexicubes_geometry(device, use_renderer=False)
205
+ model = model.eval()
206
+
207
+ images = np.asarray(images, dtype=np.float32) / 255.0
208
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
209
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
210
+
211
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
212
+ render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
213
+
214
+ images = images.unsqueeze(0).to(device)
215
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
216
+
217
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
218
+ print(mesh_fpath)
219
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
220
+ mesh_dirname = os.path.dirname(mesh_fpath)
221
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
222
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
223
+
224
+ with torch.no_grad():
225
+ # get triplane
226
+ planes = model.forward_planes(images, input_cameras)
227
+
228
+ # # get video
229
+ # chunk_size = 20 if IS_FLEXICUBES else 1
230
+ # render_size = 384
231
+
232
+ # frames = []
233
+ # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
234
+ # if IS_FLEXICUBES:
235
+ # frame = model.forward_geometry(
236
+ # planes,
237
+ # render_cameras[:, i:i+chunk_size],
238
+ # render_size=render_size,
239
+ # )['img']
240
+ # else:
241
+ # frame = model.synthesizer(
242
+ # planes,
243
+ # cameras=render_cameras[:, i:i+chunk_size],
244
+ # render_size=render_size,
245
+ # )['images_rgb']
246
+ # frames.append(frame)
247
+ # frames = torch.cat(frames, dim=1)
248
+
249
+ # images_to_video(
250
+ # frames[0],
251
+ # video_fpath,
252
+ # fps=30,
253
+ # )
254
+
255
+ # print(f"Video saved to {video_fpath}")
256
+
257
+ # get mesh
258
+ mesh_out = model.extract_mesh(
259
+ planes,
260
+ use_texture_map=False,
261
+ **infer_config,
262
+ )
263
+
264
+ vertices, faces, vertex_colors = mesh_out
265
+ vertices = vertices[:, [1, 2, 0]]
266
+
267
+ save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
268
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
269
+
270
+ print(f"Mesh saved to {mesh_fpath}")
271
+
272
+ return mesh_fpath, mesh_glb_fpath
273
+
274
+
275
+ ###############################################################################
276
+ ############# above part is for 3D generate #############
277
+ ###############################################################################
278
+
279
 
280
  gpt_state = 0
281
 
 
990
  submit_tts = gr.Button(value="Submit", interactive=True)
991
  clear_tts = gr.Button(value="Clear", interactive=True)
992
 
993
+ ###############################################################################
994
+ # this part is for 3d generate.
995
+ ###############################################################################
996
+
997
+ with gr.Row(variant="panel"):
998
+ with gr.Column():
999
+ with gr.Row():
1000
+ input_image = gr.Image(
1001
+ label="Input Image",
1002
+ image_mode="RGBA",
1003
+ sources="upload",
1004
+ #width=256,
1005
+ #height=256,
1006
+ type="pil",
1007
+ elem_id="content_image",
1008
+ )
1009
+ processed_image = gr.Image(
1010
+ label="Processed Image",
1011
+ image_mode="RGBA",
1012
+ #width=256,
1013
+ #height=256,
1014
+ type="pil",
1015
+ interactive=False
1016
+ )
1017
+ with gr.Row():
1018
+ with gr.Group():
1019
+ do_remove_background = gr.Checkbox(
1020
+ label="Remove Background", value=True
1021
+ )
1022
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
1023
+
1024
+ sample_steps = gr.Slider(
1025
+ label="Sample Steps",
1026
+ minimum=30,
1027
+ maximum=75,
1028
+ value=75,
1029
+ step=5
1030
+ )
1031
+
1032
+ with gr.Row():
1033
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
1034
+
1035
+ with gr.Row(variant="panel"):
1036
+ gr.Examples(
1037
+ examples=[
1038
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
1039
+ ],
1040
+ inputs=[input_image],
1041
+ label="Examples",
1042
+ cache_examples=False,
1043
+ examples_per_page=16
1044
+ )
1045
+
1046
+ with gr.Column():
1047
+
1048
+ with gr.Row():
1049
+
1050
+ with gr.Column():
1051
+ mv_show_images = gr.Image(
1052
+ label="Generated Multi-views",
1053
+ type="pil",
1054
+ width=379,
1055
+ interactive=False
1056
+ )
1057
+
1058
+ # with gr.Column():
1059
+ # output_video = gr.Video(
1060
+ # label="video", format="mp4",
1061
+ # width=379,
1062
+ # autoplay=True,
1063
+ # interactive=False
1064
+ # )
1065
+
1066
+ with gr.Row():
1067
+ with gr.Tab("OBJ"):
1068
+ output_model_obj = gr.Model3D(
1069
+ label="Output Model (OBJ Format)",
1070
+ interactive=False,
1071
+ )
1072
+ gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
1073
+ with gr.Tab("GLB"):
1074
+ output_model_glb = gr.Model3D(
1075
+ label="Output Model (GLB Format)",
1076
+ interactive=False,
1077
+ )
1078
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
1079
+
1080
+ with gr.Row():
1081
+ gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
1082
+
1083
+ mv_images = gr.State()
1084
+
1085
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
1086
+ fn=preprocess,
1087
+ inputs=[input_image, do_remove_background],
1088
+ outputs=[processed_image],
1089
+ ).success(
1090
+ fn=generate_mvs,
1091
+ inputs=[processed_image, sample_steps, sample_seed],
1092
+ outputs=[mv_images, mv_show_images]
1093
+
1094
+ ).success(
1095
+ fn=make3d,
1096
+ inputs=[mv_images],
1097
+ outputs=[output_model_obj, output_model_glb]
1098
+ )
1099
+ ###############################################################################
1100
+ # above part is for 3d generate.
1101
+ ###############################################################################
1102
+
1103
+
1104
  def clear_tts_fields():
1105
  return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None]
1106