Spaces:
Running
on
Zero
Running
on
Zero
gokaygokay
commited on
Commit
·
f25bfd1
1
Parent(s):
777f4d9
back
Browse files- app.py +129 -261
- configs/instant-mesh-base.yaml +0 -22
- configs/instant-mesh-large.yaml +0 -22
- configs/instant-nerf-base.yaml +0 -21
- configs/instant-nerf-large.yaml +0 -21
- requirements.txt +8 -24
- src/__init__.py +0 -0
- src/data/__init__.py +0 -0
- src/data/objaverse.py +0 -329
- src/model.py +0 -310
- src/model_mesh.py +0 -325
- src/models/__init__.py +0 -0
- src/models/decoder/__init__.py +0 -0
- src/models/decoder/transformer.py +0 -123
- src/models/encoder/__init__.py +0 -0
- src/models/encoder/dino.py +0 -550
- src/models/encoder/dino_wrapper.py +0 -80
- src/models/geometry/__init__.py +0 -7
- src/models/geometry/camera/__init__.py +0 -16
- src/models/geometry/camera/perspective_camera.py +0 -35
- src/models/geometry/render/__init__.py +0 -8
- src/models/geometry/render/neural_render.py +0 -121
- src/models/geometry/rep_3d/__init__.py +0 -18
- src/models/geometry/rep_3d/dmtet.py +0 -504
- src/models/geometry/rep_3d/dmtet_utils.py +0 -20
- src/models/geometry/rep_3d/extract_texture_map.py +0 -40
- src/models/geometry/rep_3d/flexicubes.py +0 -579
- src/models/geometry/rep_3d/flexicubes_geometry.py +0 -120
- src/models/geometry/rep_3d/tables.py +0 -791
- src/models/lrm.py +0 -196
- src/models/lrm_mesh.py +0 -385
- src/models/renderer/__init__.py +0 -9
- src/models/renderer/synthesizer.py +0 -203
- src/models/renderer/synthesizer_mesh.py +0 -141
- src/models/renderer/utils/__init__.py +0 -9
- src/models/renderer/utils/math_utils.py +0 -118
- src/models/renderer/utils/ray_marcher.py +0 -72
- src/models/renderer/utils/ray_sampler.py +0 -141
- src/models/renderer/utils/renderer.py +0 -323
- src/utils/__init__.py +0 -0
- src/utils/camera_util.py +0 -111
- src/utils/infer_util.py +0 -84
- src/utils/mesh_util.py +0 -181
- src/utils/train_util.py +0 -26
- zero123plus/pipeline.py +0 -406
app.py
CHANGED
@@ -1,286 +1,154 @@
|
|
1 |
import spaces
|
2 |
-
import
|
3 |
-
import time
|
4 |
-
from os import path
|
5 |
-
from huggingface_hub import hf_hub_download
|
6 |
-
import numpy as np
|
7 |
import torch
|
8 |
-
import rembg
|
9 |
from PIL import Image
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from pytorch_lightning import seed_everything
|
13 |
-
from omegaconf import OmegaConf
|
14 |
-
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
15 |
-
import gradio as gr
|
16 |
-
import shutil
|
17 |
-
import tempfile
|
18 |
-
from src.utils.train_util import instantiate_from_config
|
19 |
-
from src.utils.camera_util import (
|
20 |
-
FOV_to_intrinsics,
|
21 |
-
get_zero123plus_input_cameras,
|
22 |
-
get_circular_camera_poses,
|
23 |
-
)
|
24 |
-
from src.utils.mesh_util import save_obj, save_glb
|
25 |
-
from src.utils.infer_util import remove_background, resize_foreground, images_to_video
|
26 |
import random
|
27 |
-
import
|
28 |
-
import
|
|
|
|
|
29 |
|
30 |
-
#
|
31 |
-
|
32 |
-
|
33 |
-
os.environ["HF_HUB_CACHE"] = cache_path
|
34 |
-
os.environ["HF_HOME"] = cache_path
|
35 |
|
36 |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
42 |
-
|
43 |
-
class timer:
|
44 |
-
def __init__(self, method_name="timed process"):
|
45 |
-
self.method = method_name
|
46 |
-
def __enter__(self):
|
47 |
-
self.start = time.time()
|
48 |
-
print(f"{self.method} starts")
|
49 |
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
50 |
-
end = time.time()
|
51 |
-
print(f"{self.method} took {str(round(end - self.start, 2))}s")
|
52 |
-
|
53 |
-
def find_cuda():
|
54 |
-
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
55 |
-
if cuda_home and os.path.exists(cuda_home):
|
56 |
-
return cuda_home
|
57 |
-
nvcc_path = shutil.which('nvcc')
|
58 |
-
if nvcc_path:
|
59 |
-
cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
|
60 |
-
return cuda_path
|
61 |
-
return None
|
62 |
-
|
63 |
-
cuda_path = find_cuda()
|
64 |
-
if cuda_path:
|
65 |
-
print(f"CUDA installation found at: {cuda_path}")
|
66 |
-
else:
|
67 |
-
print("CUDA installation not found")
|
68 |
-
|
69 |
-
|
70 |
-
API_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
71 |
-
headers = {"Authorization": f"Bearer {API_TOKEN}"}
|
72 |
-
timeout = 100
|
73 |
-
|
74 |
-
device = 'cuda'
|
75 |
-
|
76 |
-
# Load 3D generation models
|
77 |
-
config_path = 'configs/instant-mesh-large.yaml'
|
78 |
-
config = OmegaConf.load(config_path)
|
79 |
-
config_name = os.path.basename(config_path).replace('.yaml', '')
|
80 |
-
model_config = config.model_config
|
81 |
-
infer_config = config.infer_config
|
82 |
-
|
83 |
-
IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
|
84 |
-
|
85 |
-
# Load diffusion model for 3D generation
|
86 |
-
print('Loading diffusion model ...')
|
87 |
-
pipeline = DiffusionPipeline.from_pretrained(
|
88 |
-
"sudo-ai/zero123plus-v1.2",
|
89 |
-
custom_pipeline="zero123plus",
|
90 |
-
torch_dtype=torch.float16,
|
91 |
-
)
|
92 |
-
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
93 |
-
pipeline.scheduler.config, timestep_spacing='trailing'
|
94 |
-
)
|
95 |
-
|
96 |
-
# Load custom white-background UNet
|
97 |
-
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
|
98 |
-
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
99 |
-
pipeline.unet.load_state_dict(state_dict, strict=True)
|
100 |
-
|
101 |
-
pipeline = pipeline.to(device)
|
102 |
-
|
103 |
-
# Load reconstruction model
|
104 |
-
print('Loading reconstruction model ...')
|
105 |
-
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
|
106 |
-
model = instantiate_from_config(model_config)
|
107 |
-
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
108 |
-
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
|
109 |
-
model.load_state_dict(state_dict, strict=True)
|
110 |
-
|
111 |
-
model = model.to(device)
|
112 |
-
|
113 |
-
print('Loading Finished!')
|
114 |
-
|
115 |
-
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
|
116 |
-
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
|
117 |
-
if is_flexicubes:
|
118 |
-
cameras = torch.linalg.inv(c2ws)
|
119 |
-
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
120 |
-
else:
|
121 |
-
extrinsics = c2ws.flatten(-2)
|
122 |
-
intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
|
123 |
-
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
|
124 |
-
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
125 |
-
return cameras
|
126 |
-
|
127 |
-
def preprocess(input_image, do_remove_background):
|
128 |
-
rembg_session = rembg.new_session() if do_remove_background else None
|
129 |
-
if do_remove_background:
|
130 |
-
input_image = remove_background(input_image, rembg_session)
|
131 |
-
input_image = resize_foreground(input_image, 0.85)
|
132 |
-
return input_image
|
133 |
|
|
|
|
|
|
|
134 |
|
|
|
|
|
135 |
|
|
|
|
|
136 |
|
|
|
137 |
@spaces.GPU
|
138 |
-
def
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
num_inference_steps=sample_steps
|
143 |
-
).images[0]
|
144 |
-
show_image = np.asarray(z123_image, dtype=np.uint8)
|
145 |
-
show_image = torch.from_numpy(show_image)
|
146 |
-
show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
147 |
-
show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
|
148 |
-
show_image = Image.fromarray(show_image.numpy())
|
149 |
-
return z123_image, show_image
|
150 |
-
|
151 |
-
@spaces.GPU
|
152 |
-
def make3d(images):
|
153 |
-
global model
|
154 |
-
if IS_FLEXICUBES:
|
155 |
-
model.init_flexicubes_geometry(device, use_renderer=False)
|
156 |
-
model = model.eval()
|
157 |
-
|
158 |
-
images = np.asarray(images, dtype=np.float32) / 255.0
|
159 |
-
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()
|
160 |
-
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)
|
161 |
-
|
162 |
-
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
|
163 |
-
render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
|
164 |
-
|
165 |
-
images = images.unsqueeze(0).to(device)
|
166 |
-
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
|
167 |
-
|
168 |
-
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
|
169 |
-
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
170 |
-
mesh_dirname = os.path.dirname(mesh_fpath)
|
171 |
-
mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
|
172 |
-
|
173 |
-
with torch.no_grad():
|
174 |
-
planes = model.forward_planes(images, input_cameras)
|
175 |
-
mesh_out = model.extract_mesh(
|
176 |
-
planes,
|
177 |
-
use_texture_map=False,
|
178 |
-
**infer_config,
|
179 |
-
)
|
180 |
-
vertices, faces, vertex_colors = mesh_out
|
181 |
-
vertices = vertices[:, [1, 2, 0]]
|
182 |
-
save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
|
183 |
-
save_obj(vertices, faces, vertex_colors, mesh_fpath)
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
if randomize_seed:
|
196 |
-
seed = random.randint(
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
-
try:
|
216 |
-
image_bytes = response.content
|
217 |
-
image = Image.open(io.BytesIO(image_bytes))
|
218 |
-
return image
|
219 |
-
except Exception as e:
|
220 |
-
print(f"Error when trying to open the image: {e}")
|
221 |
-
return None
|
222 |
-
|
223 |
-
# Update the Gradio interface
|
224 |
-
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
225 |
-
gr.Markdown(
|
226 |
-
"""
|
227 |
-
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|
228 |
-
<h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem;">Flux Image to 3D Model Generator</h1>
|
229 |
-
</div>
|
230 |
-
"""
|
231 |
-
)
|
232 |
-
|
233 |
with gr.Row():
|
234 |
-
with gr.Column(scale=
|
235 |
-
|
236 |
-
label="
|
237 |
-
placeholder="E.g., A serene landscape with mountains and a lake at sunset",
|
238 |
-
lines=3
|
239 |
-
)
|
240 |
|
241 |
with gr.Accordion("Advanced Settings", open=False):
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
seed = gr.Number(label="Seed", value=-1, precision=0)
|
252 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
253 |
|
254 |
-
generate_btn = gr.Button("Generate
|
255 |
-
|
256 |
-
with gr.Column(scale=
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
with gr.Tab("GLB"):
|
263 |
-
output_model_glb = gr.Model3D(label="Output Model (GLB Format)")
|
264 |
-
|
265 |
-
mv_images = gr.State()
|
266 |
-
|
267 |
-
def process_pipeline(prompt, height, width, steps, scales, seed, randomize_seed):
|
268 |
-
# Generate Flux image using the API
|
269 |
-
prompt_real = f"wbgmsst, {prompt}, white background"
|
270 |
-
flux_image = query(prompt_real, steps, scales, randomize_seed, seed, width, height)
|
271 |
-
if flux_image is None:
|
272 |
-
raise gr.Error("Failed to generate image")
|
273 |
-
|
274 |
-
processed_image = preprocess(flux_image, do_remove_background=True)
|
275 |
-
mv_images, show_image = generate_mvs(processed_image, steps, seed)
|
276 |
-
obj_path, glb_path = make3d(mv_images)
|
277 |
-
return flux_image, show_image, obj_path, glb_path
|
278 |
-
|
279 |
generate_btn.click(
|
280 |
-
fn=
|
281 |
-
inputs=[
|
282 |
-
|
|
|
|
|
|
|
283 |
)
|
284 |
|
285 |
-
|
286 |
-
demo.queue().launch()
|
|
|
1 |
import spaces
|
2 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
3 |
import torch
|
|
|
4 |
from PIL import Image
|
5 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
|
6 |
+
from diffusers import DiffusionPipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import random
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
12 |
|
13 |
+
# Initialize models
|
14 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
dtype = torch.bfloat16
|
|
|
|
|
16 |
|
17 |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
|
18 |
|
19 |
+
# FLUX.1-dev model
|
20 |
+
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token = huggingface_token).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
# Initialize Florence model
|
23 |
+
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
|
24 |
+
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
|
25 |
|
26 |
+
# Prompt Enhancer
|
27 |
+
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
|
28 |
|
29 |
+
MAX_SEED = np.iinfo(np.int32).max
|
30 |
+
MAX_IMAGE_SIZE = 2048
|
31 |
|
32 |
+
# Florence caption function
|
33 |
@spaces.GPU
|
34 |
+
def florence_caption(image):
|
35 |
+
# Convert image to PIL if it's not already
|
36 |
+
if not isinstance(image, Image.Image):
|
37 |
+
image = Image.fromarray(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
|
40 |
+
generated_ids = florence_model.generate(
|
41 |
+
input_ids=inputs["input_ids"],
|
42 |
+
pixel_values=inputs["pixel_values"],
|
43 |
+
max_new_tokens=1024,
|
44 |
+
early_stopping=False,
|
45 |
+
do_sample=False,
|
46 |
+
num_beams=3,
|
47 |
+
)
|
48 |
+
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
49 |
+
parsed_answer = florence_processor.post_process_generation(
|
50 |
+
generated_text,
|
51 |
+
task="<MORE_DETAILED_CAPTION>",
|
52 |
+
image_size=(image.width, image.height)
|
53 |
+
)
|
54 |
+
return parsed_answer["<MORE_DETAILED_CAPTION>"]
|
55 |
+
|
56 |
+
# Prompt Enhancer function
|
57 |
+
def enhance_prompt(input_prompt):
|
58 |
+
result = enhancer_long("Enhance the description: " + input_prompt)
|
59 |
+
enhanced_text = result[0]['summary_text']
|
60 |
+
return enhanced_text
|
61 |
+
|
62 |
+
@spaces.GPU(duration=190)
|
63 |
+
def process_workflow(image, text_prompt, use_enhancer, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
|
64 |
+
if image is not None:
|
65 |
+
# Convert image to PIL if it's not already
|
66 |
+
if not isinstance(image, Image.Image):
|
67 |
+
image = Image.fromarray(image)
|
68 |
+
|
69 |
+
prompt = florence_caption(image)
|
70 |
+
print(prompt)
|
71 |
+
else:
|
72 |
+
prompt = text_prompt
|
73 |
+
|
74 |
+
if use_enhancer:
|
75 |
+
prompt = enhance_prompt(prompt)
|
76 |
|
77 |
if randomize_seed:
|
78 |
+
seed = random.randint(0, MAX_SEED)
|
79 |
|
80 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
81 |
+
|
82 |
+
image = pipe(
|
83 |
+
prompt=prompt,
|
84 |
+
generator=generator,
|
85 |
+
num_inference_steps=num_inference_steps,
|
86 |
+
width=width,
|
87 |
+
height=height,
|
88 |
+
guidance_scale=guidance_scale
|
89 |
+
).images[0]
|
90 |
+
|
91 |
+
return image, prompt, seed
|
92 |
+
|
93 |
+
custom_css = """
|
94 |
+
.input-group, .output-group {
|
95 |
+
border: 1px solid #e0e0e0;
|
96 |
+
border-radius: 10px;
|
97 |
+
padding: 20px;
|
98 |
+
margin-bottom: 20px;
|
99 |
+
background-color: #f9f9f9;
|
100 |
+
}
|
101 |
+
.submit-btn {
|
102 |
+
background-color: #2980b9 !important;
|
103 |
+
color: white !important;
|
104 |
+
}
|
105 |
+
.submit-btn:hover {
|
106 |
+
background-color: #3498db !important;
|
107 |
+
}
|
108 |
+
"""
|
109 |
+
|
110 |
+
title = """<h1 align="center">FLUX.1-dev with Florence-2 Captioner and Prompt Enhancer</h1>
|
111 |
+
<p><center>
|
112 |
+
<a href="https://huggingface.co/black-forest-labs/FLUX.1-dev" target="_blank">[FLUX.1-dev Model]</a>
|
113 |
+
<a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
|
114 |
+
<a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
|
115 |
+
<p align="center">Create long prompts from images or enhance your short prompts with prompt enhancer</p>
|
116 |
+
</center></p>
|
117 |
+
"""
|
118 |
+
|
119 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
|
120 |
+
gr.HTML(title)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
with gr.Row():
|
123 |
+
with gr.Column(scale=1):
|
124 |
+
with gr.Group(elem_classes="input-group"):
|
125 |
+
input_image = gr.Image(label="Input Image (Florence-2 Captioner)")
|
|
|
|
|
|
|
126 |
|
127 |
with gr.Accordion("Advanced Settings", open=False):
|
128 |
+
text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)")
|
129 |
+
use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
|
130 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
131 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
132 |
+
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
|
133 |
+
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
|
134 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5)
|
135 |
+
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
|
|
|
|
|
|
|
136 |
|
137 |
+
generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
|
138 |
+
|
139 |
+
with gr.Column(scale=1):
|
140 |
+
with gr.Group(elem_classes="output-group"):
|
141 |
+
output_image = gr.Image(label="Result", elem_id="gallery", show_label=False)
|
142 |
+
final_prompt = gr.Textbox(label="Final Prompt Used")
|
143 |
+
used_seed = gr.Number(label="Seed Used")
|
144 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
generate_btn.click(
|
146 |
+
fn=process_workflow,
|
147 |
+
inputs=[
|
148 |
+
input_image, text_prompt, use_enhancer, seed, randomize_seed,
|
149 |
+
width, height, guidance_scale, num_inference_steps
|
150 |
+
],
|
151 |
+
outputs=[output_image, final_prompt, used_seed]
|
152 |
)
|
153 |
|
154 |
+
demo.launch(debug=True)
|
|
configs/instant-mesh-base.yaml
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
model_config:
|
2 |
-
target: src.models.lrm_mesh.InstantMesh
|
3 |
-
params:
|
4 |
-
encoder_feat_dim: 768
|
5 |
-
encoder_freeze: false
|
6 |
-
encoder_model_name: facebook/dino-vitb16
|
7 |
-
transformer_dim: 1024
|
8 |
-
transformer_layers: 12
|
9 |
-
transformer_heads: 16
|
10 |
-
triplane_low_res: 32
|
11 |
-
triplane_high_res: 64
|
12 |
-
triplane_dim: 40
|
13 |
-
rendering_samples_per_ray: 96
|
14 |
-
grid_res: 128
|
15 |
-
grid_scale: 2.1
|
16 |
-
|
17 |
-
|
18 |
-
infer_config:
|
19 |
-
unet_path: ckpts/diffusion_pytorch_model.bin
|
20 |
-
model_path: ckpts/instant_mesh_base.ckpt
|
21 |
-
texture_resolution: 1024
|
22 |
-
render_resolution: 512
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/instant-mesh-large.yaml
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
model_config:
|
2 |
-
target: src.models.lrm_mesh.InstantMesh
|
3 |
-
params:
|
4 |
-
encoder_feat_dim: 768
|
5 |
-
encoder_freeze: false
|
6 |
-
encoder_model_name: facebook/dino-vitb16
|
7 |
-
transformer_dim: 1024
|
8 |
-
transformer_layers: 16
|
9 |
-
transformer_heads: 16
|
10 |
-
triplane_low_res: 32
|
11 |
-
triplane_high_res: 64
|
12 |
-
triplane_dim: 80
|
13 |
-
rendering_samples_per_ray: 128
|
14 |
-
grid_res: 128
|
15 |
-
grid_scale: 2.1
|
16 |
-
|
17 |
-
|
18 |
-
infer_config:
|
19 |
-
unet_path: ckpts/diffusion_pytorch_model.bin
|
20 |
-
model_path: ckpts/instant_mesh_large.ckpt
|
21 |
-
texture_resolution: 1024
|
22 |
-
render_resolution: 512
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/instant-nerf-base.yaml
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
model_config:
|
2 |
-
target: src.models.lrm.InstantNeRF
|
3 |
-
params:
|
4 |
-
encoder_feat_dim: 768
|
5 |
-
encoder_freeze: false
|
6 |
-
encoder_model_name: facebook/dino-vitb16
|
7 |
-
transformer_dim: 1024
|
8 |
-
transformer_layers: 12
|
9 |
-
transformer_heads: 16
|
10 |
-
triplane_low_res: 32
|
11 |
-
triplane_high_res: 64
|
12 |
-
triplane_dim: 40
|
13 |
-
rendering_samples_per_ray: 96
|
14 |
-
|
15 |
-
|
16 |
-
infer_config:
|
17 |
-
unet_path: ckpts/diffusion_pytorch_model.bin
|
18 |
-
model_path: ckpts/instant_nerf_base.ckpt
|
19 |
-
mesh_threshold: 10.0
|
20 |
-
mesh_resolution: 256
|
21 |
-
render_resolution: 384
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/instant-nerf-large.yaml
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
model_config:
|
2 |
-
target: src.models.lrm.InstantNeRF
|
3 |
-
params:
|
4 |
-
encoder_feat_dim: 768
|
5 |
-
encoder_freeze: false
|
6 |
-
encoder_model_name: facebook/dino-vitb16
|
7 |
-
transformer_dim: 1024
|
8 |
-
transformer_layers: 16
|
9 |
-
transformer_heads: 16
|
10 |
-
triplane_low_res: 32
|
11 |
-
triplane_high_res: 64
|
12 |
-
triplane_dim: 80
|
13 |
-
rendering_samples_per_ray: 128
|
14 |
-
|
15 |
-
|
16 |
-
infer_config:
|
17 |
-
unet_path: ckpts/diffusion_pytorch_model.bin
|
18 |
-
model_path: ckpts/instant_nerf_large.ckpt
|
19 |
-
mesh_threshold: 10.0
|
20 |
-
mesh_resolution: 256
|
21 |
-
render_resolution: 384
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,27 +1,11 @@
|
|
|
|
|
|
|
|
|
|
1 |
torch==2.4.0
|
2 |
torchvision==0.19.0
|
3 |
-
|
4 |
-
|
5 |
-
einops
|
6 |
-
omegaconf
|
7 |
-
deepspeed
|
8 |
-
torchmetrics
|
9 |
-
webdataset
|
10 |
sentencepiece
|
11 |
-
|
12 |
-
|
13 |
-
PyMCubes
|
14 |
-
trimesh
|
15 |
-
rembg
|
16 |
-
peft
|
17 |
-
transformers==4.44.0
|
18 |
-
diffusers==0.31.0
|
19 |
-
bitsandbytes
|
20 |
-
imageio[ffmpeg]
|
21 |
-
xatlas
|
22 |
-
plyfile
|
23 |
-
xformers==0.0.27.post2
|
24 |
-
git+https://github.com/NVlabs/nvdiffrast/
|
25 |
-
huggingface-hub==0.25.2
|
26 |
-
optimum-quanto==0.2.5
|
27 |
-
k-diffusion
|
|
|
1 |
+
spaces
|
2 |
+
huggingface_hub
|
3 |
+
accelerate
|
4 |
+
git+https://github.com/huggingface/diffusers.git
|
5 |
torch==2.4.0
|
6 |
torchvision==0.19.0
|
7 |
+
transformers==4.42.4
|
8 |
+
xformers
|
|
|
|
|
|
|
|
|
|
|
9 |
sentencepiece
|
10 |
+
timm
|
11 |
+
einops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/__init__.py
DELETED
File without changes
|
src/data/__init__.py
DELETED
File without changes
|
src/data/objaverse.py
DELETED
@@ -1,329 +0,0 @@
|
|
1 |
-
import os, sys
|
2 |
-
import math
|
3 |
-
import json
|
4 |
-
import importlib
|
5 |
-
from pathlib import Path
|
6 |
-
|
7 |
-
import cv2
|
8 |
-
import random
|
9 |
-
import numpy as np
|
10 |
-
from PIL import Image
|
11 |
-
import webdataset as wds
|
12 |
-
import pytorch_lightning as pl
|
13 |
-
|
14 |
-
import torch
|
15 |
-
import torch.nn.functional as F
|
16 |
-
from torch.utils.data import Dataset
|
17 |
-
from torch.utils.data import DataLoader
|
18 |
-
from torch.utils.data.distributed import DistributedSampler
|
19 |
-
from torchvision import transforms
|
20 |
-
|
21 |
-
from src.utils.train_util import instantiate_from_config
|
22 |
-
from src.utils.camera_util import (
|
23 |
-
FOV_to_intrinsics,
|
24 |
-
center_looking_at_camera_pose,
|
25 |
-
get_surrounding_views,
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
class DataModuleFromConfig(pl.LightningDataModule):
|
30 |
-
def __init__(
|
31 |
-
self,
|
32 |
-
batch_size=8,
|
33 |
-
num_workers=4,
|
34 |
-
train=None,
|
35 |
-
validation=None,
|
36 |
-
test=None,
|
37 |
-
**kwargs,
|
38 |
-
):
|
39 |
-
super().__init__()
|
40 |
-
|
41 |
-
self.batch_size = batch_size
|
42 |
-
self.num_workers = num_workers
|
43 |
-
|
44 |
-
self.dataset_configs = dict()
|
45 |
-
if train is not None:
|
46 |
-
self.dataset_configs['train'] = train
|
47 |
-
if validation is not None:
|
48 |
-
self.dataset_configs['validation'] = validation
|
49 |
-
if test is not None:
|
50 |
-
self.dataset_configs['test'] = test
|
51 |
-
|
52 |
-
def setup(self, stage):
|
53 |
-
|
54 |
-
if stage in ['fit']:
|
55 |
-
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
56 |
-
else:
|
57 |
-
raise NotImplementedError
|
58 |
-
|
59 |
-
def train_dataloader(self):
|
60 |
-
|
61 |
-
sampler = DistributedSampler(self.datasets['train'])
|
62 |
-
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
63 |
-
|
64 |
-
def val_dataloader(self):
|
65 |
-
|
66 |
-
sampler = DistributedSampler(self.datasets['validation'])
|
67 |
-
return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
68 |
-
|
69 |
-
def test_dataloader(self):
|
70 |
-
|
71 |
-
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
72 |
-
|
73 |
-
|
74 |
-
class ObjaverseData(Dataset):
|
75 |
-
def __init__(self,
|
76 |
-
root_dir='objaverse/',
|
77 |
-
meta_fname='valid_paths.json',
|
78 |
-
input_image_dir='rendering_random_32views',
|
79 |
-
target_image_dir='rendering_random_32views',
|
80 |
-
input_view_num=6,
|
81 |
-
target_view_num=2,
|
82 |
-
total_view_n=32,
|
83 |
-
fov=50,
|
84 |
-
camera_rotation=True,
|
85 |
-
validation=False,
|
86 |
-
):
|
87 |
-
self.root_dir = Path(root_dir)
|
88 |
-
self.input_image_dir = input_image_dir
|
89 |
-
self.target_image_dir = target_image_dir
|
90 |
-
|
91 |
-
self.input_view_num = input_view_num
|
92 |
-
self.target_view_num = target_view_num
|
93 |
-
self.total_view_n = total_view_n
|
94 |
-
self.fov = fov
|
95 |
-
self.camera_rotation = camera_rotation
|
96 |
-
|
97 |
-
with open(os.path.join(root_dir, meta_fname)) as f:
|
98 |
-
filtered_dict = json.load(f)
|
99 |
-
paths = filtered_dict['good_objs']
|
100 |
-
self.paths = paths
|
101 |
-
|
102 |
-
self.depth_scale = 4.0
|
103 |
-
|
104 |
-
total_objects = len(self.paths)
|
105 |
-
print('============= length of dataset %d =============' % len(self.paths))
|
106 |
-
|
107 |
-
def __len__(self):
|
108 |
-
return len(self.paths)
|
109 |
-
|
110 |
-
def load_im(self, path, color):
|
111 |
-
'''
|
112 |
-
replace background pixel with random color in rendering
|
113 |
-
'''
|
114 |
-
pil_img = Image.open(path)
|
115 |
-
|
116 |
-
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
117 |
-
alpha = image[:, :, 3:]
|
118 |
-
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
119 |
-
|
120 |
-
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
121 |
-
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
122 |
-
return image, alpha
|
123 |
-
|
124 |
-
def __getitem__(self, index):
|
125 |
-
# load data
|
126 |
-
while True:
|
127 |
-
input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
|
128 |
-
target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
|
129 |
-
|
130 |
-
indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
|
131 |
-
input_indices = indices[:self.input_view_num]
|
132 |
-
target_indices = indices[self.input_view_num:]
|
133 |
-
|
134 |
-
'''background color, default: white'''
|
135 |
-
bg_white = [1., 1., 1.]
|
136 |
-
bg_black = [0., 0., 0.]
|
137 |
-
|
138 |
-
image_list = []
|
139 |
-
alpha_list = []
|
140 |
-
depth_list = []
|
141 |
-
normal_list = []
|
142 |
-
pose_list = []
|
143 |
-
|
144 |
-
try:
|
145 |
-
input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
|
146 |
-
for idx in input_indices:
|
147 |
-
image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
|
148 |
-
normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
|
149 |
-
depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
|
150 |
-
depth = torch.from_numpy(depth).unsqueeze(0)
|
151 |
-
pose = input_cameras[idx]
|
152 |
-
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
|
153 |
-
|
154 |
-
image_list.append(image)
|
155 |
-
alpha_list.append(alpha)
|
156 |
-
depth_list.append(depth)
|
157 |
-
normal_list.append(normal)
|
158 |
-
pose_list.append(pose)
|
159 |
-
|
160 |
-
target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
|
161 |
-
for idx in target_indices:
|
162 |
-
image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
|
163 |
-
normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
|
164 |
-
depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
|
165 |
-
depth = torch.from_numpy(depth).unsqueeze(0)
|
166 |
-
pose = target_cameras[idx]
|
167 |
-
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
|
168 |
-
|
169 |
-
image_list.append(image)
|
170 |
-
alpha_list.append(alpha)
|
171 |
-
depth_list.append(depth)
|
172 |
-
normal_list.append(normal)
|
173 |
-
pose_list.append(pose)
|
174 |
-
|
175 |
-
except Exception as e:
|
176 |
-
print(e)
|
177 |
-
index = np.random.randint(0, len(self.paths))
|
178 |
-
continue
|
179 |
-
|
180 |
-
break
|
181 |
-
|
182 |
-
images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
|
183 |
-
alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
|
184 |
-
depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
|
185 |
-
normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
|
186 |
-
w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
|
187 |
-
c2ws = torch.linalg.inv(w2cs).float()
|
188 |
-
|
189 |
-
normals = normals * 2.0 - 1.0
|
190 |
-
normals = F.normalize(normals, dim=1)
|
191 |
-
normals = (normals + 1.0) / 2.0
|
192 |
-
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
|
193 |
-
|
194 |
-
# random rotation along z axis
|
195 |
-
if self.camera_rotation:
|
196 |
-
degree = np.random.uniform(0, math.pi * 2)
|
197 |
-
rot = torch.tensor([
|
198 |
-
[np.cos(degree), -np.sin(degree), 0, 0],
|
199 |
-
[np.sin(degree), np.cos(degree), 0, 0],
|
200 |
-
[0, 0, 1, 0],
|
201 |
-
[0, 0, 0, 1],
|
202 |
-
]).unsqueeze(0).float()
|
203 |
-
c2ws = torch.matmul(rot, c2ws)
|
204 |
-
|
205 |
-
# rotate normals
|
206 |
-
N, _, H, W = normals.shape
|
207 |
-
normals = normals * 2.0 - 1.0
|
208 |
-
normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
|
209 |
-
normals = F.normalize(normals, dim=1)
|
210 |
-
normals = (normals + 1.0) / 2.0
|
211 |
-
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
|
212 |
-
|
213 |
-
# random scaling
|
214 |
-
if np.random.rand() < 0.5:
|
215 |
-
scale = np.random.uniform(0.8, 1.0)
|
216 |
-
c2ws[:, :3, 3] *= scale
|
217 |
-
depths *= scale
|
218 |
-
|
219 |
-
# instrinsics of perspective cameras
|
220 |
-
K = FOV_to_intrinsics(self.fov)
|
221 |
-
Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
|
222 |
-
|
223 |
-
data = {
|
224 |
-
'input_images': images[:self.input_view_num], # (6, 3, H, W)
|
225 |
-
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
|
226 |
-
'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
|
227 |
-
'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
|
228 |
-
'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
|
229 |
-
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
|
230 |
-
|
231 |
-
# lrm generator input and supervision
|
232 |
-
'target_images': images[self.input_view_num:], # (V, 3, H, W)
|
233 |
-
'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
|
234 |
-
'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
|
235 |
-
'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
|
236 |
-
'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
|
237 |
-
'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
|
238 |
-
|
239 |
-
'depth_available': 1,
|
240 |
-
}
|
241 |
-
return data
|
242 |
-
|
243 |
-
|
244 |
-
class ValidationData(Dataset):
|
245 |
-
def __init__(self,
|
246 |
-
root_dir='objaverse/',
|
247 |
-
input_view_num=6,
|
248 |
-
input_image_size=256,
|
249 |
-
fov=50,
|
250 |
-
):
|
251 |
-
self.root_dir = Path(root_dir)
|
252 |
-
self.input_view_num = input_view_num
|
253 |
-
self.input_image_size = input_image_size
|
254 |
-
self.fov = fov
|
255 |
-
|
256 |
-
self.paths = sorted(os.listdir(self.root_dir))
|
257 |
-
print('============= length of dataset %d =============' % len(self.paths))
|
258 |
-
|
259 |
-
cam_distance = 2.5
|
260 |
-
azimuths = np.array([30, 90, 150, 210, 270, 330])
|
261 |
-
elevations = np.array([30, -20, 30, -20, 30, -20])
|
262 |
-
azimuths = np.deg2rad(azimuths)
|
263 |
-
elevations = np.deg2rad(elevations)
|
264 |
-
|
265 |
-
x = cam_distance * np.cos(elevations) * np.cos(azimuths)
|
266 |
-
y = cam_distance * np.cos(elevations) * np.sin(azimuths)
|
267 |
-
z = cam_distance * np.sin(elevations)
|
268 |
-
|
269 |
-
cam_locations = np.stack([x, y, z], axis=-1)
|
270 |
-
cam_locations = torch.from_numpy(cam_locations).float()
|
271 |
-
c2ws = center_looking_at_camera_pose(cam_locations)
|
272 |
-
self.c2ws = c2ws.float()
|
273 |
-
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
|
274 |
-
|
275 |
-
render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
|
276 |
-
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
|
277 |
-
self.render_c2ws = render_c2ws.float()
|
278 |
-
self.render_Ks = render_Ks.float()
|
279 |
-
|
280 |
-
def __len__(self):
|
281 |
-
return len(self.paths)
|
282 |
-
|
283 |
-
def load_im(self, path, color):
|
284 |
-
'''
|
285 |
-
replace background pixel with random color in rendering
|
286 |
-
'''
|
287 |
-
pil_img = Image.open(path)
|
288 |
-
pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
|
289 |
-
|
290 |
-
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
291 |
-
if image.shape[-1] == 4:
|
292 |
-
alpha = image[:, :, 3:]
|
293 |
-
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
294 |
-
else:
|
295 |
-
alpha = np.ones_like(image[:, :, :1])
|
296 |
-
|
297 |
-
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
298 |
-
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
299 |
-
return image, alpha
|
300 |
-
|
301 |
-
def __getitem__(self, index):
|
302 |
-
# load data
|
303 |
-
input_image_path = os.path.join(self.root_dir, self.paths[index])
|
304 |
-
|
305 |
-
'''background color, default: white'''
|
306 |
-
# color = np.random.uniform(0.48, 0.52)
|
307 |
-
bkg_color = [1.0, 1.0, 1.0]
|
308 |
-
|
309 |
-
image_list = []
|
310 |
-
alpha_list = []
|
311 |
-
|
312 |
-
for idx in range(self.input_view_num):
|
313 |
-
image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
|
314 |
-
image_list.append(image)
|
315 |
-
alpha_list.append(alpha)
|
316 |
-
|
317 |
-
images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
|
318 |
-
alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
|
319 |
-
|
320 |
-
data = {
|
321 |
-
'input_images': images, # (6, 3, H, W)
|
322 |
-
'input_alphas': alphas, # (6, 1, H, W)
|
323 |
-
'input_c2ws': self.c2ws, # (6, 4, 4)
|
324 |
-
'input_Ks': self.Ks, # (6, 3, 3)
|
325 |
-
|
326 |
-
'render_c2ws': self.render_c2ws,
|
327 |
-
'render_Ks': self.render_Ks,
|
328 |
-
}
|
329 |
-
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/model.py
DELETED
@@ -1,310 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from torchvision.transforms import v2
|
6 |
-
from torchvision.utils import make_grid, save_image
|
7 |
-
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
8 |
-
import pytorch_lightning as pl
|
9 |
-
from einops import rearrange, repeat
|
10 |
-
|
11 |
-
from src.utils.train_util import instantiate_from_config
|
12 |
-
|
13 |
-
|
14 |
-
class MVRecon(pl.LightningModule):
|
15 |
-
def __init__(
|
16 |
-
self,
|
17 |
-
lrm_generator_config,
|
18 |
-
lrm_path=None,
|
19 |
-
input_size=256,
|
20 |
-
render_size=192,
|
21 |
-
):
|
22 |
-
super(MVRecon, self).__init__()
|
23 |
-
|
24 |
-
self.input_size = input_size
|
25 |
-
self.render_size = render_size
|
26 |
-
|
27 |
-
# init modules
|
28 |
-
self.lrm_generator = instantiate_from_config(lrm_generator_config)
|
29 |
-
if lrm_path is not None:
|
30 |
-
lrm_ckpt = torch.load(lrm_path)
|
31 |
-
self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
|
32 |
-
|
33 |
-
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
|
34 |
-
|
35 |
-
self.validation_step_outputs = []
|
36 |
-
|
37 |
-
def on_fit_start(self):
|
38 |
-
if self.global_rank == 0:
|
39 |
-
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
|
40 |
-
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
|
41 |
-
|
42 |
-
def prepare_batch_data(self, batch):
|
43 |
-
lrm_generator_input = {}
|
44 |
-
render_gt = {} # for supervision
|
45 |
-
|
46 |
-
# input images
|
47 |
-
images = batch['input_images']
|
48 |
-
images = v2.functional.resize(
|
49 |
-
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
50 |
-
|
51 |
-
lrm_generator_input['images'] = images.to(self.device)
|
52 |
-
|
53 |
-
# input cameras and render cameras
|
54 |
-
input_c2ws = batch['input_c2ws'].flatten(-2)
|
55 |
-
input_Ks = batch['input_Ks'].flatten(-2)
|
56 |
-
target_c2ws = batch['target_c2ws'].flatten(-2)
|
57 |
-
target_Ks = batch['target_Ks'].flatten(-2)
|
58 |
-
render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
|
59 |
-
render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
|
60 |
-
render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
|
61 |
-
|
62 |
-
input_extrinsics = input_c2ws[:, :, :12]
|
63 |
-
input_intrinsics = torch.stack([
|
64 |
-
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
65 |
-
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
66 |
-
], dim=-1)
|
67 |
-
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
68 |
-
|
69 |
-
# add noise to input cameras
|
70 |
-
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
|
71 |
-
|
72 |
-
lrm_generator_input['cameras'] = cameras.to(self.device)
|
73 |
-
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
|
74 |
-
|
75 |
-
# target images
|
76 |
-
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
|
77 |
-
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
|
78 |
-
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
|
79 |
-
|
80 |
-
# random crop
|
81 |
-
render_size = np.random.randint(self.render_size, 513)
|
82 |
-
target_images = v2.functional.resize(
|
83 |
-
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
|
84 |
-
target_depths = v2.functional.resize(
|
85 |
-
target_depths, render_size, interpolation=0, antialias=True)
|
86 |
-
target_alphas = v2.functional.resize(
|
87 |
-
target_alphas, render_size, interpolation=0, antialias=True)
|
88 |
-
|
89 |
-
crop_params = v2.RandomCrop.get_params(
|
90 |
-
target_images, output_size=(self.render_size, self.render_size))
|
91 |
-
target_images = v2.functional.crop(target_images, *crop_params)
|
92 |
-
target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
|
93 |
-
target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
|
94 |
-
|
95 |
-
lrm_generator_input['render_size'] = render_size
|
96 |
-
lrm_generator_input['crop_params'] = crop_params
|
97 |
-
|
98 |
-
render_gt['target_images'] = target_images.to(self.device)
|
99 |
-
render_gt['target_depths'] = target_depths.to(self.device)
|
100 |
-
render_gt['target_alphas'] = target_alphas.to(self.device)
|
101 |
-
|
102 |
-
return lrm_generator_input, render_gt
|
103 |
-
|
104 |
-
def prepare_validation_batch_data(self, batch):
|
105 |
-
lrm_generator_input = {}
|
106 |
-
|
107 |
-
# input images
|
108 |
-
images = batch['input_images']
|
109 |
-
images = v2.functional.resize(
|
110 |
-
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
111 |
-
|
112 |
-
lrm_generator_input['images'] = images.to(self.device)
|
113 |
-
|
114 |
-
input_c2ws = batch['input_c2ws'].flatten(-2)
|
115 |
-
input_Ks = batch['input_Ks'].flatten(-2)
|
116 |
-
|
117 |
-
input_extrinsics = input_c2ws[:, :, :12]
|
118 |
-
input_intrinsics = torch.stack([
|
119 |
-
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
120 |
-
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
121 |
-
], dim=-1)
|
122 |
-
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
123 |
-
|
124 |
-
lrm_generator_input['cameras'] = cameras.to(self.device)
|
125 |
-
|
126 |
-
render_c2ws = batch['render_c2ws'].flatten(-2)
|
127 |
-
render_Ks = batch['render_Ks'].flatten(-2)
|
128 |
-
render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
|
129 |
-
|
130 |
-
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
|
131 |
-
lrm_generator_input['render_size'] = 384
|
132 |
-
lrm_generator_input['crop_params'] = None
|
133 |
-
|
134 |
-
return lrm_generator_input
|
135 |
-
|
136 |
-
def forward_lrm_generator(
|
137 |
-
self,
|
138 |
-
images,
|
139 |
-
cameras,
|
140 |
-
render_cameras,
|
141 |
-
render_size=192,
|
142 |
-
crop_params=None,
|
143 |
-
chunk_size=1,
|
144 |
-
):
|
145 |
-
planes = torch.utils.checkpoint.checkpoint(
|
146 |
-
self.lrm_generator.forward_planes,
|
147 |
-
images,
|
148 |
-
cameras,
|
149 |
-
use_reentrant=False,
|
150 |
-
)
|
151 |
-
frames = []
|
152 |
-
for i in range(0, render_cameras.shape[1], chunk_size):
|
153 |
-
frames.append(
|
154 |
-
torch.utils.checkpoint.checkpoint(
|
155 |
-
self.lrm_generator.synthesizer,
|
156 |
-
planes,
|
157 |
-
cameras=render_cameras[:, i:i+chunk_size],
|
158 |
-
render_size=render_size,
|
159 |
-
crop_params=crop_params,
|
160 |
-
use_reentrant=False
|
161 |
-
)
|
162 |
-
)
|
163 |
-
frames = {
|
164 |
-
k: torch.cat([r[k] for r in frames], dim=1)
|
165 |
-
for k in frames[0].keys()
|
166 |
-
}
|
167 |
-
return frames
|
168 |
-
|
169 |
-
def forward(self, lrm_generator_input):
|
170 |
-
images = lrm_generator_input['images']
|
171 |
-
cameras = lrm_generator_input['cameras']
|
172 |
-
render_cameras = lrm_generator_input['render_cameras']
|
173 |
-
render_size = lrm_generator_input['render_size']
|
174 |
-
crop_params = lrm_generator_input['crop_params']
|
175 |
-
|
176 |
-
out = self.forward_lrm_generator(
|
177 |
-
images,
|
178 |
-
cameras,
|
179 |
-
render_cameras,
|
180 |
-
render_size=render_size,
|
181 |
-
crop_params=crop_params,
|
182 |
-
chunk_size=1,
|
183 |
-
)
|
184 |
-
render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
|
185 |
-
render_depths = out['images_depth']
|
186 |
-
render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
|
187 |
-
|
188 |
-
out = {
|
189 |
-
'render_images': render_images,
|
190 |
-
'render_depths': render_depths,
|
191 |
-
'render_alphas': render_alphas,
|
192 |
-
}
|
193 |
-
return out
|
194 |
-
|
195 |
-
def training_step(self, batch, batch_idx):
|
196 |
-
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
|
197 |
-
|
198 |
-
render_out = self.forward(lrm_generator_input)
|
199 |
-
|
200 |
-
loss, loss_dict = self.compute_loss(render_out, render_gt)
|
201 |
-
|
202 |
-
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
203 |
-
|
204 |
-
if self.global_step % 1000 == 0 and self.global_rank == 0:
|
205 |
-
B, N, C, H, W = render_gt['target_images'].shape
|
206 |
-
N_in = lrm_generator_input['images'].shape[1]
|
207 |
-
|
208 |
-
input_images = v2.functional.resize(
|
209 |
-
lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
|
210 |
-
input_images = torch.cat(
|
211 |
-
[input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
|
212 |
-
|
213 |
-
input_images = rearrange(
|
214 |
-
input_images, 'b n c h w -> b c h (n w)')
|
215 |
-
target_images = rearrange(
|
216 |
-
render_gt['target_images'], 'b n c h w -> b c h (n w)')
|
217 |
-
render_images = rearrange(
|
218 |
-
render_out['render_images'], 'b n c h w -> b c h (n w)')
|
219 |
-
target_alphas = rearrange(
|
220 |
-
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
221 |
-
render_alphas = rearrange(
|
222 |
-
repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
223 |
-
target_depths = rearrange(
|
224 |
-
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
225 |
-
render_depths = rearrange(
|
226 |
-
repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
227 |
-
MAX_DEPTH = torch.max(target_depths)
|
228 |
-
target_depths = target_depths / MAX_DEPTH * target_alphas
|
229 |
-
render_depths = render_depths / MAX_DEPTH
|
230 |
-
|
231 |
-
grid = torch.cat([
|
232 |
-
input_images,
|
233 |
-
target_images, render_images,
|
234 |
-
target_alphas, render_alphas,
|
235 |
-
target_depths, render_depths,
|
236 |
-
], dim=-2)
|
237 |
-
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
|
238 |
-
|
239 |
-
save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
|
240 |
-
|
241 |
-
return loss
|
242 |
-
|
243 |
-
def compute_loss(self, render_out, render_gt):
|
244 |
-
# NOTE: the rgb value range of OpenLRM is [0, 1]
|
245 |
-
render_images = render_out['render_images']
|
246 |
-
target_images = render_gt['target_images'].to(render_images)
|
247 |
-
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
248 |
-
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
249 |
-
|
250 |
-
loss_mse = F.mse_loss(render_images, target_images)
|
251 |
-
loss_lpips = 2.0 * self.lpips(render_images, target_images)
|
252 |
-
|
253 |
-
render_alphas = render_out['render_alphas']
|
254 |
-
target_alphas = render_gt['target_alphas']
|
255 |
-
loss_mask = F.mse_loss(render_alphas, target_alphas)
|
256 |
-
|
257 |
-
loss = loss_mse + loss_lpips + loss_mask
|
258 |
-
|
259 |
-
prefix = 'train'
|
260 |
-
loss_dict = {}
|
261 |
-
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
|
262 |
-
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
|
263 |
-
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
|
264 |
-
loss_dict.update({f'{prefix}/loss': loss})
|
265 |
-
|
266 |
-
return loss, loss_dict
|
267 |
-
|
268 |
-
@torch.no_grad()
|
269 |
-
def validation_step(self, batch, batch_idx):
|
270 |
-
lrm_generator_input = self.prepare_validation_batch_data(batch)
|
271 |
-
|
272 |
-
render_out = self.forward(lrm_generator_input)
|
273 |
-
render_images = render_out['render_images']
|
274 |
-
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
|
275 |
-
|
276 |
-
self.validation_step_outputs.append(render_images)
|
277 |
-
|
278 |
-
def on_validation_epoch_end(self):
|
279 |
-
images = torch.cat(self.validation_step_outputs, dim=-1)
|
280 |
-
|
281 |
-
all_images = self.all_gather(images)
|
282 |
-
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
|
283 |
-
|
284 |
-
if self.global_rank == 0:
|
285 |
-
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
|
286 |
-
|
287 |
-
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
|
288 |
-
save_image(grid, image_path)
|
289 |
-
print(f"Saved image to {image_path}")
|
290 |
-
|
291 |
-
self.validation_step_outputs.clear()
|
292 |
-
|
293 |
-
def configure_optimizers(self):
|
294 |
-
lr = self.learning_rate
|
295 |
-
|
296 |
-
params = []
|
297 |
-
|
298 |
-
lrm_params_fast, lrm_params_slow = [], []
|
299 |
-
for n, p in self.lrm_generator.named_parameters():
|
300 |
-
if 'adaLN_modulation' in n or 'camera_embedder' in n:
|
301 |
-
lrm_params_fast.append(p)
|
302 |
-
else:
|
303 |
-
lrm_params_slow.append(p)
|
304 |
-
params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
|
305 |
-
params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
|
306 |
-
|
307 |
-
optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
|
308 |
-
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
|
309 |
-
|
310 |
-
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/model_mesh.py
DELETED
@@ -1,325 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from torchvision.transforms import v2
|
6 |
-
from torchvision.utils import make_grid, save_image
|
7 |
-
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
8 |
-
import pytorch_lightning as pl
|
9 |
-
from einops import rearrange, repeat
|
10 |
-
|
11 |
-
from src.utils.train_util import instantiate_from_config
|
12 |
-
|
13 |
-
|
14 |
-
# Regulrarization loss for FlexiCubes
|
15 |
-
def sdf_reg_loss_batch(sdf, all_edges):
|
16 |
-
sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
|
17 |
-
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
|
18 |
-
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
19 |
-
sdf_diff = F.binary_cross_entropy_with_logits(
|
20 |
-
sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
|
21 |
-
F.binary_cross_entropy_with_logits(
|
22 |
-
sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
|
23 |
-
return sdf_diff
|
24 |
-
|
25 |
-
|
26 |
-
class MVRecon(pl.LightningModule):
|
27 |
-
def __init__(
|
28 |
-
self,
|
29 |
-
lrm_generator_config,
|
30 |
-
input_size=256,
|
31 |
-
render_size=512,
|
32 |
-
init_ckpt=None,
|
33 |
-
):
|
34 |
-
super(MVRecon, self).__init__()
|
35 |
-
|
36 |
-
self.input_size = input_size
|
37 |
-
self.render_size = render_size
|
38 |
-
|
39 |
-
# init modules
|
40 |
-
self.lrm_generator = instantiate_from_config(lrm_generator_config)
|
41 |
-
|
42 |
-
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
|
43 |
-
|
44 |
-
# Load weights from pretrained MVRecon model, and use the mlp
|
45 |
-
# weights to initialize the weights of sdf and rgb mlps.
|
46 |
-
if init_ckpt is not None:
|
47 |
-
sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
|
48 |
-
sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
|
49 |
-
sd_fc = {}
|
50 |
-
for k, v in sd.items():
|
51 |
-
if k.startswith('lrm_generator.synthesizer.decoder.net.'):
|
52 |
-
if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
|
53 |
-
# Here we assume the density filed's isosurface threshold is t,
|
54 |
-
# we reverse the sign of density filed to initialize SDF field.
|
55 |
-
# -(w*x + b - t) = (-w)*x + (t - b)
|
56 |
-
if 'weight' in k:
|
57 |
-
sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
|
58 |
-
else:
|
59 |
-
sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
|
60 |
-
sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
|
61 |
-
else:
|
62 |
-
sd_fc[k.replace('net.', 'net_sdf.')] = v
|
63 |
-
sd_fc[k.replace('net.', 'net_rgb.')] = v
|
64 |
-
else:
|
65 |
-
sd_fc[k] = v
|
66 |
-
sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
|
67 |
-
# missing `net_deformation` and `net_weight` parameters
|
68 |
-
self.lrm_generator.load_state_dict(sd_fc, strict=False)
|
69 |
-
print(f'Loaded weights from {init_ckpt}')
|
70 |
-
|
71 |
-
self.validation_step_outputs = []
|
72 |
-
|
73 |
-
def on_fit_start(self):
|
74 |
-
device = torch.device(f'cuda:{self.global_rank}')
|
75 |
-
self.lrm_generator.init_flexicubes_geometry(device)
|
76 |
-
if self.global_rank == 0:
|
77 |
-
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
|
78 |
-
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
|
79 |
-
|
80 |
-
def prepare_batch_data(self, batch):
|
81 |
-
lrm_generator_input = {}
|
82 |
-
render_gt = {}
|
83 |
-
|
84 |
-
# input images
|
85 |
-
images = batch['input_images']
|
86 |
-
images = v2.functional.resize(
|
87 |
-
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
88 |
-
|
89 |
-
lrm_generator_input['images'] = images.to(self.device)
|
90 |
-
|
91 |
-
# input cameras and render cameras
|
92 |
-
input_c2ws = batch['input_c2ws']
|
93 |
-
input_Ks = batch['input_Ks']
|
94 |
-
target_c2ws = batch['target_c2ws']
|
95 |
-
|
96 |
-
render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
|
97 |
-
render_w2cs = torch.linalg.inv(render_c2ws)
|
98 |
-
|
99 |
-
input_extrinsics = input_c2ws.flatten(-2)
|
100 |
-
input_extrinsics = input_extrinsics[:, :, :12]
|
101 |
-
input_intrinsics = input_Ks.flatten(-2)
|
102 |
-
input_intrinsics = torch.stack([
|
103 |
-
input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
|
104 |
-
input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
|
105 |
-
], dim=-1)
|
106 |
-
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
107 |
-
|
108 |
-
# add noise to input_cameras
|
109 |
-
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
|
110 |
-
|
111 |
-
lrm_generator_input['cameras'] = cameras.to(self.device)
|
112 |
-
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
|
113 |
-
|
114 |
-
# target images
|
115 |
-
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
|
116 |
-
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
|
117 |
-
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
|
118 |
-
target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
|
119 |
-
|
120 |
-
render_size = self.render_size
|
121 |
-
target_images = v2.functional.resize(
|
122 |
-
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
|
123 |
-
target_depths = v2.functional.resize(
|
124 |
-
target_depths, render_size, interpolation=0, antialias=True)
|
125 |
-
target_alphas = v2.functional.resize(
|
126 |
-
target_alphas, render_size, interpolation=0, antialias=True)
|
127 |
-
target_normals = v2.functional.resize(
|
128 |
-
target_normals, render_size, interpolation=3, antialias=True)
|
129 |
-
|
130 |
-
lrm_generator_input['render_size'] = render_size
|
131 |
-
|
132 |
-
render_gt['target_images'] = target_images.to(self.device)
|
133 |
-
render_gt['target_depths'] = target_depths.to(self.device)
|
134 |
-
render_gt['target_alphas'] = target_alphas.to(self.device)
|
135 |
-
render_gt['target_normals'] = target_normals.to(self.device)
|
136 |
-
|
137 |
-
return lrm_generator_input, render_gt
|
138 |
-
|
139 |
-
def prepare_validation_batch_data(self, batch):
|
140 |
-
lrm_generator_input = {}
|
141 |
-
|
142 |
-
# input images
|
143 |
-
images = batch['input_images']
|
144 |
-
images = v2.functional.resize(
|
145 |
-
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
146 |
-
|
147 |
-
lrm_generator_input['images'] = images.to(self.device)
|
148 |
-
|
149 |
-
# input cameras
|
150 |
-
input_c2ws = batch['input_c2ws'].flatten(-2)
|
151 |
-
input_Ks = batch['input_Ks'].flatten(-2)
|
152 |
-
|
153 |
-
input_extrinsics = input_c2ws[:, :, :12]
|
154 |
-
input_intrinsics = torch.stack([
|
155 |
-
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
156 |
-
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
157 |
-
], dim=-1)
|
158 |
-
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
159 |
-
|
160 |
-
lrm_generator_input['cameras'] = cameras.to(self.device)
|
161 |
-
|
162 |
-
# render cameras
|
163 |
-
render_c2ws = batch['render_c2ws']
|
164 |
-
render_w2cs = torch.linalg.inv(render_c2ws)
|
165 |
-
|
166 |
-
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
|
167 |
-
lrm_generator_input['render_size'] = 384
|
168 |
-
|
169 |
-
return lrm_generator_input
|
170 |
-
|
171 |
-
def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
|
172 |
-
planes = torch.utils.checkpoint.checkpoint(
|
173 |
-
self.lrm_generator.forward_planes,
|
174 |
-
images,
|
175 |
-
cameras,
|
176 |
-
use_reentrant=False,
|
177 |
-
)
|
178 |
-
out = self.lrm_generator.forward_geometry(
|
179 |
-
planes,
|
180 |
-
render_cameras,
|
181 |
-
render_size,
|
182 |
-
)
|
183 |
-
return out
|
184 |
-
|
185 |
-
def forward(self, lrm_generator_input):
|
186 |
-
images = lrm_generator_input['images']
|
187 |
-
cameras = lrm_generator_input['cameras']
|
188 |
-
render_cameras = lrm_generator_input['render_cameras']
|
189 |
-
render_size = lrm_generator_input['render_size']
|
190 |
-
|
191 |
-
out = self.forward_lrm_generator(
|
192 |
-
images, cameras, render_cameras, render_size=render_size)
|
193 |
-
|
194 |
-
return out
|
195 |
-
|
196 |
-
def training_step(self, batch, batch_idx):
|
197 |
-
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
|
198 |
-
|
199 |
-
render_out = self.forward(lrm_generator_input)
|
200 |
-
|
201 |
-
loss, loss_dict = self.compute_loss(render_out, render_gt)
|
202 |
-
|
203 |
-
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
204 |
-
|
205 |
-
if self.global_step % 1000 == 0 and self.global_rank == 0:
|
206 |
-
B, N, C, H, W = render_gt['target_images'].shape
|
207 |
-
N_in = lrm_generator_input['images'].shape[1]
|
208 |
-
|
209 |
-
target_images = rearrange(
|
210 |
-
render_gt['target_images'], 'b n c h w -> b c h (n w)')
|
211 |
-
render_images = rearrange(
|
212 |
-
render_out['img'], 'b n c h w -> b c h (n w)')
|
213 |
-
target_alphas = rearrange(
|
214 |
-
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
215 |
-
render_alphas = rearrange(
|
216 |
-
repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
217 |
-
target_depths = rearrange(
|
218 |
-
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
219 |
-
render_depths = rearrange(
|
220 |
-
repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
221 |
-
target_normals = rearrange(
|
222 |
-
render_gt['target_normals'], 'b n c h w -> b c h (n w)')
|
223 |
-
render_normals = rearrange(
|
224 |
-
render_out['normal'], 'b n c h w -> b c h (n w)')
|
225 |
-
MAX_DEPTH = torch.max(target_depths)
|
226 |
-
target_depths = target_depths / MAX_DEPTH * target_alphas
|
227 |
-
render_depths = render_depths / MAX_DEPTH
|
228 |
-
|
229 |
-
grid = torch.cat([
|
230 |
-
target_images, render_images,
|
231 |
-
target_alphas, render_alphas,
|
232 |
-
target_depths, render_depths,
|
233 |
-
target_normals, render_normals,
|
234 |
-
], dim=-2)
|
235 |
-
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
|
236 |
-
|
237 |
-
image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
|
238 |
-
save_image(grid, image_path)
|
239 |
-
print(f"Saved image to {image_path}")
|
240 |
-
|
241 |
-
return loss
|
242 |
-
|
243 |
-
def compute_loss(self, render_out, render_gt):
|
244 |
-
# NOTE: the rgb value range of OpenLRM is [0, 1]
|
245 |
-
render_images = render_out['img']
|
246 |
-
target_images = render_gt['target_images'].to(render_images)
|
247 |
-
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
248 |
-
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
249 |
-
loss_mse = F.mse_loss(render_images, target_images)
|
250 |
-
loss_lpips = 2.0 * self.lpips(render_images, target_images)
|
251 |
-
|
252 |
-
render_alphas = render_out['mask']
|
253 |
-
target_alphas = render_gt['target_alphas']
|
254 |
-
loss_mask = F.mse_loss(render_alphas, target_alphas)
|
255 |
-
|
256 |
-
render_depths = render_out['depth']
|
257 |
-
target_depths = render_gt['target_depths']
|
258 |
-
loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
|
259 |
-
|
260 |
-
render_normals = render_out['normal'] * 2.0 - 1.0
|
261 |
-
target_normals = render_gt['target_normals'] * 2.0 - 1.0
|
262 |
-
similarity = (render_normals * target_normals).sum(dim=-3).abs()
|
263 |
-
normal_mask = target_alphas.squeeze(-3)
|
264 |
-
loss_normal = 1 - similarity[normal_mask>0].mean()
|
265 |
-
loss_normal = 0.2 * loss_normal
|
266 |
-
|
267 |
-
# flexicubes regularization loss
|
268 |
-
sdf = render_out['sdf']
|
269 |
-
sdf_reg_loss = render_out['sdf_reg_loss']
|
270 |
-
sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
|
271 |
-
_, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
|
272 |
-
flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
|
273 |
-
flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
|
274 |
-
|
275 |
-
loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
|
276 |
-
|
277 |
-
loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
|
278 |
-
|
279 |
-
prefix = 'train'
|
280 |
-
loss_dict = {}
|
281 |
-
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
|
282 |
-
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
|
283 |
-
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
|
284 |
-
loss_dict.update({f'{prefix}/loss_normal': loss_normal})
|
285 |
-
loss_dict.update({f'{prefix}/loss_depth': loss_depth})
|
286 |
-
loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
|
287 |
-
loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
|
288 |
-
loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
|
289 |
-
loss_dict.update({f'{prefix}/loss': loss})
|
290 |
-
|
291 |
-
return loss, loss_dict
|
292 |
-
|
293 |
-
@torch.no_grad()
|
294 |
-
def validation_step(self, batch, batch_idx):
|
295 |
-
lrm_generator_input = self.prepare_validation_batch_data(batch)
|
296 |
-
|
297 |
-
render_out = self.forward(lrm_generator_input)
|
298 |
-
render_images = render_out['img']
|
299 |
-
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
|
300 |
-
|
301 |
-
self.validation_step_outputs.append(render_images)
|
302 |
-
|
303 |
-
def on_validation_epoch_end(self):
|
304 |
-
images = torch.cat(self.validation_step_outputs, dim=-1)
|
305 |
-
|
306 |
-
all_images = self.all_gather(images)
|
307 |
-
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
|
308 |
-
|
309 |
-
if self.global_rank == 0:
|
310 |
-
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
|
311 |
-
|
312 |
-
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
|
313 |
-
save_image(grid, image_path)
|
314 |
-
print(f"Saved image to {image_path}")
|
315 |
-
|
316 |
-
self.validation_step_outputs.clear()
|
317 |
-
|
318 |
-
def configure_optimizers(self):
|
319 |
-
lr = self.learning_rate
|
320 |
-
|
321 |
-
optimizer = torch.optim.AdamW(
|
322 |
-
self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
|
323 |
-
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
|
324 |
-
|
325 |
-
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/__init__.py
DELETED
File without changes
|
src/models/decoder/__init__.py
DELETED
File without changes
|
src/models/decoder/transformer.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Zexin He
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
|
16 |
-
import torch
|
17 |
-
import torch.nn as nn
|
18 |
-
|
19 |
-
|
20 |
-
class BasicTransformerBlock(nn.Module):
|
21 |
-
"""
|
22 |
-
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
|
23 |
-
"""
|
24 |
-
# use attention from torch.nn.MultiHeadAttention
|
25 |
-
# Block contains a cross-attention layer, a self-attention layer, and a MLP
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
inner_dim: int,
|
29 |
-
cond_dim: int,
|
30 |
-
num_heads: int,
|
31 |
-
eps: float,
|
32 |
-
attn_drop: float = 0.,
|
33 |
-
attn_bias: bool = False,
|
34 |
-
mlp_ratio: float = 4.,
|
35 |
-
mlp_drop: float = 0.,
|
36 |
-
):
|
37 |
-
super().__init__()
|
38 |
-
|
39 |
-
self.norm1 = nn.LayerNorm(inner_dim)
|
40 |
-
self.cross_attn = nn.MultiheadAttention(
|
41 |
-
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
42 |
-
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
43 |
-
self.norm2 = nn.LayerNorm(inner_dim)
|
44 |
-
self.self_attn = nn.MultiheadAttention(
|
45 |
-
embed_dim=inner_dim, num_heads=num_heads,
|
46 |
-
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
47 |
-
self.norm3 = nn.LayerNorm(inner_dim)
|
48 |
-
self.mlp = nn.Sequential(
|
49 |
-
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
50 |
-
nn.GELU(),
|
51 |
-
nn.Dropout(mlp_drop),
|
52 |
-
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
53 |
-
nn.Dropout(mlp_drop),
|
54 |
-
)
|
55 |
-
|
56 |
-
def forward(self, x, cond):
|
57 |
-
# x: [N, L, D]
|
58 |
-
# cond: [N, L_cond, D_cond]
|
59 |
-
x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
|
60 |
-
before_sa = self.norm2(x)
|
61 |
-
x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
|
62 |
-
x = x + self.mlp(self.norm3(x))
|
63 |
-
return x
|
64 |
-
|
65 |
-
|
66 |
-
class TriplaneTransformer(nn.Module):
|
67 |
-
"""
|
68 |
-
Transformer with condition that generates a triplane representation.
|
69 |
-
|
70 |
-
Reference:
|
71 |
-
Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
|
72 |
-
"""
|
73 |
-
def __init__(
|
74 |
-
self,
|
75 |
-
inner_dim: int,
|
76 |
-
image_feat_dim: int,
|
77 |
-
triplane_low_res: int,
|
78 |
-
triplane_high_res: int,
|
79 |
-
triplane_dim: int,
|
80 |
-
num_layers: int,
|
81 |
-
num_heads: int,
|
82 |
-
eps: float = 1e-6,
|
83 |
-
):
|
84 |
-
super().__init__()
|
85 |
-
|
86 |
-
# attributes
|
87 |
-
self.triplane_low_res = triplane_low_res
|
88 |
-
self.triplane_high_res = triplane_high_res
|
89 |
-
self.triplane_dim = triplane_dim
|
90 |
-
|
91 |
-
# modules
|
92 |
-
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
|
93 |
-
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
|
94 |
-
self.layers = nn.ModuleList([
|
95 |
-
BasicTransformerBlock(
|
96 |
-
inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
|
97 |
-
for _ in range(num_layers)
|
98 |
-
])
|
99 |
-
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
100 |
-
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
|
101 |
-
|
102 |
-
def forward(self, image_feats):
|
103 |
-
# image_feats: [N, L_cond, D_cond]
|
104 |
-
|
105 |
-
N = image_feats.shape[0]
|
106 |
-
H = W = self.triplane_low_res
|
107 |
-
L = 3 * H * W
|
108 |
-
|
109 |
-
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
|
110 |
-
for layer in self.layers:
|
111 |
-
x = layer(x, image_feats)
|
112 |
-
x = self.norm(x)
|
113 |
-
|
114 |
-
# separate each plane and apply deconv
|
115 |
-
x = x.view(N, 3, H, W, -1)
|
116 |
-
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
|
117 |
-
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
|
118 |
-
x = self.deconv(x) # [3*N, D', H', W']
|
119 |
-
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
|
120 |
-
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
|
121 |
-
x = x.contiguous()
|
122 |
-
|
123 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/encoder/__init__.py
DELETED
File without changes
|
src/models/encoder/dino.py
DELETED
@@ -1,550 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
""" PyTorch ViT model."""
|
16 |
-
|
17 |
-
|
18 |
-
import collections.abc
|
19 |
-
import math
|
20 |
-
from typing import Dict, List, Optional, Set, Tuple, Union
|
21 |
-
|
22 |
-
import torch
|
23 |
-
from torch import nn
|
24 |
-
|
25 |
-
from transformers.activations import ACT2FN
|
26 |
-
from transformers.modeling_outputs import (
|
27 |
-
BaseModelOutput,
|
28 |
-
BaseModelOutputWithPooling,
|
29 |
-
)
|
30 |
-
from transformers import PreTrainedModel, ViTConfig
|
31 |
-
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
32 |
-
|
33 |
-
|
34 |
-
class ViTEmbeddings(nn.Module):
|
35 |
-
"""
|
36 |
-
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
37 |
-
"""
|
38 |
-
|
39 |
-
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
|
40 |
-
super().__init__()
|
41 |
-
|
42 |
-
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
43 |
-
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
44 |
-
self.patch_embeddings = ViTPatchEmbeddings(config)
|
45 |
-
num_patches = self.patch_embeddings.num_patches
|
46 |
-
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
|
47 |
-
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
48 |
-
self.config = config
|
49 |
-
|
50 |
-
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
51 |
-
"""
|
52 |
-
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
53 |
-
resolution images.
|
54 |
-
|
55 |
-
Source:
|
56 |
-
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
57 |
-
"""
|
58 |
-
|
59 |
-
num_patches = embeddings.shape[1] - 1
|
60 |
-
num_positions = self.position_embeddings.shape[1] - 1
|
61 |
-
if num_patches == num_positions and height == width:
|
62 |
-
return self.position_embeddings
|
63 |
-
class_pos_embed = self.position_embeddings[:, 0]
|
64 |
-
patch_pos_embed = self.position_embeddings[:, 1:]
|
65 |
-
dim = embeddings.shape[-1]
|
66 |
-
h0 = height // self.config.patch_size
|
67 |
-
w0 = width // self.config.patch_size
|
68 |
-
# we add a small number to avoid floating point error in the interpolation
|
69 |
-
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
70 |
-
h0, w0 = h0 + 0.1, w0 + 0.1
|
71 |
-
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
72 |
-
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
73 |
-
patch_pos_embed = nn.functional.interpolate(
|
74 |
-
patch_pos_embed,
|
75 |
-
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
76 |
-
mode="bicubic",
|
77 |
-
align_corners=False,
|
78 |
-
)
|
79 |
-
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
80 |
-
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
81 |
-
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
82 |
-
|
83 |
-
def forward(
|
84 |
-
self,
|
85 |
-
pixel_values: torch.Tensor,
|
86 |
-
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
87 |
-
interpolate_pos_encoding: bool = False,
|
88 |
-
) -> torch.Tensor:
|
89 |
-
batch_size, num_channels, height, width = pixel_values.shape
|
90 |
-
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
91 |
-
|
92 |
-
if bool_masked_pos is not None:
|
93 |
-
seq_length = embeddings.shape[1]
|
94 |
-
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
95 |
-
# replace the masked visual tokens by mask_tokens
|
96 |
-
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
97 |
-
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
98 |
-
|
99 |
-
# add the [CLS] token to the embedded patch tokens
|
100 |
-
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
101 |
-
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
102 |
-
|
103 |
-
# add positional encoding to each token
|
104 |
-
if interpolate_pos_encoding:
|
105 |
-
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
106 |
-
else:
|
107 |
-
embeddings = embeddings + self.position_embeddings
|
108 |
-
|
109 |
-
embeddings = self.dropout(embeddings)
|
110 |
-
|
111 |
-
return embeddings
|
112 |
-
|
113 |
-
|
114 |
-
class ViTPatchEmbeddings(nn.Module):
|
115 |
-
"""
|
116 |
-
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
117 |
-
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
118 |
-
Transformer.
|
119 |
-
"""
|
120 |
-
|
121 |
-
def __init__(self, config):
|
122 |
-
super().__init__()
|
123 |
-
image_size, patch_size = config.image_size, config.patch_size
|
124 |
-
num_channels, hidden_size = config.num_channels, config.hidden_size
|
125 |
-
|
126 |
-
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
127 |
-
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
128 |
-
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
129 |
-
self.image_size = image_size
|
130 |
-
self.patch_size = patch_size
|
131 |
-
self.num_channels = num_channels
|
132 |
-
self.num_patches = num_patches
|
133 |
-
|
134 |
-
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
135 |
-
|
136 |
-
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
137 |
-
batch_size, num_channels, height, width = pixel_values.shape
|
138 |
-
if num_channels != self.num_channels:
|
139 |
-
raise ValueError(
|
140 |
-
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
141 |
-
f" Expected {self.num_channels} but got {num_channels}."
|
142 |
-
)
|
143 |
-
if not interpolate_pos_encoding:
|
144 |
-
if height != self.image_size[0] or width != self.image_size[1]:
|
145 |
-
raise ValueError(
|
146 |
-
f"Input image size ({height}*{width}) doesn't match model"
|
147 |
-
f" ({self.image_size[0]}*{self.image_size[1]})."
|
148 |
-
)
|
149 |
-
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
150 |
-
return embeddings
|
151 |
-
|
152 |
-
|
153 |
-
class ViTSelfAttention(nn.Module):
|
154 |
-
def __init__(self, config: ViTConfig) -> None:
|
155 |
-
super().__init__()
|
156 |
-
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
157 |
-
raise ValueError(
|
158 |
-
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
159 |
-
f"heads {config.num_attention_heads}."
|
160 |
-
)
|
161 |
-
|
162 |
-
self.num_attention_heads = config.num_attention_heads
|
163 |
-
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
164 |
-
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
165 |
-
|
166 |
-
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
167 |
-
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
168 |
-
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
169 |
-
|
170 |
-
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
171 |
-
|
172 |
-
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
173 |
-
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
174 |
-
x = x.view(new_x_shape)
|
175 |
-
return x.permute(0, 2, 1, 3)
|
176 |
-
|
177 |
-
def forward(
|
178 |
-
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
179 |
-
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
180 |
-
mixed_query_layer = self.query(hidden_states)
|
181 |
-
|
182 |
-
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
183 |
-
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
184 |
-
query_layer = self.transpose_for_scores(mixed_query_layer)
|
185 |
-
|
186 |
-
# Take the dot product between "query" and "key" to get the raw attention scores.
|
187 |
-
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
188 |
-
|
189 |
-
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
190 |
-
|
191 |
-
# Normalize the attention scores to probabilities.
|
192 |
-
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
193 |
-
|
194 |
-
# This is actually dropping out entire tokens to attend to, which might
|
195 |
-
# seem a bit unusual, but is taken from the original Transformer paper.
|
196 |
-
attention_probs = self.dropout(attention_probs)
|
197 |
-
|
198 |
-
# Mask heads if we want to
|
199 |
-
if head_mask is not None:
|
200 |
-
attention_probs = attention_probs * head_mask
|
201 |
-
|
202 |
-
context_layer = torch.matmul(attention_probs, value_layer)
|
203 |
-
|
204 |
-
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
205 |
-
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
206 |
-
context_layer = context_layer.view(new_context_layer_shape)
|
207 |
-
|
208 |
-
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
209 |
-
|
210 |
-
return outputs
|
211 |
-
|
212 |
-
|
213 |
-
class ViTSelfOutput(nn.Module):
|
214 |
-
"""
|
215 |
-
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
|
216 |
-
layernorm applied before each block.
|
217 |
-
"""
|
218 |
-
|
219 |
-
def __init__(self, config: ViTConfig) -> None:
|
220 |
-
super().__init__()
|
221 |
-
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
222 |
-
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
223 |
-
|
224 |
-
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
225 |
-
hidden_states = self.dense(hidden_states)
|
226 |
-
hidden_states = self.dropout(hidden_states)
|
227 |
-
|
228 |
-
return hidden_states
|
229 |
-
|
230 |
-
|
231 |
-
class ViTAttention(nn.Module):
|
232 |
-
def __init__(self, config: ViTConfig) -> None:
|
233 |
-
super().__init__()
|
234 |
-
self.attention = ViTSelfAttention(config)
|
235 |
-
self.output = ViTSelfOutput(config)
|
236 |
-
self.pruned_heads = set()
|
237 |
-
|
238 |
-
def prune_heads(self, heads: Set[int]) -> None:
|
239 |
-
if len(heads) == 0:
|
240 |
-
return
|
241 |
-
heads, index = find_pruneable_heads_and_indices(
|
242 |
-
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
243 |
-
)
|
244 |
-
|
245 |
-
# Prune linear layers
|
246 |
-
self.attention.query = prune_linear_layer(self.attention.query, index)
|
247 |
-
self.attention.key = prune_linear_layer(self.attention.key, index)
|
248 |
-
self.attention.value = prune_linear_layer(self.attention.value, index)
|
249 |
-
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
250 |
-
|
251 |
-
# Update hyper params and store pruned heads
|
252 |
-
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
253 |
-
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
254 |
-
self.pruned_heads = self.pruned_heads.union(heads)
|
255 |
-
|
256 |
-
def forward(
|
257 |
-
self,
|
258 |
-
hidden_states: torch.Tensor,
|
259 |
-
head_mask: Optional[torch.Tensor] = None,
|
260 |
-
output_attentions: bool = False,
|
261 |
-
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
262 |
-
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
263 |
-
|
264 |
-
attention_output = self.output(self_outputs[0], hidden_states)
|
265 |
-
|
266 |
-
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
267 |
-
return outputs
|
268 |
-
|
269 |
-
|
270 |
-
class ViTIntermediate(nn.Module):
|
271 |
-
def __init__(self, config: ViTConfig) -> None:
|
272 |
-
super().__init__()
|
273 |
-
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
274 |
-
if isinstance(config.hidden_act, str):
|
275 |
-
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
276 |
-
else:
|
277 |
-
self.intermediate_act_fn = config.hidden_act
|
278 |
-
|
279 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
280 |
-
hidden_states = self.dense(hidden_states)
|
281 |
-
hidden_states = self.intermediate_act_fn(hidden_states)
|
282 |
-
|
283 |
-
return hidden_states
|
284 |
-
|
285 |
-
|
286 |
-
class ViTOutput(nn.Module):
|
287 |
-
def __init__(self, config: ViTConfig) -> None:
|
288 |
-
super().__init__()
|
289 |
-
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
290 |
-
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
291 |
-
|
292 |
-
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
293 |
-
hidden_states = self.dense(hidden_states)
|
294 |
-
hidden_states = self.dropout(hidden_states)
|
295 |
-
|
296 |
-
hidden_states = hidden_states + input_tensor
|
297 |
-
|
298 |
-
return hidden_states
|
299 |
-
|
300 |
-
|
301 |
-
def modulate(x, shift, scale):
|
302 |
-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
303 |
-
|
304 |
-
|
305 |
-
class ViTLayer(nn.Module):
|
306 |
-
"""This corresponds to the Block class in the timm implementation."""
|
307 |
-
|
308 |
-
def __init__(self, config: ViTConfig) -> None:
|
309 |
-
super().__init__()
|
310 |
-
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
311 |
-
self.seq_len_dim = 1
|
312 |
-
self.attention = ViTAttention(config)
|
313 |
-
self.intermediate = ViTIntermediate(config)
|
314 |
-
self.output = ViTOutput(config)
|
315 |
-
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
316 |
-
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
317 |
-
|
318 |
-
self.adaLN_modulation = nn.Sequential(
|
319 |
-
nn.SiLU(),
|
320 |
-
nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
|
321 |
-
)
|
322 |
-
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
323 |
-
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
324 |
-
|
325 |
-
def forward(
|
326 |
-
self,
|
327 |
-
hidden_states: torch.Tensor,
|
328 |
-
adaln_input: torch.Tensor = None,
|
329 |
-
head_mask: Optional[torch.Tensor] = None,
|
330 |
-
output_attentions: bool = False,
|
331 |
-
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
332 |
-
shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
333 |
-
|
334 |
-
self_attention_outputs = self.attention(
|
335 |
-
modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
|
336 |
-
head_mask,
|
337 |
-
output_attentions=output_attentions,
|
338 |
-
)
|
339 |
-
attention_output = self_attention_outputs[0]
|
340 |
-
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
341 |
-
|
342 |
-
# first residual connection
|
343 |
-
hidden_states = attention_output + hidden_states
|
344 |
-
|
345 |
-
# in ViT, layernorm is also applied after self-attention
|
346 |
-
layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
|
347 |
-
layer_output = self.intermediate(layer_output)
|
348 |
-
|
349 |
-
# second residual connection is done here
|
350 |
-
layer_output = self.output(layer_output, hidden_states)
|
351 |
-
|
352 |
-
outputs = (layer_output,) + outputs
|
353 |
-
|
354 |
-
return outputs
|
355 |
-
|
356 |
-
|
357 |
-
class ViTEncoder(nn.Module):
|
358 |
-
def __init__(self, config: ViTConfig) -> None:
|
359 |
-
super().__init__()
|
360 |
-
self.config = config
|
361 |
-
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
|
362 |
-
self.gradient_checkpointing = False
|
363 |
-
|
364 |
-
def forward(
|
365 |
-
self,
|
366 |
-
hidden_states: torch.Tensor,
|
367 |
-
adaln_input: torch.Tensor = None,
|
368 |
-
head_mask: Optional[torch.Tensor] = None,
|
369 |
-
output_attentions: bool = False,
|
370 |
-
output_hidden_states: bool = False,
|
371 |
-
return_dict: bool = True,
|
372 |
-
) -> Union[tuple, BaseModelOutput]:
|
373 |
-
all_hidden_states = () if output_hidden_states else None
|
374 |
-
all_self_attentions = () if output_attentions else None
|
375 |
-
|
376 |
-
for i, layer_module in enumerate(self.layer):
|
377 |
-
if output_hidden_states:
|
378 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
379 |
-
|
380 |
-
layer_head_mask = head_mask[i] if head_mask is not None else None
|
381 |
-
|
382 |
-
if self.gradient_checkpointing and self.training:
|
383 |
-
layer_outputs = self._gradient_checkpointing_func(
|
384 |
-
layer_module.__call__,
|
385 |
-
hidden_states,
|
386 |
-
adaln_input,
|
387 |
-
layer_head_mask,
|
388 |
-
output_attentions,
|
389 |
-
)
|
390 |
-
else:
|
391 |
-
layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
|
392 |
-
|
393 |
-
hidden_states = layer_outputs[0]
|
394 |
-
|
395 |
-
if output_attentions:
|
396 |
-
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
397 |
-
|
398 |
-
if output_hidden_states:
|
399 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
400 |
-
|
401 |
-
if not return_dict:
|
402 |
-
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
403 |
-
return BaseModelOutput(
|
404 |
-
last_hidden_state=hidden_states,
|
405 |
-
hidden_states=all_hidden_states,
|
406 |
-
attentions=all_self_attentions,
|
407 |
-
)
|
408 |
-
|
409 |
-
|
410 |
-
class ViTPreTrainedModel(PreTrainedModel):
|
411 |
-
"""
|
412 |
-
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
413 |
-
models.
|
414 |
-
"""
|
415 |
-
|
416 |
-
config_class = ViTConfig
|
417 |
-
base_model_prefix = "vit"
|
418 |
-
main_input_name = "pixel_values"
|
419 |
-
supports_gradient_checkpointing = True
|
420 |
-
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
|
421 |
-
|
422 |
-
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
423 |
-
"""Initialize the weights"""
|
424 |
-
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
425 |
-
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
426 |
-
# `trunc_normal_cpu` not implemented in `half` issues
|
427 |
-
module.weight.data = nn.init.trunc_normal_(
|
428 |
-
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
429 |
-
).to(module.weight.dtype)
|
430 |
-
if module.bias is not None:
|
431 |
-
module.bias.data.zero_()
|
432 |
-
elif isinstance(module, nn.LayerNorm):
|
433 |
-
module.bias.data.zero_()
|
434 |
-
module.weight.data.fill_(1.0)
|
435 |
-
elif isinstance(module, ViTEmbeddings):
|
436 |
-
module.position_embeddings.data = nn.init.trunc_normal_(
|
437 |
-
module.position_embeddings.data.to(torch.float32),
|
438 |
-
mean=0.0,
|
439 |
-
std=self.config.initializer_range,
|
440 |
-
).to(module.position_embeddings.dtype)
|
441 |
-
|
442 |
-
module.cls_token.data = nn.init.trunc_normal_(
|
443 |
-
module.cls_token.data.to(torch.float32),
|
444 |
-
mean=0.0,
|
445 |
-
std=self.config.initializer_range,
|
446 |
-
).to(module.cls_token.dtype)
|
447 |
-
|
448 |
-
|
449 |
-
class ViTModel(ViTPreTrainedModel):
|
450 |
-
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
|
451 |
-
super().__init__(config)
|
452 |
-
self.config = config
|
453 |
-
|
454 |
-
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
|
455 |
-
self.encoder = ViTEncoder(config)
|
456 |
-
|
457 |
-
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
458 |
-
self.pooler = ViTPooler(config) if add_pooling_layer else None
|
459 |
-
|
460 |
-
# Initialize weights and apply final processing
|
461 |
-
self.post_init()
|
462 |
-
|
463 |
-
def get_input_embeddings(self) -> ViTPatchEmbeddings:
|
464 |
-
return self.embeddings.patch_embeddings
|
465 |
-
|
466 |
-
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
467 |
-
"""
|
468 |
-
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
469 |
-
class PreTrainedModel
|
470 |
-
"""
|
471 |
-
for layer, heads in heads_to_prune.items():
|
472 |
-
self.encoder.layer[layer].attention.prune_heads(heads)
|
473 |
-
|
474 |
-
def forward(
|
475 |
-
self,
|
476 |
-
pixel_values: Optional[torch.Tensor] = None,
|
477 |
-
adaln_input: Optional[torch.Tensor] = None,
|
478 |
-
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
479 |
-
head_mask: Optional[torch.Tensor] = None,
|
480 |
-
output_attentions: Optional[bool] = None,
|
481 |
-
output_hidden_states: Optional[bool] = None,
|
482 |
-
interpolate_pos_encoding: Optional[bool] = None,
|
483 |
-
return_dict: Optional[bool] = None,
|
484 |
-
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
485 |
-
r"""
|
486 |
-
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
487 |
-
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
488 |
-
"""
|
489 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
490 |
-
output_hidden_states = (
|
491 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
492 |
-
)
|
493 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
494 |
-
|
495 |
-
if pixel_values is None:
|
496 |
-
raise ValueError("You have to specify pixel_values")
|
497 |
-
|
498 |
-
# Prepare head mask if needed
|
499 |
-
# 1.0 in head_mask indicate we keep the head
|
500 |
-
# attention_probs has shape bsz x n_heads x N x N
|
501 |
-
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
502 |
-
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
503 |
-
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
504 |
-
|
505 |
-
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
506 |
-
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
507 |
-
if pixel_values.dtype != expected_dtype:
|
508 |
-
pixel_values = pixel_values.to(expected_dtype)
|
509 |
-
|
510 |
-
embedding_output = self.embeddings(
|
511 |
-
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
512 |
-
)
|
513 |
-
|
514 |
-
encoder_outputs = self.encoder(
|
515 |
-
embedding_output,
|
516 |
-
adaln_input=adaln_input,
|
517 |
-
head_mask=head_mask,
|
518 |
-
output_attentions=output_attentions,
|
519 |
-
output_hidden_states=output_hidden_states,
|
520 |
-
return_dict=return_dict,
|
521 |
-
)
|
522 |
-
sequence_output = encoder_outputs[0]
|
523 |
-
sequence_output = self.layernorm(sequence_output)
|
524 |
-
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
525 |
-
|
526 |
-
if not return_dict:
|
527 |
-
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
528 |
-
return head_outputs + encoder_outputs[1:]
|
529 |
-
|
530 |
-
return BaseModelOutputWithPooling(
|
531 |
-
last_hidden_state=sequence_output,
|
532 |
-
pooler_output=pooled_output,
|
533 |
-
hidden_states=encoder_outputs.hidden_states,
|
534 |
-
attentions=encoder_outputs.attentions,
|
535 |
-
)
|
536 |
-
|
537 |
-
|
538 |
-
class ViTPooler(nn.Module):
|
539 |
-
def __init__(self, config: ViTConfig):
|
540 |
-
super().__init__()
|
541 |
-
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
542 |
-
self.activation = nn.Tanh()
|
543 |
-
|
544 |
-
def forward(self, hidden_states):
|
545 |
-
# We "pool" the model by simply taking the hidden state corresponding
|
546 |
-
# to the first token.
|
547 |
-
first_token_tensor = hidden_states[:, 0]
|
548 |
-
pooled_output = self.dense(first_token_tensor)
|
549 |
-
pooled_output = self.activation(pooled_output)
|
550 |
-
return pooled_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/encoder/dino_wrapper.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Zexin He
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
|
16 |
-
import torch.nn as nn
|
17 |
-
from transformers import ViTImageProcessor
|
18 |
-
from einops import rearrange, repeat
|
19 |
-
from .dino import ViTModel
|
20 |
-
|
21 |
-
|
22 |
-
class DinoWrapper(nn.Module):
|
23 |
-
"""
|
24 |
-
Dino v1 wrapper using huggingface transformer implementation.
|
25 |
-
"""
|
26 |
-
def __init__(self, model_name: str, freeze: bool = True):
|
27 |
-
super().__init__()
|
28 |
-
self.model, self.processor = self._build_dino(model_name)
|
29 |
-
self.camera_embedder = nn.Sequential(
|
30 |
-
nn.Linear(16, self.model.config.hidden_size, bias=True),
|
31 |
-
nn.SiLU(),
|
32 |
-
nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
|
33 |
-
)
|
34 |
-
if freeze:
|
35 |
-
self._freeze()
|
36 |
-
|
37 |
-
def forward(self, image, camera):
|
38 |
-
# image: [B, N, C, H, W]
|
39 |
-
# camera: [B, N, D]
|
40 |
-
# RGB image with [0,1] scale and properly sized
|
41 |
-
if image.ndim == 5:
|
42 |
-
image = rearrange(image, 'b n c h w -> (b n) c h w')
|
43 |
-
dtype = image.dtype
|
44 |
-
inputs = self.processor(
|
45 |
-
images=image.float(),
|
46 |
-
return_tensors="pt",
|
47 |
-
do_rescale=False,
|
48 |
-
do_resize=False,
|
49 |
-
).to(self.model.device).to(dtype)
|
50 |
-
# embed camera
|
51 |
-
N = camera.shape[1]
|
52 |
-
camera_embeddings = self.camera_embedder(camera)
|
53 |
-
camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
|
54 |
-
embeddings = camera_embeddings
|
55 |
-
# This resampling of positional embedding uses bicubic interpolation
|
56 |
-
outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
|
57 |
-
last_hidden_states = outputs.last_hidden_state
|
58 |
-
return last_hidden_states
|
59 |
-
|
60 |
-
def _freeze(self):
|
61 |
-
print(f"======== Freezing DinoWrapper ========")
|
62 |
-
self.model.eval()
|
63 |
-
for name, param in self.model.named_parameters():
|
64 |
-
param.requires_grad = False
|
65 |
-
|
66 |
-
@staticmethod
|
67 |
-
def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
|
68 |
-
import requests
|
69 |
-
try:
|
70 |
-
model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
|
71 |
-
processor = ViTImageProcessor.from_pretrained(model_name)
|
72 |
-
return model, processor
|
73 |
-
except requests.exceptions.ProxyError as err:
|
74 |
-
if proxy_error_retries > 0:
|
75 |
-
print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
|
76 |
-
import time
|
77 |
-
time.sleep(proxy_error_cooldown)
|
78 |
-
return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
|
79 |
-
else:
|
80 |
-
raise err
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/__init__.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/camera/__init__.py
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
|
9 |
-
import torch
|
10 |
-
from torch import nn
|
11 |
-
|
12 |
-
|
13 |
-
class Camera(nn.Module):
|
14 |
-
def __init__(self):
|
15 |
-
super(Camera, self).__init__()
|
16 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/camera/perspective_camera.py
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
|
9 |
-
import torch
|
10 |
-
from . import Camera
|
11 |
-
import numpy as np
|
12 |
-
|
13 |
-
|
14 |
-
def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
|
15 |
-
if near_plane is None:
|
16 |
-
near_plane = n
|
17 |
-
return np.array(
|
18 |
-
[[n / x, 0, 0, 0],
|
19 |
-
[0, n / -x, 0, 0],
|
20 |
-
[0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
|
21 |
-
[0, 0, -1, 0]]).astype(np.float32)
|
22 |
-
|
23 |
-
|
24 |
-
class PerspectiveCamera(Camera):
|
25 |
-
def __init__(self, fovy=49.0, device='cuda'):
|
26 |
-
super(PerspectiveCamera, self).__init__()
|
27 |
-
self.device = device
|
28 |
-
focal = np.tan(fovy / 180.0 * np.pi * 0.5)
|
29 |
-
self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
|
30 |
-
|
31 |
-
def project(self, points_bxnx4):
|
32 |
-
out = torch.matmul(
|
33 |
-
points_bxnx4,
|
34 |
-
torch.transpose(self.proj_mtx, 1, 2))
|
35 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/render/__init__.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
class Renderer():
|
4 |
-
def __init__(self):
|
5 |
-
pass
|
6 |
-
|
7 |
-
def forward(self):
|
8 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/render/neural_render.py
DELETED
@@ -1,121 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.nn.functional as F
|
11 |
-
import nvdiffrast.torch as dr
|
12 |
-
from . import Renderer
|
13 |
-
|
14 |
-
_FG_LUT = None
|
15 |
-
|
16 |
-
|
17 |
-
def interpolate(attr, rast, attr_idx, rast_db=None):
|
18 |
-
return dr.interpolate(
|
19 |
-
attr.contiguous(), rast, attr_idx, rast_db=rast_db,
|
20 |
-
diff_attrs=None if rast_db is None else 'all')
|
21 |
-
|
22 |
-
|
23 |
-
def xfm_points(points, matrix, use_python=True):
|
24 |
-
'''Transform points.
|
25 |
-
Args:
|
26 |
-
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
27 |
-
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
28 |
-
use_python: Use PyTorch's torch.matmul (for validation)
|
29 |
-
Returns:
|
30 |
-
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
31 |
-
'''
|
32 |
-
out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
|
33 |
-
if torch.is_anomaly_enabled():
|
34 |
-
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
|
35 |
-
return out
|
36 |
-
|
37 |
-
|
38 |
-
def dot(x, y):
|
39 |
-
return torch.sum(x * y, -1, keepdim=True)
|
40 |
-
|
41 |
-
|
42 |
-
def compute_vertex_normal(v_pos, t_pos_idx):
|
43 |
-
i0 = t_pos_idx[:, 0]
|
44 |
-
i1 = t_pos_idx[:, 1]
|
45 |
-
i2 = t_pos_idx[:, 2]
|
46 |
-
|
47 |
-
v0 = v_pos[i0, :]
|
48 |
-
v1 = v_pos[i1, :]
|
49 |
-
v2 = v_pos[i2, :]
|
50 |
-
|
51 |
-
face_normals = torch.cross(v1 - v0, v2 - v0)
|
52 |
-
|
53 |
-
# Splat face normals to vertices
|
54 |
-
v_nrm = torch.zeros_like(v_pos)
|
55 |
-
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
56 |
-
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
57 |
-
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
58 |
-
|
59 |
-
# Normalize, replace zero (degenerated) normals with some default value
|
60 |
-
v_nrm = torch.where(
|
61 |
-
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
62 |
-
)
|
63 |
-
v_nrm = F.normalize(v_nrm, dim=1)
|
64 |
-
assert torch.all(torch.isfinite(v_nrm))
|
65 |
-
|
66 |
-
return v_nrm
|
67 |
-
|
68 |
-
|
69 |
-
class NeuralRender(Renderer):
|
70 |
-
def __init__(self, device='cuda', camera_model=None):
|
71 |
-
super(NeuralRender, self).__init__()
|
72 |
-
self.device = device
|
73 |
-
self.ctx = dr.RasterizeCudaContext(device=device)
|
74 |
-
self.projection_mtx = None
|
75 |
-
self.camera = camera_model
|
76 |
-
|
77 |
-
def render_mesh(
|
78 |
-
self,
|
79 |
-
mesh_v_pos_bxnx3,
|
80 |
-
mesh_t_pos_idx_fx3,
|
81 |
-
camera_mv_bx4x4,
|
82 |
-
mesh_v_feat_bxnxd,
|
83 |
-
resolution=256,
|
84 |
-
spp=1,
|
85 |
-
device='cuda',
|
86 |
-
hierarchical_mask=False
|
87 |
-
):
|
88 |
-
assert not hierarchical_mask
|
89 |
-
|
90 |
-
mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
|
91 |
-
v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
|
92 |
-
v_pos_clip = self.camera.project(v_pos) # Projection in the camera
|
93 |
-
|
94 |
-
v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
|
95 |
-
|
96 |
-
# Render the image,
|
97 |
-
# Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
|
98 |
-
num_layers = 1
|
99 |
-
mask_pyramid = None
|
100 |
-
assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
|
101 |
-
mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos
|
102 |
-
|
103 |
-
with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
|
104 |
-
for _ in range(num_layers):
|
105 |
-
rast, db = peeler.rasterize_next_layer()
|
106 |
-
gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
|
107 |
-
|
108 |
-
hard_mask = torch.clamp(rast[..., -1:], 0, 1)
|
109 |
-
antialias_mask = dr.antialias(
|
110 |
-
hard_mask.clone().contiguous(), rast, v_pos_clip,
|
111 |
-
mesh_t_pos_idx_fx3)
|
112 |
-
|
113 |
-
depth = gb_feat[..., -2:-1]
|
114 |
-
ori_mesh_feature = gb_feat[..., :-4]
|
115 |
-
|
116 |
-
normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
|
117 |
-
normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
|
118 |
-
normal = F.normalize(normal, dim=-1)
|
119 |
-
normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
|
120 |
-
|
121 |
-
return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/rep_3d/__init__.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import numpy as np
|
11 |
-
|
12 |
-
|
13 |
-
class Geometry():
|
14 |
-
def __init__(self):
|
15 |
-
pass
|
16 |
-
|
17 |
-
def forward(self):
|
18 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/rep_3d/dmtet.py
DELETED
@@ -1,504 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import numpy as np
|
11 |
-
import os
|
12 |
-
from . import Geometry
|
13 |
-
from .dmtet_utils import get_center_boundary_index
|
14 |
-
import torch.nn.functional as F
|
15 |
-
|
16 |
-
|
17 |
-
###############################################################################
|
18 |
-
# DMTet utility functions
|
19 |
-
###############################################################################
|
20 |
-
def create_mt_variable(device):
|
21 |
-
triangle_table = torch.tensor(
|
22 |
-
[
|
23 |
-
[-1, -1, -1, -1, -1, -1],
|
24 |
-
[1, 0, 2, -1, -1, -1],
|
25 |
-
[4, 0, 3, -1, -1, -1],
|
26 |
-
[1, 4, 2, 1, 3, 4],
|
27 |
-
[3, 1, 5, -1, -1, -1],
|
28 |
-
[2, 3, 0, 2, 5, 3],
|
29 |
-
[1, 4, 0, 1, 5, 4],
|
30 |
-
[4, 2, 5, -1, -1, -1],
|
31 |
-
[4, 5, 2, -1, -1, -1],
|
32 |
-
[4, 1, 0, 4, 5, 1],
|
33 |
-
[3, 2, 0, 3, 5, 2],
|
34 |
-
[1, 3, 5, -1, -1, -1],
|
35 |
-
[4, 1, 2, 4, 3, 1],
|
36 |
-
[3, 0, 4, -1, -1, -1],
|
37 |
-
[2, 0, 1, -1, -1, -1],
|
38 |
-
[-1, -1, -1, -1, -1, -1]
|
39 |
-
], dtype=torch.long, device=device)
|
40 |
-
|
41 |
-
num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
|
42 |
-
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
|
43 |
-
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
|
44 |
-
return triangle_table, num_triangles_table, base_tet_edges, v_id
|
45 |
-
|
46 |
-
|
47 |
-
def sort_edges(edges_ex2):
|
48 |
-
with torch.no_grad():
|
49 |
-
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
|
50 |
-
order = order.unsqueeze(dim=1)
|
51 |
-
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
52 |
-
b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
|
53 |
-
return torch.stack([a, b], -1)
|
54 |
-
|
55 |
-
|
56 |
-
###############################################################################
|
57 |
-
# marching tetrahedrons (differentiable)
|
58 |
-
###############################################################################
|
59 |
-
|
60 |
-
def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id):
|
61 |
-
with torch.no_grad():
|
62 |
-
occ_n = sdf_n > 0
|
63 |
-
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
64 |
-
occ_sum = torch.sum(occ_fx4, -1)
|
65 |
-
valid_tets = (occ_sum > 0) & (occ_sum < 4)
|
66 |
-
occ_sum = occ_sum[valid_tets]
|
67 |
-
|
68 |
-
# find all vertices
|
69 |
-
all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
|
70 |
-
all_edges = sort_edges(all_edges)
|
71 |
-
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
72 |
-
|
73 |
-
unique_edges = unique_edges.long()
|
74 |
-
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
75 |
-
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
|
76 |
-
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
|
77 |
-
idx_map = mapping[idx_map] # map edges to verts
|
78 |
-
|
79 |
-
interp_v = unique_edges[mask_edges] # .long()
|
80 |
-
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
|
81 |
-
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
|
82 |
-
edges_to_interp_sdf[:, -1] *= -1
|
83 |
-
|
84 |
-
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
|
85 |
-
|
86 |
-
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
|
87 |
-
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
88 |
-
|
89 |
-
idx_map = idx_map.reshape(-1, 6)
|
90 |
-
|
91 |
-
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
92 |
-
num_triangles = num_triangles_table[tetindex]
|
93 |
-
|
94 |
-
# Generate triangle indices
|
95 |
-
faces = torch.cat(
|
96 |
-
(
|
97 |
-
torch.gather(
|
98 |
-
input=idx_map[num_triangles == 1], dim=1,
|
99 |
-
index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
|
100 |
-
torch.gather(
|
101 |
-
input=idx_map[num_triangles == 2], dim=1,
|
102 |
-
index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
|
103 |
-
), dim=0)
|
104 |
-
return verts, faces
|
105 |
-
|
106 |
-
|
107 |
-
def create_tetmesh_variables(device='cuda'):
|
108 |
-
tet_table = torch.tensor(
|
109 |
-
[[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
110 |
-
[0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1],
|
111 |
-
[1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1],
|
112 |
-
[1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8],
|
113 |
-
[2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1],
|
114 |
-
[2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9],
|
115 |
-
[2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9],
|
116 |
-
[6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9],
|
117 |
-
[3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1],
|
118 |
-
[3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9],
|
119 |
-
[3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9],
|
120 |
-
[5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9],
|
121 |
-
[3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8],
|
122 |
-
[4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8],
|
123 |
-
[4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6],
|
124 |
-
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device)
|
125 |
-
num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device)
|
126 |
-
return tet_table, num_tets_table
|
127 |
-
|
128 |
-
|
129 |
-
def marching_tets_tetmesh(
|
130 |
-
pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
131 |
-
return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
132 |
-
with torch.no_grad():
|
133 |
-
occ_n = sdf_n > 0
|
134 |
-
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
135 |
-
occ_sum = torch.sum(occ_fx4, -1)
|
136 |
-
valid_tets = (occ_sum > 0) & (occ_sum < 4)
|
137 |
-
occ_sum = occ_sum[valid_tets]
|
138 |
-
|
139 |
-
# find all vertices
|
140 |
-
all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
|
141 |
-
all_edges = sort_edges(all_edges)
|
142 |
-
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
143 |
-
|
144 |
-
unique_edges = unique_edges.long()
|
145 |
-
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
146 |
-
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
|
147 |
-
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
|
148 |
-
idx_map = mapping[idx_map] # map edges to verts
|
149 |
-
|
150 |
-
interp_v = unique_edges[mask_edges] # .long()
|
151 |
-
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
|
152 |
-
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
|
153 |
-
edges_to_interp_sdf[:, -1] *= -1
|
154 |
-
|
155 |
-
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
|
156 |
-
|
157 |
-
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
|
158 |
-
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
159 |
-
|
160 |
-
idx_map = idx_map.reshape(-1, 6)
|
161 |
-
|
162 |
-
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
163 |
-
num_triangles = num_triangles_table[tetindex]
|
164 |
-
|
165 |
-
# Generate triangle indices
|
166 |
-
faces = torch.cat(
|
167 |
-
(
|
168 |
-
torch.gather(
|
169 |
-
input=idx_map[num_triangles == 1], dim=1,
|
170 |
-
index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
|
171 |
-
torch.gather(
|
172 |
-
input=idx_map[num_triangles == 2], dim=1,
|
173 |
-
index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
|
174 |
-
), dim=0)
|
175 |
-
if not return_tet_mesh:
|
176 |
-
return verts, faces
|
177 |
-
occupied_verts = ori_v[occ_n]
|
178 |
-
mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1
|
179 |
-
mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda")
|
180 |
-
tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4))
|
181 |
-
|
182 |
-
idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10
|
183 |
-
tet_verts = torch.cat([verts, occupied_verts], 0)
|
184 |
-
num_tets = num_tets_table[tetindex]
|
185 |
-
|
186 |
-
tets = torch.cat(
|
187 |
-
(
|
188 |
-
torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape(
|
189 |
-
-1,
|
190 |
-
4),
|
191 |
-
torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape(
|
192 |
-
-1,
|
193 |
-
4),
|
194 |
-
), dim=0)
|
195 |
-
# add fully occupied tets
|
196 |
-
fully_occupied = occ_fx4.sum(-1) == 4
|
197 |
-
tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0]
|
198 |
-
tets = torch.cat([tets, tet_fully_occupied])
|
199 |
-
|
200 |
-
return verts, faces, tet_verts, tets
|
201 |
-
|
202 |
-
|
203 |
-
###############################################################################
|
204 |
-
# Compact tet grid
|
205 |
-
###############################################################################
|
206 |
-
|
207 |
-
def compact_tets(pos_nx3, sdf_n, tet_fx4):
|
208 |
-
with torch.no_grad():
|
209 |
-
# Find surface tets
|
210 |
-
occ_n = sdf_n > 0
|
211 |
-
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
212 |
-
occ_sum = torch.sum(occ_fx4, -1)
|
213 |
-
valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets
|
214 |
-
|
215 |
-
valid_vtx = tet_fx4[valid_tets].reshape(-1)
|
216 |
-
unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True)
|
217 |
-
new_pos = pos_nx3[unique_vtx]
|
218 |
-
new_sdf = sdf_n[unique_vtx]
|
219 |
-
new_tets = idx_map.reshape(-1, 4)
|
220 |
-
return new_pos, new_sdf, new_tets
|
221 |
-
|
222 |
-
|
223 |
-
###############################################################################
|
224 |
-
# Subdivide volume
|
225 |
-
###############################################################################
|
226 |
-
|
227 |
-
def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
|
228 |
-
device = tet_pos_bxnx3.device
|
229 |
-
# get new verts
|
230 |
-
tet_fx4 = tet_bxfx4[0]
|
231 |
-
edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3]
|
232 |
-
all_edges = tet_fx4[:, edges].reshape(-1, 2)
|
233 |
-
all_edges = sort_edges(all_edges)
|
234 |
-
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
235 |
-
idx_map = idx_map + tet_pos_bxnx3.shape[1]
|
236 |
-
all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1)
|
237 |
-
mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape(
|
238 |
-
all_values.shape[0], -1, 2,
|
239 |
-
all_values.shape[-1]).mean(2)
|
240 |
-
new_v = torch.cat([all_values, mid_points_pos], 1)
|
241 |
-
new_v, new_sdf = new_v[..., :3], new_v[..., 3]
|
242 |
-
|
243 |
-
# get new tets
|
244 |
-
|
245 |
-
idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3]
|
246 |
-
idx_ab = idx_map[0::6]
|
247 |
-
idx_ac = idx_map[1::6]
|
248 |
-
idx_ad = idx_map[2::6]
|
249 |
-
idx_bc = idx_map[3::6]
|
250 |
-
idx_bd = idx_map[4::6]
|
251 |
-
idx_cd = idx_map[5::6]
|
252 |
-
|
253 |
-
tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1)
|
254 |
-
tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1)
|
255 |
-
tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1)
|
256 |
-
tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1)
|
257 |
-
tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1)
|
258 |
-
tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1)
|
259 |
-
tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1)
|
260 |
-
tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1)
|
261 |
-
|
262 |
-
tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0)
|
263 |
-
tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1)
|
264 |
-
tet = tet_np.long().to(device)
|
265 |
-
|
266 |
-
return new_v, tet, new_sdf
|
267 |
-
|
268 |
-
|
269 |
-
###############################################################################
|
270 |
-
# Adjacency
|
271 |
-
###############################################################################
|
272 |
-
def tet_to_tet_adj_sparse(tet_tx4):
|
273 |
-
# include self connection!!!!!!!!!!!!!!!!!!!
|
274 |
-
with torch.no_grad():
|
275 |
-
t = tet_tx4.shape[0]
|
276 |
-
device = tet_tx4.device
|
277 |
-
idx_array = torch.LongTensor(
|
278 |
-
[0, 1, 2,
|
279 |
-
1, 0, 3,
|
280 |
-
2, 3, 0,
|
281 |
-
3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3)
|
282 |
-
|
283 |
-
# get all faces
|
284 |
-
all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape(
|
285 |
-
-1,
|
286 |
-
3) # (tx4, 3)
|
287 |
-
all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1)
|
288 |
-
# sort and group
|
289 |
-
all_faces_sorted, _ = torch.sort(all_faces, dim=1)
|
290 |
-
|
291 |
-
all_faces_unique, inverse_indices, counts = torch.unique(
|
292 |
-
all_faces_sorted, dim=0, return_counts=True,
|
293 |
-
return_inverse=True)
|
294 |
-
tet_face_fx3 = all_faces_unique[counts == 2]
|
295 |
-
counts = counts[inverse_indices] # tx4
|
296 |
-
valid = (counts == 2)
|
297 |
-
|
298 |
-
group = inverse_indices[valid]
|
299 |
-
# print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape)
|
300 |
-
_, indices = torch.sort(group)
|
301 |
-
all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices]
|
302 |
-
tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1)
|
303 |
-
|
304 |
-
tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])])
|
305 |
-
adj_self = torch.arange(t, device=tet_tx4.device)
|
306 |
-
adj_self = torch.stack([adj_self, adj_self], -1)
|
307 |
-
tet_adj_idx = torch.cat([tet_adj_idx, adj_self])
|
308 |
-
|
309 |
-
tet_adj_idx = torch.unique(tet_adj_idx, dim=0)
|
310 |
-
values = torch.ones(
|
311 |
-
tet_adj_idx.shape[0], device=tet_tx4.device).float()
|
312 |
-
adj_sparse = torch.sparse.FloatTensor(
|
313 |
-
tet_adj_idx.t(), values, torch.Size([t, t]))
|
314 |
-
|
315 |
-
# normalization
|
316 |
-
neighbor_num = 1.0 / torch.sparse.sum(
|
317 |
-
adj_sparse, dim=1).to_dense()
|
318 |
-
values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0])
|
319 |
-
adj_sparse = torch.sparse.FloatTensor(
|
320 |
-
tet_adj_idx.t(), values, torch.Size([t, t]))
|
321 |
-
return adj_sparse
|
322 |
-
|
323 |
-
|
324 |
-
###############################################################################
|
325 |
-
# Compact grid
|
326 |
-
###############################################################################
|
327 |
-
|
328 |
-
def get_tet_bxfx4x3(bxnxz, bxfx4):
|
329 |
-
n_batch, z = bxnxz.shape[0], bxnxz.shape[2]
|
330 |
-
gather_input = bxnxz.unsqueeze(2).expand(
|
331 |
-
n_batch, bxnxz.shape[1], 4, z)
|
332 |
-
gather_index = bxfx4.unsqueeze(-1).expand(
|
333 |
-
n_batch, bxfx4.shape[1], 4, z).long()
|
334 |
-
tet_bxfx4xz = torch.gather(
|
335 |
-
input=gather_input, dim=1, index=gather_index)
|
336 |
-
|
337 |
-
return tet_bxfx4xz
|
338 |
-
|
339 |
-
|
340 |
-
def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
|
341 |
-
with torch.no_grad():
|
342 |
-
assert tet_pos_bxnx3.shape[0] == 1
|
343 |
-
|
344 |
-
occ = grid_sdf[0] > 0
|
345 |
-
occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1)
|
346 |
-
mask = (occ_sum > 0) & (occ_sum < 4)
|
347 |
-
|
348 |
-
# build connectivity graph
|
349 |
-
adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0])
|
350 |
-
mask = mask.float().unsqueeze(-1)
|
351 |
-
|
352 |
-
# Include a one ring of neighbors
|
353 |
-
for i in range(1):
|
354 |
-
mask = torch.sparse.mm(adj_matrix, mask)
|
355 |
-
mask = mask.squeeze(-1) > 0
|
356 |
-
|
357 |
-
mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long)
|
358 |
-
new_tet_bxfx4 = tet_bxfx4[:, mask].long()
|
359 |
-
selected_verts_idx = torch.unique(new_tet_bxfx4)
|
360 |
-
new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx]
|
361 |
-
mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device)
|
362 |
-
new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape)
|
363 |
-
new_grid_sdf = grid_sdf[:, selected_verts_idx]
|
364 |
-
return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf
|
365 |
-
|
366 |
-
|
367 |
-
###############################################################################
|
368 |
-
# Regularizer
|
369 |
-
###############################################################################
|
370 |
-
|
371 |
-
def sdf_reg_loss(sdf, all_edges):
|
372 |
-
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2)
|
373 |
-
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
|
374 |
-
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
375 |
-
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(
|
376 |
-
sdf_f1x6x2[..., 0],
|
377 |
-
(sdf_f1x6x2[..., 1] > 0).float()) + \
|
378 |
-
torch.nn.functional.binary_cross_entropy_with_logits(
|
379 |
-
sdf_f1x6x2[..., 1],
|
380 |
-
(sdf_f1x6x2[..., 0] > 0).float())
|
381 |
-
return sdf_diff
|
382 |
-
|
383 |
-
|
384 |
-
def sdf_reg_loss_batch(sdf, all_edges):
|
385 |
-
sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
|
386 |
-
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
|
387 |
-
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
388 |
-
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
|
389 |
-
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
|
390 |
-
return sdf_diff
|
391 |
-
|
392 |
-
|
393 |
-
###############################################################################
|
394 |
-
# Geometry interface
|
395 |
-
###############################################################################
|
396 |
-
class DMTetGeometry(Geometry):
|
397 |
-
def __init__(
|
398 |
-
self, grid_res=64, scale=2.0, device='cuda', renderer=None,
|
399 |
-
render_type='neural_render', args=None):
|
400 |
-
super(DMTetGeometry, self).__init__()
|
401 |
-
self.grid_res = grid_res
|
402 |
-
self.device = device
|
403 |
-
self.args = args
|
404 |
-
tets = np.load('data/tets/%d_compress.npz' % (grid_res))
|
405 |
-
self.verts = torch.from_numpy(tets['vertices']).float().to(self.device)
|
406 |
-
# Make sure the tet is zero-centered and length is equal to 1
|
407 |
-
length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0]
|
408 |
-
length = length.max()
|
409 |
-
mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0
|
410 |
-
self.verts = (self.verts - mid.unsqueeze(dim=0)) / length
|
411 |
-
if isinstance(scale, list):
|
412 |
-
self.verts[:, 0] = self.verts[:, 0] * scale[0]
|
413 |
-
self.verts[:, 1] = self.verts[:, 1] * scale[1]
|
414 |
-
self.verts[:, 2] = self.verts[:, 2] * scale[1]
|
415 |
-
else:
|
416 |
-
self.verts = self.verts * scale
|
417 |
-
self.indices = torch.from_numpy(tets['tets']).long().to(self.device)
|
418 |
-
self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device)
|
419 |
-
self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device)
|
420 |
-
# Parameters for regularization computation
|
421 |
-
edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device)
|
422 |
-
all_edges = self.indices[:, edges].reshape(-1, 2)
|
423 |
-
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
|
424 |
-
self.all_edges = torch.unique(all_edges_sorted, dim=0)
|
425 |
-
|
426 |
-
# Parameters used for fix boundary sdf
|
427 |
-
self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts)
|
428 |
-
self.renderer = renderer
|
429 |
-
self.render_type = render_type
|
430 |
-
|
431 |
-
def getAABB(self):
|
432 |
-
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
433 |
-
|
434 |
-
def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
|
435 |
-
if indices is None:
|
436 |
-
indices = self.indices
|
437 |
-
verts, faces = marching_tets(
|
438 |
-
v_deformed_nx3, sdf_n, indices, self.triangle_table,
|
439 |
-
self.num_triangles_table, self.base_tet_edges, self.v_id)
|
440 |
-
faces = torch.cat(
|
441 |
-
[faces[:, 0:1],
|
442 |
-
faces[:, 2:3],
|
443 |
-
faces[:, 1:2], ], dim=-1)
|
444 |
-
return verts, faces
|
445 |
-
|
446 |
-
def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
|
447 |
-
if indices is None:
|
448 |
-
indices = self.indices
|
449 |
-
verts, faces, tet_verts, tets = marching_tets_tetmesh(
|
450 |
-
v_deformed_nx3, sdf_n, indices, self.triangle_table,
|
451 |
-
self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True,
|
452 |
-
num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3)
|
453 |
-
faces = torch.cat(
|
454 |
-
[faces[:, 0:1],
|
455 |
-
faces[:, 2:3],
|
456 |
-
faces[:, 1:2], ], dim=-1)
|
457 |
-
return verts, faces, tet_verts, tets
|
458 |
-
|
459 |
-
def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
|
460 |
-
return_value = dict()
|
461 |
-
if self.render_type == 'neural_render':
|
462 |
-
tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
|
463 |
-
mesh_v_nx3.unsqueeze(dim=0),
|
464 |
-
mesh_f_fx3.int(),
|
465 |
-
camera_mv_bx4x4,
|
466 |
-
mesh_v_nx3.unsqueeze(dim=0),
|
467 |
-
resolution=resolution,
|
468 |
-
device=self.device,
|
469 |
-
hierarchical_mask=hierarchical_mask
|
470 |
-
)
|
471 |
-
|
472 |
-
return_value['tex_pos'] = tex_pos
|
473 |
-
return_value['mask'] = mask
|
474 |
-
return_value['hard_mask'] = hard_mask
|
475 |
-
return_value['rast'] = rast
|
476 |
-
return_value['v_pos_clip'] = v_pos_clip
|
477 |
-
return_value['mask_pyramid'] = mask_pyramid
|
478 |
-
return_value['depth'] = depth
|
479 |
-
else:
|
480 |
-
raise NotImplementedError
|
481 |
-
|
482 |
-
return return_value
|
483 |
-
|
484 |
-
def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
|
485 |
-
# Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
|
486 |
-
v_list = []
|
487 |
-
f_list = []
|
488 |
-
n_batch = v_deformed_bxnx3.shape[0]
|
489 |
-
all_render_output = []
|
490 |
-
for i_batch in range(n_batch):
|
491 |
-
verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
|
492 |
-
v_list.append(verts_nx3)
|
493 |
-
f_list.append(faces_fx3)
|
494 |
-
render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
|
495 |
-
all_render_output.append(render_output)
|
496 |
-
|
497 |
-
# Concatenate all render output
|
498 |
-
return_keys = all_render_output[0].keys()
|
499 |
-
return_value = dict()
|
500 |
-
for k in return_keys:
|
501 |
-
value = [v[k] for v in all_render_output]
|
502 |
-
return_value[k] = value
|
503 |
-
# We can do concatenation outside of the render
|
504 |
-
return return_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/rep_3d/dmtet_utils.py
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
|
9 |
-
import torch
|
10 |
-
|
11 |
-
|
12 |
-
def get_center_boundary_index(verts):
|
13 |
-
length_ = torch.sum(verts ** 2, dim=-1)
|
14 |
-
center_idx = torch.argmin(length_)
|
15 |
-
boundary_neg = verts == verts.max()
|
16 |
-
boundary_pos = verts == verts.min()
|
17 |
-
boundary = torch.bitwise_or(boundary_pos, boundary_neg)
|
18 |
-
boundary = torch.sum(boundary.float(), dim=-1)
|
19 |
-
boundary_idx = torch.nonzero(boundary)
|
20 |
-
return center_idx, boundary_idx.squeeze(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/rep_3d/extract_texture_map.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import xatlas
|
11 |
-
import numpy as np
|
12 |
-
import nvdiffrast.torch as dr
|
13 |
-
|
14 |
-
|
15 |
-
# ==============================================================================================
|
16 |
-
def interpolate(attr, rast, attr_idx, rast_db=None):
|
17 |
-
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
|
18 |
-
|
19 |
-
|
20 |
-
def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
|
21 |
-
vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
|
22 |
-
|
23 |
-
# Convert to tensors
|
24 |
-
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
25 |
-
|
26 |
-
uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
|
27 |
-
mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
|
28 |
-
# mesh_v_tex. ture
|
29 |
-
uv_clip = uvs[None, ...] * 2.0 - 1.0
|
30 |
-
|
31 |
-
# pad to four component coordinate
|
32 |
-
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
|
33 |
-
|
34 |
-
# rasterize
|
35 |
-
rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
|
36 |
-
|
37 |
-
# Interpolate world space position
|
38 |
-
gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
|
39 |
-
mask = rast[..., 3:4] > 0
|
40 |
-
return uvs, mesh_tex_idx, gb_pos, mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/rep_3d/flexicubes.py
DELETED
@@ -1,579 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
import torch
|
9 |
-
from .tables import *
|
10 |
-
|
11 |
-
__all__ = [
|
12 |
-
'FlexiCubes'
|
13 |
-
]
|
14 |
-
|
15 |
-
|
16 |
-
class FlexiCubes:
|
17 |
-
"""
|
18 |
-
This class implements the FlexiCubes method for extracting meshes from scalar fields.
|
19 |
-
It maintains a series of lookup tables and indices to support the mesh extraction process.
|
20 |
-
FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
|
21 |
-
the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
|
22 |
-
the surface representation through gradient-based optimization.
|
23 |
-
|
24 |
-
During instantiation, the class loads DMC tables from a file and transforms them into
|
25 |
-
PyTorch tensors on the specified device.
|
26 |
-
|
27 |
-
Attributes:
|
28 |
-
device (str): Specifies the computational device (default is "cuda").
|
29 |
-
dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
|
30 |
-
associated with each dual vertex in 256 Marching Cubes (MC) configurations.
|
31 |
-
num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
|
32 |
-
the 256 MC configurations.
|
33 |
-
check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
|
34 |
-
of the DMC configurations.
|
35 |
-
tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
|
36 |
-
quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
|
37 |
-
along one diagonal.
|
38 |
-
quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
|
39 |
-
two triangles along the other diagonal.
|
40 |
-
quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
|
41 |
-
during training by connecting all edges to their midpoints.
|
42 |
-
cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
|
43 |
-
eight corners in 3D space, ordered starting from the origin (0,0,0),
|
44 |
-
moving along the x-axis, then y-axis, and finally z-axis.
|
45 |
-
Used as a blueprint for generating a voxel grid.
|
46 |
-
cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
|
47 |
-
to retrieve the case id.
|
48 |
-
cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
|
49 |
-
Used to retrieve edge vertices in DMC.
|
50 |
-
edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
|
51 |
-
their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
|
52 |
-
first edge is oriented along the x-axis.
|
53 |
-
dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
|
54 |
-
across four adjacent cubes to the shared faces of these cubes. For instance,
|
55 |
-
dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
|
56 |
-
the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
|
57 |
-
This tensor is only utilized during isosurface tetrahedralization.
|
58 |
-
adj_pairs (torch.Tensor):
|
59 |
-
A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
|
60 |
-
qef_reg_scale (float):
|
61 |
-
The scaling factor applied to the regularization loss to prevent issues with singularity
|
62 |
-
when solving the QEF. This parameter is only used when a 'grad_func' is specified.
|
63 |
-
weight_scale (float):
|
64 |
-
The scale of weights in FlexiCubes. Should be between 0 and 1.
|
65 |
-
"""
|
66 |
-
|
67 |
-
def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
|
68 |
-
|
69 |
-
self.device = device
|
70 |
-
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
|
71 |
-
self.num_vd_table = torch.tensor(num_vd_table,
|
72 |
-
dtype=torch.long, device=device, requires_grad=False)
|
73 |
-
self.check_table = torch.tensor(
|
74 |
-
check_table,
|
75 |
-
dtype=torch.long, device=device, requires_grad=False)
|
76 |
-
|
77 |
-
self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
|
78 |
-
self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
79 |
-
self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
80 |
-
self.quad_split_train = torch.tensor(
|
81 |
-
[0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
|
82 |
-
|
83 |
-
self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
|
84 |
-
1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
|
85 |
-
self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
|
86 |
-
self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
|
87 |
-
2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
|
88 |
-
|
89 |
-
self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
|
90 |
-
dtype=torch.long, device=device)
|
91 |
-
self.dir_faces_table = torch.tensor([
|
92 |
-
[[5, 4], [3, 2], [4, 5], [2, 3]],
|
93 |
-
[[5, 4], [1, 0], [4, 5], [0, 1]],
|
94 |
-
[[3, 2], [1, 0], [2, 3], [0, 1]]
|
95 |
-
], dtype=torch.long, device=device)
|
96 |
-
self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
|
97 |
-
self.qef_reg_scale = qef_reg_scale
|
98 |
-
self.weight_scale = weight_scale
|
99 |
-
|
100 |
-
def construct_voxel_grid(self, res):
|
101 |
-
"""
|
102 |
-
Generates a voxel grid based on the specified resolution.
|
103 |
-
|
104 |
-
Args:
|
105 |
-
res (int or list[int]): The resolution of the voxel grid. If an integer
|
106 |
-
is provided, it is used for all three dimensions. If a list or tuple
|
107 |
-
of 3 integers is provided, they define the resolution for the x,
|
108 |
-
y, and z dimensions respectively.
|
109 |
-
|
110 |
-
Returns:
|
111 |
-
(torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
|
112 |
-
cube corners (index into vertices) of the constructed voxel grid.
|
113 |
-
The vertices are centered at the origin, with the length of each
|
114 |
-
dimension in the grid being one.
|
115 |
-
"""
|
116 |
-
base_cube_f = torch.arange(8).to(self.device)
|
117 |
-
if isinstance(res, int):
|
118 |
-
res = (res, res, res)
|
119 |
-
voxel_grid_template = torch.ones(res, device=self.device)
|
120 |
-
|
121 |
-
res = torch.tensor([res], dtype=torch.float, device=self.device)
|
122 |
-
coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
|
123 |
-
verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
|
124 |
-
cubes = (base_cube_f.unsqueeze(0) +
|
125 |
-
torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
|
126 |
-
|
127 |
-
verts_rounded = torch.round(verts * 10**5) / (10**5)
|
128 |
-
verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
|
129 |
-
cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
|
130 |
-
|
131 |
-
return verts_unique - 0.5, cubes
|
132 |
-
|
133 |
-
def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
|
134 |
-
gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
|
135 |
-
r"""
|
136 |
-
Main function for mesh extraction from scalar field using FlexiCubes. This function converts
|
137 |
-
discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
|
138 |
-
to triangle or tetrahedral meshes using a differentiable operation as described in
|
139 |
-
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
|
140 |
-
mesh quality and geometric fidelity by adjusting the surface representation based on gradient
|
141 |
-
optimization. The output surface is differentiable with respect to the input vertex positions,
|
142 |
-
scalar field values, and weight parameters.
|
143 |
-
|
144 |
-
If you intend to extract a surface mesh from a fixed Signed Distance Field without the
|
145 |
-
optimization of parameters, it is suggested to provide the "grad_func" which should
|
146 |
-
return the surface gradient at any given 3D position. When grad_func is provided, the process
|
147 |
-
to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
|
148 |
-
described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
|
149 |
-
Please note, this approach is non-differentiable.
|
150 |
-
|
151 |
-
For more details and example usage in optimization, refer to the
|
152 |
-
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
|
153 |
-
|
154 |
-
Args:
|
155 |
-
x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
|
156 |
-
s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
|
157 |
-
denote that the corresponding vertex resides inside the isosurface. This affects
|
158 |
-
the directions of the extracted triangle faces and volume to be tetrahedralized.
|
159 |
-
cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
|
160 |
-
res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
|
161 |
-
is used for all three dimensions. If a list or tuple of 3 integers is provided, they
|
162 |
-
specify the resolution for the x, y, and z dimensions respectively.
|
163 |
-
beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
|
164 |
-
vertices positioning. Defaults to uniform value for all edges.
|
165 |
-
alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
|
166 |
-
vertices positioning. Defaults to uniform value for all vertices.
|
167 |
-
gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
|
168 |
-
quadrilaterals into triangles. Defaults to uniform value for all cubes.
|
169 |
-
training (bool, optional): If set to True, applies differentiable quad splitting for
|
170 |
-
training. Defaults to False.
|
171 |
-
output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
|
172 |
-
outputs a triangular mesh. Defaults to False.
|
173 |
-
grad_func (callable, optional): A function to compute the surface gradient at specified
|
174 |
-
3D positions (input: Nx3 positions). The function should return gradients as an Nx3
|
175 |
-
tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
|
176 |
-
|
177 |
-
Returns:
|
178 |
-
(torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
|
179 |
-
- Vertices for the extracted triangular/tetrahedral mesh.
|
180 |
-
- Faces for the extracted triangular/tetrahedral mesh.
|
181 |
-
- Regularizer L_dev, computed per dual vertex.
|
182 |
-
|
183 |
-
.. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
|
184 |
-
https://research.nvidia.com/labs/toronto-ai/flexicubes/
|
185 |
-
.. _Manifold Dual Contouring:
|
186 |
-
https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
|
187 |
-
"""
|
188 |
-
|
189 |
-
surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
|
190 |
-
if surf_cubes.sum() == 0:
|
191 |
-
return torch.zeros(
|
192 |
-
(0, 3),
|
193 |
-
device=self.device), torch.zeros(
|
194 |
-
(0, 4),
|
195 |
-
dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
|
196 |
-
(0, 3),
|
197 |
-
dtype=torch.long, device=self.device), torch.zeros(
|
198 |
-
(0),
|
199 |
-
device=self.device)
|
200 |
-
beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
|
201 |
-
|
202 |
-
case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
|
203 |
-
|
204 |
-
surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
|
205 |
-
|
206 |
-
vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
|
207 |
-
x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
|
208 |
-
vertices, faces, s_edges, edge_indices = self._triangulate(
|
209 |
-
s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
|
210 |
-
if not output_tetmesh:
|
211 |
-
return vertices, faces, L_dev
|
212 |
-
else:
|
213 |
-
vertices, tets = self._tetrahedralize(
|
214 |
-
x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
|
215 |
-
surf_cubes, training)
|
216 |
-
return vertices, tets, L_dev
|
217 |
-
|
218 |
-
def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
|
219 |
-
"""
|
220 |
-
Regularizer L_dev as in Equation 8
|
221 |
-
"""
|
222 |
-
dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
|
223 |
-
mean_l2 = torch.zeros_like(vd[:, 0])
|
224 |
-
mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
|
225 |
-
mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
|
226 |
-
return mad
|
227 |
-
|
228 |
-
def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
|
229 |
-
"""
|
230 |
-
Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
|
231 |
-
"""
|
232 |
-
n_cubes = surf_cubes.shape[0]
|
233 |
-
|
234 |
-
if beta_fx12 is not None:
|
235 |
-
beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
|
236 |
-
else:
|
237 |
-
beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
|
238 |
-
|
239 |
-
if alpha_fx8 is not None:
|
240 |
-
alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
|
241 |
-
else:
|
242 |
-
alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
|
243 |
-
|
244 |
-
if gamma_f is not None:
|
245 |
-
gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
|
246 |
-
else:
|
247 |
-
gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
|
248 |
-
|
249 |
-
return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
|
250 |
-
|
251 |
-
@torch.no_grad()
|
252 |
-
def _get_case_id(self, occ_fx8, surf_cubes, res):
|
253 |
-
"""
|
254 |
-
Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
|
255 |
-
ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
|
256 |
-
supplementary material. It should be noted that this function assumes a regular grid.
|
257 |
-
"""
|
258 |
-
case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
|
259 |
-
|
260 |
-
problem_config = self.check_table.to(self.device)[case_ids]
|
261 |
-
to_check = problem_config[..., 0] == 1
|
262 |
-
problem_config = problem_config[to_check]
|
263 |
-
if not isinstance(res, (list, tuple)):
|
264 |
-
res = [res, res, res]
|
265 |
-
|
266 |
-
# The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
|
267 |
-
# 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
|
268 |
-
# This allows efficient checking on adjacent cubes.
|
269 |
-
problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
|
270 |
-
vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
|
271 |
-
vol_idx_problem = vol_idx[surf_cubes][to_check]
|
272 |
-
problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
|
273 |
-
vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
|
274 |
-
|
275 |
-
within_range = (
|
276 |
-
vol_idx_problem_adj[..., 0] >= 0) & (
|
277 |
-
vol_idx_problem_adj[..., 0] < res[0]) & (
|
278 |
-
vol_idx_problem_adj[..., 1] >= 0) & (
|
279 |
-
vol_idx_problem_adj[..., 1] < res[1]) & (
|
280 |
-
vol_idx_problem_adj[..., 2] >= 0) & (
|
281 |
-
vol_idx_problem_adj[..., 2] < res[2])
|
282 |
-
|
283 |
-
vol_idx_problem = vol_idx_problem[within_range]
|
284 |
-
vol_idx_problem_adj = vol_idx_problem_adj[within_range]
|
285 |
-
problem_config = problem_config[within_range]
|
286 |
-
problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
|
287 |
-
vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
|
288 |
-
# If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
|
289 |
-
to_invert = (problem_config_adj[..., 0] == 1)
|
290 |
-
idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
|
291 |
-
case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
|
292 |
-
return case_ids
|
293 |
-
|
294 |
-
@torch.no_grad()
|
295 |
-
def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
|
296 |
-
"""
|
297 |
-
Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
|
298 |
-
can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
|
299 |
-
and marks the cube edges with this index.
|
300 |
-
"""
|
301 |
-
occ_n = s_n < 0
|
302 |
-
all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
|
303 |
-
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
304 |
-
|
305 |
-
unique_edges = unique_edges.long()
|
306 |
-
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
307 |
-
|
308 |
-
surf_edges_mask = mask_edges[_idx_map]
|
309 |
-
counts = counts[_idx_map]
|
310 |
-
|
311 |
-
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
|
312 |
-
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
|
313 |
-
# Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
|
314 |
-
# for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
|
315 |
-
idx_map = mapping[_idx_map]
|
316 |
-
surf_edges = unique_edges[mask_edges]
|
317 |
-
return surf_edges, idx_map, counts, surf_edges_mask
|
318 |
-
|
319 |
-
@torch.no_grad()
|
320 |
-
def _identify_surf_cubes(self, s_n, cube_fx8):
|
321 |
-
"""
|
322 |
-
Identifies grid cubes that intersect with the underlying surface by checking if the signs at
|
323 |
-
all corners are not identical.
|
324 |
-
"""
|
325 |
-
occ_n = s_n < 0
|
326 |
-
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
|
327 |
-
_occ_sum = torch.sum(occ_fx8, -1)
|
328 |
-
surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
|
329 |
-
return surf_cubes, occ_fx8
|
330 |
-
|
331 |
-
def _linear_interp(self, edges_weight, edges_x):
|
332 |
-
"""
|
333 |
-
Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
|
334 |
-
"""
|
335 |
-
edge_dim = edges_weight.dim() - 2
|
336 |
-
assert edges_weight.shape[edge_dim] == 2
|
337 |
-
edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
|
338 |
-
torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
|
339 |
-
denominator = edges_weight.sum(edge_dim)
|
340 |
-
ue = (edges_x * edges_weight).sum(edge_dim) / denominator
|
341 |
-
return ue
|
342 |
-
|
343 |
-
def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
|
344 |
-
p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
|
345 |
-
norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
|
346 |
-
c_bx3 = c_bx3.reshape(-1, 3)
|
347 |
-
A = norm_bxnx3
|
348 |
-
B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
|
349 |
-
|
350 |
-
A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
|
351 |
-
B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
|
352 |
-
A = torch.cat([A, A_reg], 1)
|
353 |
-
B = torch.cat([B, B_reg], 1)
|
354 |
-
dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
|
355 |
-
return dual_verts
|
356 |
-
|
357 |
-
def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
|
358 |
-
"""
|
359 |
-
Computes the location of dual vertices as described in Section 4.2
|
360 |
-
"""
|
361 |
-
alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
|
362 |
-
surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
|
363 |
-
surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
|
364 |
-
zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
|
365 |
-
|
366 |
-
idx_map = idx_map.reshape(-1, 12)
|
367 |
-
num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
|
368 |
-
edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
|
369 |
-
|
370 |
-
total_num_vd = 0
|
371 |
-
vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
|
372 |
-
if grad_func is not None:
|
373 |
-
normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
|
374 |
-
vd = []
|
375 |
-
for num in torch.unique(num_vd):
|
376 |
-
cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
|
377 |
-
curr_num_vd = cur_cubes.sum() * num
|
378 |
-
curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
|
379 |
-
curr_edge_group_to_vd = torch.arange(
|
380 |
-
curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
|
381 |
-
total_num_vd += curr_num_vd
|
382 |
-
curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
|
383 |
-
cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
|
384 |
-
|
385 |
-
curr_mask = (curr_edge_group != -1)
|
386 |
-
edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
|
387 |
-
edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
|
388 |
-
edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
|
389 |
-
vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
|
390 |
-
vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
|
391 |
-
|
392 |
-
if grad_func is not None:
|
393 |
-
with torch.no_grad():
|
394 |
-
cube_e_verts_idx = idx_map[cur_cubes]
|
395 |
-
curr_edge_group[~curr_mask] = 0
|
396 |
-
|
397 |
-
verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
|
398 |
-
verts_group_idx[verts_group_idx == -1] = 0
|
399 |
-
verts_group_pos = torch.index_select(
|
400 |
-
input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
|
401 |
-
v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
|
402 |
-
curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
|
403 |
-
verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
|
404 |
-
|
405 |
-
normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
|
406 |
-
-1, num.item(), 7,
|
407 |
-
3)
|
408 |
-
curr_mask = curr_mask.squeeze(2)
|
409 |
-
vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
|
410 |
-
verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
|
411 |
-
edge_group = torch.cat(edge_group)
|
412 |
-
edge_group_to_vd = torch.cat(edge_group_to_vd)
|
413 |
-
edge_group_to_cube = torch.cat(edge_group_to_cube)
|
414 |
-
vd_num_edges = torch.cat(vd_num_edges)
|
415 |
-
vd_gamma = torch.cat(vd_gamma)
|
416 |
-
|
417 |
-
if grad_func is not None:
|
418 |
-
vd = torch.cat(vd)
|
419 |
-
L_dev = torch.zeros([1], device=self.device)
|
420 |
-
else:
|
421 |
-
vd = torch.zeros((total_num_vd, 3), device=self.device)
|
422 |
-
beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
|
423 |
-
|
424 |
-
idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
|
425 |
-
|
426 |
-
x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
|
427 |
-
s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
|
428 |
-
|
429 |
-
zero_crossing_group = torch.index_select(
|
430 |
-
input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
|
431 |
-
|
432 |
-
alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
|
433 |
-
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
|
434 |
-
ue_group = self._linear_interp(s_group * alpha_group, x_group)
|
435 |
-
|
436 |
-
beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
|
437 |
-
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
|
438 |
-
beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
|
439 |
-
vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
|
440 |
-
L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
|
441 |
-
|
442 |
-
v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
|
443 |
-
|
444 |
-
vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
|
445 |
-
12 + edge_group, src=v_idx[edge_group_to_vd])
|
446 |
-
|
447 |
-
return vd, L_dev, vd_gamma, vd_idx_map
|
448 |
-
|
449 |
-
def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
|
450 |
-
"""
|
451 |
-
Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
|
452 |
-
triangles based on the gamma parameter, as described in Section 4.3.
|
453 |
-
"""
|
454 |
-
with torch.no_grad():
|
455 |
-
group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
|
456 |
-
group = idx_map.reshape(-1)[group_mask]
|
457 |
-
vd_idx = vd_idx_map[group_mask]
|
458 |
-
edge_indices, indices = torch.sort(group, stable=True)
|
459 |
-
quad_vd_idx = vd_idx[indices].reshape(-1, 4)
|
460 |
-
|
461 |
-
# Ensure all face directions point towards the positive SDF to maintain consistent winding.
|
462 |
-
s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
|
463 |
-
flip_mask = s_edges[:, 0] > 0
|
464 |
-
quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
|
465 |
-
quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
|
466 |
-
if grad_func is not None:
|
467 |
-
# when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
|
468 |
-
with torch.no_grad():
|
469 |
-
vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
|
470 |
-
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
471 |
-
gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
|
472 |
-
gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
|
473 |
-
else:
|
474 |
-
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
|
475 |
-
gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
|
476 |
-
0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
|
477 |
-
gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
|
478 |
-
1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
|
479 |
-
if not training:
|
480 |
-
mask = (gamma_02 > gamma_13).squeeze(1)
|
481 |
-
faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
|
482 |
-
faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
|
483 |
-
faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
|
484 |
-
faces = faces.reshape(-1, 3)
|
485 |
-
else:
|
486 |
-
vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
487 |
-
vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
|
488 |
-
torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
|
489 |
-
vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
|
490 |
-
torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
|
491 |
-
weight_sum = (gamma_02 + gamma_13) + 1e-8
|
492 |
-
vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
|
493 |
-
weight_sum.unsqueeze(-1)).squeeze(1)
|
494 |
-
vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
|
495 |
-
vd = torch.cat([vd, vd_center])
|
496 |
-
faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
|
497 |
-
faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
|
498 |
-
return vd, faces, s_edges, edge_indices
|
499 |
-
|
500 |
-
def _tetrahedralize(
|
501 |
-
self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
|
502 |
-
surf_cubes, training):
|
503 |
-
"""
|
504 |
-
Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
|
505 |
-
"""
|
506 |
-
occ_n = s_n < 0
|
507 |
-
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
|
508 |
-
occ_sum = torch.sum(occ_fx8, -1)
|
509 |
-
|
510 |
-
inside_verts = x_nx3[occ_n]
|
511 |
-
mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
|
512 |
-
mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
|
513 |
-
"""
|
514 |
-
For each grid edge connecting two grid vertices with different
|
515 |
-
signs, we first form a four-sided pyramid by connecting one
|
516 |
-
of the grid vertices with four mesh vertices that correspond
|
517 |
-
to the grid edge and then subdivide the pyramid into two tetrahedra
|
518 |
-
"""
|
519 |
-
inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
|
520 |
-
s_edges < 0]]
|
521 |
-
if not training:
|
522 |
-
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
|
523 |
-
else:
|
524 |
-
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
|
525 |
-
|
526 |
-
tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
|
527 |
-
"""
|
528 |
-
For each grid edge connecting two grid vertices with the
|
529 |
-
same sign, the tetrahedron is formed by the two grid vertices
|
530 |
-
and two vertices in consecutive adjacent cells
|
531 |
-
"""
|
532 |
-
inside_cubes = (occ_sum == 8)
|
533 |
-
inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
|
534 |
-
inside_cubes_center_idx = torch.arange(
|
535 |
-
inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
|
536 |
-
|
537 |
-
surface_n_inside_cubes = surf_cubes | inside_cubes
|
538 |
-
edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
|
539 |
-
dtype=torch.long, device=x_nx3.device) * -1
|
540 |
-
surf_cubes = surf_cubes[surface_n_inside_cubes]
|
541 |
-
inside_cubes = inside_cubes[surface_n_inside_cubes]
|
542 |
-
edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
|
543 |
-
edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
|
544 |
-
|
545 |
-
all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
|
546 |
-
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
547 |
-
unique_edges = unique_edges.long()
|
548 |
-
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
|
549 |
-
mask = mask_edges[_idx_map]
|
550 |
-
counts = counts[_idx_map]
|
551 |
-
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
|
552 |
-
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
|
553 |
-
idx_map = mapping[_idx_map]
|
554 |
-
|
555 |
-
group_mask = (counts == 4) & mask
|
556 |
-
group = idx_map.reshape(-1)[group_mask]
|
557 |
-
edge_indices, indices = torch.sort(group)
|
558 |
-
cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
|
559 |
-
device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
|
560 |
-
edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
|
561 |
-
0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
|
562 |
-
# Identify the face shared by the adjacent cells.
|
563 |
-
cube_idx_4 = cube_idx[indices].reshape(-1, 4)
|
564 |
-
edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
|
565 |
-
shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
|
566 |
-
cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
|
567 |
-
# Identify an edge of the face with different signs and
|
568 |
-
# select the mesh vertex corresponding to the identified edge.
|
569 |
-
case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
|
570 |
-
case_ids_expand[surf_cubes] = case_ids
|
571 |
-
cases = case_ids_expand[cube_idx_4x2]
|
572 |
-
quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
|
573 |
-
mask = (quad_edge == -1).sum(-1) == 0
|
574 |
-
inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
|
575 |
-
tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
|
576 |
-
|
577 |
-
tets = torch.cat([tets_surface, tets_inside])
|
578 |
-
vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
|
579 |
-
return vertices, tets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/rep_3d/flexicubes_geometry.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import numpy as np
|
11 |
-
import os
|
12 |
-
from . import Geometry
|
13 |
-
from .flexicubes import FlexiCubes # replace later
|
14 |
-
from .dmtet import sdf_reg_loss_batch
|
15 |
-
import torch.nn.functional as F
|
16 |
-
|
17 |
-
def get_center_boundary_index(grid_res, device):
|
18 |
-
v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
|
19 |
-
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
|
20 |
-
center_indices = torch.nonzero(v.reshape(-1))
|
21 |
-
|
22 |
-
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
|
23 |
-
v[:2, ...] = True
|
24 |
-
v[-2:, ...] = True
|
25 |
-
v[:, :2, ...] = True
|
26 |
-
v[:, -2:, ...] = True
|
27 |
-
v[:, :, :2] = True
|
28 |
-
v[:, :, -2:] = True
|
29 |
-
boundary_indices = torch.nonzero(v.reshape(-1))
|
30 |
-
return center_indices, boundary_indices
|
31 |
-
|
32 |
-
###############################################################################
|
33 |
-
# Geometry interface
|
34 |
-
###############################################################################
|
35 |
-
class FlexiCubesGeometry(Geometry):
|
36 |
-
def __init__(
|
37 |
-
self, grid_res=64, scale=2.0, device='cuda', renderer=None,
|
38 |
-
render_type='neural_render', args=None):
|
39 |
-
super(FlexiCubesGeometry, self).__init__()
|
40 |
-
self.grid_res = grid_res
|
41 |
-
self.device = device
|
42 |
-
self.args = args
|
43 |
-
self.fc = FlexiCubes(device, weight_scale=0.5)
|
44 |
-
self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
|
45 |
-
if isinstance(scale, list):
|
46 |
-
self.verts[:, 0] = self.verts[:, 0] * scale[0]
|
47 |
-
self.verts[:, 1] = self.verts[:, 1] * scale[1]
|
48 |
-
self.verts[:, 2] = self.verts[:, 2] * scale[1]
|
49 |
-
else:
|
50 |
-
self.verts = self.verts * scale
|
51 |
-
|
52 |
-
all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
|
53 |
-
self.all_edges = torch.unique(all_edges, dim=0)
|
54 |
-
|
55 |
-
# Parameters used for fix boundary sdf
|
56 |
-
self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
|
57 |
-
self.renderer = renderer
|
58 |
-
self.render_type = render_type
|
59 |
-
|
60 |
-
def getAABB(self):
|
61 |
-
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
62 |
-
|
63 |
-
def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
|
64 |
-
if indices is None:
|
65 |
-
indices = self.indices
|
66 |
-
|
67 |
-
verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
|
68 |
-
beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
|
69 |
-
gamma_f=weight_n[:, 20], training=is_training
|
70 |
-
)
|
71 |
-
return verts, faces, v_reg_loss
|
72 |
-
|
73 |
-
|
74 |
-
def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
|
75 |
-
return_value = dict()
|
76 |
-
if self.render_type == 'neural_render':
|
77 |
-
tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh(
|
78 |
-
mesh_v_nx3.unsqueeze(dim=0),
|
79 |
-
mesh_f_fx3.int(),
|
80 |
-
camera_mv_bx4x4,
|
81 |
-
mesh_v_nx3.unsqueeze(dim=0),
|
82 |
-
resolution=resolution,
|
83 |
-
device=self.device,
|
84 |
-
hierarchical_mask=hierarchical_mask
|
85 |
-
)
|
86 |
-
|
87 |
-
return_value['tex_pos'] = tex_pos
|
88 |
-
return_value['mask'] = mask
|
89 |
-
return_value['hard_mask'] = hard_mask
|
90 |
-
return_value['rast'] = rast
|
91 |
-
return_value['v_pos_clip'] = v_pos_clip
|
92 |
-
return_value['mask_pyramid'] = mask_pyramid
|
93 |
-
return_value['depth'] = depth
|
94 |
-
return_value['normal'] = normal
|
95 |
-
else:
|
96 |
-
raise NotImplementedError
|
97 |
-
|
98 |
-
return return_value
|
99 |
-
|
100 |
-
def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
|
101 |
-
# Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
|
102 |
-
v_list = []
|
103 |
-
f_list = []
|
104 |
-
n_batch = v_deformed_bxnx3.shape[0]
|
105 |
-
all_render_output = []
|
106 |
-
for i_batch in range(n_batch):
|
107 |
-
verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
|
108 |
-
v_list.append(verts_nx3)
|
109 |
-
f_list.append(faces_fx3)
|
110 |
-
render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
|
111 |
-
all_render_output.append(render_output)
|
112 |
-
|
113 |
-
# Concatenate all render output
|
114 |
-
return_keys = all_render_output[0].keys()
|
115 |
-
return_value = dict()
|
116 |
-
for k in return_keys:
|
117 |
-
value = [v[k] for v in all_render_output]
|
118 |
-
return_value[k] = value
|
119 |
-
# We can do concatenation outside of the render
|
120 |
-
return return_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/geometry/rep_3d/tables.py
DELETED
@@ -1,791 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
dmc_table = [
|
9 |
-
[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
10 |
-
[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
11 |
-
[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
12 |
-
[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
13 |
-
[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
14 |
-
[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
15 |
-
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
16 |
-
[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
17 |
-
[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
18 |
-
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
19 |
-
[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
20 |
-
[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
21 |
-
[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
22 |
-
[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
23 |
-
[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
24 |
-
[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
25 |
-
[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
26 |
-
[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
27 |
-
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
28 |
-
[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
29 |
-
[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
30 |
-
[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
31 |
-
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
32 |
-
[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
33 |
-
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
34 |
-
[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
35 |
-
[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
36 |
-
[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
37 |
-
[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
38 |
-
[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
39 |
-
[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
40 |
-
[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
41 |
-
[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
42 |
-
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
43 |
-
[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
44 |
-
[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
45 |
-
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
46 |
-
[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
47 |
-
[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
48 |
-
[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
49 |
-
[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
50 |
-
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
51 |
-
[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
52 |
-
[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
53 |
-
[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
54 |
-
[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
55 |
-
[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
56 |
-
[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
57 |
-
[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
58 |
-
[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
59 |
-
[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
60 |
-
[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
61 |
-
[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
62 |
-
[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
63 |
-
[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
64 |
-
[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
65 |
-
[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
66 |
-
[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
67 |
-
[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
68 |
-
[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
69 |
-
[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
70 |
-
[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
71 |
-
[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
72 |
-
[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
73 |
-
[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
74 |
-
[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
75 |
-
[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
76 |
-
[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
77 |
-
[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
78 |
-
[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
79 |
-
[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
80 |
-
[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
81 |
-
[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
82 |
-
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
83 |
-
[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
84 |
-
[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
85 |
-
[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
86 |
-
[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
87 |
-
[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
88 |
-
[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
89 |
-
[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
90 |
-
[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
91 |
-
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
92 |
-
[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
93 |
-
[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
94 |
-
[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
95 |
-
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
96 |
-
[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
97 |
-
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
98 |
-
[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
99 |
-
[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
100 |
-
[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
101 |
-
[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
102 |
-
[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
103 |
-
[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
104 |
-
[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
105 |
-
[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
106 |
-
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
107 |
-
[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
108 |
-
[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
109 |
-
[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
110 |
-
[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
111 |
-
[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
112 |
-
[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
113 |
-
[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
114 |
-
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
|
115 |
-
[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
116 |
-
[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
117 |
-
[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
118 |
-
[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
119 |
-
[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
120 |
-
[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
121 |
-
[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
122 |
-
[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
123 |
-
[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
124 |
-
[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
125 |
-
[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
126 |
-
[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
127 |
-
[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
128 |
-
[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
129 |
-
[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
130 |
-
[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
131 |
-
[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
132 |
-
[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
133 |
-
[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
134 |
-
[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
135 |
-
[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
136 |
-
[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
137 |
-
[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
138 |
-
[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
139 |
-
[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
140 |
-
[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
141 |
-
[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
142 |
-
[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
143 |
-
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
144 |
-
[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
145 |
-
[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
146 |
-
[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
147 |
-
[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
148 |
-
[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
149 |
-
[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
150 |
-
[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
151 |
-
[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
152 |
-
[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
153 |
-
[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
154 |
-
[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
155 |
-
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
156 |
-
[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
157 |
-
[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
158 |
-
[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
159 |
-
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
|
160 |
-
[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
161 |
-
[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
162 |
-
[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
163 |
-
[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
164 |
-
[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
165 |
-
[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
166 |
-
[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
167 |
-
[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
168 |
-
[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
169 |
-
[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
170 |
-
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
171 |
-
[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
172 |
-
[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
173 |
-
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
174 |
-
[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
175 |
-
[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
176 |
-
[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
177 |
-
[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
178 |
-
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
179 |
-
[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
180 |
-
[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
181 |
-
[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
182 |
-
[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
183 |
-
[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
184 |
-
[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
185 |
-
[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
186 |
-
[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
187 |
-
[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
188 |
-
[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
189 |
-
[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
190 |
-
[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
191 |
-
[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
192 |
-
[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
193 |
-
[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
194 |
-
[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
195 |
-
[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
196 |
-
[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
197 |
-
[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
198 |
-
[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
199 |
-
[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
200 |
-
[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
201 |
-
[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
202 |
-
[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
203 |
-
[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
204 |
-
[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
205 |
-
[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
206 |
-
[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
207 |
-
[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
208 |
-
[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
209 |
-
[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
210 |
-
[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
211 |
-
[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
212 |
-
[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
213 |
-
[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
214 |
-
[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
215 |
-
[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
216 |
-
[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
217 |
-
[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
218 |
-
[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
219 |
-
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
220 |
-
[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
221 |
-
[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
222 |
-
[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
223 |
-
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
224 |
-
[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
225 |
-
[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
226 |
-
[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
227 |
-
[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
228 |
-
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
229 |
-
[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
230 |
-
[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
231 |
-
[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
232 |
-
[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
233 |
-
[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
234 |
-
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
235 |
-
[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
236 |
-
[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
237 |
-
[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
238 |
-
[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
239 |
-
[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
240 |
-
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
241 |
-
[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
242 |
-
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
243 |
-
[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
244 |
-
[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
245 |
-
[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
246 |
-
[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
247 |
-
[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
248 |
-
[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
249 |
-
[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
250 |
-
[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
251 |
-
[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
252 |
-
[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
253 |
-
[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
254 |
-
[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
255 |
-
[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
256 |
-
[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
257 |
-
[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
258 |
-
[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
259 |
-
[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
260 |
-
[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
261 |
-
[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
262 |
-
[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
263 |
-
[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
264 |
-
[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
|
265 |
-
]
|
266 |
-
num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
|
267 |
-
2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
|
268 |
-
1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
|
269 |
-
1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
|
270 |
-
2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
|
271 |
-
3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
|
272 |
-
2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
|
273 |
-
1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
|
274 |
-
1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
|
275 |
-
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
|
276 |
-
check_table = [
|
277 |
-
[0, 0, 0, 0, 0],
|
278 |
-
[0, 0, 0, 0, 0],
|
279 |
-
[0, 0, 0, 0, 0],
|
280 |
-
[0, 0, 0, 0, 0],
|
281 |
-
[0, 0, 0, 0, 0],
|
282 |
-
[0, 0, 0, 0, 0],
|
283 |
-
[0, 0, 0, 0, 0],
|
284 |
-
[0, 0, 0, 0, 0],
|
285 |
-
[0, 0, 0, 0, 0],
|
286 |
-
[0, 0, 0, 0, 0],
|
287 |
-
[0, 0, 0, 0, 0],
|
288 |
-
[0, 0, 0, 0, 0],
|
289 |
-
[0, 0, 0, 0, 0],
|
290 |
-
[0, 0, 0, 0, 0],
|
291 |
-
[0, 0, 0, 0, 0],
|
292 |
-
[0, 0, 0, 0, 0],
|
293 |
-
[0, 0, 0, 0, 0],
|
294 |
-
[0, 0, 0, 0, 0],
|
295 |
-
[0, 0, 0, 0, 0],
|
296 |
-
[0, 0, 0, 0, 0],
|
297 |
-
[0, 0, 0, 0, 0],
|
298 |
-
[0, 0, 0, 0, 0],
|
299 |
-
[0, 0, 0, 0, 0],
|
300 |
-
[0, 0, 0, 0, 0],
|
301 |
-
[0, 0, 0, 0, 0],
|
302 |
-
[0, 0, 0, 0, 0],
|
303 |
-
[0, 0, 0, 0, 0],
|
304 |
-
[0, 0, 0, 0, 0],
|
305 |
-
[0, 0, 0, 0, 0],
|
306 |
-
[0, 0, 0, 0, 0],
|
307 |
-
[0, 0, 0, 0, 0],
|
308 |
-
[0, 0, 0, 0, 0],
|
309 |
-
[0, 0, 0, 0, 0],
|
310 |
-
[0, 0, 0, 0, 0],
|
311 |
-
[0, 0, 0, 0, 0],
|
312 |
-
[0, 0, 0, 0, 0],
|
313 |
-
[0, 0, 0, 0, 0],
|
314 |
-
[0, 0, 0, 0, 0],
|
315 |
-
[0, 0, 0, 0, 0],
|
316 |
-
[0, 0, 0, 0, 0],
|
317 |
-
[0, 0, 0, 0, 0],
|
318 |
-
[0, 0, 0, 0, 0],
|
319 |
-
[0, 0, 0, 0, 0],
|
320 |
-
[0, 0, 0, 0, 0],
|
321 |
-
[0, 0, 0, 0, 0],
|
322 |
-
[0, 0, 0, 0, 0],
|
323 |
-
[0, 0, 0, 0, 0],
|
324 |
-
[0, 0, 0, 0, 0],
|
325 |
-
[0, 0, 0, 0, 0],
|
326 |
-
[0, 0, 0, 0, 0],
|
327 |
-
[0, 0, 0, 0, 0],
|
328 |
-
[0, 0, 0, 0, 0],
|
329 |
-
[0, 0, 0, 0, 0],
|
330 |
-
[0, 0, 0, 0, 0],
|
331 |
-
[0, 0, 0, 0, 0],
|
332 |
-
[0, 0, 0, 0, 0],
|
333 |
-
[0, 0, 0, 0, 0],
|
334 |
-
[0, 0, 0, 0, 0],
|
335 |
-
[0, 0, 0, 0, 0],
|
336 |
-
[0, 0, 0, 0, 0],
|
337 |
-
[0, 0, 0, 0, 0],
|
338 |
-
[1, 1, 0, 0, 194],
|
339 |
-
[1, -1, 0, 0, 193],
|
340 |
-
[0, 0, 0, 0, 0],
|
341 |
-
[0, 0, 0, 0, 0],
|
342 |
-
[0, 0, 0, 0, 0],
|
343 |
-
[0, 0, 0, 0, 0],
|
344 |
-
[0, 0, 0, 0, 0],
|
345 |
-
[0, 0, 0, 0, 0],
|
346 |
-
[0, 0, 0, 0, 0],
|
347 |
-
[0, 0, 0, 0, 0],
|
348 |
-
[0, 0, 0, 0, 0],
|
349 |
-
[0, 0, 0, 0, 0],
|
350 |
-
[0, 0, 0, 0, 0],
|
351 |
-
[0, 0, 0, 0, 0],
|
352 |
-
[0, 0, 0, 0, 0],
|
353 |
-
[0, 0, 0, 0, 0],
|
354 |
-
[0, 0, 0, 0, 0],
|
355 |
-
[0, 0, 0, 0, 0],
|
356 |
-
[0, 0, 0, 0, 0],
|
357 |
-
[0, 0, 0, 0, 0],
|
358 |
-
[0, 0, 0, 0, 0],
|
359 |
-
[0, 0, 0, 0, 0],
|
360 |
-
[0, 0, 0, 0, 0],
|
361 |
-
[0, 0, 0, 0, 0],
|
362 |
-
[0, 0, 0, 0, 0],
|
363 |
-
[0, 0, 0, 0, 0],
|
364 |
-
[0, 0, 0, 0, 0],
|
365 |
-
[0, 0, 0, 0, 0],
|
366 |
-
[0, 0, 0, 0, 0],
|
367 |
-
[0, 0, 0, 0, 0],
|
368 |
-
[1, 0, 1, 0, 164],
|
369 |
-
[0, 0, 0, 0, 0],
|
370 |
-
[0, 0, 0, 0, 0],
|
371 |
-
[1, 0, -1, 0, 161],
|
372 |
-
[0, 0, 0, 0, 0],
|
373 |
-
[0, 0, 0, 0, 0],
|
374 |
-
[0, 0, 0, 0, 0],
|
375 |
-
[0, 0, 0, 0, 0],
|
376 |
-
[0, 0, 0, 0, 0],
|
377 |
-
[0, 0, 0, 0, 0],
|
378 |
-
[0, 0, 0, 0, 0],
|
379 |
-
[0, 0, 0, 0, 0],
|
380 |
-
[1, 0, 0, 1, 152],
|
381 |
-
[0, 0, 0, 0, 0],
|
382 |
-
[0, 0, 0, 0, 0],
|
383 |
-
[0, 0, 0, 0, 0],
|
384 |
-
[0, 0, 0, 0, 0],
|
385 |
-
[0, 0, 0, 0, 0],
|
386 |
-
[0, 0, 0, 0, 0],
|
387 |
-
[1, 0, 0, 1, 145],
|
388 |
-
[1, 0, 0, 1, 144],
|
389 |
-
[0, 0, 0, 0, 0],
|
390 |
-
[0, 0, 0, 0, 0],
|
391 |
-
[0, 0, 0, 0, 0],
|
392 |
-
[0, 0, 0, 0, 0],
|
393 |
-
[0, 0, 0, 0, 0],
|
394 |
-
[0, 0, 0, 0, 0],
|
395 |
-
[1, 0, 0, -1, 137],
|
396 |
-
[0, 0, 0, 0, 0],
|
397 |
-
[0, 0, 0, 0, 0],
|
398 |
-
[0, 0, 0, 0, 0],
|
399 |
-
[1, 0, 1, 0, 133],
|
400 |
-
[1, 0, 1, 0, 132],
|
401 |
-
[1, 1, 0, 0, 131],
|
402 |
-
[1, 1, 0, 0, 130],
|
403 |
-
[0, 0, 0, 0, 0],
|
404 |
-
[0, 0, 0, 0, 0],
|
405 |
-
[0, 0, 0, 0, 0],
|
406 |
-
[0, 0, 0, 0, 0],
|
407 |
-
[0, 0, 0, 0, 0],
|
408 |
-
[0, 0, 0, 0, 0],
|
409 |
-
[0, 0, 0, 0, 0],
|
410 |
-
[0, 0, 0, 0, 0],
|
411 |
-
[0, 0, 0, 0, 0],
|
412 |
-
[0, 0, 0, 0, 0],
|
413 |
-
[0, 0, 0, 0, 0],
|
414 |
-
[0, 0, 0, 0, 0],
|
415 |
-
[0, 0, 0, 0, 0],
|
416 |
-
[0, 0, 0, 0, 0],
|
417 |
-
[0, 0, 0, 0, 0],
|
418 |
-
[0, 0, 0, 0, 0],
|
419 |
-
[0, 0, 0, 0, 0],
|
420 |
-
[0, 0, 0, 0, 0],
|
421 |
-
[0, 0, 0, 0, 0],
|
422 |
-
[0, 0, 0, 0, 0],
|
423 |
-
[0, 0, 0, 0, 0],
|
424 |
-
[0, 0, 0, 0, 0],
|
425 |
-
[0, 0, 0, 0, 0],
|
426 |
-
[0, 0, 0, 0, 0],
|
427 |
-
[0, 0, 0, 0, 0],
|
428 |
-
[0, 0, 0, 0, 0],
|
429 |
-
[0, 0, 0, 0, 0],
|
430 |
-
[0, 0, 0, 0, 0],
|
431 |
-
[0, 0, 0, 0, 0],
|
432 |
-
[1, 0, 0, 1, 100],
|
433 |
-
[0, 0, 0, 0, 0],
|
434 |
-
[1, 0, 0, 1, 98],
|
435 |
-
[0, 0, 0, 0, 0],
|
436 |
-
[1, 0, 0, 1, 96],
|
437 |
-
[0, 0, 0, 0, 0],
|
438 |
-
[0, 0, 0, 0, 0],
|
439 |
-
[0, 0, 0, 0, 0],
|
440 |
-
[0, 0, 0, 0, 0],
|
441 |
-
[0, 0, 0, 0, 0],
|
442 |
-
[0, 0, 0, 0, 0],
|
443 |
-
[0, 0, 0, 0, 0],
|
444 |
-
[1, 0, 1, 0, 88],
|
445 |
-
[0, 0, 0, 0, 0],
|
446 |
-
[0, 0, 0, 0, 0],
|
447 |
-
[0, 0, 0, 0, 0],
|
448 |
-
[0, 0, 0, 0, 0],
|
449 |
-
[0, 0, 0, 0, 0],
|
450 |
-
[1, 0, -1, 0, 82],
|
451 |
-
[0, 0, 0, 0, 0],
|
452 |
-
[0, 0, 0, 0, 0],
|
453 |
-
[0, 0, 0, 0, 0],
|
454 |
-
[0, 0, 0, 0, 0],
|
455 |
-
[0, 0, 0, 0, 0],
|
456 |
-
[0, 0, 0, 0, 0],
|
457 |
-
[0, 0, 0, 0, 0],
|
458 |
-
[1, 0, 1, 0, 74],
|
459 |
-
[0, 0, 0, 0, 0],
|
460 |
-
[1, 0, 1, 0, 72],
|
461 |
-
[0, 0, 0, 0, 0],
|
462 |
-
[1, 0, 0, -1, 70],
|
463 |
-
[0, 0, 0, 0, 0],
|
464 |
-
[0, 0, 0, 0, 0],
|
465 |
-
[1, -1, 0, 0, 67],
|
466 |
-
[0, 0, 0, 0, 0],
|
467 |
-
[1, -1, 0, 0, 65],
|
468 |
-
[0, 0, 0, 0, 0],
|
469 |
-
[0, 0, 0, 0, 0],
|
470 |
-
[0, 0, 0, 0, 0],
|
471 |
-
[0, 0, 0, 0, 0],
|
472 |
-
[0, 0, 0, 0, 0],
|
473 |
-
[0, 0, 0, 0, 0],
|
474 |
-
[0, 0, 0, 0, 0],
|
475 |
-
[0, 0, 0, 0, 0],
|
476 |
-
[1, 1, 0, 0, 56],
|
477 |
-
[0, 0, 0, 0, 0],
|
478 |
-
[0, 0, 0, 0, 0],
|
479 |
-
[0, 0, 0, 0, 0],
|
480 |
-
[1, -1, 0, 0, 52],
|
481 |
-
[0, 0, 0, 0, 0],
|
482 |
-
[0, 0, 0, 0, 0],
|
483 |
-
[0, 0, 0, 0, 0],
|
484 |
-
[0, 0, 0, 0, 0],
|
485 |
-
[0, 0, 0, 0, 0],
|
486 |
-
[0, 0, 0, 0, 0],
|
487 |
-
[0, 0, 0, 0, 0],
|
488 |
-
[1, 1, 0, 0, 44],
|
489 |
-
[0, 0, 0, 0, 0],
|
490 |
-
[0, 0, 0, 0, 0],
|
491 |
-
[0, 0, 0, 0, 0],
|
492 |
-
[1, 1, 0, 0, 40],
|
493 |
-
[0, 0, 0, 0, 0],
|
494 |
-
[1, 0, 0, -1, 38],
|
495 |
-
[1, 0, -1, 0, 37],
|
496 |
-
[0, 0, 0, 0, 0],
|
497 |
-
[0, 0, 0, 0, 0],
|
498 |
-
[0, 0, 0, 0, 0],
|
499 |
-
[1, 0, -1, 0, 33],
|
500 |
-
[0, 0, 0, 0, 0],
|
501 |
-
[0, 0, 0, 0, 0],
|
502 |
-
[0, 0, 0, 0, 0],
|
503 |
-
[0, 0, 0, 0, 0],
|
504 |
-
[1, -1, 0, 0, 28],
|
505 |
-
[0, 0, 0, 0, 0],
|
506 |
-
[1, 0, -1, 0, 26],
|
507 |
-
[1, 0, 0, -1, 25],
|
508 |
-
[0, 0, 0, 0, 0],
|
509 |
-
[0, 0, 0, 0, 0],
|
510 |
-
[0, 0, 0, 0, 0],
|
511 |
-
[0, 0, 0, 0, 0],
|
512 |
-
[1, -1, 0, 0, 20],
|
513 |
-
[0, 0, 0, 0, 0],
|
514 |
-
[1, 0, -1, 0, 18],
|
515 |
-
[0, 0, 0, 0, 0],
|
516 |
-
[0, 0, 0, 0, 0],
|
517 |
-
[0, 0, 0, 0, 0],
|
518 |
-
[0, 0, 0, 0, 0],
|
519 |
-
[0, 0, 0, 0, 0],
|
520 |
-
[0, 0, 0, 0, 0],
|
521 |
-
[0, 0, 0, 0, 0],
|
522 |
-
[0, 0, 0, 0, 0],
|
523 |
-
[1, 0, 0, -1, 9],
|
524 |
-
[0, 0, 0, 0, 0],
|
525 |
-
[0, 0, 0, 0, 0],
|
526 |
-
[1, 0, 0, -1, 6],
|
527 |
-
[0, 0, 0, 0, 0],
|
528 |
-
[0, 0, 0, 0, 0],
|
529 |
-
[0, 0, 0, 0, 0],
|
530 |
-
[0, 0, 0, 0, 0],
|
531 |
-
[0, 0, 0, 0, 0],
|
532 |
-
[0, 0, 0, 0, 0]
|
533 |
-
]
|
534 |
-
tet_table = [
|
535 |
-
[-1, -1, -1, -1, -1, -1],
|
536 |
-
[0, 0, 0, 0, 0, 0],
|
537 |
-
[0, 0, 0, 0, 0, 0],
|
538 |
-
[1, 1, 1, 1, 1, 1],
|
539 |
-
[4, 4, 4, 4, 4, 4],
|
540 |
-
[0, 0, 0, 0, 0, 0],
|
541 |
-
[4, 0, 0, 4, 4, -1],
|
542 |
-
[1, 1, 1, 1, 1, 1],
|
543 |
-
[4, 4, 4, 4, 4, 4],
|
544 |
-
[0, 4, 0, 4, 4, -1],
|
545 |
-
[0, 0, 0, 0, 0, 0],
|
546 |
-
[1, 1, 1, 1, 1, 1],
|
547 |
-
[5, 5, 5, 5, 5, 5],
|
548 |
-
[0, 0, 0, 0, 0, 0],
|
549 |
-
[0, 0, 0, 0, 0, 0],
|
550 |
-
[1, 1, 1, 1, 1, 1],
|
551 |
-
[2, 2, 2, 2, 2, 2],
|
552 |
-
[0, 0, 0, 0, 0, 0],
|
553 |
-
[2, 0, 2, -1, 0, 2],
|
554 |
-
[1, 1, 1, 1, 1, 1],
|
555 |
-
[2, -1, 2, 4, 4, 2],
|
556 |
-
[0, 0, 0, 0, 0, 0],
|
557 |
-
[2, 0, 2, 4, 4, 2],
|
558 |
-
[1, 1, 1, 1, 1, 1],
|
559 |
-
[2, 4, 2, 4, 4, 2],
|
560 |
-
[0, 4, 0, 4, 4, 0],
|
561 |
-
[2, 0, 2, 0, 0, 2],
|
562 |
-
[1, 1, 1, 1, 1, 1],
|
563 |
-
[2, 5, 2, 5, 5, 2],
|
564 |
-
[0, 0, 0, 0, 0, 0],
|
565 |
-
[2, 0, 2, 0, 0, 2],
|
566 |
-
[1, 1, 1, 1, 1, 1],
|
567 |
-
[1, 1, 1, 1, 1, 1],
|
568 |
-
[0, 1, 1, -1, 0, 1],
|
569 |
-
[0, 0, 0, 0, 0, 0],
|
570 |
-
[2, 2, 2, 2, 2, 2],
|
571 |
-
[4, 1, 1, 4, 4, 1],
|
572 |
-
[0, 1, 1, 0, 0, 1],
|
573 |
-
[4, 0, 0, 4, 4, 0],
|
574 |
-
[2, 2, 2, 2, 2, 2],
|
575 |
-
[-1, 1, 1, 4, 4, 1],
|
576 |
-
[0, 1, 1, 4, 4, 1],
|
577 |
-
[0, 0, 0, 0, 0, 0],
|
578 |
-
[2, 2, 2, 2, 2, 2],
|
579 |
-
[5, 1, 1, 5, 5, 1],
|
580 |
-
[0, 1, 1, 0, 0, 1],
|
581 |
-
[0, 0, 0, 0, 0, 0],
|
582 |
-
[2, 2, 2, 2, 2, 2],
|
583 |
-
[1, 1, 1, 1, 1, 1],
|
584 |
-
[0, 0, 0, 0, 0, 0],
|
585 |
-
[0, 0, 0, 0, 0, 0],
|
586 |
-
[8, 8, 8, 8, 8, 8],
|
587 |
-
[1, 1, 1, 4, 4, 1],
|
588 |
-
[0, 0, 0, 0, 0, 0],
|
589 |
-
[4, 0, 0, 4, 4, 0],
|
590 |
-
[4, 4, 4, 4, 4, 4],
|
591 |
-
[1, 1, 1, 4, 4, 1],
|
592 |
-
[0, 4, 0, 4, 4, 0],
|
593 |
-
[0, 0, 0, 0, 0, 0],
|
594 |
-
[4, 4, 4, 4, 4, 4],
|
595 |
-
[1, 1, 1, 5, 5, 1],
|
596 |
-
[0, 0, 0, 0, 0, 0],
|
597 |
-
[0, 0, 0, 0, 0, 0],
|
598 |
-
[5, 5, 5, 5, 5, 5],
|
599 |
-
[6, 6, 6, 6, 6, 6],
|
600 |
-
[6, -1, 0, 6, 0, 6],
|
601 |
-
[6, 0, 0, 6, 0, 6],
|
602 |
-
[6, 1, 1, 6, 1, 6],
|
603 |
-
[4, 4, 4, 4, 4, 4],
|
604 |
-
[0, 0, 0, 0, 0, 0],
|
605 |
-
[4, 0, 0, 4, 4, 4],
|
606 |
-
[1, 1, 1, 1, 1, 1],
|
607 |
-
[6, 4, -1, 6, 4, 6],
|
608 |
-
[6, 4, 0, 6, 4, 6],
|
609 |
-
[6, 0, 0, 6, 0, 6],
|
610 |
-
[6, 1, 1, 6, 1, 6],
|
611 |
-
[5, 5, 5, 5, 5, 5],
|
612 |
-
[0, 0, 0, 0, 0, 0],
|
613 |
-
[0, 0, 0, 0, 0, 0],
|
614 |
-
[1, 1, 1, 1, 1, 1],
|
615 |
-
[2, 2, 2, 2, 2, 2],
|
616 |
-
[0, 0, 0, 0, 0, 0],
|
617 |
-
[2, 0, 2, 2, 0, 2],
|
618 |
-
[1, 1, 1, 1, 1, 1],
|
619 |
-
[2, 2, 2, 2, 2, 2],
|
620 |
-
[0, 0, 0, 0, 0, 0],
|
621 |
-
[2, 0, 2, 2, 2, 2],
|
622 |
-
[1, 1, 1, 1, 1, 1],
|
623 |
-
[2, 4, 2, 2, 4, 2],
|
624 |
-
[0, 4, 0, 4, 4, 0],
|
625 |
-
[2, 0, 2, 2, 0, 2],
|
626 |
-
[1, 1, 1, 1, 1, 1],
|
627 |
-
[2, 2, 2, 2, 2, 2],
|
628 |
-
[0, 0, 0, 0, 0, 0],
|
629 |
-
[0, 0, 0, 0, 0, 0],
|
630 |
-
[1, 1, 1, 1, 1, 1],
|
631 |
-
[6, 1, 1, 6, -1, 6],
|
632 |
-
[6, 1, 1, 6, 0, 6],
|
633 |
-
[6, 0, 0, 6, 0, 6],
|
634 |
-
[6, 2, 2, 6, 2, 6],
|
635 |
-
[4, 1, 1, 4, 4, 1],
|
636 |
-
[0, 1, 1, 0, 0, 1],
|
637 |
-
[4, 0, 0, 4, 4, 4],
|
638 |
-
[2, 2, 2, 2, 2, 2],
|
639 |
-
[6, 1, 1, 6, 4, 6],
|
640 |
-
[6, 1, 1, 6, 4, 6],
|
641 |
-
[6, 0, 0, 6, 0, 6],
|
642 |
-
[6, 2, 2, 6, 2, 6],
|
643 |
-
[5, 1, 1, 5, 5, 1],
|
644 |
-
[0, 1, 1, 0, 0, 1],
|
645 |
-
[0, 0, 0, 0, 0, 0],
|
646 |
-
[2, 2, 2, 2, 2, 2],
|
647 |
-
[1, 1, 1, 1, 1, 1],
|
648 |
-
[0, 0, 0, 0, 0, 0],
|
649 |
-
[0, 0, 0, 0, 0, 0],
|
650 |
-
[6, 6, 6, 6, 6, 6],
|
651 |
-
[1, 1, 1, 1, 1, 1],
|
652 |
-
[0, 0, 0, 0, 0, 0],
|
653 |
-
[0, 0, 0, 0, 0, 0],
|
654 |
-
[4, 4, 4, 4, 4, 4],
|
655 |
-
[1, 1, 1, 1, 4, 1],
|
656 |
-
[0, 4, 0, 4, 4, 0],
|
657 |
-
[0, 0, 0, 0, 0, 0],
|
658 |
-
[4, 4, 4, 4, 4, 4],
|
659 |
-
[1, 1, 1, 1, 1, 1],
|
660 |
-
[0, 0, 0, 0, 0, 0],
|
661 |
-
[0, 5, 0, 5, 0, 5],
|
662 |
-
[5, 5, 5, 5, 5, 5],
|
663 |
-
[5, 5, 5, 5, 5, 5],
|
664 |
-
[0, 5, 0, 5, 0, 5],
|
665 |
-
[-1, 5, 0, 5, 0, 5],
|
666 |
-
[1, 5, 1, 5, 1, 5],
|
667 |
-
[4, 5, -1, 5, 4, 5],
|
668 |
-
[0, 5, 0, 5, 0, 5],
|
669 |
-
[4, 5, 0, 5, 4, 5],
|
670 |
-
[1, 5, 1, 5, 1, 5],
|
671 |
-
[4, 4, 4, 4, 4, 4],
|
672 |
-
[0, 4, 0, 4, 4, 4],
|
673 |
-
[0, 0, 0, 0, 0, 0],
|
674 |
-
[1, 1, 1, 1, 1, 1],
|
675 |
-
[6, 6, 6, 6, 6, 6],
|
676 |
-
[0, 0, 0, 0, 0, 0],
|
677 |
-
[0, 0, 0, 0, 0, 0],
|
678 |
-
[1, 1, 1, 1, 1, 1],
|
679 |
-
[2, 5, 2, 5, -1, 5],
|
680 |
-
[0, 5, 0, 5, 0, 5],
|
681 |
-
[2, 5, 2, 5, 0, 5],
|
682 |
-
[1, 5, 1, 5, 1, 5],
|
683 |
-
[2, 5, 2, 5, 4, 5],
|
684 |
-
[0, 5, 0, 5, 0, 5],
|
685 |
-
[2, 5, 2, 5, 4, 5],
|
686 |
-
[1, 5, 1, 5, 1, 5],
|
687 |
-
[2, 4, 2, 4, 4, 2],
|
688 |
-
[0, 4, 0, 4, 4, 4],
|
689 |
-
[2, 0, 2, 0, 0, 2],
|
690 |
-
[1, 1, 1, 1, 1, 1],
|
691 |
-
[2, 6, 2, 6, 6, 2],
|
692 |
-
[0, 0, 0, 0, 0, 0],
|
693 |
-
[2, 0, 2, 0, 0, 2],
|
694 |
-
[1, 1, 1, 1, 1, 1],
|
695 |
-
[1, 1, 1, 1, 1, 1],
|
696 |
-
[0, 1, 1, 1, 0, 1],
|
697 |
-
[0, 0, 0, 0, 0, 0],
|
698 |
-
[2, 2, 2, 2, 2, 2],
|
699 |
-
[4, 1, 1, 1, 4, 1],
|
700 |
-
[0, 1, 1, 1, 0, 1],
|
701 |
-
[4, 0, 0, 4, 4, 0],
|
702 |
-
[2, 2, 2, 2, 2, 2],
|
703 |
-
[1, 1, 1, 1, 1, 1],
|
704 |
-
[0, 1, 1, 1, 1, 1],
|
705 |
-
[0, 0, 0, 0, 0, 0],
|
706 |
-
[2, 2, 2, 2, 2, 2],
|
707 |
-
[1, 1, 1, 1, 1, 1],
|
708 |
-
[0, 0, 0, 0, 0, 0],
|
709 |
-
[0, 0, 0, 0, 0, 0],
|
710 |
-
[2, 2, 2, 2, 2, 2],
|
711 |
-
[1, 1, 1, 1, 1, 1],
|
712 |
-
[0, 0, 0, 0, 0, 0],
|
713 |
-
[0, 0, 0, 0, 0, 0],
|
714 |
-
[5, 5, 5, 5, 5, 5],
|
715 |
-
[1, 1, 1, 1, 4, 1],
|
716 |
-
[0, 0, 0, 0, 0, 0],
|
717 |
-
[4, 0, 0, 4, 4, 0],
|
718 |
-
[4, 4, 4, 4, 4, 4],
|
719 |
-
[1, 1, 1, 1, 1, 1],
|
720 |
-
[0, 0, 0, 0, 0, 0],
|
721 |
-
[0, 0, 0, 0, 0, 0],
|
722 |
-
[4, 4, 4, 4, 4, 4],
|
723 |
-
[1, 1, 1, 1, 1, 1],
|
724 |
-
[6, 0, 0, 6, 0, 6],
|
725 |
-
[0, 0, 0, 0, 0, 0],
|
726 |
-
[6, 6, 6, 6, 6, 6],
|
727 |
-
[5, 5, 5, 5, 5, 5],
|
728 |
-
[5, 5, 0, 5, 0, 5],
|
729 |
-
[5, 5, 0, 5, 0, 5],
|
730 |
-
[5, 5, 1, 5, 1, 5],
|
731 |
-
[4, 4, 4, 4, 4, 4],
|
732 |
-
[0, 0, 0, 0, 0, 0],
|
733 |
-
[4, 4, 0, 4, 4, 4],
|
734 |
-
[1, 1, 1, 1, 1, 1],
|
735 |
-
[4, 4, 4, 4, 4, 4],
|
736 |
-
[4, 4, 0, 4, 4, 4],
|
737 |
-
[0, 0, 0, 0, 0, 0],
|
738 |
-
[1, 1, 1, 1, 1, 1],
|
739 |
-
[8, 8, 8, 8, 8, 8],
|
740 |
-
[0, 0, 0, 0, 0, 0],
|
741 |
-
[0, 0, 0, 0, 0, 0],
|
742 |
-
[1, 1, 1, 1, 1, 1],
|
743 |
-
[2, 2, 2, 2, 2, 2],
|
744 |
-
[0, 0, 0, 0, 0, 0],
|
745 |
-
[2, 2, 2, 2, 0, 2],
|
746 |
-
[1, 1, 1, 1, 1, 1],
|
747 |
-
[2, 2, 2, 2, 2, 2],
|
748 |
-
[0, 0, 0, 0, 0, 0],
|
749 |
-
[2, 2, 2, 2, 2, 2],
|
750 |
-
[1, 1, 1, 1, 1, 1],
|
751 |
-
[2, 2, 2, 2, 2, 2],
|
752 |
-
[0, 0, 0, 0, 0, 0],
|
753 |
-
[0, 0, 0, 0, 0, 0],
|
754 |
-
[4, 1, 1, 4, 4, 1],
|
755 |
-
[2, 2, 2, 2, 2, 2],
|
756 |
-
[0, 0, 0, 0, 0, 0],
|
757 |
-
[0, 0, 0, 0, 0, 0],
|
758 |
-
[1, 1, 1, 1, 1, 1],
|
759 |
-
[1, 1, 1, 1, 1, 1],
|
760 |
-
[1, 1, 1, 1, 0, 1],
|
761 |
-
[0, 0, 0, 0, 0, 0],
|
762 |
-
[2, 2, 2, 2, 2, 2],
|
763 |
-
[1, 1, 1, 1, 1, 1],
|
764 |
-
[0, 0, 0, 0, 0, 0],
|
765 |
-
[0, 0, 0, 0, 0, 0],
|
766 |
-
[2, 4, 2, 4, 4, 2],
|
767 |
-
[1, 1, 1, 1, 1, 1],
|
768 |
-
[1, 1, 1, 1, 1, 1],
|
769 |
-
[0, 0, 0, 0, 0, 0],
|
770 |
-
[2, 2, 2, 2, 2, 2],
|
771 |
-
[1, 1, 1, 1, 1, 1],
|
772 |
-
[0, 0, 0, 0, 0, 0],
|
773 |
-
[0, 0, 0, 0, 0, 0],
|
774 |
-
[2, 2, 2, 2, 2, 2],
|
775 |
-
[1, 1, 1, 1, 1, 1],
|
776 |
-
[0, 0, 0, 0, 0, 0],
|
777 |
-
[0, 0, 0, 0, 0, 0],
|
778 |
-
[5, 5, 5, 5, 5, 5],
|
779 |
-
[1, 1, 1, 1, 1, 1],
|
780 |
-
[0, 0, 0, 0, 0, 0],
|
781 |
-
[0, 0, 0, 0, 0, 0],
|
782 |
-
[4, 4, 4, 4, 4, 4],
|
783 |
-
[1, 1, 1, 1, 1, 1],
|
784 |
-
[0, 0, 0, 0, 0, 0],
|
785 |
-
[0, 0, 0, 0, 0, 0],
|
786 |
-
[4, 4, 4, 4, 4, 4],
|
787 |
-
[1, 1, 1, 1, 1, 1],
|
788 |
-
[0, 0, 0, 0, 0, 0],
|
789 |
-
[0, 0, 0, 0, 0, 0],
|
790 |
-
[12, 12, 12, 12, 12, 12]
|
791 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/lrm.py
DELETED
@@ -1,196 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Zexin He
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
import torch.nn as nn
|
18 |
-
import mcubes
|
19 |
-
import nvdiffrast.torch as dr
|
20 |
-
from einops import rearrange, repeat
|
21 |
-
|
22 |
-
from .encoder.dino_wrapper import DinoWrapper
|
23 |
-
from .decoder.transformer import TriplaneTransformer
|
24 |
-
from .renderer.synthesizer import TriplaneSynthesizer
|
25 |
-
from ..utils.mesh_util import xatlas_uvmap
|
26 |
-
|
27 |
-
|
28 |
-
class InstantNeRF(nn.Module):
|
29 |
-
"""
|
30 |
-
Full model of the large reconstruction model.
|
31 |
-
"""
|
32 |
-
def __init__(
|
33 |
-
self,
|
34 |
-
encoder_freeze: bool = False,
|
35 |
-
encoder_model_name: str = 'facebook/dino-vitb16',
|
36 |
-
encoder_feat_dim: int = 768,
|
37 |
-
transformer_dim: int = 1024,
|
38 |
-
transformer_layers: int = 16,
|
39 |
-
transformer_heads: int = 16,
|
40 |
-
triplane_low_res: int = 32,
|
41 |
-
triplane_high_res: int = 64,
|
42 |
-
triplane_dim: int = 80,
|
43 |
-
rendering_samples_per_ray: int = 128,
|
44 |
-
):
|
45 |
-
super().__init__()
|
46 |
-
|
47 |
-
# modules
|
48 |
-
self.encoder = DinoWrapper(
|
49 |
-
model_name=encoder_model_name,
|
50 |
-
freeze=encoder_freeze,
|
51 |
-
)
|
52 |
-
|
53 |
-
self.transformer = TriplaneTransformer(
|
54 |
-
inner_dim=transformer_dim,
|
55 |
-
num_layers=transformer_layers,
|
56 |
-
num_heads=transformer_heads,
|
57 |
-
image_feat_dim=encoder_feat_dim,
|
58 |
-
triplane_low_res=triplane_low_res,
|
59 |
-
triplane_high_res=triplane_high_res,
|
60 |
-
triplane_dim=triplane_dim,
|
61 |
-
)
|
62 |
-
|
63 |
-
self.synthesizer = TriplaneSynthesizer(
|
64 |
-
triplane_dim=triplane_dim,
|
65 |
-
samples_per_ray=rendering_samples_per_ray,
|
66 |
-
)
|
67 |
-
|
68 |
-
def forward_planes(self, images, cameras):
|
69 |
-
# images: [B, V, C_img, H_img, W_img]
|
70 |
-
# cameras: [B, V, 16]
|
71 |
-
B = images.shape[0]
|
72 |
-
|
73 |
-
# encode images
|
74 |
-
image_feats = self.encoder(images, cameras)
|
75 |
-
image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
|
76 |
-
|
77 |
-
# transformer generating planes
|
78 |
-
planes = self.transformer(image_feats)
|
79 |
-
|
80 |
-
return planes
|
81 |
-
|
82 |
-
def forward(self, images, cameras, render_cameras, render_size: int):
|
83 |
-
# images: [B, V, C_img, H_img, W_img]
|
84 |
-
# cameras: [B, V, 16]
|
85 |
-
# render_cameras: [B, M, D_cam_render]
|
86 |
-
# render_size: int
|
87 |
-
B, M = render_cameras.shape[:2]
|
88 |
-
|
89 |
-
planes = self.forward_planes(images, cameras)
|
90 |
-
|
91 |
-
# render target views
|
92 |
-
render_results = self.synthesizer(planes, render_cameras, render_size)
|
93 |
-
|
94 |
-
return {
|
95 |
-
'planes': planes,
|
96 |
-
**render_results,
|
97 |
-
}
|
98 |
-
|
99 |
-
def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
|
100 |
-
'''
|
101 |
-
Predict Texture given triplanes
|
102 |
-
:param planes: the triplane feature map
|
103 |
-
:param tex_pos: Position we want to query the texture field
|
104 |
-
:param hard_mask: 2D silhoueete of the rendered image
|
105 |
-
'''
|
106 |
-
tex_pos = torch.cat(tex_pos, dim=0)
|
107 |
-
if not hard_mask is None:
|
108 |
-
tex_pos = tex_pos * hard_mask.float()
|
109 |
-
batch_size = tex_pos.shape[0]
|
110 |
-
tex_pos = tex_pos.reshape(batch_size, -1, 3)
|
111 |
-
###################
|
112 |
-
# We use mask to get the texture location (to save the memory)
|
113 |
-
if hard_mask is not None:
|
114 |
-
n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
|
115 |
-
sample_tex_pose_list = []
|
116 |
-
max_point = n_point_list.max()
|
117 |
-
expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
|
118 |
-
for i in range(tex_pos.shape[0]):
|
119 |
-
tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
|
120 |
-
if tex_pos_one_shape.shape[1] < max_point:
|
121 |
-
tex_pos_one_shape = torch.cat(
|
122 |
-
[tex_pos_one_shape, torch.zeros(
|
123 |
-
1, max_point - tex_pos_one_shape.shape[1], 3,
|
124 |
-
device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
|
125 |
-
sample_tex_pose_list.append(tex_pos_one_shape)
|
126 |
-
tex_pos = torch.cat(sample_tex_pose_list, dim=0)
|
127 |
-
|
128 |
-
tex_feat = self.synthesizer.forward_points(planes, tex_pos)['rgb']
|
129 |
-
|
130 |
-
if hard_mask is not None:
|
131 |
-
final_tex_feat = torch.zeros(
|
132 |
-
planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
|
133 |
-
expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
|
134 |
-
for i in range(planes.shape[0]):
|
135 |
-
final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
|
136 |
-
tex_feat = final_tex_feat
|
137 |
-
|
138 |
-
return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
|
139 |
-
|
140 |
-
def extract_mesh(
|
141 |
-
self,
|
142 |
-
planes: torch.Tensor,
|
143 |
-
mesh_resolution: int = 256,
|
144 |
-
mesh_threshold: int = 10.0,
|
145 |
-
use_texture_map: bool = False,
|
146 |
-
texture_resolution: int = 1024,
|
147 |
-
**kwargs,
|
148 |
-
):
|
149 |
-
'''
|
150 |
-
Extract a 3D mesh from triplane nerf. Only support batch_size 1.
|
151 |
-
:param planes: triplane features
|
152 |
-
:param mesh_resolution: marching cubes resolution
|
153 |
-
:param mesh_threshold: iso-surface threshold
|
154 |
-
:param use_texture_map: use texture map or vertex color
|
155 |
-
:param texture_resolution: the resolution of texture map
|
156 |
-
'''
|
157 |
-
assert planes.shape[0] == 1
|
158 |
-
device = planes.device
|
159 |
-
|
160 |
-
grid_out = self.synthesizer.forward_grid(
|
161 |
-
planes=planes,
|
162 |
-
grid_size=mesh_resolution,
|
163 |
-
)
|
164 |
-
|
165 |
-
vertices, faces = mcubes.marching_cubes(
|
166 |
-
grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(),
|
167 |
-
mesh_threshold,
|
168 |
-
)
|
169 |
-
vertices = vertices / (mesh_resolution - 1) * 2 - 1
|
170 |
-
|
171 |
-
if not use_texture_map:
|
172 |
-
# query vertex colors
|
173 |
-
vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0)
|
174 |
-
vertices_colors = self.synthesizer.forward_points(
|
175 |
-
planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy()
|
176 |
-
vertices_colors = (vertices_colors * 255).astype(np.uint8)
|
177 |
-
|
178 |
-
return vertices, faces, vertices_colors
|
179 |
-
|
180 |
-
# use x-atlas to get uv mapping for the mesh
|
181 |
-
vertices = torch.tensor(vertices, dtype=torch.float32, device=device)
|
182 |
-
faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device)
|
183 |
-
|
184 |
-
ctx = dr.RasterizeCudaContext(device=device)
|
185 |
-
uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
|
186 |
-
ctx, vertices, faces, resolution=texture_resolution)
|
187 |
-
tex_hard_mask = tex_hard_mask.float()
|
188 |
-
|
189 |
-
# query the texture field to get the RGB color for texture map
|
190 |
-
tex_feat = self.get_texture_prediction(
|
191 |
-
planes, [gb_pos], tex_hard_mask)
|
192 |
-
background_feature = torch.zeros_like(tex_feat)
|
193 |
-
img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask)
|
194 |
-
texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
|
195 |
-
|
196 |
-
return vertices, faces, uvs, mesh_tex_idx, texture_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/lrm_mesh.py
DELETED
@@ -1,385 +0,0 @@
|
|
1 |
-
# Copyright (c) 2023, Tencent Inc
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
import torch.nn as nn
|
18 |
-
import nvdiffrast.torch as dr
|
19 |
-
from einops import rearrange, repeat
|
20 |
-
|
21 |
-
from .encoder.dino_wrapper import DinoWrapper
|
22 |
-
from .decoder.transformer import TriplaneTransformer
|
23 |
-
from .renderer.synthesizer_mesh import TriplaneSynthesizer
|
24 |
-
from .geometry.camera.perspective_camera import PerspectiveCamera
|
25 |
-
from .geometry.render.neural_render import NeuralRender
|
26 |
-
from .geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry
|
27 |
-
from ..utils.mesh_util import xatlas_uvmap
|
28 |
-
|
29 |
-
|
30 |
-
class InstantMesh(nn.Module):
|
31 |
-
"""
|
32 |
-
Full model of the large reconstruction model.
|
33 |
-
"""
|
34 |
-
def __init__(
|
35 |
-
self,
|
36 |
-
encoder_freeze: bool = False,
|
37 |
-
encoder_model_name: str = 'facebook/dino-vitb16',
|
38 |
-
encoder_feat_dim: int = 768,
|
39 |
-
transformer_dim: int = 1024,
|
40 |
-
transformer_layers: int = 16,
|
41 |
-
transformer_heads: int = 16,
|
42 |
-
triplane_low_res: int = 32,
|
43 |
-
triplane_high_res: int = 64,
|
44 |
-
triplane_dim: int = 80,
|
45 |
-
rendering_samples_per_ray: int = 128,
|
46 |
-
grid_res: int = 128,
|
47 |
-
grid_scale: float = 2.0,
|
48 |
-
):
|
49 |
-
super().__init__()
|
50 |
-
|
51 |
-
# attributes
|
52 |
-
self.grid_res = grid_res
|
53 |
-
self.grid_scale = grid_scale
|
54 |
-
self.deformation_multiplier = 4.0
|
55 |
-
|
56 |
-
# modules
|
57 |
-
self.encoder = DinoWrapper(
|
58 |
-
model_name=encoder_model_name,
|
59 |
-
freeze=encoder_freeze,
|
60 |
-
)
|
61 |
-
|
62 |
-
self.transformer = TriplaneTransformer(
|
63 |
-
inner_dim=transformer_dim,
|
64 |
-
num_layers=transformer_layers,
|
65 |
-
num_heads=transformer_heads,
|
66 |
-
image_feat_dim=encoder_feat_dim,
|
67 |
-
triplane_low_res=triplane_low_res,
|
68 |
-
triplane_high_res=triplane_high_res,
|
69 |
-
triplane_dim=triplane_dim,
|
70 |
-
)
|
71 |
-
|
72 |
-
self.synthesizer = TriplaneSynthesizer(
|
73 |
-
triplane_dim=triplane_dim,
|
74 |
-
samples_per_ray=rendering_samples_per_ray,
|
75 |
-
)
|
76 |
-
|
77 |
-
def init_flexicubes_geometry(self, device, fovy=50.0, use_renderer=True):
|
78 |
-
camera = PerspectiveCamera(fovy=fovy, device=device)
|
79 |
-
if use_renderer:
|
80 |
-
renderer = NeuralRender(device, camera_model=camera)
|
81 |
-
else:
|
82 |
-
renderer = None
|
83 |
-
self.geometry = FlexiCubesGeometry(
|
84 |
-
grid_res=self.grid_res,
|
85 |
-
scale=self.grid_scale,
|
86 |
-
renderer=renderer,
|
87 |
-
render_type='neural_render',
|
88 |
-
device=device,
|
89 |
-
)
|
90 |
-
|
91 |
-
def forward_planes(self, images, cameras):
|
92 |
-
# images: [B, V, C_img, H_img, W_img]
|
93 |
-
# cameras: [B, V, 16]
|
94 |
-
B = images.shape[0]
|
95 |
-
|
96 |
-
# encode images
|
97 |
-
image_feats = self.encoder(images, cameras)
|
98 |
-
image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
|
99 |
-
|
100 |
-
# decode triplanes
|
101 |
-
planes = self.transformer(image_feats)
|
102 |
-
|
103 |
-
return planes
|
104 |
-
|
105 |
-
def get_sdf_deformation_prediction(self, planes):
|
106 |
-
'''
|
107 |
-
Predict SDF and deformation for tetrahedron vertices
|
108 |
-
:param planes: triplane feature map for the geometry
|
109 |
-
'''
|
110 |
-
init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1)
|
111 |
-
|
112 |
-
# Step 1: predict the SDF and deformation
|
113 |
-
sdf, deformation, weight = torch.utils.checkpoint.checkpoint(
|
114 |
-
self.synthesizer.get_geometry_prediction,
|
115 |
-
planes,
|
116 |
-
init_position,
|
117 |
-
self.geometry.indices,
|
118 |
-
use_reentrant=False,
|
119 |
-
)
|
120 |
-
|
121 |
-
# Step 2: Normalize the deformation to avoid the flipped triangles.
|
122 |
-
deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation)
|
123 |
-
sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32)
|
124 |
-
|
125 |
-
####
|
126 |
-
# Step 3: Fix some sdf if we observe empty shape (full positive or full negative)
|
127 |
-
sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1))
|
128 |
-
sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1)
|
129 |
-
pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1)
|
130 |
-
neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1)
|
131 |
-
zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
|
132 |
-
if torch.sum(zero_surface).item() > 0:
|
133 |
-
update_sdf = torch.zeros_like(sdf[0:1])
|
134 |
-
max_sdf = sdf.max()
|
135 |
-
min_sdf = sdf.min()
|
136 |
-
update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero
|
137 |
-
update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero
|
138 |
-
new_sdf = torch.zeros_like(sdf)
|
139 |
-
for i_batch in range(zero_surface.shape[0]):
|
140 |
-
if zero_surface[i_batch]:
|
141 |
-
new_sdf[i_batch:i_batch + 1] += update_sdf
|
142 |
-
update_mask = (new_sdf == 0).float()
|
143 |
-
# Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative)
|
144 |
-
sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1)
|
145 |
-
sdf_reg_loss = sdf_reg_loss * zero_surface.float()
|
146 |
-
sdf = sdf * update_mask + new_sdf * (1 - update_mask)
|
147 |
-
|
148 |
-
# Step 4: Here we remove the gradient for the bad sdf (full positive or full negative)
|
149 |
-
final_sdf = []
|
150 |
-
final_def = []
|
151 |
-
for i_batch in range(zero_surface.shape[0]):
|
152 |
-
if zero_surface[i_batch]:
|
153 |
-
final_sdf.append(sdf[i_batch: i_batch + 1].detach())
|
154 |
-
final_def.append(deformation[i_batch: i_batch + 1].detach())
|
155 |
-
else:
|
156 |
-
final_sdf.append(sdf[i_batch: i_batch + 1])
|
157 |
-
final_def.append(deformation[i_batch: i_batch + 1])
|
158 |
-
sdf = torch.cat(final_sdf, dim=0)
|
159 |
-
deformation = torch.cat(final_def, dim=0)
|
160 |
-
return sdf, deformation, sdf_reg_loss, weight
|
161 |
-
|
162 |
-
def get_geometry_prediction(self, planes=None):
|
163 |
-
'''
|
164 |
-
Function to generate mesh with give triplanes
|
165 |
-
:param planes: triplane features
|
166 |
-
'''
|
167 |
-
# Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid.
|
168 |
-
sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes)
|
169 |
-
v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation
|
170 |
-
tets = self.geometry.indices
|
171 |
-
n_batch = planes.shape[0]
|
172 |
-
v_list = []
|
173 |
-
f_list = []
|
174 |
-
flexicubes_surface_reg_list = []
|
175 |
-
|
176 |
-
# Step 2: Using marching tet to obtain the mesh
|
177 |
-
for i_batch in range(n_batch):
|
178 |
-
verts, faces, flexicubes_surface_reg = self.geometry.get_mesh(
|
179 |
-
v_deformed[i_batch],
|
180 |
-
sdf[i_batch].squeeze(dim=-1),
|
181 |
-
with_uv=False,
|
182 |
-
indices=tets,
|
183 |
-
weight_n=weight[i_batch].squeeze(dim=-1),
|
184 |
-
is_training=self.training,
|
185 |
-
)
|
186 |
-
flexicubes_surface_reg_list.append(flexicubes_surface_reg)
|
187 |
-
v_list.append(verts)
|
188 |
-
f_list.append(faces)
|
189 |
-
|
190 |
-
flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean()
|
191 |
-
flexicubes_weight_reg = (weight ** 2).mean()
|
192 |
-
|
193 |
-
return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
|
194 |
-
|
195 |
-
def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
|
196 |
-
'''
|
197 |
-
Predict Texture given triplanes
|
198 |
-
:param planes: the triplane feature map
|
199 |
-
:param tex_pos: Position we want to query the texture field
|
200 |
-
:param hard_mask: 2D silhoueete of the rendered image
|
201 |
-
'''
|
202 |
-
tex_pos = torch.cat(tex_pos, dim=0)
|
203 |
-
if not hard_mask is None:
|
204 |
-
tex_pos = tex_pos * hard_mask.float()
|
205 |
-
batch_size = tex_pos.shape[0]
|
206 |
-
tex_pos = tex_pos.reshape(batch_size, -1, 3)
|
207 |
-
###################
|
208 |
-
# We use mask to get the texture location (to save the memory)
|
209 |
-
if hard_mask is not None:
|
210 |
-
n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
|
211 |
-
sample_tex_pose_list = []
|
212 |
-
max_point = n_point_list.max()
|
213 |
-
expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
|
214 |
-
for i in range(tex_pos.shape[0]):
|
215 |
-
tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
|
216 |
-
if tex_pos_one_shape.shape[1] < max_point:
|
217 |
-
tex_pos_one_shape = torch.cat(
|
218 |
-
[tex_pos_one_shape, torch.zeros(
|
219 |
-
1, max_point - tex_pos_one_shape.shape[1], 3,
|
220 |
-
device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
|
221 |
-
sample_tex_pose_list.append(tex_pos_one_shape)
|
222 |
-
tex_pos = torch.cat(sample_tex_pose_list, dim=0)
|
223 |
-
|
224 |
-
tex_feat = torch.utils.checkpoint.checkpoint(
|
225 |
-
self.synthesizer.get_texture_prediction,
|
226 |
-
planes,
|
227 |
-
tex_pos,
|
228 |
-
use_reentrant=False,
|
229 |
-
)
|
230 |
-
|
231 |
-
if hard_mask is not None:
|
232 |
-
final_tex_feat = torch.zeros(
|
233 |
-
planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
|
234 |
-
expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
|
235 |
-
for i in range(planes.shape[0]):
|
236 |
-
final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
|
237 |
-
tex_feat = final_tex_feat
|
238 |
-
|
239 |
-
return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
|
240 |
-
|
241 |
-
def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256):
|
242 |
-
'''
|
243 |
-
Function to render a generated mesh with nvdiffrast
|
244 |
-
:param mesh_v: List of vertices for the mesh
|
245 |
-
:param mesh_f: List of faces for the mesh
|
246 |
-
:param cam_mv: 4x4 rotation matrix
|
247 |
-
:return:
|
248 |
-
'''
|
249 |
-
return_value_list = []
|
250 |
-
for i_mesh in range(len(mesh_v)):
|
251 |
-
return_value = self.geometry.render_mesh(
|
252 |
-
mesh_v[i_mesh],
|
253 |
-
mesh_f[i_mesh].int(),
|
254 |
-
cam_mv[i_mesh],
|
255 |
-
resolution=render_size,
|
256 |
-
hierarchical_mask=False
|
257 |
-
)
|
258 |
-
return_value_list.append(return_value)
|
259 |
-
|
260 |
-
return_keys = return_value_list[0].keys()
|
261 |
-
return_value = dict()
|
262 |
-
for k in return_keys:
|
263 |
-
value = [v[k] for v in return_value_list]
|
264 |
-
return_value[k] = value
|
265 |
-
|
266 |
-
mask = torch.cat(return_value['mask'], dim=0)
|
267 |
-
hard_mask = torch.cat(return_value['hard_mask'], dim=0)
|
268 |
-
tex_pos = return_value['tex_pos']
|
269 |
-
depth = torch.cat(return_value['depth'], dim=0)
|
270 |
-
normal = torch.cat(return_value['normal'], dim=0)
|
271 |
-
return mask, hard_mask, tex_pos, depth, normal
|
272 |
-
|
273 |
-
def forward_geometry(self, planes, render_cameras, render_size=256):
|
274 |
-
'''
|
275 |
-
Main function of our Generator. It first generate 3D mesh, then render it into 2D image
|
276 |
-
with given `render_cameras`.
|
277 |
-
:param planes: triplane features
|
278 |
-
:param render_cameras: cameras to render generated 3D shape
|
279 |
-
'''
|
280 |
-
B, NV = render_cameras.shape[:2]
|
281 |
-
|
282 |
-
# Generate 3D mesh first
|
283 |
-
mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
|
284 |
-
|
285 |
-
# Render the mesh into 2D image (get 3d position of each image plane)
|
286 |
-
cam_mv = render_cameras
|
287 |
-
run_n_view = cam_mv.shape[1]
|
288 |
-
antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size)
|
289 |
-
|
290 |
-
tex_hard_mask = hard_mask
|
291 |
-
tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos]
|
292 |
-
tex_hard_mask = torch.cat(
|
293 |
-
[torch.cat(
|
294 |
-
[tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1]
|
295 |
-
for i_view in range(run_n_view)], dim=2)
|
296 |
-
for i in range(planes.shape[0])], dim=0)
|
297 |
-
|
298 |
-
# Querying the texture field to predict the texture feature for each pixel on the image
|
299 |
-
tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask)
|
300 |
-
background_feature = torch.ones_like(tex_feat) # white background
|
301 |
-
|
302 |
-
# Merge them together
|
303 |
-
img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask)
|
304 |
-
|
305 |
-
# We should split it back to the original image shape
|
306 |
-
img_feat = torch.cat(
|
307 |
-
[torch.cat(
|
308 |
-
[img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)]
|
309 |
-
for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0)
|
310 |
-
|
311 |
-
img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV))
|
312 |
-
antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV))
|
313 |
-
depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive
|
314 |
-
normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV))
|
315 |
-
|
316 |
-
out = {
|
317 |
-
'img': img,
|
318 |
-
'mask': antilias_mask,
|
319 |
-
'depth': depth,
|
320 |
-
'normal': normal,
|
321 |
-
'sdf': sdf,
|
322 |
-
'mesh_v': mesh_v,
|
323 |
-
'mesh_f': mesh_f,
|
324 |
-
'sdf_reg_loss': sdf_reg_loss,
|
325 |
-
}
|
326 |
-
return out
|
327 |
-
|
328 |
-
def forward(self, images, cameras, render_cameras, render_size: int):
|
329 |
-
# images: [B, V, C_img, H_img, W_img]
|
330 |
-
# cameras: [B, V, 16]
|
331 |
-
# render_cameras: [B, M, D_cam_render]
|
332 |
-
# render_size: int
|
333 |
-
B, M = render_cameras.shape[:2]
|
334 |
-
|
335 |
-
planes = self.forward_planes(images, cameras)
|
336 |
-
out = self.forward_geometry(planes, render_cameras, render_size=render_size)
|
337 |
-
|
338 |
-
return {
|
339 |
-
'planes': planes,
|
340 |
-
**out
|
341 |
-
}
|
342 |
-
|
343 |
-
def extract_mesh(
|
344 |
-
self,
|
345 |
-
planes: torch.Tensor,
|
346 |
-
use_texture_map: bool = False,
|
347 |
-
texture_resolution: int = 1024,
|
348 |
-
**kwargs,
|
349 |
-
):
|
350 |
-
'''
|
351 |
-
Extract a 3D mesh from FlexiCubes. Only support batch_size 1.
|
352 |
-
:param planes: triplane features
|
353 |
-
:param use_texture_map: use texture map or vertex color
|
354 |
-
:param texture_resolution: the resolution of texure map
|
355 |
-
'''
|
356 |
-
assert planes.shape[0] == 1
|
357 |
-
device = planes.device
|
358 |
-
|
359 |
-
# predict geometry first
|
360 |
-
mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
|
361 |
-
vertices, faces = mesh_v[0], mesh_f[0]
|
362 |
-
|
363 |
-
if not use_texture_map:
|
364 |
-
# query vertex colors
|
365 |
-
vertices_tensor = vertices.unsqueeze(0)
|
366 |
-
vertices_colors = self.synthesizer.get_texture_prediction(
|
367 |
-
planes, vertices_tensor).clamp(0, 1).squeeze(0).cpu().numpy()
|
368 |
-
vertices_colors = (vertices_colors * 255).astype(np.uint8)
|
369 |
-
|
370 |
-
return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors
|
371 |
-
|
372 |
-
# use x-atlas to get uv mapping for the mesh
|
373 |
-
ctx = dr.RasterizeCudaContext(device=device)
|
374 |
-
uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
|
375 |
-
self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution)
|
376 |
-
tex_hard_mask = tex_hard_mask.float()
|
377 |
-
|
378 |
-
# query the texture field to get the RGB color for texture map
|
379 |
-
tex_feat = self.get_texture_prediction(
|
380 |
-
planes, [gb_pos], tex_hard_mask)
|
381 |
-
background_feature = torch.zeros_like(tex_feat)
|
382 |
-
img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask)
|
383 |
-
texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
|
384 |
-
|
385 |
-
return vertices, faces, uvs, mesh_tex_idx, texture_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/renderer/__init__.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
-
#
|
4 |
-
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
-
# property and proprietary rights in and to this material, related
|
6 |
-
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
-
# disclosure or distribution of this material and related documentation
|
8 |
-
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
-
# its affiliates is strictly prohibited.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/renderer/synthesizer.py
DELETED
@@ -1,203 +0,0 @@
|
|
1 |
-
# ORIGINAL LICENSE
|
2 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
4 |
-
#
|
5 |
-
# Modified by Jiale Xu
|
6 |
-
# The modifications are subject to the same license as the original.
|
7 |
-
|
8 |
-
|
9 |
-
import itertools
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
|
13 |
-
from .utils.renderer import ImportanceRenderer
|
14 |
-
from .utils.ray_sampler import RaySampler
|
15 |
-
|
16 |
-
|
17 |
-
class OSGDecoder(nn.Module):
|
18 |
-
"""
|
19 |
-
Triplane decoder that gives RGB and sigma values from sampled features.
|
20 |
-
Using ReLU here instead of Softplus in the original implementation.
|
21 |
-
|
22 |
-
Reference:
|
23 |
-
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
|
24 |
-
"""
|
25 |
-
def __init__(self, n_features: int,
|
26 |
-
hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
|
27 |
-
super().__init__()
|
28 |
-
self.net = nn.Sequential(
|
29 |
-
nn.Linear(3 * n_features, hidden_dim),
|
30 |
-
activation(),
|
31 |
-
*itertools.chain(*[[
|
32 |
-
nn.Linear(hidden_dim, hidden_dim),
|
33 |
-
activation(),
|
34 |
-
] for _ in range(num_layers - 2)]),
|
35 |
-
nn.Linear(hidden_dim, 1 + 3),
|
36 |
-
)
|
37 |
-
# init all bias to zero
|
38 |
-
for m in self.modules():
|
39 |
-
if isinstance(m, nn.Linear):
|
40 |
-
nn.init.zeros_(m.bias)
|
41 |
-
|
42 |
-
def forward(self, sampled_features, ray_directions):
|
43 |
-
# Aggregate features by mean
|
44 |
-
# sampled_features = sampled_features.mean(1)
|
45 |
-
# Aggregate features by concatenation
|
46 |
-
_N, n_planes, _M, _C = sampled_features.shape
|
47 |
-
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
|
48 |
-
x = sampled_features
|
49 |
-
|
50 |
-
N, M, C = x.shape
|
51 |
-
x = x.contiguous().view(N*M, C)
|
52 |
-
|
53 |
-
x = self.net(x)
|
54 |
-
x = x.view(N, M, -1)
|
55 |
-
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
|
56 |
-
sigma = x[..., 0:1]
|
57 |
-
|
58 |
-
return {'rgb': rgb, 'sigma': sigma}
|
59 |
-
|
60 |
-
|
61 |
-
class TriplaneSynthesizer(nn.Module):
|
62 |
-
"""
|
63 |
-
Synthesizer that renders a triplane volume with planes and a camera.
|
64 |
-
|
65 |
-
Reference:
|
66 |
-
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
|
67 |
-
"""
|
68 |
-
|
69 |
-
DEFAULT_RENDERING_KWARGS = {
|
70 |
-
'ray_start': 'auto',
|
71 |
-
'ray_end': 'auto',
|
72 |
-
'box_warp': 2.,
|
73 |
-
'white_back': True,
|
74 |
-
'disparity_space_sampling': False,
|
75 |
-
'clamp_mode': 'softplus',
|
76 |
-
'sampler_bbox_min': -1.,
|
77 |
-
'sampler_bbox_max': 1.,
|
78 |
-
}
|
79 |
-
|
80 |
-
def __init__(self, triplane_dim: int, samples_per_ray: int):
|
81 |
-
super().__init__()
|
82 |
-
|
83 |
-
# attributes
|
84 |
-
self.triplane_dim = triplane_dim
|
85 |
-
self.rendering_kwargs = {
|
86 |
-
**self.DEFAULT_RENDERING_KWARGS,
|
87 |
-
'depth_resolution': samples_per_ray // 2,
|
88 |
-
'depth_resolution_importance': samples_per_ray // 2,
|
89 |
-
}
|
90 |
-
|
91 |
-
# renderings
|
92 |
-
self.renderer = ImportanceRenderer()
|
93 |
-
self.ray_sampler = RaySampler()
|
94 |
-
|
95 |
-
# modules
|
96 |
-
self.decoder = OSGDecoder(n_features=triplane_dim)
|
97 |
-
|
98 |
-
def forward(self, planes, cameras, render_size=128, crop_params=None):
|
99 |
-
# planes: (N, 3, D', H', W')
|
100 |
-
# cameras: (N, M, D_cam)
|
101 |
-
# render_size: int
|
102 |
-
assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras"
|
103 |
-
N, M = cameras.shape[:2]
|
104 |
-
|
105 |
-
cam2world_matrix = cameras[..., :16].view(N, M, 4, 4)
|
106 |
-
intrinsics = cameras[..., 16:25].view(N, M, 3, 3)
|
107 |
-
|
108 |
-
# Create a batch of rays for volume rendering
|
109 |
-
ray_origins, ray_directions = self.ray_sampler(
|
110 |
-
cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4),
|
111 |
-
intrinsics=intrinsics.reshape(-1, 3, 3),
|
112 |
-
render_size=render_size,
|
113 |
-
)
|
114 |
-
assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins"
|
115 |
-
assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
|
116 |
-
|
117 |
-
# Crop rays if crop_params is available
|
118 |
-
if crop_params is not None:
|
119 |
-
ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3)
|
120 |
-
ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3)
|
121 |
-
i, j, h, w = crop_params
|
122 |
-
ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3)
|
123 |
-
ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3)
|
124 |
-
|
125 |
-
# Perform volume rendering
|
126 |
-
rgb_samples, depth_samples, weights_samples = self.renderer(
|
127 |
-
planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs,
|
128 |
-
)
|
129 |
-
|
130 |
-
# Reshape into 'raw' neural-rendered image
|
131 |
-
if crop_params is not None:
|
132 |
-
Himg, Wimg = crop_params[2:]
|
133 |
-
else:
|
134 |
-
Himg = Wimg = render_size
|
135 |
-
rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous()
|
136 |
-
depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
|
137 |
-
weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
|
138 |
-
|
139 |
-
out = {
|
140 |
-
'images_rgb': rgb_images,
|
141 |
-
'images_depth': depth_images,
|
142 |
-
'images_weight': weight_images,
|
143 |
-
}
|
144 |
-
return out
|
145 |
-
|
146 |
-
def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
|
147 |
-
# planes: (N, 3, D', H', W')
|
148 |
-
# grid_size: int
|
149 |
-
# aabb: (N, 2, 3)
|
150 |
-
if aabb is None:
|
151 |
-
aabb = torch.tensor([
|
152 |
-
[self.rendering_kwargs['sampler_bbox_min']] * 3,
|
153 |
-
[self.rendering_kwargs['sampler_bbox_max']] * 3,
|
154 |
-
], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
|
155 |
-
assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
|
156 |
-
N = planes.shape[0]
|
157 |
-
|
158 |
-
# create grid points for triplane query
|
159 |
-
grid_points = []
|
160 |
-
for i in range(N):
|
161 |
-
grid_points.append(torch.stack(torch.meshgrid(
|
162 |
-
torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
|
163 |
-
torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
|
164 |
-
torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
|
165 |
-
indexing='ij',
|
166 |
-
), dim=-1).reshape(-1, 3))
|
167 |
-
cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
|
168 |
-
|
169 |
-
features = self.forward_points(planes, cube_grid)
|
170 |
-
|
171 |
-
# reshape into grid
|
172 |
-
features = {
|
173 |
-
k: v.reshape(N, grid_size, grid_size, grid_size, -1)
|
174 |
-
for k, v in features.items()
|
175 |
-
}
|
176 |
-
return features
|
177 |
-
|
178 |
-
def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
|
179 |
-
# planes: (N, 3, D', H', W')
|
180 |
-
# points: (N, P, 3)
|
181 |
-
N, P = points.shape[:2]
|
182 |
-
|
183 |
-
# query triplane in chunks
|
184 |
-
outs = []
|
185 |
-
for i in range(0, points.shape[1], chunk_size):
|
186 |
-
chunk_points = points[:, i:i+chunk_size]
|
187 |
-
|
188 |
-
# query triplane
|
189 |
-
chunk_out = self.renderer.run_model_activated(
|
190 |
-
planes=planes,
|
191 |
-
decoder=self.decoder,
|
192 |
-
sample_coordinates=chunk_points,
|
193 |
-
sample_directions=torch.zeros_like(chunk_points),
|
194 |
-
options=self.rendering_kwargs,
|
195 |
-
)
|
196 |
-
outs.append(chunk_out)
|
197 |
-
|
198 |
-
# concatenate the outputs
|
199 |
-
point_features = {
|
200 |
-
k: torch.cat([out[k] for out in outs], dim=1)
|
201 |
-
for k in outs[0].keys()
|
202 |
-
}
|
203 |
-
return point_features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/renderer/synthesizer_mesh.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
# ORIGINAL LICENSE
|
2 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
4 |
-
#
|
5 |
-
# Modified by Jiale Xu
|
6 |
-
# The modifications are subject to the same license as the original.
|
7 |
-
|
8 |
-
import itertools
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
|
12 |
-
from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes
|
13 |
-
|
14 |
-
|
15 |
-
class OSGDecoder(nn.Module):
|
16 |
-
"""
|
17 |
-
Triplane decoder that gives RGB and sigma values from sampled features.
|
18 |
-
Using ReLU here instead of Softplus in the original implementation.
|
19 |
-
|
20 |
-
Reference:
|
21 |
-
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
|
22 |
-
"""
|
23 |
-
def __init__(self, n_features: int,
|
24 |
-
hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
|
25 |
-
super().__init__()
|
26 |
-
|
27 |
-
self.net_sdf = nn.Sequential(
|
28 |
-
nn.Linear(3 * n_features, hidden_dim),
|
29 |
-
activation(),
|
30 |
-
*itertools.chain(*[[
|
31 |
-
nn.Linear(hidden_dim, hidden_dim),
|
32 |
-
activation(),
|
33 |
-
] for _ in range(num_layers - 2)]),
|
34 |
-
nn.Linear(hidden_dim, 1),
|
35 |
-
)
|
36 |
-
self.net_rgb = nn.Sequential(
|
37 |
-
nn.Linear(3 * n_features, hidden_dim),
|
38 |
-
activation(),
|
39 |
-
*itertools.chain(*[[
|
40 |
-
nn.Linear(hidden_dim, hidden_dim),
|
41 |
-
activation(),
|
42 |
-
] for _ in range(num_layers - 2)]),
|
43 |
-
nn.Linear(hidden_dim, 3),
|
44 |
-
)
|
45 |
-
self.net_deformation = nn.Sequential(
|
46 |
-
nn.Linear(3 * n_features, hidden_dim),
|
47 |
-
activation(),
|
48 |
-
*itertools.chain(*[[
|
49 |
-
nn.Linear(hidden_dim, hidden_dim),
|
50 |
-
activation(),
|
51 |
-
] for _ in range(num_layers - 2)]),
|
52 |
-
nn.Linear(hidden_dim, 3),
|
53 |
-
)
|
54 |
-
self.net_weight = nn.Sequential(
|
55 |
-
nn.Linear(8 * 3 * n_features, hidden_dim),
|
56 |
-
activation(),
|
57 |
-
*itertools.chain(*[[
|
58 |
-
nn.Linear(hidden_dim, hidden_dim),
|
59 |
-
activation(),
|
60 |
-
] for _ in range(num_layers - 2)]),
|
61 |
-
nn.Linear(hidden_dim, 21),
|
62 |
-
)
|
63 |
-
|
64 |
-
# init all bias to zero
|
65 |
-
for m in self.modules():
|
66 |
-
if isinstance(m, nn.Linear):
|
67 |
-
nn.init.zeros_(m.bias)
|
68 |
-
|
69 |
-
def get_geometry_prediction(self, sampled_features, flexicubes_indices):
|
70 |
-
_N, n_planes, _M, _C = sampled_features.shape
|
71 |
-
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
|
72 |
-
|
73 |
-
sdf = self.net_sdf(sampled_features)
|
74 |
-
deformation = self.net_deformation(sampled_features)
|
75 |
-
|
76 |
-
grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1)
|
77 |
-
grid_features = grid_features.reshape(
|
78 |
-
sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1])
|
79 |
-
weight = self.net_weight(grid_features) * 0.1
|
80 |
-
|
81 |
-
return sdf, deformation, weight
|
82 |
-
|
83 |
-
def get_texture_prediction(self, sampled_features):
|
84 |
-
_N, n_planes, _M, _C = sampled_features.shape
|
85 |
-
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
|
86 |
-
|
87 |
-
rgb = self.net_rgb(sampled_features)
|
88 |
-
rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
|
89 |
-
|
90 |
-
return rgb
|
91 |
-
|
92 |
-
|
93 |
-
class TriplaneSynthesizer(nn.Module):
|
94 |
-
"""
|
95 |
-
Synthesizer that renders a triplane volume with planes and a camera.
|
96 |
-
|
97 |
-
Reference:
|
98 |
-
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
|
99 |
-
"""
|
100 |
-
|
101 |
-
DEFAULT_RENDERING_KWARGS = {
|
102 |
-
'ray_start': 'auto',
|
103 |
-
'ray_end': 'auto',
|
104 |
-
'box_warp': 2.,
|
105 |
-
'white_back': True,
|
106 |
-
'disparity_space_sampling': False,
|
107 |
-
'clamp_mode': 'softplus',
|
108 |
-
'sampler_bbox_min': -1.,
|
109 |
-
'sampler_bbox_max': 1.,
|
110 |
-
}
|
111 |
-
|
112 |
-
def __init__(self, triplane_dim: int, samples_per_ray: int):
|
113 |
-
super().__init__()
|
114 |
-
|
115 |
-
# attributes
|
116 |
-
self.triplane_dim = triplane_dim
|
117 |
-
self.rendering_kwargs = {
|
118 |
-
**self.DEFAULT_RENDERING_KWARGS,
|
119 |
-
'depth_resolution': samples_per_ray // 2,
|
120 |
-
'depth_resolution_importance': samples_per_ray // 2,
|
121 |
-
}
|
122 |
-
|
123 |
-
# modules
|
124 |
-
self.plane_axes = generate_planes()
|
125 |
-
self.decoder = OSGDecoder(n_features=triplane_dim)
|
126 |
-
|
127 |
-
def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices):
|
128 |
-
plane_axes = self.plane_axes.to(planes.device)
|
129 |
-
sampled_features = sample_from_planes(
|
130 |
-
plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])
|
131 |
-
|
132 |
-
sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices)
|
133 |
-
return sdf, deformation, weight
|
134 |
-
|
135 |
-
def get_texture_prediction(self, planes, sample_coordinates):
|
136 |
-
plane_axes = self.plane_axes.to(planes.device)
|
137 |
-
sampled_features = sample_from_planes(
|
138 |
-
plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])
|
139 |
-
|
140 |
-
rgb = self.decoder.get_texture_prediction(sampled_features)
|
141 |
-
return rgb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/renderer/utils/__init__.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
-
#
|
4 |
-
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
-
# property and proprietary rights in and to this material, related
|
6 |
-
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
-
# disclosure or distribution of this material and related documentation
|
8 |
-
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
-
# its affiliates is strictly prohibited.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/renderer/utils/math_utils.py
DELETED
@@ -1,118 +0,0 @@
|
|
1 |
-
# MIT License
|
2 |
-
|
3 |
-
# Copyright (c) 2022 Petr Kellnhofer
|
4 |
-
|
5 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
# of this software and associated documentation files (the "Software"), to deal
|
7 |
-
# in the Software without restriction, including without limitation the rights
|
8 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
# copies of the Software, and to permit persons to whom the Software is
|
10 |
-
# furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
# The above copyright notice and this permission notice shall be included in all
|
13 |
-
# copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
# SOFTWARE.
|
22 |
-
|
23 |
-
import torch
|
24 |
-
|
25 |
-
def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
|
26 |
-
"""
|
27 |
-
Left-multiplies MxM @ NxM. Returns NxM.
|
28 |
-
"""
|
29 |
-
res = torch.matmul(vectors4, matrix.T)
|
30 |
-
return res
|
31 |
-
|
32 |
-
|
33 |
-
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
|
34 |
-
"""
|
35 |
-
Normalize vector lengths.
|
36 |
-
"""
|
37 |
-
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
|
38 |
-
|
39 |
-
def torch_dot(x: torch.Tensor, y: torch.Tensor):
|
40 |
-
"""
|
41 |
-
Dot product of two tensors.
|
42 |
-
"""
|
43 |
-
return (x * y).sum(-1)
|
44 |
-
|
45 |
-
|
46 |
-
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
|
47 |
-
"""
|
48 |
-
Author: Petr Kellnhofer
|
49 |
-
Intersects rays with the [-1, 1] NDC volume.
|
50 |
-
Returns min and max distance of entry.
|
51 |
-
Returns -1 for no intersection.
|
52 |
-
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
|
53 |
-
"""
|
54 |
-
o_shape = rays_o.shape
|
55 |
-
rays_o = rays_o.detach().reshape(-1, 3)
|
56 |
-
rays_d = rays_d.detach().reshape(-1, 3)
|
57 |
-
|
58 |
-
|
59 |
-
bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
|
60 |
-
bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
|
61 |
-
bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
|
62 |
-
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
|
63 |
-
|
64 |
-
# Precompute inverse for stability.
|
65 |
-
invdir = 1 / rays_d
|
66 |
-
sign = (invdir < 0).long()
|
67 |
-
|
68 |
-
# Intersect with YZ plane.
|
69 |
-
tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
70 |
-
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
71 |
-
|
72 |
-
# Intersect with XZ plane.
|
73 |
-
tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
74 |
-
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
75 |
-
|
76 |
-
# Resolve parallel rays.
|
77 |
-
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
|
78 |
-
|
79 |
-
# Use the shortest intersection.
|
80 |
-
tmin = torch.max(tmin, tymin)
|
81 |
-
tmax = torch.min(tmax, tymax)
|
82 |
-
|
83 |
-
# Intersect with XY plane.
|
84 |
-
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
85 |
-
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
86 |
-
|
87 |
-
# Resolve parallel rays.
|
88 |
-
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
|
89 |
-
|
90 |
-
# Use the shortest intersection.
|
91 |
-
tmin = torch.max(tmin, tzmin)
|
92 |
-
tmax = torch.min(tmax, tzmax)
|
93 |
-
|
94 |
-
# Mark invalid.
|
95 |
-
tmin[torch.logical_not(is_valid)] = -1
|
96 |
-
tmax[torch.logical_not(is_valid)] = -2
|
97 |
-
|
98 |
-
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
|
99 |
-
|
100 |
-
|
101 |
-
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
|
102 |
-
"""
|
103 |
-
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
|
104 |
-
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
|
105 |
-
"""
|
106 |
-
# create a tensor of 'num' steps from 0 to 1
|
107 |
-
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
|
108 |
-
|
109 |
-
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
|
110 |
-
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
|
111 |
-
# "cannot statically infer the expected size of a list in this contex", hence the code below
|
112 |
-
for i in range(start.ndim):
|
113 |
-
steps = steps.unsqueeze(-1)
|
114 |
-
|
115 |
-
# the output starts at 'start' and increments until 'stop' in each dimension
|
116 |
-
out = start[None] + steps * (stop - start)[None]
|
117 |
-
|
118 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/renderer/utils/ray_marcher.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
-
#
|
4 |
-
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
-
# property and proprietary rights in and to this material, related
|
6 |
-
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
-
# disclosure or distribution of this material and related documentation
|
8 |
-
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
-
# its affiliates is strictly prohibited.
|
10 |
-
#
|
11 |
-
# Modified by Jiale Xu
|
12 |
-
# The modifications are subject to the same license as the original.
|
13 |
-
|
14 |
-
|
15 |
-
"""
|
16 |
-
The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
|
17 |
-
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
|
18 |
-
"""
|
19 |
-
|
20 |
-
import torch
|
21 |
-
import torch.nn as nn
|
22 |
-
import torch.nn.functional as F
|
23 |
-
|
24 |
-
|
25 |
-
class MipRayMarcher2(nn.Module):
|
26 |
-
def __init__(self, activation_factory):
|
27 |
-
super().__init__()
|
28 |
-
self.activation_factory = activation_factory
|
29 |
-
|
30 |
-
def run_forward(self, colors, densities, depths, rendering_options, normals=None):
|
31 |
-
dtype = colors.dtype
|
32 |
-
deltas = depths[:, :, 1:] - depths[:, :, :-1]
|
33 |
-
colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
|
34 |
-
densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
|
35 |
-
depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
|
36 |
-
|
37 |
-
# using factory mode for better usability
|
38 |
-
densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype)
|
39 |
-
|
40 |
-
density_delta = densities_mid * deltas
|
41 |
-
|
42 |
-
alpha = 1 - torch.exp(-density_delta).to(dtype)
|
43 |
-
|
44 |
-
alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
|
45 |
-
weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
|
46 |
-
weights = weights.to(dtype)
|
47 |
-
|
48 |
-
composite_rgb = torch.sum(weights * colors_mid, -2)
|
49 |
-
weight_total = weights.sum(2)
|
50 |
-
# composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
|
51 |
-
composite_depth = torch.sum(weights * depths_mid, -2)
|
52 |
-
|
53 |
-
# clip the composite to min/max range of depths
|
54 |
-
composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype)
|
55 |
-
composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
|
56 |
-
|
57 |
-
if rendering_options.get('white_back', False):
|
58 |
-
composite_rgb = composite_rgb + 1 - weight_total
|
59 |
-
|
60 |
-
# rendered value scale is 0-1, comment out original mipnerf scaling
|
61 |
-
# composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
|
62 |
-
|
63 |
-
return composite_rgb, composite_depth, weights
|
64 |
-
|
65 |
-
|
66 |
-
def forward(self, colors, densities, depths, rendering_options, normals=None):
|
67 |
-
if normals is not None:
|
68 |
-
composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals)
|
69 |
-
return composite_rgb, composite_depth, composite_normals, weights
|
70 |
-
|
71 |
-
composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
|
72 |
-
return composite_rgb, composite_depth, weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/renderer/utils/ray_sampler.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
-
#
|
4 |
-
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
-
# property and proprietary rights in and to this material, related
|
6 |
-
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
-
# disclosure or distribution of this material and related documentation
|
8 |
-
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
-
# its affiliates is strictly prohibited.
|
10 |
-
#
|
11 |
-
# Modified by Jiale Xu
|
12 |
-
# The modifications are subject to the same license as the original.
|
13 |
-
|
14 |
-
|
15 |
-
"""
|
16 |
-
The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
|
17 |
-
Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
|
18 |
-
"""
|
19 |
-
|
20 |
-
import torch
|
21 |
-
|
22 |
-
class RaySampler(torch.nn.Module):
|
23 |
-
def __init__(self):
|
24 |
-
super().__init__()
|
25 |
-
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
|
26 |
-
|
27 |
-
|
28 |
-
def forward(self, cam2world_matrix, intrinsics, render_size):
|
29 |
-
"""
|
30 |
-
Create batches of rays and return origins and directions.
|
31 |
-
|
32 |
-
cam2world_matrix: (N, 4, 4)
|
33 |
-
intrinsics: (N, 3, 3)
|
34 |
-
render_size: int
|
35 |
-
|
36 |
-
ray_origins: (N, M, 3)
|
37 |
-
ray_dirs: (N, M, 2)
|
38 |
-
"""
|
39 |
-
|
40 |
-
dtype = cam2world_matrix.dtype
|
41 |
-
device = cam2world_matrix.device
|
42 |
-
N, M = cam2world_matrix.shape[0], render_size**2
|
43 |
-
cam_locs_world = cam2world_matrix[:, :3, 3]
|
44 |
-
fx = intrinsics[:, 0, 0]
|
45 |
-
fy = intrinsics[:, 1, 1]
|
46 |
-
cx = intrinsics[:, 0, 2]
|
47 |
-
cy = intrinsics[:, 1, 2]
|
48 |
-
sk = intrinsics[:, 0, 1]
|
49 |
-
|
50 |
-
uv = torch.stack(torch.meshgrid(
|
51 |
-
torch.arange(render_size, dtype=dtype, device=device),
|
52 |
-
torch.arange(render_size, dtype=dtype, device=device),
|
53 |
-
indexing='ij',
|
54 |
-
))
|
55 |
-
uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
|
56 |
-
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
|
57 |
-
|
58 |
-
x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
|
59 |
-
y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
|
60 |
-
z_cam = torch.ones((N, M), dtype=dtype, device=device)
|
61 |
-
|
62 |
-
x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
|
63 |
-
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
|
64 |
-
|
65 |
-
cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype)
|
66 |
-
|
67 |
-
_opencv2blender = torch.tensor([
|
68 |
-
[1, 0, 0, 0],
|
69 |
-
[0, -1, 0, 0],
|
70 |
-
[0, 0, -1, 0],
|
71 |
-
[0, 0, 0, 1],
|
72 |
-
], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1)
|
73 |
-
|
74 |
-
cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
|
75 |
-
|
76 |
-
world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
|
77 |
-
|
78 |
-
ray_dirs = world_rel_points - cam_locs_world[:, None, :]
|
79 |
-
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype)
|
80 |
-
|
81 |
-
ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
|
82 |
-
|
83 |
-
return ray_origins, ray_dirs
|
84 |
-
|
85 |
-
|
86 |
-
class OrthoRaySampler(torch.nn.Module):
|
87 |
-
def __init__(self):
|
88 |
-
super().__init__()
|
89 |
-
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
|
90 |
-
|
91 |
-
|
92 |
-
def forward(self, cam2world_matrix, ortho_scale, render_size):
|
93 |
-
"""
|
94 |
-
Create batches of rays and return origins and directions.
|
95 |
-
|
96 |
-
cam2world_matrix: (N, 4, 4)
|
97 |
-
ortho_scale: float
|
98 |
-
render_size: int
|
99 |
-
|
100 |
-
ray_origins: (N, M, 3)
|
101 |
-
ray_dirs: (N, M, 3)
|
102 |
-
"""
|
103 |
-
|
104 |
-
N, M = cam2world_matrix.shape[0], render_size**2
|
105 |
-
|
106 |
-
uv = torch.stack(torch.meshgrid(
|
107 |
-
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
|
108 |
-
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
|
109 |
-
indexing='ij',
|
110 |
-
))
|
111 |
-
uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
|
112 |
-
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
|
113 |
-
|
114 |
-
x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
|
115 |
-
y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
|
116 |
-
z_cam = torch.zeros((N, M), device=cam2world_matrix.device)
|
117 |
-
|
118 |
-
x_lift = (x_cam - 0.5) * ortho_scale
|
119 |
-
y_lift = (y_cam - 0.5) * ortho_scale
|
120 |
-
|
121 |
-
cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
|
122 |
-
|
123 |
-
_opencv2blender = torch.tensor([
|
124 |
-
[1, 0, 0, 0],
|
125 |
-
[0, -1, 0, 0],
|
126 |
-
[0, 0, -1, 0],
|
127 |
-
[0, 0, 0, 1],
|
128 |
-
], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
|
129 |
-
|
130 |
-
cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
|
131 |
-
|
132 |
-
ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
|
133 |
-
|
134 |
-
ray_dirs_cam = torch.stack([
|
135 |
-
torch.zeros((N, M), device=cam2world_matrix.device),
|
136 |
-
torch.zeros((N, M), device=cam2world_matrix.device),
|
137 |
-
torch.ones((N, M), device=cam2world_matrix.device),
|
138 |
-
], dim=-1)
|
139 |
-
ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1)
|
140 |
-
|
141 |
-
return ray_origins, ray_dirs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/renderer/utils/renderer.py
DELETED
@@ -1,323 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3 |
-
#
|
4 |
-
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
-
# property and proprietary rights in and to this material, related
|
6 |
-
# documentation and any modifications thereto. Any use, reproduction,
|
7 |
-
# disclosure or distribution of this material and related documentation
|
8 |
-
# without an express license agreement from NVIDIA CORPORATION or
|
9 |
-
# its affiliates is strictly prohibited.
|
10 |
-
#
|
11 |
-
# Modified by Jiale Xu
|
12 |
-
# The modifications are subject to the same license as the original.
|
13 |
-
|
14 |
-
|
15 |
-
"""
|
16 |
-
The renderer is a module that takes in rays, decides where to sample along each
|
17 |
-
ray, and computes pixel colors using the volume rendering equation.
|
18 |
-
"""
|
19 |
-
|
20 |
-
import torch
|
21 |
-
import torch.nn as nn
|
22 |
-
import torch.nn.functional as F
|
23 |
-
|
24 |
-
from .ray_marcher import MipRayMarcher2
|
25 |
-
from . import math_utils
|
26 |
-
|
27 |
-
|
28 |
-
def generate_planes():
|
29 |
-
"""
|
30 |
-
Defines planes by the three vectors that form the "axes" of the
|
31 |
-
plane. Should work with arbitrary number of planes and planes of
|
32 |
-
arbitrary orientation.
|
33 |
-
|
34 |
-
Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
|
35 |
-
"""
|
36 |
-
return torch.tensor([[[1, 0, 0],
|
37 |
-
[0, 1, 0],
|
38 |
-
[0, 0, 1]],
|
39 |
-
[[1, 0, 0],
|
40 |
-
[0, 0, 1],
|
41 |
-
[0, 1, 0]],
|
42 |
-
[[0, 0, 1],
|
43 |
-
[0, 1, 0],
|
44 |
-
[1, 0, 0]]], dtype=torch.float32)
|
45 |
-
|
46 |
-
def project_onto_planes(planes, coordinates):
|
47 |
-
"""
|
48 |
-
Does a projection of a 3D point onto a batch of 2D planes,
|
49 |
-
returning 2D plane coordinates.
|
50 |
-
|
51 |
-
Takes plane axes of shape n_planes, 3, 3
|
52 |
-
# Takes coordinates of shape N, M, 3
|
53 |
-
# returns projections of shape N*n_planes, M, 2
|
54 |
-
"""
|
55 |
-
N, M, C = coordinates.shape
|
56 |
-
n_planes, _, _ = planes.shape
|
57 |
-
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
|
58 |
-
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
|
59 |
-
projections = torch.bmm(coordinates, inv_planes)
|
60 |
-
return projections[..., :2]
|
61 |
-
|
62 |
-
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
|
63 |
-
assert padding_mode == 'zeros'
|
64 |
-
N, n_planes, C, H, W = plane_features.shape
|
65 |
-
_, M, _ = coordinates.shape
|
66 |
-
plane_features = plane_features.view(N*n_planes, C, H, W)
|
67 |
-
dtype = plane_features.dtype
|
68 |
-
|
69 |
-
coordinates = (2/box_warp) * coordinates # add specific box bounds
|
70 |
-
|
71 |
-
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
|
72 |
-
output_features = torch.nn.functional.grid_sample(
|
73 |
-
plane_features,
|
74 |
-
projected_coordinates.to(dtype),
|
75 |
-
mode=mode,
|
76 |
-
padding_mode=padding_mode,
|
77 |
-
align_corners=False,
|
78 |
-
).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
79 |
-
return output_features
|
80 |
-
|
81 |
-
def sample_from_3dgrid(grid, coordinates):
|
82 |
-
"""
|
83 |
-
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
|
84 |
-
Expects grid in shape (1, channels, H, W, D)
|
85 |
-
(Also works if grid has batch size)
|
86 |
-
Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
|
87 |
-
"""
|
88 |
-
batch_size, n_coords, n_dims = coordinates.shape
|
89 |
-
sampled_features = torch.nn.functional.grid_sample(
|
90 |
-
grid.expand(batch_size, -1, -1, -1, -1),
|
91 |
-
coordinates.reshape(batch_size, 1, 1, -1, n_dims),
|
92 |
-
mode='bilinear',
|
93 |
-
padding_mode='zeros',
|
94 |
-
align_corners=False,
|
95 |
-
)
|
96 |
-
N, C, H, W, D = sampled_features.shape
|
97 |
-
sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
|
98 |
-
return sampled_features
|
99 |
-
|
100 |
-
class ImportanceRenderer(torch.nn.Module):
|
101 |
-
"""
|
102 |
-
Modified original version to filter out-of-box samples as TensoRF does.
|
103 |
-
|
104 |
-
Reference:
|
105 |
-
TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
|
106 |
-
"""
|
107 |
-
def __init__(self):
|
108 |
-
super().__init__()
|
109 |
-
self.activation_factory = self._build_activation_factory()
|
110 |
-
self.ray_marcher = MipRayMarcher2(self.activation_factory)
|
111 |
-
self.plane_axes = generate_planes()
|
112 |
-
|
113 |
-
def _build_activation_factory(self):
|
114 |
-
def activation_factory(options: dict):
|
115 |
-
if options['clamp_mode'] == 'softplus':
|
116 |
-
return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
|
117 |
-
else:
|
118 |
-
assert False, "Renderer only supports `clamp_mode`=`softplus`!"
|
119 |
-
return activation_factory
|
120 |
-
|
121 |
-
def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
|
122 |
-
planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
|
123 |
-
"""
|
124 |
-
Additional filtering is applied to filter out-of-box samples.
|
125 |
-
Modifications made by Zexin He.
|
126 |
-
"""
|
127 |
-
|
128 |
-
# context related variables
|
129 |
-
batch_size, num_rays, samples_per_ray, _ = depths.shape
|
130 |
-
device = depths.device
|
131 |
-
|
132 |
-
# define sample points with depths
|
133 |
-
sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
|
134 |
-
sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
|
135 |
-
|
136 |
-
# filter out-of-box samples
|
137 |
-
mask_inbox = \
|
138 |
-
(rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
|
139 |
-
(sample_coordinates <= rendering_options['sampler_bbox_max'])
|
140 |
-
mask_inbox = mask_inbox.all(-1)
|
141 |
-
|
142 |
-
# forward model according to all samples
|
143 |
-
_out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
|
144 |
-
|
145 |
-
# set out-of-box samples to zeros(rgb) & -inf(sigma)
|
146 |
-
SAFE_GUARD = 3
|
147 |
-
DATA_TYPE = _out['sigma'].dtype
|
148 |
-
colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
|
149 |
-
densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
|
150 |
-
colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox]
|
151 |
-
|
152 |
-
# reshape back
|
153 |
-
colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
|
154 |
-
densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
|
155 |
-
|
156 |
-
return colors_pass, densities_pass
|
157 |
-
|
158 |
-
def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
|
159 |
-
# self.plane_axes = self.plane_axes.to(ray_origins.device)
|
160 |
-
|
161 |
-
if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
|
162 |
-
ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
|
163 |
-
is_ray_valid = ray_end > ray_start
|
164 |
-
if torch.any(is_ray_valid).item():
|
165 |
-
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
|
166 |
-
ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
|
167 |
-
depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
|
168 |
-
else:
|
169 |
-
# Create stratified depth samples
|
170 |
-
depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
|
171 |
-
|
172 |
-
# Coarse Pass
|
173 |
-
colors_coarse, densities_coarse = self._forward_pass(
|
174 |
-
depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
|
175 |
-
planes=planes, decoder=decoder, rendering_options=rendering_options)
|
176 |
-
|
177 |
-
# Fine Pass
|
178 |
-
N_importance = rendering_options['depth_resolution_importance']
|
179 |
-
if N_importance > 0:
|
180 |
-
_, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
|
181 |
-
|
182 |
-
depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
|
183 |
-
|
184 |
-
colors_fine, densities_fine = self._forward_pass(
|
185 |
-
depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
|
186 |
-
planes=planes, decoder=decoder, rendering_options=rendering_options)
|
187 |
-
|
188 |
-
all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
|
189 |
-
depths_fine, colors_fine, densities_fine)
|
190 |
-
|
191 |
-
rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
|
192 |
-
else:
|
193 |
-
rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
|
194 |
-
|
195 |
-
return rgb_final, depth_final, weights.sum(2)
|
196 |
-
|
197 |
-
def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
|
198 |
-
plane_axes = self.plane_axes.to(planes.device)
|
199 |
-
sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
|
200 |
-
|
201 |
-
out = decoder(sampled_features, sample_directions)
|
202 |
-
if options.get('density_noise', 0) > 0:
|
203 |
-
out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
|
204 |
-
return out
|
205 |
-
|
206 |
-
def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
|
207 |
-
out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
|
208 |
-
out['sigma'] = self.activation_factory(options)(out['sigma'])
|
209 |
-
return out
|
210 |
-
|
211 |
-
def sort_samples(self, all_depths, all_colors, all_densities):
|
212 |
-
_, indices = torch.sort(all_depths, dim=-2)
|
213 |
-
all_depths = torch.gather(all_depths, -2, indices)
|
214 |
-
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
|
215 |
-
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
|
216 |
-
return all_depths, all_colors, all_densities
|
217 |
-
|
218 |
-
def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None):
|
219 |
-
all_depths = torch.cat([depths1, depths2], dim = -2)
|
220 |
-
all_colors = torch.cat([colors1, colors2], dim = -2)
|
221 |
-
all_densities = torch.cat([densities1, densities2], dim = -2)
|
222 |
-
|
223 |
-
if normals1 is not None and normals2 is not None:
|
224 |
-
all_normals = torch.cat([normals1, normals2], dim = -2)
|
225 |
-
else:
|
226 |
-
all_normals = None
|
227 |
-
|
228 |
-
_, indices = torch.sort(all_depths, dim=-2)
|
229 |
-
all_depths = torch.gather(all_depths, -2, indices)
|
230 |
-
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
|
231 |
-
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
|
232 |
-
|
233 |
-
if all_normals is not None:
|
234 |
-
all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1]))
|
235 |
-
return all_depths, all_colors, all_normals, all_densities
|
236 |
-
|
237 |
-
return all_depths, all_colors, all_densities
|
238 |
-
|
239 |
-
def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
|
240 |
-
"""
|
241 |
-
Return depths of approximately uniformly spaced samples along rays.
|
242 |
-
"""
|
243 |
-
N, M, _ = ray_origins.shape
|
244 |
-
if disparity_space_sampling:
|
245 |
-
depths_coarse = torch.linspace(0,
|
246 |
-
1,
|
247 |
-
depth_resolution,
|
248 |
-
device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
|
249 |
-
depth_delta = 1/(depth_resolution - 1)
|
250 |
-
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
|
251 |
-
depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
|
252 |
-
else:
|
253 |
-
if type(ray_start) == torch.Tensor:
|
254 |
-
depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
|
255 |
-
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
|
256 |
-
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
|
257 |
-
else:
|
258 |
-
depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
|
259 |
-
depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
|
260 |
-
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
|
261 |
-
|
262 |
-
return depths_coarse
|
263 |
-
|
264 |
-
def sample_importance(self, z_vals, weights, N_importance):
|
265 |
-
"""
|
266 |
-
Return depths of importance sampled points along rays. See NeRF importance sampling for more.
|
267 |
-
"""
|
268 |
-
with torch.no_grad():
|
269 |
-
batch_size, num_rays, samples_per_ray, _ = z_vals.shape
|
270 |
-
|
271 |
-
z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
|
272 |
-
weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
|
273 |
-
|
274 |
-
# smooth weights
|
275 |
-
weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1), 2, 1, padding=1)
|
276 |
-
weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
|
277 |
-
weights = weights + 0.01
|
278 |
-
|
279 |
-
z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
|
280 |
-
importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
|
281 |
-
N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
|
282 |
-
return importance_z_vals
|
283 |
-
|
284 |
-
def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
|
285 |
-
"""
|
286 |
-
Sample @N_importance samples from @bins with distribution defined by @weights.
|
287 |
-
Inputs:
|
288 |
-
bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
|
289 |
-
weights: (N_rays, N_samples_)
|
290 |
-
N_importance: the number of samples to draw from the distribution
|
291 |
-
det: deterministic or not
|
292 |
-
eps: a small number to prevent division by zero
|
293 |
-
Outputs:
|
294 |
-
samples: the sampled samples
|
295 |
-
"""
|
296 |
-
N_rays, N_samples_ = weights.shape
|
297 |
-
weights = weights + eps # prevent division by zero (don't do inplace op!)
|
298 |
-
pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
|
299 |
-
cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
|
300 |
-
cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
|
301 |
-
# padded to 0~1 inclusive
|
302 |
-
|
303 |
-
if det:
|
304 |
-
u = torch.linspace(0, 1, N_importance, device=bins.device)
|
305 |
-
u = u.expand(N_rays, N_importance)
|
306 |
-
else:
|
307 |
-
u = torch.rand(N_rays, N_importance, device=bins.device)
|
308 |
-
u = u.contiguous()
|
309 |
-
|
310 |
-
inds = torch.searchsorted(cdf, u, right=True)
|
311 |
-
below = torch.clamp_min(inds-1, 0)
|
312 |
-
above = torch.clamp_max(inds, N_samples_)
|
313 |
-
|
314 |
-
inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
|
315 |
-
cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
|
316 |
-
bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
|
317 |
-
|
318 |
-
denom = cdf_g[...,1]-cdf_g[...,0]
|
319 |
-
denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
|
320 |
-
# anyway, therefore any value for it is fine (set to 1 here)
|
321 |
-
|
322 |
-
samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
|
323 |
-
return samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/__init__.py
DELETED
File without changes
|
src/utils/camera_util.py
DELETED
@@ -1,111 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn.functional as F
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
|
6 |
-
def pad_camera_extrinsics_4x4(extrinsics):
|
7 |
-
if extrinsics.shape[-2] == 4:
|
8 |
-
return extrinsics
|
9 |
-
padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
|
10 |
-
if extrinsics.ndim == 3:
|
11 |
-
padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
|
12 |
-
extrinsics = torch.cat([extrinsics, padding], dim=-2)
|
13 |
-
return extrinsics
|
14 |
-
|
15 |
-
|
16 |
-
def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
|
17 |
-
"""
|
18 |
-
Create OpenGL camera extrinsics from camera locations and look-at position.
|
19 |
-
|
20 |
-
camera_position: (M, 3) or (3,)
|
21 |
-
look_at: (3)
|
22 |
-
up_world: (3)
|
23 |
-
return: (M, 3, 4) or (3, 4)
|
24 |
-
"""
|
25 |
-
# by default, looking at the origin and world up is z-axis
|
26 |
-
if look_at is None:
|
27 |
-
look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
|
28 |
-
if up_world is None:
|
29 |
-
up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
|
30 |
-
if camera_position.ndim == 2:
|
31 |
-
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
32 |
-
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
33 |
-
|
34 |
-
# OpenGL camera: z-backward, x-right, y-up
|
35 |
-
z_axis = camera_position - look_at
|
36 |
-
z_axis = F.normalize(z_axis, dim=-1).float()
|
37 |
-
x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
|
38 |
-
x_axis = F.normalize(x_axis, dim=-1).float()
|
39 |
-
y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
|
40 |
-
y_axis = F.normalize(y_axis, dim=-1).float()
|
41 |
-
|
42 |
-
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
|
43 |
-
extrinsics = pad_camera_extrinsics_4x4(extrinsics)
|
44 |
-
return extrinsics
|
45 |
-
|
46 |
-
|
47 |
-
def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
|
48 |
-
azimuths = np.deg2rad(azimuths)
|
49 |
-
elevations = np.deg2rad(elevations)
|
50 |
-
|
51 |
-
xs = radius * np.cos(elevations) * np.cos(azimuths)
|
52 |
-
ys = radius * np.cos(elevations) * np.sin(azimuths)
|
53 |
-
zs = radius * np.sin(elevations)
|
54 |
-
|
55 |
-
cam_locations = np.stack([xs, ys, zs], axis=-1)
|
56 |
-
cam_locations = torch.from_numpy(cam_locations).float()
|
57 |
-
|
58 |
-
c2ws = center_looking_at_camera_pose(cam_locations)
|
59 |
-
return c2ws
|
60 |
-
|
61 |
-
|
62 |
-
def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0):
|
63 |
-
# M: number of circular views
|
64 |
-
# radius: camera dist to center
|
65 |
-
# elevation: elevation degrees of the camera
|
66 |
-
# return: (M, 4, 4)
|
67 |
-
assert M > 0 and radius > 0
|
68 |
-
|
69 |
-
elevation = np.deg2rad(elevation)
|
70 |
-
|
71 |
-
camera_positions = []
|
72 |
-
for i in range(M):
|
73 |
-
azimuth = 2 * np.pi * i / M
|
74 |
-
x = radius * np.cos(elevation) * np.cos(azimuth)
|
75 |
-
y = radius * np.cos(elevation) * np.sin(azimuth)
|
76 |
-
z = radius * np.sin(elevation)
|
77 |
-
camera_positions.append([x, y, z])
|
78 |
-
camera_positions = np.array(camera_positions)
|
79 |
-
camera_positions = torch.from_numpy(camera_positions).float()
|
80 |
-
extrinsics = center_looking_at_camera_pose(camera_positions)
|
81 |
-
return extrinsics
|
82 |
-
|
83 |
-
|
84 |
-
def FOV_to_intrinsics(fov, device='cpu'):
|
85 |
-
"""
|
86 |
-
Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
|
87 |
-
Note the intrinsics are returned as normalized by image size, rather than in pixel units.
|
88 |
-
Assumes principal point is at image center.
|
89 |
-
"""
|
90 |
-
focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
|
91 |
-
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
|
92 |
-
return intrinsics
|
93 |
-
|
94 |
-
|
95 |
-
def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
|
96 |
-
"""
|
97 |
-
Get the input camera parameters.
|
98 |
-
"""
|
99 |
-
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
|
100 |
-
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
|
101 |
-
|
102 |
-
c2ws = spherical_camera_pose(azimuths, elevations, radius)
|
103 |
-
c2ws = c2ws.float().flatten(-2)
|
104 |
-
|
105 |
-
Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)
|
106 |
-
|
107 |
-
extrinsics = c2ws[:, :12]
|
108 |
-
intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
|
109 |
-
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
|
110 |
-
|
111 |
-
return cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/infer_util.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import imageio
|
3 |
-
import rembg
|
4 |
-
import torch
|
5 |
-
import numpy as np
|
6 |
-
import PIL.Image
|
7 |
-
from PIL import Image
|
8 |
-
from typing import Any
|
9 |
-
|
10 |
-
|
11 |
-
def remove_background(image: PIL.Image.Image,
|
12 |
-
rembg_session: Any = None,
|
13 |
-
force: bool = False,
|
14 |
-
**rembg_kwargs,
|
15 |
-
) -> PIL.Image.Image:
|
16 |
-
do_remove = True
|
17 |
-
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
18 |
-
do_remove = False
|
19 |
-
do_remove = do_remove or force
|
20 |
-
if do_remove:
|
21 |
-
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
22 |
-
return image
|
23 |
-
|
24 |
-
|
25 |
-
def resize_foreground(
|
26 |
-
image: PIL.Image.Image,
|
27 |
-
ratio: float,
|
28 |
-
) -> PIL.Image.Image:
|
29 |
-
image = np.array(image)
|
30 |
-
assert image.shape[-1] == 4
|
31 |
-
alpha = np.where(image[..., 3] > 0)
|
32 |
-
y1, y2, x1, x2 = (
|
33 |
-
alpha[0].min(),
|
34 |
-
alpha[0].max(),
|
35 |
-
alpha[1].min(),
|
36 |
-
alpha[1].max(),
|
37 |
-
)
|
38 |
-
# crop the foreground
|
39 |
-
fg = image[y1:y2, x1:x2]
|
40 |
-
# pad to square
|
41 |
-
size = max(fg.shape[0], fg.shape[1])
|
42 |
-
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
43 |
-
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
44 |
-
new_image = np.pad(
|
45 |
-
fg,
|
46 |
-
((ph0, ph1), (pw0, pw1), (0, 0)),
|
47 |
-
mode="constant",
|
48 |
-
constant_values=((0, 0), (0, 0), (0, 0)),
|
49 |
-
)
|
50 |
-
|
51 |
-
# compute padding according to the ratio
|
52 |
-
new_size = int(new_image.shape[0] / ratio)
|
53 |
-
# pad to size, double side
|
54 |
-
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
55 |
-
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
56 |
-
new_image = np.pad(
|
57 |
-
new_image,
|
58 |
-
((ph0, ph1), (pw0, pw1), (0, 0)),
|
59 |
-
mode="constant",
|
60 |
-
constant_values=((0, 0), (0, 0), (0, 0)),
|
61 |
-
)
|
62 |
-
new_image = PIL.Image.fromarray(new_image)
|
63 |
-
return new_image
|
64 |
-
|
65 |
-
|
66 |
-
def images_to_video(
|
67 |
-
images: torch.Tensor,
|
68 |
-
output_path: str,
|
69 |
-
fps: int = 30,
|
70 |
-
) -> None:
|
71 |
-
# images: (N, C, H, W)
|
72 |
-
video_dir = os.path.dirname(output_path)
|
73 |
-
video_name = os.path.basename(output_path)
|
74 |
-
os.makedirs(video_dir, exist_ok=True)
|
75 |
-
|
76 |
-
frames = []
|
77 |
-
for i in range(len(images)):
|
78 |
-
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
79 |
-
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
|
80 |
-
f"Frame shape mismatch: {frame.shape} vs {images.shape}"
|
81 |
-
assert frame.min() >= 0 and frame.max() <= 255, \
|
82 |
-
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
|
83 |
-
frames.append(frame)
|
84 |
-
imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/mesh_util.py
DELETED
@@ -1,181 +0,0 @@
|
|
1 |
-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import xatlas
|
11 |
-
import trimesh
|
12 |
-
import cv2
|
13 |
-
import numpy as np
|
14 |
-
import nvdiffrast.torch as dr
|
15 |
-
from PIL import Image
|
16 |
-
|
17 |
-
|
18 |
-
def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath):
|
19 |
-
|
20 |
-
pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
|
21 |
-
facenp_fx3 = facenp_fx3[:, [2, 1, 0]]
|
22 |
-
|
23 |
-
mesh = trimesh.Trimesh(
|
24 |
-
vertices=pointnp_px3,
|
25 |
-
faces=facenp_fx3,
|
26 |
-
vertex_colors=colornp_px3,
|
27 |
-
)
|
28 |
-
mesh.export(fpath, 'obj')
|
29 |
-
|
30 |
-
|
31 |
-
def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath):
|
32 |
-
|
33 |
-
pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]])
|
34 |
-
|
35 |
-
mesh = trimesh.Trimesh(
|
36 |
-
vertices=pointnp_px3,
|
37 |
-
faces=facenp_fx3,
|
38 |
-
vertex_colors=colornp_px3,
|
39 |
-
)
|
40 |
-
mesh.export(fpath, 'glb')
|
41 |
-
|
42 |
-
|
43 |
-
def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname):
|
44 |
-
import os
|
45 |
-
fol, na = os.path.split(fname)
|
46 |
-
na, _ = os.path.splitext(na)
|
47 |
-
|
48 |
-
matname = '%s/%s.mtl' % (fol, na)
|
49 |
-
fid = open(matname, 'w')
|
50 |
-
fid.write('newmtl material_0\n')
|
51 |
-
fid.write('Kd 1 1 1\n')
|
52 |
-
fid.write('Ka 0 0 0\n')
|
53 |
-
fid.write('Ks 0.4 0.4 0.4\n')
|
54 |
-
fid.write('Ns 10\n')
|
55 |
-
fid.write('illum 2\n')
|
56 |
-
fid.write('map_Kd %s.png\n' % na)
|
57 |
-
fid.close()
|
58 |
-
####
|
59 |
-
|
60 |
-
fid = open(fname, 'w')
|
61 |
-
fid.write('mtllib %s.mtl\n' % na)
|
62 |
-
|
63 |
-
for pidx, p in enumerate(pointnp_px3):
|
64 |
-
pp = p
|
65 |
-
fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2]))
|
66 |
-
|
67 |
-
for pidx, p in enumerate(tcoords_px2):
|
68 |
-
pp = p
|
69 |
-
fid.write('vt %f %f\n' % (pp[0], pp[1]))
|
70 |
-
|
71 |
-
fid.write('usemtl material_0\n')
|
72 |
-
for i, f in enumerate(facenp_fx3):
|
73 |
-
f1 = f + 1
|
74 |
-
f2 = facetex_fx3[i] + 1
|
75 |
-
fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
|
76 |
-
fid.close()
|
77 |
-
|
78 |
-
# save texture map
|
79 |
-
lo, hi = 0, 1
|
80 |
-
img = np.asarray(texmap_hxwx3, dtype=np.float32)
|
81 |
-
img = (img - lo) * (255 / (hi - lo))
|
82 |
-
img = img.clip(0, 255)
|
83 |
-
mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True)
|
84 |
-
mask = (mask <= 3.0).astype(np.float32)
|
85 |
-
kernel = np.ones((3, 3), 'uint8')
|
86 |
-
dilate_img = cv2.dilate(img, kernel, iterations=1)
|
87 |
-
img = img * (1 - mask) + dilate_img * mask
|
88 |
-
img = img.clip(0, 255).astype(np.uint8)
|
89 |
-
Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png')
|
90 |
-
|
91 |
-
|
92 |
-
def loadobj(meshfile):
|
93 |
-
v = []
|
94 |
-
f = []
|
95 |
-
meshfp = open(meshfile, 'r')
|
96 |
-
for line in meshfp.readlines():
|
97 |
-
data = line.strip().split(' ')
|
98 |
-
data = [da for da in data if len(da) > 0]
|
99 |
-
if len(data) != 4:
|
100 |
-
continue
|
101 |
-
if data[0] == 'v':
|
102 |
-
v.append([float(d) for d in data[1:]])
|
103 |
-
if data[0] == 'f':
|
104 |
-
data = [da.split('/')[0] for da in data]
|
105 |
-
f.append([int(d) for d in data[1:]])
|
106 |
-
meshfp.close()
|
107 |
-
|
108 |
-
# torch need int64
|
109 |
-
facenp_fx3 = np.array(f, dtype=np.int64) - 1
|
110 |
-
pointnp_px3 = np.array(v, dtype=np.float32)
|
111 |
-
return pointnp_px3, facenp_fx3
|
112 |
-
|
113 |
-
|
114 |
-
def loadobjtex(meshfile):
|
115 |
-
v = []
|
116 |
-
vt = []
|
117 |
-
f = []
|
118 |
-
ft = []
|
119 |
-
meshfp = open(meshfile, 'r')
|
120 |
-
for line in meshfp.readlines():
|
121 |
-
data = line.strip().split(' ')
|
122 |
-
data = [da for da in data if len(da) > 0]
|
123 |
-
if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)):
|
124 |
-
continue
|
125 |
-
if data[0] == 'v':
|
126 |
-
assert len(data) == 4
|
127 |
-
|
128 |
-
v.append([float(d) for d in data[1:]])
|
129 |
-
if data[0] == 'vt':
|
130 |
-
if len(data) == 3 or len(data) == 4:
|
131 |
-
vt.append([float(d) for d in data[1:3]])
|
132 |
-
if data[0] == 'f':
|
133 |
-
data = [da.split('/') for da in data]
|
134 |
-
if len(data) == 4:
|
135 |
-
f.append([int(d[0]) for d in data[1:]])
|
136 |
-
ft.append([int(d[1]) for d in data[1:]])
|
137 |
-
elif len(data) == 5:
|
138 |
-
idx1 = [1, 2, 3]
|
139 |
-
data1 = [data[i] for i in idx1]
|
140 |
-
f.append([int(d[0]) for d in data1])
|
141 |
-
ft.append([int(d[1]) for d in data1])
|
142 |
-
idx2 = [1, 3, 4]
|
143 |
-
data2 = [data[i] for i in idx2]
|
144 |
-
f.append([int(d[0]) for d in data2])
|
145 |
-
ft.append([int(d[1]) for d in data2])
|
146 |
-
meshfp.close()
|
147 |
-
|
148 |
-
# torch need int64
|
149 |
-
facenp_fx3 = np.array(f, dtype=np.int64) - 1
|
150 |
-
ftnp_fx3 = np.array(ft, dtype=np.int64) - 1
|
151 |
-
pointnp_px3 = np.array(v, dtype=np.float32)
|
152 |
-
uvs = np.array(vt, dtype=np.float32)
|
153 |
-
return pointnp_px3, facenp_fx3, uvs, ftnp_fx3
|
154 |
-
|
155 |
-
|
156 |
-
# ==============================================================================================
|
157 |
-
def interpolate(attr, rast, attr_idx, rast_db=None):
|
158 |
-
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
|
159 |
-
|
160 |
-
|
161 |
-
def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
|
162 |
-
vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
|
163 |
-
|
164 |
-
# Convert to tensors
|
165 |
-
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
166 |
-
|
167 |
-
uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
|
168 |
-
mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
|
169 |
-
# mesh_v_tex. ture
|
170 |
-
uv_clip = uvs[None, ...] * 2.0 - 1.0
|
171 |
-
|
172 |
-
# pad to four component coordinate
|
173 |
-
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
|
174 |
-
|
175 |
-
# rasterize
|
176 |
-
rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
|
177 |
-
|
178 |
-
# Interpolate world space position
|
179 |
-
gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
|
180 |
-
mask = rast[..., 3:4] > 0
|
181 |
-
return uvs, mesh_tex_idx, gb_pos, mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/train_util.py
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
import importlib
|
2 |
-
|
3 |
-
|
4 |
-
def count_params(model, verbose=False):
|
5 |
-
total_params = sum(p.numel() for p in model.parameters())
|
6 |
-
if verbose:
|
7 |
-
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
8 |
-
return total_params
|
9 |
-
|
10 |
-
|
11 |
-
def instantiate_from_config(config):
|
12 |
-
if not "target" in config:
|
13 |
-
if config == '__is_first_stage__':
|
14 |
-
return None
|
15 |
-
elif config == "__is_unconditional__":
|
16 |
-
return None
|
17 |
-
raise KeyError("Expected key `target` to instantiate.")
|
18 |
-
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
19 |
-
|
20 |
-
|
21 |
-
def get_obj_from_str(string, reload=False):
|
22 |
-
module, cls = string.rsplit(".", 1)
|
23 |
-
if reload:
|
24 |
-
module_imp = importlib.import_module(module)
|
25 |
-
importlib.reload(module_imp)
|
26 |
-
return getattr(importlib.import_module(module, package=None), cls)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
zero123plus/pipeline.py
DELETED
@@ -1,406 +0,0 @@
|
|
1 |
-
from typing import Any, Dict, Optional
|
2 |
-
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
3 |
-
from diffusers.schedulers import KarrasDiffusionSchedulers
|
4 |
-
|
5 |
-
import numpy
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.utils.checkpoint
|
9 |
-
import torch.distributed
|
10 |
-
import transformers
|
11 |
-
from collections import OrderedDict
|
12 |
-
from PIL import Image
|
13 |
-
from torchvision import transforms
|
14 |
-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
15 |
-
|
16 |
-
import diffusers
|
17 |
-
from diffusers import (
|
18 |
-
AutoencoderKL,
|
19 |
-
DDPMScheduler,
|
20 |
-
DiffusionPipeline,
|
21 |
-
EulerAncestralDiscreteScheduler,
|
22 |
-
UNet2DConditionModel,
|
23 |
-
ImagePipelineOutput
|
24 |
-
)
|
25 |
-
from diffusers.image_processor import VaeImageProcessor
|
26 |
-
from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0
|
27 |
-
from diffusers.utils.import_utils import is_xformers_available
|
28 |
-
|
29 |
-
|
30 |
-
def to_rgb_image(maybe_rgba: Image.Image):
|
31 |
-
if maybe_rgba.mode == 'RGB':
|
32 |
-
return maybe_rgba
|
33 |
-
elif maybe_rgba.mode == 'RGBA':
|
34 |
-
rgba = maybe_rgba
|
35 |
-
img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
|
36 |
-
img = Image.fromarray(img, 'RGB')
|
37 |
-
img.paste(rgba, mask=rgba.getchannel('A'))
|
38 |
-
return img
|
39 |
-
else:
|
40 |
-
raise ValueError("Unsupported image type.", maybe_rgba.mode)
|
41 |
-
|
42 |
-
|
43 |
-
class ReferenceOnlyAttnProc(torch.nn.Module):
|
44 |
-
def __init__(
|
45 |
-
self,
|
46 |
-
chained_proc,
|
47 |
-
enabled=False,
|
48 |
-
name=None
|
49 |
-
) -> None:
|
50 |
-
super().__init__()
|
51 |
-
self.enabled = enabled
|
52 |
-
self.chained_proc = chained_proc
|
53 |
-
self.name = name
|
54 |
-
|
55 |
-
def __call__(
|
56 |
-
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
57 |
-
mode="w", ref_dict: dict = None, is_cfg_guidance = False
|
58 |
-
) -> Any:
|
59 |
-
if encoder_hidden_states is None:
|
60 |
-
encoder_hidden_states = hidden_states
|
61 |
-
if self.enabled and is_cfg_guidance:
|
62 |
-
res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask)
|
63 |
-
hidden_states = hidden_states[1:]
|
64 |
-
encoder_hidden_states = encoder_hidden_states[1:]
|
65 |
-
if self.enabled:
|
66 |
-
if mode == 'w':
|
67 |
-
ref_dict[self.name] = encoder_hidden_states
|
68 |
-
elif mode == 'r':
|
69 |
-
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
|
70 |
-
elif mode == 'm':
|
71 |
-
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
|
72 |
-
else:
|
73 |
-
assert False, mode
|
74 |
-
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
|
75 |
-
if self.enabled and is_cfg_guidance:
|
76 |
-
res = torch.cat([res0, res])
|
77 |
-
return res
|
78 |
-
|
79 |
-
|
80 |
-
class RefOnlyNoisedUNet(torch.nn.Module):
|
81 |
-
def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
|
82 |
-
super().__init__()
|
83 |
-
self.unet = unet
|
84 |
-
self.train_sched = train_sched
|
85 |
-
self.val_sched = val_sched
|
86 |
-
|
87 |
-
unet_lora_attn_procs = dict()
|
88 |
-
for name, _ in unet.attn_processors.items():
|
89 |
-
if torch.__version__ >= '2.0':
|
90 |
-
default_attn_proc = AttnProcessor2_0()
|
91 |
-
elif is_xformers_available():
|
92 |
-
default_attn_proc = XFormersAttnProcessor()
|
93 |
-
else:
|
94 |
-
default_attn_proc = AttnProcessor()
|
95 |
-
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
|
96 |
-
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
|
97 |
-
)
|
98 |
-
unet.set_attn_processor(unet_lora_attn_procs)
|
99 |
-
|
100 |
-
def __getattr__(self, name: str):
|
101 |
-
try:
|
102 |
-
return super().__getattr__(name)
|
103 |
-
except AttributeError:
|
104 |
-
return getattr(self.unet, name)
|
105 |
-
|
106 |
-
def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
|
107 |
-
if is_cfg_guidance:
|
108 |
-
encoder_hidden_states = encoder_hidden_states[1:]
|
109 |
-
class_labels = class_labels[1:]
|
110 |
-
self.unet(
|
111 |
-
noisy_cond_lat, timestep,
|
112 |
-
encoder_hidden_states=encoder_hidden_states,
|
113 |
-
class_labels=class_labels,
|
114 |
-
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
|
115 |
-
**kwargs
|
116 |
-
)
|
117 |
-
|
118 |
-
def forward(
|
119 |
-
self, sample, timestep, encoder_hidden_states, class_labels=None,
|
120 |
-
*args, cross_attention_kwargs,
|
121 |
-
down_block_res_samples=None, mid_block_res_sample=None,
|
122 |
-
**kwargs
|
123 |
-
):
|
124 |
-
cond_lat = cross_attention_kwargs['cond_lat']
|
125 |
-
is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
|
126 |
-
noise = torch.randn_like(cond_lat)
|
127 |
-
if self.training:
|
128 |
-
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
|
129 |
-
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
|
130 |
-
else:
|
131 |
-
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
|
132 |
-
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
|
133 |
-
ref_dict = {}
|
134 |
-
self.forward_cond(
|
135 |
-
noisy_cond_lat, timestep,
|
136 |
-
encoder_hidden_states, class_labels,
|
137 |
-
ref_dict, is_cfg_guidance, **kwargs
|
138 |
-
)
|
139 |
-
weight_dtype = self.unet.dtype
|
140 |
-
return self.unet(
|
141 |
-
sample, timestep,
|
142 |
-
encoder_hidden_states, *args,
|
143 |
-
class_labels=class_labels,
|
144 |
-
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
|
145 |
-
down_block_additional_residuals=[
|
146 |
-
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
|
147 |
-
] if down_block_res_samples is not None else None,
|
148 |
-
mid_block_additional_residual=(
|
149 |
-
mid_block_res_sample.to(dtype=weight_dtype)
|
150 |
-
if mid_block_res_sample is not None else None
|
151 |
-
),
|
152 |
-
**kwargs
|
153 |
-
)
|
154 |
-
|
155 |
-
|
156 |
-
def scale_latents(latents):
|
157 |
-
latents = (latents - 0.22) * 0.75
|
158 |
-
return latents
|
159 |
-
|
160 |
-
|
161 |
-
def unscale_latents(latents):
|
162 |
-
latents = latents / 0.75 + 0.22
|
163 |
-
return latents
|
164 |
-
|
165 |
-
|
166 |
-
def scale_image(image):
|
167 |
-
image = image * 0.5 / 0.8
|
168 |
-
return image
|
169 |
-
|
170 |
-
|
171 |
-
def unscale_image(image):
|
172 |
-
image = image / 0.5 * 0.8
|
173 |
-
return image
|
174 |
-
|
175 |
-
|
176 |
-
class DepthControlUNet(torch.nn.Module):
|
177 |
-
def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None:
|
178 |
-
super().__init__()
|
179 |
-
self.unet = unet
|
180 |
-
if controlnet is None:
|
181 |
-
self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet)
|
182 |
-
else:
|
183 |
-
self.controlnet = controlnet
|
184 |
-
DefaultAttnProc = AttnProcessor2_0
|
185 |
-
if is_xformers_available():
|
186 |
-
DefaultAttnProc = XFormersAttnProcessor
|
187 |
-
self.controlnet.set_attn_processor(DefaultAttnProc())
|
188 |
-
self.conditioning_scale = conditioning_scale
|
189 |
-
|
190 |
-
def __getattr__(self, name: str):
|
191 |
-
try:
|
192 |
-
return super().__getattr__(name)
|
193 |
-
except AttributeError:
|
194 |
-
return getattr(self.unet, name)
|
195 |
-
|
196 |
-
def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs):
|
197 |
-
cross_attention_kwargs = dict(cross_attention_kwargs)
|
198 |
-
control_depth = cross_attention_kwargs.pop('control_depth')
|
199 |
-
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
200 |
-
sample,
|
201 |
-
timestep,
|
202 |
-
encoder_hidden_states=encoder_hidden_states,
|
203 |
-
controlnet_cond=control_depth,
|
204 |
-
conditioning_scale=self.conditioning_scale,
|
205 |
-
return_dict=False,
|
206 |
-
)
|
207 |
-
return self.unet(
|
208 |
-
sample,
|
209 |
-
timestep,
|
210 |
-
encoder_hidden_states=encoder_hidden_states,
|
211 |
-
down_block_res_samples=down_block_res_samples,
|
212 |
-
mid_block_res_sample=mid_block_res_sample,
|
213 |
-
cross_attention_kwargs=cross_attention_kwargs
|
214 |
-
)
|
215 |
-
|
216 |
-
|
217 |
-
class ModuleListDict(torch.nn.Module):
|
218 |
-
def __init__(self, procs: dict) -> None:
|
219 |
-
super().__init__()
|
220 |
-
self.keys = sorted(procs.keys())
|
221 |
-
self.values = torch.nn.ModuleList(procs[k] for k in self.keys)
|
222 |
-
|
223 |
-
def __getitem__(self, key):
|
224 |
-
return self.values[self.keys.index(key)]
|
225 |
-
|
226 |
-
|
227 |
-
class SuperNet(torch.nn.Module):
|
228 |
-
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
229 |
-
super().__init__()
|
230 |
-
state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys()))
|
231 |
-
self.layers = torch.nn.ModuleList(state_dict.values())
|
232 |
-
self.mapping = dict(enumerate(state_dict.keys()))
|
233 |
-
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
|
234 |
-
|
235 |
-
# .processor for unet, .self_attn for text encoder
|
236 |
-
self.split_keys = [".processor", ".self_attn"]
|
237 |
-
|
238 |
-
# we add a hook to state_dict() and load_state_dict() so that the
|
239 |
-
# naming fits with `unet.attn_processors`
|
240 |
-
def map_to(module, state_dict, *args, **kwargs):
|
241 |
-
new_state_dict = {}
|
242 |
-
for key, value in state_dict.items():
|
243 |
-
num = int(key.split(".")[1]) # 0 is always "layers"
|
244 |
-
new_key = key.replace(f"layers.{num}", module.mapping[num])
|
245 |
-
new_state_dict[new_key] = value
|
246 |
-
|
247 |
-
return new_state_dict
|
248 |
-
|
249 |
-
def remap_key(key, state_dict):
|
250 |
-
for k in self.split_keys:
|
251 |
-
if k in key:
|
252 |
-
return key.split(k)[0] + k
|
253 |
-
return key.split('.')[0]
|
254 |
-
|
255 |
-
def map_from(module, state_dict, *args, **kwargs):
|
256 |
-
all_keys = list(state_dict.keys())
|
257 |
-
for key in all_keys:
|
258 |
-
replace_key = remap_key(key, state_dict)
|
259 |
-
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
|
260 |
-
state_dict[new_key] = state_dict[key]
|
261 |
-
del state_dict[key]
|
262 |
-
|
263 |
-
self._register_state_dict_hook(map_to)
|
264 |
-
self._register_load_state_dict_pre_hook(map_from, with_module=True)
|
265 |
-
|
266 |
-
|
267 |
-
class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
|
268 |
-
tokenizer: transformers.CLIPTokenizer
|
269 |
-
text_encoder: transformers.CLIPTextModel
|
270 |
-
vision_encoder: transformers.CLIPVisionModelWithProjection
|
271 |
-
|
272 |
-
feature_extractor_clip: transformers.CLIPImageProcessor
|
273 |
-
unet: UNet2DConditionModel
|
274 |
-
scheduler: diffusers.schedulers.KarrasDiffusionSchedulers
|
275 |
-
|
276 |
-
vae: AutoencoderKL
|
277 |
-
ramping: nn.Linear
|
278 |
-
|
279 |
-
feature_extractor_vae: transformers.CLIPImageProcessor
|
280 |
-
|
281 |
-
depth_transforms_multi = transforms.Compose([
|
282 |
-
transforms.ToTensor(),
|
283 |
-
transforms.Normalize([0.5], [0.5])
|
284 |
-
])
|
285 |
-
|
286 |
-
def __init__(
|
287 |
-
self,
|
288 |
-
vae: AutoencoderKL,
|
289 |
-
text_encoder: CLIPTextModel,
|
290 |
-
tokenizer: CLIPTokenizer,
|
291 |
-
unet: UNet2DConditionModel,
|
292 |
-
scheduler: KarrasDiffusionSchedulers,
|
293 |
-
vision_encoder: transformers.CLIPVisionModelWithProjection,
|
294 |
-
feature_extractor_clip: CLIPImageProcessor,
|
295 |
-
feature_extractor_vae: CLIPImageProcessor,
|
296 |
-
ramping_coefficients: Optional[list] = None,
|
297 |
-
safety_checker=None,
|
298 |
-
):
|
299 |
-
DiffusionPipeline.__init__(self)
|
300 |
-
|
301 |
-
self.register_modules(
|
302 |
-
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
|
303 |
-
unet=unet, scheduler=scheduler, safety_checker=None,
|
304 |
-
vision_encoder=vision_encoder,
|
305 |
-
feature_extractor_clip=feature_extractor_clip,
|
306 |
-
feature_extractor_vae=feature_extractor_vae
|
307 |
-
)
|
308 |
-
self.register_to_config(ramping_coefficients=ramping_coefficients)
|
309 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
310 |
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
311 |
-
|
312 |
-
def prepare(self):
|
313 |
-
train_sched = DDPMScheduler.from_config(self.scheduler.config)
|
314 |
-
if isinstance(self.unet, UNet2DConditionModel):
|
315 |
-
self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()
|
316 |
-
|
317 |
-
def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0):
|
318 |
-
self.prepare()
|
319 |
-
self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)
|
320 |
-
return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)]))
|
321 |
-
|
322 |
-
def encode_condition_image(self, image: torch.Tensor):
|
323 |
-
image = self.vae.encode(image).latent_dist.sample()
|
324 |
-
return image
|
325 |
-
|
326 |
-
@torch.no_grad()
|
327 |
-
def __call__(
|
328 |
-
self,
|
329 |
-
image: Image.Image = None,
|
330 |
-
prompt = "",
|
331 |
-
*args,
|
332 |
-
num_images_per_prompt: Optional[int] = 1,
|
333 |
-
guidance_scale=4.0,
|
334 |
-
depth_image: Image.Image = None,
|
335 |
-
output_type: Optional[str] = "pil",
|
336 |
-
width=640,
|
337 |
-
height=960,
|
338 |
-
num_inference_steps=28,
|
339 |
-
return_dict=True,
|
340 |
-
**kwargs
|
341 |
-
):
|
342 |
-
self.prepare()
|
343 |
-
if image is None:
|
344 |
-
raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
|
345 |
-
assert not isinstance(image, torch.Tensor)
|
346 |
-
image = to_rgb_image(image)
|
347 |
-
image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
|
348 |
-
image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
|
349 |
-
if depth_image is not None and hasattr(self.unet, "controlnet"):
|
350 |
-
depth_image = to_rgb_image(depth_image)
|
351 |
-
depth_image = self.depth_transforms_multi(depth_image).to(
|
352 |
-
device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
|
353 |
-
)
|
354 |
-
image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
|
355 |
-
image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
|
356 |
-
cond_lat = self.encode_condition_image(image)
|
357 |
-
if guidance_scale > 1:
|
358 |
-
negative_lat = self.encode_condition_image(torch.zeros_like(image))
|
359 |
-
cond_lat = torch.cat([negative_lat, cond_lat])
|
360 |
-
encoded = self.vision_encoder(image_2, output_hidden_states=False)
|
361 |
-
global_embeds = encoded.image_embeds
|
362 |
-
global_embeds = global_embeds.unsqueeze(-2)
|
363 |
-
|
364 |
-
if hasattr(self, "encode_prompt"):
|
365 |
-
encoder_hidden_states = self.encode_prompt(
|
366 |
-
prompt,
|
367 |
-
self.device,
|
368 |
-
num_images_per_prompt,
|
369 |
-
False
|
370 |
-
)[0]
|
371 |
-
else:
|
372 |
-
encoder_hidden_states = self._encode_prompt(
|
373 |
-
prompt,
|
374 |
-
self.device,
|
375 |
-
num_images_per_prompt,
|
376 |
-
False
|
377 |
-
)
|
378 |
-
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
|
379 |
-
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
|
380 |
-
cak = dict(cond_lat=cond_lat)
|
381 |
-
if hasattr(self.unet, "controlnet"):
|
382 |
-
cak['control_depth'] = depth_image
|
383 |
-
latents: torch.Tensor = super().__call__(
|
384 |
-
None,
|
385 |
-
*args,
|
386 |
-
cross_attention_kwargs=cak,
|
387 |
-
guidance_scale=guidance_scale,
|
388 |
-
num_images_per_prompt=num_images_per_prompt,
|
389 |
-
prompt_embeds=encoder_hidden_states,
|
390 |
-
num_inference_steps=num_inference_steps,
|
391 |
-
output_type='latent',
|
392 |
-
width=width,
|
393 |
-
height=height,
|
394 |
-
**kwargs
|
395 |
-
).images
|
396 |
-
latents = unscale_latents(latents)
|
397 |
-
if not output_type == "latent":
|
398 |
-
image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
|
399 |
-
else:
|
400 |
-
image = latents
|
401 |
-
|
402 |
-
image = self.image_processor.postprocess(image, output_type=output_type)
|
403 |
-
if not return_dict:
|
404 |
-
return (image,)
|
405 |
-
|
406 |
-
return ImagePipelineOutput(images=image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|