gokaygokay commited on
Commit
f25bfd1
·
1 Parent(s): 777f4d9
Files changed (45) hide show
  1. app.py +129 -261
  2. configs/instant-mesh-base.yaml +0 -22
  3. configs/instant-mesh-large.yaml +0 -22
  4. configs/instant-nerf-base.yaml +0 -21
  5. configs/instant-nerf-large.yaml +0 -21
  6. requirements.txt +8 -24
  7. src/__init__.py +0 -0
  8. src/data/__init__.py +0 -0
  9. src/data/objaverse.py +0 -329
  10. src/model.py +0 -310
  11. src/model_mesh.py +0 -325
  12. src/models/__init__.py +0 -0
  13. src/models/decoder/__init__.py +0 -0
  14. src/models/decoder/transformer.py +0 -123
  15. src/models/encoder/__init__.py +0 -0
  16. src/models/encoder/dino.py +0 -550
  17. src/models/encoder/dino_wrapper.py +0 -80
  18. src/models/geometry/__init__.py +0 -7
  19. src/models/geometry/camera/__init__.py +0 -16
  20. src/models/geometry/camera/perspective_camera.py +0 -35
  21. src/models/geometry/render/__init__.py +0 -8
  22. src/models/geometry/render/neural_render.py +0 -121
  23. src/models/geometry/rep_3d/__init__.py +0 -18
  24. src/models/geometry/rep_3d/dmtet.py +0 -504
  25. src/models/geometry/rep_3d/dmtet_utils.py +0 -20
  26. src/models/geometry/rep_3d/extract_texture_map.py +0 -40
  27. src/models/geometry/rep_3d/flexicubes.py +0 -579
  28. src/models/geometry/rep_3d/flexicubes_geometry.py +0 -120
  29. src/models/geometry/rep_3d/tables.py +0 -791
  30. src/models/lrm.py +0 -196
  31. src/models/lrm_mesh.py +0 -385
  32. src/models/renderer/__init__.py +0 -9
  33. src/models/renderer/synthesizer.py +0 -203
  34. src/models/renderer/synthesizer_mesh.py +0 -141
  35. src/models/renderer/utils/__init__.py +0 -9
  36. src/models/renderer/utils/math_utils.py +0 -118
  37. src/models/renderer/utils/ray_marcher.py +0 -72
  38. src/models/renderer/utils/ray_sampler.py +0 -141
  39. src/models/renderer/utils/renderer.py +0 -323
  40. src/utils/__init__.py +0 -0
  41. src/utils/camera_util.py +0 -111
  42. src/utils/infer_util.py +0 -84
  43. src/utils/mesh_util.py +0 -181
  44. src/utils/train_util.py +0 -26
  45. zero123plus/pipeline.py +0 -406
app.py CHANGED
@@ -1,286 +1,154 @@
1
  import spaces
2
- import os
3
- import time
4
- from os import path
5
- from huggingface_hub import hf_hub_download
6
- import numpy as np
7
  import torch
8
- import rembg
9
  from PIL import Image
10
- from torchvision.transforms import v2
11
- from einops import rearrange
12
- from pytorch_lightning import seed_everything
13
- from omegaconf import OmegaConf
14
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
15
- import gradio as gr
16
- import shutil
17
- import tempfile
18
- from src.utils.train_util import instantiate_from_config
19
- from src.utils.camera_util import (
20
- FOV_to_intrinsics,
21
- get_zero123plus_input_cameras,
22
- get_circular_camera_poses,
23
- )
24
- from src.utils.mesh_util import save_obj, save_glb
25
- from src.utils.infer_util import remove_background, resize_foreground, images_to_video
26
  import random
27
- import requests
28
- import io
 
 
29
 
30
- # Set up cache path
31
- cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
32
- os.environ["TRANSFORMERS_CACHE"] = cache_path
33
- os.environ["HF_HUB_CACHE"] = cache_path
34
- os.environ["HF_HOME"] = cache_path
35
 
36
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
37
 
38
- if not path.exists(cache_path):
39
- os.makedirs(cache_path, exist_ok=True)
40
-
41
- torch.backends.cuda.matmul.allow_tf32 = True
42
-
43
- class timer:
44
- def __init__(self, method_name="timed process"):
45
- self.method = method_name
46
- def __enter__(self):
47
- self.start = time.time()
48
- print(f"{self.method} starts")
49
- def __exit__(self, exc_type, exc_val, exc_tb):
50
- end = time.time()
51
- print(f"{self.method} took {str(round(end - self.start, 2))}s")
52
-
53
- def find_cuda():
54
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
55
- if cuda_home and os.path.exists(cuda_home):
56
- return cuda_home
57
- nvcc_path = shutil.which('nvcc')
58
- if nvcc_path:
59
- cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
60
- return cuda_path
61
- return None
62
-
63
- cuda_path = find_cuda()
64
- if cuda_path:
65
- print(f"CUDA installation found at: {cuda_path}")
66
- else:
67
- print("CUDA installation not found")
68
-
69
-
70
- API_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
71
- headers = {"Authorization": f"Bearer {API_TOKEN}"}
72
- timeout = 100
73
-
74
- device = 'cuda'
75
-
76
- # Load 3D generation models
77
- config_path = 'configs/instant-mesh-large.yaml'
78
- config = OmegaConf.load(config_path)
79
- config_name = os.path.basename(config_path).replace('.yaml', '')
80
- model_config = config.model_config
81
- infer_config = config.infer_config
82
-
83
- IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
84
-
85
- # Load diffusion model for 3D generation
86
- print('Loading diffusion model ...')
87
- pipeline = DiffusionPipeline.from_pretrained(
88
- "sudo-ai/zero123plus-v1.2",
89
- custom_pipeline="zero123plus",
90
- torch_dtype=torch.float16,
91
- )
92
- pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
93
- pipeline.scheduler.config, timestep_spacing='trailing'
94
- )
95
-
96
- # Load custom white-background UNet
97
- unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
98
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')
99
- pipeline.unet.load_state_dict(state_dict, strict=True)
100
-
101
- pipeline = pipeline.to(device)
102
-
103
- # Load reconstruction model
104
- print('Loading reconstruction model ...')
105
- model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
106
- model = instantiate_from_config(model_config)
107
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
108
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
109
- model.load_state_dict(state_dict, strict=True)
110
-
111
- model = model.to(device)
112
-
113
- print('Loading Finished!')
114
-
115
- def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
116
- c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
117
- if is_flexicubes:
118
- cameras = torch.linalg.inv(c2ws)
119
- cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
120
- else:
121
- extrinsics = c2ws.flatten(-2)
122
- intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
123
- cameras = torch.cat([extrinsics, intrinsics], dim=-1)
124
- cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
125
- return cameras
126
-
127
- def preprocess(input_image, do_remove_background):
128
- rembg_session = rembg.new_session() if do_remove_background else None
129
- if do_remove_background:
130
- input_image = remove_background(input_image, rembg_session)
131
- input_image = resize_foreground(input_image, 0.85)
132
- return input_image
133
 
 
 
 
134
 
 
 
135
 
 
 
136
 
 
137
  @spaces.GPU
138
- def generate_mvs(input_image, sample_steps, sample_seed):
139
- seed_everything(sample_seed)
140
- z123_image = pipeline(
141
- input_image,
142
- num_inference_steps=sample_steps
143
- ).images[0]
144
- show_image = np.asarray(z123_image, dtype=np.uint8)
145
- show_image = torch.from_numpy(show_image)
146
- show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
147
- show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
148
- show_image = Image.fromarray(show_image.numpy())
149
- return z123_image, show_image
150
-
151
- @spaces.GPU
152
- def make3d(images):
153
- global model
154
- if IS_FLEXICUBES:
155
- model.init_flexicubes_geometry(device, use_renderer=False)
156
- model = model.eval()
157
-
158
- images = np.asarray(images, dtype=np.float32) / 255.0
159
- images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()
160
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)
161
-
162
- input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
163
- render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
164
-
165
- images = images.unsqueeze(0).to(device)
166
- images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
167
-
168
- mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
169
- mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
170
- mesh_dirname = os.path.dirname(mesh_fpath)
171
- mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
172
-
173
- with torch.no_grad():
174
- planes = model.forward_planes(images, input_cameras)
175
- mesh_out = model.extract_mesh(
176
- planes,
177
- use_texture_map=False,
178
- **infer_config,
179
- )
180
- vertices, faces, vertex_colors = mesh_out
181
- vertices = vertices[:, [1, 2, 0]]
182
- save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
183
- save_obj(vertices, faces, vertex_colors, mesh_fpath)
184
 
185
- return mesh_fpath, mesh_glb_fpath
186
-
187
- # Remove the FluxPipeline setup and replace with the query function
188
- def query(prompt, steps=28, cfg_scale=3.5, randomize_seed=True, seed=-1, width=1024, height=1024):
189
- if not prompt:
190
- return None
191
-
192
- lora_id = "gokaygokay/Flux-Game-Assets-LoRA-v2"
193
- API_URL = f"https://api-inference.huggingface.co/models/{lora_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  if randomize_seed:
196
- seed = random.randint(1, 4294967296)
197
 
198
- payload = {
199
- "inputs": prompt,
200
- "steps": steps,
201
- "cfg_scale": cfg_scale,
202
- "seed": seed,
203
- "parameters": {
204
- "width": width,
205
- "height": height
206
- }
207
- }
208
-
209
- response = requests.post(API_URL, headers=headers, json=payload, timeout=100)
210
- if response.status_code != 200:
211
- if response.status_code == 503:
212
- raise gr.Error("The model is being loaded")
213
- raise gr.Error(f"Error {response.status_code}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- try:
216
- image_bytes = response.content
217
- image = Image.open(io.BytesIO(image_bytes))
218
- return image
219
- except Exception as e:
220
- print(f"Error when trying to open the image: {e}")
221
- return None
222
-
223
- # Update the Gradio interface
224
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
225
- gr.Markdown(
226
- """
227
- <div style="text-align: center; max-width: 650px; margin: 0 auto;">
228
- <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem;">Flux Image to 3D Model Generator</h1>
229
- </div>
230
- """
231
- )
232
-
233
  with gr.Row():
234
- with gr.Column(scale=3):
235
- prompt = gr.Textbox(
236
- label="Your Image Description",
237
- placeholder="E.g., A serene landscape with mountains and a lake at sunset",
238
- lines=3
239
- )
240
 
241
  with gr.Accordion("Advanced Settings", open=False):
242
- with gr.Group():
243
- with gr.Row():
244
- height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
245
- width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
246
-
247
- with gr.Row():
248
- steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
249
- scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
250
-
251
- seed = gr.Number(label="Seed", value=-1, precision=0)
252
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
253
 
254
- generate_btn = gr.Button("Generate 3D Model", variant="primary")
255
-
256
- with gr.Column(scale=4):
257
- flux_output = gr.Image(label="Generated Flux Image")
258
- mv_show_images = gr.Image(label="Generated Multi-views")
259
- with gr.Row():
260
- with gr.Tab("OBJ"):
261
- output_model_obj = gr.Model3D(label="Output Model (OBJ Format)")
262
- with gr.Tab("GLB"):
263
- output_model_glb = gr.Model3D(label="Output Model (GLB Format)")
264
-
265
- mv_images = gr.State()
266
-
267
- def process_pipeline(prompt, height, width, steps, scales, seed, randomize_seed):
268
- # Generate Flux image using the API
269
- prompt_real = f"wbgmsst, {prompt}, white background"
270
- flux_image = query(prompt_real, steps, scales, randomize_seed, seed, width, height)
271
- if flux_image is None:
272
- raise gr.Error("Failed to generate image")
273
-
274
- processed_image = preprocess(flux_image, do_remove_background=True)
275
- mv_images, show_image = generate_mvs(processed_image, steps, seed)
276
- obj_path, glb_path = make3d(mv_images)
277
- return flux_image, show_image, obj_path, glb_path
278
-
279
  generate_btn.click(
280
- fn=process_pipeline,
281
- inputs=[prompt, height, width, steps, scales, seed, randomize_seed],
282
- outputs=[flux_output, mv_show_images, output_model_obj, output_model_glb]
 
 
 
283
  )
284
 
285
- if __name__ == "__main__":
286
- demo.queue().launch()
 
1
  import spaces
2
+ import gradio as gr
 
 
 
 
3
  import torch
 
4
  from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
6
+ from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import random
8
+ import numpy as np
9
+ import os
10
+ import subprocess
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
+ # Initialize models
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.bfloat16
 
 
16
 
17
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
18
 
19
+ # FLUX.1-dev model
20
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token = huggingface_token).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Initialize Florence model
23
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
24
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
25
 
26
+ # Prompt Enhancer
27
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
28
 
29
+ MAX_SEED = np.iinfo(np.int32).max
30
+ MAX_IMAGE_SIZE = 2048
31
 
32
+ # Florence caption function
33
  @spaces.GPU
34
+ def florence_caption(image):
35
+ # Convert image to PIL if it's not already
36
+ if not isinstance(image, Image.Image):
37
+ image = Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
40
+ generated_ids = florence_model.generate(
41
+ input_ids=inputs["input_ids"],
42
+ pixel_values=inputs["pixel_values"],
43
+ max_new_tokens=1024,
44
+ early_stopping=False,
45
+ do_sample=False,
46
+ num_beams=3,
47
+ )
48
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
49
+ parsed_answer = florence_processor.post_process_generation(
50
+ generated_text,
51
+ task="<MORE_DETAILED_CAPTION>",
52
+ image_size=(image.width, image.height)
53
+ )
54
+ return parsed_answer["<MORE_DETAILED_CAPTION>"]
55
+
56
+ # Prompt Enhancer function
57
+ def enhance_prompt(input_prompt):
58
+ result = enhancer_long("Enhance the description: " + input_prompt)
59
+ enhanced_text = result[0]['summary_text']
60
+ return enhanced_text
61
+
62
+ @spaces.GPU(duration=190)
63
+ def process_workflow(image, text_prompt, use_enhancer, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
64
+ if image is not None:
65
+ # Convert image to PIL if it's not already
66
+ if not isinstance(image, Image.Image):
67
+ image = Image.fromarray(image)
68
+
69
+ prompt = florence_caption(image)
70
+ print(prompt)
71
+ else:
72
+ prompt = text_prompt
73
+
74
+ if use_enhancer:
75
+ prompt = enhance_prompt(prompt)
76
 
77
  if randomize_seed:
78
+ seed = random.randint(0, MAX_SEED)
79
 
80
+ generator = torch.Generator(device=device).manual_seed(seed)
81
+
82
+ image = pipe(
83
+ prompt=prompt,
84
+ generator=generator,
85
+ num_inference_steps=num_inference_steps,
86
+ width=width,
87
+ height=height,
88
+ guidance_scale=guidance_scale
89
+ ).images[0]
90
+
91
+ return image, prompt, seed
92
+
93
+ custom_css = """
94
+ .input-group, .output-group {
95
+ border: 1px solid #e0e0e0;
96
+ border-radius: 10px;
97
+ padding: 20px;
98
+ margin-bottom: 20px;
99
+ background-color: #f9f9f9;
100
+ }
101
+ .submit-btn {
102
+ background-color: #2980b9 !important;
103
+ color: white !important;
104
+ }
105
+ .submit-btn:hover {
106
+ background-color: #3498db !important;
107
+ }
108
+ """
109
+
110
+ title = """<h1 align="center">FLUX.1-dev with Florence-2 Captioner and Prompt Enhancer</h1>
111
+ <p><center>
112
+ <a href="https://huggingface.co/black-forest-labs/FLUX.1-dev" target="_blank">[FLUX.1-dev Model]</a>
113
+ <a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
114
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
115
+ <p align="center">Create long prompts from images or enhance your short prompts with prompt enhancer</p>
116
+ </center></p>
117
+ """
118
+
119
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
120
+ gr.HTML(title)
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  with gr.Row():
123
+ with gr.Column(scale=1):
124
+ with gr.Group(elem_classes="input-group"):
125
+ input_image = gr.Image(label="Input Image (Florence-2 Captioner)")
 
 
 
126
 
127
  with gr.Accordion("Advanced Settings", open=False):
128
+ text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)")
129
+ use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
130
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
131
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
132
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
133
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
134
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5)
135
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
 
 
 
136
 
137
+ generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
138
+
139
+ with gr.Column(scale=1):
140
+ with gr.Group(elem_classes="output-group"):
141
+ output_image = gr.Image(label="Result", elem_id="gallery", show_label=False)
142
+ final_prompt = gr.Textbox(label="Final Prompt Used")
143
+ used_seed = gr.Number(label="Seed Used")
144
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  generate_btn.click(
146
+ fn=process_workflow,
147
+ inputs=[
148
+ input_image, text_prompt, use_enhancer, seed, randomize_seed,
149
+ width, height, guidance_scale, num_inference_steps
150
+ ],
151
+ outputs=[output_image, final_prompt, used_seed]
152
  )
153
 
154
+ demo.launch(debug=True)
 
configs/instant-mesh-base.yaml DELETED
@@ -1,22 +0,0 @@
1
- model_config:
2
- target: src.models.lrm_mesh.InstantMesh
3
- params:
4
- encoder_feat_dim: 768
5
- encoder_freeze: false
6
- encoder_model_name: facebook/dino-vitb16
7
- transformer_dim: 1024
8
- transformer_layers: 12
9
- transformer_heads: 16
10
- triplane_low_res: 32
11
- triplane_high_res: 64
12
- triplane_dim: 40
13
- rendering_samples_per_ray: 96
14
- grid_res: 128
15
- grid_scale: 2.1
16
-
17
-
18
- infer_config:
19
- unet_path: ckpts/diffusion_pytorch_model.bin
20
- model_path: ckpts/instant_mesh_base.ckpt
21
- texture_resolution: 1024
22
- render_resolution: 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/instant-mesh-large.yaml DELETED
@@ -1,22 +0,0 @@
1
- model_config:
2
- target: src.models.lrm_mesh.InstantMesh
3
- params:
4
- encoder_feat_dim: 768
5
- encoder_freeze: false
6
- encoder_model_name: facebook/dino-vitb16
7
- transformer_dim: 1024
8
- transformer_layers: 16
9
- transformer_heads: 16
10
- triplane_low_res: 32
11
- triplane_high_res: 64
12
- triplane_dim: 80
13
- rendering_samples_per_ray: 128
14
- grid_res: 128
15
- grid_scale: 2.1
16
-
17
-
18
- infer_config:
19
- unet_path: ckpts/diffusion_pytorch_model.bin
20
- model_path: ckpts/instant_mesh_large.ckpt
21
- texture_resolution: 1024
22
- render_resolution: 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/instant-nerf-base.yaml DELETED
@@ -1,21 +0,0 @@
1
- model_config:
2
- target: src.models.lrm.InstantNeRF
3
- params:
4
- encoder_feat_dim: 768
5
- encoder_freeze: false
6
- encoder_model_name: facebook/dino-vitb16
7
- transformer_dim: 1024
8
- transformer_layers: 12
9
- transformer_heads: 16
10
- triplane_low_res: 32
11
- triplane_high_res: 64
12
- triplane_dim: 40
13
- rendering_samples_per_ray: 96
14
-
15
-
16
- infer_config:
17
- unet_path: ckpts/diffusion_pytorch_model.bin
18
- model_path: ckpts/instant_nerf_base.ckpt
19
- mesh_threshold: 10.0
20
- mesh_resolution: 256
21
- render_resolution: 384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/instant-nerf-large.yaml DELETED
@@ -1,21 +0,0 @@
1
- model_config:
2
- target: src.models.lrm.InstantNeRF
3
- params:
4
- encoder_feat_dim: 768
5
- encoder_freeze: false
6
- encoder_model_name: facebook/dino-vitb16
7
- transformer_dim: 1024
8
- transformer_layers: 16
9
- transformer_heads: 16
10
- triplane_low_res: 32
11
- triplane_high_res: 64
12
- triplane_dim: 80
13
- rendering_samples_per_ray: 128
14
-
15
-
16
- infer_config:
17
- unet_path: ckpts/diffusion_pytorch_model.bin
18
- model_path: ckpts/instant_nerf_large.ckpt
19
- mesh_threshold: 10.0
20
- mesh_resolution: 256
21
- render_resolution: 384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,27 +1,11 @@
 
 
 
 
1
  torch==2.4.0
2
  torchvision==0.19.0
3
- torchaudio==2.4.0
4
- pytorch-lightning==2.1.2
5
- einops
6
- omegaconf
7
- deepspeed
8
- torchmetrics
9
- webdataset
10
  sentencepiece
11
- accelerate
12
- tensorboard
13
- PyMCubes
14
- trimesh
15
- rembg
16
- peft
17
- transformers==4.44.0
18
- diffusers==0.31.0
19
- bitsandbytes
20
- imageio[ffmpeg]
21
- xatlas
22
- plyfile
23
- xformers==0.0.27.post2
24
- git+https://github.com/NVlabs/nvdiffrast/
25
- huggingface-hub==0.25.2
26
- optimum-quanto==0.2.5
27
- k-diffusion
 
1
+ spaces
2
+ huggingface_hub
3
+ accelerate
4
+ git+https://github.com/huggingface/diffusers.git
5
  torch==2.4.0
6
  torchvision==0.19.0
7
+ transformers==4.42.4
8
+ xformers
 
 
 
 
 
9
  sentencepiece
