Spaces:
Running
on
Zero
Running
on
Zero
cavargas10
commited on
Commit
•
1f30907
1
Parent(s):
69b6a88
Upload 56 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- assets/teaser.jpg +0 -0
- custum_3d_diffusion/custum_modules/attention_processors.py +385 -0
- custum_3d_diffusion/custum_modules/unifield_processor.py +459 -0
- custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py +298 -0
- custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py +296 -0
- custum_3d_diffusion/modules.py +14 -0
- custum_3d_diffusion/trainings/__init__.py +0 -0
- custum_3d_diffusion/trainings/base.py +208 -0
- custum_3d_diffusion/trainings/config_classes.py +35 -0
- custum_3d_diffusion/trainings/image2image_trainer.py +86 -0
- custum_3d_diffusion/trainings/image2mvimage_trainer.py +139 -0
- custum_3d_diffusion/trainings/utils.py +25 -0
- gradio_app/__init__.py +0 -0
- gradio_app/all_models.py +22 -0
- gradio_app/custom_models/image2mvimage.yaml +63 -0
- gradio_app/custom_models/image2normal.yaml +61 -0
- gradio_app/custom_models/mvimg_prediction.py +59 -0
- gradio_app/custom_models/normal_prediction.py +28 -0
- gradio_app/custom_models/utils.py +75 -0
- gradio_app/examples/Groot.png +0 -0
- gradio_app/examples/aaa.png +0 -0
- gradio_app/examples/abma.png +0 -0
- gradio_app/examples/akun.png +0 -0
- gradio_app/examples/anya.png +0 -0
- gradio_app/examples/bag.png +3 -0
- gradio_app/examples/ex1.png +3 -0
- gradio_app/examples/ex2.png +0 -0
- gradio_app/examples/ex3.jpg +0 -0
- gradio_app/examples/ex4.png +0 -0
- gradio_app/examples/generated_1715761545_frame0.png +0 -0
- gradio_app/examples/generated_1715762357_frame0.png +0 -0
- gradio_app/examples/generated_1715763329_frame0.png +0 -0
- gradio_app/examples/hatsune_miku.png +0 -0
- gradio_app/examples/princess-large.png +0 -0
- gradio_app/gradio_3dgen.py +85 -0
- gradio_app/gradio_3dgen_steps.py +87 -0
- gradio_app/gradio_local.py +76 -0
- gradio_app/utils.py +112 -0
- mesh_reconstruction/func.py +133 -0
- mesh_reconstruction/opt.py +190 -0
- mesh_reconstruction/recon.py +59 -0
- mesh_reconstruction/refine.py +80 -0
- mesh_reconstruction/remesh.py +361 -0
- mesh_reconstruction/render.py +159 -0
- package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl +3 -0
- package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl +3 -0
- scripts/all_typing.py +42 -0
- scripts/load_onnx.py +48 -0
- scripts/mesh_init.py +132 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
gradio_app/examples/bag.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
gradio_app/examples/ex1.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
39 |
+
package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
assets/teaser.jpg
ADDED
custum_3d_diffusion/custum_modules/attention_processors.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional
|
2 |
+
import torch
|
3 |
+
from diffusers.models.attention_processor import Attention
|
4 |
+
|
5 |
+
def construct_pix2pix_attention(hidden_states_dim, norm_type="none"):
|
6 |
+
if norm_type == "layernorm":
|
7 |
+
norm = torch.nn.LayerNorm(hidden_states_dim)
|
8 |
+
else:
|
9 |
+
norm = torch.nn.Identity()
|
10 |
+
attention = Attention(
|
11 |
+
query_dim=hidden_states_dim,
|
12 |
+
heads=8,
|
13 |
+
dim_head=hidden_states_dim // 8,
|
14 |
+
bias=True,
|
15 |
+
)
|
16 |
+
# NOTE: xformers 0.22 does not support batchsize >= 4096
|
17 |
+
attention.xformers_not_supported = True # hacky solution
|
18 |
+
return norm, attention
|
19 |
+
|
20 |
+
class ExtraAttnProc(torch.nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
chained_proc,
|
24 |
+
enabled=False,
|
25 |
+
name=None,
|
26 |
+
mode='extract',
|
27 |
+
with_proj_in=False,
|
28 |
+
proj_in_dim=768,
|
29 |
+
target_dim=None,
|
30 |
+
pixel_wise_crosspond=False,
|
31 |
+
norm_type="none", # none or layernorm
|
32 |
+
crosspond_effect_on="all", # all or first
|
33 |
+
crosspond_chain_pos="parralle", # before or parralle or after
|
34 |
+
simple_3d=False,
|
35 |
+
views=4,
|
36 |
+
) -> None:
|
37 |
+
super().__init__()
|
38 |
+
self.enabled = enabled
|
39 |
+
self.chained_proc = chained_proc
|
40 |
+
self.name = name
|
41 |
+
self.mode = mode
|
42 |
+
self.with_proj_in=with_proj_in
|
43 |
+
self.proj_in_dim = proj_in_dim
|
44 |
+
self.target_dim = target_dim or proj_in_dim
|
45 |
+
self.hidden_states_dim = self.target_dim
|
46 |
+
self.pixel_wise_crosspond = pixel_wise_crosspond
|
47 |
+
self.crosspond_effect_on = crosspond_effect_on
|
48 |
+
self.crosspond_chain_pos = crosspond_chain_pos
|
49 |
+
self.views = views
|
50 |
+
self.simple_3d = simple_3d
|
51 |
+
if self.with_proj_in and self.enabled:
|
52 |
+
self.in_linear = torch.nn.Linear(self.proj_in_dim, self.target_dim, bias=False)
|
53 |
+
if self.target_dim == self.proj_in_dim:
|
54 |
+
self.in_linear.weight.data = torch.eye(proj_in_dim)
|
55 |
+
else:
|
56 |
+
self.in_linear = None
|
57 |
+
if self.pixel_wise_crosspond and self.enabled:
|
58 |
+
self.crosspond_norm, self.crosspond_attention = construct_pix2pix_attention(self.hidden_states_dim, norm_type=norm_type)
|
59 |
+
|
60 |
+
def do_crosspond_attention(self, hidden_states: torch.FloatTensor, other_states: torch.FloatTensor):
|
61 |
+
hidden_states = self.crosspond_norm(hidden_states)
|
62 |
+
|
63 |
+
batch, L, D = hidden_states.shape
|
64 |
+
assert hidden_states.shape == other_states.shape, f"got {hidden_states.shape} and {other_states.shape}"
|
65 |
+
# to -> batch * L, 1, D
|
66 |
+
hidden_states = hidden_states.reshape(batch * L, 1, D)
|
67 |
+
other_states = other_states.reshape(batch * L, 1, D)
|
68 |
+
hidden_states_catted = other_states
|
69 |
+
hidden_states = self.crosspond_attention(
|
70 |
+
hidden_states,
|
71 |
+
encoder_hidden_states=hidden_states_catted,
|
72 |
+
)
|
73 |
+
return hidden_states.reshape(batch, L, D)
|
74 |
+
|
75 |
+
def __call__(
|
76 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
77 |
+
ref_dict: dict = None, mode=None, **kwargs
|
78 |
+
) -> Any:
|
79 |
+
if not self.enabled:
|
80 |
+
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
81 |
+
if encoder_hidden_states is None:
|
82 |
+
encoder_hidden_states = hidden_states
|
83 |
+
assert ref_dict is not None
|
84 |
+
if (mode or self.mode) == 'extract':
|
85 |
+
ref_dict[self.name] = hidden_states
|
86 |
+
hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
87 |
+
if self.pixel_wise_crosspond and self.crosspond_chain_pos == "after":
|
88 |
+
ref_dict[self.name] = hidden_states1
|
89 |
+
return hidden_states1
|
90 |
+
elif (mode or self.mode) == 'inject':
|
91 |
+
ref_state = ref_dict.pop(self.name)
|
92 |
+
if self.with_proj_in:
|
93 |
+
ref_state = self.in_linear(ref_state)
|
94 |
+
|
95 |
+
B, L, D = ref_state.shape
|
96 |
+
if hidden_states.shape[0] == B:
|
97 |
+
modalities = 1
|
98 |
+
views = 1
|
99 |
+
else:
|
100 |
+
modalities = hidden_states.shape[0] // B // self.views
|
101 |
+
views = self.views
|
102 |
+
if self.pixel_wise_crosspond:
|
103 |
+
if self.crosspond_effect_on == "all":
|
104 |
+
ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, *ref_state.shape[-2:])
|
105 |
+
|
106 |
+
if self.crosspond_chain_pos == "before":
|
107 |
+
hidden_states = hidden_states + self.do_crosspond_attention(hidden_states, ref_state)
|
108 |
+
|
109 |
+
hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
110 |
+
|
111 |
+
if self.crosspond_chain_pos == "parralle":
|
112 |
+
hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states, ref_state)
|
113 |
+
|
114 |
+
if self.crosspond_chain_pos == "after":
|
115 |
+
hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states1, ref_state)
|
116 |
+
return hidden_states1
|
117 |
+
else:
|
118 |
+
assert self.crosspond_effect_on == "first"
|
119 |
+
# hidden_states [B * modalities * views, L, D]
|
120 |
+
# ref_state [B, L, D]
|
121 |
+
ref_state = ref_state[:, None].expand(-1, modalities, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1]) # [B * modalities, L, D]
|
122 |
+
|
123 |
+
def do_paritial_crosspond(hidden_states, ref_state):
|
124 |
+
first_view_hidden_states = hidden_states.view(-1, views, hidden_states.shape[1], hidden_states.shape[2])[:, 0] # [B * modalities, L, D]
|
125 |
+
hidden_states2 = self.do_crosspond_attention(first_view_hidden_states, ref_state) # [B * modalities, L, D]
|
126 |
+
hidden_states2_padded = torch.zeros_like(hidden_states).reshape(-1, views, hidden_states.shape[1], hidden_states.shape[2])
|
127 |
+
hidden_states2_padded[:, 0] = hidden_states2
|
128 |
+
hidden_states2_padded = hidden_states2_padded.reshape(-1, hidden_states.shape[1], hidden_states.shape[2])
|
129 |
+
return hidden_states2_padded
|
130 |
+
|
131 |
+
if self.crosspond_chain_pos == "before":
|
132 |
+
hidden_states = hidden_states + do_paritial_crosspond(hidden_states, ref_state)
|
133 |
+
|
134 |
+
hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) # [B * modalities * views, L, D]
|
135 |
+
if self.crosspond_chain_pos == "parralle":
|
136 |
+
hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states, ref_state)
|
137 |
+
if self.crosspond_chain_pos == "after":
|
138 |
+
hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states1, ref_state)
|
139 |
+
return hidden_states1
|
140 |
+
elif self.simple_3d:
|
141 |
+
B, L, C = encoder_hidden_states.shape
|
142 |
+
mv = self.views
|
143 |
+
encoder_hidden_states = encoder_hidden_states.reshape(B // mv, mv, L, C)
|
144 |
+
ref_state = ref_state[:, None]
|
145 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
|
146 |
+
encoder_hidden_states = encoder_hidden_states.reshape(B // mv, 1, (mv+1) * L, C)
|
147 |
+
encoder_hidden_states = encoder_hidden_states.repeat(1, mv, 1, 1).reshape(-1, (mv+1) * L, C)
|
148 |
+
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
149 |
+
else:
|
150 |
+
ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1])
|
151 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
|
152 |
+
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
153 |
+
else:
|
154 |
+
raise NotImplementedError("mode or self.mode is required to be 'extract' or 'inject'")
|
155 |
+
|
156 |
+
def add_extra_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
|
157 |
+
return_dict = torch.nn.ModuleDict()
|
158 |
+
proj_in_dim = kwargs.get('proj_in_dim', False)
|
159 |
+
kwargs.pop('proj_in_dim', None)
|
160 |
+
|
161 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
162 |
+
for sub_name, child in module.named_children():
|
163 |
+
if "ref_unet" not in (sub_name + name):
|
164 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
165 |
+
|
166 |
+
if isinstance(module, Attention):
|
167 |
+
new_processor = ExtraAttnProc(
|
168 |
+
chained_proc=module.get_processor(),
|
169 |
+
enabled=enable_filter(f"{name}.processor"),
|
170 |
+
name=f"{name}.processor",
|
171 |
+
proj_in_dim=proj_in_dim if proj_in_dim else module.cross_attention_dim,
|
172 |
+
target_dim=module.cross_attention_dim,
|
173 |
+
**kwargs
|
174 |
+
)
|
175 |
+
module.set_processor(new_processor)
|
176 |
+
return_dict[f"{name}.processor".replace(".", "__")] = new_processor
|
177 |
+
|
178 |
+
for name, module in model.named_children():
|
179 |
+
recursive_add_processors(name, module)
|
180 |
+
return return_dict
|
181 |
+
|
182 |
+
def switch_extra_processor(model, enable_filter=lambda x:True):
|
183 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
184 |
+
for sub_name, child in module.named_children():
|
185 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
186 |
+
|
187 |
+
if isinstance(module, ExtraAttnProc):
|
188 |
+
module.enabled = enable_filter(name)
|
189 |
+
|
190 |
+
for name, module in model.named_children():
|
191 |
+
recursive_add_processors(name, module)
|
192 |
+
|
193 |
+
class multiviewAttnProc(torch.nn.Module):
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
chained_proc,
|
197 |
+
enabled=False,
|
198 |
+
name=None,
|
199 |
+
hidden_states_dim=None,
|
200 |
+
chain_pos="parralle", # before or parralle or after
|
201 |
+
num_modalities=1,
|
202 |
+
views=4,
|
203 |
+
base_img_size=64,
|
204 |
+
) -> None:
|
205 |
+
super().__init__()
|
206 |
+
self.enabled = enabled
|
207 |
+
self.chained_proc = chained_proc
|
208 |
+
self.name = name
|
209 |
+
self.hidden_states_dim = hidden_states_dim
|
210 |
+
self.num_modalities = num_modalities
|
211 |
+
self.views = views
|
212 |
+
self.base_img_size = base_img_size
|
213 |
+
self.chain_pos = chain_pos
|
214 |
+
self.diff_joint_attn = True
|
215 |
+
|
216 |
+
def __call__(
|
217 |
+
self,
|
218 |
+
attn: Attention,
|
219 |
+
hidden_states: torch.FloatTensor,
|
220 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
221 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
222 |
+
**kwargs
|
223 |
+
) -> torch.Tensor:
|
224 |
+
if not self.enabled:
|
225 |
+
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
226 |
+
|
227 |
+
B, L, C = hidden_states.shape
|
228 |
+
mv = self.views
|
229 |
+
hidden_states = hidden_states.reshape(B // mv, mv, L, C).reshape(-1, mv * L, C)
|
230 |
+
hidden_states = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
231 |
+
return hidden_states.reshape(B // mv, mv, L, C).reshape(-1, L, C)
|
232 |
+
|
233 |
+
def add_multiview_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
|
234 |
+
return_dict = torch.nn.ModuleDict()
|
235 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
236 |
+
for sub_name, child in module.named_children():
|
237 |
+
if "ref_unet" not in (sub_name + name):
|
238 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
239 |
+
|
240 |
+
if isinstance(module, Attention):
|
241 |
+
new_processor = multiviewAttnProc(
|
242 |
+
chained_proc=module.get_processor(),
|
243 |
+
enabled=enable_filter(f"{name}.processor"),
|
244 |
+
name=f"{name}.processor",
|
245 |
+
hidden_states_dim=module.inner_dim,
|
246 |
+
**kwargs
|
247 |
+
)
|
248 |
+
module.set_processor(new_processor)
|
249 |
+
return_dict[f"{name}.processor".replace(".", "__")] = new_processor
|
250 |
+
|
251 |
+
for name, module in model.named_children():
|
252 |
+
recursive_add_processors(name, module)
|
253 |
+
|
254 |
+
return return_dict
|
255 |
+
|
256 |
+
def switch_multiview_processor(model, enable_filter=lambda x:True):
|
257 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
258 |
+
for sub_name, child in module.named_children():
|
259 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
260 |
+
|
261 |
+
if isinstance(module, Attention):
|
262 |
+
processor = module.get_processor()
|
263 |
+
if isinstance(processor, multiviewAttnProc):
|
264 |
+
processor.enabled = enable_filter(f"{name}.processor")
|
265 |
+
|
266 |
+
for name, module in model.named_children():
|
267 |
+
recursive_add_processors(name, module)
|
268 |
+
|
269 |
+
class NNModuleWrapper(torch.nn.Module):
|
270 |
+
def __init__(self, module):
|
271 |
+
super().__init__()
|
272 |
+
self.module = module
|
273 |
+
|
274 |
+
def forward(self, *args, **kwargs):
|
275 |
+
return self.module(*args, **kwargs)
|
276 |
+
|
277 |
+
def __getattr__(self, name: str):
|
278 |
+
try:
|
279 |
+
return super().__getattr__(name)
|
280 |
+
except AttributeError:
|
281 |
+
return getattr(self.module, name)
|
282 |
+
|
283 |
+
class AttnProcessorSwitch(torch.nn.Module):
|
284 |
+
def __init__(
|
285 |
+
self,
|
286 |
+
proc_dict: dict,
|
287 |
+
enabled_proc="default",
|
288 |
+
name=None,
|
289 |
+
switch_name="default_switch",
|
290 |
+
):
|
291 |
+
super().__init__()
|
292 |
+
self.proc_dict = torch.nn.ModuleDict({k: (v if isinstance(v, torch.nn.Module) else NNModuleWrapper(v)) for k, v in proc_dict.items()})
|
293 |
+
self.enabled_proc = enabled_proc
|
294 |
+
self.name = name
|
295 |
+
self.switch_name = switch_name
|
296 |
+
self.choose_module(enabled_proc)
|
297 |
+
|
298 |
+
def choose_module(self, enabled_proc):
|
299 |
+
self.enabled_proc = enabled_proc
|
300 |
+
assert enabled_proc in self.proc_dict.keys()
|
301 |
+
|
302 |
+
def __call__(
|
303 |
+
self,
|
304 |
+
*args,
|
305 |
+
**kwargs
|
306 |
+
) -> torch.FloatTensor:
|
307 |
+
used_proc = self.proc_dict[self.enabled_proc]
|
308 |
+
return used_proc(*args, **kwargs)
|
309 |
+
|
310 |
+
def add_switch(model: torch.nn.Module, module_filter=lambda x:True, switch_dict_fn=lambda x: {"default": x}, switch_name="default_switch", enabled_proc="default"):
|
311 |
+
return_dict = torch.nn.ModuleDict()
|
312 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
313 |
+
for sub_name, child in module.named_children():
|
314 |
+
if "ref_unet" not in (sub_name + name):
|
315 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
316 |
+
|
317 |
+
if isinstance(module, Attention):
|
318 |
+
processor = module.get_processor()
|
319 |
+
if module_filter(processor):
|
320 |
+
proc_dict = switch_dict_fn(processor)
|
321 |
+
new_processor = AttnProcessorSwitch(
|
322 |
+
proc_dict=proc_dict,
|
323 |
+
enabled_proc=enabled_proc,
|
324 |
+
name=f"{name}.processor",
|
325 |
+
switch_name=switch_name,
|
326 |
+
)
|
327 |
+
module.set_processor(new_processor)
|
328 |
+
return_dict[f"{name}.processor".replace(".", "__")] = new_processor
|
329 |
+
|
330 |
+
for name, module in model.named_children():
|
331 |
+
recursive_add_processors(name, module)
|
332 |
+
|
333 |
+
return return_dict
|
334 |
+
|
335 |
+
def change_switch(model: torch.nn.Module, switch_name="default_switch", enabled_proc="default"):
|
336 |
+
def recursive_change_processors(name: str, module: torch.nn.Module):
|
337 |
+
for sub_name, child in module.named_children():
|
338 |
+
recursive_change_processors(f"{name}.{sub_name}", child)
|
339 |
+
|
340 |
+
if isinstance(module, Attention):
|
341 |
+
processor = module.get_processor()
|
342 |
+
if isinstance(processor, AttnProcessorSwitch) and processor.switch_name == switch_name:
|
343 |
+
processor.choose_module(enabled_proc)
|
344 |
+
|
345 |
+
for name, module in model.named_children():
|
346 |
+
recursive_change_processors(name, module)
|
347 |
+
|
348 |
+
########## Hack: Attention fix #############
|
349 |
+
from diffusers.models.attention import Attention
|
350 |
+
|
351 |
+
def forward(
|
352 |
+
self,
|
353 |
+
hidden_states: torch.FloatTensor,
|
354 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
355 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
356 |
+
**cross_attention_kwargs,
|
357 |
+
) -> torch.Tensor:
|
358 |
+
r"""
|
359 |
+
The forward method of the `Attention` class.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
hidden_states (`torch.Tensor`):
|
363 |
+
The hidden states of the query.
|
364 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
365 |
+
The hidden states of the encoder.
|
366 |
+
attention_mask (`torch.Tensor`, *optional*):
|
367 |
+
The attention mask to use. If `None`, no mask is applied.
|
368 |
+
**cross_attention_kwargs:
|
369 |
+
Additional keyword arguments to pass along to the cross attention.
|
370 |
+
|
371 |
+
Returns:
|
372 |
+
`torch.Tensor`: The output of the attention layer.
|
373 |
+
"""
|
374 |
+
# The `Attention` class can call different attention processors / attention functions
|
375 |
+
# here we simply pass along all tensors to the selected processor class
|
376 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
377 |
+
return self.processor(
|
378 |
+
self,
|
379 |
+
hidden_states,
|
380 |
+
encoder_hidden_states=encoder_hidden_states,
|
381 |
+
attention_mask=attention_mask,
|
382 |
+
**cross_attention_kwargs,
|
383 |
+
)
|
384 |
+
|
385 |
+
Attention.forward = forward
|
custum_3d_diffusion/custum_modules/unifield_processor.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import FunctionType
|
2 |
+
from typing import Any, Dict, List
|
3 |
+
from diffusers import UNet2DConditionModel
|
4 |
+
import torch
|
5 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, ImageProjection
|
6 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from diffusers.loaders import IPAdapterMixin
|
9 |
+
from custum_3d_diffusion.custum_modules.attention_processors import add_extra_processor, switch_extra_processor, add_multiview_processor, switch_multiview_processor, add_switch, change_switch
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class AttnConfig:
|
13 |
+
"""
|
14 |
+
* CrossAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), IPAdapter module (achieves conceptual control).
|
15 |
+
* SelfAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), Reference Attention module (achieves pixel-level control).
|
16 |
+
* Multiview Attention module: Multiview Attention module (achieves multi-view consistency).
|
17 |
+
* Cross Modality Attention module: Cross Modality Attention module (achieves multi-modality consistency).
|
18 |
+
|
19 |
+
For setups:
|
20 |
+
train_xxx_lr is implemented in the U-Net architecture.
|
21 |
+
enable_xxx_lora is implemented in the U-Net architecture.
|
22 |
+
enable_xxx_ip is implemented in the processor and U-Net architecture.
|
23 |
+
enable_xxx_ref_proj_in is implemented in the processor.
|
24 |
+
"""
|
25 |
+
latent_size: int = 64
|
26 |
+
|
27 |
+
train_lr: float = 0
|
28 |
+
# for cross attention
|
29 |
+
# 0 learning rate for not training
|
30 |
+
train_cross_attn_lr: float = 0
|
31 |
+
train_cross_attn_lora_lr: float = 0
|
32 |
+
train_cross_attn_ip_lr: float = 0 # 0 for not trained
|
33 |
+
init_cross_attn_lora: bool = False
|
34 |
+
enable_cross_attn_lora: bool = False
|
35 |
+
init_cross_attn_ip: bool = False
|
36 |
+
enable_cross_attn_ip: bool = False
|
37 |
+
cross_attn_lora_rank: int = 64 # 0 for not enabled
|
38 |
+
cross_attn_lora_only_kv: bool = False
|
39 |
+
ipadapter_pretrained_name: str = "h94/IP-Adapter"
|
40 |
+
ipadapter_subfolder_name: str = "models"
|
41 |
+
ipadapter_weight_name: str = "ip-adapter-plus_sd15.safetensors"
|
42 |
+
ipadapter_effect_on: str = "all" # all, first
|
43 |
+
|
44 |
+
# for self attention
|
45 |
+
train_self_attn_lr: float = 0
|
46 |
+
train_self_attn_lora_lr: float = 0
|
47 |
+
init_self_attn_lora: bool = False
|
48 |
+
enable_self_attn_lora: bool = False
|
49 |
+
self_attn_lora_rank: int = 64
|
50 |
+
self_attn_lora_only_kv: bool = False
|
51 |
+
|
52 |
+
train_self_attn_ref_lr: float = 0
|
53 |
+
train_ref_unet_lr: float = 0
|
54 |
+
init_self_attn_ref: bool = False
|
55 |
+
enable_self_attn_ref: bool = False
|
56 |
+
self_attn_ref_other_model_name: str = ""
|
57 |
+
self_attn_ref_position: str = "attn1"
|
58 |
+
self_attn_ref_pixel_wise_crosspond: bool = False # enable pixel_wise_crosspond in refattn
|
59 |
+
self_attn_ref_chain_pos: str = "parralle" # before or parralle or after
|
60 |
+
self_attn_ref_effect_on: str = "all" # all or first, for _crosspond attn
|
61 |
+
self_attn_ref_zero_init: bool = True
|
62 |
+
use_simple3d_attn: bool = False
|
63 |
+
|
64 |
+
# for multiview attention
|
65 |
+
init_multiview_attn: bool = False
|
66 |
+
enable_multiview_attn: bool = False
|
67 |
+
multiview_attn_position: str = "attn1"
|
68 |
+
multiview_chain_pose: str = "parralle" # before or parralle or after
|
69 |
+
num_modalities: int = 1
|
70 |
+
use_mv_joint_attn: bool = False
|
71 |
+
|
72 |
+
# for unet
|
73 |
+
init_unet_path: str = "runwayml/stable-diffusion-v1-5"
|
74 |
+
init_num_cls_label: int = 0 # for initialize
|
75 |
+
cls_labels: List[int] = field(default_factory=lambda: [])
|
76 |
+
cls_label_type: str = "embedding"
|
77 |
+
cat_condition: bool = False # cat condition to input
|
78 |
+
|
79 |
+
class Configurable:
|
80 |
+
attn_config: AttnConfig
|
81 |
+
|
82 |
+
def set_config(self, attn_config: AttnConfig):
|
83 |
+
raise NotImplementedError()
|
84 |
+
|
85 |
+
def update_config(self, attn_config: AttnConfig):
|
86 |
+
self.attn_config = attn_config
|
87 |
+
|
88 |
+
def do_set_config(self, attn_config: AttnConfig):
|
89 |
+
self.set_config(attn_config)
|
90 |
+
for name, module in self.named_modules():
|
91 |
+
if isinstance(module, Configurable):
|
92 |
+
if hasattr(module, "do_set_config"):
|
93 |
+
module.do_set_config(attn_config)
|
94 |
+
else:
|
95 |
+
print(f"Warning: {name} has no attribute do_set_config, but is an instance of Configurable")
|
96 |
+
module.attn_config = attn_config
|
97 |
+
|
98 |
+
def do_update_config(self, attn_config: AttnConfig):
|
99 |
+
self.update_config(attn_config)
|
100 |
+
for name, module in self.named_modules():
|
101 |
+
if isinstance(module, Configurable):
|
102 |
+
if hasattr(module, "do_update_config"):
|
103 |
+
module.do_update_config(attn_config)
|
104 |
+
else:
|
105 |
+
print(f"Warning: {name} has no attribute do_update_config, but is an instance of Configurable")
|
106 |
+
module.attn_config = attn_config
|
107 |
+
|
108 |
+
from diffusers import ModelMixin # Must import ModelMixin for CompiledUNet
|
109 |
+
class UnifieldWrappedUNet(UNet2DConditionModel):
|
110 |
+
forward_hook: FunctionType
|
111 |
+
|
112 |
+
def forward(self, *args, **kwargs):
|
113 |
+
if hasattr(self, 'forward_hook'):
|
114 |
+
return self.forward_hook(super().forward, *args, **kwargs)
|
115 |
+
return super().forward(*args, **kwargs)
|
116 |
+
|
117 |
+
|
118 |
+
class ConfigurableUNet2DConditionModel(Configurable, IPAdapterMixin):
|
119 |
+
unet: UNet2DConditionModel
|
120 |
+
|
121 |
+
cls_embedding_param_dict = {}
|
122 |
+
cross_attn_lora_param_dict = {}
|
123 |
+
self_attn_lora_param_dict = {}
|
124 |
+
cross_attn_param_dict = {}
|
125 |
+
self_attn_param_dict = {}
|
126 |
+
ipadapter_param_dict = {}
|
127 |
+
ref_attn_param_dict = {}
|
128 |
+
ref_unet_param_dict = {}
|
129 |
+
multiview_attn_param_dict = {}
|
130 |
+
other_param_dict = {}
|
131 |
+
|
132 |
+
rev_param_name_mapping = {}
|
133 |
+
|
134 |
+
class_labels = []
|
135 |
+
def set_class_labels(self, class_labels: torch.Tensor):
|
136 |
+
if self.attn_config.init_num_cls_label != 0:
|
137 |
+
self.class_labels = class_labels.to(self.unet.device).long()
|
138 |
+
|
139 |
+
def __init__(self, init_config: AttnConfig, weight_dtype) -> None:
|
140 |
+
super().__init__()
|
141 |
+
self.weight_dtype = weight_dtype
|
142 |
+
self.set_config(init_config)
|
143 |
+
|
144 |
+
def enable_xformers_memory_efficient_attention(self):
|
145 |
+
self.unet.enable_xformers_memory_efficient_attention
|
146 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
147 |
+
for sub_name, child in module.named_children():
|
148 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
149 |
+
|
150 |
+
if isinstance(module, Attention):
|
151 |
+
if hasattr(module, 'xformers_not_supported'):
|
152 |
+
return
|
153 |
+
old_processor = module.get_processor()
|
154 |
+
if isinstance(old_processor, (AttnProcessor, AttnProcessor2_0)):
|
155 |
+
module.set_use_memory_efficient_attention_xformers(True)
|
156 |
+
|
157 |
+
for name, module in self.unet.named_children():
|
158 |
+
recursive_add_processors(name, module)
|
159 |
+
|
160 |
+
def __getattr__(self, name: str) -> Any:
|
161 |
+
try:
|
162 |
+
return super().__getattr__(name)
|
163 |
+
except AttributeError:
|
164 |
+
return getattr(self.unet, name)
|
165 |
+
|
166 |
+
# --- for IPAdapterMixin
|
167 |
+
|
168 |
+
def register_modules(self, **kwargs):
|
169 |
+
for name, module in kwargs.items():
|
170 |
+
# set models
|
171 |
+
setattr(self, name, module)
|
172 |
+
|
173 |
+
def register_to_config(self, **kwargs):
|
174 |
+
pass
|
175 |
+
|
176 |
+
def unload_ip_adapter(self):
|
177 |
+
raise NotImplementedError()
|
178 |
+
|
179 |
+
# --- for Configurable
|
180 |
+
|
181 |
+
def get_refunet(self):
|
182 |
+
if self.attn_config.self_attn_ref_other_model_name == "self":
|
183 |
+
return self.unet
|
184 |
+
else:
|
185 |
+
return self.unet.ref_unet
|
186 |
+
|
187 |
+
def set_config(self, attn_config: AttnConfig):
|
188 |
+
self.attn_config = attn_config
|
189 |
+
|
190 |
+
unet_type = UnifieldWrappedUNet
|
191 |
+
# class_embed_type = "projection" for 'camera'
|
192 |
+
# class_embed_type = None for 'embedding'
|
193 |
+
unet_kwargs = {}
|
194 |
+
if attn_config.init_num_cls_label > 0:
|
195 |
+
if attn_config.cls_label_type == "embedding":
|
196 |
+
unet_kwargs = {
|
197 |
+
"num_class_embeds": attn_config.init_num_cls_label,
|
198 |
+
"device_map": None,
|
199 |
+
"low_cpu_mem_usage": False,
|
200 |
+
"class_embed_type": None,
|
201 |
+
}
|
202 |
+
else:
|
203 |
+
raise ValueError(f"cls_label_type {attn_config.cls_label_type} is not supported")
|
204 |
+
|
205 |
+
self.unet: UnifieldWrappedUNet = unet_type.from_pretrained(
|
206 |
+
attn_config.init_unet_path, subfolder="unet", torch_dtype=self.weight_dtype,
|
207 |
+
**unet_kwargs
|
208 |
+
)
|
209 |
+
assert isinstance(self.unet, UnifieldWrappedUNet)
|
210 |
+
self.unet.forward_hook = self.unet_forward_hook
|
211 |
+
|
212 |
+
if self.attn_config.cat_condition:
|
213 |
+
# double in_channels
|
214 |
+
if self.unet.config.in_channels != 8:
|
215 |
+
self.unet.register_to_config(in_channels=self.unet.config.in_channels * 2)
|
216 |
+
# repeate unet.conv_in weight twice
|
217 |
+
doubled_conv_in = torch.nn.Conv2d(self.unet.conv_in.in_channels * 2, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
|
218 |
+
doubled_conv_in.weight.data = torch.cat([self.unet.conv_in.weight.data, torch.zeros_like(self.unet.conv_in.weight.data)], dim=1)
|
219 |
+
doubled_conv_in.bias.data = self.unet.conv_in.bias.data
|
220 |
+
self.unet.conv_in = doubled_conv_in
|
221 |
+
|
222 |
+
used_param_ids = set()
|
223 |
+
|
224 |
+
if attn_config.init_cross_attn_lora:
|
225 |
+
# setup lora
|
226 |
+
from peft import LoraConfig
|
227 |
+
from peft.utils import get_peft_model_state_dict
|
228 |
+
if attn_config.cross_attn_lora_only_kv:
|
229 |
+
target_modules=["attn2.to_k", "attn2.to_v"]
|
230 |
+
else:
|
231 |
+
target_modules=["attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0"]
|
232 |
+
lora_config: LoraConfig = LoraConfig(
|
233 |
+
r=attn_config.cross_attn_lora_rank,
|
234 |
+
lora_alpha=attn_config.cross_attn_lora_rank,
|
235 |
+
init_lora_weights="gaussian",
|
236 |
+
target_modules=target_modules,
|
237 |
+
)
|
238 |
+
adapter_name="cross_attn_lora"
|
239 |
+
self.unet.add_adapter(lora_config, adapter_name=adapter_name)
|
240 |
+
# update cross_attn_lora_param_dict
|
241 |
+
self.cross_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
|
242 |
+
used_param_ids.update(self.cross_attn_lora_param_dict.keys())
|
243 |
+
|
244 |
+
if attn_config.init_self_attn_lora:
|
245 |
+
# setup lora
|
246 |
+
from peft import LoraConfig
|
247 |
+
if attn_config.self_attn_lora_only_kv:
|
248 |
+
target_modules=["attn1.to_k", "attn1.to_v"]
|
249 |
+
else:
|
250 |
+
target_modules=["attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0"]
|
251 |
+
lora_config: LoraConfig = LoraConfig(
|
252 |
+
r=attn_config.self_attn_lora_rank,
|
253 |
+
lora_alpha=attn_config.self_attn_lora_rank,
|
254 |
+
init_lora_weights="gaussian",
|
255 |
+
target_modules=target_modules,
|
256 |
+
)
|
257 |
+
adapter_name="self_attn_lora"
|
258 |
+
self.unet.add_adapter(lora_config, adapter_name=adapter_name)
|
259 |
+
# update cross_self_lora_param_dict
|
260 |
+
self.self_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
|
261 |
+
used_param_ids.update(self.self_attn_lora_param_dict.keys())
|
262 |
+
|
263 |
+
if attn_config.init_num_cls_label != 0:
|
264 |
+
self.cls_embedding_param_dict = {id(param): param for param in self.unet.class_embedding.parameters()}
|
265 |
+
used_param_ids.update(self.cls_embedding_param_dict.keys())
|
266 |
+
self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
|
267 |
+
|
268 |
+
if attn_config.init_cross_attn_ip:
|
269 |
+
self.image_encoder = None
|
270 |
+
# setup ipadapter
|
271 |
+
self.load_ip_adapter(
|
272 |
+
attn_config.ipadapter_pretrained_name,
|
273 |
+
subfolder=attn_config.ipadapter_subfolder_name,
|
274 |
+
weight_name=attn_config.ipadapter_weight_name
|
275 |
+
)
|
276 |
+
# warp ip_adapter_attn_proc with switch
|
277 |
+
from diffusers.models.attention_processor import IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0
|
278 |
+
add_switch(self.unet, module_filter=lambda x: isinstance(x, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)), switch_dict_fn=lambda x: {"ipadapter": x, "default": XFormersAttnProcessor()}, switch_name="ipadapter_switch", enabled_proc="ipadapter")
|
279 |
+
# update ipadapter_param_dict
|
280 |
+
# weights are in attention processors and unet.encoder_hid_proj
|
281 |
+
self.ipadapter_param_dict = {id(param): param for param in self.unet.encoder_hid_proj.parameters() if id(param) not in used_param_ids}
|
282 |
+
used_param_ids.update(self.ipadapter_param_dict.keys())
|
283 |
+
print("DEBUG: ipadapter_param_dict len in encoder_hid_proj", len(self.ipadapter_param_dict))
|
284 |
+
for name, processor in self.unet.attn_processors.items():
|
285 |
+
if hasattr(processor, "to_k_ip"):
|
286 |
+
self.ipadapter_param_dict.update({id(param): param for param in processor.parameters()})
|
287 |
+
print(f"DEBUG: ipadapter_param_dict len in all", len(self.ipadapter_param_dict))
|
288 |
+
|
289 |
+
ref_unet = None
|
290 |
+
if attn_config.init_self_attn_ref:
|
291 |
+
# setup reference attention processor
|
292 |
+
if attn_config.self_attn_ref_other_model_name == "self":
|
293 |
+
raise NotImplementedError("self reference is not fully implemented")
|
294 |
+
else:
|
295 |
+
ref_unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
|
296 |
+
attn_config.self_attn_ref_other_model_name, subfolder="unet", torch_dtype=self.unet.dtype
|
297 |
+
)
|
298 |
+
ref_unet.to(self.unet.device)
|
299 |
+
if self.attn_config.train_ref_unet_lr == 0:
|
300 |
+
ref_unet.eval()
|
301 |
+
ref_unet.requires_grad_(False)
|
302 |
+
else:
|
303 |
+
ref_unet.train()
|
304 |
+
|
305 |
+
add_extra_processor(
|
306 |
+
model=ref_unet,
|
307 |
+
enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
|
308 |
+
mode='extract',
|
309 |
+
with_proj_in=False,
|
310 |
+
pixel_wise_crosspond=False,
|
311 |
+
)
|
312 |
+
# NOTE: Here require cross_attention_dim in two unet's self attention should be the same
|
313 |
+
processor_dict = add_extra_processor(
|
314 |
+
model=self.unet,
|
315 |
+
enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
|
316 |
+
mode='inject',
|
317 |
+
with_proj_in=False,
|
318 |
+
pixel_wise_crosspond=attn_config.self_attn_ref_pixel_wise_crosspond,
|
319 |
+
crosspond_effect_on=attn_config.self_attn_ref_effect_on,
|
320 |
+
crosspond_chain_pos=attn_config.self_attn_ref_chain_pos,
|
321 |
+
simple_3d=attn_config.use_simple3d_attn,
|
322 |
+
)
|
323 |
+
self.ref_unet_param_dict = {id(param): param for name, param in ref_unet.named_parameters() if id(param) not in used_param_ids and (attn_config.self_attn_ref_position in name)}
|
324 |
+
if attn_config.self_attn_ref_chain_pos != "after":
|
325 |
+
# pop untrainable paramters
|
326 |
+
for name, param in ref_unet.named_parameters():
|
327 |
+
if id(param) in self.ref_unet_param_dict and ('up_blocks.3.attentions.2.transformer_blocks.0.' in name):
|
328 |
+
self.ref_unet_param_dict.pop(id(param))
|
329 |
+
used_param_ids.update(self.ref_unet_param_dict.keys())
|
330 |
+
# update ref_attn_param_dict
|
331 |
+
self.ref_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
|
332 |
+
used_param_ids.update(self.ref_attn_param_dict.keys())
|
333 |
+
|
334 |
+
if attn_config.init_multiview_attn:
|
335 |
+
processor_dict = add_multiview_processor(
|
336 |
+
model = self.unet,
|
337 |
+
enable_filter = lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"),
|
338 |
+
num_modalities = attn_config.num_modalities,
|
339 |
+
base_img_size = attn_config.latent_size,
|
340 |
+
chain_pos = attn_config.multiview_chain_pose,
|
341 |
+
)
|
342 |
+
# update multiview_attn_param_dict
|
343 |
+
self.multiview_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
|
344 |
+
used_param_ids.update(self.multiview_attn_param_dict.keys())
|
345 |
+
|
346 |
+
# initialize cross_attn_param_dict parameters
|
347 |
+
self.cross_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn2" in name and id(param) not in used_param_ids}
|
348 |
+
used_param_ids.update(self.cross_attn_param_dict.keys())
|
349 |
+
|
350 |
+
# initialize self_attn_param_dict parameters
|
351 |
+
self.self_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn1" in name and id(param) not in used_param_ids}
|
352 |
+
used_param_ids.update(self.self_attn_param_dict.keys())
|
353 |
+
|
354 |
+
# initialize other_param_dict parameters
|
355 |
+
self.other_param_dict = {id(param): param for name, param in self.unet.named_parameters() if id(param) not in used_param_ids}
|
356 |
+
|
357 |
+
if ref_unet is not None:
|
358 |
+
self.unet.ref_unet = ref_unet
|
359 |
+
|
360 |
+
self.rev_param_name_mapping = {id(param): name for name, param in self.unet.named_parameters()}
|
361 |
+
|
362 |
+
self.update_config(attn_config, force_update=True)
|
363 |
+
return self.unet
|
364 |
+
|
365 |
+
_attn_keys_to_update = ["enable_cross_attn_lora", "enable_cross_attn_ip", "enable_self_attn_lora", "enable_self_attn_ref", "enable_multiview_attn", "cls_labels"]
|
366 |
+
|
367 |
+
def update_config(self, attn_config: AttnConfig, force_update=False):
|
368 |
+
assert isinstance(self.unet, UNet2DConditionModel), "unet must be an instance of UNet2DConditionModel"
|
369 |
+
|
370 |
+
need_to_update = False
|
371 |
+
# update cls_labels
|
372 |
+
for key in self._attn_keys_to_update:
|
373 |
+
if getattr(self.attn_config, key) != getattr(attn_config, key):
|
374 |
+
need_to_update = True
|
375 |
+
break
|
376 |
+
if not force_update and not need_to_update:
|
377 |
+
return
|
378 |
+
|
379 |
+
self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
|
380 |
+
|
381 |
+
# setup loras
|
382 |
+
if self.attn_config.init_cross_attn_lora or self.attn_config.init_self_attn_lora:
|
383 |
+
if attn_config.enable_cross_attn_lora or attn_config.enable_self_attn_lora:
|
384 |
+
cross_attn_lora_weight = 1. if attn_config.enable_cross_attn_lora > 0 else 0
|
385 |
+
self_attn_lora_weight = 1. if attn_config.enable_self_attn_lora > 0 else 0
|
386 |
+
self.unet.set_adapters(["cross_attn_lora", "self_attn_lora"], weights=[cross_attn_lora_weight, self_attn_lora_weight])
|
387 |
+
else:
|
388 |
+
self.unet.disable_adapters()
|
389 |
+
|
390 |
+
# setup ipadapter
|
391 |
+
if self.attn_config.init_cross_attn_ip:
|
392 |
+
if attn_config.enable_cross_attn_ip:
|
393 |
+
change_switch(self.unet, "ipadapter_switch", "ipadapter")
|
394 |
+
else:
|
395 |
+
change_switch(self.unet, "ipadapter_switch", "default")
|
396 |
+
|
397 |
+
# setup reference attention processor
|
398 |
+
if self.attn_config.init_self_attn_ref:
|
399 |
+
if attn_config.enable_self_attn_ref:
|
400 |
+
switch_extra_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"))
|
401 |
+
else:
|
402 |
+
switch_extra_processor(self.unet, enable_filter=lambda name: False)
|
403 |
+
|
404 |
+
# setup multiview attention processor
|
405 |
+
if self.attn_config.init_multiview_attn:
|
406 |
+
if attn_config.enable_multiview_attn:
|
407 |
+
switch_multiview_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"))
|
408 |
+
else:
|
409 |
+
switch_multiview_processor(self.unet, enable_filter=lambda name: False)
|
410 |
+
|
411 |
+
# update cls_labels
|
412 |
+
for key in self._attn_keys_to_update:
|
413 |
+
setattr(self.attn_config, key, getattr(attn_config, key))
|
414 |
+
|
415 |
+
def unet_forward_hook(self, raw_forward, sample: torch.FloatTensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, cross_attention_kwargs=None, condition_latents=None, class_labels=None, noisy_condition_input=False, cond_pixels_clip=None, **kwargs):
|
416 |
+
if class_labels is None and len(self.class_labels) > 0:
|
417 |
+
class_labels = self.class_labels.repeat(sample.shape[0] // self.class_labels.shape[0]).to(sample.device)
|
418 |
+
elif self.attn_config.init_num_cls_label != 0:
|
419 |
+
assert class_labels is not None, "class_labels should be passed if self.class_labels is empty and self.attn_config.init_num_cls_label is not 0"
|
420 |
+
if class_labels is not None:
|
421 |
+
if self.attn_config.cls_label_type == "embedding":
|
422 |
+
pass
|
423 |
+
else:
|
424 |
+
raise ValueError(f"cls_label_type {self.attn_config.cls_label_type} is not supported")
|
425 |
+
if self.attn_config.init_self_attn_ref and self.attn_config.enable_self_attn_ref:
|
426 |
+
# NOTE: extra step, extract condition
|
427 |
+
ref_dict = {}
|
428 |
+
ref_unet = self.get_refunet().to(sample.device)
|
429 |
+
assert condition_latents is not None
|
430 |
+
if self.attn_config.self_attn_ref_other_model_name == "self":
|
431 |
+
raise NotImplementedError()
|
432 |
+
else:
|
433 |
+
with torch.no_grad():
|
434 |
+
cond_encoder_hidden_states = encoder_hidden_states.reshape(condition_latents.shape[0], -1, *encoder_hidden_states.shape[1:])[:, 0]
|
435 |
+
if timestep.dim() == 0:
|
436 |
+
cond_timestep = timestep
|
437 |
+
else:
|
438 |
+
cond_timestep = timestep.reshape(condition_latents.shape[0], -1)[:, 0]
|
439 |
+
ref_unet(condition_latents, cond_timestep, cond_encoder_hidden_states, cross_attention_kwargs=dict(ref_dict=ref_dict))
|
440 |
+
# NOTE: extra step, inject condition
|
441 |
+
# Predict the noise residual and compute loss
|
442 |
+
if cross_attention_kwargs is None:
|
443 |
+
cross_attention_kwargs = {}
|
444 |
+
cross_attention_kwargs.update(ref_dict=ref_dict, mode='inject')
|
445 |
+
elif condition_latents is not None:
|
446 |
+
if not hasattr(self, 'condition_latents_raised'):
|
447 |
+
print("Warning! condition_latents is not None, but self_attn_ref is not enabled! This warning will only be raised once.")
|
448 |
+
self.condition_latents_raised = True
|
449 |
+
|
450 |
+
if self.attn_config.init_cross_attn_ip:
|
451 |
+
raise NotImplementedError()
|
452 |
+
|
453 |
+
if self.attn_config.cat_condition:
|
454 |
+
assert condition_latents is not None
|
455 |
+
B = condition_latents.shape[0]
|
456 |
+
cat_latents = condition_latents.reshape(B, 1, *condition_latents.shape[1:]).repeat(1, sample.shape[0] // B, 1, 1, 1).reshape(*sample.shape)
|
457 |
+
sample = torch.cat([sample, cat_latents], dim=1)
|
458 |
+
|
459 |
+
return raw_forward(sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, **kwargs)
|
custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# modified by Wuvin
|
15 |
+
|
16 |
+
|
17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
|
23 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
24 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
|
25 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
26 |
+
from PIL import Image
|
27 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
class StableDiffusionImageCustomPipeline(
|
32 |
+
StableDiffusionImageVariationPipeline
|
33 |
+
):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
vae: AutoencoderKL,
|
37 |
+
image_encoder: CLIPVisionModelWithProjection,
|
38 |
+
unet: UNet2DConditionModel,
|
39 |
+
scheduler: KarrasDiffusionSchedulers,
|
40 |
+
safety_checker: StableDiffusionSafetyChecker,
|
41 |
+
feature_extractor: CLIPImageProcessor,
|
42 |
+
requires_safety_checker: bool = True,
|
43 |
+
latents_offset=None,
|
44 |
+
noisy_cond_latents=False,
|
45 |
+
):
|
46 |
+
super().__init__(
|
47 |
+
vae=vae,
|
48 |
+
image_encoder=image_encoder,
|
49 |
+
unet=unet,
|
50 |
+
scheduler=scheduler,
|
51 |
+
safety_checker=safety_checker,
|
52 |
+
feature_extractor=feature_extractor,
|
53 |
+
requires_safety_checker=requires_safety_checker
|
54 |
+
)
|
55 |
+
latents_offset = tuple(latents_offset) if latents_offset is not None else None
|
56 |
+
self.latents_offset = latents_offset
|
57 |
+
if latents_offset is not None:
|
58 |
+
self.register_to_config(latents_offset=latents_offset)
|
59 |
+
self.noisy_cond_latents = noisy_cond_latents
|
60 |
+
self.register_to_config(noisy_cond_latents=noisy_cond_latents)
|
61 |
+
|
62 |
+
def encode_latents(self, image, device, dtype, height, width):
|
63 |
+
# support batchsize > 1
|
64 |
+
if isinstance(image, Image.Image):
|
65 |
+
image = [image]
|
66 |
+
image = [img.convert("RGB") for img in image]
|
67 |
+
images = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype)
|
68 |
+
latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
|
69 |
+
if self.latents_offset is not None:
|
70 |
+
return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
|
71 |
+
else:
|
72 |
+
return latents
|
73 |
+
|
74 |
+
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
|
75 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
76 |
+
|
77 |
+
if not isinstance(image, torch.Tensor):
|
78 |
+
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
79 |
+
|
80 |
+
image = image.to(device=device, dtype=dtype)
|
81 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
82 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
83 |
+
|
84 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
85 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
86 |
+
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
|
87 |
+
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
88 |
+
|
89 |
+
if do_classifier_free_guidance:
|
90 |
+
# NOTE: the same as original code
|
91 |
+
negative_prompt_embeds = torch.zeros_like(image_embeddings)
|
92 |
+
# For classifier free guidance, we need to do two forward passes.
|
93 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
94 |
+
# to avoid doing two forward passes
|
95 |
+
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
|
96 |
+
|
97 |
+
return image_embeddings
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def __call__(
|
101 |
+
self,
|
102 |
+
image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
|
103 |
+
height: Optional[int] = 1024,
|
104 |
+
width: Optional[int] = 1024,
|
105 |
+
height_cond: Optional[int] = 512,
|
106 |
+
width_cond: Optional[int] = 512,
|
107 |
+
num_inference_steps: int = 50,
|
108 |
+
guidance_scale: float = 7.5,
|
109 |
+
num_images_per_prompt: Optional[int] = 1,
|
110 |
+
eta: float = 0.0,
|
111 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
112 |
+
latents: Optional[torch.FloatTensor] = None,
|
113 |
+
output_type: Optional[str] = "pil",
|
114 |
+
return_dict: bool = True,
|
115 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
116 |
+
callback_steps: int = 1,
|
117 |
+
upper_left_feature: bool = False,
|
118 |
+
):
|
119 |
+
r"""
|
120 |
+
The call function to the pipeline for generation.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
|
124 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
125 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
126 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
127 |
+
The height in pixels of the generated image.
|
128 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
129 |
+
The width in pixels of the generated image.
|
130 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
131 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
132 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
133 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
134 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
135 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
136 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
137 |
+
The number of images to generate per prompt.
|
138 |
+
eta (`float`, *optional*, defaults to 0.0):
|
139 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
140 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
141 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
142 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
143 |
+
generation deterministic.
|
144 |
+
latents (`torch.FloatTensor`, *optional*):
|
145 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
146 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
147 |
+
tensor is generated by sampling using the supplied random `generator`.
|
148 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
149 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
150 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
151 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
152 |
+
plain tuple.
|
153 |
+
callback (`Callable`, *optional*):
|
154 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
155 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
156 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
157 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
158 |
+
every step.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
162 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
163 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
164 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
165 |
+
"not-safe-for-work" (nsfw) content.
|
166 |
+
|
167 |
+
Examples:
|
168 |
+
|
169 |
+
```py
|
170 |
+
from diffusers import StableDiffusionImageVariationPipeline
|
171 |
+
from PIL import Image
|
172 |
+
from io import BytesIO
|
173 |
+
import requests
|
174 |
+
|
175 |
+
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
176 |
+
"lambdalabs/sd-image-variations-diffusers", revision="v2.0"
|
177 |
+
)
|
178 |
+
pipe = pipe.to("cuda")
|
179 |
+
|
180 |
+
url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
|
181 |
+
|
182 |
+
response = requests.get(url)
|
183 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
184 |
+
|
185 |
+
out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
|
186 |
+
out["images"][0].save("result.jpg")
|
187 |
+
```
|
188 |
+
"""
|
189 |
+
# 0. Default height and width to unet
|
190 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
191 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
192 |
+
|
193 |
+
# 1. Check inputs. Raise error if not correct
|
194 |
+
self.check_inputs(image, height, width, callback_steps)
|
195 |
+
|
196 |
+
# 2. Define call parameters
|
197 |
+
if isinstance(image, Image.Image):
|
198 |
+
batch_size = 1
|
199 |
+
elif isinstance(image, list):
|
200 |
+
batch_size = len(image)
|
201 |
+
else:
|
202 |
+
batch_size = image.shape[0]
|
203 |
+
device = self._execution_device
|
204 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
205 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
206 |
+
# corresponds to doing no classifier free guidance.
|
207 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
208 |
+
|
209 |
+
# 3. Encode input image
|
210 |
+
if isinstance(image, Image.Image) and upper_left_feature:
|
211 |
+
# only use the first one of four images
|
212 |
+
emb_image = image.crop((0, 0, image.size[0] // 2, image.size[1] // 2))
|
213 |
+
else:
|
214 |
+
emb_image = image
|
215 |
+
|
216 |
+
image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
|
217 |
+
cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
|
218 |
+
|
219 |
+
# 4. Prepare timesteps
|
220 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
221 |
+
timesteps = self.scheduler.timesteps
|
222 |
+
|
223 |
+
# 5. Prepare latent variables
|
224 |
+
num_channels_latents = self.unet.config.out_channels
|
225 |
+
latents = self.prepare_latents(
|
226 |
+
batch_size * num_images_per_prompt,
|
227 |
+
num_channels_latents,
|
228 |
+
height,
|
229 |
+
width,
|
230 |
+
image_embeddings.dtype,
|
231 |
+
device,
|
232 |
+
generator,
|
233 |
+
latents,
|
234 |
+
)
|
235 |
+
|
236 |
+
# 6. Prepare extra step kwargs.
|
237 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
238 |
+
|
239 |
+
# 7. Denoising loop
|
240 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
241 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
242 |
+
for i, t in enumerate(timesteps):
|
243 |
+
if self.noisy_cond_latents:
|
244 |
+
raise ValueError("Noisy condition latents is not recommended.")
|
245 |
+
else:
|
246 |
+
noisy_cond_latents = cond_latents
|
247 |
+
|
248 |
+
noisy_cond_latents = torch.cat([torch.zeros_like(noisy_cond_latents), noisy_cond_latents]) if do_classifier_free_guidance else noisy_cond_latents
|
249 |
+
# expand the latents if we are doing classifier free guidance
|
250 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
251 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
252 |
+
|
253 |
+
# predict the noise residual
|
254 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=noisy_cond_latents).sample
|
255 |
+
|
256 |
+
# perform guidance
|
257 |
+
if do_classifier_free_guidance:
|
258 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
259 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
260 |
+
|
261 |
+
# compute the previous noisy sample x_t -> x_t-1
|
262 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
263 |
+
|
264 |
+
# call the callback, if provided
|
265 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
266 |
+
progress_bar.update()
|
267 |
+
if callback is not None and i % callback_steps == 0:
|
268 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
269 |
+
callback(step_idx, t, latents)
|
270 |
+
|
271 |
+
self.maybe_free_model_hooks()
|
272 |
+
|
273 |
+
if self.latents_offset is not None:
|
274 |
+
latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
|
275 |
+
|
276 |
+
if not output_type == "latent":
|
277 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
278 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
|
279 |
+
else:
|
280 |
+
image = latents
|
281 |
+
has_nsfw_concept = None
|
282 |
+
|
283 |
+
if has_nsfw_concept is None:
|
284 |
+
do_denormalize = [True] * image.shape[0]
|
285 |
+
else:
|
286 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
287 |
+
|
288 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
289 |
+
|
290 |
+
self.maybe_free_model_hooks()
|
291 |
+
|
292 |
+
if not return_dict:
|
293 |
+
return (image, has_nsfw_concept)
|
294 |
+
|
295 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
296 |
+
|
297 |
+
if __name__ == "__main__":
|
298 |
+
pass
|
custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# modified by Wuvin
|
15 |
+
|
16 |
+
|
17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
|
23 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler
|
24 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
|
25 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
26 |
+
from PIL import Image
|
27 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
class StableDiffusionImage2MVCustomPipeline(
|
32 |
+
StableDiffusionImageVariationPipeline
|
33 |
+
):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
vae: AutoencoderKL,
|
37 |
+
image_encoder: CLIPVisionModelWithProjection,
|
38 |
+
unet: UNet2DConditionModel,
|
39 |
+
scheduler: KarrasDiffusionSchedulers,
|
40 |
+
safety_checker: StableDiffusionSafetyChecker,
|
41 |
+
feature_extractor: CLIPImageProcessor,
|
42 |
+
requires_safety_checker: bool = True,
|
43 |
+
latents_offset=None,
|
44 |
+
noisy_cond_latents=False,
|
45 |
+
condition_offset=True,
|
46 |
+
):
|
47 |
+
super().__init__(
|
48 |
+
vae=vae,
|
49 |
+
image_encoder=image_encoder,
|
50 |
+
unet=unet,
|
51 |
+
scheduler=scheduler,
|
52 |
+
safety_checker=safety_checker,
|
53 |
+
feature_extractor=feature_extractor,
|
54 |
+
requires_safety_checker=requires_safety_checker
|
55 |
+
)
|
56 |
+
latents_offset = tuple(latents_offset) if latents_offset is not None else None
|
57 |
+
self.latents_offset = latents_offset
|
58 |
+
if latents_offset is not None:
|
59 |
+
self.register_to_config(latents_offset=latents_offset)
|
60 |
+
if noisy_cond_latents:
|
61 |
+
raise NotImplementedError("Noisy condition latents not supported Now.")
|
62 |
+
self.condition_offset = condition_offset
|
63 |
+
self.register_to_config(condition_offset=condition_offset)
|
64 |
+
|
65 |
+
def encode_latents(self, image: Image.Image, device, dtype, height, width):
|
66 |
+
images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype)
|
67 |
+
# NOTE: .mode() for condition
|
68 |
+
latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
|
69 |
+
if self.latents_offset is not None and self.condition_offset:
|
70 |
+
return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
|
71 |
+
else:
|
72 |
+
return latents
|
73 |
+
|
74 |
+
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
|
75 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
76 |
+
|
77 |
+
if not isinstance(image, torch.Tensor):
|
78 |
+
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
79 |
+
|
80 |
+
image = image.to(device=device, dtype=dtype)
|
81 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
82 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
83 |
+
|
84 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
85 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
86 |
+
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
|
87 |
+
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
88 |
+
|
89 |
+
if do_classifier_free_guidance:
|
90 |
+
# NOTE: the same as original code
|
91 |
+
negative_prompt_embeds = torch.zeros_like(image_embeddings)
|
92 |
+
# For classifier free guidance, we need to do two forward passes.
|
93 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
94 |
+
# to avoid doing two forward passes
|
95 |
+
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
|
96 |
+
|
97 |
+
return image_embeddings
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def __call__(
|
101 |
+
self,
|
102 |
+
image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
|
103 |
+
height: Optional[int] = 1024,
|
104 |
+
width: Optional[int] = 1024,
|
105 |
+
height_cond: Optional[int] = 512,
|
106 |
+
width_cond: Optional[int] = 512,
|
107 |
+
num_inference_steps: int = 50,
|
108 |
+
guidance_scale: float = 7.5,
|
109 |
+
num_images_per_prompt: Optional[int] = 1,
|
110 |
+
eta: float = 0.0,
|
111 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
112 |
+
latents: Optional[torch.FloatTensor] = None,
|
113 |
+
output_type: Optional[str] = "pil",
|
114 |
+
return_dict: bool = True,
|
115 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
116 |
+
callback_steps: int = 1,
|
117 |
+
):
|
118 |
+
r"""
|
119 |
+
The call function to the pipeline for generation.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
|
123 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
124 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
125 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
126 |
+
The height in pixels of the generated image.
|
127 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
128 |
+
The width in pixels of the generated image.
|
129 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
130 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
131 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
132 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
133 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
134 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
135 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
136 |
+
The number of images to generate per prompt.
|
137 |
+
eta (`float`, *optional*, defaults to 0.0):
|
138 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
139 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
140 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
141 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
142 |
+
generation deterministic.
|
143 |
+
latents (`torch.FloatTensor`, *optional*):
|
144 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
145 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
146 |
+
tensor is generated by sampling using the supplied random `generator`.
|
147 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
148 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
149 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
150 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
151 |
+
plain tuple.
|
152 |
+
callback (`Callable`, *optional*):
|
153 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
154 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
155 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
156 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
157 |
+
every step.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
161 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
162 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
163 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
164 |
+
"not-safe-for-work" (nsfw) content.
|
165 |
+
|
166 |
+
Examples:
|
167 |
+
|
168 |
+
```py
|
169 |
+
from diffusers import StableDiffusionImageVariationPipeline
|
170 |
+
from PIL import Image
|
171 |
+
from io import BytesIO
|
172 |
+
import requests
|
173 |
+
|
174 |
+
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
175 |
+
"lambdalabs/sd-image-variations-diffusers", revision="v2.0"
|
176 |
+
)
|
177 |
+
pipe = pipe.to("cuda")
|
178 |
+
|
179 |
+
url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
|
180 |
+
|
181 |
+
response = requests.get(url)
|
182 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
183 |
+
|
184 |
+
out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
|
185 |
+
out["images"][0].save("result.jpg")
|
186 |
+
```
|
187 |
+
"""
|
188 |
+
# 0. Default height and width to unet
|
189 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
190 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
191 |
+
|
192 |
+
# 1. Check inputs. Raise error if not correct
|
193 |
+
self.check_inputs(image, height, width, callback_steps)
|
194 |
+
|
195 |
+
# 2. Define call parameters
|
196 |
+
if isinstance(image, Image.Image):
|
197 |
+
batch_size = 1
|
198 |
+
elif len(image) == 1:
|
199 |
+
image = image[0]
|
200 |
+
batch_size = 1
|
201 |
+
else:
|
202 |
+
raise NotImplementedError()
|
203 |
+
# elif isinstance(image, list):
|
204 |
+
# batch_size = len(image)
|
205 |
+
# else:
|
206 |
+
# batch_size = image.shape[0]
|
207 |
+
device = self._execution_device
|
208 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
209 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
210 |
+
# corresponds to doing no classifier free guidance.
|
211 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
212 |
+
|
213 |
+
# 3. Encode input image
|
214 |
+
emb_image = image
|
215 |
+
|
216 |
+
image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
|
217 |
+
cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
|
218 |
+
cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents
|
219 |
+
image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values
|
220 |
+
if do_classifier_free_guidance:
|
221 |
+
image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0)
|
222 |
+
|
223 |
+
# 4. Prepare timesteps
|
224 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
225 |
+
timesteps = self.scheduler.timesteps
|
226 |
+
|
227 |
+
# 5. Prepare latent variables
|
228 |
+
num_channels_latents = self.unet.config.out_channels
|
229 |
+
latents = self.prepare_latents(
|
230 |
+
batch_size * num_images_per_prompt,
|
231 |
+
num_channels_latents,
|
232 |
+
height,
|
233 |
+
width,
|
234 |
+
image_embeddings.dtype,
|
235 |
+
device,
|
236 |
+
generator,
|
237 |
+
latents,
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
# 6. Prepare extra step kwargs.
|
242 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
243 |
+
# 7. Denoising loop
|
244 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
245 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
246 |
+
for i, t in enumerate(timesteps):
|
247 |
+
# expand the latents if we are doing classifier free guidance
|
248 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
249 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
250 |
+
|
251 |
+
# predict the noise residual
|
252 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=cond_latents, noisy_condition_input=False, cond_pixels_clip=image_pixels).sample
|
253 |
+
|
254 |
+
# perform guidance
|
255 |
+
if do_classifier_free_guidance:
|
256 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
257 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
258 |
+
|
259 |
+
# compute the previous noisy sample x_t -> x_t-1
|
260 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
261 |
+
|
262 |
+
# call the callback, if provided
|
263 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
264 |
+
progress_bar.update()
|
265 |
+
if callback is not None and i % callback_steps == 0:
|
266 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
267 |
+
callback(step_idx, t, latents)
|
268 |
+
|
269 |
+
self.maybe_free_model_hooks()
|
270 |
+
|
271 |
+
if self.latents_offset is not None:
|
272 |
+
latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
|
273 |
+
|
274 |
+
if not output_type == "latent":
|
275 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
276 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
|
277 |
+
else:
|
278 |
+
image = latents
|
279 |
+
has_nsfw_concept = None
|
280 |
+
|
281 |
+
if has_nsfw_concept is None:
|
282 |
+
do_denormalize = [True] * image.shape[0]
|
283 |
+
else:
|
284 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
285 |
+
|
286 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
287 |
+
|
288 |
+
self.maybe_free_model_hooks()
|
289 |
+
|
290 |
+
if not return_dict:
|
291 |
+
return (image, has_nsfw_concept)
|
292 |
+
|
293 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
294 |
+
|
295 |
+
if __name__ == "__main__":
|
296 |
+
pass
|
custum_3d_diffusion/modules.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__modules__ = {}
|
2 |
+
|
3 |
+
def register(name):
|
4 |
+
def decorator(cls):
|
5 |
+
__modules__[name] = cls
|
6 |
+
return cls
|
7 |
+
|
8 |
+
return decorator
|
9 |
+
|
10 |
+
|
11 |
+
def find(name):
|
12 |
+
return __modules__[name]
|
13 |
+
|
14 |
+
from custum_3d_diffusion.trainings import base, image2mvimage_trainer, image2image_trainer
|
custum_3d_diffusion/trainings/__init__.py
ADDED
File without changes
|
custum_3d_diffusion/trainings/base.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from accelerate import Accelerator
|
3 |
+
from accelerate.logging import MultiProcessAdapter
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from typing import Optional, Union
|
6 |
+
from datasets import load_dataset
|
7 |
+
import json
|
8 |
+
import abc
|
9 |
+
from diffusers.utils import make_image_grid
|
10 |
+
import numpy as np
|
11 |
+
import wandb
|
12 |
+
|
13 |
+
from custum_3d_diffusion.trainings.utils import load_config
|
14 |
+
from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig
|
15 |
+
|
16 |
+
class BasicTrainer(torch.nn.Module, abc.ABC):
|
17 |
+
accelerator: Accelerator
|
18 |
+
logger: MultiProcessAdapter
|
19 |
+
unet: ConfigurableUNet2DConditionModel
|
20 |
+
train_dataloader: torch.utils.data.DataLoader
|
21 |
+
test_dataset: torch.utils.data.Dataset
|
22 |
+
attn_config: AttnConfig
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class TrainerConfig:
|
26 |
+
trainer_name: str = "basic"
|
27 |
+
pretrained_model_name_or_path: str = ""
|
28 |
+
|
29 |
+
attn_config: dict = field(default_factory=dict)
|
30 |
+
dataset_name: str = ""
|
31 |
+
dataset_config_name: Optional[str] = None
|
32 |
+
resolution: str = "1024"
|
33 |
+
dataloader_num_workers: int = 4
|
34 |
+
pair_sampler_group_size: int = 1
|
35 |
+
num_views: int = 4
|
36 |
+
|
37 |
+
max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps)
|
38 |
+
training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps
|
39 |
+
max_train_samples: Optional[int] = None
|
40 |
+
seed: Optional[int] = None # For dataset related operations and validation stuff
|
41 |
+
train_batch_size: int = 1
|
42 |
+
|
43 |
+
validation_interval: int = 5000
|
44 |
+
debug: bool = False
|
45 |
+
|
46 |
+
cfg: TrainerConfig # only enable_xxx is used
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
accelerator: Accelerator,
|
51 |
+
logger: MultiProcessAdapter,
|
52 |
+
unet: ConfigurableUNet2DConditionModel,
|
53 |
+
config: Union[dict, str],
|
54 |
+
weight_dtype: torch.dtype,
|
55 |
+
index: int,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.index = index # index in all trainers
|
59 |
+
self.accelerator = accelerator
|
60 |
+
self.logger = logger
|
61 |
+
self.unet = unet
|
62 |
+
self.weight_dtype = weight_dtype
|
63 |
+
self.ext_logs = {}
|
64 |
+
self.cfg = load_config(self.TrainerConfig, config)
|
65 |
+
self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
|
66 |
+
self.test_dataset = None
|
67 |
+
self.validate_trainer_config()
|
68 |
+
self.configure()
|
69 |
+
|
70 |
+
def get_HW(self):
|
71 |
+
resolution = json.loads(self.cfg.resolution)
|
72 |
+
if isinstance(resolution, int):
|
73 |
+
H = W = resolution
|
74 |
+
elif isinstance(resolution, list):
|
75 |
+
H, W = resolution
|
76 |
+
return H, W
|
77 |
+
|
78 |
+
def unet_update(self):
|
79 |
+
self.unet.update_config(self.attn_config)
|
80 |
+
|
81 |
+
def validate_trainer_config(self):
|
82 |
+
pass
|
83 |
+
|
84 |
+
def is_train_finished(self, current_step):
|
85 |
+
assert isinstance(self.cfg.max_train_steps, int)
|
86 |
+
return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
|
87 |
+
|
88 |
+
def next_train_step(self, current_step):
|
89 |
+
if self.is_train_finished(current_step):
|
90 |
+
return None
|
91 |
+
return current_step + self.cfg.training_step_interval
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
|
95 |
+
catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
|
96 |
+
return make_image_grid(catted, rows=1, cols=len(catted))
|
97 |
+
|
98 |
+
def configure(self) -> None:
|
99 |
+
pass
|
100 |
+
|
101 |
+
@abc.abstractmethod
|
102 |
+
def init_shared_modules(self, shared_modules: dict) -> dict:
|
103 |
+
pass
|
104 |
+
|
105 |
+
def load_dataset(self):
|
106 |
+
dataset = load_dataset(
|
107 |
+
self.cfg.dataset_name,
|
108 |
+
self.cfg.dataset_config_name,
|
109 |
+
trust_remote_code=True
|
110 |
+
)
|
111 |
+
return dataset
|
112 |
+
|
113 |
+
@abc.abstractmethod
|
114 |
+
def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
|
115 |
+
"""Both init train_dataloader and test_dataset, but returns train_dataloader only"""
|
116 |
+
pass
|
117 |
+
|
118 |
+
@abc.abstractmethod
|
119 |
+
def forward_step(
|
120 |
+
self,
|
121 |
+
*args,
|
122 |
+
**kwargs
|
123 |
+
) -> torch.Tensor:
|
124 |
+
"""
|
125 |
+
input a batch
|
126 |
+
return a loss
|
127 |
+
"""
|
128 |
+
self.unet_update()
|
129 |
+
pass
|
130 |
+
|
131 |
+
@abc.abstractmethod
|
132 |
+
def construct_pipeline(self, shared_modules, unet):
|
133 |
+
pass
|
134 |
+
|
135 |
+
@abc.abstractmethod
|
136 |
+
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
|
137 |
+
"""
|
138 |
+
For inference time forward.
|
139 |
+
"""
|
140 |
+
pass
|
141 |
+
|
142 |
+
@abc.abstractmethod
|
143 |
+
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
|
144 |
+
pass
|
145 |
+
|
146 |
+
def do_validation(
|
147 |
+
self,
|
148 |
+
shared_modules,
|
149 |
+
unet,
|
150 |
+
global_step,
|
151 |
+
):
|
152 |
+
self.unet_update()
|
153 |
+
self.logger.info("Running validation... ")
|
154 |
+
pipeline = self.construct_pipeline(shared_modules, unet)
|
155 |
+
pipeline.set_progress_bar_config(disable=True)
|
156 |
+
titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
|
157 |
+
for tracker in self.accelerator.trackers:
|
158 |
+
if tracker.name == "tensorboard":
|
159 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
160 |
+
tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
|
161 |
+
elif tracker.name == "wandb":
|
162 |
+
[image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation
|
163 |
+
tracker.log({"validation": [
|
164 |
+
wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
|
165 |
+
for i, image in enumerate(images)]})
|
166 |
+
else:
|
167 |
+
self.logger.warn(f"image logging not implemented for {tracker.name}")
|
168 |
+
del pipeline
|
169 |
+
torch.cuda.empty_cache()
|
170 |
+
return images
|
171 |
+
|
172 |
+
|
173 |
+
@torch.no_grad()
|
174 |
+
def log_validation(
|
175 |
+
self,
|
176 |
+
shared_modules,
|
177 |
+
unet,
|
178 |
+
global_step,
|
179 |
+
force=False
|
180 |
+
):
|
181 |
+
if self.accelerator.is_main_process:
|
182 |
+
for tracker in self.accelerator.trackers:
|
183 |
+
if tracker.name == "wandb":
|
184 |
+
tracker.log(self.ext_logs)
|
185 |
+
self.ext_logs = {}
|
186 |
+
if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
|
187 |
+
self.unet_update()
|
188 |
+
if self.accelerator.is_main_process:
|
189 |
+
self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)
|
190 |
+
|
191 |
+
def save_model(self, unwrap_unet, shared_modules, save_dir):
|
192 |
+
if self.accelerator.is_main_process:
|
193 |
+
pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
|
194 |
+
pipeline.save_pretrained(save_dir)
|
195 |
+
self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")
|
196 |
+
|
197 |
+
def save_debug_info(self, save_name="debug", **kwargs):
|
198 |
+
if self.cfg.debug:
|
199 |
+
to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
|
200 |
+
import pickle
|
201 |
+
import os
|
202 |
+
if os.path.exists(f"{save_name}.pkl"):
|
203 |
+
for i in range(100):
|
204 |
+
if not os.path.exists(f"{save_name}_v{i}.pkl"):
|
205 |
+
save_name = f"{save_name}_v{i}"
|
206 |
+
break
|
207 |
+
with open(f"{save_name}.pkl", "wb") as f:
|
208 |
+
pickle.dump(to_saves, f)
|
custum_3d_diffusion/trainings/config_classes.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class TrainerSubConfig:
|
7 |
+
trainer_type: str = ""
|
8 |
+
trainer: dict = field(default_factory=dict)
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class ExprimentConfig:
|
13 |
+
trainers: List[dict] = field(default_factory=lambda: [])
|
14 |
+
init_config: dict = field(default_factory=dict)
|
15 |
+
pretrained_model_name_or_path: str = ""
|
16 |
+
pretrained_unet_state_dict_path: str = ""
|
17 |
+
# expriments related parameters
|
18 |
+
linear_beta_schedule: bool = False
|
19 |
+
zero_snr: bool = False
|
20 |
+
prediction_type: Optional[str] = None
|
21 |
+
seed: Optional[int] = None
|
22 |
+
max_train_steps: int = 1000000
|
23 |
+
gradient_accumulation_steps: int = 1
|
24 |
+
learning_rate: float = 1e-4
|
25 |
+
lr_scheduler: str = "constant"
|
26 |
+
lr_warmup_steps: int = 500
|
27 |
+
use_8bit_adam: bool = False
|
28 |
+
adam_beta1: float = 0.9
|
29 |
+
adam_beta2: float = 0.999
|
30 |
+
adam_weight_decay: float = 1e-2
|
31 |
+
adam_epsilon: float = 1e-08
|
32 |
+
max_grad_norm: float = 1.0
|
33 |
+
mixed_precision: Optional[str] = None # ["no", "fp16", "bf16", "fp8"]
|
34 |
+
skip_training: bool = False
|
35 |
+
debug: bool = False
|
custum_3d_diffusion/trainings/image2image_trainer.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from diffusers import EulerAncestralDiscreteScheduler, DDPMScheduler
|
4 |
+
from dataclasses import dataclass
|
5 |
+
|
6 |
+
from custum_3d_diffusion.modules import register
|
7 |
+
from custum_3d_diffusion.trainings.image2mvimage_trainer import Image2MVImageTrainer
|
8 |
+
from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2img import StableDiffusionImageCustomPipeline
|
9 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
10 |
+
|
11 |
+
def get_HW(resolution):
|
12 |
+
if isinstance(resolution, str):
|
13 |
+
resolution = json.loads(resolution)
|
14 |
+
if isinstance(resolution, int):
|
15 |
+
H = W = resolution
|
16 |
+
elif isinstance(resolution, list):
|
17 |
+
H, W = resolution
|
18 |
+
return H, W
|
19 |
+
|
20 |
+
|
21 |
+
@register("image2image_trainer")
|
22 |
+
class Image2ImageTrainer(Image2MVImageTrainer):
|
23 |
+
"""
|
24 |
+
Trainer for simple image to multiview images.
|
25 |
+
"""
|
26 |
+
@dataclass
|
27 |
+
class TrainerConfig(Image2MVImageTrainer.TrainerConfig):
|
28 |
+
trainer_name: str = "image2image"
|
29 |
+
|
30 |
+
cfg: TrainerConfig
|
31 |
+
|
32 |
+
def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
|
33 |
+
raise NotImplementedError()
|
34 |
+
|
35 |
+
def construct_pipeline(self, shared_modules, unet, old_version=False):
|
36 |
+
MyPipeline = StableDiffusionImageCustomPipeline
|
37 |
+
pipeline = MyPipeline.from_pretrained(
|
38 |
+
self.cfg.pretrained_model_name_or_path,
|
39 |
+
vae=shared_modules['vae'],
|
40 |
+
image_encoder=shared_modules['image_encoder'],
|
41 |
+
feature_extractor=shared_modules['feature_extractor'],
|
42 |
+
unet=unet,
|
43 |
+
safety_checker=None,
|
44 |
+
torch_dtype=self.weight_dtype,
|
45 |
+
latents_offset=self.cfg.latents_offset,
|
46 |
+
noisy_cond_latents=self.cfg.noisy_condition_input,
|
47 |
+
)
|
48 |
+
pipeline.set_progress_bar_config(disable=True)
|
49 |
+
scheduler_dict = {}
|
50 |
+
if self.cfg.zero_snr:
|
51 |
+
scheduler_dict.update(rescale_betas_zero_snr=True)
|
52 |
+
if self.cfg.linear_beta_schedule:
|
53 |
+
scheduler_dict.update(beta_schedule='linear')
|
54 |
+
|
55 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
|
56 |
+
return pipeline
|
57 |
+
|
58 |
+
def get_forward_args(self):
|
59 |
+
if self.cfg.seed is None:
|
60 |
+
generator = None
|
61 |
+
else:
|
62 |
+
generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
|
63 |
+
|
64 |
+
H, W = get_HW(self.cfg.resolution)
|
65 |
+
H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
|
66 |
+
|
67 |
+
forward_args = dict(
|
68 |
+
num_images_per_prompt=1,
|
69 |
+
num_inference_steps=20,
|
70 |
+
height=H,
|
71 |
+
width=W,
|
72 |
+
height_cond=H_cond,
|
73 |
+
width_cond=W_cond,
|
74 |
+
generator=generator,
|
75 |
+
)
|
76 |
+
if self.cfg.zero_snr:
|
77 |
+
forward_args.update(guidance_rescale=0.7)
|
78 |
+
return forward_args
|
79 |
+
|
80 |
+
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
|
81 |
+
forward_args = self.get_forward_args()
|
82 |
+
forward_args.update(pipeline_call_kwargs)
|
83 |
+
return pipeline(**forward_args)
|
84 |
+
|
85 |
+
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
|
86 |
+
raise NotImplementedError()
|
custum_3d_diffusion/trainings/image2mvimage_trainer.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler
|
3 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature
|
4 |
+
|
5 |
+
import json
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import List, Optional
|
8 |
+
|
9 |
+
from custum_3d_diffusion.modules import register
|
10 |
+
from custum_3d_diffusion.trainings.base import BasicTrainer
|
11 |
+
from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline
|
12 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
13 |
+
|
14 |
+
def get_HW(resolution):
|
15 |
+
if isinstance(resolution, str):
|
16 |
+
resolution = json.loads(resolution)
|
17 |
+
if isinstance(resolution, int):
|
18 |
+
H = W = resolution
|
19 |
+
elif isinstance(resolution, list):
|
20 |
+
H, W = resolution
|
21 |
+
return H, W
|
22 |
+
|
23 |
+
@register("image2mvimage_trainer")
|
24 |
+
class Image2MVImageTrainer(BasicTrainer):
|
25 |
+
"""
|
26 |
+
Trainer for simple image to multiview images.
|
27 |
+
"""
|
28 |
+
@dataclass
|
29 |
+
class TrainerConfig(BasicTrainer.TrainerConfig):
|
30 |
+
trainer_name: str = "image2mvimage"
|
31 |
+
condition_image_column_name: str = "conditioning_image"
|
32 |
+
image_column_name: str = "image"
|
33 |
+
condition_dropout: float = 0.
|
34 |
+
condition_image_resolution: str = "512"
|
35 |
+
validation_images: Optional[List[str]] = None
|
36 |
+
noise_offset: float = 0.1
|
37 |
+
max_loss_drop: float = 0.
|
38 |
+
snr_gamma: float = 5.0
|
39 |
+
log_distribution: bool = False
|
40 |
+
latents_offset: Optional[List[float]] = None
|
41 |
+
input_perturbation: float = 0.
|
42 |
+
noisy_condition_input: bool = False # whether to add noise for ref unet input
|
43 |
+
normal_cls_offset: int = 0
|
44 |
+
condition_offset: bool = True
|
45 |
+
zero_snr: bool = False
|
46 |
+
linear_beta_schedule: bool = False
|
47 |
+
|
48 |
+
cfg: TrainerConfig
|
49 |
+
|
50 |
+
def configure(self) -> None:
|
51 |
+
return super().configure()
|
52 |
+
|
53 |
+
def init_shared_modules(self, shared_modules: dict) -> dict:
|
54 |
+
if 'vae' not in shared_modules:
|
55 |
+
vae = AutoencoderKL.from_pretrained(
|
56 |
+
self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype
|
57 |
+
)
|
58 |
+
vae.requires_grad_(False)
|
59 |
+
vae.to(self.accelerator.device, dtype=self.weight_dtype)
|
60 |
+
shared_modules['vae'] = vae
|
61 |
+
if 'image_encoder' not in shared_modules:
|
62 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
63 |
+
self.cfg.pretrained_model_name_or_path, subfolder="image_encoder"
|
64 |
+
)
|
65 |
+
image_encoder.requires_grad_(False)
|
66 |
+
image_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
|
67 |
+
shared_modules['image_encoder'] = image_encoder
|
68 |
+
if 'feature_extractor' not in shared_modules:
|
69 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(
|
70 |
+
self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor"
|
71 |
+
)
|
72 |
+
shared_modules['feature_extractor'] = feature_extractor
|
73 |
+
return shared_modules
|
74 |
+
|
75 |
+
def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
|
76 |
+
raise NotImplementedError()
|
77 |
+
|
78 |
+
def loss_rescale(self, loss, timesteps=None):
|
79 |
+
raise NotImplementedError()
|
80 |
+
|
81 |
+
def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
|
82 |
+
raise NotImplementedError()
|
83 |
+
|
84 |
+
def construct_pipeline(self, shared_modules, unet, old_version=False):
|
85 |
+
MyPipeline = StableDiffusionImage2MVCustomPipeline
|
86 |
+
pipeline = MyPipeline.from_pretrained(
|
87 |
+
self.cfg.pretrained_model_name_or_path,
|
88 |
+
vae=shared_modules['vae'],
|
89 |
+
image_encoder=shared_modules['image_encoder'],
|
90 |
+
feature_extractor=shared_modules['feature_extractor'],
|
91 |
+
unet=unet,
|
92 |
+
safety_checker=None,
|
93 |
+
torch_dtype=self.weight_dtype,
|
94 |
+
latents_offset=self.cfg.latents_offset,
|
95 |
+
noisy_cond_latents=self.cfg.noisy_condition_input,
|
96 |
+
condition_offset=self.cfg.condition_offset,
|
97 |
+
)
|
98 |
+
pipeline.set_progress_bar_config(disable=True)
|
99 |
+
scheduler_dict = {}
|
100 |
+
if self.cfg.zero_snr:
|
101 |
+
scheduler_dict.update(rescale_betas_zero_snr=True)
|
102 |
+
if self.cfg.linear_beta_schedule:
|
103 |
+
scheduler_dict.update(beta_schedule='linear')
|
104 |
+
|
105 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
|
106 |
+
return pipeline
|
107 |
+
|
108 |
+
def get_forward_args(self):
|
109 |
+
if self.cfg.seed is None:
|
110 |
+
generator = None
|
111 |
+
else:
|
112 |
+
generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
|
113 |
+
|
114 |
+
H, W = get_HW(self.cfg.resolution)
|
115 |
+
H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
|
116 |
+
|
117 |
+
sub_img_H = H // 2
|
118 |
+
num_imgs = H // sub_img_H * W // sub_img_H
|
119 |
+
|
120 |
+
forward_args = dict(
|
121 |
+
num_images_per_prompt=num_imgs,
|
122 |
+
num_inference_steps=50,
|
123 |
+
height=sub_img_H,
|
124 |
+
width=sub_img_H,
|
125 |
+
height_cond=H_cond,
|
126 |
+
width_cond=W_cond,
|
127 |
+
generator=generator,
|
128 |
+
)
|
129 |
+
if self.cfg.zero_snr:
|
130 |
+
forward_args.update(guidance_rescale=0.7)
|
131 |
+
return forward_args
|
132 |
+
|
133 |
+
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
|
134 |
+
forward_args = self.get_forward_args()
|
135 |
+
forward_args.update(pipeline_call_kwargs)
|
136 |
+
return pipeline(**forward_args)
|
137 |
+
|
138 |
+
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
|
139 |
+
raise NotImplementedError()
|
custum_3d_diffusion/trainings/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from omegaconf import DictConfig, OmegaConf
|
2 |
+
|
3 |
+
|
4 |
+
def parse_structured(fields, cfg) -> DictConfig:
|
5 |
+
scfg = OmegaConf.structured(fields(**cfg))
|
6 |
+
return scfg
|
7 |
+
|
8 |
+
|
9 |
+
def load_config(fields, config, extras=None):
|
10 |
+
if extras is not None:
|
11 |
+
print("Warning! extra parameter in cli is not verified, may cause erros.")
|
12 |
+
if isinstance(config, str):
|
13 |
+
cfg = OmegaConf.load(config)
|
14 |
+
elif isinstance(config, dict):
|
15 |
+
cfg = OmegaConf.create(config)
|
16 |
+
elif isinstance(config, DictConfig):
|
17 |
+
cfg = config
|
18 |
+
else:
|
19 |
+
raise NotImplementedError(f"Unsupported config type {type(config)}")
|
20 |
+
if extras is not None:
|
21 |
+
cli_conf = OmegaConf.from_cli(extras)
|
22 |
+
cfg = OmegaConf.merge(cfg, cli_conf)
|
23 |
+
OmegaConf.resolve(cfg)
|
24 |
+
assert isinstance(cfg, DictConfig)
|
25 |
+
return parse_structured(fields, cfg)
|
gradio_app/__init__.py
ADDED
File without changes
|
gradio_app/all_models.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from scripts.sd_model_zoo import load_common_sd15_pipe
|
3 |
+
from diffusers import StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
|
4 |
+
|
5 |
+
|
6 |
+
class MyModelZoo:
|
7 |
+
_pipe_disney_controlnet_lineart_ipadapter_i2i: StableDiffusionControlNetImg2ImgPipeline = None
|
8 |
+
|
9 |
+
base_model = "runwayml/stable-diffusion-v1-5"
|
10 |
+
|
11 |
+
def __init__(self, base_model=None) -> None:
|
12 |
+
if base_model is not None:
|
13 |
+
self.base_model = base_model
|
14 |
+
|
15 |
+
@property
|
16 |
+
def pipe_disney_controlnet_tile_ipadapter_i2i(self):
|
17 |
+
return self._pipe_disney_controlnet_lineart_ipadapter_i2i
|
18 |
+
|
19 |
+
def init_models(self):
|
20 |
+
self._pipe_disney_controlnet_lineart_ipadapter_i2i = load_common_sd15_pipe(base_model=self.base_model, ip_adapter=True, plus_model=False, controlnet="./ckpt/controlnet-tile", pipeline_class=StableDiffusionControlNetImg2ImgPipeline)
|
21 |
+
|
22 |
+
model_zoo = MyModelZoo()
|
gradio_app/custom_models/image2mvimage.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_model_name_or_path: "./ckpt/img2mvimg"
|
2 |
+
mixed_precision: "bf16"
|
3 |
+
|
4 |
+
init_config:
|
5 |
+
# enable controls
|
6 |
+
enable_cross_attn_lora: False
|
7 |
+
enable_cross_attn_ip: False
|
8 |
+
enable_self_attn_lora: False
|
9 |
+
enable_self_attn_ref: False
|
10 |
+
enable_multiview_attn: True
|
11 |
+
|
12 |
+
# for cross attention
|
13 |
+
init_cross_attn_lora: False
|
14 |
+
init_cross_attn_ip: False
|
15 |
+
cross_attn_lora_rank: 256 # 0 for not enabled
|
16 |
+
cross_attn_lora_only_kv: False
|
17 |
+
ipadapter_pretrained_name: "h94/IP-Adapter"
|
18 |
+
ipadapter_subfolder_name: "models"
|
19 |
+
ipadapter_weight_name: "ip-adapter_sd15.safetensors"
|
20 |
+
ipadapter_effect_on: "all" # all, first
|
21 |
+
|
22 |
+
# for self attention
|
23 |
+
init_self_attn_lora: False
|
24 |
+
self_attn_lora_rank: 256
|
25 |
+
self_attn_lora_only_kv: False
|
26 |
+
|
27 |
+
# for self attention ref
|
28 |
+
init_self_attn_ref: False
|
29 |
+
self_attn_ref_position: "attn1"
|
30 |
+
self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
|
31 |
+
self_attn_ref_pixel_wise_crosspond: False
|
32 |
+
self_attn_ref_effect_on: "all"
|
33 |
+
|
34 |
+
# for multiview attention
|
35 |
+
init_multiview_attn: True
|
36 |
+
multiview_attn_position: "attn1"
|
37 |
+
use_mv_joint_attn: True
|
38 |
+
num_modalities: 1
|
39 |
+
|
40 |
+
# for unet
|
41 |
+
init_unet_path: "${pretrained_model_name_or_path}"
|
42 |
+
cat_condition: True # cat condition to input
|
43 |
+
|
44 |
+
# for cls embedding
|
45 |
+
init_num_cls_label: 8 # for initialize
|
46 |
+
cls_labels: [0, 1, 2, 3] # for current task
|
47 |
+
|
48 |
+
trainers:
|
49 |
+
- trainer_type: "image2mvimage_trainer"
|
50 |
+
trainer:
|
51 |
+
pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
|
52 |
+
attn_config:
|
53 |
+
cls_labels: [0, 1, 2, 3] # for current task
|
54 |
+
enable_cross_attn_lora: False
|
55 |
+
enable_cross_attn_ip: False
|
56 |
+
enable_self_attn_lora: False
|
57 |
+
enable_self_attn_ref: False
|
58 |
+
enable_multiview_attn: True
|
59 |
+
resolution: "256"
|
60 |
+
condition_image_resolution: "256"
|
61 |
+
normal_cls_offset: 4
|
62 |
+
condition_image_column_name: "conditioning_image"
|
63 |
+
image_column_name: "image"
|
gradio_app/custom_models/image2normal.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_model_name_or_path: "lambdalabs/sd-image-variations-diffusers"
|
2 |
+
mixed_precision: "bf16"
|
3 |
+
|
4 |
+
init_config:
|
5 |
+
# enable controls
|
6 |
+
enable_cross_attn_lora: False
|
7 |
+
enable_cross_attn_ip: False
|
8 |
+
enable_self_attn_lora: False
|
9 |
+
enable_self_attn_ref: True
|
10 |
+
enable_multiview_attn: False
|
11 |
+
|
12 |
+
# for cross attention
|
13 |
+
init_cross_attn_lora: False
|
14 |
+
init_cross_attn_ip: False
|
15 |
+
cross_attn_lora_rank: 512 # 0 for not enabled
|
16 |
+
cross_attn_lora_only_kv: False
|
17 |
+
ipadapter_pretrained_name: "h94/IP-Adapter"
|
18 |
+
ipadapter_subfolder_name: "models"
|
19 |
+
ipadapter_weight_name: "ip-adapter_sd15.safetensors"
|
20 |
+
ipadapter_effect_on: "all" # all, first
|
21 |
+
|
22 |
+
# for self attention
|
23 |
+
init_self_attn_lora: False
|
24 |
+
self_attn_lora_rank: 512
|
25 |
+
self_attn_lora_only_kv: False
|
26 |
+
|
27 |
+
# for self attention ref
|
28 |
+
init_self_attn_ref: True
|
29 |
+
self_attn_ref_position: "attn1"
|
30 |
+
self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
|
31 |
+
self_attn_ref_pixel_wise_crosspond: True
|
32 |
+
self_attn_ref_effect_on: "all"
|
33 |
+
|
34 |
+
# for multiview attention
|
35 |
+
init_multiview_attn: False
|
36 |
+
multiview_attn_position: "attn1"
|
37 |
+
num_modalities: 1
|
38 |
+
|
39 |
+
# for unet
|
40 |
+
init_unet_path: "${pretrained_model_name_or_path}"
|
41 |
+
init_num_cls_label: 0 # for initialize
|
42 |
+
cls_labels: [] # for current task
|
43 |
+
|
44 |
+
trainers:
|
45 |
+
- trainer_type: "image2image_trainer"
|
46 |
+
trainer:
|
47 |
+
pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
|
48 |
+
attn_config:
|
49 |
+
cls_labels: [] # for current task
|
50 |
+
enable_cross_attn_lora: False
|
51 |
+
enable_cross_attn_ip: False
|
52 |
+
enable_self_attn_lora: False
|
53 |
+
enable_self_attn_ref: True
|
54 |
+
enable_multiview_attn: False
|
55 |
+
resolution: "512"
|
56 |
+
condition_image_resolution: "512"
|
57 |
+
condition_image_column_name: "conditioning_image"
|
58 |
+
image_column_name: "image"
|
59 |
+
|
60 |
+
|
61 |
+
|
gradio_app/custom_models/mvimg_prediction.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from rembg import remove
|
7 |
+
from gradio_app.utils import change_rgba_bg, rgba_to_rgb
|
8 |
+
from gradio_app.custom_models.utils import load_pipeline
|
9 |
+
from scripts.all_typing import *
|
10 |
+
from scripts.utils import session, simple_preprocess
|
11 |
+
|
12 |
+
training_config = "gradio_app/custom_models/image2mvimage.yaml"
|
13 |
+
checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
|
14 |
+
|
15 |
+
trainer, pipeline = load_pipeline(training_config, checkpoint_path)
|
16 |
+
|
17 |
+
def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
|
18 |
+
global pipeline
|
19 |
+
pipeline = pipeline.to("cuda")
|
20 |
+
if isinstance(img_list, Image.Image):
|
21 |
+
img_list = [img_list]
|
22 |
+
img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
|
23 |
+
ret = []
|
24 |
+
for img in img_list:
|
25 |
+
images = trainer.pipeline_forward(
|
26 |
+
pipeline=pipeline,
|
27 |
+
image=img,
|
28 |
+
guidance_scale=guidance_scale,
|
29 |
+
**kwargs
|
30 |
+
).images
|
31 |
+
ret.extend(images)
|
32 |
+
return ret
|
33 |
+
|
34 |
+
|
35 |
+
def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145):
|
36 |
+
if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.:
|
37 |
+
# still do remove using rembg, since simple_preprocess requires RGBA image
|
38 |
+
print("RGB image not RGBA! still remove bg!")
|
39 |
+
remove_bg = True
|
40 |
+
|
41 |
+
if remove_bg:
|
42 |
+
input_image = remove(input_image, session=session)
|
43 |
+
|
44 |
+
# make front_pil RGBA with white bg
|
45 |
+
input_image = change_rgba_bg(input_image, "white")
|
46 |
+
single_image = simple_preprocess(input_image)
|
47 |
+
|
48 |
+
generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None
|
49 |
+
|
50 |
+
rgb_pils = predict(
|
51 |
+
single_image,
|
52 |
+
generator=generator,
|
53 |
+
guidance_scale=guidance_scale,
|
54 |
+
width=256,
|
55 |
+
height=256,
|
56 |
+
num_inference_steps=30,
|
57 |
+
)
|
58 |
+
|
59 |
+
return rgb_pils, single_image
|
gradio_app/custom_models/normal_prediction.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from PIL import Image
|
3 |
+
from gradio_app.utils import rgba_to_rgb, simple_remove
|
4 |
+
from gradio_app.custom_models.utils import load_pipeline
|
5 |
+
from scripts.utils import rotate_normals_torch
|
6 |
+
from scripts.all_typing import *
|
7 |
+
|
8 |
+
training_config = "gradio_app/custom_models/image2normal.yaml"
|
9 |
+
checkpoint_path = "ckpt/image2normal/unet_state_dict.pth"
|
10 |
+
trainer, pipeline = load_pipeline(training_config, checkpoint_path)
|
11 |
+
|
12 |
+
def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs):
|
13 |
+
global pipeline
|
14 |
+
pipeline = pipeline.to("cuda")
|
15 |
+
|
16 |
+
img_list = image if isinstance(image, list) else [image]
|
17 |
+
img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
|
18 |
+
images = trainer.pipeline_forward(
|
19 |
+
pipeline=pipeline,
|
20 |
+
image=img_list,
|
21 |
+
num_inference_steps=num_inference_steps,
|
22 |
+
guidance_scale=guidance_scale,
|
23 |
+
**kwargs
|
24 |
+
).images
|
25 |
+
images = simple_remove(images)
|
26 |
+
if do_rotate and len(images) > 1:
|
27 |
+
images = rotate_normals_torch(images, return_types='pil')
|
28 |
+
return images
|
gradio_app/custom_models/utils.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from gradio_app.utils import rgba_to_rgb
|
5 |
+
from custum_3d_diffusion.trainings.config_classes import ExprimentConfig, TrainerSubConfig
|
6 |
+
from custum_3d_diffusion import modules
|
7 |
+
from custum_3d_diffusion.custum_modules.unifield_processor import AttnConfig, ConfigurableUNet2DConditionModel
|
8 |
+
from custum_3d_diffusion.trainings.base import BasicTrainer
|
9 |
+
from custum_3d_diffusion.trainings.utils import load_config
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class FakeAccelerator:
|
14 |
+
device: torch.device = torch.device("cuda")
|
15 |
+
|
16 |
+
|
17 |
+
def init_trainers(cfg_path: str, weight_dtype: torch.dtype, extras: dict):
|
18 |
+
accelerator = FakeAccelerator()
|
19 |
+
cfg: ExprimentConfig = load_config(ExprimentConfig, cfg_path, extras)
|
20 |
+
init_config: AttnConfig = load_config(AttnConfig, cfg.init_config)
|
21 |
+
configurable_unet = ConfigurableUNet2DConditionModel(init_config, weight_dtype)
|
22 |
+
configurable_unet.enable_xformers_memory_efficient_attention()
|
23 |
+
trainer_cfgs: List[TrainerSubConfig] = [load_config(TrainerSubConfig, trainer) for trainer in cfg.trainers]
|
24 |
+
trainers: List[BasicTrainer] = [modules.find(trainer.trainer_type)(accelerator, None, configurable_unet, trainer.trainer, weight_dtype, i) for i, trainer in enumerate(trainer_cfgs)]
|
25 |
+
return trainers, configurable_unet
|
26 |
+
|
27 |
+
from gradio_app.utils import make_image_grid, split_image
|
28 |
+
def process_image(function, img, guidance_scale=2., merged_image=False, remove_bg=True):
|
29 |
+
from rembg import remove
|
30 |
+
if remove_bg:
|
31 |
+
img = remove(img)
|
32 |
+
img = rgba_to_rgb(img)
|
33 |
+
if merged_image:
|
34 |
+
img = split_image(img, rows=2)
|
35 |
+
images = function(
|
36 |
+
image=img,
|
37 |
+
guidance_scale=guidance_scale,
|
38 |
+
)
|
39 |
+
if len(images) > 1:
|
40 |
+
return make_image_grid(images, rows=2)
|
41 |
+
else:
|
42 |
+
return images[0]
|
43 |
+
|
44 |
+
|
45 |
+
def process_text(trainer, pipeline, img, guidance_scale=2.):
|
46 |
+
pipeline.cfg.validation_prompts = [img]
|
47 |
+
titles, images = trainer.batched_validation_forward(pipeline, guidance_scale=[guidance_scale])
|
48 |
+
return images[0]
|
49 |
+
|
50 |
+
|
51 |
+
def load_pipeline(config_path, ckpt_path, pipeline_filter=lambda x: True, weight_dtype = torch.bfloat16):
|
52 |
+
training_config = config_path
|
53 |
+
load_from_checkpoint = ckpt_path
|
54 |
+
extras = []
|
55 |
+
device = "cuda"
|
56 |
+
trainers, configurable_unet = init_trainers(training_config, weight_dtype, extras)
|
57 |
+
shared_modules = dict()
|
58 |
+
for trainer in trainers:
|
59 |
+
shared_modules = trainer.init_shared_modules(shared_modules)
|
60 |
+
|
61 |
+
if load_from_checkpoint is not None:
|
62 |
+
state_dict = torch.load(load_from_checkpoint, map_location="cpu")
|
63 |
+
configurable_unet.unet.load_state_dict(state_dict, strict=False)
|
64 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
65 |
+
configurable_unet.unet.to(device, dtype=weight_dtype)
|
66 |
+
|
67 |
+
pipeline = None
|
68 |
+
trainer_out = None
|
69 |
+
for trainer in trainers:
|
70 |
+
if pipeline_filter(trainer.cfg.trainer_name):
|
71 |
+
pipeline = trainer.construct_pipeline(shared_modules, configurable_unet.unet)
|
72 |
+
pipeline.set_progress_bar_config(disable=False)
|
73 |
+
trainer_out = trainer
|
74 |
+
pipeline = pipeline.to(device, dtype=weight_dtype)
|
75 |
+
return trainer_out, pipeline
|
gradio_app/examples/Groot.png
ADDED
gradio_app/examples/aaa.png
ADDED
gradio_app/examples/abma.png
ADDED
gradio_app/examples/akun.png
ADDED
gradio_app/examples/anya.png
ADDED
gradio_app/examples/bag.png
ADDED
Git LFS Details
|
gradio_app/examples/ex1.png
ADDED
Git LFS Details
|
gradio_app/examples/ex2.png
ADDED
gradio_app/examples/ex3.jpg
ADDED
gradio_app/examples/ex4.png
ADDED
gradio_app/examples/generated_1715761545_frame0.png
ADDED
gradio_app/examples/generated_1715762357_frame0.png
ADDED
gradio_app/examples/generated_1715763329_frame0.png
ADDED
gradio_app/examples/hatsune_miku.png
ADDED
gradio_app/examples/princess-large.png
ADDED
gradio_app/gradio_3dgen.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import os
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
from pytorch3d.structures import Meshes
|
6 |
+
from gradio_app.utils import clean_up
|
7 |
+
from gradio_app.custom_models.mvimg_prediction import run_mvprediction
|
8 |
+
from gradio_app.custom_models.normal_prediction import predict_normals
|
9 |
+
from scripts.refine_lr_to_sr import run_sr_fast
|
10 |
+
from scripts.utils import save_glb_and_video
|
11 |
+
# from scripts.multiview_inference import geo_reconstruct
|
12 |
+
from scripts.multiview_inference import geo_reconstruct_part1, geo_reconstruct_part2, geo_reconstruct_part3
|
13 |
+
|
14 |
+
@spaces.GPU(duration=100)
|
15 |
+
def run_mv(preview_img, input_processing, seed):
|
16 |
+
if preview_img.size[0] <= 512:
|
17 |
+
preview_img = run_sr_fast([preview_img])[0]
|
18 |
+
rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
|
19 |
+
return rgb_pils, front_pil
|
20 |
+
|
21 |
+
@spaces.GPU(duration=100) # seems split into multiple part will leads to `RuntimeError`, before fix it, still initialize here
|
22 |
+
def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
|
23 |
+
if preview_img is None:
|
24 |
+
raise gr.Error("The input image is none!")
|
25 |
+
if isinstance(preview_img, str):
|
26 |
+
preview_img = Image.open(preview_img)
|
27 |
+
|
28 |
+
rgb_pils, front_pil = run_mv(preview_img, input_processing, seed)
|
29 |
+
|
30 |
+
vertices, faces, img_list = geo_reconstruct_part1(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
|
31 |
+
|
32 |
+
meshes = geo_reconstruct_part2(vertices, faces)
|
33 |
+
|
34 |
+
new_meshes = geo_reconstruct_part3(meshes, img_list)
|
35 |
+
|
36 |
+
vertices = new_meshes.verts_packed()
|
37 |
+
vertices = vertices / 2 * 1.35
|
38 |
+
vertices[..., [0, 2]] = - vertices[..., [0, 2]]
|
39 |
+
new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures)
|
40 |
+
|
41 |
+
ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video)
|
42 |
+
return ret_mesh, video
|
43 |
+
|
44 |
+
#######################################
|
45 |
+
def create_ui(concurrency_id="wkl"):
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column(scale=1):
|
48 |
+
input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
|
49 |
+
|
50 |
+
example_folder = os.path.join(os.path.dirname(__file__), "./examples")
|
51 |
+
example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)])
|
52 |
+
gr.Examples(
|
53 |
+
examples=example_fns,
|
54 |
+
inputs=[input_image],
|
55 |
+
cache_examples=False,
|
56 |
+
label='Examples',
|
57 |
+
examples_per_page=12
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
with gr.Column(scale=1):
|
62 |
+
# export mesh display
|
63 |
+
output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320, camera_position=(90, 90, 2))
|
64 |
+
output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
|
65 |
+
|
66 |
+
input_processing = gr.Checkbox(
|
67 |
+
value=True,
|
68 |
+
label='Remove Background',
|
69 |
+
visible=True,
|
70 |
+
)
|
71 |
+
do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False)
|
72 |
+
expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
|
73 |
+
init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False)
|
74 |
+
setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed")
|
75 |
+
render_video = gr.Checkbox(value=False, visible=False, label="generate video")
|
76 |
+
fullrunv2_btn = gr.Button('Generate 3D', variant = "primary", interactive=True)
|
77 |
+
|
78 |
+
fullrunv2_btn.click(
|
79 |
+
fn = generate3dv2,
|
80 |
+
inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type],
|
81 |
+
outputs=[output_mesh, output_video],
|
82 |
+
concurrency_id=concurrency_id,
|
83 |
+
api_name="generate3dv2",
|
84 |
+
).success(clean_up, api_name=False)
|
85 |
+
return input_image
|
gradio_app/gradio_3dgen_steps.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
from gradio_app.custom_models.mvimg_prediction import run_mvprediction
|
5 |
+
from gradio_app.utils import make_image_grid, split_image
|
6 |
+
from scripts.utils import save_glb_and_video
|
7 |
+
|
8 |
+
def concept_to_multiview(preview_img, input_processing, seed, guidance=1.):
|
9 |
+
seed = int(seed)
|
10 |
+
if preview_img is None:
|
11 |
+
raise gr.Error("preview_img is none.")
|
12 |
+
if isinstance(preview_img, str):
|
13 |
+
preview_img = Image.open(preview_img)
|
14 |
+
|
15 |
+
rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=seed, guidance_scale=guidance)
|
16 |
+
rgb_pil = make_image_grid(rgb_pils, rows=2)
|
17 |
+
return rgb_pil, front_pil
|
18 |
+
|
19 |
+
def concept_to_multiview_ui(concurrency_id="wkl"):
|
20 |
+
with gr.Row():
|
21 |
+
with gr.Column(scale=2):
|
22 |
+
preview_img = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
|
23 |
+
input_processing = gr.Checkbox(
|
24 |
+
value=True,
|
25 |
+
label='Remove Background',
|
26 |
+
)
|
27 |
+
seed = gr.Slider(minimum=-1, maximum=1000000000, value=-1, step=1.0, label="seed")
|
28 |
+
guidance = gr.Slider(minimum=1.0, maximum=5.0, value=1.0, label="Guidance Scale", step=0.5)
|
29 |
+
run_btn = gr.Button('Generate Multiview', interactive=True)
|
30 |
+
with gr.Column(scale=3):
|
31 |
+
# export mesh display
|
32 |
+
output_rgb = gr.Image(type='pil', label="RGB", show_label=True)
|
33 |
+
output_front = gr.Image(type='pil', image_mode='RGBA', label="Frontview", show_label=True)
|
34 |
+
run_btn.click(
|
35 |
+
fn = concept_to_multiview,
|
36 |
+
inputs=[preview_img, input_processing, seed, guidance],
|
37 |
+
outputs=[output_rgb, output_front],
|
38 |
+
concurrency_id=concurrency_id,
|
39 |
+
api_name=False,
|
40 |
+
)
|
41 |
+
return output_rgb, output_front
|
42 |
+
|
43 |
+
from gradio_app.custom_models.normal_prediction import predict_normals
|
44 |
+
from scripts.multiview_inference import geo_reconstruct
|
45 |
+
def multiview_to_mesh_v2(rgb_pil, normal_pil, front_pil, do_refine=False, expansion_weight=0.1, init_type="std"):
|
46 |
+
rgb_pils = split_image(rgb_pil, rows=2)
|
47 |
+
if normal_pil is not None:
|
48 |
+
normal_pil = split_image(normal_pil, rows=2)
|
49 |
+
if front_pil is None:
|
50 |
+
front_pil = rgb_pils[0]
|
51 |
+
new_meshes = geo_reconstruct(rgb_pils, normal_pil, front_pil, do_refine=do_refine, predict_normal=normal_pil is None, expansion_weight=expansion_weight, init_type=init_type)
|
52 |
+
ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=False)
|
53 |
+
return ret_mesh
|
54 |
+
|
55 |
+
def new_multiview_to_mesh_ui(concurrency_id="wkl"):
|
56 |
+
with gr.Row():
|
57 |
+
with gr.Column(scale=2):
|
58 |
+
rgb_pil = gr.Image(type='pil', image_mode='RGB', label='RGB')
|
59 |
+
front_pil = gr.Image(type='pil', image_mode='RGBA', label='Frontview(Optinal)')
|
60 |
+
normal_pil = gr.Image(type='pil', image_mode='RGBA', label='Normal(Optinal)')
|
61 |
+
do_refine = gr.Checkbox(
|
62 |
+
value=False,
|
63 |
+
label='Refine rgb',
|
64 |
+
visible=False,
|
65 |
+
)
|
66 |
+
expansion_weight = gr.Slider(minimum=-1.0, maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
|
67 |
+
init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh initialization", value="std", visible=False)
|
68 |
+
run_btn = gr.Button('Generate 3D', interactive=True)
|
69 |
+
with gr.Column(scale=3):
|
70 |
+
# export mesh display
|
71 |
+
output_mesh = gr.Model3D(value=None, label="mesh model", show_label=True)
|
72 |
+
run_btn.click(
|
73 |
+
fn = multiview_to_mesh_v2,
|
74 |
+
inputs=[rgb_pil, normal_pil, front_pil, do_refine, expansion_weight, init_type],
|
75 |
+
outputs=[output_mesh],
|
76 |
+
concurrency_id=concurrency_id,
|
77 |
+
api_name="multiview_to_mesh",
|
78 |
+
)
|
79 |
+
return rgb_pil, front_pil, output_mesh
|
80 |
+
|
81 |
+
|
82 |
+
#######################################
|
83 |
+
def create_step_ui(concurrency_id="wkl"):
|
84 |
+
with gr.Tab(label="3D:concept_to_multiview"):
|
85 |
+
concept_to_multiview_ui(concurrency_id)
|
86 |
+
with gr.Tab(label="3D:new_multiview_to_mesh"):
|
87 |
+
new_multiview_to_mesh_ui(concurrency_id)
|
gradio_app/gradio_local.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
if __name__ == "__main__":
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
sys.path.append(os.curdir)
|
5 |
+
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
|
6 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
7 |
+
os.environ['TRANSFORMERS_OFFLINE']='0'
|
8 |
+
os.environ['DIFFUSERS_OFFLINE']='0'
|
9 |
+
os.environ['HF_HUB_OFFLINE']='0'
|
10 |
+
os.environ['GRADIO_ANALYTICS_ENABLED']='False'
|
11 |
+
os.environ['HF_ENDPOINT']='https://hf-mirror.com'
|
12 |
+
import torch
|
13 |
+
torch.set_float32_matmul_precision('medium')
|
14 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
15 |
+
torch.set_grad_enabled(False)
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
from gradio_app.gradio_3dgen import create_ui as create_3d_ui
|
21 |
+
# from app.gradio_3dgen_steps import create_step_ui
|
22 |
+
from gradio_app.all_models import model_zoo
|
23 |
+
|
24 |
+
|
25 |
+
_TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
|
26 |
+
_DESCRIPTION = '''
|
27 |
+
[Project page](https://wukailu.github.io/Unique3D/)
|
28 |
+
|
29 |
+
* High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
|
30 |
+
|
31 |
+
* The demo is still under construction, and more features are expected to be implemented soon.
|
32 |
+
'''
|
33 |
+
|
34 |
+
def launch(
|
35 |
+
port,
|
36 |
+
listen=False,
|
37 |
+
share=False,
|
38 |
+
gradio_root="",
|
39 |
+
):
|
40 |
+
model_zoo.init_models()
|
41 |
+
|
42 |
+
with gr.Blocks(
|
43 |
+
title=_TITLE,
|
44 |
+
theme=gr.themes.Monochrome(),
|
45 |
+
) as demo:
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column(scale=1):
|
48 |
+
gr.Markdown('# ' + _TITLE)
|
49 |
+
gr.Markdown(_DESCRIPTION)
|
50 |
+
create_3d_ui("wkl")
|
51 |
+
|
52 |
+
launch_args = {}
|
53 |
+
if listen:
|
54 |
+
launch_args["server_name"] = "0.0.0.0"
|
55 |
+
|
56 |
+
demo.queue(default_concurrency_limit=1).launch(
|
57 |
+
server_port=None if port == 0 else port,
|
58 |
+
share=share,
|
59 |
+
root_path=gradio_root if gradio_root != "" else None, # "/myapp"
|
60 |
+
**launch_args,
|
61 |
+
)
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
parser = argparse.ArgumentParser()
|
65 |
+
args, extra = parser.parse_known_args()
|
66 |
+
parser.add_argument("--listen", action="store_true")
|
67 |
+
parser.add_argument("--port", type=int, default=0)
|
68 |
+
parser.add_argument("--share", action="store_true")
|
69 |
+
parser.add_argument("--gradio_root", default="")
|
70 |
+
args = parser.parse_args()
|
71 |
+
launch(
|
72 |
+
args.port,
|
73 |
+
listen=args.listen,
|
74 |
+
share=args.share,
|
75 |
+
gradio_root=args.gradio_root,
|
76 |
+
)
|
gradio_app/utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import gc
|
5 |
+
import numpy as np
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from scripts.refine_lr_to_sr import run_sr_fast
|
9 |
+
|
10 |
+
GRADIO_CACHE = "/tmp/gradio/"
|
11 |
+
|
12 |
+
def clean_up():
|
13 |
+
torch.cuda.empty_cache()
|
14 |
+
gc.collect()
|
15 |
+
|
16 |
+
def remove_color(arr):
|
17 |
+
if arr.shape[-1] == 4:
|
18 |
+
arr = arr[..., :3]
|
19 |
+
# calc diffs
|
20 |
+
base = arr[0, 0]
|
21 |
+
diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1)
|
22 |
+
alpha = (diffs <= 80)
|
23 |
+
|
24 |
+
arr[alpha] = 255
|
25 |
+
alpha = ~alpha
|
26 |
+
arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1)
|
27 |
+
return arr
|
28 |
+
|
29 |
+
def simple_remove(imgs, run_sr=True):
|
30 |
+
"""Only works for normal"""
|
31 |
+
if not isinstance(imgs, list):
|
32 |
+
imgs = [imgs]
|
33 |
+
single_input = True
|
34 |
+
else:
|
35 |
+
single_input = False
|
36 |
+
if run_sr:
|
37 |
+
imgs = run_sr_fast(imgs)
|
38 |
+
rets = []
|
39 |
+
for img in imgs:
|
40 |
+
arr = np.array(img)
|
41 |
+
arr = remove_color(arr)
|
42 |
+
rets.append(Image.fromarray(arr.astype(np.uint8)))
|
43 |
+
if single_input:
|
44 |
+
return rets[0]
|
45 |
+
return rets
|
46 |
+
|
47 |
+
def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"):
|
48 |
+
new_image = Image.new("RGBA", rgba.size, bkgd)
|
49 |
+
new_image.paste(rgba, (0, 0), rgba)
|
50 |
+
new_image = new_image.convert('RGB')
|
51 |
+
return new_image
|
52 |
+
|
53 |
+
def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"):
|
54 |
+
rgb_white = rgba_to_rgb(rgba, bkgd)
|
55 |
+
new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1))
|
56 |
+
return new_rgba
|
57 |
+
|
58 |
+
def split_image(image, rows=None, cols=None):
|
59 |
+
"""
|
60 |
+
inverse function of make_image_grid
|
61 |
+
"""
|
62 |
+
# image is in square
|
63 |
+
if rows is None and cols is None:
|
64 |
+
# image.size [W, H]
|
65 |
+
rows = 1
|
66 |
+
cols = image.size[0] // image.size[1]
|
67 |
+
assert cols * image.size[1] == image.size[0]
|
68 |
+
subimg_size = image.size[1]
|
69 |
+
elif rows is None:
|
70 |
+
subimg_size = image.size[0] // cols
|
71 |
+
rows = image.size[1] // subimg_size
|
72 |
+
assert rows * subimg_size == image.size[1]
|
73 |
+
elif cols is None:
|
74 |
+
subimg_size = image.size[1] // rows
|
75 |
+
cols = image.size[0] // subimg_size
|
76 |
+
assert cols * subimg_size == image.size[0]
|
77 |
+
else:
|
78 |
+
subimg_size = image.size[1] // rows
|
79 |
+
assert cols * subimg_size == image.size[0]
|
80 |
+
subimgs = []
|
81 |
+
for i in range(rows):
|
82 |
+
for j in range(cols):
|
83 |
+
subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size))
|
84 |
+
subimgs.append(subimg)
|
85 |
+
return subimgs
|
86 |
+
|
87 |
+
def make_image_grid(images, rows=None, cols=None, resize=None):
|
88 |
+
if rows is None and cols is None:
|
89 |
+
rows = 1
|
90 |
+
cols = len(images)
|
91 |
+
if rows is None:
|
92 |
+
rows = len(images) // cols
|
93 |
+
if len(images) % cols != 0:
|
94 |
+
rows += 1
|
95 |
+
if cols is None:
|
96 |
+
cols = len(images) // rows
|
97 |
+
if len(images) % rows != 0:
|
98 |
+
cols += 1
|
99 |
+
total_imgs = rows * cols
|
100 |
+
if total_imgs > len(images):
|
101 |
+
images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))]
|
102 |
+
|
103 |
+
if resize is not None:
|
104 |
+
images = [img.resize((resize, resize)) for img in images]
|
105 |
+
|
106 |
+
w, h = images[0].size
|
107 |
+
grid = Image.new(images[0].mode, size=(cols * w, rows * h))
|
108 |
+
|
109 |
+
for i, img in enumerate(images):
|
110 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
111 |
+
return grid
|
112 |
+
|
mesh_reconstruction/func.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/Profactor/continuous-remeshing
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import trimesh
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
def to_numpy(*args):
|
8 |
+
def convert(a):
|
9 |
+
if isinstance(a,torch.Tensor):
|
10 |
+
return a.detach().cpu().numpy()
|
11 |
+
assert a is None or isinstance(a,np.ndarray)
|
12 |
+
return a
|
13 |
+
|
14 |
+
return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args)
|
15 |
+
|
16 |
+
def laplacian(
|
17 |
+
num_verts:int,
|
18 |
+
edges: torch.Tensor #E,2
|
19 |
+
) -> torch.Tensor: #sparse V,V
|
20 |
+
"""create sparse Laplacian matrix"""
|
21 |
+
V = num_verts
|
22 |
+
E = edges.shape[0]
|
23 |
+
|
24 |
+
#adjacency matrix,
|
25 |
+
idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T # (2, 2*E)
|
26 |
+
ones = torch.ones(2*E, dtype=torch.float32, device=edges.device)
|
27 |
+
A = torch.sparse.FloatTensor(idx, ones, (V, V))
|
28 |
+
|
29 |
+
#degree matrix
|
30 |
+
deg = torch.sparse.sum(A, dim=1).to_dense()
|
31 |
+
idx = torch.arange(V, device=edges.device)
|
32 |
+
idx = torch.stack([idx, idx], dim=0)
|
33 |
+
D = torch.sparse.FloatTensor(idx, deg, (V, V))
|
34 |
+
|
35 |
+
return D - A
|
36 |
+
|
37 |
+
def _translation(x, y, z, device):
|
38 |
+
return torch.tensor([[1., 0, 0, x],
|
39 |
+
[0, 1, 0, y],
|
40 |
+
[0, 0, 1, z],
|
41 |
+
[0, 0, 0, 1]],device=device) #4,4
|
42 |
+
|
43 |
+
def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
|
44 |
+
"""
|
45 |
+
see https://blog.csdn.net/wodownload2/article/details/85069240/
|
46 |
+
"""
|
47 |
+
if l is None:
|
48 |
+
l = -r
|
49 |
+
if t is None:
|
50 |
+
t = r
|
51 |
+
if b is None:
|
52 |
+
b = -t
|
53 |
+
p = torch.zeros([4,4],device=device)
|
54 |
+
p[0,0] = 2*n/(r-l)
|
55 |
+
p[0,2] = (r+l)/(r-l)
|
56 |
+
p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1)
|
57 |
+
p[1,2] = (t+b)/(t-b)
|
58 |
+
p[2,2] = -(f+n)/(f-n)
|
59 |
+
p[2,3] = -(2*f*n)/(f-n)
|
60 |
+
p[3,2] = -1
|
61 |
+
return p #4,4
|
62 |
+
|
63 |
+
def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
|
64 |
+
if l is None:
|
65 |
+
l = -r
|
66 |
+
if t is None:
|
67 |
+
t = r
|
68 |
+
if b is None:
|
69 |
+
b = -t
|
70 |
+
o = torch.zeros([4,4],device=device)
|
71 |
+
o[0,0] = 2/(r-l)
|
72 |
+
o[0,3] = -(r+l)/(r-l)
|
73 |
+
o[1,1] = 2/(t-b) * (-1 if flip_y else 1)
|
74 |
+
o[1,3] = -(t+b)/(t-b)
|
75 |
+
o[2,2] = -2/(f-n)
|
76 |
+
o[2,3] = -(f+n)/(f-n)
|
77 |
+
o[3,3] = 1
|
78 |
+
return o #4,4
|
79 |
+
|
80 |
+
def make_star_cameras(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
|
81 |
+
if r is None:
|
82 |
+
r = 1/distance
|
83 |
+
A = az_count
|
84 |
+
P = pol_count
|
85 |
+
C = A * P
|
86 |
+
|
87 |
+
phi = torch.arange(0,A) * (2*torch.pi/A)
|
88 |
+
phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone()
|
89 |
+
phi_rot[:,0,2,2] = phi.cos()
|
90 |
+
phi_rot[:,0,2,0] = -phi.sin()
|
91 |
+
phi_rot[:,0,0,2] = phi.sin()
|
92 |
+
phi_rot[:,0,0,0] = phi.cos()
|
93 |
+
|
94 |
+
theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2
|
95 |
+
theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone()
|
96 |
+
theta_rot[0,:,1,1] = theta.cos()
|
97 |
+
theta_rot[0,:,1,2] = -theta.sin()
|
98 |
+
theta_rot[0,:,2,1] = theta.sin()
|
99 |
+
theta_rot[0,:,2,2] = theta.cos()
|
100 |
+
|
101 |
+
mv = torch.empty((C,4,4), device=device)
|
102 |
+
mv[:] = torch.eye(4, device=device)
|
103 |
+
mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3)
|
104 |
+
mv = _translation(0, 0, -distance, device) @ mv
|
105 |
+
|
106 |
+
return mv, _projection(r,device)
|
107 |
+
|
108 |
+
def make_star_cameras_orthographic(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
|
109 |
+
mv, _ = make_star_cameras(az_count,pol_count,distance,r,image_size,device)
|
110 |
+
if r is None:
|
111 |
+
r = 1
|
112 |
+
return mv, _orthographic(r,device)
|
113 |
+
|
114 |
+
def make_sphere(level:int=2,radius=1.,device='cuda') -> Tuple[torch.Tensor,torch.Tensor]:
|
115 |
+
sphere = trimesh.creation.icosphere(subdivisions=level, radius=1.0, color=None)
|
116 |
+
vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius
|
117 |
+
faces = torch.tensor(sphere.faces, device=device, dtype=torch.long)
|
118 |
+
return vertices,faces
|
119 |
+
|
120 |
+
from pytorch3d.renderer import (
|
121 |
+
FoVOrthographicCameras,
|
122 |
+
look_at_view_transform,
|
123 |
+
)
|
124 |
+
|
125 |
+
def get_camera(R, T, focal_length=1 / (2**0.5)):
|
126 |
+
focal_length = 1 / focal_length
|
127 |
+
camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
|
128 |
+
return camera
|
129 |
+
|
130 |
+
def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1):
|
131 |
+
R, T = look_at_view_transform(dist, 0, azim_list)
|
132 |
+
focal_length = 1 / focal
|
133 |
+
return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device)
|
mesh_reconstruction/opt.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/Profactor/continuous-remeshing
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import torch_scatter
|
5 |
+
from typing import Tuple
|
6 |
+
from mesh_reconstruction.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges
|
7 |
+
|
8 |
+
@torch.no_grad()
|
9 |
+
def remesh(
|
10 |
+
vertices_etc:torch.Tensor, #V,D
|
11 |
+
faces:torch.Tensor, #F,3 long
|
12 |
+
min_edgelen:torch.Tensor, #V
|
13 |
+
max_edgelen:torch.Tensor, #V
|
14 |
+
flip:bool,
|
15 |
+
max_vertices=1e6
|
16 |
+
):
|
17 |
+
|
18 |
+
# dummies
|
19 |
+
vertices_etc,faces = prepend_dummies(vertices_etc,faces)
|
20 |
+
vertices = vertices_etc[:,:3] #V,3
|
21 |
+
nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device)
|
22 |
+
min_edgelen = torch.concat((nan_tensor,min_edgelen))
|
23 |
+
max_edgelen = torch.concat((nan_tensor,max_edgelen))
|
24 |
+
|
25 |
+
# collapse
|
26 |
+
edges,face_to_edge = calc_edges(faces) #E,2 F,3
|
27 |
+
edge_length = calc_edge_length(vertices,edges) #E
|
28 |
+
face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3
|
29 |
+
vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3
|
30 |
+
face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5)
|
31 |
+
shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0
|
32 |
+
priority = face_collapse.float() + shortness
|
33 |
+
vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority)
|
34 |
+
|
35 |
+
# split
|
36 |
+
if vertices.shape[0]<max_vertices:
|
37 |
+
edges,face_to_edge = calc_edges(faces) #E,2 F,3
|
38 |
+
vertices = vertices_etc[:,:3] #V,3
|
39 |
+
edge_length = calc_edge_length(vertices,edges) #E
|
40 |
+
splits = edge_length > max_edgelen[edges].mean(dim=-1)
|
41 |
+
vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False)
|
42 |
+
|
43 |
+
vertices_etc,faces = pack(vertices_etc,faces)
|
44 |
+
vertices = vertices_etc[:,:3]
|
45 |
+
|
46 |
+
if flip:
|
47 |
+
edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3
|
48 |
+
flip_edges(vertices,faces,edges,edge_to_face,with_border=False)
|
49 |
+
|
50 |
+
return remove_dummies(vertices_etc,faces)
|
51 |
+
|
52 |
+
def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int):
|
53 |
+
"""lerp with adam's bias correction"""
|
54 |
+
c_prev = 1-weight**(step-1)
|
55 |
+
c = 1-weight**step
|
56 |
+
a_weight = weight*c_prev/c
|
57 |
+
b_weight = (1-weight)/c
|
58 |
+
a.mul_(a_weight).add_(b, alpha=b_weight)
|
59 |
+
|
60 |
+
|
61 |
+
class MeshOptimizer:
|
62 |
+
"""Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh()."""
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
vertices:torch.Tensor, #V,3
|
66 |
+
faces:torch.Tensor, #F,3
|
67 |
+
lr=0.3, #learning rate
|
68 |
+
betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu
|
69 |
+
gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing)
|
70 |
+
nu_ref=0.3, #reference velocity for edge length controller
|
71 |
+
edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length
|
72 |
+
edge_len_tol=.5, #edge length tolerance for split and collapse
|
73 |
+
gain=.2, #gain value for edge length controller
|
74 |
+
laplacian_weight=.02, #for laplacian smoothing/regularization
|
75 |
+
ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0])
|
76 |
+
grad_lim=10., #gradients are clipped to m1.abs()*grad_lim
|
77 |
+
remesh_interval=1, #larger intervals are faster but with worse mesh quality
|
78 |
+
local_edgelen=True, #set to False to use a global scalar reference edge length instead
|
79 |
+
):
|
80 |
+
self._vertices = vertices
|
81 |
+
self._faces = faces
|
82 |
+
self._lr = lr
|
83 |
+
self._betas = betas
|
84 |
+
self._gammas = gammas
|
85 |
+
self._nu_ref = nu_ref
|
86 |
+
self._edge_len_lims = edge_len_lims
|
87 |
+
self._edge_len_tol = edge_len_tol
|
88 |
+
self._gain = gain
|
89 |
+
self._laplacian_weight = laplacian_weight
|
90 |
+
self._ramp = ramp
|
91 |
+
self._grad_lim = grad_lim
|
92 |
+
self._remesh_interval = remesh_interval
|
93 |
+
self._local_edgelen = local_edgelen
|
94 |
+
self._step = 0
|
95 |
+
|
96 |
+
V = self._vertices.shape[0]
|
97 |
+
# prepare continuous tensor for all vertex-based data
|
98 |
+
self._vertices_etc = torch.zeros([V,9],device=vertices.device)
|
99 |
+
self._split_vertices_etc()
|
100 |
+
self.vertices.copy_(vertices) #initialize vertices
|
101 |
+
self._vertices.requires_grad_()
|
102 |
+
self._ref_len.fill_(edge_len_lims[1])
|
103 |
+
|
104 |
+
@property
|
105 |
+
def vertices(self):
|
106 |
+
return self._vertices
|
107 |
+
|
108 |
+
@property
|
109 |
+
def faces(self):
|
110 |
+
return self._faces
|
111 |
+
|
112 |
+
def _split_vertices_etc(self):
|
113 |
+
self._vertices = self._vertices_etc[:,:3]
|
114 |
+
self._m2 = self._vertices_etc[:,3]
|
115 |
+
self._nu = self._vertices_etc[:,4]
|
116 |
+
self._m1 = self._vertices_etc[:,5:8]
|
117 |
+
self._ref_len = self._vertices_etc[:,8]
|
118 |
+
|
119 |
+
with_gammas = any(g!=0 for g in self._gammas)
|
120 |
+
self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3]
|
121 |
+
|
122 |
+
def zero_grad(self):
|
123 |
+
self._vertices.grad = None
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def step(self):
|
127 |
+
|
128 |
+
eps = 1e-8
|
129 |
+
|
130 |
+
self._step += 1
|
131 |
+
|
132 |
+
# spatial smoothing
|
133 |
+
edges,_ = calc_edges(self._faces) #E,2
|
134 |
+
E = edges.shape[0]
|
135 |
+
edge_smooth = self._smooth[edges] #E,2,S
|
136 |
+
neighbor_smooth = torch.zeros_like(self._smooth) #V,S
|
137 |
+
torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth)
|
138 |
+
|
139 |
+
#apply optional smoothing of m1,m2,nu
|
140 |
+
if self._gammas[0]:
|
141 |
+
self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0])
|
142 |
+
if self._gammas[1]:
|
143 |
+
self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1])
|
144 |
+
if self._gammas[2]:
|
145 |
+
self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2])
|
146 |
+
|
147 |
+
#add laplace smoothing to gradients
|
148 |
+
laplace = self._vertices - neighbor_smooth[:,:3]
|
149 |
+
grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight)
|
150 |
+
|
151 |
+
#gradient clipping
|
152 |
+
if self._step>1:
|
153 |
+
grad_lim = self._m1.abs().mul_(self._grad_lim)
|
154 |
+
grad.clamp_(min=-grad_lim,max=grad_lim)
|
155 |
+
|
156 |
+
# moment updates
|
157 |
+
lerp_unbiased(self._m1, grad, self._betas[0], self._step)
|
158 |
+
lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step)
|
159 |
+
|
160 |
+
velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3
|
161 |
+
speed = velocity.norm(dim=-1) #V
|
162 |
+
|
163 |
+
if self._betas[2]:
|
164 |
+
lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V
|
165 |
+
else:
|
166 |
+
self._nu.copy_(speed) #V
|
167 |
+
|
168 |
+
# update vertices
|
169 |
+
ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp)
|
170 |
+
self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr)
|
171 |
+
|
172 |
+
# update target edge length
|
173 |
+
if self._step % self._remesh_interval == 0:
|
174 |
+
if self._local_edgelen:
|
175 |
+
len_change = (1 + (self._nu - self._nu_ref) * self._gain)
|
176 |
+
else:
|
177 |
+
len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain)
|
178 |
+
self._ref_len *= len_change
|
179 |
+
self._ref_len.clamp_(*self._edge_len_lims)
|
180 |
+
|
181 |
+
def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]:
|
182 |
+
min_edge_len = self._ref_len * (1 - self._edge_len_tol)
|
183 |
+
max_edge_len = self._ref_len * (1 + self._edge_len_tol)
|
184 |
+
|
185 |
+
self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e6)
|
186 |
+
|
187 |
+
self._split_vertices_etc()
|
188 |
+
self._vertices.requires_grad_()
|
189 |
+
|
190 |
+
return self._vertices, self._faces
|
mesh_reconstruction/recon.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from typing import List
|
6 |
+
from mesh_reconstruction.remesh import calc_vertex_normals
|
7 |
+
from mesh_reconstruction.opt import MeshOptimizer
|
8 |
+
from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
|
9 |
+
from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
|
10 |
+
from scripts.utils import to_py3d_mesh, init_target
|
11 |
+
|
12 |
+
def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
|
13 |
+
vertices, faces = vertices.to("cuda"), faces.to("cuda")
|
14 |
+
assert len(pils) == 4
|
15 |
+
mv,proj = make_star_cameras_orthographic(4, 1)
|
16 |
+
renderer = NormalsRenderer(mv,proj,list(pils[0].size))
|
17 |
+
# cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
|
18 |
+
# renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda")
|
19 |
+
|
20 |
+
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
21 |
+
# 1. no rotate
|
22 |
+
target_images = target_images[[0, 3, 2, 1]]
|
23 |
+
|
24 |
+
# 2. init from coarse mesh
|
25 |
+
opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len))
|
26 |
+
|
27 |
+
vertices = opt.vertices
|
28 |
+
|
29 |
+
mask = target_images[..., -1] < 0.5
|
30 |
+
|
31 |
+
for i in tqdm(range(steps)):
|
32 |
+
opt.zero_grad()
|
33 |
+
opt._lr *= decay
|
34 |
+
normals = calc_vertex_normals(vertices,faces)
|
35 |
+
images = renderer.render(vertices,normals,faces)
|
36 |
+
|
37 |
+
loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean()
|
38 |
+
|
39 |
+
t_mask = images[..., -1] > 0.5
|
40 |
+
loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean()
|
41 |
+
loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
|
42 |
+
|
43 |
+
loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight
|
44 |
+
|
45 |
+
# out of box
|
46 |
+
loss_oob = (vertices.abs() > 0.99).float().mean() * 10
|
47 |
+
loss = loss + loss_oob
|
48 |
+
|
49 |
+
loss.backward()
|
50 |
+
opt.step()
|
51 |
+
|
52 |
+
vertices,faces = opt.remesh(poisson=False)
|
53 |
+
|
54 |
+
vertices, faces = vertices.detach().cpu(), faces.detach().cpu()
|
55 |
+
|
56 |
+
if return_mesh:
|
57 |
+
return to_py3d_mesh(vertices, faces)
|
58 |
+
else:
|
59 |
+
return vertices, faces
|
mesh_reconstruction/refine.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
from typing import List
|
5 |
+
from mesh_reconstruction.remesh import calc_vertex_normals
|
6 |
+
from mesh_reconstruction.opt import MeshOptimizer
|
7 |
+
from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
|
8 |
+
from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
|
9 |
+
from scripts.project_mesh import multiview_color_projection, get_cameras_list
|
10 |
+
from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
|
11 |
+
|
12 |
+
def run_mesh_refine(vertices, faces, pils: List[Image.Image], steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True):
|
13 |
+
vertices, faces = vertices.to("cuda"), faces.to("cuda")
|
14 |
+
if process_inputs:
|
15 |
+
vertices = vertices * 2 / 1.35
|
16 |
+
vertices[..., [0, 2]] = - vertices[..., [0, 2]]
|
17 |
+
|
18 |
+
poission_steps = []
|
19 |
+
|
20 |
+
assert len(pils) == 4
|
21 |
+
mv,proj = make_star_cameras_orthographic(4, 1)
|
22 |
+
renderer = NormalsRenderer(mv,proj,list(pils[0].size))
|
23 |
+
# cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
|
24 |
+
# renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda")
|
25 |
+
|
26 |
+
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
27 |
+
# 1. no rotate
|
28 |
+
target_images = target_images[[0, 3, 2, 1]]
|
29 |
+
|
30 |
+
# 2. init from coarse mesh
|
31 |
+
opt = MeshOptimizer(vertices,faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02)
|
32 |
+
|
33 |
+
vertices = opt.vertices
|
34 |
+
alpha_init = None
|
35 |
+
|
36 |
+
mask = target_images[..., -1] < 0.5
|
37 |
+
|
38 |
+
for i in tqdm(range(steps)):
|
39 |
+
opt.zero_grad()
|
40 |
+
opt._lr *= decay
|
41 |
+
normals = calc_vertex_normals(vertices,faces)
|
42 |
+
images = renderer.render(vertices,normals,faces)
|
43 |
+
if alpha_init is None:
|
44 |
+
alpha_init = images.detach()
|
45 |
+
|
46 |
+
if i < update_warmup or i % update_normal_interval == 0:
|
47 |
+
with torch.no_grad():
|
48 |
+
py3d_mesh = to_py3d_mesh(vertices, faces, normals)
|
49 |
+
cameras = get_cameras_list(azim_list = [0, 90, 180, 270], device=vertices.device, focal=1.)
|
50 |
+
_, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=[2.0, 0.8, 1.0, 0.8], confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear'))
|
51 |
+
target_normal = target_normal * 2 - 1
|
52 |
+
target_normal = torch.nn.functional.normalize(target_normal, dim=-1)
|
53 |
+
debug_images = renderer.render(vertices,target_normal,faces)
|
54 |
+
|
55 |
+
d_mask = images[..., -1] > 0.5
|
56 |
+
loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean()
|
57 |
+
|
58 |
+
loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
|
59 |
+
|
60 |
+
loss = loss_debug_l2 + loss_alpha_target_mask_l2
|
61 |
+
|
62 |
+
# out of box
|
63 |
+
loss_oob = (vertices.abs() > 0.99).float().mean() * 10
|
64 |
+
loss = loss + loss_oob
|
65 |
+
|
66 |
+
loss.backward()
|
67 |
+
opt.step()
|
68 |
+
|
69 |
+
vertices,faces = opt.remesh(poisson=(i in poission_steps))
|
70 |
+
|
71 |
+
vertices, faces = vertices.detach().cpu(), faces.detach().cpu()
|
72 |
+
|
73 |
+
if process_outputs:
|
74 |
+
vertices = vertices / 2 * 1.35
|
75 |
+
vertices[..., [0, 2]] = - vertices[..., [0, 2]]
|
76 |
+
|
77 |
+
if return_mesh:
|
78 |
+
return to_py3d_mesh(vertices, faces)
|
79 |
+
else:
|
80 |
+
return vertices, faces
|
mesh_reconstruction/remesh.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/Profactor/continuous-remeshing
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as tfunc
|
4 |
+
import torch_scatter
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
def prepend_dummies(
|
8 |
+
vertices:torch.Tensor, #V,D
|
9 |
+
faces:torch.Tensor, #F,3 long
|
10 |
+
)->Tuple[torch.Tensor,torch.Tensor]:
|
11 |
+
"""prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
|
12 |
+
V,D = vertices.shape
|
13 |
+
vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
|
14 |
+
faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
|
15 |
+
return vertices,faces
|
16 |
+
|
17 |
+
def remove_dummies(
|
18 |
+
vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
|
19 |
+
faces:torch.Tensor, #F,3 long - first face all zeros
|
20 |
+
)->Tuple[torch.Tensor,torch.Tensor]:
|
21 |
+
"""remove dummy elements added with prepend_dummies()"""
|
22 |
+
return vertices[1:],faces[1:]-1
|
23 |
+
|
24 |
+
|
25 |
+
def calc_edges(
|
26 |
+
faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros
|
27 |
+
with_edge_to_face: bool = False
|
28 |
+
) -> Tuple[torch.Tensor, ...]:
|
29 |
+
"""
|
30 |
+
returns Tuple of
|
31 |
+
- edges E,2 long, 0 for unused, lower vertex index first
|
32 |
+
- face_to_edge F,3 long
|
33 |
+
- (optional) edge_to_face shape=E,[left,right],[face,side]
|
34 |
+
|
35 |
+
o-<-----e1 e0,e1...edge, e0<e1
|
36 |
+
| /A L,R....left and right face
|
37 |
+
| L / | both triangles ordered counter clockwise
|
38 |
+
| / R | normals pointing out of screen
|
39 |
+
V/ |
|
40 |
+
e0---->-o
|
41 |
+
"""
|
42 |
+
|
43 |
+
F = faces.shape[0]
|
44 |
+
|
45 |
+
# make full edges, lower vertex index first
|
46 |
+
face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
|
47 |
+
full_edges = face_edges.reshape(F*3,2)
|
48 |
+
sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2
|
49 |
+
|
50 |
+
# make unique edges
|
51 |
+
edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
|
52 |
+
E = edges.shape[0]
|
53 |
+
face_to_edge = full_to_unique.reshape(F,3) #F,3
|
54 |
+
|
55 |
+
if not with_edge_to_face:
|
56 |
+
return edges, face_to_edge
|
57 |
+
|
58 |
+
is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
|
59 |
+
edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
|
60 |
+
scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
|
61 |
+
edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
|
62 |
+
edge_to_face[0] = 0
|
63 |
+
return edges, face_to_edge, edge_to_face
|
64 |
+
|
65 |
+
def calc_edge_length(
|
66 |
+
vertices:torch.Tensor, #V,3 first may be dummy
|
67 |
+
edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
|
68 |
+
)->torch.Tensor: #E
|
69 |
+
|
70 |
+
full_vertices = vertices[edges] #E,2,3
|
71 |
+
a,b = full_vertices.unbind(dim=1) #E,3
|
72 |
+
return torch.norm(a-b,p=2,dim=-1)
|
73 |
+
|
74 |
+
def calc_face_normals(
|
75 |
+
vertices:torch.Tensor, #V,3 first vertex may be unreferenced
|
76 |
+
faces:torch.Tensor, #F,3 long, first face may be all zero
|
77 |
+
normalize:bool=False,
|
78 |
+
)->torch.Tensor: #F,3
|
79 |
+
"""
|
80 |
+
n
|
81 |
+
|
|
82 |
+
c0 corners ordered counterclockwise when
|
83 |
+
/ \ looking onto surface (in neg normal direction)
|
84 |
+
c1---c2
|
85 |
+
"""
|
86 |
+
full_vertices = vertices[faces] #F,C=3,3
|
87 |
+
v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
|
88 |
+
face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
|
89 |
+
if normalize:
|
90 |
+
face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1)
|
91 |
+
return face_normals #F,3
|
92 |
+
|
93 |
+
def calc_vertex_normals(
|
94 |
+
vertices:torch.Tensor, #V,3 first vertex may be unreferenced
|
95 |
+
faces:torch.Tensor, #F,3 long, first face may be all zero
|
96 |
+
face_normals:torch.Tensor=None, #F,3, not normalized
|
97 |
+
)->torch.Tensor: #F,3
|
98 |
+
|
99 |
+
F = faces.shape[0]
|
100 |
+
|
101 |
+
if face_normals is None:
|
102 |
+
face_normals = calc_face_normals(vertices,faces)
|
103 |
+
|
104 |
+
vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
|
105 |
+
vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3))
|
106 |
+
vertex_normals = vertex_normals.sum(dim=1) #V,3
|
107 |
+
return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)
|
108 |
+
|
109 |
+
def calc_face_ref_normals(
|
110 |
+
faces:torch.Tensor, #F,3 long, 0 for unused
|
111 |
+
vertex_normals:torch.Tensor, #V,3 first unused
|
112 |
+
normalize:bool=False,
|
113 |
+
)->torch.Tensor: #F,3
|
114 |
+
"""calculate reference normals for face flip detection"""
|
115 |
+
full_normals = vertex_normals[faces] #F,C=3,3
|
116 |
+
ref_normals = full_normals.sum(dim=1) #F,3
|
117 |
+
if normalize:
|
118 |
+
ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
|
119 |
+
return ref_normals
|
120 |
+
|
121 |
+
def pack(
|
122 |
+
vertices:torch.Tensor, #V,3 first unused and nan
|
123 |
+
faces:torch.Tensor, #F,3 long, 0 for unused
|
124 |
+
)->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
|
125 |
+
"""removes unused elements in vertices and faces"""
|
126 |
+
V = vertices.shape[0]
|
127 |
+
|
128 |
+
# remove unused faces
|
129 |
+
used_faces = faces[:,0]!=0
|
130 |
+
used_faces[0] = True
|
131 |
+
faces = faces[used_faces] #sync
|
132 |
+
|
133 |
+
# remove unused vertices
|
134 |
+
used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
|
135 |
+
used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add')
|
136 |
+
used_vertices = used_vertices.any(dim=1)
|
137 |
+
used_vertices[0] = True
|
138 |
+
vertices = vertices[used_vertices] #sync
|
139 |
+
|
140 |
+
# update used faces
|
141 |
+
ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
|
142 |
+
V1 = used_vertices.sum()
|
143 |
+
ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync
|
144 |
+
faces = ind[faces]
|
145 |
+
|
146 |
+
return vertices,faces
|
147 |
+
|
148 |
+
def split_edges(
|
149 |
+
vertices:torch.Tensor, #V,3 first unused
|
150 |
+
faces:torch.Tensor, #F,3 long, 0 for unused
|
151 |
+
edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
|
152 |
+
face_to_edge:torch.Tensor, #F,3 long 0 for unused
|
153 |
+
splits, #E bool
|
154 |
+
pack_faces:bool=True,
|
155 |
+
)->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
|
156 |
+
|
157 |
+
# c2 c2 c...corners = faces
|
158 |
+
# . . . . s...side_vert, 0 means no split
|
159 |
+
# . . .N2 . S...shrunk_face
|
160 |
+
# . . . . Ni...new_faces
|
161 |
+
# s2 s1 s2|c2...s1|c1
|
162 |
+
# . . . . .
|
163 |
+
# . . . S . .
|
164 |
+
# . . . . N1 .
|
165 |
+
# c0...(s0=0)....c1 s0|c0...........c1
|
166 |
+
#
|
167 |
+
# pseudo-code:
|
168 |
+
# S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
|
169 |
+
# split = side_vert!=0 example:[False,True,True]
|
170 |
+
# N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
|
171 |
+
# N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
|
172 |
+
# N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]
|
173 |
+
|
174 |
+
V = vertices.shape[0]
|
175 |
+
F = faces.shape[0]
|
176 |
+
S = splits.sum().item() #sync
|
177 |
+
|
178 |
+
if S==0:
|
179 |
+
return vertices,faces
|
180 |
+
|
181 |
+
edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
|
182 |
+
edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
|
183 |
+
side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
|
184 |
+
split_edges = edges[splits] #S sync
|
185 |
+
|
186 |
+
#vertices
|
187 |
+
split_vertices = vertices[split_edges].mean(dim=1) #S,3
|
188 |
+
vertices = torch.concat((vertices,split_vertices),dim=0)
|
189 |
+
|
190 |
+
#faces
|
191 |
+
side_split = side_vert!=0 #F,3
|
192 |
+
shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
|
193 |
+
new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
|
194 |
+
faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
|
195 |
+
if pack_faces:
|
196 |
+
mask = faces[:,0]!=0
|
197 |
+
mask[0] = True
|
198 |
+
faces = faces[mask] #F',3 sync
|
199 |
+
|
200 |
+
return vertices,faces
|
201 |
+
|
202 |
+
def collapse_edges(
|
203 |
+
vertices:torch.Tensor, #V,3 first unused
|
204 |
+
faces:torch.Tensor, #F,3 long 0 for unused
|
205 |
+
edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
|
206 |
+
priorities:torch.Tensor, #E float
|
207 |
+
stable:bool=False, #only for unit testing
|
208 |
+
)->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
|
209 |
+
|
210 |
+
V = vertices.shape[0]
|
211 |
+
|
212 |
+
# check spacing
|
213 |
+
_,order = priorities.sort(stable=stable) #E
|
214 |
+
rank = torch.zeros_like(order)
|
215 |
+
rank[order] = torch.arange(0,len(rank),device=rank.device)
|
216 |
+
vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
|
217 |
+
edge_rank = rank #E
|
218 |
+
for i in range(3):
|
219 |
+
torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
|
220 |
+
edge_rank,_ = vert_rank[edges].max(dim=-1) #E
|
221 |
+
candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2
|
222 |
+
|
223 |
+
# check connectivity
|
224 |
+
vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
|
225 |
+
vert_connections[candidates[:,0]] = 1 #start
|
226 |
+
edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
|
227 |
+
vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
|
228 |
+
vert_connections[candidates] = 0 #clear start and end
|
229 |
+
edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
|
230 |
+
vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
|
231 |
+
collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end
|
232 |
+
|
233 |
+
# mean vertices
|
234 |
+
vertices[collapses[:,0]] = vertices[collapses].mean(dim=1)
|
235 |
+
|
236 |
+
# update faces
|
237 |
+
dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
|
238 |
+
dest[collapses[:,1]] = dest[collapses[:,0]]
|
239 |
+
faces = dest[faces] #F,3
|
240 |
+
c0,c1,c2 = faces.unbind(dim=-1)
|
241 |
+
collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
|
242 |
+
faces[collapsed] = 0
|
243 |
+
|
244 |
+
return vertices,faces
|
245 |
+
|
246 |
+
def calc_face_collapses(
|
247 |
+
vertices:torch.Tensor, #V,3 first unused
|
248 |
+
faces:torch.Tensor, #F,3 long, 0 for unused
|
249 |
+
edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
|
250 |
+
face_to_edge:torch.Tensor, #F,3 long 0 for unused
|
251 |
+
edge_length:torch.Tensor, #E
|
252 |
+
face_normals:torch.Tensor, #F,3
|
253 |
+
vertex_normals:torch.Tensor, #V,3 first unused
|
254 |
+
min_edge_length:torch.Tensor=None, #V
|
255 |
+
area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
|
256 |
+
shortest_probability = 0.8
|
257 |
+
)->torch.Tensor: #E edges to collapse
|
258 |
+
|
259 |
+
E = edges.shape[0]
|
260 |
+
F = faces.shape[0]
|
261 |
+
|
262 |
+
# face flips
|
263 |
+
ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
|
264 |
+
face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
|
265 |
+
|
266 |
+
# small faces
|
267 |
+
if min_edge_length is not None:
|
268 |
+
min_face_length = min_edge_length[faces].mean(dim=-1) #F
|
269 |
+
min_area = min_face_length**2 * area_ratio #F
|
270 |
+
face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
|
271 |
+
face_collapses[0] = False
|
272 |
+
|
273 |
+
# faces to edges
|
274 |
+
face_length = edge_length[face_to_edge] #F,3
|
275 |
+
|
276 |
+
if shortest_probability<1:
|
277 |
+
#select shortest edge with shortest_probability chance
|
278 |
+
randlim = round(2/(1-shortest_probability))
|
279 |
+
rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
|
280 |
+
sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
|
281 |
+
local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
|
282 |
+
else:
|
283 |
+
local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
|
284 |
+
|
285 |
+
edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
|
286 |
+
edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
|
287 |
+
edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long())
|
288 |
+
|
289 |
+
return edge_collapses.bool()
|
290 |
+
|
291 |
+
def flip_edges(
|
292 |
+
vertices:torch.Tensor, #V,3 first unused
|
293 |
+
faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
|
294 |
+
edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
|
295 |
+
edge_to_face:torch.Tensor, #E,[left,right],[face,side]
|
296 |
+
with_border:bool=True, #handle border edges (D=4 instead of D=6)
|
297 |
+
with_normal_check:bool=True, #check face normal flips
|
298 |
+
stable:bool=False, #only for unit testing
|
299 |
+
):
|
300 |
+
V = vertices.shape[0]
|
301 |
+
E = edges.shape[0]
|
302 |
+
device=vertices.device
|
303 |
+
vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
|
304 |
+
vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
|
305 |
+
neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
|
306 |
+
neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
|
307 |
+
edge_is_inside = neighbors.all(dim=-1) #E
|
308 |
+
|
309 |
+
if with_border:
|
310 |
+
# inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
|
311 |
+
# need to use float for masks in order to use scatter(reduce='multiply')
|
312 |
+
vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
|
313 |
+
src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
|
314 |
+
vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
|
315 |
+
vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
|
316 |
+
vertex_degree -= 2 * vertex_is_inside #V long
|
317 |
+
|
318 |
+
neighbor_degrees = vertex_degree[neighbors] #E,LR=2
|
319 |
+
edge_degrees = vertex_degree[edges] #E,2
|
320 |
+
#
|
321 |
+
# loss = Sum_over_affected_vertices((new_degree-6)**2)
|
322 |
+
# loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
|
323 |
+
# + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
|
324 |
+
# = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
|
325 |
+
#
|
326 |
+
loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
|
327 |
+
candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
|
328 |
+
loss_change = loss_change[candidates] #E'
|
329 |
+
if loss_change.shape[0]==0:
|
330 |
+
return
|
331 |
+
|
332 |
+
edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
|
333 |
+
_,order = loss_change.sort(descending=True, stable=stable) #E'
|
334 |
+
rank = torch.zeros_like(order)
|
335 |
+
rank[order] = torch.arange(0,len(rank),device=rank.device)
|
336 |
+
vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
|
337 |
+
torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
|
338 |
+
vertex_rank,_ = vertex_rank.max(dim=-1) #V
|
339 |
+
neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
|
340 |
+
flip = rank==neighborhood_rank #E'
|
341 |
+
|
342 |
+
if with_normal_check:
|
343 |
+
# cl-<-----e1 e0,e1...edge, e0<e1
|
344 |
+
# | /A L,R....left and right face
|
345 |
+
# | L / | both triangles ordered counter clockwise
|
346 |
+
# | / R | normals pointing out of screen
|
347 |
+
# V/ |
|
348 |
+
# e0---->-cr
|
349 |
+
v = vertices[edges_neighbors] #E",4,3
|
350 |
+
v = v - v[:,0:1] #make relative to e0
|
351 |
+
e1 = v[:,1]
|
352 |
+
cl = v[:,2]
|
353 |
+
cr = v[:,3]
|
354 |
+
n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors
|
355 |
+
flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
|
356 |
+
flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face
|
357 |
+
|
358 |
+
flip_edges_neighbors = edges_neighbors[flip] #E",4
|
359 |
+
flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
|
360 |
+
flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
|
361 |
+
faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))
|
mesh_reconstruction/render.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/Profactor/continuous-remeshing
|
2 |
+
import nvdiffrast.torch as dr
|
3 |
+
import torch
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
def _warmup(glctx, device=None):
|
7 |
+
device = 'cuda' if device is None else device
|
8 |
+
#windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
|
9 |
+
def tensor(*args, **kwargs):
|
10 |
+
return torch.tensor(*args, device=device, **kwargs)
|
11 |
+
pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
|
12 |
+
tri = tensor([[0, 1, 2]], dtype=torch.int32)
|
13 |
+
dr.rasterize(glctx, pos, tri, resolution=[256, 256])
|
14 |
+
|
15 |
+
class NormalsRenderer:
|
16 |
+
|
17 |
+
_glctx:dr.RasterizeCudaContext = None
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
mv: torch.Tensor, #C,4,4
|
22 |
+
proj: torch.Tensor, #C,4,4
|
23 |
+
image_size: Tuple[int,int],
|
24 |
+
mvp = None,
|
25 |
+
device=None,
|
26 |
+
):
|
27 |
+
if mvp is None:
|
28 |
+
self._mvp = proj @ mv #C,4,4
|
29 |
+
else:
|
30 |
+
self._mvp = mvp
|
31 |
+
self._image_size = image_size
|
32 |
+
self._glctx = dr.RasterizeCudaContext(device=device)
|
33 |
+
_warmup(self._glctx, device)
|
34 |
+
|
35 |
+
def render(self,
|
36 |
+
vertices: torch.Tensor, #V,3 float
|
37 |
+
normals: torch.Tensor, #V,3 float in [-1, 1]
|
38 |
+
faces: torch.Tensor, #F,3 long
|
39 |
+
) ->torch.Tensor: #C,H,W,4
|
40 |
+
|
41 |
+
V = vertices.shape[0]
|
42 |
+
faces = faces.type(torch.int32)
|
43 |
+
vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4
|
44 |
+
vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4
|
45 |
+
rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4
|
46 |
+
vert_col = (normals+1)/2 #V,3
|
47 |
+
col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3
|
48 |
+
alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1
|
49 |
+
col = torch.concat((col,alpha),dim=-1) #C,H,W,4
|
50 |
+
col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
|
51 |
+
return col #C,H,W,4
|
52 |
+
|
53 |
+
from pytorch3d.structures import Meshes
|
54 |
+
from pytorch3d.renderer.mesh.shader import ShaderBase
|
55 |
+
from pytorch3d.renderer import (
|
56 |
+
RasterizationSettings,
|
57 |
+
MeshRendererWithFragments,
|
58 |
+
TexturesVertex,
|
59 |
+
MeshRasterizer,
|
60 |
+
BlendParams,
|
61 |
+
FoVOrthographicCameras,
|
62 |
+
look_at_view_transform,
|
63 |
+
hard_rgb_blend,
|
64 |
+
)
|
65 |
+
|
66 |
+
class VertexColorShader(ShaderBase):
|
67 |
+
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
68 |
+
blend_params = kwargs.get("blend_params", self.blend_params)
|
69 |
+
texels = meshes.sample_textures(fragments)
|
70 |
+
return hard_rgb_blend(texels, fragments, blend_params)
|
71 |
+
|
72 |
+
def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"):
|
73 |
+
if len(mesh) != len(cameras):
|
74 |
+
if len(cameras) % len(mesh) == 0:
|
75 |
+
mesh = mesh.extend(len(cameras))
|
76 |
+
else:
|
77 |
+
raise NotImplementedError()
|
78 |
+
|
79 |
+
# render requires everything in float16 or float32
|
80 |
+
input_dtype = dtype
|
81 |
+
blend_params = BlendParams(1e-4, 1e-4, bkgd)
|
82 |
+
|
83 |
+
# Define the settings for rasterization and shading
|
84 |
+
raster_settings = RasterizationSettings(
|
85 |
+
image_size=(H, W),
|
86 |
+
blur_radius=blur_radius,
|
87 |
+
faces_per_pixel=faces_per_pixel,
|
88 |
+
clip_barycentric_coords=True,
|
89 |
+
bin_size=None,
|
90 |
+
max_faces_per_bin=500000,
|
91 |
+
)
|
92 |
+
|
93 |
+
# Create a renderer by composing a rasterizer and a shader
|
94 |
+
# We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used)
|
95 |
+
renderer = MeshRendererWithFragments(
|
96 |
+
rasterizer=MeshRasterizer(
|
97 |
+
cameras=cameras,
|
98 |
+
raster_settings=raster_settings
|
99 |
+
),
|
100 |
+
shader=VertexColorShader(
|
101 |
+
device=device,
|
102 |
+
cameras=cameras,
|
103 |
+
blend_params=blend_params
|
104 |
+
)
|
105 |
+
)
|
106 |
+
|
107 |
+
# render RGB and depth, get mask
|
108 |
+
with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type):
|
109 |
+
images, _ = renderer(mesh)
|
110 |
+
return images # BHW4
|
111 |
+
|
112 |
+
class Pytorch3DNormalsRenderer:
|
113 |
+
def __init__(self, cameras, image_size, device):
|
114 |
+
self.cameras = cameras.to(device)
|
115 |
+
self._image_size = image_size
|
116 |
+
self.device = device
|
117 |
+
|
118 |
+
def render(self,
|
119 |
+
vertices: torch.Tensor, #V,3 float
|
120 |
+
normals: torch.Tensor, #V,3 float in [-1, 1]
|
121 |
+
faces: torch.Tensor, #F,3 long
|
122 |
+
) ->torch.Tensor: #C,H,W,4
|
123 |
+
mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device)
|
124 |
+
return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device)
|
125 |
+
|
126 |
+
def save_tensor_to_img(tensor, save_dir):
|
127 |
+
from PIL import Image
|
128 |
+
import numpy as np
|
129 |
+
for idx, img in enumerate(tensor):
|
130 |
+
img = img[..., :3].cpu().numpy()
|
131 |
+
img = (img * 255).astype(np.uint8)
|
132 |
+
img = Image.fromarray(img)
|
133 |
+
img.save(save_dir + f"{idx}.png")
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
import sys
|
137 |
+
import os
|
138 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
139 |
+
from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
|
140 |
+
cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
|
141 |
+
mv,proj = make_star_cameras_orthographic(4, 1)
|
142 |
+
resolution = 1024
|
143 |
+
renderer1 = NormalsRenderer(mv,proj, [resolution,resolution], device="cuda")
|
144 |
+
renderer2 = Pytorch3DNormalsRenderer(cameras, [resolution,resolution], device="cuda")
|
145 |
+
vertices = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[1,0,0]], device="cuda", dtype=torch.float32)
|
146 |
+
normals = torch.tensor([[-1,-1,-1],[1,-1,-1],[-1,-1,1],[-1,1,-1]], device="cuda", dtype=torch.float32)
|
147 |
+
faces = torch.tensor([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], device="cuda", dtype=torch.long)
|
148 |
+
|
149 |
+
import time
|
150 |
+
t0 = time.time()
|
151 |
+
r1 = renderer1.render(vertices, normals, faces)
|
152 |
+
print("time r1:", time.time() - t0)
|
153 |
+
|
154 |
+
t0 = time.time()
|
155 |
+
r2 = renderer2.render(vertices, normals, faces)
|
156 |
+
print("time r2:", time.time() - t0)
|
157 |
+
|
158 |
+
for i in range(4):
|
159 |
+
print((r1[i]-r2[i]).abs().mean(), (r1[i]+r2[i]).abs().mean())
|
package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff4a35615ed42148c8579622bee6dca88f7f3be683671524a282fafaf7589682
|
3 |
+
size 3079614
|
package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11e7f7f781fef16c09ec8d03bfb6da84cf61c54fc59e8a4ea047a90c4a24e88f
|
3 |
+
size 162720703
|
scripts/all_typing.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code from https://github.com/threestudio-project
|
2 |
+
|
3 |
+
"""
|
4 |
+
This module contains type annotations for the project, using
|
5 |
+
1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
|
6 |
+
2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
|
7 |
+
|
8 |
+
Two types of typing checking can be used:
|
9 |
+
1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
|
10 |
+
2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
|
11 |
+
"""
|
12 |
+
|
13 |
+
# Basic types
|
14 |
+
from typing import (
|
15 |
+
Any,
|
16 |
+
Callable,
|
17 |
+
Dict,
|
18 |
+
Iterable,
|
19 |
+
List,
|
20 |
+
Literal,
|
21 |
+
NamedTuple,
|
22 |
+
NewType,
|
23 |
+
Optional,
|
24 |
+
Sized,
|
25 |
+
Tuple,
|
26 |
+
Type,
|
27 |
+
TypeVar,
|
28 |
+
Union,
|
29 |
+
)
|
30 |
+
|
31 |
+
# Tensor dtype
|
32 |
+
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
|
33 |
+
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
|
34 |
+
|
35 |
+
# Config type
|
36 |
+
from omegaconf import DictConfig
|
37 |
+
|
38 |
+
# PyTorch Tensor type
|
39 |
+
from torch import Tensor
|
40 |
+
|
41 |
+
# Runtime type checking decorator
|
42 |
+
from typeguard import typechecked as typechecker
|
scripts/load_onnx.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import onnxruntime
|
2 |
+
import torch
|
3 |
+
|
4 |
+
providers = [
|
5 |
+
# ('TensorrtExecutionProvider', {
|
6 |
+
# 'device_id': 0,
|
7 |
+
# 'trt_max_workspace_size': 8 * 1024 * 1024 * 1024,
|
8 |
+
# 'trt_fp16_enable': True,
|
9 |
+
# 'trt_engine_cache_enable': True,
|
10 |
+
# }),
|
11 |
+
('CUDAExecutionProvider', {
|
12 |
+
'device_id': 0,
|
13 |
+
'arena_extend_strategy': 'kSameAsRequested',
|
14 |
+
'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
|
15 |
+
'cudnn_conv_algo_search': 'HEURISTIC',
|
16 |
+
})
|
17 |
+
]
|
18 |
+
|
19 |
+
def load_onnx(file_path: str):
|
20 |
+
assert file_path.endswith(".onnx")
|
21 |
+
sess_opt = onnxruntime.SessionOptions()
|
22 |
+
ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers)
|
23 |
+
return ort_session
|
24 |
+
|
25 |
+
|
26 |
+
def load_onnx_caller(file_path: str, single_output=False):
|
27 |
+
ort_session = load_onnx(file_path)
|
28 |
+
def caller(*args):
|
29 |
+
torch_input = isinstance(args[0], torch.Tensor)
|
30 |
+
if torch_input:
|
31 |
+
torch_input_dtype = args[0].dtype
|
32 |
+
torch_input_device = args[0].device
|
33 |
+
# check all are torch.Tensor and have same dtype and device
|
34 |
+
assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor"
|
35 |
+
assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor"
|
36 |
+
assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor"
|
37 |
+
args = [arg.cpu().float().numpy() for arg in args]
|
38 |
+
|
39 |
+
ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))}
|
40 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
41 |
+
|
42 |
+
if torch_input:
|
43 |
+
ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs]
|
44 |
+
|
45 |
+
if single_output:
|
46 |
+
return ort_outs[0]
|
47 |
+
return ort_outs
|
48 |
+
return caller
|
scripts/mesh_init.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from pytorch3d.structures import Meshes
|
5 |
+
from pytorch3d.renderer import TexturesVertex
|
6 |
+
from scripts.utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh
|
7 |
+
import pymeshlab
|
8 |
+
|
9 |
+
_MAX_THREAD = 8
|
10 |
+
|
11 |
+
# rgb and depth to mesh
|
12 |
+
def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"):
|
13 |
+
pixel_center = 0.5 if use_pixel_centers else 0
|
14 |
+
i, j = np.meshgrid(
|
15 |
+
np.arange(W, dtype=np.float32) + pixel_center,
|
16 |
+
np.arange(H, dtype=np.float32) + pixel_center,
|
17 |
+
indexing='xy'
|
18 |
+
)
|
19 |
+
i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device)
|
20 |
+
|
21 |
+
origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3
|
22 |
+
directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3
|
23 |
+
|
24 |
+
return origins, directions
|
25 |
+
|
26 |
+
def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False):
|
27 |
+
if valid_HWC is None:
|
28 |
+
valid_HWC = torch.ones_like(pred_HWC).bool()
|
29 |
+
H, W = rgb_BCHW.shape[-2:]
|
30 |
+
rgb_BCHW = rgb_BCHW.flip(-2)
|
31 |
+
pred_HWC = pred_HWC.flip(0)
|
32 |
+
valid_HWC = valid_HWC.flip(0)
|
33 |
+
rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device)
|
34 |
+
verts = rays_o + rays_d * pred_HWC # [H, W, 3]
|
35 |
+
verts = verts.reshape(-1, 3) # [V, 3]
|
36 |
+
indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device)
|
37 |
+
faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1)
|
38 |
+
# faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1]
|
39 |
+
faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1]
|
40 |
+
faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1)
|
41 |
+
# faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:]
|
42 |
+
faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:]
|
43 |
+
faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3), faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)], dim=0) # (F, 3)
|
44 |
+
colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3)
|
45 |
+
if is_back:
|
46 |
+
verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device)
|
47 |
+
|
48 |
+
used_verts = faces.unique()
|
49 |
+
old_to_new_mapping = torch.zeros_like(verts[..., 0]).long()
|
50 |
+
old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device)
|
51 |
+
new_faces = old_to_new_mapping[faces]
|
52 |
+
mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]]))
|
53 |
+
return mesh
|
54 |
+
|
55 |
+
def normalmap_to_depthmap(normal_np):
|
56 |
+
from scripts.normal_to_height_map import estimate_height_map
|
57 |
+
height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96)
|
58 |
+
return height
|
59 |
+
|
60 |
+
def transform_back_normal_to_front(normal_pil):
|
61 |
+
arr = np.array(normal_pil) # in [0, 255]
|
62 |
+
arr[..., 0] = 255-arr[..., 0]
|
63 |
+
arr[..., 2] = 255-arr[..., 2]
|
64 |
+
return Image.fromarray(arr.astype(np.uint8))
|
65 |
+
|
66 |
+
def calc_w_over_h(normal_pil):
|
67 |
+
if isinstance(normal_pil, Image.Image):
|
68 |
+
arr = np.array(normal_pil)
|
69 |
+
else:
|
70 |
+
assert isinstance(normal_pil, np.ndarray)
|
71 |
+
arr = normal_pil
|
72 |
+
if arr.shape[-1] == 4:
|
73 |
+
alpha = arr[..., -1] / 255.
|
74 |
+
alpha[alpha >= 0.5] = 1
|
75 |
+
alpha[alpha < 0.5] = 0
|
76 |
+
else:
|
77 |
+
alpha = ~(arr.min(axis=-1) >= 250)
|
78 |
+
h_min, w_min = np.min(np.where(alpha), axis=1)
|
79 |
+
h_max, w_max = np.max(np.where(alpha), axis=1)
|
80 |
+
return (w_max - w_min) / (h_max - h_min)
|
81 |
+
|
82 |
+
def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0):
|
83 |
+
if is_back:
|
84 |
+
normal_pil = transform_back_normal_to_front(normal_pil)
|
85 |
+
normal_img = np.array(normal_pil)
|
86 |
+
rgb_img = np.array(rgb_pil)
|
87 |
+
if normal_img.shape[-1] == 4:
|
88 |
+
valid_HWC = normal_img[..., [3]] / 255
|
89 |
+
elif rgb_img.shape[-1] == 4:
|
90 |
+
valid_HWC = rgb_img[..., [3]] / 255
|
91 |
+
else:
|
92 |
+
raise ValueError("invalid input, either normal or rgb should have alpha channel")
|
93 |
+
|
94 |
+
real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0])
|
95 |
+
|
96 |
+
heights = normalmap_to_depthmap(normal_img)
|
97 |
+
rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None]
|
98 |
+
valid_HWC[valid_HWC < 0.5] = 0
|
99 |
+
valid_HWC[valid_HWC >= 0.5] = 1
|
100 |
+
valid_HWC = torch.from_numpy(valid_HWC).bool()
|
101 |
+
if init_type == "std":
|
102 |
+
# accurate but not stable
|
103 |
+
pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None]
|
104 |
+
elif init_type == "thin":
|
105 |
+
heights = heights - heights.min()
|
106 |
+
heights = (heights / heights.max() * 0.2)
|
107 |
+
pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
|
108 |
+
else:
|
109 |
+
# stable but not accurate
|
110 |
+
heights = heights - heights.min()
|
111 |
+
heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1]
|
112 |
+
pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
|
113 |
+
|
114 |
+
# set the boarder pixels to 0 height
|
115 |
+
import cv2
|
116 |
+
# edge filter
|
117 |
+
edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255)
|
118 |
+
edge = torch.from_numpy(edge).bool()[..., None]
|
119 |
+
pred_HWC[edge] = 0
|
120 |
+
|
121 |
+
valid_HWC[pred_HWC < clamp_min] = False
|
122 |
+
return depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back)
|
123 |
+
|
124 |
+
def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0):
|
125 |
+
ms = pymeshlab.MeshSet()
|
126 |
+
ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh")
|
127 |
+
if simplification > 0:
|
128 |
+
ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
|
129 |
+
ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True)
|
130 |
+
if simplification > 0:
|
131 |
+
ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
|
132 |
+
return meshlab_mesh_to_py3dmesh(ms.current_mesh())
|