10
+ timm
11
+ einops
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/__init__.py DELETED
File without changes
src/data/__init__.py DELETED
File without changes
src/data/objaverse.py DELETED
@@ -1,329 +0,0 @@
1
- import os, sys
2
- import math
3
- import json
4
- import importlib
5
- from pathlib import Path
6
-
7
- import cv2
8
- import random
9
- import numpy as np
10
- from PIL import Image
11
- import webdataset as wds
12
- import pytorch_lightning as pl
13
-
14
- import torch
15
- import torch.nn.functional as F
16
- from torch.utils.data import Dataset
17
- from torch.utils.data import DataLoader
18
- from torch.utils.data.distributed import DistributedSampler
19
- from torchvision import transforms
20
-
21
- from src.utils.train_util import instantiate_from_config
22
- from src.utils.camera_util import (
23
- FOV_to_intrinsics,
24
- center_looking_at_camera_pose,
25
- get_surrounding_views,
26
- )
27
-
28
-
29
- class DataModuleFromConfig(pl.LightningDataModule):
30
- def __init__(
31
- self,
32
- batch_size=8,
33
- num_workers=4,
34
- train=None,
35
- validation=None,
36
- test=None,
37
- **kwargs,
38
- ):
39
- super().__init__()
40
-
41
- self.batch_size = batch_size
42
- self.num_workers = num_workers
43
-
44
- self.dataset_configs = dict()
45
- if train is not None:
46
- self.dataset_configs['train'] = train
47
- if validation is not None:
48
- self.dataset_configs['validation'] = validation
49
- if test is not None:
50
- self.dataset_configs['test'] = test
51
-
52
- def setup(self, stage):
53
-
54
- if stage in ['fit']:
55
- self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
56
- else:
57
- raise NotImplementedError
58
-
59
- def train_dataloader(self):
60
-
61
- sampler = DistributedSampler(self.datasets['train'])
62
- return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
63
-
64
- def val_dataloader(self):
65
-
66
- sampler = DistributedSampler(self.datasets['validation'])
67
- return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
68
-
69
- def test_dataloader(self):
70
-
71
- return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
72
-
73
-
74
- class ObjaverseData(Dataset):
75
- def __init__(self,
76
- root_dir='objaverse/',
77
- meta_fname='valid_paths.json',
78
- input_image_dir='rendering_random_32views',
79
- target_image_dir='rendering_random_32views',
80
- input_view_num=6,
81
- target_view_num=2,
82
- total_view_n=32,
83
- fov=50,
84
- camera_rotation=True,
85
- validation=False,
86
- ):
87
- self.root_dir = Path(root_dir)
88
- self.input_image_dir = input_image_dir
89
- self.target_image_dir = target_image_dir
90
-
91
- self.input_view_num = input_view_num
92
- self.target_view_num = target_view_num
93
- self.total_view_n = total_view_n
94
- self.fov = fov
95
- self.camera_rotation = camera_rotation
96
-
97
- with open(os.path.join(root_dir, meta_fname)) as f:
98
- filtered_dict = json.load(f)
99
- paths = filtered_dict['good_objs']
100
- self.paths = paths
101
-
102
- self.depth_scale = 4.0
103
-
104
- total_objects = len(self.paths)
105
- print('============= length of dataset %d =============' % len(self.paths))
106
-
107
- def __len__(self):
108
- return len(self.paths)
109
-
110
- def load_im(self, path, color):
111
- '''
112
- replace background pixel with random color in rendering
113
- '''
114
- pil_img = Image.open(path)
115
-
116
- image = np.asarray(pil_img, dtype=np.float32) / 255.
117
- alpha = image[:, :, 3:]
118
- image = image[:, :, :3] * alpha + color * (1 - alpha)
119
-
120
- image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
121
- alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
122
- return image, alpha
123
-
124
- def __getitem__(self, index):
125
- # load data
126
- while True:
127
- input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
128
- target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
129
-
130
- indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
131
- input_indices = indices[:self.input_view_num]
132
- target_indices = indices[self.input_view_num:]
133
-
134
- '''background color, default: white'''
135
- bg_white = [1., 1., 1.]
136
- bg_black = [0., 0., 0.]
137
-
138
- image_list = []
139
- alpha_list = []
140
- depth_list = []
141
- normal_list = []
142
- pose_list = []
143
-
144
- try:
145
- input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
146
- for idx in input_indices:
147
- image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
148
- normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
149
- depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
150
- depth = torch.from_numpy(depth).unsqueeze(0)
151
- pose = input_cameras[idx]
152
- pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
153
-
154
- image_list.append(image)
155
- alpha_list.append(alpha)
156
- depth_list.append(depth)
157
- normal_list.append(normal)
158
- pose_list.append(pose)
159
-
160
- target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
161
- for idx in target_indices:
162
- image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
163
- normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
164
- depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
165
- depth = torch.from_numpy(depth).unsqueeze(0)
166
- pose = target_cameras[idx]
167
- pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
168
-
169
- image_list.append(image)
170
- alpha_list.append(alpha)
171
- depth_list.append(depth)
172
- normal_list.append(normal)
173
- pose_list.append(pose)
174
-
175
- except Exception as e:
176
- print(e)
177
- index = np.random.randint(0, len(self.paths))
178
- continue
179
-
180
- break
181
-
182
- images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
183
- alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
184
- depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
185
- normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
186
- w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
187
- c2ws = torch.linalg.inv(w2cs).float()
188
-
189
- normals = normals * 2.0 - 1.0
190
- normals = F.normalize(normals, dim=1)
191
- normals = (normals + 1.0) / 2.0
192
- normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
193
-
194
- # random rotation along z axis
195
- if self.camera_rotation:
196
- degree = np.random.uniform(0, math.pi * 2)
197
- rot = torch.tensor([
198
- [np.cos(degree), -np.sin(degree), 0, 0],
199
- [np.sin(degree), np.cos(degree), 0, 0],
200
- [0, 0, 1, 0],
201
- [0, 0, 0, 1],
202
- ]).unsqueeze(0).float()
203
- c2ws = torch.matmul(rot, c2ws)
204
-
205
- # rotate normals
206
- N, _, H, W = normals.shape
207
- normals = normals * 2.0 - 1.0
208
- normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
209
- normals = F.normalize(normals, dim=1)
210
- normals = (normals + 1.0) / 2.0
211
- normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
212
-
213
- # random scaling
214
- if np.random.rand() < 0.5:
215
- scale = np.random.uniform(0.8, 1.0)
216
- c2ws[:, :3, 3] *= scale
217
- depths *= scale
218
-
219
- # instrinsics of perspective cameras
220
- K = FOV_to_intrinsics(self.fov)
221
- Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
222
-
223
- data = {
224
- 'input_images': images[:self.input_view_num], # (6, 3, H, W)
225
- 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
226
- 'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
227
- 'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
228
- 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
229
- 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
230
-
231
- # lrm generator input and supervision
232
- 'target_images': images[self.input_view_num:], # (V, 3, H, W)
233
- 'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
234
- 'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
235
- 'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
236
- 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
237
- 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
238
-
239
- 'depth_available': 1,
240
- }
241
- return data
242
-
243
-
244
- class ValidationData(Dataset):
245
- def __init__(self,
246
- root_dir='objaverse/',
247
- input_view_num=6,
248
- input_image_size=256,
249
- fov=50,
250
- ):
251
- self.root_dir = Path(root_dir)
252
- self.input_view_num = input_view_num
253
- self.input_image_size = input_image_size
254
- self.fov = fov
255
-
256
- self.paths = sorted(os.listdir(self.root_dir))
257
- print('============= length of dataset %d =============' % len(self.paths))
258
-
259
- cam_distance = 2.5
260
- azimuths = np.array([30, 90, 150, 210, 270, 330])
261
- elevations = np.array([30, -20, 30, -20, 30, -20])
262
- azimuths = np.deg2rad(azimuths)
263
- elevations = np.deg2rad(elevations)
264
-
265
- x = cam_distance * np.cos(elevations) * np.cos(azimuths)
266
- y = cam_distance * np.cos(elevations) * np.sin(azimuths)
267
- z = cam_distance * np.sin(elevations)
268
-
269
- cam_locations = np.stack([x, y, z], axis=-1)
270
- cam_locations = torch.from_numpy(cam_locations).float()
271
- c2ws = center_looking_at_camera_pose(cam_locations)
272
- self.c2ws = c2ws.float()
273
- self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
274
-
275
- render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
276
- render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
277
- self.render_c2ws = render_c2ws.float()
278
- self.render_Ks = render_Ks.float()
279
-
280
- def __len__(self):
281
- return len(self.paths)
282
-
283
- def load_im(self, path, color):
284
- '''
285
- replace background pixel with random color in rendering
286
- '''
287
- pil_img = Image.open(path)
288
- pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
289
-
290
- image = np.asarray(pil_img, dtype=np.float32) / 255.
291
- if image.shape[-1] == 4:
292
- alpha = image[:, :, 3:]
293
- image = image[:, :, :3] * alpha + color * (1 - alpha)
294
- else:
295
- alpha = np.ones_like(image[:, :, :1])
296
-
297
- image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
298
- alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
299
- return image, alpha
300
-
301
- def __getitem__(self, index):
302
- # load data
303
- input_image_path = os.path.join(self.root_dir, self.paths[index])
304
-
305
- '''background color, default: white'''
306
- # color = np.random.uniform(0.48, 0.52)
307
- bkg_color = [1.0, 1.0, 1.0]
308
-
309
- image_list = []
310
- alpha_list = []
311
-
312
- for idx in range(self.input_view_num):
313
- image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
314
- image_list.append(image)
315
- alpha_list.append(alpha)
316
-
317
- images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
318
- alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
319
-
320
- data = {
321
- 'input_images': images, # (6, 3, H, W)
322
- 'input_alphas': alphas, # (6, 1, H, W)
323
- 'input_c2ws': self.c2ws, # (6, 4, 4)
324
- 'input_Ks': self.Ks, # (6, 3, 3)
325
-
326
- 'render_c2ws': self.render_c2ws,
327
- 'render_Ks': self.render_Ks,
328
- }
329
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model.py DELETED
@@ -1,310 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import torch.nn.functional as F
5
- from torchvision.transforms import v2
6
- from torchvision.utils import make_grid, save_image
7
- from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
8
- import pytorch_lightning as pl
9
- from einops import rearrange, repeat
10
-
11
- from src.utils.train_util import instantiate_from_config
12
-
13
-
14
- class MVRecon(pl.LightningModule):
15
- def __init__(
16
- self,
17
- lrm_generator_config,
18
- lrm_path=None,
19
- input_size=256,
20
- render_size=192,
21
- ):
22
- super(MVRecon, self).__init__()
23
-
24
- self.input_size = input_size
25
- self.render_size = render_size
26
-
27
- # init modules
28
- self.lrm_generator = instantiate_from_config(lrm_generator_config)
29
- if lrm_path is not None:
30
- lrm_ckpt = torch.load(lrm_path)
31
- self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
32
-
33
- self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
34
-
35
- self.validation_step_outputs = []
36
-
37
- def on_fit_start(self):
38
- if self.global_rank == 0:
39
- os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
40
- os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
41
-
42
- def prepare_batch_data(self, batch):
43
- lrm_generator_input = {}
44
- render_gt = {} # for supervision
45
-
46
- # input images
47
- images = batch['input_images']
48
- images = v2.functional.resize(
49
- images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
50
-
51
- lrm_generator_input['images'] = images.to(self.device)
52
-
53
- # input cameras and render cameras
54
- input_c2ws = batch['input_c2ws'].flatten(-2)
55
- input_Ks = batch['input_Ks'].flatten(-2)
56
- target_c2ws = batch['target_c2ws'].flatten(-2)
57
- target_Ks = batch['target_Ks'].flatten(-2)
58
- render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
59
- render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
60
- render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
61
-
62
- input_extrinsics = input_c2ws[:, :, :12]
63
- input_intrinsics = torch.stack([
64
- input_Ks[:, :, 0], input_Ks[:, :, 4],
65
- input_Ks[:, :, 2], input_Ks[:, :, 5],
66
- ], dim=-1)
67
- cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
68
-
69
- # add noise to input cameras
70
- cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
71
-
72
- lrm_generator_input['cameras'] = cameras.to(self.device)
73
- lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
74
-
75
- # target images
76
- target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
77
- target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
78
- target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
79
-
80
- # random crop
81
- render_size = np.random.randint(self.render_size, 513)
82
- target_images = v2.functional.resize(
83
- target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
84
- target_depths = v2.functional.resize(
85
- target_depths, render_size, interpolation=0, antialias=True)
86
- target_alphas = v2.functional.resize(
87
- target_alphas, render_size, interpolation=0, antialias=True)
88
-
89
- crop_params = v2.RandomCrop.get_params(
90
- target_images, output_size=(self.render_size, self.render_size))
91
- target_images = v2.functional.crop(target_images, *crop_params)
92
- target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
93
- target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
94
-
95
- lrm_generator_input['render_size'] = render_size
96
- lrm_generator_input['crop_params'] = crop_params
97
-
98
- render_gt['target_images'] = target_images.to(self.device)
99
- render_gt['target_depths'] = target_depths.to(self.device)
100
- render_gt['target_alphas'] = target_alphas.to(self.device)
101
-
102
- return lrm_generator_input, render_gt
103
-
104
- def prepare_validation_batch_data(self, batch):
105
- lrm_generator_input = {}
106
-
107
- # input images
108
- images = batch['input_images']
109
- images = v2.functional.resize(
110
- images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
111
-
112
- lrm_generator_input['images'] = images.to(self.device)
113
-
114
- input_c2ws = batch['input_c2ws'].flatten(-2)
115
- input_Ks = batch['input_Ks'].flatten(-2)
116
-
117
- input_extrinsics = input_c2ws[:, :, :12]
118
- input_intrinsics = torch.stack([
119
- input_Ks[:, :, 0], input_Ks[:, :, 4],
120
- input_Ks[:, :, 2], input_Ks[:, :, 5],
121
- ], dim=-1)
122
- cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
123
-
124
- lrm_generator_input['cameras'] = cameras.to(self.device)
125
-
126
- render_c2ws = batch['render_c2ws'].flatten(-2)
127
- render_Ks = batch['render_Ks'].flatten(-2)
128
- render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
129
-
130
- lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
131
- lrm_generator_input['render_size'] = 384
132
- lrm_generator_input['crop_params'] = None
133
-
134
- return lrm_generator_input
135
-
136
- def forward_lrm_generator(
137
- self,
138
- images,
139
- cameras,
140
- render_cameras,
141
- render_size=192,
142
- crop_params=None,
143
- chunk_size=1,
144
- ):
145
- planes = torch.utils.checkpoint.checkpoint(
146
- self.lrm_generator.forward_planes,
147
- images,
148
- cameras,
149
- use_reentrant=False,
150
- )
151
- frames = []
152
- for i in range(0, render_cameras.shape[1], chunk_size):
153
- frames.append(
154
- torch.utils.checkpoint.checkpoint(
155
- self.lrm_generator.synthesizer,
156
- planes,
157
- cameras=render_cameras[:, i:i+chunk_size],
158
- render_size=render_size,
159
- crop_params=crop_params,
160
- use_reentrant=False
161
- )
162
- )
163
- frames = {
164
- k: torch.cat([r[k] for r in frames], dim=1)
165
- for k in frames[0].keys()
166
- }
167
- return frames
168
-
169
- def forward(self, lrm_generator_input):
170
- images = lrm_generator_input['images']
171
- cameras = lrm_generator_input['cameras']
172
- render_cameras = lrm_generator_input['render_cameras']
173
- render_size = lrm_generator_input['render_size']
174
- crop_params = lrm_generator_input['crop_params']
175
-
176
- out = self.forward_lrm_generator(
177
- images,
178
- cameras,
179
- render_cameras,
180
- render_size=render_size,
181
- crop_params=crop_params,
182
- chunk_size=1,
183
- )
184
- render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
185
- render_depths = out['images_depth']
186
- render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
187
-
188
- out = {
189
- 'render_images': render_images,
190
- 'render_depths': render_depths,
191
- 'render_alphas': render_alphas,
192
- }
193
- return out
194
-
195
- def training_step(self, batch, batch_idx):
196
- lrm_generator_input, render_gt = self.prepare_batch_data(batch)
197
-
198
- render_out = self.forward(lrm_generator_input)
199
-
200
- loss, loss_dict = self.compute_loss(render_out, render_gt)
201
-
202
- self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
203
-
204
- if self.global_step % 1000 == 0 and self.global_rank == 0:
205
- B, N, C, H, W = render_gt['target_images'].shape
206
- N_in = lrm_generator_input['images'].shape[1]
207
-
208
- input_images = v2.functional.resize(
209
- lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
210
- input_images = torch.cat(
211
- [input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
212
-
213
- input_images = rearrange(
214
- input_images, 'b n c h w -> b c h (n w)')
215
- target_images = rearrange(
216
- render_gt['target_images'], 'b n c h w -> b c h (n w)')
217
- render_images = rearrange(
218
- render_out['render_images'], 'b n c h w -> b c h (n w)')
219
- target_alphas = rearrange(
220
- repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
221
- render_alphas = rearrange(
222
- repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
223
- target_depths = rearrange(
224
- repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
225
- render_depths = rearrange(
226
- repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
227
- MAX_DEPTH = torch.max(target_depths)
228
- target_depths = target_depths / MAX_DEPTH * target_alphas
229
- render_depths = render_depths / MAX_DEPTH
230
-
231
- grid = torch.cat([
232
- input_images,
233
- target_images, render_images,
234
- target_alphas, render_alphas,
235
- target_depths, render_depths,
236
- ], dim=-2)
237
- grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
238
-
239
- save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
240
-
241
- return loss
242
-
243
- def compute_loss(self, render_out, render_gt):
244
- # NOTE: the rgb value range of OpenLRM is [0, 1]
245
- render_images = render_out['render_images']
246
- target_images = render_gt['target_images'].to(render_images)
247
- render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
248
- target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
249
-
250
- loss_mse = F.mse_loss(render_images, target_images)
251
- loss_lpips = 2.0 * self.lpips(render_images, target_images)
252
-
253
- render_alphas = render_out['render_alphas']
254
- target_alphas = render_gt['target_alphas']
255
- loss_mask = F.mse_loss(render_alphas, target_alphas)
256
-
257
- loss = loss_mse + loss_lpips + loss_mask
258
-
259
- prefix = 'train'
260
- loss_dict = {}
261
- loss_dict.update({f'{prefix}/loss_mse': loss_mse})
262
- loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
263
- loss_dict.update({f'{prefix}/loss_mask': loss_mask})
264
- loss_dict.update({f'{prefix}/loss': loss})
265
-
266
- return loss, loss_dict
267
-
268
- @torch.no_grad()
269
- def validation_step(self, batch, batch_idx):
270
- lrm_generator_input = self.prepare_validation_batch_data(batch)
271
-
272
- render_out = self.forward(lrm_generator_input)
273
- render_images = render_out['render_images']
274
- render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
275
-
276
- self.validation_step_outputs.append(render_images)
277
-
278
- def on_validation_epoch_end(self):
279
- images = torch.cat(self.validation_step_outputs, dim=-1)
280
-
281
- all_images = self.all_gather(images)
282
- all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
283
-
284
- if self.global_rank == 0:
285
- image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
286
-
287
- grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
288
- save_image(grid, image_path)
289
- print(f"Saved image to {image_path}")
290
-
291
- self.validation_step_outputs.clear()
292
-
293
- def configure_optimizers(self):
294
- lr = self.learning_rate
295
-
296
- params = []
297
-
298
- lrm_params_fast, lrm_params_slow = [], []
299
- for n, p in self.lrm_generator.named_parameters():
300
- if 'adaLN_modulation' in n or 'camera_embedder' in n:
301
- lrm_params_fast.append(p)
302
- else:
303
- lrm_params_slow.append(p)
304
- params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
305
- params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
306
-
307
- optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
308
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
309
-
310
- return {'optimizer': optimizer, 'lr_scheduler': scheduler}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model_mesh.py DELETED
@@ -1,325 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import torch.nn.functional as F
5
- from torchvision.transforms import v2
6
- from torchvision.utils import make_grid, save_image
7
- from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
8
- import pytorch_lightning as pl
9
- from einops import rearrange, repeat
10
-
11
- from src.utils.train_util import instantiate_from_config
12
-
13
-
14
- # Regulrarization loss for FlexiCubes
15
- def sdf_reg_loss_batch(sdf, all_edges):
16
- sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
17
- mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
18
- sdf_f1x6x2 = sdf_f1x6x2[mask]
19
- sdf_diff = F.binary_cross_entropy_with_logits(
20
- sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
21
- F.binary_cross_entropy_with_logits(
22
- sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
23
- return sdf_diff
24
-
25
-
26
- class MVRecon(pl.LightningModule):
27
- def __init__(
28
- self,
29
- lrm_generator_config,
30
- input_size=256,
31
- render_size=512,
32
- init_ckpt=None,
33
- ):
34
- super(MVRecon, self).__init__()
35
-
36
- self.input_size = input_size
37
- self.render_size = render_size
38
-
39
- # init modules
40
- self.lrm_generator = instantiate_from_config(lrm_generator_config)
41
-
42
- self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
43
-
44
- # Load weights from pretrained MVRecon model, and use the mlp
45
- # weights to initialize the weights of sdf and rgb mlps.
46
- if init_ckpt is not None:
47
- sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
48
- sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
49
- sd_fc = {}
50
- for k, v in sd.items():
51
- if k.startswith('lrm_generator.synthesizer.decoder.net.'):
52
- if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
53
- # Here we assume the density filed's isosurface threshold is t,
54
- # we reverse the sign of density filed to initialize SDF field.
55
- # -(w*x + b - t) = (-w)*x + (t - b)
56
- if 'weight' in k:
57
- sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
58
- else:
59
- sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
60
- sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
61
- else:
62
- sd_fc[k.replace('net.', 'net_sdf.')] = v
63
- sd_fc[k.replace('net.', 'net_rgb.')] = v
64
- else:
65
- sd_fc[k] = v
66
- sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
67
- # missing `net_deformation` and `net_weight` parameters
68
- self.lrm_generator.load_state_dict(sd_fc, strict=False)
69
- print(f'Loaded weights from {init_ckpt}')
70
-
71
- self.validation_step_outputs = []
72
-
73
- def on_fit_start(self):
74
- device = torch.device(f'cuda:{self.global_rank}')
75
- self.lrm_generator.init_flexicubes_geometry(device)
76
- if self.global_rank == 0:
77
- os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
78
- os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
79
-
80
- def prepare_batch_data(self, batch):
81
- lrm_generator_input = {}
82
- render_gt = {}
83
-
84
- # input images
85
- images = batch['input_images']
86
- images = v2.functional.resize(
87
- images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
88
-
89
- lrm_generator_input['images'] = images.to(self.device)
90
-
91
- # input cameras and render cameras
92
- input_c2ws = batch['input_c2ws']
93
- input_Ks = batch['input_Ks']
94
- target_c2ws = batch['target_c2ws']
95
-
96
- render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
97
- render_w2cs = torch.linalg.inv(render_c2ws)
98
-
99
- input_extrinsics = input_c2ws.flatten(-2)
100
- input_extrinsics = input_extrinsics[:, :, :12]
101
- input_intrinsics = input_Ks.flatten(-2)
102
- input_intrinsics = torch.stack([
103
- input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
104
- input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
105
- ], dim=-1)
106
- cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
107
-
108
- # add noise to input_cameras
109
- cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
110
-
111
- lrm_generator_input['cameras'] = cameras.to(self.device)
112
- lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
113
-
114
- # target images
115
- target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
116
- target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
117
- target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
118
- target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
119
-
120
- render_size = self.render_size
121
- target_images = v2.functional.resize(
122
- target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
123
- target_depths = v2.functional.resize(
124
- target_depths, render_size, interpolation=0, antialias=True)
125
- target_alphas = v2.functional.resize(
126
- target_alphas, render_size, interpolation=0, antialias=True)
127
- target_normals = v2.functional.resize(
128
- target_normals, render_size, interpolation=3, antialias=True)
129
-
130
- lrm_generator_input['render_size'] = render_size
131
-
132
- render_gt['target_images'] = target_images.to(self.device)
133
- render_gt['target_depths'] = target_depths.to(self.device)
134
- render_gt['target_alphas'] = target_alphas.to(self.device)
135
- render_gt['target_normals'] = target_normals.to(self.device)
136
-
137
- return lrm_generator_input, render_gt
138
-
139
- def prepare_validation_batch_data(self, batch):
140
- lrm_generator_input = {}
141
-
142
- # input images
143
- images = batch['input_images']
144
- images = v2.functional.resize(
145
- images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
146
-
147
- lrm_generator_input['images'] = images.to(self.device)
148
-
149
- # input cameras
150
- input_c2ws = batch['input_c2ws'].flatten(-2)
151
- input_Ks = batch['input_Ks'].flatten(-2)
152
-
153
- input_extrinsics = input_c2ws[:, :, :12]
154
- input_intrinsics = torch.stack([
155
- input_Ks[:, :, 0], input_Ks[:, :, 4],
156
- input_Ks[:, :, 2], input_Ks[:, :, 5],
157
- ], dim=-1)
158
- cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
159
-
160
- lrm_generator_input['cameras'] = cameras.to(self.device)
161
-
162
- # render cameras
163
- render_c2ws = batch['render_c2ws']
164
- render_w2cs = torch.linalg.inv(render_c2ws)
165
-
166
- lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
167
- lrm_generator_input['render_size'] = 384
168
-
169
- return lrm_generator_input
170
-
171
- def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
172
- planes = torch.utils.checkpoint.checkpoint(
173
- self.lrm_generator.forward_planes,
174
- images,
175
- cameras,
176
- use_reentrant=False,
177
- )
178
- out = self.lrm_generator.forward_geometry(
179
- planes,
180
- render_cameras,
181
- render_size,
182
- )
183
- return out
184
-
185
- def forward(self, lrm_generator_input):
186
- images = lrm_generator_input['images']
187
- cameras = lrm_generator_input['cameras']
188
- render_cameras = lrm_generator_input['render_cameras']
189
- render_size = lrm_generator_input['render_size']
190
-
191
- out = self.forward_lrm_generator(
192
- images, cameras, render_cameras, render_size=render_size)
193
-
194
- return out
195
-
196
- def training_step(self, batch, batch_idx):
197
- lrm_generator_input, render_gt = self.prepare_batch_data(batch)
198
-
199
- render_out = self.forward(lrm_generator_input)
200
-
201
- loss, loss_dict = self.compute_loss(render_out, render_gt)
202
-
203
- self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
204
-
205
- if self.global_step % 1000 == 0 and self.global_rank == 0:
206
- B, N, C, H, W = render_gt['target_images'].shape
207
- N_in = lrm_generator_input['images'].shape[1]
208
-
209
- target_images = rearrange(
210
- render_gt['target_images'], 'b n c h w -> b c h (n w)')
211
- render_images = rearrange(
212
- render_out['img'], 'b n c h w -> b c h (n w)')
213
- target_alphas = rearrange(
214
- repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
215
- render_alphas = rearrange(
216
- repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
217
- target_depths = rearrange(
218
- repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
219
- render_depths = rearrange(
220
- repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
221
- target_normals = rearrange(
222
- render_gt['target_normals'], 'b n c h w -> b c h (n w)')
223
- render_normals = rearrange(
224
- render_out['normal'], 'b n c h w -> b c h (n w)')
225
- MAX_DEPTH = torch.max(target_depths)
226
- target_depths = target_depths / MAX_DEPTH * target_alphas
227
- render_depths = render_depths / MAX_DEPTH
228
-
229
- grid = torch.cat([
230
- target_images, render_images,
231
- target_alphas, render_alphas,
232
- target_depths, render_depths,
233
- target_normals, render_normals,
234
- ], dim=-2)
235
- grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
236
-
237
- image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
238
- save_image(grid, image_path)
239
- print(f"Saved image to {image_path}")
240
-
241
- return loss
242
-
243
- def compute_loss(self, render_out, render_gt):
244
- # NOTE: the rgb value range of OpenLRM is [0, 1]
245
- render_images = render_out['img']
246
- target_images = render_gt['target_images'].to(render_images)
247
- render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
248
- target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
249
- loss_mse = F.mse_loss(render_images, target_images)
250
- loss_lpips = 2.0 * self.lpips(render_images, target_images)
251
-
252
- render_alphas = render_out['mask']
253
- target_alphas = render_gt['target_alphas']
254
- loss_mask = F.mse_loss(render_alphas, target_alphas)
255
-
256
- render_depths = render_out['depth']
257
- target_depths = render_gt['target_depths']
258
- loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
259
-
260
- render_normals = render_out['normal'] * 2.0 - 1.0
261
- target_normals = render_gt['target_normals'] * 2.0 - 1.0
262
- similarity = (render_normals * target_normals).sum(dim=-3).abs()
263
- normal_mask = target_alphas.squeeze(-3)
264
- loss_normal = 1 - similarity[normal_mask>0].mean()
265
- loss_normal = 0.2 * loss_normal
266
-
267
- # flexicubes regularization loss
268
- sdf = render_out['sdf']
269
- sdf_reg_loss = render_out['sdf_reg_loss']
270
- sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
271
- _, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
272
- flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
273
- flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
274
-
275
- loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
276
-
277
- loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
278
-
279
- prefix = 'train'
280
- loss_dict = {}
281
- loss_dict.update({f'{prefix}/loss_mse': loss_mse})
282
- loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
283
- loss_dict.update({f'{prefix}/loss_mask': loss_mask})
284
- loss_dict.update({f'{prefix}/loss_normal': loss_normal})
285
- loss_dict.update({f'{prefix}/loss_depth': loss_depth})
286
- loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
287
- loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
288
- loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
289
- loss_dict.update({f'{prefix}/loss': loss})
290
-
291
- return loss, loss_dict
292
-
293
- @torch.no_grad()
294
- def validation_step(self, batch, batch_idx):
295
- lrm_generator_input = self.prepare_validation_batch_data(batch)
296
-
297
- render_out = self.forward(lrm_generator_input)
298
- render_images = render_out['img']
299
- render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
300
-
301
- self.validation_step_outputs.append(render_images)
302
-
303
- def on_validation_epoch_end(self):
304
- images = torch.cat(self.validation_step_outputs, dim=-1)
305
-
306
- all_images = self.all_gather(images)
307
- all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
308
-
309
- if self.global_rank == 0:
310
- image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
311
-
312
- grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
313
- save_image(grid, image_path)
314
- print(f"Saved image to {image_path}")
315
-
316
- self.validation_step_outputs.clear()
317
-
318
- def configure_optimizers(self):
319
- lr = self.learning_rate
320
-
321
- optimizer = torch.optim.AdamW(
322
- self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
323
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
324
-
325
- return {'optimizer': optimizer, 'lr_scheduler': scheduler}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/__init__.py DELETED
File without changes
src/models/decoder/__init__.py DELETED
File without changes
src/models/decoder/transformer.py DELETED
@@ -1,123 +0,0 @@
1
- # Copyright (c) 2023, Zexin He
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import torch
17
- import torch.nn as nn
18
-
19
-
20
- class BasicTransformerBlock(nn.Module):
21
- """
22
- Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
23
- """
24
- # use attention from torch.nn.MultiHeadAttention
25
- # Block contains a cross-attention layer, a self-attention layer, and a MLP
26
- def __init__(
27
- self,
28
- inner_dim: int,
29
- cond_dim: int,
30
- num_heads: int,
31
- eps: float,
32
- attn_drop: float = 0.,
33
- attn_bias: bool = False,
34
- mlp_ratio: float = 4.,
35
- mlp_drop: float = 0.,
36
- ):
37
- super().__init__()
38
-
39
- self.norm1 = nn.LayerNorm(inner_dim)
40
- self.cross_attn = nn.MultiheadAttention(
41
- embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
42
- dropout=attn_drop, bias=attn_bias, batch_first=True)
43
- self.norm2 = nn.LayerNorm(inner_dim)
44
- self.self_attn = nn.MultiheadAttention(
45
- embed_dim=inner_dim, num_heads=num_heads,
46
- dropout=attn_drop, bias=attn_bias, batch_first=True)
47
- self.norm3 = nn.LayerNorm(inner_dim)
48
- self.mlp = nn.Sequential(
49
- nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
50
- nn.GELU(),
51
- nn.Dropout(mlp_drop),
52
- nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
53
- nn.Dropout(mlp_drop),
54
- )
55
-
56
- def forward(self, x, cond):
57
- # x: [N, L, D]
58
- # cond: [N, L_cond, D_cond]
59
- x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
60
- before_sa = self.norm2(x)
61
- x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
62
- x = x + self.mlp(self.norm3(x))
63
- return x
64
-
65
-
66
- class TriplaneTransformer(nn.Module):
67
- """
68
- Transformer with condition that generates a triplane representation.
69
-
70
- Reference:
71
- Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
72
- """
73
- def __init__(
74
- self,
75
- inner_dim: int,
76
- image_feat_dim: int,
77
- triplane_low_res: int,
78
- triplane_high_res: int,
79
- triplane_dim: int,
80
- num_layers: int,
81
- num_heads: int,
82
- eps: float = 1e-6,
83
- ):
84
- super().__init__()
85
-
86
- # attributes
87
- self.triplane_low_res = triplane_low_res
88
- self.triplane_high_res = triplane_high_res
89
- self.triplane_dim = triplane_dim
90
-
91
- # modules
92
- # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
93
- self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
94
- self.layers = nn.ModuleList([
95
- BasicTransformerBlock(
96
- inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
97
- for _ in range(num_layers)
98
- ])
99
- self.norm = nn.LayerNorm(inner_dim, eps=eps)
100
- self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
101
-
102
- def forward(self, image_feats):
103
- # image_feats: [N, L_cond, D_cond]
104
-
105
- N = image_feats.shape[0]
106
- H = W = self.triplane_low_res
107
- L = 3 * H * W
108
-
109
- x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
110
- for layer in self.layers:
111
- x = layer(x, image_feats)
112
- x = self.norm(x)
113
-
114
- # separate each plane and apply deconv
115
- x = x.view(N, 3, H, W, -1)
116
- x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
117
- x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
118
- x = self.deconv(x) # [3*N, D', H', W']
119
- x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
120
- x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
121
- x = x.contiguous()
122
-
123
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/encoder/__init__.py DELETED
File without changes
src/models/encoder/dino.py DELETED
@@ -1,550 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ PyTorch ViT model."""
16
-
17
-
18
- import collections.abc
19
- import math
20
- from typing import Dict, List, Optional, Set, Tuple, Union
21
-
22
- import torch
23
- from torch import nn
24
-
25
- from transformers.activations import ACT2FN
26
- from transformers.modeling_outputs import (
27
- BaseModelOutput,
28
- BaseModelOutputWithPooling,
29
- )
30
- from transformers import PreTrainedModel, ViTConfig
31
- from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
32
-
33
-
34
- class ViTEmbeddings(nn.Module):
35
- """
36
- Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
37
- """
38
-
39
- def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
40
- super().__init__()
41
-
42
- self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
43
- self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
44
- self.patch_embeddings = ViTPatchEmbeddings(config)
45
- num_patches = self.patch_embeddings.num_patches
46
- self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
47
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
48
- self.config = config
49
-
50
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
51
- """
52
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
53
- resolution images.
54
-
55
- Source:
56
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
57
- """
58
-
59
- num_patches = embeddings.shape[1] - 1
60
- num_positions = self.position_embeddings.shape[1] - 1
61
- if num_patches == num_positions and height == width:
62
- return self.position_embeddings
63
- class_pos_embed = self.position_embeddings[:, 0]
64
- patch_pos_embed = self.position_embeddings[:, 1:]
65
- dim = embeddings.shape[-1]
66
- h0 = height // self.config.patch_size
67
- w0 = width // self.config.patch_size
68
- # we add a small number to avoid floating point error in the interpolation
69
- # see discussion at https://github.com/facebookresearch/dino/issues/8
70
- h0, w0 = h0 + 0.1, w0 + 0.1
71
- patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
72
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
73
- patch_pos_embed = nn.functional.interpolate(
74
- patch_pos_embed,
75
- scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
76
- mode="bicubic",
77
- align_corners=False,
78
- )
79
- assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
80
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
81
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
82
-
83
- def forward(
84
- self,
85
- pixel_values: torch.Tensor,
86
- bool_masked_pos: Optional[torch.BoolTensor] = None,
87
- interpolate_pos_encoding: bool = False,
88
- ) -> torch.Tensor:
89
- batch_size, num_channels, height, width = pixel_values.shape
90
- embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
91
-
92
- if bool_masked_pos is not None:
93
- seq_length = embeddings.shape[1]
94
- mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
95
- # replace the masked visual tokens by mask_tokens
96
- mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
97
- embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
98
-
99
- # add the [CLS] token to the embedded patch tokens
100
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
101
- embeddings = torch.cat((cls_tokens, embeddings), dim=1)
102
-
103
- # add positional encoding to each token
104
- if interpolate_pos_encoding:
105
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
106
- else:
107
- embeddings = embeddings + self.position_embeddings
108
-
109
- embeddings = self.dropout(embeddings)
110
-
111
- return embeddings
112
-
113
-
114
- class ViTPatchEmbeddings(nn.Module):
115
- """
116
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
117
- `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
118
- Transformer.
119
- """
120
-
121
- def __init__(self, config):
122
- super().__init__()
123
- image_size, patch_size = config.image_size, config.patch_size
124
- num_channels, hidden_size = config.num_channels, config.hidden_size
125
-
126
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
127
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
128
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
129
- self.image_size = image_size
130
- self.patch_size = patch_size
131
- self.num_channels = num_channels
132
- self.num_patches = num_patches
133
-
134
- self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
135
-
136
- def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
137
- batch_size, num_channels, height, width = pixel_values.shape
138
- if num_channels != self.num_channels:
139
- raise ValueError(
140
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
141
- f" Expected {self.num_channels} but got {num_channels}."
142
- )
143
- if not interpolate_pos_encoding:
144
- if height != self.image_size[0] or width != self.image_size[1]:
145
- raise ValueError(
146
- f"Input image size ({height}*{width}) doesn't match model"
147
- f" ({self.image_size[0]}*{self.image_size[1]})."
148
- )
149
- embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
150
- return embeddings
151
-
152
-
153
- class ViTSelfAttention(nn.Module):
154
- def __init__(self, config: ViTConfig) -> None:
155
- super().__init__()
156
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
157
- raise ValueError(
158
- f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
159
- f"heads {config.num_attention_heads}."
160
- )
161
-
162
- self.num_attention_heads = config.num_attention_heads
163
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
164
- self.all_head_size = self.num_attention_heads * self.attention_head_size
165
-
166
- self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
167
- self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
168
- self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
169
-
170
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
171
-
172
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
173
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
174
- x = x.view(new_x_shape)
175
- return x.permute(0, 2, 1, 3)
176
-
177
- def forward(
178
- self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
179
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
180
- mixed_query_layer = self.query(hidden_states)
181
-
182
- key_layer = self.transpose_for_scores(self.key(hidden_states))
183
- value_layer = self.transpose_for_scores(self.value(hidden_states))
184
- query_layer = self.transpose_for_scores(mixed_query_layer)
185
-
186
- # Take the dot product between "query" and "key" to get the raw attention scores.
187
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
188
-
189
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
190
-
191
- # Normalize the attention scores to probabilities.
192
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
193
-
194
- # This is actually dropping out entire tokens to attend to, which might
195
- # seem a bit unusual, but is taken from the original Transformer paper.
196
- attention_probs = self.dropout(attention_probs)
197
-
198
- # Mask heads if we want to
199
- if head_mask is not None:
200
- attention_probs = attention_probs * head_mask
201
-
202
- context_layer = torch.matmul(attention_probs, value_layer)
203
-
204
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
205
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
206
- context_layer = context_layer.view(new_context_layer_shape)
207
-
208
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
209
-
210
- return outputs
211
-
212
-
213
- class ViTSelfOutput(nn.Module):
214
- """
215
- The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
216
- layernorm applied before each block.
217
- """
218
-
219
- def __init__(self, config: ViTConfig) -> None:
220
- super().__init__()
221
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
222
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
223
-
224
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
225
- hidden_states = self.dense(hidden_states)
226
- hidden_states = self.dropout(hidden_states)
227
-
228
- return hidden_states
229
-
230
-
231
- class ViTAttention(nn.Module):
232
- def __init__(self, config: ViTConfig) -> None:
233
- super().__init__()
234
- self.attention = ViTSelfAttention(config)
235
- self.output = ViTSelfOutput(config)
236
- self.pruned_heads = set()
237
-
238
- def prune_heads(self, heads: Set[int]) -> None:
239
- if len(heads) == 0:
240
- return
241
- heads, index = find_pruneable_heads_and_indices(
242
- heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
243
- )
244
-
245
- # Prune linear layers
246
- self.attention.query = prune_linear_layer(self.attention.query, index)
247
- self.attention.key = prune_linear_layer(self.attention.key, index)
248
- self.attention.value = prune_linear_layer(self.attention.value, index)
249
- self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
250
-
251
- # Update hyper params and store pruned heads
252
- self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
253
- self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
254
- self.pruned_heads = self.pruned_heads.union(heads)
255
-
256
- def forward(
257
- self,
258
- hidden_states: torch.Tensor,
259
- head_mask: Optional[torch.Tensor] = None,
260
- output_attentions: bool = False,
261
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
262
- self_outputs = self.attention(hidden_states, head_mask, output_attentions)
263
-
264
- attention_output = self.output(self_outputs[0], hidden_states)
265
-
266
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
267
- return outputs
268
-
269
-
270
- class ViTIntermediate(nn.Module):
271
- def __init__(self, config: ViTConfig) -> None:
272
- super().__init__()
273
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
274
- if isinstance(config.hidden_act, str):
275
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
276
- else:
277
- self.intermediate_act_fn = config.hidden_act
278
-
279
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
280
- hidden_states = self.dense(hidden_states)
281
- hidden_states = self.intermediate_act_fn(hidden_states)
282
-
283
- return hidden_states
284
-
285
-
286
- class ViTOutput(nn.Module):
287
- def __init__(self, config: ViTConfig) -> None:
288
- super().__init__()
289
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
290
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
291
-
292
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
293
- hidden_states = self.dense(hidden_states)
294
- hidden_states = self.dropout(hidden_states)
295
-
296
- hidden_states = hidden_states + input_tensor
297
-
298
- return hidden_states
299
-
300
-
301
- def modulate(x, shift, scale):
302
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
303
-
304
-
305
- class ViTLayer(nn.Module):
306
- """This corresponds to the Block class in the timm implementation."""
307
-
308
- def __init__(self, config: ViTConfig) -> None:
309
- super().__init__()
310
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
311
- self.seq_len_dim = 1
312
- self.attention = ViTAttention(config)
313
- self.intermediate = ViTIntermediate(config)
314
- self.output = ViTOutput(config)
315
- self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
316
- self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
317
-
318
- self.adaLN_modulation = nn.Sequential(
319
- nn.SiLU(),
320
- nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
321
- )
322
- nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
323
- nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
324
-
325
- def forward(
326
- self,
327
- hidden_states: torch.Tensor,
328
- adaln_input: torch.Tensor = None,
329
- head_mask: Optional[torch.Tensor] = None,
330
- output_attentions: bool = False,
331
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
332
- shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
333
-
334
- self_attention_outputs = self.attention(
335
- modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
336
- head_mask,
337
- output_attentions=output_attentions,
338
- )
339
- attention_output = self_attention_outputs[0]
340
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
341
-
342
- # first residual connection
343
- hidden_states = attention_output + hidden_states
344
-
345
- # in ViT, layernorm is also applied after self-attention
346
- layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
347
- layer_output = self.intermediate(layer_output)
348
-
349
- # second residual connection is done here
350
- layer_output = self.output(layer_output, hidden_states)
351
-
352
- outputs = (layer_output,) + outputs
353
-
354
- return outputs
355
-
356
-
357
- class ViTEncoder(nn.Module):
358
- def __init__(self, config: ViTConfig) -> None:
359
- super().__init__()
360
- self.config = config
361
- self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
362
- self.gradient_checkpointing = False
363
-
364
- def forward(
365
- self,
366
- hidden_states: torch.Tensor,
367
- adaln_input: torch.Tensor = None,
368
- head_mask: Optional[torch.Tensor] = None,
369
- output_attentions: bool = False,
370
- output_hidden_states: bool = False,
371
- return_dict: bool = True,
372
- ) -> Union[tuple, BaseModelOutput]:
373
- all_hidden_states = () if output_hidden_states else None
374
- all_self_attentions = () if output_attentions else None
375
-
376
- for i, layer_module in enumerate(self.layer):
377
- if output_hidden_states:
378
- all_hidden_states = all_hidden_states + (hidden_states,)
379
-
380
- layer_head_mask = head_mask[i] if head_mask is not None else None
381
-
382
- if self.gradient_checkpointing and self.training:
383
- layer_outputs = self._gradient_checkpointing_func(
384
- layer_module.__call__,
385
- hidden_states,
386
- adaln_input,
387
- layer_head_mask,
388
- output_attentions,
389
- )
390
- else:
391
- layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
392
-
393
- hidden_states = layer_outputs[0]
394
-
395
- if output_attentions:
396
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
397
-
398
- if output_hidden_states:
399
- all_hidden_states = all_hidden_states + (hidden_states,)
400
-
401
- if not return_dict:
402
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
403
- return BaseModelOutput(
404
- last_hidden_state=hidden_states,
405
- hidden_states=all_hidden_states,
406
- attentions=all_self_attentions,
407
- )
408
-
409
-
410
- class ViTPreTrainedModel(PreTrainedModel):
411
- """
412
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
413
- models.
414
- """
415
-
416
- config_class = ViTConfig
417
- base_model_prefix = "vit"
418
- main_input_name = "pixel_values"
419
- supports_gradient_checkpointing = True
420
- _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
421
-
422
- def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
423
- """Initialize the weights"""
424
- if isinstance(module, (nn.Linear, nn.Conv2d)):
425
- # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
426
- # `trunc_normal_cpu` not implemented in `half` issues
427
- module.weight.data = nn.init.trunc_normal_(
428
- module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
429
- ).to(module.weight.dtype)
430
- if module.bias is not None:
431
- module.bias.data.zero_()
432
- elif isinstance(module, nn.LayerNorm):
433
- module.bias.data.zero_()
434
- module.weight.data.fill_(1.0)
435
- elif isinstance(module, ViTEmbeddings):
436
- module.position_embeddings.data = nn.init.trunc_normal_(
437
- module.position_embeddings.data.to(torch.float32),
438
- mean=0.0,
439
- std=self.config.initializer_range,
440
- ).to(module.position_embeddings.dtype)
441
-
442
- module.cls_token.data = nn.init.trunc_normal_(
443
- module.cls_token.data.to(torch.float32),
444
- mean=0.0,
445
- std=self.config.initializer_range,
446
- ).to(module.cls_token.dtype)
447
-
448
-
449
- class ViTModel(ViTPreTrainedModel):
450
- def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
451
- super().__init__(config)
452
- self.config = config
453
-
454
- self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
455
- self.encoder = ViTEncoder(config)
456
-
457
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
458
- self.pooler = ViTPooler(config) if add_pooling_layer else None
459
-
460
- # Initialize weights and apply final processing
461
- self.post_init()
462
-
463
- def get_input_embeddings(self) -> ViTPatchEmbeddings:
464
- return self.embeddings.patch_embeddings
465
-
466
- def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
467
- """
468
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
469
- class PreTrainedModel
470
- """
471
- for layer, heads in heads_to_prune.items():
472
- self.encoder.layer[layer].attention.prune_heads(heads)
473
-
474
- def forward(
475
- self,
476
- pixel_values: Optional[torch.Tensor] = None,
477
- adaln_input: Optional[torch.Tensor] = None,
478
- bool_masked_pos: Optional[torch.BoolTensor] = None,
479
- head_mask: Optional[torch.Tensor] = None,
480
- output_attentions: Optional[bool] = None,
481
- output_hidden_states: Optional[bool] = None,
482
- interpolate_pos_encoding: Optional[bool] = None,
483
- return_dict: Optional[bool] = None,
484
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
485
- r"""
486
- bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
487
- Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
488
- """
489
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
- output_hidden_states = (
491
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
492
- )
493
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
494
-
495
- if pixel_values is None:
496
- raise ValueError("You have to specify pixel_values")
497
-
498
- # Prepare head mask if needed
499
- # 1.0 in head_mask indicate we keep the head
500
- # attention_probs has shape bsz x n_heads x N x N
501
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
502
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
503
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
504
-
505
- # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
506
- expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
507
- if pixel_values.dtype != expected_dtype:
508
- pixel_values = pixel_values.to(expected_dtype)
509
-
510
- embedding_output = self.embeddings(
511
- pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
512
- )
513
-
514
- encoder_outputs = self.encoder(
515
- embedding_output,
516
- adaln_input=adaln_input,
517
- head_mask=head_mask,
518
- output_attentions=output_attentions,
519
- output_hidden_states=output_hidden_states,
520
- return_dict=return_dict,
521
- )
522
- sequence_output = encoder_outputs[0]
523
- sequence_output = self.layernorm(sequence_output)
524
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
525
-
526
- if not return_dict:
527
- head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
528
- return head_outputs + encoder_outputs[1:]
529
-
530
- return BaseModelOutputWithPooling(
531
- last_hidden_state=sequence_output,
532
- pooler_output=pooled_output,
533
- hidden_states=encoder_outputs.hidden_states,
534
- attentions=encoder_outputs.attentions,
535
- )
536
-
537
-
538
- class ViTPooler(nn.Module):
539
- def __init__(self, config: ViTConfig):
540
- super().__init__()
541
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
542
- self.activation = nn.Tanh()
543
-
544
- def forward(self, hidden_states):
545
- # We "pool" the model by simply taking the hidden state corresponding
546
- # to the first token.
547
- first_token_tensor = hidden_states[:, 0]
548
- pooled_output = self.dense(first_token_tensor)
549
- pooled_output = self.activation(pooled_output)
550
- return pooled_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/encoder/dino_wrapper.py DELETED
@@ -1,80 +0,0 @@
1
- # Copyright (c) 2023, Zexin He
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import torch.nn as nn
17
- from transformers import ViTImageProcessor
18
- from einops import rearrange, repeat
19
- from .dino import ViTModel
20
-
21
-
22
- class DinoWrapper(nn.Module):
23
- """
24
- Dino v1 wrapper using huggingface transformer implementation.
25
- """
26
- def __init__(self, model_name: str, freeze: bool = True):
27
- super().__init__()
28
- self.model, self.processor = self._build_dino(model_name)
29
- self.camera_embedder = nn.Sequential(
30
- nn.Linear(16, self.model.config.hidden_size, bias=True),
31
- nn.SiLU(),
32
- nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
33
- )
34
- if freeze:
35
- self._freeze()
36
-
37
- def forward(self, image, camera):
38
- # image: [B, N, C, H, W]
39
- # camera: [B, N, D]
40
- # RGB image with [0,1] scale and properly sized
41
- if image.ndim == 5:
42
- image = rearrange(image, 'b n c h w -> (b n) c h w')
43
- dtype = image.dtype
44
- inputs = self.processor(
45
- images=image.float(),
46
- return_tensors="pt",
47
- do_rescale=False,
48
- do_resize=False,
49
- ).to(self.model.device).to(dtype)
50
- # embed camera
51
- N = camera.shape[1]
52
- camera_embeddings = self.camera_embedder(camera)
53
- camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
54
- embeddings = camera_embeddings
55
- # This resampling of positional embedding uses bicubic interpolation
56
- outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
57
- last_hidden_states = outputs.last_hidden_state
58
- return last_hidden_states
59
-
60
- def _freeze(self):
61
- print(f"======== Freezing DinoWrapper ========")
62
- self.model.eval()
63
- for name, param in self.model.named_parameters():
64
- param.requires_grad = False
65
-
66
- @staticmethod
67
- def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
68
- import requests
69
- try:
70
- model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
71
- processor = ViTImageProcessor.from_pretrained(model_name)
72
- return model, processor
73
- except requests.exceptions.ProxyError as err:
74
- if proxy_error_retries > 0:
75
- print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
76
- import time
77
- time.sleep(proxy_error_cooldown)
78
- return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
79
- else:
80
- raise err
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/geometry/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 
 
 
 
 
 
 
 
src/models/geometry/camera/__init__.py DELETED
@@ -1,16 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- from torch import nn
11
-
12
-
13
- class Camera(nn.Module):
14
- def __init__(self):
15
- super(Camera, self).__init__()
16
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/geometry/camera/perspective_camera.py DELETED
@@ -1,35 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- from . import Camera
11
- import numpy as np
12
-
13
-
14
- def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
15
- if near_plane is None:
16
- near_plane = n
17
- return np.array(
18
- [[n / x, 0, 0, 0],
19
- [0, n / -x, 0, 0],
20
- [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
21
- [0, 0, -1, 0]]).astype(np.float32)
22
-
23
-
24
- class PerspectiveCamera(Camera):
25
- def __init__(self, fovy=49.0, device='cuda'):
26
- super(PerspectiveCamera, self).__init__()
27
- self.device = device
28
- focal = np.tan(fovy / 180.0 * np.pi * 0.5)
29
- self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
30
-
31
- def project(self, points_bxnx4):
32
- out = torch.matmul(
33
- points_bxnx4,
34
- torch.transpose(self.proj_mtx, 1, 2))
35
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/geometry/render/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
-
3
- class Renderer():
4
- def __init__(self):
5
- pass
6
-
7
- def forward(self):
8
- pass
 
 
 
 
 
 
 
 
 
src/models/geometry/render/neural_render.py DELETED
@@ -1,121 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- import torch.nn.functional as F
11
- import nvdiffrast.torch as dr
12
- from . import Renderer
13
-
14
- _FG_LUT = None
15
-
16
-
17
- def interpolate(attr, rast, attr_idx, rast_db=None):
18
- return dr.interpolate(
19
- attr.contiguous(), rast, attr_idx, rast_db=rast_db,
20
- diff_attrs=None if rast_db is None else 'all')
21
-
22
-
23
- def xfm_points(points, matrix, use_python=True):
24
- '''Transform points.
25
- Args:
26
- points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
27
- matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
28
- use_python: Use PyTorch's torch.matmul (for validation)
29
- Returns:
30
- Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
31
- '''
32
- out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
33
- if torch.is_anomaly_enabled():
34
- assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
35
- return out
36
-
37
-
38
- def dot(x, y):
39
- return torch.sum(x * y, -1, keepdim=True)
40
-
41
-
42
- def compute_vertex_normal(v_pos, t_pos_idx):
43
- i0 = t_pos_idx[:, 0]
44
- i1 = t_pos_idx[:, 1]
45
- i2 = t_pos_idx[:, 2]
46
-
47
- v0 = v_pos[i0, :]
48
- v1 = v_pos[i1, :]
49
- v2 = v_pos[i2, :]
50
-
51
- face_normals = torch.cross(v1 - v0, v2 - v0)
52
-
53
- # Splat face normals to vertices
54
- v_nrm = torch.zeros_like(v_pos)
55
- v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
56
- v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
57
- v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
58
-
59
- # Normalize, replace zero (degenerated) normals with some default value
60
- v_nrm = torch.where(
61
- dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
62
- )
63
- v_nrm = F.normalize(v_nrm, dim=1)
64
- assert torch.all(torch.isfinite(v_nrm))
65
-
66
- return v_nrm
67
-
68
-
69
- class NeuralRender(Renderer):
70
- def __init__(self, device='cuda', camera_model=None):
71
- super(NeuralRender, self).__init__()
72
- self.device = device
73
- self.ctx = dr.RasterizeCudaContext(device=device)
74
- self.projection_mtx = None
75
- self.camera = camera_model
76
-
77
- def render_mesh(
78
- self,
79
- mesh_v_pos_bxnx3,
80
- mesh_t_pos_idx_fx3,
81
- camera_mv_bx4x4,
82
- mesh_v_feat_bxnxd,
83
- resolution=256,
84
- spp=1,
85
- device='cuda',
86
- hierarchical_mask=False
87
- ):
88
- assert not hierarchical_mask
89
-
90
- mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
91
- v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
92
- v_pos_clip = self.camera.project(v_pos) # Projection in the camera
93
-
94
- v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
95
-
96
- # Render the image,
97
- # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
98
- num_layers = 1
99
- mask_pyramid = None
100
- assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
101
- mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos
102
-
103
- with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
104
- for _ in range(num_layers):
105
- rast, db = peeler.rasterize_next_layer()
106
- gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
107
-
108
- hard_mask = torch.clamp(rast[..., -1:], 0, 1)
109
- antialias_mask = dr.antialias(
110
- hard_mask.clone().contiguous(), rast, v_pos_clip,
111
- mesh_t_pos_idx_fx3)
112
-
113
- depth = gb_feat[..., -2:-1]
114
- ori_mesh_feature = gb_feat[..., :-4]
115
-
116
- normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
117
- normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
118
- normal = F.normalize(normal, dim=-1)
119
- normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
120
-
121
- return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/geometry/rep_3d/__init__.py DELETED
@@ -1,18 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- import numpy as np
11
-
12
-
13
- class Geometry():
14
- def __init__(self):
15
- pass
16
-
17
- def forward(self):
18
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/geometry/rep_3d/dmtet.py DELETED
@@ -1,504 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- import numpy as np
11
- import os
12
- from . import Geometry
13
- from .dmtet_utils import get_center_boundary_index
14
- import torch.nn.functional as F
15
-
16
-
17
- ###############################################################################
18
- # DMTet utility functions
19
- ###############################################################################
20
- def create_mt_variable(device):
21
- triangle_table = torch.tensor(
22
- [
23
- [-1, -1, -1, -1, -1, -1],
24
- [1, 0, 2, -1, -1, -1],
25
- [4, 0, 3, -1, -1, -1],
26
- [1, 4, 2, 1, 3, 4],
27
- [3, 1, 5, -1, -1, -1],
28
- [2, 3, 0, 2, 5, 3],
29
- [1, 4, 0, 1, 5, 4],
30
- [4, 2, 5, -1, -1, -1],
31
- [4, 5, 2, -1, -1, -1],
32
- [4, 1, 0, 4, 5, 1],
33
- [3, 2, 0, 3, 5, 2],
34
- [1, 3, 5, -1, -1, -1],
35
- [4, 1, 2, 4, 3, 1],
36
- [3, 0, 4, -1, -1, -1],
37
- [2, 0, 1, -1, -1, -1],
38
- [-1, -1, -1, -1, -1, -1]
39
- ], dtype=torch.long, device=device)
40
-
41
- num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
42
- base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
43
- v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
44
- return triangle_table, num_triangles_table, base_tet_edges, v_id
45
-
46
-
47
- def sort_edges(edges_ex2):
48
- with torch.no_grad():
49
- order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
50
- order = order.unsqueeze(dim=1)
51
- a = torch.gather(input=edges_ex2, index=order, dim=1)
52
- b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
53
- return torch.stack([a, b], -1)
54
-
55
-
56
- ###############################################################################
57
- # marching tetrahedrons (differentiable)
58
- ###############################################################################
59
-
60
- def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id):
61
- with torch.no_grad():
62
- occ_n = sdf_n > 0
63
- occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
64
- occ_sum = torch.sum(occ_fx4, -1)
65
- valid_tets = (occ_sum > 0) & (occ_sum < 4)
66
- occ_sum = occ_sum[valid_tets]
67
-
68
- # find all vertices
69
- all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
70
- all_edges = sort_edges(all_edges)
71
- unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
72
-
73
- unique_edges = unique_edges.long()
74
- mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
75
- mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
76
- mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
77
- idx_map = mapping[idx_map] # map edges to verts
78
-
79
- interp_v = unique_edges[mask_edges] # .long()
80
- edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
81
- edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
82
- edges_to_interp_sdf[:, -1] *= -1
83
-
84
- denominator = edges_to_interp_sdf.sum(1, keepdim=True)
85
-
86
- edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
87
- verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
88
-
89
- idx_map = idx_map.reshape(-1, 6)
90
-
91
- tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
92
- num_triangles = num_triangles_table[tetindex]
93
-
94
- # Generate triangle indices
95
- faces = torch.cat(
96
- (
97
- torch.gather(
98
- input=idx_map[num_triangles == 1], dim=1,
99
- index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
100
- torch.gather(
101
- input=idx_map[num_triangles == 2], dim=1,
102
- index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
103
- ), dim=0)
104
- return verts, faces
105
-
106
-
107
- def create_tetmesh_variables(device='cuda'):
108
- tet_table = torch.tensor(
109
- [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
110
- [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1],
111
- [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1],
112
- [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8],
113
- [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1],
114
- [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9],
115
- [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9],
116
- [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9],
117
- [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1],
118
- [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9],
119
- [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9],
120
- [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9],
121
- [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8],
122
- [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8],
123
- [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6],
124
- [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device)
125
- num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device)
126
- return tet_table, num_tets_table
127
-
128
-
129
- def marching_tets_tetmesh(
130
- pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
131
- return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
132
- with torch.no_grad():
133
- occ_n = sdf_n > 0
134
- occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
135
- occ_sum = torch.sum(occ_fx4, -1)
136
- valid_tets = (occ_sum > 0) & (occ_sum < 4)
137
- occ_sum = occ_sum[valid_tets]
138
-
139
- # find all vertices
140
- all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
141
- all_edges = sort_edges(all_edges)
142
- unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
143
-
144
- unique_edges = unique_edges.long()
145
- mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
146
- mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
147
- mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
148
- idx_map = mapping[idx_map] # map edges to verts
149
-
150
- interp_v = unique_edges[mask_edges] # .long()
151
- edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
152
- edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
153
- edges_to_interp_sdf[:, -1] *= -1
154
-
155
- denominator = edges_to_interp_sdf.sum(1, keepdim=True)
156
-
157
- edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
158
- verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
159
-
160
- idx_map = idx_map.reshape(-1, 6)
161
-
162
- tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
163
- num_triangles = num_triangles_table[tetindex]
164
-
165
- # Generate triangle indices
166
- faces = torch.cat(
167
- (
168
- torch.gather(
169
- input=idx_map[num_triangles == 1], dim=1,
170
- index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
171
- torch.gather(
172
- input=idx_map[num_triangles == 2], dim=1,
173
- index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
174
- ), dim=0)
175
- if not return_tet_mesh:
176
- return verts, faces
177
- occupied_verts = ori_v[occ_n]
178
- mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1
179
- mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda")
180
- tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4))
181
-
182
- idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10
183
- tet_verts = torch.cat([verts, occupied_verts], 0)
184
- num_tets = num_tets_table[tetindex]
185
-
186
- tets = torch.cat(
187
- (
188
- torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape(
189
- -1,
190
- 4),
191
- torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape(
192
- -1,
193
- 4),
194
- ), dim=0)
195
- # add fully occupied tets
196
- fully_occupied = occ_fx4.sum(-1) == 4
197
- tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0]
198
- tets = torch.cat([tets, tet_fully_occupied])
199
-
200
- return verts, faces, tet_verts, tets
201
-
202
-
203
- ###############################################################################
204
- # Compact tet grid
205
- ###############################################################################
206
-
207
- def compact_tets(pos_nx3, sdf_n, tet_fx4):
208
- with torch.no_grad():
209
- # Find surface tets
210
- occ_n = sdf_n > 0
211
- occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
212
- occ_sum = torch.sum(occ_fx4, -1)
213
- valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets
214
-
215
- valid_vtx = tet_fx4[valid_tets].reshape(-1)
216
- unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True)
217
- new_pos = pos_nx3[unique_vtx]
218
- new_sdf = sdf_n[unique_vtx]
219
- new_tets = idx_map.reshape(-1, 4)
220
- return new_pos, new_sdf, new_tets
221
-
222
-
223
- ###############################################################################
224
- # Subdivide volume
225
- ###############################################################################
226
-
227
- def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
228
- device = tet_pos_bxnx3.device
229
- # get new verts
230
- tet_fx4 = tet_bxfx4[0]
231
- edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3]
232
- all_edges = tet_fx4[:, edges].reshape(-1, 2)
233
- all_edges = sort_edges(all_edges)
234
- unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
235
- idx_map = idx_map + tet_pos_bxnx3.shape[1]
236
- all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1)
237
- mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape(
238
- all_values.shape[0], -1, 2,
239
- all_values.shape[-1]).mean(2)
240
- new_v = torch.cat([all_values, mid_points_pos], 1)
241
- new_v, new_sdf = new_v[..., :3], new_v[..., 3]
242
-
243
- # get new tets
244
-
245
- idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3]
246
- idx_ab = idx_map[0::6]
247
- idx_ac = idx_map[1::6]
248
- idx_ad = idx_map[2::6]
249
- idx_bc = idx_map[3::6]
250
- idx_bd = idx_map[4::6]
251
- idx_cd = idx_map[5::6]
252
-
253
- tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1)
254
- tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1)
255
- tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1)
256
- tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1)
257
- tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1)
258
- tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1)
259
- tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1)
260
- tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1)
261
-
262
- tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0)
263
- tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1)
264
- tet = tet_np.long().to(device)
265
-
266
- return new_v, tet, new_sdf
267
-
268
-
269
- ###############################################################################
270
- # Adjacency
271
- ###############################################################################
272
- def tet_to_tet_adj_sparse(tet_tx4):
273
- # include self connection!!!!!!!!!!!!!!!!!!!
274
- with torch.no_grad():
275
- t = tet_tx4.shape[0]
276
- device = tet_tx4.device
277
- idx_array = torch.LongTensor(
278
- [0, 1, 2,
279
- 1, 0, 3,
280
- 2, 3, 0,
281
- 3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3)
282
-
283
- # get all faces
284
- all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape(
285
- -1,
286
- 3) # (tx4, 3)
287
- all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1)
288
- # sort and group
289
- all_faces_sorted, _ = torch.sort(all_faces, dim=1)
290
-
291
- all_faces_unique, inverse_indices, counts = torch.unique(
292
- all_faces_sorted, dim=0, return_counts=True,
293
- return_inverse=True)
294
- tet_face_fx3 = all_faces_unique[counts == 2]
295
- counts = counts[inverse_indices] # tx4
296
- valid = (counts == 2)
297
-
298
- group = inverse_indices[valid]
299
- # print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape)
300
- _, indices = torch.sort(group)
301
- all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices]
302
- tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1)
303
-
304
- tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])])
305
- adj_self = torch.arange(t, device=tet_tx4.device)
306
- adj_self = torch.stack([adj_self, adj_self], -1)
307
- tet_adj_idx = torch.cat([tet_adj_idx, adj_self])
308
-
309
- tet_adj_idx = torch.unique(tet_adj_idx, dim=0)
310
- values = torch.ones(
311
- tet_adj_idx.shape[0], device=tet_tx4.device).float()
312
- adj_sparse = torch.sparse.FloatTensor(
313
- tet_adj_idx.t(), values, torch.Size([t, t]))
314
-
315
- # normalization
316
- neighbor_num = 1.0 / torch.sparse.sum(
317
- adj_sparse, dim=1).to_dense()
318
- values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0])
319
- adj_sparse = torch.sparse.FloatTensor(
320
- tet_adj_idx.t(), values, torch.Size([t, t]))
321
- return adj_sparse
322
-
323
-
324
- ###############################################################################
325
- # Compact grid
326
- ###############################################################################
327
-
328
- def get_tet_bxfx4x3(bxnxz, bxfx4):
329
- n_batch, z = bxnxz.shape[0], bxnxz.shape[2]
330
- gather_input = bxnxz.unsqueeze(2).expand(
331
- n_batch, bxnxz.shape[1], 4, z)
332
- gather_index = bxfx4.unsqueeze(-1).expand(
333
- n_batch, bxfx4.shape[1], 4, z).long()
334
- tet_bxfx4xz = torch.gather(
335
- input=gather_input, dim=1, index=gather_index)
336
-
337
- return tet_bxfx4xz
338
-
339
-
340
- def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
341
- with torch.no_grad():
342
- assert tet_pos_bxnx3.shape[0] == 1
343
-
344
- occ = grid_sdf[0] > 0
345
- occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1)
346
- mask = (occ_sum > 0) & (occ_sum < 4)
347
-
348
- # build connectivity graph
349
- adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0])
350
- mask = mask.float().unsqueeze(-1)
351
-
352
- # Include a one ring of neighbors
353
- for i in range(1):
354
- mask = torch.sparse.mm(adj_matrix, mask)
355
- mask = mask.squeeze(-1) > 0
356
-
357
- mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long)
358
- new_tet_bxfx4 = tet_bxfx4[:, mask].long()
359
- selected_verts_idx = torch.unique(new_tet_bxfx4)
360
- new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx]
361
- mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device)
362
- new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape)
363
- new_grid_sdf = grid_sdf[:, selected_verts_idx]
364
- return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf
365
-
366
-
367
- ###############################################################################
368
- # Regularizer
369
- ###############################################################################
370
-
371
- def sdf_reg_loss(sdf, all_edges):
372
- sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2)
373
- mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
374
- sdf_f1x6x2 = sdf_f1x6x2[mask]
375
- sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(
376
- sdf_f1x6x2[..., 0],
377
- (sdf_f1x6x2[..., 1] > 0).float()) + \
378
- torch.nn.functional.binary_cross_entropy_with_logits(
379
- sdf_f1x6x2[..., 1],
380
- (sdf_f1x6x2[..., 0] > 0).float())
381
- return sdf_diff
382
-
383
-
384
- def sdf_reg_loss_batch(sdf, all_edges):
385
- sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
386
- mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
387
- sdf_f1x6x2 = sdf_f1x6x2[mask]
388
- sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
389
- torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
390
- return sdf_diff
391
-
392
-
393
- ###############################################################################
394
- # Geometry interface
395
- ###############################################################################
396
- class DMTetGeometry(Geometry):
397
- def __init__(
398
- self, grid_res=64, scale=2.0, device='cuda', renderer=None,
399
- render_type='neural_render', args=None):
400
- super(DMTetGeometry, self).__init__()
401
- self.grid_res = grid_res
402
- self.device = device
403
- self.args = args
404
- tets = np.load('data/tets/%d_compress.npz' % (grid_res))
405
- self.verts = torch.from_numpy(tets['vertices']).float().to(self.device)
406
- # Make sure the tet is zero-centered and length is equal to 1
407
- length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0]
408
- length = length.max()
409
- mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0
410
- self.verts = (self.verts - mid.unsqueeze(dim=0)) / length
411
- if isinstance(scale, list):
412
- self.verts[:, 0] = self.verts[:, 0] * scale[0]
413
- self.verts[:, 1] = self.verts[:, 1] * scale[1]
414
- self.verts[:, 2] = self.verts[:, 2] * scale[1]
415
- else:
416
- self.verts = self.verts * scale
417
- self.indices = torch.from_numpy(tets['tets']).long().to(self.device)
418
- self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device)
419
- self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device)
420
- # Parameters for regularization computation
421
- edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device)
422
- all_edges = self.indices[:, edges].reshape(-1, 2)
423
- all_edges_sorted = torch.sort(all_edges, dim=1)[0]
424
- self.all_edges = torch.unique(all_edges_sorted, dim=0)
425
-
426
- # Parameters used for fix boundary sdf
427
- self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts)
428
- self.renderer = renderer
429
- self.render_type = render_type
430
-
431
- def getAABB(self):
432
- return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
433
-
434
- def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
435
- if indices is None:
436
- indices = self.indices
437
- verts, faces = marching_tets(
438
- v_deformed_nx3, sdf_n, indices, self.triangle_table,
439
- self.num_triangles_table, self.base_tet_edges, self.v_id)
440
- faces = torch.cat(
441
- [faces[:, 0:1],
442
- faces[:, 2:3],
443
- faces[:, 1:2], ], dim=-1)
444
- return verts, faces
445
-
446
- def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
447
- if indices is None:
448
- indices = self.indices
449
- verts, faces, tet_verts, tets = marching_tets_tetmesh(
450
- v_deformed_nx3, sdf_n, indices, self.triangle_table,
451
- self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True,
452
- num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3)
453
- faces = torch.cat(
454
- [faces[:, 0:1],
455
- faces[:, 2:3],
456
- faces[:, 1:2], ], dim=-1)
457
- return verts, faces, tet_verts, tets
458
-
459
- def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
460
- return_value = dict()
461
- if self.render_type == 'neural_render':
462
- tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
463
- mesh_v_nx3.unsqueeze(dim=0),
464
- mesh_f_fx3.int(),
465
- camera_mv_bx4x4,
466
- mesh_v_nx3.unsqueeze(dim=0),
467
- resolution=resolution,
468
- device=self.device,
469
- hierarchical_mask=hierarchical_mask
470
- )
471
-
472
- return_value['tex_pos'] = tex_pos
473
- return_value['mask'] = mask
474
- return_value['hard_mask'] = hard_mask
475
- return_value['rast'] = rast
476
- return_value['v_pos_clip'] = v_pos_clip
477
- return_value['mask_pyramid'] = mask_pyramid
478
- return_value['depth'] = depth
479
- else:
480
- raise NotImplementedError
481
-
482
- return return_value
483
-
484
- def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
485
- # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
486
- v_list = []
487
- f_list = []
488
- n_batch = v_deformed_bxnx3.shape[0]
489
- all_render_output = []
490
- for i_batch in range(n_batch):
491
- verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
492
- v_list.append(verts_nx3)
493
- f_list.append(faces_fx3)
494
- render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
495
- all_render_output.append(render_output)
496
-
497
- # Concatenate all render output
498
- return_keys = all_render_output[0].keys()
499
- return_value = dict()
500
- for k in return_keys:
501
- value = [v[k] for v in all_render_output]
502
- return_value[k] = value
503
- # We can do concatenation outside of the render
504
- return return_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/geometry/rep_3d/dmtet_utils.py DELETED
@@ -1,20 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
-
11
-
12
- def get_center_boundary_index(verts):
13
- length_ = torch.sum(verts ** 2, dim=-1)
14
- center_idx = torch.argmin(length_)
15
- boundary_neg = verts == verts.max()
16
- boundary_pos = verts == verts.min()
17
- boundary = torch.bitwise_or(boundary_pos, boundary_neg)
18
- boundary = torch.sum(boundary.float(), dim=-1)
19
- boundary_idx = torch.nonzero(boundary)
20
- return center_idx, boundary_idx.squeeze(dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/geometry/rep_3d/extract_texture_map.py DELETED
@@ -1,40 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- import xatlas
11
- import numpy as np
12
- import nvdiffrast.torch as dr
13
-
14
-
15
- # ==============================================================================================
16
- def interpolate(attr, rast, attr_idx, rast_db=None):
17
- return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
18
-
19
-
20
- def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
21
- vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
22
-
23
- # Convert to tensors
24
- indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
25
-
26
- uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
27
- mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
28
- # mesh_v_tex. ture
29
- uv_clip = uvs[None, ...] * 2.0 - 1.0
30
-
31
- # pad to four component coordinate
32
- uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
33
-
34
- # rasterize
35
- rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
36
-
37
- # Interpolate world space position
38
- gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
39
- mask = rast[..., 3:4] > 0
40
- return uvs, mesh_tex_idx, gb_pos, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/geometry/rep_3d/flexicubes.py DELETED
@@ -1,579 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
- import torch
9
- from .tables import *
10
-
11
- __all__ = [
12
- 'FlexiCubes'
13
- ]
14
-
15
-
16
- class FlexiCubes:
17
- """
18
- This class implements the FlexiCubes method for extracting meshes from scalar fields.
19
- It maintains a series of lookup tables and indices to support the mesh extraction process.
20
- FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
21
- the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
22
- the surface representation through gradient-based optimization.
23
-
24
- During instantiation, the class loads DMC tables from a file and transforms them into
25
- PyTorch tensors on the specified device.
26
-
27
- Attributes:
28
- device (str): Specifies the computational device (default is "cuda").
29
- dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
30
- associated with each dual vertex in 256 Marching Cubes (MC) configurations.
31
- num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
32
- the 256 MC configurations.
33
- check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
34
- of the DMC configurations.
35
- tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
36
- quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
37
- along one diagonal.
38
- quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
39
- two triangles along the other diagonal.
40
- quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
41
- during training by connecting all edges to their midpoints.
42
- cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
43
- eight corners in 3D space, ordered starting from the origin (0,0,0),
44
- moving along the x-axis, then y-axis, and finally z-axis.
45
- Used as a blueprint for generating a voxel grid.
46
- cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
47
- to retrieve the case id.
48
- cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
49
- Used to retrieve edge vertices in DMC.
50
- edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
51
- their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
52
- first edge is oriented along the x-axis.
53
- dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
54
- across four adjacent cubes to the shared faces of these cubes. For instance,
55
- dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
56
- the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
57
- This tensor is only utilized during isosurface tetrahedralization.
58
- adj_pairs (torch.Tensor):
59
- A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
60
- qef_reg_scale (float):
61
- The scaling factor applied to the regularization loss to prevent issues with singularity
62
- when solving the QEF. This parameter is only used when a 'grad_func' is specified.
63
- weight_scale (float):
64
- The scale of weights in FlexiCubes. Should be between 0 and 1.
65
- """
66
-
67
- def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
68
-
69
- self.device = device
70
- self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
71
- self.num_vd_table = torch.tensor(num_vd_table,
72
- dtype=torch.long, device=device, requires_grad=False)
73
- self.check_table = torch.tensor(
74
- check_table,
75
- dtype=torch.long, device=device, requires_grad=False)
76
-
77
- self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
78
- self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
79
- self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
80
- self.quad_split_train = torch.tensor(
81
- [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
82
-
83
- self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
84
- 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
85
- self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
86
- self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
87
- 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
88
-
89
- self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
90
- dtype=torch.long, device=device)
91
- self.dir_faces_table = torch.tensor([
92
- [[5, 4], [3, 2], [4, 5], [2, 3]],
93
- [[5, 4], [1, 0], [4, 5], [0, 1]],
94
- [[3, 2], [1, 0], [2, 3], [0, 1]]
95
- ], dtype=torch.long, device=device)
96
- self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
97
- self.qef_reg_scale = qef_reg_scale
98
- self.weight_scale = weight_scale
99
-
100
- def construct_voxel_grid(self, res):
101
- """
102
- Generates a voxel grid based on the specified resolution.
103
-
104
- Args:
105
- res (int or list[int]): The resolution of the voxel grid. If an integer
106
- is provided, it is used for all three dimensions. If a list or tuple
107
- of 3 integers is provided, they define the resolution for the x,
108
- y, and z dimensions respectively.
109
-
110
- Returns:
111
- (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
112
- cube corners (index into vertices) of the constructed voxel grid.
113
- The vertices are centered at the origin, with the length of each
114
- dimension in the grid being one.
115
- """
116
- base_cube_f = torch.arange(8).to(self.device)
117
- if isinstance(res, int):
118
- res = (res, res, res)
119
- voxel_grid_template = torch.ones(res, device=self.device)
120
-
121
- res = torch.tensor([res], dtype=torch.float, device=self.device)
122
- coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
123
- verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
124
- cubes = (base_cube_f.unsqueeze(0) +
125
- torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
126
-
127
- verts_rounded = torch.round(verts * 10**5) / (10**5)
128
- verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
129
- cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
130
-
131
- return verts_unique - 0.5, cubes
132
-
133
- def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
134
- gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
135
- r"""
136
- Main function for mesh extraction from scalar field using FlexiCubes. This function converts
137
- discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
138
- to triangle or tetrahedral meshes using a differentiable operation as described in
139
- `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
140
- mesh quality and geometric fidelity by adjusting the surface representation based on gradient
141
- optimization. The output surface is differentiable with respect to the input vertex positions,
142
- scalar field values, and weight parameters.
143
-
144
- If you intend to extract a surface mesh from a fixed Signed Distance Field without the
145
- optimization of parameters, it is suggested to provide the "grad_func" which should
146
- return the surface gradient at any given 3D position. When grad_func is provided, the process
147
- to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
148
- described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
149
- Please note, this approach is non-differentiable.
150
-
151
- For more details and example usage in optimization, refer to the
152
- `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
153
-
154
- Args:
155
- x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
156
- s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
157
- denote that the corresponding vertex resides inside the isosurface. This affects
158
- the directions of the extracted triangle faces and volume to be tetrahedralized.
159
- cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
160
- res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
161
- is used for all three dimensions. If a list or tuple of 3 integers is provided, they
162
- specify the resolution for the x, y, and z dimensions respectively.
163
- beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
164
- vertices positioning. Defaults to uniform value for all edges.
165
- alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
166
- vertices positioning. Defaults to uniform value for all vertices.
167
- gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
168
- quadrilaterals into triangles. Defaults to uniform value for all cubes.
169
- training (bool, optional): If set to True, applies differentiable quad splitting for
170
- training. Defaults to False.
171
- output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
172
- outputs a triangular mesh. Defaults to False.
173
- grad_func (callable, optional): A function to compute the surface gradient at specified
174
- 3D positions (input: Nx3 positions). The function should return gradients as an Nx3
175
- tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
176
-
177
- Returns:
178
- (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
179
- - Vertices for the extracted triangular/tetrahedral mesh.
180
- - Faces for the extracted triangular/tetrahedral mesh.
181
- - Regularizer L_dev, computed per dual vertex.
182
-
183
- .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
184
- https://research.nvidia.com/labs/toronto-ai/flexicubes/
185
- .. _Manifold Dual Contouring:
186
- https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
187
- """
188
-
189
- surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
190
- if surf_cubes.sum() == 0:
191
- return torch.zeros(
192
- (0, 3),
193
- device=self.device), torch.zeros(
194
- (0, 4),
195
- dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
196
- (0, 3),
197
- dtype=torch.long, device=self.device), torch.zeros(
198
- (0),
199
- device=self.device)
200
- beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
201
-
202
- case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
203
-
204
- surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
205
-
206
- vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
207
- x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
208
- vertices, faces, s_edges, edge_indices = self._triangulate(
209
- s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
210
- if not output_tetmesh:
211
- return vertices, faces, L_dev
212
- else:
213
- vertices, tets = self._tetrahedralize(
214
- x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
215
- surf_cubes, training)
216
- return vertices, tets, L_dev
217
-
218
- def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
219
- """
220
- Regularizer L_dev as in Equation 8
221
- """
222
- dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
223
- mean_l2 = torch.zeros_like(vd[:, 0])
224
- mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
225
- mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
226
- return mad
227
-
228
- def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
229
- """
230
- Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
231
- """
232
- n_cubes = surf_cubes.shape[0]
233
-
234
- if beta_fx12 is not None:
235
- beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
236
- else:
237
- beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
238
-
239
- if alpha_fx8 is not None:
240
- alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
241
- else:
242
- alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
243
-
244
- if gamma_f is not None:
245
- gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
246
- else:
247
- gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
248
-
249
- return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
250
-
251
- @torch.no_grad()
252
- def _get_case_id(self, occ_fx8, surf_cubes, res):
253
- """
254
- Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
255
- ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
256
- supplementary material. It should be noted that this function assumes a regular grid.
257
- """
258
- case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
259
-
260
- problem_config = self.check_table.to(self.device)[case_ids]
261
- to_check = problem_config[..., 0] == 1
262
- problem_config = problem_config[to_check]
263
- if not isinstance(res, (list, tuple)):
264
- res = [res, res, res]
265
-
266
- # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
267
- # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
268
- # This allows efficient checking on adjacent cubes.
269
- problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
270
- vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
271
- vol_idx_problem = vol_idx[surf_cubes][to_check]
272
- problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
273
- vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
274
-
275
- within_range = (
276
- vol_idx_problem_adj[..., 0] >= 0) & (
277
- vol_idx_problem_adj[..., 0] < res[0]) & (
278
- vol_idx_problem_adj[..., 1] >= 0) & (
279
- vol_idx_problem_adj[..., 1] < res[1]) & (
280
- vol_idx_problem_adj[..., 2] >= 0) & (
281
- vol_idx_problem_adj[..., 2] < res[2])
282
-
283
- vol_idx_problem = vol_idx_problem[within_range]
284
- vol_idx_problem_adj = vol_idx_problem_adj[within_range]
285
- problem_config = problem_config[within_range]
286
- problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
287
- vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
288
- # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
289
- to_invert = (problem_config_adj[..., 0] == 1)
290
- idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
291
- case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
292
- return case_ids
293
-
294
- @torch.no_grad()
295
- def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
296
- """
297
- Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
298
- can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
299
- and marks the cube edges with this index.
300
- """
301
- occ_n = s_n < 0
302
- all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
303
- unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
304
-
305
- unique_edges = unique_edges.long()
306
- mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
307
-
308
- surf_edges_mask = mask_edges[_idx_map]
309
- counts = counts[_idx_map]
310
-
311
- mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
312
- mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
313
- # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
314
- # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
315
- idx_map = mapping[_idx_map]
316
- surf_edges = unique_edges[mask_edges]
317
- return surf_edges, idx_map, counts, surf_edges_mask
318
-
319
- @torch.no_grad()
320
- def _identify_surf_cubes(self, s_n, cube_fx8):
321
- """
322
- Identifies grid cubes that intersect with the underlying surface by checking if the signs at
323
- all corners are not identical.
324
- """
325
- occ_n = s_n < 0
326
- occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
327
- _occ_sum = torch.sum(occ_fx8, -1)
328
- surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
329
- return surf_cubes, occ_fx8
330
-
331
- def _linear_interp(self, edges_weight, edges_x):
332
- """
333
- Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
334
- """
335
- edge_dim = edges_weight.dim() - 2
336
- assert edges_weight.shape[edge_dim] == 2
337
- edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
338
- torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
339
- denominator = edges_weight.sum(edge_dim)
340
- ue = (edges_x * edges_weight).sum(edge_dim) / denominator
341
- return ue
342
-
343
- def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
344
- p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
345
- norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
346
- c_bx3 = c_bx3.reshape(-1, 3)
347
- A = norm_bxnx3
348
- B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
349
-
350
- A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
351
- B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
352
- A = torch.cat([A, A_reg], 1)
353
- B = torch.cat([B, B_reg], 1)
354
- dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
355
- return dual_verts
356
-
357
- def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
358
- """
359
- Computes the location of dual vertices as described in Section 4.2
360
- """
361
- alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
362
- surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
363
- surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
364
- zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
365
-
366
- idx_map = idx_map.reshape(-1, 12)
367
- num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
368
- edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
369
-
370
- total_num_vd = 0
371
- vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
372
- if grad_func is not None:
373
- normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
374
- vd = []
375
- for num in torch.unique(num_vd):
376
- cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
377
- curr_num_vd = cur_cubes.sum() * num
378
- curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
379
- curr_edge_group_to_vd = torch.arange(
380
- curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
381
- total_num_vd += curr_num_vd
382
- curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
383
- cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
384
-
385
- curr_mask = (curr_edge_group != -1)
386
- edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
387
- edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
388
- edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
389
- vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
390
- vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
391
-
392
- if grad_func is not None:
393
- with torch.no_grad():
394
- cube_e_verts_idx = idx_map[cur_cubes]
395
- curr_edge_group[~curr_mask] = 0
396
-
397
- verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
398
- verts_group_idx[verts_group_idx == -1] = 0
399
- verts_group_pos = torch.index_select(
400
- input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
401
- v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
402
- curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
403
- verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
404
-
405
- normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
406
- -1, num.item(), 7,
407
- 3)
408
- curr_mask = curr_mask.squeeze(2)
409
- vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
410
- verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
411
- edge_group = torch.cat(edge_group)
412
- edge_group_to_vd = torch.cat(edge_group_to_vd)
413
- edge_group_to_cube = torch.cat(edge_group_to_cube)
414
- vd_num_edges = torch.cat(vd_num_edges)
415
- vd_gamma = torch.cat(vd_gamma)
416
-
417
- if grad_func is not None:
418
- vd = torch.cat(vd)
419
- L_dev = torch.zeros([1], device=self.device)
420
- else:
421
- vd = torch.zeros((total_num_vd, 3), device=self.device)
422
- beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
423
-
424
- idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
425
-
426
- x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
427
- s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
428
-
429
- zero_crossing_group = torch.index_select(
430
- input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
431
-
432
- alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
433
- index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
434
- ue_group = self._linear_interp(s_group * alpha_group, x_group)
435
-
436
- beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
437
- index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
438
- beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
439
- vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
440
- L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
441
-
442
- v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
443
-
444
- vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
445
- 12 + edge_group, src=v_idx[edge_group_to_vd])
446
-
447
- return vd, L_dev, vd_gamma, vd_idx_map
448
-
449
- def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
450
- """
451
- Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
452
- triangles based on the gamma parameter, as described in Section 4.3.
453
- """
454
- with torch.no_grad():
455
- group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
456
- group = idx_map.reshape(-1)[group_mask]
457
- vd_idx = vd_idx_map[group_mask]
458
- edge_indices, indices = torch.sort(group, stable=True)
459
- quad_vd_idx = vd_idx[indices].reshape(-1, 4)
460
-
461
- # Ensure all face directions point towards the positive SDF to maintain consistent winding.
462
- s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
463
- flip_mask = s_edges[:, 0] > 0
464
- quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
465
- quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
466
- if grad_func is not None:
467
- # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
468
- with torch.no_grad():
469
- vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
470
- quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
471
- gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
472
- gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
473
- else:
474
- quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
475
- gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
476
- 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
477
- gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
478
- 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
479
- if not training:
480
- mask = (gamma_02 > gamma_13).squeeze(1)
481
- faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
482
- faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
483
- faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
484
- faces = faces.reshape(-1, 3)
485
- else:
486
- vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
487
- vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
488
- torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
489
- vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
490
- torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
491
- weight_sum = (gamma_02 + gamma_13) + 1e-8
492
- vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
493
- weight_sum.unsqueeze(-1)).squeeze(1)
494
- vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
495
- vd = torch.cat([vd, vd_center])
496
- faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
497
- faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
498
- return vd, faces, s_edges, edge_indices
499
-
500
- def _tetrahedralize(
501
- self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
502
- surf_cubes, training):
503
- """
504
- Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
505
- """
506
- occ_n = s_n < 0
507
- occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
508
- occ_sum = torch.sum(occ_fx8, -1)
509
-
510
- inside_verts = x_nx3[occ_n]
511
- mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
512
- mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
513
- """
514
- For each grid edge connecting two grid vertices with different
515
- signs, we first form a four-sided pyramid by connecting one
516
- of the grid vertices with four mesh vertices that correspond
517
- to the grid edge and then subdivide the pyramid into two tetrahedra
518
- """
519
- inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
520
- s_edges < 0]]
521
- if not training:
522
- inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
523
- else:
524
- inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
525
-
526
- tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
527
- """
528
- For each grid edge connecting two grid vertices with the
529
- same sign, the tetrahedron is formed by the two grid vertices
530
- and two vertices in consecutive adjacent cells
531
- """
532
- inside_cubes = (occ_sum == 8)
533
- inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
534
- inside_cubes_center_idx = torch.arange(
535
- inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
536
-
537
- surface_n_inside_cubes = surf_cubes | inside_cubes
538
- edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
539
- dtype=torch.long, device=x_nx3.device) * -1
540
- surf_cubes = surf_cubes[surface_n_inside_cubes]
541
- inside_cubes = inside_cubes[surface_n_inside_cubes]
542
- edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
543
- edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
544
-
545
- all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
546
- unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
547
- unique_edges = unique_edges.long()
548
- mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
549
- mask = mask_edges[_idx_map]
550
- counts = counts[_idx_map]
551
- mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
552
- mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
553
- idx_map = mapping[_idx_map]
554
-
555
- group_mask = (counts == 4) & mask
556
- group = idx_map.reshape(-1)[group_mask]
557
- edge_indices, indices = torch.sort(group)
558
- cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
559
- device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
560
- edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
561
- 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
562
- # Identify the face shared by the adjacent cells.
563
- cube_idx_4 = cube_idx[indices].reshape(-1, 4)
564
- edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
565
- shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
566
- cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
567
- # Identify an edge of the face with different signs and
568
- # select the mesh vertex corresponding to the identified edge.
569
- case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
570
- case_ids_expand[surf_cubes] = case_ids
571
- cases = case_ids_expand[cube_idx_4x2]
572
- quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
573
- mask = (quad_edge == -1).sum(-1) == 0
574
- inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
575
- tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
576
-
577
- tets = torch.cat([tets_surface, tets_inside])
578
- vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
579
- return vertices, tets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/geometry/rep_3d/flexicubes_geometry.py DELETED
@@ -1,120 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- import numpy as np
11
- import os
12
- from . import Geometry
13
- from .flexicubes import FlexiCubes # replace later
14
- from .dmtet import sdf_reg_loss_batch
15
- import torch.nn.functional as F
16
-
17
- def get_center_boundary_index(grid_res, device):
18
- v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
19
- v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
20
- center_indices = torch.nonzero(v.reshape(-1))
21
-
22
- v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
23
- v[:2, ...] = True
24
- v[-2:, ...] = True
25
- v[:, :2, ...] = True
26
- v[:, -2:, ...] = True
27
- v[:, :, :2] = True
28
- v[:, :, -2:] = True
29
- boundary_indices = torch.nonzero(v.reshape(-1))
30
- return center_indices, boundary_indices
31
-
32
- ###############################################################################
33
- # Geometry interface
34
- ###############################################################################
35
- class FlexiCubesGeometry(Geometry):
36
- def __init__(
37
- self, grid_res=64, scale=2.0, device='cuda', renderer=None,
38
- render_type='neural_render', args=None):
39
- super(FlexiCubesGeometry, self).__init__()
40
- self.grid_res = grid_res
41
- self.device = device
42
- self.args = args
43
- self.fc = FlexiCubes(device, weight_scale=0.5)
44
- self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
45
- if isinstance(scale, list):
46
- self.verts[:, 0] = self.verts[:, 0] * scale[0]
47
- self.verts[:, 1] = self.verts[:, 1] * scale[1]
48
- self.verts[:, 2] = self.verts[:, 2] * scale[1]
49
- else:
50
- self.verts = self.verts * scale
51
-
52
- all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
53
- self.all_edges = torch.unique(all_edges, dim=0)
54
-
55
- # Parameters used for fix boundary sdf
56
- self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
57
- self.renderer = renderer
58
- self.render_type = render_type
59
-
60
- def getAABB(self):
61
- return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
62
-
63
- def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
64
- if indices is None:
65
- indices = self.indices
66
-
67
- verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
68
- beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
69
- gamma_f=weight_n[:, 20], training=is_training
70
- )
71
- return verts, faces, v_reg_loss
72
-
73
-
74
- def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
75
- return_value = dict()
76
- if self.render_type == 'neural_render':
77
- tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh(
78
- mesh_v_nx3.unsqueeze(dim=0),
79
- mesh_f_fx3.int(),
80
- camera_mv_bx4x4,
81
- mesh_v_nx3.unsqueeze(dim=0),
82
- resolution=resolution,
83
- device=self.device,
84
- hierarchical_mask=hierarchical_mask
85
- )
86
-
87
- return_value['tex_pos'] = tex_pos
88
- return_value['mask'] = mask
89
- return_value['hard_mask'] = hard_mask
90
- return_value['rast'] = rast
91
- return_value['v_pos_clip'] = v_pos_clip
92
- return_value['mask_pyramid'] = mask_pyramid
93
- return_value['depth'] = depth
94
- return_value['normal'] = normal
95
- else:
96
- raise NotImplementedError
97
-
98
- return return_value
99
-
100
- def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
101
- # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
102
- v_list = []
103
- f_list = []
104
- n_batch = v_deformed_bxnx3.shape[0]
105
- all_render_output = []
106
- for i_batch in range(n_batch):
107
- verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
108
- v_list.append(verts_nx3)
109
- f_list.append(faces_fx3)
110
- render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
111
- all_render_output.append(render_output)
112
-
113
- # Concatenate all render output
114
- return_keys = all_render_output[0].keys()
115
- return_value = dict()
116
- for k in return_keys:
117
- value = [v[k] for v in all_render_output]
118
- return_value[k] = value
119
- # We can do concatenation outside of the render
120
- return return_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/geometry/rep_3d/tables.py DELETED
@@ -1,791 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
- dmc_table = [
9
- [[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
10
- [[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
11
- [[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
12
- [[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
13
- [[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
14
- [[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
15
- [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
16
- [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
17
- [[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
18
- [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
19
- [[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
20
- [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
21
- [[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
22
- [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
23
- [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
24
- [[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
25
- [[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
26
- [[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
27
- [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
28
- [[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
29
- [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
30
- [[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
31
- [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
32
- [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
33
- [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
34
- [[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
35
- [[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
36
- [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
37
- [[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
38
- [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
39
- [[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
40
- [[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
41
- [[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
42
- [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
43
- [[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
44
- [[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
45
- [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
46
- [[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
47
- [[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
48
- [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
49
- [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
50
- [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
51
- [[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
52
- [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
53
- [[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
54
- [[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
55
- [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
56
- [[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
57
- [[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
58
- [[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
59
- [[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
60
- [[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
61
- [[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
62
- [[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
63
- [[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
64
- [[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
65
- [[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
66
- [[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
67
- [[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
68
- [[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
69
- [[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
70
- [[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
71
- [[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
72
- [[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
73
- [[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
74
- [[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
75
- [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
76
- [[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
77
- [[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
78
- [[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
79
- [[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
80
- [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
81
- [[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
82
- [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
83
- [[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
84
- [[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
85
- [[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
86
- [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
87
- [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
88
- [[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
89
- [[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
90
- [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
91
- [[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
92
- [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
93
- [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
94
- [[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
95
- [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
96
- [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
97
- [[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
98
- [[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
99
- [[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
100
- [[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
101
- [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
102
- [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
103
- [[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
104
- [[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
105
- [[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
106
- [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
107
- [[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
108
- [[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
109
- [[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
110
- [[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
111
- [[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
112
- [[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
113
- [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
114
- [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
115
- [[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
116
- [[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
117
- [[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
118
- [[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
119
- [[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
120
- [[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
121
- [[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
122
- [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
123
- [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
124
- [[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
125
- [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
126
- [[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
127
- [[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
128
- [[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
129
- [[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
130
- [[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
131
- [[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
132
- [[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
133
- [[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
134
- [[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
135
- [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
136
- [[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
137
- [[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
138
- [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
139
- [[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
140
- [[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
141
- [[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
142
- [[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
143
- [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
144
- [[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
145
- [[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
146
- [[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
147
- [[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
148
- [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
149
- [[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
150
- [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
151
- [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
152
- [[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
153
- [[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
154
- [[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
155
- [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
156
- [[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
157
- [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
158
- [[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
159
- [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
160
- [[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
161
- [[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
162
- [[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
163
- [[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
164
- [[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
165
- [[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
166
- [[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
167
- [[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
168
- [[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
169
- [[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
170
- [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
171
- [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
172
- [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
173
- [[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
174
- [[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
175
- [[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
176
- [[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
177
- [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
178
- [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
179
- [[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
180
- [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
181
- [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
182
- [[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
183
- [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
184
- [[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
185
- [[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
186
- [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
187
- [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
188
- [[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
189
- [[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
190
- [[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
191
- [[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
192
- [[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
193
- [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
194
- [[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
195
- [[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
196
- [[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
197
- [[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
198
- [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
199
- [[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
200
- [[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
201
- [[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
202
- [[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
203
- [[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
204
- [[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
205
- [[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
206
- [[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
207
- [[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
208
- [[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
209
- [[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
210
- [[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
211
- [[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
212
- [[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
213
- [[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
214
- [[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
215
- [[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
216
- [[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
217
- [[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
218
- [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
219
- [[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
220
- [[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
221
- [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
222
- [[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
223
- [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
224
- [[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
225
- [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
226
- [[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
227
- [[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
228
- [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
229
- [[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
230
- [[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
231
- [[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
232
- [[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
233
- [[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
234
- [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
235
- [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
236
- [[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
237
- [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
238
- [[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
239
- [[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
240
- [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
241
- [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
242
- [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
243
- [[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
244
- [[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
245
- [[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
246
- [[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
247
- [[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
248
- [[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
249
- [[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
250
- [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
251
- [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
252
- [[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
253
- [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
254
- [[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
255
- [[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
256
- [[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
257
- [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
258
- [[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
259
- [[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
260
- [[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
261
- [[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
262
- [[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
263
- [[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
264
- [[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
265
- ]
266
- num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
267
- 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
268
- 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
269
- 1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
270
- 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
271
- 3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
272
- 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
273
- 1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
274
- 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
275
- 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
276
- check_table = [
277
- [0, 0, 0, 0, 0],
278
- [0, 0, 0, 0, 0],
279
- [0, 0, 0, 0, 0],
280
- [0, 0, 0, 0, 0],
281
- [0, 0, 0, 0, 0],
282
- [0, 0, 0, 0, 0],
283
- [0, 0, 0, 0, 0],
284
- [0, 0, 0, 0, 0],
285
- [0, 0, 0, 0, 0],
286
- [0, 0, 0, 0, 0],
287
- [0, 0, 0, 0, 0],
288
- [0, 0, 0, 0, 0],
289
- [0, 0, 0, 0, 0],
290
- [0, 0, 0, 0, 0],
291
- [0, 0, 0, 0, 0],
292
- [0, 0, 0, 0, 0],
293
- [0, 0, 0, 0, 0],
294
- [0, 0, 0, 0, 0],
295
- [0, 0, 0, 0, 0],
296
- [0, 0, 0, 0, 0],
297
- [0, 0, 0, 0, 0],
298
- [0, 0, 0, 0, 0],
299
- [0, 0, 0, 0, 0],
300
- [0, 0, 0, 0, 0],
301
- [0, 0, 0, 0, 0],
302
- [0, 0, 0, 0, 0],
303
- [0, 0, 0, 0, 0],
304
- [0, 0, 0, 0, 0],
305
- [0, 0, 0, 0, 0],
306
- [0, 0, 0, 0, 0],
307
- [0, 0, 0, 0, 0],
308
- [0, 0, 0, 0, 0],
309
- [0, 0, 0, 0, 0],
310
- [0, 0, 0, 0, 0],
311
- [0, 0, 0, 0, 0],
312
- [0, 0, 0, 0, 0],
313
- [0, 0, 0, 0, 0],
314
- [0, 0, 0, 0, 0],
315
- [0, 0, 0, 0, 0],
316
- [0, 0, 0, 0, 0],
317
- [0, 0, 0, 0, 0],
318
- [0, 0, 0, 0, 0],
319
- [0, 0, 0, 0, 0],
320
- [0, 0, 0, 0, 0],
321
- [0, 0, 0, 0, 0],
322
- [0, 0, 0, 0, 0],
323
- [0, 0, 0, 0, 0],
324
- [0, 0, 0, 0, 0],
325
- [0, 0, 0, 0, 0],
326
- [0, 0, 0, 0, 0],
327
- [0, 0, 0, 0, 0],
328
- [0, 0, 0, 0, 0],
329
- [0, 0, 0, 0, 0],
330
- [0, 0, 0, 0, 0],
331
- [0, 0, 0, 0, 0],
332
- [0, 0, 0, 0, 0],
333
- [0, 0, 0, 0, 0],
334
- [0, 0, 0, 0, 0],
335
- [0, 0, 0, 0, 0],
336
- [0, 0, 0, 0, 0],
337
- [0, 0, 0, 0, 0],
338
- [1, 1, 0, 0, 194],
339
- [1, -1, 0, 0, 193],
340
- [0, 0, 0, 0, 0],
341
- [0, 0, 0, 0, 0],
342
- [0, 0, 0, 0, 0],
343
- [0, 0, 0, 0, 0],
344
- [0, 0, 0, 0, 0],
345
- [0, 0, 0, 0, 0],
346
- [0, 0, 0, 0, 0],
347
- [0, 0, 0, 0, 0],
348
- [0, 0, 0, 0, 0],
349
- [0, 0, 0, 0, 0],
350
- [0, 0, 0, 0, 0],
351
- [0, 0, 0, 0, 0],
352
- [0, 0, 0, 0, 0],
353
- [0, 0, 0, 0, 0],
354
- [0, 0, 0, 0, 0],
355
- [0, 0, 0, 0, 0],
356
- [0, 0, 0, 0, 0],
357
- [0, 0, 0, 0, 0],
358
- [0, 0, 0, 0, 0],
359
- [0, 0, 0, 0, 0],
360
- [0, 0, 0, 0, 0],
361
- [0, 0, 0, 0, 0],
362
- [0, 0, 0, 0, 0],
363
- [0, 0, 0, 0, 0],
364
- [0, 0, 0, 0, 0],
365
- [0, 0, 0, 0, 0],
366
- [0, 0, 0, 0, 0],
367
- [0, 0, 0, 0, 0],
368
- [1, 0, 1, 0, 164],
369
- [0, 0, 0, 0, 0],
370
- [0, 0, 0, 0, 0],
371
- [1, 0, -1, 0, 161],
372
- [0, 0, 0, 0, 0],
373
- [0, 0, 0, 0, 0],
374
- [0, 0, 0, 0, 0],
375
- [0, 0, 0, 0, 0],
376
- [0, 0, 0, 0, 0],
377
- [0, 0, 0, 0, 0],
378
- [0, 0, 0, 0, 0],
379
- [0, 0, 0, 0, 0],
380
- [1, 0, 0, 1, 152],
381
- [0, 0, 0, 0, 0],
382
- [0, 0, 0, 0, 0],
383
- [0, 0, 0, 0, 0],
384
- [0, 0, 0, 0, 0],
385
- [0, 0, 0, 0, 0],
386
- [0, 0, 0, 0, 0],
387
- [1, 0, 0, 1, 145],
388
- [1, 0, 0, 1, 144],
389
- [0, 0, 0, 0, 0],
390
- [0, 0, 0, 0, 0],
391
- [0, 0, 0, 0, 0],
392
- [0, 0, 0, 0, 0],
393
- [0, 0, 0, 0, 0],
394
- [0, 0, 0, 0, 0],
395
- [1, 0, 0, -1, 137],
396
- [0, 0, 0, 0, 0],
397
- [0, 0, 0, 0, 0],
398
- [0, 0, 0, 0, 0],
399
- [1, 0, 1, 0, 133],
400
- [1, 0, 1, 0, 132],
401
- [1, 1, 0, 0, 131],
402
- [1, 1, 0, 0, 130],
403
- [0, 0, 0, 0, 0],
404
- [0, 0, 0, 0, 0],
405
- [0, 0, 0, 0, 0],
406
- [0, 0, 0, 0, 0],
407
- [0, 0, 0, 0, 0],
408
- [0, 0, 0, 0, 0],
409
- [0, 0, 0, 0, 0],
410
- [0, 0, 0, 0, 0],
411
- [0, 0, 0, 0, 0],
412
- [0, 0, 0, 0, 0],
413
- [0, 0, 0, 0, 0],
414
- [0, 0, 0, 0, 0],
415
- [0, 0, 0, 0, 0],
416
- [0, 0, 0, 0, 0],
417
- [0, 0, 0, 0, 0],
418
- [0, 0, 0, 0, 0],
419
- [0, 0, 0, 0, 0],
420
- [0, 0, 0, 0, 0],
421
- [0, 0, 0, 0, 0],
422
- [0, 0, 0, 0, 0],
423
- [0, 0, 0, 0, 0],
424
- [0, 0, 0, 0, 0],
425
- [0, 0, 0, 0, 0],
426
- [0, 0, 0, 0, 0],
427
- [0, 0, 0, 0, 0],
428
- [0, 0, 0, 0, 0],
429
- [0, 0, 0, 0, 0],
430
- [0, 0, 0, 0, 0],
431
- [0, 0, 0, 0, 0],
432
- [1, 0, 0, 1, 100],
433
- [0, 0, 0, 0, 0],
434
- [1, 0, 0, 1, 98],
435
- [0, 0, 0, 0, 0],
436
- [1, 0, 0, 1, 96],
437
- [0, 0, 0, 0, 0],
438
- [0, 0, 0, 0, 0],
439
- [0, 0, 0, 0, 0],
440
- [0, 0, 0, 0, 0],
441
- [0, 0, 0, 0, 0],
442
- [0, 0, 0, 0, 0],
443
- [0, 0, 0, 0, 0],
444
- [1, 0, 1, 0, 88],
445
- [0, 0, 0, 0, 0],
446
- [0, 0, 0, 0, 0],
447
- [0, 0, 0, 0, 0],
448
- [0, 0, 0, 0, 0],
449
- [0, 0, 0, 0, 0],
450
- [1, 0, -1, 0, 82],
451
- [0, 0, 0, 0, 0],
452
- [0, 0, 0, 0, 0],
453
- [0, 0, 0, 0, 0],
454
- [0, 0, 0, 0, 0],
455
- [0, 0, 0, 0, 0],
456
- [0, 0, 0, 0, 0],
457
- [0, 0, 0, 0, 0],
458
- [1, 0, 1, 0, 74],
459
- [0, 0, 0, 0, 0],
460
- [1, 0, 1, 0, 72],
461
- [0, 0, 0, 0, 0],
462
- [1, 0, 0, -1, 70],
463
- [0, 0, 0, 0, 0],
464
- [0, 0, 0, 0, 0],
465
- [1, -1, 0, 0, 67],
466
- [0, 0, 0, 0, 0],
467
- [1, -1, 0, 0, 65],
468
- [0, 0, 0, 0, 0],
469
- [0, 0, 0, 0, 0],
470
- [0, 0, 0, 0, 0],
471
- [0, 0, 0, 0, 0],
472
- [0, 0, 0, 0, 0],
473
- [0, 0, 0, 0, 0],
474
- [0, 0, 0, 0, 0],
475
- [0, 0, 0, 0, 0],
476
- [1, 1, 0, 0, 56],
477
- [0, 0, 0, 0, 0],
478
- [0, 0, 0, 0, 0],
479
- [0, 0, 0, 0, 0],
480
- [1, -1, 0, 0, 52],
481
- [0, 0, 0, 0, 0],
482
- [0, 0, 0, 0, 0],
483
- [0, 0, 0, 0, 0],
484
- [0, 0, 0, 0, 0],
485
- [0, 0, 0, 0, 0],
486
- [0, 0, 0, 0, 0],
487
- [0, 0, 0, 0, 0],
488
- [1, 1, 0, 0, 44],
489
- [0, 0, 0, 0, 0],
490
- [0, 0, 0, 0, 0],
491
- [0, 0, 0, 0, 0],
492
- [1, 1, 0, 0, 40],
493
- [0, 0, 0, 0, 0],
494
- [1, 0, 0, -1, 38],
495
- [1, 0, -1, 0, 37],
496
- [0, 0, 0, 0, 0],
497
- [0, 0, 0, 0, 0],
498
- [0, 0, 0, 0, 0],
499
- [1, 0, -1, 0, 33],
500
- [0, 0, 0, 0, 0],
501
- [0, 0, 0, 0, 0],
502
- [0, 0, 0, 0, 0],
503
- [0, 0, 0, 0, 0],
504
- [1, -1, 0, 0, 28],
505
- [0, 0, 0, 0, 0],
506
- [1, 0, -1, 0, 26],
507
- [1, 0, 0, -1, 25],
508
- [0, 0, 0, 0, 0],
509
- [0, 0, 0, 0, 0],
510
- [0, 0, 0, 0, 0],
511
- [0, 0, 0, 0, 0],
512
- [1, -1, 0, 0, 20],
513
- [0, 0, 0, 0, 0],
514
- [1, 0, -1, 0, 18],
515
- [0, 0, 0, 0, 0],
516
- [0, 0, 0, 0, 0],
517
- [0, 0, 0, 0, 0],
518
- [0, 0, 0, 0, 0],
519
- [0, 0, 0, 0, 0],
520
- [0, 0, 0, 0, 0],
521
- [0, 0, 0, 0, 0],
522
- [0, 0, 0, 0, 0],
523
- [1, 0, 0, -1, 9],
524
- [0, 0, 0, 0, 0],
525
- [0, 0, 0, 0, 0],
526
- [1, 0, 0, -1, 6],
527
- [0, 0, 0, 0, 0],
528
- [0, 0, 0, 0, 0],
529
- [0, 0, 0, 0, 0],
530
- [0, 0, 0, 0, 0],
531
- [0, 0, 0, 0, 0],
532
- [0, 0, 0, 0, 0]
533
- ]
534
- tet_table = [
535
- [-1, -1, -1, -1, -1, -1],
536
- [0, 0, 0, 0, 0, 0],
537
- [0, 0, 0, 0, 0, 0],
538
- [1, 1, 1, 1, 1, 1],
539
- [4, 4, 4, 4, 4, 4],
540
- [0, 0, 0, 0, 0, 0],
541
- [4, 0, 0, 4, 4, -1],
542
- [1, 1, 1, 1, 1, 1],
543
- [4, 4, 4, 4, 4, 4],
544
- [0, 4, 0, 4, 4, -1],
545
- [0, 0, 0, 0, 0, 0],
546
- [1, 1, 1, 1, 1, 1],
547
- [5, 5, 5, 5, 5, 5],
548
- [0, 0, 0, 0, 0, 0],
549
- [0, 0, 0, 0, 0, 0],
550
- [1, 1, 1, 1, 1, 1],
551
- [2, 2, 2, 2, 2, 2],
552
- [0, 0, 0, 0, 0, 0],
553
- [2, 0, 2, -1, 0, 2],
554
- [1, 1, 1, 1, 1, 1],
555
- [2, -1, 2, 4, 4, 2],
556
- [0, 0, 0, 0, 0, 0],
557
- [2, 0, 2, 4, 4, 2],
558
- [1, 1, 1, 1, 1, 1],
559
- [2, 4, 2, 4, 4, 2],
560
- [0, 4, 0, 4, 4, 0],
561
- [2, 0, 2, 0, 0, 2],
562
- [1, 1, 1, 1, 1, 1],
563
- [2, 5, 2, 5, 5, 2],
564
- [0, 0, 0, 0, 0, 0],
565
- [2, 0, 2, 0, 0, 2],
566
- [1, 1, 1, 1, 1, 1],
567
- [1, 1, 1, 1, 1, 1],
568
- [0, 1, 1, -1, 0, 1],
569
- [0, 0, 0, 0, 0, 0],
570
- [2, 2, 2, 2, 2, 2],
571
- [4, 1, 1, 4, 4, 1],
572
- [0, 1, 1, 0, 0, 1],
573
- [4, 0, 0, 4, 4, 0],
574
- [2, 2, 2, 2, 2, 2],
575
- [-1, 1, 1, 4, 4, 1],
576
- [0, 1, 1, 4, 4, 1],
577
- [0, 0, 0, 0, 0, 0],
578
- [2, 2, 2, 2, 2, 2],
579
- [5, 1, 1, 5, 5, 1],
580
- [0, 1, 1, 0, 0, 1],
581
- [0, 0, 0, 0, 0, 0],
582
- [2, 2, 2, 2, 2, 2],
583
- [1, 1, 1, 1, 1, 1],
584
- [0, 0, 0, 0, 0, 0],
585
- [0, 0, 0, 0, 0, 0],
586
- [8, 8, 8, 8, 8, 8],
587
- [1, 1, 1, 4, 4, 1],
588
- [0, 0, 0, 0, 0, 0],
589
- [4, 0, 0, 4, 4, 0],
590
- [4, 4, 4, 4, 4, 4],
591
- [1, 1, 1, 4, 4, 1],
592
- [0, 4, 0, 4, 4, 0],
593
- [0, 0, 0, 0, 0, 0],
594
- [4, 4, 4, 4, 4, 4],
595
- [1, 1, 1, 5, 5, 1],
596
- [0, 0, 0, 0, 0, 0],
597
- [0, 0, 0, 0, 0, 0],
598
- [5, 5, 5, 5, 5, 5],
599
- [6, 6, 6, 6, 6, 6],
600
- [6, -1, 0, 6, 0, 6],
601
- [6, 0, 0, 6, 0, 6],
602
- [6, 1, 1, 6, 1, 6],
603
- [4, 4, 4, 4, 4, 4],
604
- [0, 0, 0, 0, 0, 0],
605
- [4, 0, 0, 4, 4, 4],
606
- [1, 1, 1, 1, 1, 1],
607
- [6, 4, -1, 6, 4, 6],
608
- [6, 4, 0, 6, 4, 6],
609
- [6, 0, 0, 6, 0, 6],
610
- [6, 1, 1, 6, 1, 6],
611
- [5, 5, 5, 5, 5, 5],
612
- [0, 0, 0, 0, 0, 0],
613
- [0, 0, 0, 0, 0, 0],
614
- [1, 1, 1, 1, 1, 1],
615
- [2, 2, 2, 2, 2, 2],
616
- [0, 0, 0, 0, 0, 0],
617
- [2, 0, 2, 2, 0, 2],
618
- [1, 1, 1, 1, 1, 1],
619
- [2, 2, 2, 2, 2, 2],
620
- [0, 0, 0, 0, 0, 0],
621
- [2, 0, 2, 2, 2, 2],
622
- [1, 1, 1, 1, 1, 1],
623
- [2, 4, 2, 2, 4, 2],
624
- [0, 4, 0, 4, 4, 0],
625
- [2, 0, 2, 2, 0, 2],
626
- [1, 1, 1, 1, 1, 1],
627
- [2, 2, 2, 2, 2, 2],
628
- [0, 0, 0, 0, 0, 0],
629
- [0, 0, 0, 0, 0, 0],
630
- [1, 1, 1, 1, 1, 1],
631
- [6, 1, 1, 6, -1, 6],
632
- [6, 1, 1, 6, 0, 6],
633
- [6, 0, 0, 6, 0, 6],
634
- [6, 2, 2, 6, 2, 6],
635
- [4, 1, 1, 4, 4, 1],
636
- [0, 1, 1, 0, 0, 1],
637
- [4, 0, 0, 4, 4, 4],
638
- [2, 2, 2, 2, 2, 2],
639
- [6, 1, 1, 6, 4, 6],
640
- [6, 1, 1, 6, 4, 6],
641
- [6, 0, 0, 6, 0, 6],
642
- [6, 2, 2, 6, 2, 6],
643
- [5, 1, 1, 5, 5, 1],
644
- [0, 1, 1, 0, 0, 1],
645
- [0, 0, 0, 0, 0, 0],
646
- [2, 2, 2, 2, 2, 2],
647
- [1, 1, 1, 1, 1, 1],
648
- [0, 0, 0, 0, 0, 0],
649
- [0, 0, 0, 0, 0, 0],
650
- [6, 6, 6, 6, 6, 6],
651
- [1, 1, 1, 1, 1, 1],
652
- [0, 0, 0, 0, 0, 0],
653
- [0, 0, 0, 0, 0, 0],
654
- [4, 4, 4, 4, 4, 4],
655
- [1, 1, 1, 1, 4, 1],
656
- [0, 4, 0, 4, 4, 0],
657
- [0, 0, 0, 0, 0, 0],
658
- [4, 4, 4, 4, 4, 4],
659
- [1, 1, 1, 1, 1, 1],
660
- [0, 0, 0, 0, 0, 0],
661
- [0, 5, 0, 5, 0, 5],
662
- [5, 5, 5, 5, 5, 5],
663
- [5, 5, 5, 5, 5, 5],
664
- [0, 5, 0, 5, 0, 5],
665
- [-1, 5, 0, 5, 0, 5],
666
- [1, 5, 1, 5, 1, 5],
667
- [4, 5, -1, 5, 4, 5],
668
- [0, 5, 0, 5, 0, 5],
669
- [4, 5, 0, 5, 4, 5],
670
- [1, 5, 1, 5, 1, 5],
671
- [4, 4, 4, 4, 4, 4],
672
- [0, 4, 0, 4, 4, 4],
673
- [0, 0, 0, 0, 0, 0],
674
- [1, 1, 1, 1, 1, 1],
675
- [6, 6, 6, 6, 6, 6],
676
- [0, 0, 0, 0, 0, 0],
677
- [0, 0, 0, 0, 0, 0],
678
- [1, 1, 1, 1, 1, 1],
679
- [2, 5, 2, 5, -1, 5],
680
- [0, 5, 0, 5, 0, 5],
681
- [2, 5, 2, 5, 0, 5],
682
- [1, 5, 1, 5, 1, 5],
683
- [2, 5, 2, 5, 4, 5],
684
- [0, 5, 0, 5, 0, 5],
685
- [2, 5, 2, 5, 4, 5],
686
- [1, 5, 1, 5, 1, 5],
687
- [2, 4, 2, 4, 4, 2],
688
- [0, 4, 0, 4, 4, 4],
689
- [2, 0, 2, 0, 0, 2],
690
- [1, 1, 1, 1, 1, 1],
691
- [2, 6, 2, 6, 6, 2],
692
- [0, 0, 0, 0, 0, 0],
693
- [2, 0, 2, 0, 0, 2],
694
- [1, 1, 1, 1, 1, 1],
695
- [1, 1, 1, 1, 1, 1],
696
- [0, 1, 1, 1, 0, 1],
697
- [0, 0, 0, 0, 0, 0],
698
- [2, 2, 2, 2, 2, 2],
699
- [4, 1, 1, 1, 4, 1],
700
- [0, 1, 1, 1, 0, 1],
701
- [4, 0, 0, 4, 4, 0],
702
- [2, 2, 2, 2, 2, 2],
703
- [1, 1, 1, 1, 1, 1],
704
- [0, 1, 1, 1, 1, 1],
705
- [0, 0, 0, 0, 0, 0],
706
- [2, 2, 2, 2, 2, 2],
707
- [1, 1, 1, 1, 1, 1],
708
- [0, 0, 0, 0, 0, 0],
709
- [0, 0, 0, 0, 0, 0],
710
- [2, 2, 2, 2, 2, 2],
711
- [1, 1, 1, 1, 1, 1],
712
- [0, 0, 0, 0, 0, 0],
713
- [0, 0, 0, 0, 0, 0],
714
- [5, 5, 5, 5, 5, 5],
715
- [1, 1, 1, 1, 4, 1],
716
- [0, 0, 0, 0, 0, 0],
717
- [4, 0, 0, 4, 4, 0],
718
- [4, 4, 4, 4, 4, 4],
719
- [1, 1, 1, 1, 1, 1],
720
- [0, 0, 0, 0, 0, 0],
721
- [0, 0, 0, 0, 0, 0],
722
- [4, 4, 4, 4, 4, 4],
723
- [1, 1, 1, 1, 1, 1],
724
- [6, 0, 0, 6, 0, 6],
725
- [0, 0, 0, 0, 0, 0],
726
- [6, 6, 6, 6, 6, 6],
727
- [5, 5, 5, 5, 5, 5],
728
- [5, 5, 0, 5, 0, 5],
729
- [5, 5, 0, 5, 0, 5],
730
- [5, 5, 1, 5, 1, 5],
731
- [4, 4, 4, 4, 4, 4],
732
- [0, 0, 0, 0, 0, 0],
733
- [4, 4, 0, 4, 4, 4],
734
- [1, 1, 1, 1, 1, 1],
735
- [4, 4, 4, 4, 4, 4],
736
- [4, 4, 0, 4, 4, 4],
737
- [0, 0, 0, 0, 0, 0],
738
- [1, 1, 1, 1, 1, 1],
739
- [8, 8, 8, 8, 8, 8],
740
- [0, 0, 0, 0, 0, 0],
741
- [0, 0, 0, 0, 0, 0],
742
- [1, 1, 1, 1, 1, 1],
743
- [2, 2, 2, 2, 2, 2],
744
- [0, 0, 0, 0, 0, 0],
745
- [2, 2, 2, 2, 0, 2],
746
- [1, 1, 1, 1, 1, 1],
747
- [2, 2, 2, 2, 2, 2],
748
- [0, 0, 0, 0, 0, 0],
749
- [2, 2, 2, 2, 2, 2],
750
- [1, 1, 1, 1, 1, 1],
751
- [2, 2, 2, 2, 2, 2],
752
- [0, 0, 0, 0, 0, 0],
753
- [0, 0, 0, 0, 0, 0],
754
- [4, 1, 1, 4, 4, 1],
755
- [2, 2, 2, 2, 2, 2],
756
- [0, 0, 0, 0, 0, 0],
757
- [0, 0, 0, 0, 0, 0],
758
- [1, 1, 1, 1, 1, 1],
759
- [1, 1, 1, 1, 1, 1],
760
- [1, 1, 1, 1, 0, 1],
761
- [0, 0, 0, 0, 0, 0],
762
- [2, 2, 2, 2, 2, 2],
763
- [1, 1, 1, 1, 1, 1],
764
- [0, 0, 0, 0, 0, 0],
765
- [0, 0, 0, 0, 0, 0],
766
- [2, 4, 2, 4, 4, 2],
767
- [1, 1, 1, 1, 1, 1],
768
- [1, 1, 1, 1, 1, 1],
769
- [0, 0, 0, 0, 0, 0],
770
- [2, 2, 2, 2, 2, 2],
771
- [1, 1, 1, 1, 1, 1],
772
- [0, 0, 0, 0, 0, 0],
773
- [0, 0, 0, 0, 0, 0],
774
- [2, 2, 2, 2, 2, 2],
775
- [1, 1, 1, 1, 1, 1],
776
- [0, 0, 0, 0, 0, 0],
777
- [0, 0, 0, 0, 0, 0],
778
- [5, 5, 5, 5, 5, 5],
779
- [1, 1, 1, 1, 1, 1],
780
- [0, 0, 0, 0, 0, 0],
781
- [0, 0, 0, 0, 0, 0],
782
- [4, 4, 4, 4, 4, 4],
783
- [1, 1, 1, 1, 1, 1],
784
- [0, 0, 0, 0, 0, 0],
785
- [0, 0, 0, 0, 0, 0],
786
- [4, 4, 4, 4, 4, 4],
787
- [1, 1, 1, 1, 1, 1],
788
- [0, 0, 0, 0, 0, 0],
789
- [0, 0, 0, 0, 0, 0],
790
- [12, 12, 12, 12, 12, 12]
791
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/lrm.py DELETED
@@ -1,196 +0,0 @@
1
- # Copyright (c) 2023, Zexin He
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import numpy as np
16
- import torch
17
- import torch.nn as nn
18
- import mcubes
19
- import nvdiffrast.torch as dr
20
- from einops import rearrange, repeat
21
-
22
- from .encoder.dino_wrapper import DinoWrapper
23
- from .decoder.transformer import TriplaneTransformer
24
- from .renderer.synthesizer import TriplaneSynthesizer
25
- from ..utils.mesh_util import xatlas_uvmap
26
-
27
-
28
- class InstantNeRF(nn.Module):
29
- """
30
- Full model of the large reconstruction model.
31
- """
32
- def __init__(
33
- self,
34
- encoder_freeze: bool = False,
35
- encoder_model_name: str = 'facebook/dino-vitb16',
36
- encoder_feat_dim: int = 768,
37
- transformer_dim: int = 1024,
38
- transformer_layers: int = 16,
39
- transformer_heads: int = 16,
40
- triplane_low_res: int = 32,
41
- triplane_high_res: int = 64,
42
- triplane_dim: int = 80,
43
- rendering_samples_per_ray: int = 128,
44
- ):
45
- super().__init__()
46
-
47
- # modules
48
- self.encoder = DinoWrapper(
49
- model_name=encoder_model_name,
50
- freeze=encoder_freeze,
51
- )
52
-
53
- self.transformer = TriplaneTransformer(
54
- inner_dim=transformer_dim,
55
- num_layers=transformer_layers,
56
- num_heads=transformer_heads,
57
- image_feat_dim=encoder_feat_dim,
58
- triplane_low_res=triplane_low_res,
59
- triplane_high_res=triplane_high_res,
60
- triplane_dim=triplane_dim,
61
- )
62
-
63
- self.synthesizer = TriplaneSynthesizer(
64
- triplane_dim=triplane_dim,
65
- samples_per_ray=rendering_samples_per_ray,
66
- )
67
-
68
- def forward_planes(self, images, cameras):
69
- # images: [B, V, C_img, H_img, W_img]
70
- # cameras: [B, V, 16]
71
- B = images.shape[0]
72
-
73
- # encode images
74
- image_feats = self.encoder(images, cameras)
75
- image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
76
-
77
- # transformer generating planes
78
- planes = self.transformer(image_feats)
79
-
80
- return planes
81
-
82
- def forward(self, images, cameras, render_cameras, render_size: int):
83
- # images: [B, V, C_img, H_img, W_img]
84
- # cameras: [B, V, 16]
85
- # render_cameras: [B, M, D_cam_render]
86
- # render_size: int
87
- B, M = render_cameras.shape[:2]
88
-
89
- planes = self.forward_planes(images, cameras)
90
-
91
- # render target views
92
- render_results = self.synthesizer(planes, render_cameras, render_size)
93
-
94
- return {
95
- 'planes': planes,
96
- **render_results,
97
- }
98
-
99
- def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
100
- '''
101
- Predict Texture given triplanes
102
- :param planes: the triplane feature map
103
- :param tex_pos: Position we want to query the texture field
104
- :param hard_mask: 2D silhoueete of the rendered image
105
- '''
106
- tex_pos = torch.cat(tex_pos, dim=0)
107
- if not hard_mask is None:
108
- tex_pos = tex_pos * hard_mask.float()
109
- batch_size = tex_pos.shape[0]
110
- tex_pos = tex_pos.reshape(batch_size, -1, 3)
111
- ###################
112
- # We use mask to get the texture location (to save the memory)
113
- if hard_mask is not None:
114
- n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
115
- sample_tex_pose_list = []
116
- max_point = n_point_list.max()
117
- expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
118
- for i in range(tex_pos.shape[0]):
119
- tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
120
- if tex_pos_one_shape.shape[1] < max_point:
121
- tex_pos_one_shape = torch.cat(
122
- [tex_pos_one_shape, torch.zeros(
123
- 1, max_point - tex_pos_one_shape.shape[1], 3,
124
- device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
125
- sample_tex_pose_list.append(tex_pos_one_shape)
126
- tex_pos = torch.cat(sample_tex_pose_list, dim=0)
127
-
128
- tex_feat = self.synthesizer.forward_points(planes, tex_pos)['rgb']
129
-
130
- if hard_mask is not None:
131
- final_tex_feat = torch.zeros(
132
- planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
133
- expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
134
- for i in range(planes.shape[0]):
135
- final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
136
- tex_feat = final_tex_feat
137
-
138
- return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
139
-
140
- def extract_mesh(
141
- self,
142
- planes: torch.Tensor,
143
- mesh_resolution: int = 256,
144
- mesh_threshold: int = 10.0,
145
- use_texture_map: bool = False,
146
- texture_resolution: int = 1024,
147
- **kwargs,
148
- ):
149
- '''
150
- Extract a 3D mesh from triplane nerf. Only support batch_size 1.
151
- :param planes: triplane features
152
- :param mesh_resolution: marching cubes resolution
153
- :param mesh_threshold: iso-surface threshold
154
- :param use_texture_map: use texture map or vertex color
155
- :param texture_resolution: the resolution of texture map
156
- '''
157
- assert planes.shape[0] == 1
158
- device = planes.device
159
-
160
- grid_out = self.synthesizer.forward_grid(
161
- planes=planes,
162
- grid_size=mesh_resolution,
163
- )
164
-
165
- vertices, faces = mcubes.marching_cubes(
166
- grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(),
167
- mesh_threshold,
168
- )
169
- vertices = vertices / (mesh_resolution - 1) * 2 - 1
170
-
171
- if not use_texture_map:
172
- # query vertex colors
173
- vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0)
174
- vertices_colors = self.synthesizer.forward_points(
175
- planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy()
176
- vertices_colors = (vertices_colors * 255).astype(np.uint8)
177
-
178
- return vertices, faces, vertices_colors
179
-
180
- # use x-atlas to get uv mapping for the mesh
181
- vertices = torch.tensor(vertices, dtype=torch.float32, device=device)
182
- faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device)
183
-
184
- ctx = dr.RasterizeCudaContext(device=device)
185
- uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
186
- ctx, vertices, faces, resolution=texture_resolution)
187
- tex_hard_mask = tex_hard_mask.float()
188
-
189
- # query the texture field to get the RGB color for texture map
190
- tex_feat = self.get_texture_prediction(
191
- planes, [gb_pos], tex_hard_mask)
192
- background_feature = torch.zeros_like(tex_feat)
193
- img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask)
194
- texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
195
-
196
- return vertices, faces, uvs, mesh_tex_idx, texture_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/lrm_mesh.py DELETED
@@ -1,385 +0,0 @@
1
- # Copyright (c) 2023, Tencent Inc
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import numpy as np
16
- import torch
17
- import torch.nn as nn
18
- import nvdiffrast.torch as dr
19
- from einops import rearrange, repeat
20
-
21
- from .encoder.dino_wrapper import DinoWrapper
22
- from .decoder.transformer import TriplaneTransformer
23
- from .renderer.synthesizer_mesh import TriplaneSynthesizer
24
- from .geometry.camera.perspective_camera import PerspectiveCamera
25
- from .geometry.render.neural_render import NeuralRender
26
- from .geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry
27
- from ..utils.mesh_util import xatlas_uvmap
28
-
29
-
30
- class InstantMesh(nn.Module):
31
- """
32
- Full model of the large reconstruction model.
33
- """
34
- def __init__(
35
- self,
36
- encoder_freeze: bool = False,
37
- encoder_model_name: str = 'facebook/dino-vitb16',
38
- encoder_feat_dim: int = 768,
39
- transformer_dim: int = 1024,
40
- transformer_layers: int = 16,
41
- transformer_heads: int = 16,
42
- triplane_low_res: int = 32,
43
- triplane_high_res: int = 64,
44
- triplane_dim: int = 80,
45
- rendering_samples_per_ray: int = 128,
46
- grid_res: int = 128,
47
- grid_scale: float = 2.0,
48
- ):
49
- super().__init__()
50
-
51
- # attributes
52
- self.grid_res = grid_res
53
- self.grid_scale = grid_scale
54
- self.deformation_multiplier = 4.0
55
-
56
- # modules
57
- self.encoder = DinoWrapper(
58
- model_name=encoder_model_name,
59
- freeze=encoder_freeze,
60
- )
61
-
62
- self.transformer = TriplaneTransformer(
63
- inner_dim=transformer_dim,
64
- num_layers=transformer_layers,
65
- num_heads=transformer_heads,
66
- image_feat_dim=encoder_feat_dim,
67
- triplane_low_res=triplane_low_res,
68
- triplane_high_res=triplane_high_res,
69
- triplane_dim=triplane_dim,
70
- )
71
-
72
- self.synthesizer = TriplaneSynthesizer(
73
- triplane_dim=triplane_dim,
74
- samples_per_ray=rendering_samples_per_ray,
75
- )
76
-
77
- def init_flexicubes_geometry(self, device, fovy=50.0, use_renderer=True):
78
- camera = PerspectiveCamera(fovy=fovy, device=device)
79
- if use_renderer:
80
- renderer = NeuralRender(device, camera_model=camera)
81
- else:
82
- renderer = None
83
- self.geometry = FlexiCubesGeometry(
84
- grid_res=self.grid_res,
85
- scale=self.grid_scale,
86
- renderer=renderer,
87
- render_type='neural_render',
88
- device=device,
89
- )
90
-
91
- def forward_planes(self, images, cameras):
92
- # images: [B, V, C_img, H_img, W_img]
93
- # cameras: [B, V, 16]
94
- B = images.shape[0]
95
-
96
- # encode images
97
- image_feats = self.encoder(images, cameras)
98
- image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
99
-
100
- # decode triplanes
101
- planes = self.transformer(image_feats)
102
-
103
- return planes
104
-
105
- def get_sdf_deformation_prediction(self, planes):
106
- '''
107
- Predict SDF and deformation for tetrahedron vertices
108
- :param planes: triplane feature map for the geometry
109
- '''
110
- init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1)
111
-
112
- # Step 1: predict the SDF and deformation
113
- sdf, deformation, weight = torch.utils.checkpoint.checkpoint(
114
- self.synthesizer.get_geometry_prediction,
115
- planes,
116
- init_position,
117
- self.geometry.indices,
118
- use_reentrant=False,
119
- )
120
-
121
- # Step 2: Normalize the deformation to avoid the flipped triangles.
122
- deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation)
123
- sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32)
124
-
125
- ####
126
- # Step 3: Fix some sdf if we observe empty shape (full positive or full negative)
127
- sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1))
128
- sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1)
129
- pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1)
130
- neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1)
131
- zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
132
- if torch.sum(zero_surface).item() > 0:
133
- update_sdf = torch.zeros_like(sdf[0:1])
134
- max_sdf = sdf.max()
135
- min_sdf = sdf.min()
136
- update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero
137
- update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero
138
- new_sdf = torch.zeros_like(sdf)
139
- for i_batch in range(zero_surface.shape[0]):
140
- if zero_surface[i_batch]:
141
- new_sdf[i_batch:i_batch + 1] += update_sdf
142
- update_mask = (new_sdf == 0).float()
143
- # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative)
144
- sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1)
145
- sdf_reg_loss = sdf_reg_loss * zero_surface.float()
146
- sdf = sdf * update_mask + new_sdf * (1 - update_mask)
147
-
148
- # Step 4: Here we remove the gradient for the bad sdf (full positive or full negative)
149
- final_sdf = []
150
- final_def = []
151
- for i_batch in range(zero_surface.shape[0]):
152
- if zero_surface[i_batch]:
153
- final_sdf.append(sdf[i_batch: i_batch + 1].detach())
154
- final_def.append(deformation[i_batch: i_batch + 1].detach())
155
- else:
156
- final_sdf.append(sdf[i_batch: i_batch + 1])
157
- final_def.append(deformation[i_batch: i_batch + 1])
158
- sdf = torch.cat(final_sdf, dim=0)
159
- deformation = torch.cat(final_def, dim=0)
160
- return sdf, deformation, sdf_reg_loss, weight
161
-
162
- def get_geometry_prediction(self, planes=None):
163
- '''
164
- Function to generate mesh with give triplanes
165
- :param planes: triplane features
166
- '''
167
- # Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid.
168
- sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes)
169
- v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation
170
- tets = self.geometry.indices
171
- n_batch = planes.shape[0]
172
- v_list = []
173
- f_list = []
174
- flexicubes_surface_reg_list = []
175
-
176
- # Step 2: Using marching tet to obtain the mesh
177
- for i_batch in range(n_batch):
178
- verts, faces, flexicubes_surface_reg = self.geometry.get_mesh(
179
- v_deformed[i_batch],
180
- sdf[i_batch].squeeze(dim=-1),
181
- with_uv=False,
182
- indices=tets,
183
- weight_n=weight[i_batch].squeeze(dim=-1),
184
- is_training=self.training,
185
- )
186
- flexicubes_surface_reg_list.append(flexicubes_surface_reg)
187
- v_list.append(verts)
188
- f_list.append(faces)
189
-
190
- flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean()
191
- flexicubes_weight_reg = (weight ** 2).mean()
192
-
193
- return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
194
-
195
- def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
196
- '''
197
- Predict Texture given triplanes
198
- :param planes: the triplane feature map
199
- :param tex_pos: Position we want to query the texture field
200
- :param hard_mask: 2D silhoueete of the rendered image
201
- '''
202
- tex_pos = torch.cat(tex_pos, dim=0)
203
- if not hard_mask is None:
204
- tex_pos = tex_pos * hard_mask.float()
205
- batch_size = tex_pos.shape[0]
206
- tex_pos = tex_pos.reshape(batch_size, -1, 3)
207
- ###################
208
- # We use mask to get the texture location (to save the memory)
209
- if hard_mask is not None:
210
- n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
211
- sample_tex_pose_list = []
212
- max_point = n_point_list.max()
213
- expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
214
- for i in range(tex_pos.shape[0]):
215
- tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
216
- if tex_pos_one_shape.shape[1] < max_point:
217
- tex_pos_one_shape = torch.cat(
218
- [tex_pos_one_shape, torch.zeros(
219
- 1, max_point - tex_pos_one_shape.shape[1], 3,
220
- device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
221
- sample_tex_pose_list.append(tex_pos_one_shape)
222
- tex_pos = torch.cat(sample_tex_pose_list, dim=0)
223
-
224
- tex_feat = torch.utils.checkpoint.checkpoint(
225
- self.synthesizer.get_texture_prediction,
226
- planes,
227
- tex_pos,
228
- use_reentrant=False,
229
- )
230
-
231
- if hard_mask is not None:
232
- final_tex_feat = torch.zeros(
233
- planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
234
- expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
235
- for i in range(planes.shape[0]):
236
- final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
237
- tex_feat = final_tex_feat
238
-
239
- return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
240
-
241
- def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256):
242
- '''
243
- Function to render a generated mesh with nvdiffrast
244
- :param mesh_v: List of vertices for the mesh
245
- :param mesh_f: List of faces for the mesh
246
- :param cam_mv: 4x4 rotation matrix
247
- :return:
248
- '''
249
- return_value_list = []
250
- for i_mesh in range(len(mesh_v)):
251
- return_value = self.geometry.render_mesh(
252
- mesh_v[i_mesh],
253
- mesh_f[i_mesh].int(),
254
- cam_mv[i_mesh],
255
- resolution=render_size,
256
- hierarchical_mask=False
257
- )
258
- return_value_list.append(return_value)
259
-
260
- return_keys = return_value_list[0].keys()
261
- return_value = dict()
262
- for k in return_keys:
263
- value = [v[k] for v in return_value_list]
264
- return_value[k] = value
265
-
266
- mask = torch.cat(return_value['mask'], dim=0)
267
- hard_mask = torch.cat(return_value['hard_mask'], dim=0)
268
- tex_pos = return_value['tex_pos']
269
- depth = torch.cat(return_value['depth'], dim=0)
270
- normal = torch.cat(return_value['normal'], dim=0)
271
- return mask, hard_mask, tex_pos, depth, normal
272
-
273
- def forward_geometry(self, planes, render_cameras, render_size=256):
274
- '''
275
- Main function of our Generator. It first generate 3D mesh, then render it into 2D image
276
- with given `render_cameras`.
277
- :param planes: triplane features
278
- :param render_cameras: cameras to render generated 3D shape
279
- '''
280
- B, NV = render_cameras.shape[:2]
281
-
282
- # Generate 3D mesh first
283
- mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
284
-
285
- # Render the mesh into 2D image (get 3d position of each image plane)
286
- cam_mv = render_cameras
287
- run_n_view = cam_mv.shape[1]
288
- antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size)
289
-
290
- tex_hard_mask = hard_mask
291
- tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos]
292
- tex_hard_mask = torch.cat(
293
- [torch.cat(
294
- [tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1]
295
- for i_view in range(run_n_view)], dim=2)
296
- for i in range(planes.shape[0])], dim=0)
297
-
298
- # Querying the texture field to predict the texture feature for each pixel on the image
299
- tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask)
300
- background_feature = torch.ones_like(tex_feat) # white background
301
-
302
- # Merge them together
303
- img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask)
304
-
305
- # We should split it back to the original image shape
306
- img_feat = torch.cat(
307
- [torch.cat(
308
- [img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)]
309
- for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0)
310
-
311
- img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV))
312
- antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV))
313
- depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive
314
- normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV))
315
-
316
- out = {
317
- 'img': img,
318
- 'mask': antilias_mask,
319
- 'depth': depth,
320
- 'normal': normal,
321
- 'sdf': sdf,
322
- 'mesh_v': mesh_v,
323
- 'mesh_f': mesh_f,
324
- 'sdf_reg_loss': sdf_reg_loss,
325
- }
326
- return out
327
-
328
- def forward(self, images, cameras, render_cameras, render_size: int):
329
- # images: [B, V, C_img, H_img, W_img]
330
- # cameras: [B, V, 16]
331
- # render_cameras: [B, M, D_cam_render]
332
- # render_size: int
333
- B, M = render_cameras.shape[:2]
334
-
335
- planes = self.forward_planes(images, cameras)
336
- out = self.forward_geometry(planes, render_cameras, render_size=render_size)
337
-
338
- return {
339
- 'planes': planes,
340
- **out
341
- }
342
-
343
- def extract_mesh(
344
- self,
345
- planes: torch.Tensor,
346
- use_texture_map: bool = False,
347
- texture_resolution: int = 1024,
348
- **kwargs,
349
- ):
350
- '''
351
- Extract a 3D mesh from FlexiCubes. Only support batch_size 1.
352
- :param planes: triplane features
353
- :param use_texture_map: use texture map or vertex color
354
- :param texture_resolution: the resolution of texure map
355
- '''
356
- assert planes.shape[0] == 1
357
- device = planes.device
358
-
359
- # predict geometry first
360
- mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
361
- vertices, faces = mesh_v[0], mesh_f[0]
362
-
363
- if not use_texture_map:
364
- # query vertex colors
365
- vertices_tensor = vertices.unsqueeze(0)
366
- vertices_colors = self.synthesizer.get_texture_prediction(
367
- planes, vertices_tensor).clamp(0, 1).squeeze(0).cpu().numpy()
368
- vertices_colors = (vertices_colors * 255).astype(np.uint8)
369
-
370
- return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors
371
-
372
- # use x-atlas to get uv mapping for the mesh
373
- ctx = dr.RasterizeCudaContext(device=device)
374
- uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
375
- self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution)
376
- tex_hard_mask = tex_hard_mask.float()
377
-
378
- # query the texture field to get the RGB color for texture map
379
- tex_feat = self.get_texture_prediction(
380
- planes, [gb_pos], tex_hard_mask)
381
- background_feature = torch.zeros_like(tex_feat)
382
- img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask)
383
- texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
384
-
385
- return vertices, faces, uvs, mesh_tex_idx, texture_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/renderer/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
- #
4
- # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
- # property and proprietary rights in and to this material, related
6
- # documentation and any modifications thereto. Any use, reproduction,
7
- # disclosure or distribution of this material and related documentation
8
- # without an express license agreement from NVIDIA CORPORATION or
9
- # its affiliates is strictly prohibited.
 
 
 
 
 
 
 
 
 
 
src/models/renderer/synthesizer.py DELETED
@@ -1,203 +0,0 @@
1
- # ORIGINAL LICENSE
2
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4
- #
5
- # Modified by Jiale Xu
6
- # The modifications are subject to the same license as the original.
7
-
8
-
9
- import itertools
10
- import torch
11
- import torch.nn as nn
12
-
13
- from .utils.renderer import ImportanceRenderer
14
- from .utils.ray_sampler import RaySampler
15
-
16
-
17
- class OSGDecoder(nn.Module):
18
- """
19
- Triplane decoder that gives RGB and sigma values from sampled features.
20
- Using ReLU here instead of Softplus in the original implementation.
21
-
22
- Reference:
23
- EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
24
- """
25
- def __init__(self, n_features: int,
26
- hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
27
- super().__init__()
28
- self.net = nn.Sequential(
29
- nn.Linear(3 * n_features, hidden_dim),
30
- activation(),
31
- *itertools.chain(*[[
32
- nn.Linear(hidden_dim, hidden_dim),
33
- activation(),
34
- ] for _ in range(num_layers - 2)]),
35
- nn.Linear(hidden_dim, 1 + 3),
36
- )
37
- # init all bias to zero
38
- for m in self.modules():
39
- if isinstance(m, nn.Linear):
40
- nn.init.zeros_(m.bias)
41
-
42
- def forward(self, sampled_features, ray_directions):
43
- # Aggregate features by mean
44
- # sampled_features = sampled_features.mean(1)
45
- # Aggregate features by concatenation
46
- _N, n_planes, _M, _C = sampled_features.shape
47
- sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
48
- x = sampled_features
49
-
50
- N, M, C = x.shape
51
- x = x.contiguous().view(N*M, C)
52
-
53
- x = self.net(x)
54
- x = x.view(N, M, -1)
55
- rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
56
- sigma = x[..., 0:1]
57
-
58
- return {'rgb': rgb, 'sigma': sigma}
59
-
60
-
61
- class TriplaneSynthesizer(nn.Module):
62
- """
63
- Synthesizer that renders a triplane volume with planes and a camera.
64
-
65
- Reference:
66
- EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
67
- """
68
-
69
- DEFAULT_RENDERING_KWARGS = {
70
- 'ray_start': 'auto',
71
- 'ray_end': 'auto',
72
- 'box_warp': 2.,
73
- 'white_back': True,
74
- 'disparity_space_sampling': False,
75
- 'clamp_mode': 'softplus',
76
- 'sampler_bbox_min': -1.,
77
- 'sampler_bbox_max': 1.,
78
- }
79
-
80
- def __init__(self, triplane_dim: int, samples_per_ray: int):
81
- super().__init__()
82
-
83
- # attributes
84
- self.triplane_dim = triplane_dim
85
- self.rendering_kwargs = {
86
- **self.DEFAULT_RENDERING_KWARGS,
87
- 'depth_resolution': samples_per_ray // 2,
88
- 'depth_resolution_importance': samples_per_ray // 2,
89
- }
90
-
91
- # renderings
92
- self.renderer = ImportanceRenderer()
93
- self.ray_sampler = RaySampler()
94
-
95
- # modules
96
- self.decoder = OSGDecoder(n_features=triplane_dim)
97
-
98
- def forward(self, planes, cameras, render_size=128, crop_params=None):
99
- # planes: (N, 3, D', H', W')
100
- # cameras: (N, M, D_cam)
101
- # render_size: int
102
- assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras"
103
- N, M = cameras.shape[:2]
104
-
105
- cam2world_matrix = cameras[..., :16].view(N, M, 4, 4)
106
- intrinsics = cameras[..., 16:25].view(N, M, 3, 3)
107
-
108
- # Create a batch of rays for volume rendering
109
- ray_origins, ray_directions = self.ray_sampler(
110
- cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4),
111
- intrinsics=intrinsics.reshape(-1, 3, 3),
112
- render_size=render_size,
113
- )
114
- assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins"
115
- assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
116
-
117
- # Crop rays if crop_params is available
118
- if crop_params is not None:
119
- ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3)
120
- ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3)
121
- i, j, h, w = crop_params
122
- ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3)
123
- ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3)
124
-
125
- # Perform volume rendering
126
- rgb_samples, depth_samples, weights_samples = self.renderer(
127
- planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs,
128
- )
129
-
130
- # Reshape into 'raw' neural-rendered image
131
- if crop_params is not None:
132
- Himg, Wimg = crop_params[2:]
133
- else:
134
- Himg = Wimg = render_size
135
- rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous()
136
- depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
137
- weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
138
-
139
- out = {
140
- 'images_rgb': rgb_images,
141
- 'images_depth': depth_images,
142
- 'images_weight': weight_images,
143
- }
144
- return out
145
-
146
- def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
147
- # planes: (N, 3, D', H', W')
148
- # grid_size: int
149
- # aabb: (N, 2, 3)
150
- if aabb is None:
151
- aabb = torch.tensor([
152
- [self.rendering_kwargs['sampler_bbox_min']] * 3,
153
- [self.rendering_kwargs['sampler_bbox_max']] * 3,
154
- ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
155
- assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
156
- N = planes.shape[0]
157
-
158
- # create grid points for triplane query
159
- grid_points = []
160
- for i in range(N):
161
- grid_points.append(torch.stack(torch.meshgrid(
162
- torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
163
- torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
164
- torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
165
- indexing='ij',
166
- ), dim=-1).reshape(-1, 3))
167
- cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
168
-
169
- features = self.forward_points(planes, cube_grid)
170
-
171
- # reshape into grid
172
- features = {
173
- k: v.reshape(N, grid_size, grid_size, grid_size, -1)
174
- for k, v in features.items()
175
- }
176
- return features
177
-
178
- def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
179
- # planes: (N, 3, D', H', W')
180
- # points: (N, P, 3)
181
- N, P = points.shape[:2]
182
-
183
- # query triplane in chunks
184
- outs = []
185
- for i in range(0, points.shape[1], chunk_size):
186
- chunk_points = points[:, i:i+chunk_size]
187
-
188
- # query triplane
189
- chunk_out = self.renderer.run_model_activated(
190
- planes=planes,
191
- decoder=self.decoder,
192
- sample_coordinates=chunk_points,
193
- sample_directions=torch.zeros_like(chunk_points),
194
- options=self.rendering_kwargs,
195
- )
196
- outs.append(chunk_out)
197
-
198
- # concatenate the outputs
199
- point_features = {
200
- k: torch.cat([out[k] for out in outs], dim=1)
201
- for k in outs[0].keys()
202
- }
203
- return point_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/renderer/synthesizer_mesh.py DELETED
@@ -1,141 +0,0 @@
1
- # ORIGINAL LICENSE
2
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4
- #
5
- # Modified by Jiale Xu
6
- # The modifications are subject to the same license as the original.
7
-
8
- import itertools
9
- import torch
10
- import torch.nn as nn
11
-
12
- from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes
13
-
14
-
15
- class OSGDecoder(nn.Module):
16
- """
17
- Triplane decoder that gives RGB and sigma values from sampled features.
18
- Using ReLU here instead of Softplus in the original implementation.
19
-
20
- Reference:
21
- EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
22
- """
23
- def __init__(self, n_features: int,
24
- hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
25
- super().__init__()
26
-
27
- self.net_sdf = nn.Sequential(
28
- nn.Linear(3 * n_features, hidden_dim),
29
- activation(),
30
- *itertools.chain(*[[
31
- nn.Linear(hidden_dim, hidden_dim),
32
- activation(),
33
- ] for _ in range(num_layers - 2)]),
34
- nn.Linear(hidden_dim, 1),
35
- )
36
- self.net_rgb = nn.Sequential(
37
- nn.Linear(3 * n_features, hidden_dim),
38
- activation(),
39
- *itertools.chain(*[[
40
- nn.Linear(hidden_dim, hidden_dim),
41
- activation(),
42
- ] for _ in range(num_layers - 2)]),
43
- nn.Linear(hidden_dim, 3),
44
- )
45
- self.net_deformation = nn.Sequential(
46
- nn.Linear(3 * n_features, hidden_dim),
47
- activation(),
48
- *itertools.chain(*[[
49
- nn.Linear(hidden_dim, hidden_dim),
50
- activation(),
51
- ] for _ in range(num_layers - 2)]),
52
- nn.Linear(hidden_dim, 3),
53
- )
54
- self.net_weight = nn.Sequential(
55
- nn.Linear(8 * 3 * n_features, hidden_dim),
56
- activation(),
57
- *itertools.chain(*[[
58
- nn.Linear(hidden_dim, hidden_dim),
59
- activation(),
60
- ] for _ in range(num_layers - 2)]),
61
- nn.Linear(hidden_dim, 21),
62
- )
63
-
64
- # init all bias to zero
65
- for m in self.modules():
66
- if isinstance(m, nn.Linear):
67
- nn.init.zeros_(m.bias)
68
-
69
- def get_geometry_prediction(self, sampled_features, flexicubes_indices):
70
- _N, n_planes, _M, _C = sampled_features.shape
71
- sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
72
-
73
- sdf = self.net_sdf(sampled_features)
74
- deformation = self.net_deformation(sampled_features)
75
-
76
- grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1)
77
- grid_features = grid_features.reshape(
78
- sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1])
79
- weight = self.net_weight(grid_features) * 0.1
80
-
81
- return sdf, deformation, weight
82
-
83
- def get_texture_prediction(self, sampled_features):
84
- _N, n_planes, _M, _C = sampled_features.shape
85
- sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
86
-
87
- rgb = self.net_rgb(sampled_features)
88
- rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
89
-
90
- return rgb
91
-
92
-
93
- class TriplaneSynthesizer(nn.Module):
94
- """
95
- Synthesizer that renders a triplane volume with planes and a camera.
96
-
97
- Reference:
98
- EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
99
- """
100
-
101
- DEFAULT_RENDERING_KWARGS = {
102
- 'ray_start': 'auto',
103
- 'ray_end': 'auto',
104
- 'box_warp': 2.,
105
- 'white_back': True,
106
- 'disparity_space_sampling': False,
107
- 'clamp_mode': 'softplus',
108
- 'sampler_bbox_min': -1.,
109
- 'sampler_bbox_max': 1.,
110
- }
111
-
112
- def __init__(self, triplane_dim: int, samples_per_ray: int):
113
- super().__init__()
114
-
115
- # attributes
116
- self.triplane_dim = triplane_dim
117
- self.rendering_kwargs = {
118
- **self.DEFAULT_RENDERING_KWARGS,
119
- 'depth_resolution': samples_per_ray // 2,
120
- 'depth_resolution_importance': samples_per_ray // 2,
121
- }
122
-
123
- # modules
124
- self.plane_axes = generate_planes()
125
- self.decoder = OSGDecoder(n_features=triplane_dim)
126
-
127
- def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices):
128
- plane_axes = self.plane_axes.to(planes.device)
129
- sampled_features = sample_from_planes(
130
- plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])
131
-
132
- sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices)
133
- return sdf, deformation, weight
134
-
135
- def get_texture_prediction(self, planes, sample_coordinates):
136
- plane_axes = self.plane_axes.to(planes.device)
137
- sampled_features = sample_from_planes(
138
- plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])
139
-
140
- rgb = self.decoder.get_texture_prediction(sampled_features)
141
- return rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/renderer/utils/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
- #
4
- # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
- # property and proprietary rights in and to this material, related
6
- # documentation and any modifications thereto. Any use, reproduction,
7
- # disclosure or distribution of this material and related documentation
8
- # without an express license agreement from NVIDIA CORPORATION or
9
- # its affiliates is strictly prohibited.
 
 
 
 
 
 
 
 
 
 
src/models/renderer/utils/math_utils.py DELETED
@@ -1,118 +0,0 @@
1
- # MIT License
2
-
3
- # Copyright (c) 2022 Petr Kellnhofer
4
-
5
- # Permission is hereby granted, free of charge, to any person obtaining a copy
6
- # of this software and associated documentation files (the "Software"), to deal
7
- # in the Software without restriction, including without limitation the rights
8
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
-
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
-
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
-
23
- import torch
24
-
25
- def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
26
- """
27
- Left-multiplies MxM @ NxM. Returns NxM.
28
- """
29
- res = torch.matmul(vectors4, matrix.T)
30
- return res
31
-
32
-
33
- def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
34
- """
35
- Normalize vector lengths.
36
- """
37
- return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
38
-
39
- def torch_dot(x: torch.Tensor, y: torch.Tensor):
40
- """
41
- Dot product of two tensors.
42
- """
43
- return (x * y).sum(-1)
44
-
45
-
46
- def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
47
- """
48
- Author: Petr Kellnhofer
49
- Intersects rays with the [-1, 1] NDC volume.
50
- Returns min and max distance of entry.
51
- Returns -1 for no intersection.
52
- https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
53
- """
54
- o_shape = rays_o.shape
55
- rays_o = rays_o.detach().reshape(-1, 3)
56
- rays_d = rays_d.detach().reshape(-1, 3)
57
-
58
-
59
- bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
60
- bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
61
- bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
62
- is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
63
-
64
- # Precompute inverse for stability.
65
- invdir = 1 / rays_d
66
- sign = (invdir < 0).long()
67
-
68
- # Intersect with YZ plane.
69
- tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
70
- tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
71
-
72
- # Intersect with XZ plane.
73
- tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
74
- tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
75
-
76
- # Resolve parallel rays.
77
- is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
78
-
79
- # Use the shortest intersection.
80
- tmin = torch.max(tmin, tymin)
81
- tmax = torch.min(tmax, tymax)
82
-
83
- # Intersect with XY plane.
84
- tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
85
- tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
86
-
87
- # Resolve parallel rays.
88
- is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
89
-
90
- # Use the shortest intersection.
91
- tmin = torch.max(tmin, tzmin)
92
- tmax = torch.min(tmax, tzmax)
93
-
94
- # Mark invalid.
95
- tmin[torch.logical_not(is_valid)] = -1
96
- tmax[torch.logical_not(is_valid)] = -2
97
-
98
- return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
99
-
100
-
101
- def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
102
- """
103
- Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
104
- Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
105
- """
106
- # create a tensor of 'num' steps from 0 to 1
107
- steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
108
-
109
- # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
110
- # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
111
- # "cannot statically infer the expected size of a list in this contex", hence the code below
112
- for i in range(start.ndim):
113
- steps = steps.unsqueeze(-1)
114
-
115
- # the output starts at 'start' and increments until 'stop' in each dimension
116
- out = start[None] + steps * (stop - start)[None]
117
-
118
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/renderer/utils/ray_marcher.py DELETED
@@ -1,72 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
- #
4
- # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
- # property and proprietary rights in and to this material, related
6
- # documentation and any modifications thereto. Any use, reproduction,
7
- # disclosure or distribution of this material and related documentation
8
- # without an express license agreement from NVIDIA CORPORATION or
9
- # its affiliates is strictly prohibited.
10
- #
11
- # Modified by Jiale Xu
12
- # The modifications are subject to the same license as the original.
13
-
14
-
15
- """
16
- The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
17
- Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
18
- """
19
-
20
- import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
-
24
-
25
- class MipRayMarcher2(nn.Module):
26
- def __init__(self, activation_factory):
27
- super().__init__()
28
- self.activation_factory = activation_factory
29
-
30
- def run_forward(self, colors, densities, depths, rendering_options, normals=None):
31
- dtype = colors.dtype
32
- deltas = depths[:, :, 1:] - depths[:, :, :-1]
33
- colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
34
- densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
35
- depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
36
-
37
- # using factory mode for better usability
38
- densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype)
39
-
40
- density_delta = densities_mid * deltas
41
-
42
- alpha = 1 - torch.exp(-density_delta).to(dtype)
43
-
44
- alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
45
- weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
46
- weights = weights.to(dtype)
47
-
48
- composite_rgb = torch.sum(weights * colors_mid, -2)
49
- weight_total = weights.sum(2)
50
- # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
51
- composite_depth = torch.sum(weights * depths_mid, -2)
52
-
53
- # clip the composite to min/max range of depths
54
- composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype)
55
- composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
56
-
57
- if rendering_options.get('white_back', False):
58
- composite_rgb = composite_rgb + 1 - weight_total
59
-
60
- # rendered value scale is 0-1, comment out original mipnerf scaling
61
- # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
62
-
63
- return composite_rgb, composite_depth, weights
64
-
65
-
66
- def forward(self, colors, densities, depths, rendering_options, normals=None):
67
- if normals is not None:
68
- composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals)
69
- return composite_rgb, composite_depth, composite_normals, weights
70
-
71
- composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
72
- return composite_rgb, composite_depth, weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/renderer/utils/ray_sampler.py DELETED
@@ -1,141 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
- #
4
- # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
- # property and proprietary rights in and to this material, related
6
- # documentation and any modifications thereto. Any use, reproduction,
7
- # disclosure or distribution of this material and related documentation
8
- # without an express license agreement from NVIDIA CORPORATION or
9
- # its affiliates is strictly prohibited.
10
- #
11
- # Modified by Jiale Xu
12
- # The modifications are subject to the same license as the original.
13
-
14
-
15
- """
16
- The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
17
- Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
18
- """
19
-
20
- import torch
21
-
22
- class RaySampler(torch.nn.Module):
23
- def __init__(self):
24
- super().__init__()
25
- self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
26
-
27
-
28
- def forward(self, cam2world_matrix, intrinsics, render_size):
29
- """
30
- Create batches of rays and return origins and directions.
31
-
32
- cam2world_matrix: (N, 4, 4)
33
- intrinsics: (N, 3, 3)
34
- render_size: int
35
-
36
- ray_origins: (N, M, 3)
37
- ray_dirs: (N, M, 2)
38
- """
39
-
40
- dtype = cam2world_matrix.dtype
41
- device = cam2world_matrix.device
42
- N, M = cam2world_matrix.shape[0], render_size**2
43
- cam_locs_world = cam2world_matrix[:, :3, 3]
44
- fx = intrinsics[:, 0, 0]
45
- fy = intrinsics[:, 1, 1]
46
- cx = intrinsics[:, 0, 2]
47
- cy = intrinsics[:, 1, 2]
48
- sk = intrinsics[:, 0, 1]
49
-
50
- uv = torch.stack(torch.meshgrid(
51
- torch.arange(render_size, dtype=dtype, device=device),
52
- torch.arange(render_size, dtype=dtype, device=device),
53
- indexing='ij',
54
- ))
55
- uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
56
- uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
57
-
58
- x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
59
- y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
60
- z_cam = torch.ones((N, M), dtype=dtype, device=device)
61
-
62
- x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
63
- y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
64
-
65
- cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype)
66
-
67
- _opencv2blender = torch.tensor([
68
- [1, 0, 0, 0],
69
- [0, -1, 0, 0],
70
- [0, 0, -1, 0],
71
- [0, 0, 0, 1],
72
- ], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1)
73
-
74
- cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
75
-
76
- world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
77
-
78
- ray_dirs = world_rel_points - cam_locs_world[:, None, :]
79
- ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype)
80
-
81
- ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
82
-
83
- return ray_origins, ray_dirs
84
-
85
-
86
- class OrthoRaySampler(torch.nn.Module):
87
- def __init__(self):
88
- super().__init__()
89
- self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
90
-
91
-
92
- def forward(self, cam2world_matrix, ortho_scale, render_size):
93
- """
94
- Create batches of rays and return origins and directions.
95
-
96
- cam2world_matrix: (N, 4, 4)
97
- ortho_scale: float
98
- render_size: int
99
-
100
- ray_origins: (N, M, 3)
101
- ray_dirs: (N, M, 3)
102
- """
103
-
104
- N, M = cam2world_matrix.shape[0], render_size**2
105
-
106
- uv = torch.stack(torch.meshgrid(
107
- torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
108
- torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
109
- indexing='ij',
110
- ))
111
- uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
112
- uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
113
-
114
- x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
115
- y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
116
- z_cam = torch.zeros((N, M), device=cam2world_matrix.device)
117
-
118
- x_lift = (x_cam - 0.5) * ortho_scale
119
- y_lift = (y_cam - 0.5) * ortho_scale
120
-
121
- cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
122
-
123
- _opencv2blender = torch.tensor([
124
- [1, 0, 0, 0],
125
- [0, -1, 0, 0],
126
- [0, 0, -1, 0],
127
- [0, 0, 0, 1],
128
- ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
129
-
130
- cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
131
-
132
- ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
133
-
134
- ray_dirs_cam = torch.stack([
135
- torch.zeros((N, M), device=cam2world_matrix.device),
136
- torch.zeros((N, M), device=cam2world_matrix.device),
137
- torch.ones((N, M), device=cam2world_matrix.device),
138
- ], dim=-1)
139
- ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1)
140
-
141
- return ray_origins, ray_dirs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/renderer/utils/renderer.py DELETED
@@ -1,323 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
- #
4
- # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
- # property and proprietary rights in and to this material, related
6
- # documentation and any modifications thereto. Any use, reproduction,
7
- # disclosure or distribution of this material and related documentation
8
- # without an express license agreement from NVIDIA CORPORATION or
9
- # its affiliates is strictly prohibited.
10
- #
11
- # Modified by Jiale Xu
12
- # The modifications are subject to the same license as the original.
13
-
14
-
15
- """
16
- The renderer is a module that takes in rays, decides where to sample along each
17
- ray, and computes pixel colors using the volume rendering equation.
18
- """
19
-
20
- import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
-
24
- from .ray_marcher import MipRayMarcher2
25
- from . import math_utils
26
-
27
-
28
- def generate_planes():
29
- """
30
- Defines planes by the three vectors that form the "axes" of the
31
- plane. Should work with arbitrary number of planes and planes of
32
- arbitrary orientation.
33
-
34
- Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
35
- """
36
- return torch.tensor([[[1, 0, 0],
37
- [0, 1, 0],
38
- [0, 0, 1]],
39
- [[1, 0, 0],
40
- [0, 0, 1],
41
- [0, 1, 0]],
42
- [[0, 0, 1],
43
- [0, 1, 0],
44
- [1, 0, 0]]], dtype=torch.float32)
45
-
46
- def project_onto_planes(planes, coordinates):
47
- """
48
- Does a projection of a 3D point onto a batch of 2D planes,
49
- returning 2D plane coordinates.
50
-
51
- Takes plane axes of shape n_planes, 3, 3
52
- # Takes coordinates of shape N, M, 3
53
- # returns projections of shape N*n_planes, M, 2
54
- """
55
- N, M, C = coordinates.shape
56
- n_planes, _, _ = planes.shape
57
- coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
58
- inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
59
- projections = torch.bmm(coordinates, inv_planes)
60
- return projections[..., :2]
61
-
62
- def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
63
- assert padding_mode == 'zeros'
64
- N, n_planes, C, H, W = plane_features.shape
65
- _, M, _ = coordinates.shape
66
- plane_features = plane_features.view(N*n_planes, C, H, W)
67
- dtype = plane_features.dtype
68
-
69
- coordinates = (2/box_warp) * coordinates # add specific box bounds
70
-
71
- projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
72
- output_features = torch.nn.functional.grid_sample(
73
- plane_features,
74
- projected_coordinates.to(dtype),
75
- mode=mode,
76
- padding_mode=padding_mode,
77
- align_corners=False,
78
- ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
79
- return output_features
80
-
81
- def sample_from_3dgrid(grid, coordinates):
82
- """
83
- Expects coordinates in shape (batch_size, num_points_per_batch, 3)
84
- Expects grid in shape (1, channels, H, W, D)
85
- (Also works if grid has batch size)
86
- Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
87
- """
88
- batch_size, n_coords, n_dims = coordinates.shape
89
- sampled_features = torch.nn.functional.grid_sample(
90
- grid.expand(batch_size, -1, -1, -1, -1),
91
- coordinates.reshape(batch_size, 1, 1, -1, n_dims),
92
- mode='bilinear',
93
- padding_mode='zeros',
94
- align_corners=False,
95
- )
96
- N, C, H, W, D = sampled_features.shape
97
- sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
98
- return sampled_features
99
-
100
- class ImportanceRenderer(torch.nn.Module):
101
- """
102
- Modified original version to filter out-of-box samples as TensoRF does.
103
-
104
- Reference:
105
- TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
106
- """
107
- def __init__(self):
108
- super().__init__()
109
- self.activation_factory = self._build_activation_factory()
110
- self.ray_marcher = MipRayMarcher2(self.activation_factory)
111
- self.plane_axes = generate_planes()
112
-
113
- def _build_activation_factory(self):
114
- def activation_factory(options: dict):
115
- if options['clamp_mode'] == 'softplus':
116
- return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
117
- else:
118
- assert False, "Renderer only supports `clamp_mode`=`softplus`!"
119
- return activation_factory
120
-
121
- def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
122
- planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
123
- """
124
- Additional filtering is applied to filter out-of-box samples.
125
- Modifications made by Zexin He.
126
- """
127
-
128
- # context related variables
129
- batch_size, num_rays, samples_per_ray, _ = depths.shape
130
- device = depths.device
131
-
132
- # define sample points with depths
133
- sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
134
- sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
135
-
136
- # filter out-of-box samples
137
- mask_inbox = \
138
- (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
139
- (sample_coordinates <= rendering_options['sampler_bbox_max'])
140
- mask_inbox = mask_inbox.all(-1)
141
-
142
- # forward model according to all samples
143
- _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
144
-
145
- # set out-of-box samples to zeros(rgb) & -inf(sigma)
146
- SAFE_GUARD = 3
147
- DATA_TYPE = _out['sigma'].dtype
148
- colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
149
- densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
150
- colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox]
151
-
152
- # reshape back
153
- colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
154
- densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
155
-
156
- return colors_pass, densities_pass
157
-
158
- def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
159
- # self.plane_axes = self.plane_axes.to(ray_origins.device)
160
-
161
- if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
162
- ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
163
- is_ray_valid = ray_end > ray_start
164
- if torch.any(is_ray_valid).item():
165
- ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
166
- ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
167
- depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
168
- else:
169
- # Create stratified depth samples
170
- depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
171
-
172
- # Coarse Pass
173
- colors_coarse, densities_coarse = self._forward_pass(
174
- depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
175
- planes=planes, decoder=decoder, rendering_options=rendering_options)
176
-
177
- # Fine Pass
178
- N_importance = rendering_options['depth_resolution_importance']
179
- if N_importance > 0:
180
- _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
181
-
182
- depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
183
-
184
- colors_fine, densities_fine = self._forward_pass(
185
- depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
186
- planes=planes, decoder=decoder, rendering_options=rendering_options)
187
-
188
- all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
189
- depths_fine, colors_fine, densities_fine)
190
-
191
- rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
192
- else:
193
- rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
194
-
195
- return rgb_final, depth_final, weights.sum(2)
196
-
197
- def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
198
- plane_axes = self.plane_axes.to(planes.device)
199
- sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
200
-
201
- out = decoder(sampled_features, sample_directions)
202
- if options.get('density_noise', 0) > 0:
203
- out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
204
- return out
205
-
206
- def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
207
- out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
208
- out['sigma'] = self.activation_factory(options)(out['sigma'])
209
- return out
210
-
211
- def sort_samples(self, all_depths, all_colors, all_densities):
212
- _, indices = torch.sort(all_depths, dim=-2)
213
- all_depths = torch.gather(all_depths, -2, indices)
214
- all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
215
- all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
216
- return all_depths, all_colors, all_densities
217
-
218
- def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None):
219
- all_depths = torch.cat([depths1, depths2], dim = -2)
220
- all_colors = torch.cat([colors1, colors2], dim = -2)
221
- all_densities = torch.cat([densities1, densities2], dim = -2)
222
-
223
- if normals1 is not None and normals2 is not None:
224
- all_normals = torch.cat([normals1, normals2], dim = -2)
225
- else:
226
- all_normals = None
227
-
228
- _, indices = torch.sort(all_depths, dim=-2)
229
- all_depths = torch.gather(all_depths, -2, indices)
230
- all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
231
- all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
232
-
233
- if all_normals is not None:
234
- all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1]))
235
- return all_depths, all_colors, all_normals, all_densities
236
-
237
- return all_depths, all_colors, all_densities
238
-
239
- def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
240
- """
241
- Return depths of approximately uniformly spaced samples along rays.
242
- """
243
- N, M, _ = ray_origins.shape
244
- if disparity_space_sampling:
245
- depths_coarse = torch.linspace(0,
246
- 1,
247
- depth_resolution,
248
- device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
249
- depth_delta = 1/(depth_resolution - 1)
250
- depths_coarse += torch.rand_like(depths_coarse) * depth_delta
251
- depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
252
- else:
253
- if type(ray_start) == torch.Tensor:
254
- depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
255
- depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
256
- depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
257
- else:
258
- depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
259
- depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
260
- depths_coarse += torch.rand_like(depths_coarse) * depth_delta
261
-
262
- return depths_coarse
263
-
264
- def sample_importance(self, z_vals, weights, N_importance):
265
- """
266
- Return depths of importance sampled points along rays. See NeRF importance sampling for more.
267
- """
268
- with torch.no_grad():
269
- batch_size, num_rays, samples_per_ray, _ = z_vals.shape
270
-
271
- z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
272
- weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
273
-
274
- # smooth weights
275
- weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1), 2, 1, padding=1)
276
- weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
277
- weights = weights + 0.01
278
-
279
- z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
280
- importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
281
- N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
282
- return importance_z_vals
283
-
284
- def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
285
- """
286
- Sample @N_importance samples from @bins with distribution defined by @weights.
287
- Inputs:
288
- bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
289
- weights: (N_rays, N_samples_)
290
- N_importance: the number of samples to draw from the distribution
291
- det: deterministic or not
292
- eps: a small number to prevent division by zero
293
- Outputs:
294
- samples: the sampled samples
295
- """
296
- N_rays, N_samples_ = weights.shape
297
- weights = weights + eps # prevent division by zero (don't do inplace op!)
298
- pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
299
- cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
300
- cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
301
- # padded to 0~1 inclusive
302
-
303
- if det:
304
- u = torch.linspace(0, 1, N_importance, device=bins.device)
305
- u = u.expand(N_rays, N_importance)
306
- else:
307
- u = torch.rand(N_rays, N_importance, device=bins.device)
308
- u = u.contiguous()
309
-
310
- inds = torch.searchsorted(cdf, u, right=True)
311
- below = torch.clamp_min(inds-1, 0)
312
- above = torch.clamp_max(inds, N_samples_)
313
-
314
- inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
315
- cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
316
- bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
317
-
318
- denom = cdf_g[...,1]-cdf_g[...,0]
319
- denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
320
- # anyway, therefore any value for it is fine (set to 1 here)
321
-
322
- samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
323
- return samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/__init__.py DELETED
File without changes
src/utils/camera_util.py DELETED
@@ -1,111 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import numpy as np
4
-
5
-
6
- def pad_camera_extrinsics_4x4(extrinsics):
7
- if extrinsics.shape[-2] == 4:
8
- return extrinsics
9
- padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
10
- if extrinsics.ndim == 3:
11
- padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
12
- extrinsics = torch.cat([extrinsics, padding], dim=-2)
13
- return extrinsics
14
-
15
-
16
- def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
17
- """
18
- Create OpenGL camera extrinsics from camera locations and look-at position.
19
-
20
- camera_position: (M, 3) or (3,)
21
- look_at: (3)
22
- up_world: (3)
23
- return: (M, 3, 4) or (3, 4)
24
- """
25
- # by default, looking at the origin and world up is z-axis
26
- if look_at is None:
27
- look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
28
- if up_world is None:
29
- up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
30
- if camera_position.ndim == 2:
31
- look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
32
- up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
33
-
34
- # OpenGL camera: z-backward, x-right, y-up
35
- z_axis = camera_position - look_at
36
- z_axis = F.normalize(z_axis, dim=-1).float()
37
- x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
38
- x_axis = F.normalize(x_axis, dim=-1).float()
39
- y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
40
- y_axis = F.normalize(y_axis, dim=-1).float()
41
-
42
- extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
43
- extrinsics = pad_camera_extrinsics_4x4(extrinsics)
44
- return extrinsics
45
-
46
-
47
- def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
48
- azimuths = np.deg2rad(azimuths)
49
- elevations = np.deg2rad(elevations)
50
-
51
- xs = radius * np.cos(elevations) * np.cos(azimuths)
52
- ys = radius * np.cos(elevations) * np.sin(azimuths)
53
- zs = radius * np.sin(elevations)
54
-
55
- cam_locations = np.stack([xs, ys, zs], axis=-1)
56
- cam_locations = torch.from_numpy(cam_locations).float()
57
-
58
- c2ws = center_looking_at_camera_pose(cam_locations)
59
- return c2ws
60
-
61
-
62
- def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0):
63
- # M: number of circular views
64
- # radius: camera dist to center
65
- # elevation: elevation degrees of the camera
66
- # return: (M, 4, 4)
67
- assert M > 0 and radius > 0
68
-
69
- elevation = np.deg2rad(elevation)
70
-
71
- camera_positions = []
72
- for i in range(M):
73
- azimuth = 2 * np.pi * i / M
74
- x = radius * np.cos(elevation) * np.cos(azimuth)
75
- y = radius * np.cos(elevation) * np.sin(azimuth)
76
- z = radius * np.sin(elevation)
77
- camera_positions.append([x, y, z])
78
- camera_positions = np.array(camera_positions)
79
- camera_positions = torch.from_numpy(camera_positions).float()
80
- extrinsics = center_looking_at_camera_pose(camera_positions)
81
- return extrinsics
82
-
83
-
84
- def FOV_to_intrinsics(fov, device='cpu'):
85
- """
86
- Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
87
- Note the intrinsics are returned as normalized by image size, rather than in pixel units.
88
- Assumes principal point is at image center.
89
- """
90
- focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
91
- intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
92
- return intrinsics
93
-
94
-
95
- def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
96
- """
97
- Get the input camera parameters.
98
- """
99
- azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
100
- elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
101
-
102
- c2ws = spherical_camera_pose(azimuths, elevations, radius)
103
- c2ws = c2ws.float().flatten(-2)
104
-
105
- Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)
106
-
107
- extrinsics = c2ws[:, :12]
108
- intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
109
- cameras = torch.cat([extrinsics, intrinsics], dim=-1)
110
-
111
- return cameras.unsqueeze(0).repeat(batch_size, 1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/infer_util.py DELETED
@@ -1,84 +0,0 @@
1
- import os
2
- import imageio
3
- import rembg
4
- import torch
5
- import numpy as np
6
- import PIL.Image
7
- from PIL import Image
8
- from typing import Any
9
-
10
-
11
- def remove_background(image: PIL.Image.Image,
12
- rembg_session: Any = None,
13
- force: bool = False,
14
- **rembg_kwargs,
15
- ) -> PIL.Image.Image:
16
- do_remove = True
17
- if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
18
- do_remove = False
19
- do_remove = do_remove or force
20
- if do_remove:
21
- image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
22
- return image
23
-
24
-
25
- def resize_foreground(
26
- image: PIL.Image.Image,
27
- ratio: float,
28
- ) -> PIL.Image.Image:
29
- image = np.array(image)
30
- assert image.shape[-1] == 4
31
- alpha = np.where(image[..., 3] > 0)
32
- y1, y2, x1, x2 = (
33
- alpha[0].min(),
34
- alpha[0].max(),
35
- alpha[1].min(),
36
- alpha[1].max(),
37
- )
38
- # crop the foreground
39
- fg = image[y1:y2, x1:x2]
40
- # pad to square
41
- size = max(fg.shape[0], fg.shape[1])
42
- ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
43
- ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
44
- new_image = np.pad(
45
- fg,
46
- ((ph0, ph1), (pw0, pw1), (0, 0)),
47
- mode="constant",
48
- constant_values=((0, 0), (0, 0), (0, 0)),
49
- )
50
-
51
- # compute padding according to the ratio
52
- new_size = int(new_image.shape[0] / ratio)
53
- # pad to size, double side
54
- ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
55
- ph1, pw1 = new_size - size - ph0, new_size - size - pw0
56
- new_image = np.pad(
57
- new_image,
58
- ((ph0, ph1), (pw0, pw1), (0, 0)),
59
- mode="constant",
60
- constant_values=((0, 0), (0, 0), (0, 0)),
61
- )
62
- new_image = PIL.Image.fromarray(new_image)
63
- return new_image
64
-
65
-
66
- def images_to_video(
67
- images: torch.Tensor,
68
- output_path: str,
69
- fps: int = 30,
70
- ) -> None:
71
- # images: (N, C, H, W)
72
- video_dir = os.path.dirname(output_path)
73
- video_name = os.path.basename(output_path)
74
- os.makedirs(video_dir, exist_ok=True)
75
-
76
- frames = []
77
- for i in range(len(images)):
78
- frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
79
- assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
80
- f"Frame shape mismatch: {frame.shape} vs {images.shape}"
81
- assert frame.min() >= 0 and frame.max() <= 255, \
82
- f"Frame value out of range: {frame.min()} ~ {frame.max()}"
83
- frames.append(frame)
84
- imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/mesh_util.py DELETED
@@ -1,181 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- import xatlas
11
- import trimesh
12
- import cv2
13
- import numpy as np
14
- import nvdiffrast.torch as dr
15
- from PIL import Image
16
-
17
-
18
- def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath):
19
-
20
- pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
21
- facenp_fx3 = facenp_fx3[:, [2, 1, 0]]
22
-
23
- mesh = trimesh.Trimesh(
24
- vertices=pointnp_px3,
25
- faces=facenp_fx3,
26
- vertex_colors=colornp_px3,
27
- )
28
- mesh.export(fpath, 'obj')
29
-
30
-
31
- def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath):
32
-
33
- pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]])
34
-
35
- mesh = trimesh.Trimesh(
36
- vertices=pointnp_px3,
37
- faces=facenp_fx3,
38
- vertex_colors=colornp_px3,
39
- )
40
- mesh.export(fpath, 'glb')
41
-
42
-
43
- def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname):
44
- import os
45
- fol, na = os.path.split(fname)
46
- na, _ = os.path.splitext(na)
47
-
48
- matname = '%s/%s.mtl' % (fol, na)
49
- fid = open(matname, 'w')
50
- fid.write('newmtl material_0\n')
51
- fid.write('Kd 1 1 1\n')
52
- fid.write('Ka 0 0 0\n')
53
- fid.write('Ks 0.4 0.4 0.4\n')
54
- fid.write('Ns 10\n')
55
- fid.write('illum 2\n')
56
- fid.write('map_Kd %s.png\n' % na)
57
- fid.close()
58
- ####
59
-
60
- fid = open(fname, 'w')
61
- fid.write('mtllib %s.mtl\n' % na)
62
-
63
- for pidx, p in enumerate(pointnp_px3):
64
- pp = p
65
- fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2]))
66
-
67
- for pidx, p in enumerate(tcoords_px2):
68
- pp = p
69
- fid.write('vt %f %f\n' % (pp[0], pp[1]))
70
-
71
- fid.write('usemtl material_0\n')
72
- for i, f in enumerate(facenp_fx3):
73
- f1 = f + 1
74
- f2 = facetex_fx3[i] + 1
75
- fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
76
- fid.close()
77
-
78
- # save texture map
79
- lo, hi = 0, 1
80
- img = np.asarray(texmap_hxwx3, dtype=np.float32)
81
- img = (img - lo) * (255 / (hi - lo))
82
- img = img.clip(0, 255)
83
- mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True)
84
- mask = (mask <= 3.0).astype(np.float32)
85
- kernel = np.ones((3, 3), 'uint8')
86
- dilate_img = cv2.dilate(img, kernel, iterations=1)
87
- img = img * (1 - mask) + dilate_img * mask
88
- img = img.clip(0, 255).astype(np.uint8)
89
- Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png')
90
-
91
-
92
- def loadobj(meshfile):
93
- v = []
94
- f = []
95
- meshfp = open(meshfile, 'r')
96
- for line in meshfp.readlines():
97
- data = line.strip().split(' ')
98
- data = [da for da in data if len(da) > 0]
99
- if len(data) != 4:
100
- continue
101
- if data[0] == 'v':
102
- v.append([float(d) for d in data[1:]])
103
- if data[0] == 'f':
104
- data = [da.split('/')[0] for da in data]
105
- f.append([int(d) for d in data[1:]])
106
- meshfp.close()
107
-
108
- # torch need int64
109
- facenp_fx3 = np.array(f, dtype=np.int64) - 1
110
- pointnp_px3 = np.array(v, dtype=np.float32)
111
- return pointnp_px3, facenp_fx3
112
-
113
-
114
- def loadobjtex(meshfile):
115
- v = []
116
- vt = []
117
- f = []
118
- ft = []
119
- meshfp = open(meshfile, 'r')
120
- for line in meshfp.readlines():
121
- data = line.strip().split(' ')
122
- data = [da for da in data if len(da) > 0]
123
- if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)):
124
- continue
125
- if data[0] == 'v':
126
- assert len(data) == 4
127
-
128
- v.append([float(d) for d in data[1:]])
129
- if data[0] == 'vt':
130
- if len(data) == 3 or len(data) == 4:
131
- vt.append([float(d) for d in data[1:3]])
132
- if data[0] == 'f':
133
- data = [da.split('/') for da in data]
134
- if len(data) == 4:
135
- f.append([int(d[0]) for d in data[1:]])
136
- ft.append([int(d[1]) for d in data[1:]])
137
- elif len(data) == 5:
138
- idx1 = [1, 2, 3]
139
- data1 = [data[i] for i in idx1]
140
- f.append([int(d[0]) for d in data1])
141
- ft.append([int(d[1]) for d in data1])
142
- idx2 = [1, 3, 4]
143
- data2 = [data[i] for i in idx2]
144
- f.append([int(d[0]) for d in data2])
145
- ft.append([int(d[1]) for d in data2])
146
- meshfp.close()
147
-
148
- # torch need int64
149
- facenp_fx3 = np.array(f, dtype=np.int64) - 1
150
- ftnp_fx3 = np.array(ft, dtype=np.int64) - 1
151
- pointnp_px3 = np.array(v, dtype=np.float32)
152
- uvs = np.array(vt, dtype=np.float32)
153
- return pointnp_px3, facenp_fx3, uvs, ftnp_fx3
154
-
155
-
156
- # ==============================================================================================
157
- def interpolate(attr, rast, attr_idx, rast_db=None):
158
- return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
159
-
160
-
161
- def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
162
- vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
163
-
164
- # Convert to tensors
165
- indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
166
-
167
- uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
168
- mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
169
- # mesh_v_tex. ture
170
- uv_clip = uvs[None, ...] * 2.0 - 1.0
171
-
172
- # pad to four component coordinate
173
- uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
174
-
175
- # rasterize
176
- rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
177
-
178
- # Interpolate world space position
179
- gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
180
- mask = rast[..., 3:4] > 0
181
- return uvs, mesh_tex_idx, gb_pos, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/train_util.py DELETED
@@ -1,26 +0,0 @@
1
- import importlib
2
-
3
-
4
- def count_params(model, verbose=False):
5
- total_params = sum(p.numel() for p in model.parameters())
6
- if verbose:
7
- print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
8
- return total_params
9
-
10
-
11
- def instantiate_from_config(config):
12
- if not "target" in config:
13
- if config == '__is_first_stage__':
14
- return None
15
- elif config == "__is_unconditional__":
16
- return None
17
- raise KeyError("Expected key `target` to instantiate.")
18
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
19
-
20
-
21
- def get_obj_from_str(string, reload=False):
22
- module, cls = string.rsplit(".", 1)
23
- if reload:
24
- module_imp = importlib.import_module(module)
25
- importlib.reload(module_imp)
26
- return getattr(importlib.import_module(module, package=None), cls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
zero123plus/pipeline.py DELETED
@@ -1,406 +0,0 @@
1
- from typing import Any, Dict, Optional
2
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
3
- from diffusers.schedulers import KarrasDiffusionSchedulers
4
-
5
- import numpy
6
- import torch
7
- import torch.nn as nn
8
- import torch.utils.checkpoint
9
- import torch.distributed
10
- import transformers
11
- from collections import OrderedDict
12
- from PIL import Image
13
- from torchvision import transforms
14
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
15
-
16
- import diffusers
17
- from diffusers import (
18
- AutoencoderKL,
19
- DDPMScheduler,
20
- DiffusionPipeline,
21
- EulerAncestralDiscreteScheduler,
22
- UNet2DConditionModel,
23
- ImagePipelineOutput
24
- )
25
- from diffusers.image_processor import VaeImageProcessor
26
- from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0
27
- from diffusers.utils.import_utils import is_xformers_available
28
-
29
-
30
- def to_rgb_image(maybe_rgba: Image.Image):
31
- if maybe_rgba.mode == 'RGB':
32
- return maybe_rgba
33
- elif maybe_rgba.mode == 'RGBA':
34
- rgba = maybe_rgba
35
- img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
36
- img = Image.fromarray(img, 'RGB')
37
- img.paste(rgba, mask=rgba.getchannel('A'))
38
- return img
39
- else:
40
- raise ValueError("Unsupported image type.", maybe_rgba.mode)
41
-
42
-
43
- class ReferenceOnlyAttnProc(torch.nn.Module):
44
- def __init__(
45
- self,
46
- chained_proc,
47
- enabled=False,
48
- name=None
49
- ) -> None:
50
- super().__init__()
51
- self.enabled = enabled
52
- self.chained_proc = chained_proc
53
- self.name = name
54
-
55
- def __call__(
56
- self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
57
- mode="w", ref_dict: dict = None, is_cfg_guidance = False
58
- ) -> Any:
59
- if encoder_hidden_states is None:
60
- encoder_hidden_states = hidden_states
61
- if self.enabled and is_cfg_guidance:
62
- res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask)
63
- hidden_states = hidden_states[1:]
64
- encoder_hidden_states = encoder_hidden_states[1:]
65
- if self.enabled:
66
- if mode == 'w':
67
- ref_dict[self.name] = encoder_hidden_states
68
- elif mode == 'r':
69
- encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
70
- elif mode == 'm':
71
- encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
72
- else:
73
- assert False, mode
74
- res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
75
- if self.enabled and is_cfg_guidance:
76
- res = torch.cat([res0, res])
77
- return res
78
-
79
-
80
- class RefOnlyNoisedUNet(torch.nn.Module):
81
- def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
82
- super().__init__()
83
- self.unet = unet
84
- self.train_sched = train_sched
85
- self.val_sched = val_sched
86
-
87
- unet_lora_attn_procs = dict()
88
- for name, _ in unet.attn_processors.items():
89
- if torch.__version__ >= '2.0':
90
- default_attn_proc = AttnProcessor2_0()
91
- elif is_xformers_available():
92
- default_attn_proc = XFormersAttnProcessor()
93
- else:
94
- default_attn_proc = AttnProcessor()
95
- unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
96
- default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
97
- )
98
- unet.set_attn_processor(unet_lora_attn_procs)
99
-
100
- def __getattr__(self, name: str):
101
- try:
102
- return super().__getattr__(name)
103
- except AttributeError:
104
- return getattr(self.unet, name)
105
-
106
- def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
107
- if is_cfg_guidance:
108
- encoder_hidden_states = encoder_hidden_states[1:]
109
- class_labels = class_labels[1:]
110
- self.unet(
111
- noisy_cond_lat, timestep,
112
- encoder_hidden_states=encoder_hidden_states,
113
- class_labels=class_labels,
114
- cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
115
- **kwargs
116
- )
117
-
118
- def forward(
119
- self, sample, timestep, encoder_hidden_states, class_labels=None,
120
- *args, cross_attention_kwargs,
121
- down_block_res_samples=None, mid_block_res_sample=None,
122
- **kwargs
123
- ):
124
- cond_lat = cross_attention_kwargs['cond_lat']
125
- is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
126
- noise = torch.randn_like(cond_lat)
127
- if self.training:
128
- noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
129
- noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
130
- else:
131
- noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
132
- noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
133
- ref_dict = {}
134
- self.forward_cond(
135
- noisy_cond_lat, timestep,
136
- encoder_hidden_states, class_labels,
137
- ref_dict, is_cfg_guidance, **kwargs
138
- )
139
- weight_dtype = self.unet.dtype
140
- return self.unet(
141
- sample, timestep,
142
- encoder_hidden_states, *args,
143
- class_labels=class_labels,
144
- cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
145
- down_block_additional_residuals=[
146
- sample.to(dtype=weight_dtype) for sample in down_block_res_samples
147
- ] if down_block_res_samples is not None else None,
148
- mid_block_additional_residual=(
149
- mid_block_res_sample.to(dtype=weight_dtype)
150
- if mid_block_res_sample is not None else None
151
- ),
152
- **kwargs
153
- )
154
-
155
-
156
- def scale_latents(latents):
157
- latents = (latents - 0.22) * 0.75
158
- return latents
159
-
160
-
161
- def unscale_latents(latents):
162
- latents = latents / 0.75 + 0.22
163
- return latents
164
-
165
-
166
- def scale_image(image):
167
- image = image * 0.5 / 0.8
168
- return image
169
-
170
-
171
- def unscale_image(image):
172
- image = image / 0.5 * 0.8
173
- return image
174
-
175
-
176
- class DepthControlUNet(torch.nn.Module):
177
- def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None:
178
- super().__init__()
179
- self.unet = unet
180
- if controlnet is None:
181
- self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet)
182
- else:
183
- self.controlnet = controlnet
184
- DefaultAttnProc = AttnProcessor2_0
185
- if is_xformers_available():
186
- DefaultAttnProc = XFormersAttnProcessor
187
- self.controlnet.set_attn_processor(DefaultAttnProc())
188
- self.conditioning_scale = conditioning_scale
189
-
190
- def __getattr__(self, name: str):
191
- try:
192
- return super().__getattr__(name)
193
- except AttributeError:
194
- return getattr(self.unet, name)
195
-
196
- def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs):
197
- cross_attention_kwargs = dict(cross_attention_kwargs)
198
- control_depth = cross_attention_kwargs.pop('control_depth')
199
- down_block_res_samples, mid_block_res_sample = self.controlnet(
200
- sample,
201
- timestep,
202
- encoder_hidden_states=encoder_hidden_states,
203
- controlnet_cond=control_depth,
204
- conditioning_scale=self.conditioning_scale,
205
- return_dict=False,
206
- )
207
- return self.unet(
208
- sample,
209
- timestep,
210
- encoder_hidden_states=encoder_hidden_states,
211
- down_block_res_samples=down_block_res_samples,
212
- mid_block_res_sample=mid_block_res_sample,
213
- cross_attention_kwargs=cross_attention_kwargs
214
- )
215
-
216
-
217
- class ModuleListDict(torch.nn.Module):
218
- def __init__(self, procs: dict) -> None:
219
- super().__init__()
220
- self.keys = sorted(procs.keys())
221
- self.values = torch.nn.ModuleList(procs[k] for k in self.keys)
222
-
223
- def __getitem__(self, key):
224
- return self.values[self.keys.index(key)]
225
-
226
-
227
- class SuperNet(torch.nn.Module):
228
- def __init__(self, state_dict: Dict[str, torch.Tensor]):
229
- super().__init__()
230
- state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys()))
231
- self.layers = torch.nn.ModuleList(state_dict.values())
232
- self.mapping = dict(enumerate(state_dict.keys()))
233
- self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
234
-
235
- # .processor for unet, .self_attn for text encoder
236
- self.split_keys = [".processor", ".self_attn"]
237
-
238
- # we add a hook to state_dict() and load_state_dict() so that the
239
- # naming fits with `unet.attn_processors`
240
- def map_to(module, state_dict, *args, **kwargs):
241
- new_state_dict = {}
242
- for key, value in state_dict.items():
243
- num = int(key.split(".")[1]) # 0 is always "layers"
244
- new_key = key.replace(f"layers.{num}", module.mapping[num])
245
- new_state_dict[new_key] = value
246
-
247
- return new_state_dict
248
-
249
- def remap_key(key, state_dict):
250
- for k in self.split_keys:
251
- if k in key:
252
- return key.split(k)[0] + k
253
- return key.split('.')[0]
254
-
255
- def map_from(module, state_dict, *args, **kwargs):
256
- all_keys = list(state_dict.keys())
257
- for key in all_keys:
258
- replace_key = remap_key(key, state_dict)
259
- new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
260
- state_dict[new_key] = state_dict[key]
261
- del state_dict[key]
262
-
263
- self._register_state_dict_hook(map_to)
264
- self._register_load_state_dict_pre_hook(map_from, with_module=True)
265
-
266
-
267
- class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
268
- tokenizer: transformers.CLIPTokenizer
269
- text_encoder: transformers.CLIPTextModel
270
- vision_encoder: transformers.CLIPVisionModelWithProjection
271
-
272
- feature_extractor_clip: transformers.CLIPImageProcessor
273
- unet: UNet2DConditionModel
274
- scheduler: diffusers.schedulers.KarrasDiffusionSchedulers
275
-
276
- vae: AutoencoderKL
277
- ramping: nn.Linear
278
-
279
- feature_extractor_vae: transformers.CLIPImageProcessor
280
-
281
- depth_transforms_multi = transforms.Compose([
282
- transforms.ToTensor(),
283
- transforms.Normalize([0.5], [0.5])
284
- ])
285
-
286
- def __init__(
287
- self,
288
- vae: AutoencoderKL,
289
- text_encoder: CLIPTextModel,
290
- tokenizer: CLIPTokenizer,
291
- unet: UNet2DConditionModel,
292
- scheduler: KarrasDiffusionSchedulers,
293
- vision_encoder: transformers.CLIPVisionModelWithProjection,
294
- feature_extractor_clip: CLIPImageProcessor,
295
- feature_extractor_vae: CLIPImageProcessor,
296
- ramping_coefficients: Optional[list] = None,
297
- safety_checker=None,
298
- ):
299
- DiffusionPipeline.__init__(self)
300
-
301
- self.register_modules(
302
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
303
- unet=unet, scheduler=scheduler, safety_checker=None,
304
- vision_encoder=vision_encoder,
305
- feature_extractor_clip=feature_extractor_clip,
306
- feature_extractor_vae=feature_extractor_vae
307
- )
308
- self.register_to_config(ramping_coefficients=ramping_coefficients)
309
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
310
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
311
-
312
- def prepare(self):
313
- train_sched = DDPMScheduler.from_config(self.scheduler.config)
314
- if isinstance(self.unet, UNet2DConditionModel):
315
- self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()
316
-
317
- def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0):
318
- self.prepare()
319
- self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)
320
- return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)]))
321
-
322
- def encode_condition_image(self, image: torch.Tensor):
323
- image = self.vae.encode(image).latent_dist.sample()
324
- return image
325
-
326
- @torch.no_grad()
327
- def __call__(
328
- self,
329
- image: Image.Image = None,
330
- prompt = "",
331
- *args,
332
- num_images_per_prompt: Optional[int] = 1,
333
- guidance_scale=4.0,
334
- depth_image: Image.Image = None,
335
- output_type: Optional[str] = "pil",
336
- width=640,
337
- height=960,
338
- num_inference_steps=28,
339
- return_dict=True,
340
- **kwargs
341
- ):
342
- self.prepare()
343
- if image is None:
344
- raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
345
- assert not isinstance(image, torch.Tensor)
346
- image = to_rgb_image(image)
347
- image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
348
- image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
349
- if depth_image is not None and hasattr(self.unet, "controlnet"):
350
- depth_image = to_rgb_image(depth_image)
351
- depth_image = self.depth_transforms_multi(depth_image).to(
352
- device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
353
- )
354
- image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
355
- image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
356
- cond_lat = self.encode_condition_image(image)
357
- if guidance_scale > 1:
358
- negative_lat = self.encode_condition_image(torch.zeros_like(image))
359
- cond_lat = torch.cat([negative_lat, cond_lat])
360
- encoded = self.vision_encoder(image_2, output_hidden_states=False)
361
- global_embeds = encoded.image_embeds
362
- global_embeds = global_embeds.unsqueeze(-2)
363
-
364
- if hasattr(self, "encode_prompt"):
365
- encoder_hidden_states = self.encode_prompt(
366
- prompt,
367
- self.device,
368
- num_images_per_prompt,
369
- False
370
- )[0]
371
- else:
372
- encoder_hidden_states = self._encode_prompt(
373
- prompt,
374
- self.device,
375
- num_images_per_prompt,
376
- False
377
- )
378
- ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
379
- encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
380
- cak = dict(cond_lat=cond_lat)
381
- if hasattr(self.unet, "controlnet"):
382
- cak['control_depth'] = depth_image
383
- latents: torch.Tensor = super().__call__(
384
- None,
385
- *args,
386
- cross_attention_kwargs=cak,
387
- guidance_scale=guidance_scale,
388
- num_images_per_prompt=num_images_per_prompt,
389
- prompt_embeds=encoder_hidden_states,
390
- num_inference_steps=num_inference_steps,
391
- output_type='latent',
392
- width=width,
393
- height=height,
394
- **kwargs
395
- ).images
396
- latents = unscale_latents(latents)
397
- if not output_type == "latent":
398
- image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
399
- else:
400
- image = latents
401
-
402
- image = self.image_processor.postprocess(image, output_type=output_type)
403
- if not return_dict:
404
- return (image,)
405
-
406
- return ImagePipelineOutput(images=image)