cavargas10 commited on
Commit
1f30907
1 Parent(s): 69b6a88

Upload 56 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. assets/teaser.jpg +0 -0
  3. custum_3d_diffusion/custum_modules/attention_processors.py +385 -0
  4. custum_3d_diffusion/custum_modules/unifield_processor.py +459 -0
  5. custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py +298 -0
  6. custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py +296 -0
  7. custum_3d_diffusion/modules.py +14 -0
  8. custum_3d_diffusion/trainings/__init__.py +0 -0
  9. custum_3d_diffusion/trainings/base.py +208 -0
  10. custum_3d_diffusion/trainings/config_classes.py +35 -0
  11. custum_3d_diffusion/trainings/image2image_trainer.py +86 -0
  12. custum_3d_diffusion/trainings/image2mvimage_trainer.py +139 -0
  13. custum_3d_diffusion/trainings/utils.py +25 -0
  14. gradio_app/__init__.py +0 -0
  15. gradio_app/all_models.py +22 -0
  16. gradio_app/custom_models/image2mvimage.yaml +63 -0
  17. gradio_app/custom_models/image2normal.yaml +61 -0
  18. gradio_app/custom_models/mvimg_prediction.py +59 -0
  19. gradio_app/custom_models/normal_prediction.py +28 -0
  20. gradio_app/custom_models/utils.py +75 -0
  21. gradio_app/examples/Groot.png +0 -0
  22. gradio_app/examples/aaa.png +0 -0
  23. gradio_app/examples/abma.png +0 -0
  24. gradio_app/examples/akun.png +0 -0
  25. gradio_app/examples/anya.png +0 -0
  26. gradio_app/examples/bag.png +3 -0
  27. gradio_app/examples/ex1.png +3 -0
  28. gradio_app/examples/ex2.png +0 -0
  29. gradio_app/examples/ex3.jpg +0 -0
  30. gradio_app/examples/ex4.png +0 -0
  31. gradio_app/examples/generated_1715761545_frame0.png +0 -0
  32. gradio_app/examples/generated_1715762357_frame0.png +0 -0
  33. gradio_app/examples/generated_1715763329_frame0.png +0 -0
  34. gradio_app/examples/hatsune_miku.png +0 -0
  35. gradio_app/examples/princess-large.png +0 -0
  36. gradio_app/gradio_3dgen.py +85 -0
  37. gradio_app/gradio_3dgen_steps.py +87 -0
  38. gradio_app/gradio_local.py +76 -0
  39. gradio_app/utils.py +112 -0
  40. mesh_reconstruction/func.py +133 -0
  41. mesh_reconstruction/opt.py +190 -0
  42. mesh_reconstruction/recon.py +59 -0
  43. mesh_reconstruction/refine.py +80 -0
  44. mesh_reconstruction/remesh.py +361 -0
  45. mesh_reconstruction/render.py +159 -0
  46. package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl +3 -0
  47. package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl +3 -0
  48. scripts/all_typing.py +42 -0
  49. scripts/load_onnx.py +48 -0
  50. scripts/mesh_init.py +132 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ gradio_app/examples/bag.png filter=lfs diff=lfs merge=lfs -text
37
+ gradio_app/examples/ex1.png filter=lfs diff=lfs merge=lfs -text
38
+ package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
39
+ package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl filter=lfs diff=lfs merge=lfs -text
assets/teaser.jpg ADDED
custum_3d_diffusion/custum_modules/attention_processors.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+ import torch
3
+ from diffusers.models.attention_processor import Attention
4
+
5
+ def construct_pix2pix_attention(hidden_states_dim, norm_type="none"):
6
+ if norm_type == "layernorm":
7
+ norm = torch.nn.LayerNorm(hidden_states_dim)
8
+ else:
9
+ norm = torch.nn.Identity()
10
+ attention = Attention(
11
+ query_dim=hidden_states_dim,
12
+ heads=8,
13
+ dim_head=hidden_states_dim // 8,
14
+ bias=True,
15
+ )
16
+ # NOTE: xformers 0.22 does not support batchsize >= 4096
17
+ attention.xformers_not_supported = True # hacky solution
18
+ return norm, attention
19
+
20
+ class ExtraAttnProc(torch.nn.Module):
21
+ def __init__(
22
+ self,
23
+ chained_proc,
24
+ enabled=False,
25
+ name=None,
26
+ mode='extract',
27
+ with_proj_in=False,
28
+ proj_in_dim=768,
29
+ target_dim=None,
30
+ pixel_wise_crosspond=False,
31
+ norm_type="none", # none or layernorm
32
+ crosspond_effect_on="all", # all or first
33
+ crosspond_chain_pos="parralle", # before or parralle or after
34
+ simple_3d=False,
35
+ views=4,
36
+ ) -> None:
37
+ super().__init__()
38
+ self.enabled = enabled
39
+ self.chained_proc = chained_proc
40
+ self.name = name
41
+ self.mode = mode
42
+ self.with_proj_in=with_proj_in
43
+ self.proj_in_dim = proj_in_dim
44
+ self.target_dim = target_dim or proj_in_dim
45
+ self.hidden_states_dim = self.target_dim
46
+ self.pixel_wise_crosspond = pixel_wise_crosspond
47
+ self.crosspond_effect_on = crosspond_effect_on
48
+ self.crosspond_chain_pos = crosspond_chain_pos
49
+ self.views = views
50
+ self.simple_3d = simple_3d
51
+ if self.with_proj_in and self.enabled:
52
+ self.in_linear = torch.nn.Linear(self.proj_in_dim, self.target_dim, bias=False)
53
+ if self.target_dim == self.proj_in_dim:
54
+ self.in_linear.weight.data = torch.eye(proj_in_dim)
55
+ else:
56
+ self.in_linear = None
57
+ if self.pixel_wise_crosspond and self.enabled:
58
+ self.crosspond_norm, self.crosspond_attention = construct_pix2pix_attention(self.hidden_states_dim, norm_type=norm_type)
59
+
60
+ def do_crosspond_attention(self, hidden_states: torch.FloatTensor, other_states: torch.FloatTensor):
61
+ hidden_states = self.crosspond_norm(hidden_states)
62
+
63
+ batch, L, D = hidden_states.shape
64
+ assert hidden_states.shape == other_states.shape, f"got {hidden_states.shape} and {other_states.shape}"
65
+ # to -> batch * L, 1, D
66
+ hidden_states = hidden_states.reshape(batch * L, 1, D)
67
+ other_states = other_states.reshape(batch * L, 1, D)
68
+ hidden_states_catted = other_states
69
+ hidden_states = self.crosspond_attention(
70
+ hidden_states,
71
+ encoder_hidden_states=hidden_states_catted,
72
+ )
73
+ return hidden_states.reshape(batch, L, D)
74
+
75
+ def __call__(
76
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
77
+ ref_dict: dict = None, mode=None, **kwargs
78
+ ) -> Any:
79
+ if not self.enabled:
80
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
81
+ if encoder_hidden_states is None:
82
+ encoder_hidden_states = hidden_states
83
+ assert ref_dict is not None
84
+ if (mode or self.mode) == 'extract':
85
+ ref_dict[self.name] = hidden_states
86
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
87
+ if self.pixel_wise_crosspond and self.crosspond_chain_pos == "after":
88
+ ref_dict[self.name] = hidden_states1
89
+ return hidden_states1
90
+ elif (mode or self.mode) == 'inject':
91
+ ref_state = ref_dict.pop(self.name)
92
+ if self.with_proj_in:
93
+ ref_state = self.in_linear(ref_state)
94
+
95
+ B, L, D = ref_state.shape
96
+ if hidden_states.shape[0] == B:
97
+ modalities = 1
98
+ views = 1
99
+ else:
100
+ modalities = hidden_states.shape[0] // B // self.views
101
+ views = self.views
102
+ if self.pixel_wise_crosspond:
103
+ if self.crosspond_effect_on == "all":
104
+ ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, *ref_state.shape[-2:])
105
+
106
+ if self.crosspond_chain_pos == "before":
107
+ hidden_states = hidden_states + self.do_crosspond_attention(hidden_states, ref_state)
108
+
109
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
110
+
111
+ if self.crosspond_chain_pos == "parralle":
112
+ hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states, ref_state)
113
+
114
+ if self.crosspond_chain_pos == "after":
115
+ hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states1, ref_state)
116
+ return hidden_states1
117
+ else:
118
+ assert self.crosspond_effect_on == "first"
119
+ # hidden_states [B * modalities * views, L, D]
120
+ # ref_state [B, L, D]
121
+ ref_state = ref_state[:, None].expand(-1, modalities, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1]) # [B * modalities, L, D]
122
+
123
+ def do_paritial_crosspond(hidden_states, ref_state):
124
+ first_view_hidden_states = hidden_states.view(-1, views, hidden_states.shape[1], hidden_states.shape[2])[:, 0] # [B * modalities, L, D]
125
+ hidden_states2 = self.do_crosspond_attention(first_view_hidden_states, ref_state) # [B * modalities, L, D]
126
+ hidden_states2_padded = torch.zeros_like(hidden_states).reshape(-1, views, hidden_states.shape[1], hidden_states.shape[2])
127
+ hidden_states2_padded[:, 0] = hidden_states2
128
+ hidden_states2_padded = hidden_states2_padded.reshape(-1, hidden_states.shape[1], hidden_states.shape[2])
129
+ return hidden_states2_padded
130
+
131
+ if self.crosspond_chain_pos == "before":
132
+ hidden_states = hidden_states + do_paritial_crosspond(hidden_states, ref_state)
133
+
134
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) # [B * modalities * views, L, D]
135
+ if self.crosspond_chain_pos == "parralle":
136
+ hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states, ref_state)
137
+ if self.crosspond_chain_pos == "after":
138
+ hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states1, ref_state)
139
+ return hidden_states1
140
+ elif self.simple_3d:
141
+ B, L, C = encoder_hidden_states.shape
142
+ mv = self.views
143
+ encoder_hidden_states = encoder_hidden_states.reshape(B // mv, mv, L, C)
144
+ ref_state = ref_state[:, None]
145
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
146
+ encoder_hidden_states = encoder_hidden_states.reshape(B // mv, 1, (mv+1) * L, C)
147
+ encoder_hidden_states = encoder_hidden_states.repeat(1, mv, 1, 1).reshape(-1, (mv+1) * L, C)
148
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
149
+ else:
150
+ ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1])
151
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
152
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
153
+ else:
154
+ raise NotImplementedError("mode or self.mode is required to be 'extract' or 'inject'")
155
+
156
+ def add_extra_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
157
+ return_dict = torch.nn.ModuleDict()
158
+ proj_in_dim = kwargs.get('proj_in_dim', False)
159
+ kwargs.pop('proj_in_dim', None)
160
+
161
+ def recursive_add_processors(name: str, module: torch.nn.Module):
162
+ for sub_name, child in module.named_children():
163
+ if "ref_unet" not in (sub_name + name):
164
+ recursive_add_processors(f"{name}.{sub_name}", child)
165
+
166
+ if isinstance(module, Attention):
167
+ new_processor = ExtraAttnProc(
168
+ chained_proc=module.get_processor(),
169
+ enabled=enable_filter(f"{name}.processor"),
170
+ name=f"{name}.processor",
171
+ proj_in_dim=proj_in_dim if proj_in_dim else module.cross_attention_dim,
172
+ target_dim=module.cross_attention_dim,
173
+ **kwargs
174
+ )
175
+ module.set_processor(new_processor)
176
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
177
+
178
+ for name, module in model.named_children():
179
+ recursive_add_processors(name, module)
180
+ return return_dict
181
+
182
+ def switch_extra_processor(model, enable_filter=lambda x:True):
183
+ def recursive_add_processors(name: str, module: torch.nn.Module):
184
+ for sub_name, child in module.named_children():
185
+ recursive_add_processors(f"{name}.{sub_name}", child)
186
+
187
+ if isinstance(module, ExtraAttnProc):
188
+ module.enabled = enable_filter(name)
189
+
190
+ for name, module in model.named_children():
191
+ recursive_add_processors(name, module)
192
+
193
+ class multiviewAttnProc(torch.nn.Module):
194
+ def __init__(
195
+ self,
196
+ chained_proc,
197
+ enabled=False,
198
+ name=None,
199
+ hidden_states_dim=None,
200
+ chain_pos="parralle", # before or parralle or after
201
+ num_modalities=1,
202
+ views=4,
203
+ base_img_size=64,
204
+ ) -> None:
205
+ super().__init__()
206
+ self.enabled = enabled
207
+ self.chained_proc = chained_proc
208
+ self.name = name
209
+ self.hidden_states_dim = hidden_states_dim
210
+ self.num_modalities = num_modalities
211
+ self.views = views
212
+ self.base_img_size = base_img_size
213
+ self.chain_pos = chain_pos
214
+ self.diff_joint_attn = True
215
+
216
+ def __call__(
217
+ self,
218
+ attn: Attention,
219
+ hidden_states: torch.FloatTensor,
220
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
221
+ attention_mask: Optional[torch.FloatTensor] = None,
222
+ **kwargs
223
+ ) -> torch.Tensor:
224
+ if not self.enabled:
225
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
226
+
227
+ B, L, C = hidden_states.shape
228
+ mv = self.views
229
+ hidden_states = hidden_states.reshape(B // mv, mv, L, C).reshape(-1, mv * L, C)
230
+ hidden_states = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
231
+ return hidden_states.reshape(B // mv, mv, L, C).reshape(-1, L, C)
232
+
233
+ def add_multiview_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
234
+ return_dict = torch.nn.ModuleDict()
235
+ def recursive_add_processors(name: str, module: torch.nn.Module):
236
+ for sub_name, child in module.named_children():
237
+ if "ref_unet" not in (sub_name + name):
238
+ recursive_add_processors(f"{name}.{sub_name}", child)
239
+
240
+ if isinstance(module, Attention):
241
+ new_processor = multiviewAttnProc(
242
+ chained_proc=module.get_processor(),
243
+ enabled=enable_filter(f"{name}.processor"),
244
+ name=f"{name}.processor",
245
+ hidden_states_dim=module.inner_dim,
246
+ **kwargs
247
+ )
248
+ module.set_processor(new_processor)
249
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
250
+
251
+ for name, module in model.named_children():
252
+ recursive_add_processors(name, module)
253
+
254
+ return return_dict
255
+
256
+ def switch_multiview_processor(model, enable_filter=lambda x:True):
257
+ def recursive_add_processors(name: str, module: torch.nn.Module):
258
+ for sub_name, child in module.named_children():
259
+ recursive_add_processors(f"{name}.{sub_name}", child)
260
+
261
+ if isinstance(module, Attention):
262
+ processor = module.get_processor()
263
+ if isinstance(processor, multiviewAttnProc):
264
+ processor.enabled = enable_filter(f"{name}.processor")
265
+
266
+ for name, module in model.named_children():
267
+ recursive_add_processors(name, module)
268
+
269
+ class NNModuleWrapper(torch.nn.Module):
270
+ def __init__(self, module):
271
+ super().__init__()
272
+ self.module = module
273
+
274
+ def forward(self, *args, **kwargs):
275
+ return self.module(*args, **kwargs)
276
+
277
+ def __getattr__(self, name: str):
278
+ try:
279
+ return super().__getattr__(name)
280
+ except AttributeError:
281
+ return getattr(self.module, name)
282
+
283
+ class AttnProcessorSwitch(torch.nn.Module):
284
+ def __init__(
285
+ self,
286
+ proc_dict: dict,
287
+ enabled_proc="default",
288
+ name=None,
289
+ switch_name="default_switch",
290
+ ):
291
+ super().__init__()
292
+ self.proc_dict = torch.nn.ModuleDict({k: (v if isinstance(v, torch.nn.Module) else NNModuleWrapper(v)) for k, v in proc_dict.items()})
293
+ self.enabled_proc = enabled_proc
294
+ self.name = name
295
+ self.switch_name = switch_name
296
+ self.choose_module(enabled_proc)
297
+
298
+ def choose_module(self, enabled_proc):
299
+ self.enabled_proc = enabled_proc
300
+ assert enabled_proc in self.proc_dict.keys()
301
+
302
+ def __call__(
303
+ self,
304
+ *args,
305
+ **kwargs
306
+ ) -> torch.FloatTensor:
307
+ used_proc = self.proc_dict[self.enabled_proc]
308
+ return used_proc(*args, **kwargs)
309
+
310
+ def add_switch(model: torch.nn.Module, module_filter=lambda x:True, switch_dict_fn=lambda x: {"default": x}, switch_name="default_switch", enabled_proc="default"):
311
+ return_dict = torch.nn.ModuleDict()
312
+ def recursive_add_processors(name: str, module: torch.nn.Module):
313
+ for sub_name, child in module.named_children():
314
+ if "ref_unet" not in (sub_name + name):
315
+ recursive_add_processors(f"{name}.{sub_name}", child)
316
+
317
+ if isinstance(module, Attention):
318
+ processor = module.get_processor()
319
+ if module_filter(processor):
320
+ proc_dict = switch_dict_fn(processor)
321
+ new_processor = AttnProcessorSwitch(
322
+ proc_dict=proc_dict,
323
+ enabled_proc=enabled_proc,
324
+ name=f"{name}.processor",
325
+ switch_name=switch_name,
326
+ )
327
+ module.set_processor(new_processor)
328
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
329
+
330
+ for name, module in model.named_children():
331
+ recursive_add_processors(name, module)
332
+
333
+ return return_dict
334
+
335
+ def change_switch(model: torch.nn.Module, switch_name="default_switch", enabled_proc="default"):
336
+ def recursive_change_processors(name: str, module: torch.nn.Module):
337
+ for sub_name, child in module.named_children():
338
+ recursive_change_processors(f"{name}.{sub_name}", child)
339
+
340
+ if isinstance(module, Attention):
341
+ processor = module.get_processor()
342
+ if isinstance(processor, AttnProcessorSwitch) and processor.switch_name == switch_name:
343
+ processor.choose_module(enabled_proc)
344
+
345
+ for name, module in model.named_children():
346
+ recursive_change_processors(name, module)
347
+
348
+ ########## Hack: Attention fix #############
349
+ from diffusers.models.attention import Attention
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states: torch.FloatTensor,
354
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
355
+ attention_mask: Optional[torch.FloatTensor] = None,
356
+ **cross_attention_kwargs,
357
+ ) -> torch.Tensor:
358
+ r"""
359
+ The forward method of the `Attention` class.
360
+
361
+ Args:
362
+ hidden_states (`torch.Tensor`):
363
+ The hidden states of the query.
364
+ encoder_hidden_states (`torch.Tensor`, *optional*):
365
+ The hidden states of the encoder.
366
+ attention_mask (`torch.Tensor`, *optional*):
367
+ The attention mask to use. If `None`, no mask is applied.
368
+ **cross_attention_kwargs:
369
+ Additional keyword arguments to pass along to the cross attention.
370
+
371
+ Returns:
372
+ `torch.Tensor`: The output of the attention layer.
373
+ """
374
+ # The `Attention` class can call different attention processors / attention functions
375
+ # here we simply pass along all tensors to the selected processor class
376
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
377
+ return self.processor(
378
+ self,
379
+ hidden_states,
380
+ encoder_hidden_states=encoder_hidden_states,
381
+ attention_mask=attention_mask,
382
+ **cross_attention_kwargs,
383
+ )
384
+
385
+ Attention.forward = forward
custum_3d_diffusion/custum_modules/unifield_processor.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import FunctionType
2
+ from typing import Any, Dict, List
3
+ from diffusers import UNet2DConditionModel
4
+ import torch
5
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, ImageProjection
6
+ from diffusers.models.attention_processor import Attention, AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
7
+ from dataclasses import dataclass, field
8
+ from diffusers.loaders import IPAdapterMixin
9
+ from custum_3d_diffusion.custum_modules.attention_processors import add_extra_processor, switch_extra_processor, add_multiview_processor, switch_multiview_processor, add_switch, change_switch
10
+
11
+ @dataclass
12
+ class AttnConfig:
13
+ """
14
+ * CrossAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), IPAdapter module (achieves conceptual control).
15
+ * SelfAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), Reference Attention module (achieves pixel-level control).
16
+ * Multiview Attention module: Multiview Attention module (achieves multi-view consistency).
17
+ * Cross Modality Attention module: Cross Modality Attention module (achieves multi-modality consistency).
18
+
19
+ For setups:
20
+ train_xxx_lr is implemented in the U-Net architecture.
21
+ enable_xxx_lora is implemented in the U-Net architecture.
22
+ enable_xxx_ip is implemented in the processor and U-Net architecture.
23
+ enable_xxx_ref_proj_in is implemented in the processor.
24
+ """
25
+ latent_size: int = 64
26
+
27
+ train_lr: float = 0
28
+ # for cross attention
29
+ # 0 learning rate for not training
30
+ train_cross_attn_lr: float = 0
31
+ train_cross_attn_lora_lr: float = 0
32
+ train_cross_attn_ip_lr: float = 0 # 0 for not trained
33
+ init_cross_attn_lora: bool = False
34
+ enable_cross_attn_lora: bool = False
35
+ init_cross_attn_ip: bool = False
36
+ enable_cross_attn_ip: bool = False
37
+ cross_attn_lora_rank: int = 64 # 0 for not enabled
38
+ cross_attn_lora_only_kv: bool = False
39
+ ipadapter_pretrained_name: str = "h94/IP-Adapter"
40
+ ipadapter_subfolder_name: str = "models"
41
+ ipadapter_weight_name: str = "ip-adapter-plus_sd15.safetensors"
42
+ ipadapter_effect_on: str = "all" # all, first
43
+
44
+ # for self attention
45
+ train_self_attn_lr: float = 0
46
+ train_self_attn_lora_lr: float = 0
47
+ init_self_attn_lora: bool = False
48
+ enable_self_attn_lora: bool = False
49
+ self_attn_lora_rank: int = 64
50
+ self_attn_lora_only_kv: bool = False
51
+
52
+ train_self_attn_ref_lr: float = 0
53
+ train_ref_unet_lr: float = 0
54
+ init_self_attn_ref: bool = False
55
+ enable_self_attn_ref: bool = False
56
+ self_attn_ref_other_model_name: str = ""
57
+ self_attn_ref_position: str = "attn1"
58
+ self_attn_ref_pixel_wise_crosspond: bool = False # enable pixel_wise_crosspond in refattn
59
+ self_attn_ref_chain_pos: str = "parralle" # before or parralle or after
60
+ self_attn_ref_effect_on: str = "all" # all or first, for _crosspond attn
61
+ self_attn_ref_zero_init: bool = True
62
+ use_simple3d_attn: bool = False
63
+
64
+ # for multiview attention
65
+ init_multiview_attn: bool = False
66
+ enable_multiview_attn: bool = False
67
+ multiview_attn_position: str = "attn1"
68
+ multiview_chain_pose: str = "parralle" # before or parralle or after
69
+ num_modalities: int = 1
70
+ use_mv_joint_attn: bool = False
71
+
72
+ # for unet
73
+ init_unet_path: str = "runwayml/stable-diffusion-v1-5"
74
+ init_num_cls_label: int = 0 # for initialize
75
+ cls_labels: List[int] = field(default_factory=lambda: [])
76
+ cls_label_type: str = "embedding"
77
+ cat_condition: bool = False # cat condition to input
78
+
79
+ class Configurable:
80
+ attn_config: AttnConfig
81
+
82
+ def set_config(self, attn_config: AttnConfig):
83
+ raise NotImplementedError()
84
+
85
+ def update_config(self, attn_config: AttnConfig):
86
+ self.attn_config = attn_config
87
+
88
+ def do_set_config(self, attn_config: AttnConfig):
89
+ self.set_config(attn_config)
90
+ for name, module in self.named_modules():
91
+ if isinstance(module, Configurable):
92
+ if hasattr(module, "do_set_config"):
93
+ module.do_set_config(attn_config)
94
+ else:
95
+ print(f"Warning: {name} has no attribute do_set_config, but is an instance of Configurable")
96
+ module.attn_config = attn_config
97
+
98
+ def do_update_config(self, attn_config: AttnConfig):
99
+ self.update_config(attn_config)
100
+ for name, module in self.named_modules():
101
+ if isinstance(module, Configurable):
102
+ if hasattr(module, "do_update_config"):
103
+ module.do_update_config(attn_config)
104
+ else:
105
+ print(f"Warning: {name} has no attribute do_update_config, but is an instance of Configurable")
106
+ module.attn_config = attn_config
107
+
108
+ from diffusers import ModelMixin # Must import ModelMixin for CompiledUNet
109
+ class UnifieldWrappedUNet(UNet2DConditionModel):
110
+ forward_hook: FunctionType
111
+
112
+ def forward(self, *args, **kwargs):
113
+ if hasattr(self, 'forward_hook'):
114
+ return self.forward_hook(super().forward, *args, **kwargs)
115
+ return super().forward(*args, **kwargs)
116
+
117
+
118
+ class ConfigurableUNet2DConditionModel(Configurable, IPAdapterMixin):
119
+ unet: UNet2DConditionModel
120
+
121
+ cls_embedding_param_dict = {}
122
+ cross_attn_lora_param_dict = {}
123
+ self_attn_lora_param_dict = {}
124
+ cross_attn_param_dict = {}
125
+ self_attn_param_dict = {}
126
+ ipadapter_param_dict = {}
127
+ ref_attn_param_dict = {}
128
+ ref_unet_param_dict = {}
129
+ multiview_attn_param_dict = {}
130
+ other_param_dict = {}
131
+
132
+ rev_param_name_mapping = {}
133
+
134
+ class_labels = []
135
+ def set_class_labels(self, class_labels: torch.Tensor):
136
+ if self.attn_config.init_num_cls_label != 0:
137
+ self.class_labels = class_labels.to(self.unet.device).long()
138
+
139
+ def __init__(self, init_config: AttnConfig, weight_dtype) -> None:
140
+ super().__init__()
141
+ self.weight_dtype = weight_dtype
142
+ self.set_config(init_config)
143
+
144
+ def enable_xformers_memory_efficient_attention(self):
145
+ self.unet.enable_xformers_memory_efficient_attention
146
+ def recursive_add_processors(name: str, module: torch.nn.Module):
147
+ for sub_name, child in module.named_children():
148
+ recursive_add_processors(f"{name}.{sub_name}", child)
149
+
150
+ if isinstance(module, Attention):
151
+ if hasattr(module, 'xformers_not_supported'):
152
+ return
153
+ old_processor = module.get_processor()
154
+ if isinstance(old_processor, (AttnProcessor, AttnProcessor2_0)):
155
+ module.set_use_memory_efficient_attention_xformers(True)
156
+
157
+ for name, module in self.unet.named_children():
158
+ recursive_add_processors(name, module)
159
+
160
+ def __getattr__(self, name: str) -> Any:
161
+ try:
162
+ return super().__getattr__(name)
163
+ except AttributeError:
164
+ return getattr(self.unet, name)
165
+
166
+ # --- for IPAdapterMixin
167
+
168
+ def register_modules(self, **kwargs):
169
+ for name, module in kwargs.items():
170
+ # set models
171
+ setattr(self, name, module)
172
+
173
+ def register_to_config(self, **kwargs):
174
+ pass
175
+
176
+ def unload_ip_adapter(self):
177
+ raise NotImplementedError()
178
+
179
+ # --- for Configurable
180
+
181
+ def get_refunet(self):
182
+ if self.attn_config.self_attn_ref_other_model_name == "self":
183
+ return self.unet
184
+ else:
185
+ return self.unet.ref_unet
186
+
187
+ def set_config(self, attn_config: AttnConfig):
188
+ self.attn_config = attn_config
189
+
190
+ unet_type = UnifieldWrappedUNet
191
+ # class_embed_type = "projection" for 'camera'
192
+ # class_embed_type = None for 'embedding'
193
+ unet_kwargs = {}
194
+ if attn_config.init_num_cls_label > 0:
195
+ if attn_config.cls_label_type == "embedding":
196
+ unet_kwargs = {
197
+ "num_class_embeds": attn_config.init_num_cls_label,
198
+ "device_map": None,
199
+ "low_cpu_mem_usage": False,
200
+ "class_embed_type": None,
201
+ }
202
+ else:
203
+ raise ValueError(f"cls_label_type {attn_config.cls_label_type} is not supported")
204
+
205
+ self.unet: UnifieldWrappedUNet = unet_type.from_pretrained(
206
+ attn_config.init_unet_path, subfolder="unet", torch_dtype=self.weight_dtype,
207
+ **unet_kwargs
208
+ )
209
+ assert isinstance(self.unet, UnifieldWrappedUNet)
210
+ self.unet.forward_hook = self.unet_forward_hook
211
+
212
+ if self.attn_config.cat_condition:
213
+ # double in_channels
214
+ if self.unet.config.in_channels != 8:
215
+ self.unet.register_to_config(in_channels=self.unet.config.in_channels * 2)
216
+ # repeate unet.conv_in weight twice
217
+ doubled_conv_in = torch.nn.Conv2d(self.unet.conv_in.in_channels * 2, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
218
+ doubled_conv_in.weight.data = torch.cat([self.unet.conv_in.weight.data, torch.zeros_like(self.unet.conv_in.weight.data)], dim=1)
219
+ doubled_conv_in.bias.data = self.unet.conv_in.bias.data
220
+ self.unet.conv_in = doubled_conv_in
221
+
222
+ used_param_ids = set()
223
+
224
+ if attn_config.init_cross_attn_lora:
225
+ # setup lora
226
+ from peft import LoraConfig
227
+ from peft.utils import get_peft_model_state_dict
228
+ if attn_config.cross_attn_lora_only_kv:
229
+ target_modules=["attn2.to_k", "attn2.to_v"]
230
+ else:
231
+ target_modules=["attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0"]
232
+ lora_config: LoraConfig = LoraConfig(
233
+ r=attn_config.cross_attn_lora_rank,
234
+ lora_alpha=attn_config.cross_attn_lora_rank,
235
+ init_lora_weights="gaussian",
236
+ target_modules=target_modules,
237
+ )
238
+ adapter_name="cross_attn_lora"
239
+ self.unet.add_adapter(lora_config, adapter_name=adapter_name)
240
+ # update cross_attn_lora_param_dict
241
+ self.cross_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
242
+ used_param_ids.update(self.cross_attn_lora_param_dict.keys())
243
+
244
+ if attn_config.init_self_attn_lora:
245
+ # setup lora
246
+ from peft import LoraConfig
247
+ if attn_config.self_attn_lora_only_kv:
248
+ target_modules=["attn1.to_k", "attn1.to_v"]
249
+ else:
250
+ target_modules=["attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0"]
251
+ lora_config: LoraConfig = LoraConfig(
252
+ r=attn_config.self_attn_lora_rank,
253
+ lora_alpha=attn_config.self_attn_lora_rank,
254
+ init_lora_weights="gaussian",
255
+ target_modules=target_modules,
256
+ )
257
+ adapter_name="self_attn_lora"
258
+ self.unet.add_adapter(lora_config, adapter_name=adapter_name)
259
+ # update cross_self_lora_param_dict
260
+ self.self_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
261
+ used_param_ids.update(self.self_attn_lora_param_dict.keys())
262
+
263
+ if attn_config.init_num_cls_label != 0:
264
+ self.cls_embedding_param_dict = {id(param): param for param in self.unet.class_embedding.parameters()}
265
+ used_param_ids.update(self.cls_embedding_param_dict.keys())
266
+ self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
267
+
268
+ if attn_config.init_cross_attn_ip:
269
+ self.image_encoder = None
270
+ # setup ipadapter
271
+ self.load_ip_adapter(
272
+ attn_config.ipadapter_pretrained_name,
273
+ subfolder=attn_config.ipadapter_subfolder_name,
274
+ weight_name=attn_config.ipadapter_weight_name
275
+ )
276
+ # warp ip_adapter_attn_proc with switch
277
+ from diffusers.models.attention_processor import IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0
278
+ add_switch(self.unet, module_filter=lambda x: isinstance(x, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)), switch_dict_fn=lambda x: {"ipadapter": x, "default": XFormersAttnProcessor()}, switch_name="ipadapter_switch", enabled_proc="ipadapter")
279
+ # update ipadapter_param_dict
280
+ # weights are in attention processors and unet.encoder_hid_proj
281
+ self.ipadapter_param_dict = {id(param): param for param in self.unet.encoder_hid_proj.parameters() if id(param) not in used_param_ids}
282
+ used_param_ids.update(self.ipadapter_param_dict.keys())
283
+ print("DEBUG: ipadapter_param_dict len in encoder_hid_proj", len(self.ipadapter_param_dict))
284
+ for name, processor in self.unet.attn_processors.items():
285
+ if hasattr(processor, "to_k_ip"):
286
+ self.ipadapter_param_dict.update({id(param): param for param in processor.parameters()})
287
+ print(f"DEBUG: ipadapter_param_dict len in all", len(self.ipadapter_param_dict))
288
+
289
+ ref_unet = None
290
+ if attn_config.init_self_attn_ref:
291
+ # setup reference attention processor
292
+ if attn_config.self_attn_ref_other_model_name == "self":
293
+ raise NotImplementedError("self reference is not fully implemented")
294
+ else:
295
+ ref_unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
296
+ attn_config.self_attn_ref_other_model_name, subfolder="unet", torch_dtype=self.unet.dtype
297
+ )
298
+ ref_unet.to(self.unet.device)
299
+ if self.attn_config.train_ref_unet_lr == 0:
300
+ ref_unet.eval()
301
+ ref_unet.requires_grad_(False)
302
+ else:
303
+ ref_unet.train()
304
+
305
+ add_extra_processor(
306
+ model=ref_unet,
307
+ enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
308
+ mode='extract',
309
+ with_proj_in=False,
310
+ pixel_wise_crosspond=False,
311
+ )
312
+ # NOTE: Here require cross_attention_dim in two unet's self attention should be the same
313
+ processor_dict = add_extra_processor(
314
+ model=self.unet,
315
+ enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
316
+ mode='inject',
317
+ with_proj_in=False,
318
+ pixel_wise_crosspond=attn_config.self_attn_ref_pixel_wise_crosspond,
319
+ crosspond_effect_on=attn_config.self_attn_ref_effect_on,
320
+ crosspond_chain_pos=attn_config.self_attn_ref_chain_pos,
321
+ simple_3d=attn_config.use_simple3d_attn,
322
+ )
323
+ self.ref_unet_param_dict = {id(param): param for name, param in ref_unet.named_parameters() if id(param) not in used_param_ids and (attn_config.self_attn_ref_position in name)}
324
+ if attn_config.self_attn_ref_chain_pos != "after":
325
+ # pop untrainable paramters
326
+ for name, param in ref_unet.named_parameters():
327
+ if id(param) in self.ref_unet_param_dict and ('up_blocks.3.attentions.2.transformer_blocks.0.' in name):
328
+ self.ref_unet_param_dict.pop(id(param))
329
+ used_param_ids.update(self.ref_unet_param_dict.keys())
330
+ # update ref_attn_param_dict
331
+ self.ref_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
332
+ used_param_ids.update(self.ref_attn_param_dict.keys())
333
+
334
+ if attn_config.init_multiview_attn:
335
+ processor_dict = add_multiview_processor(
336
+ model = self.unet,
337
+ enable_filter = lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"),
338
+ num_modalities = attn_config.num_modalities,
339
+ base_img_size = attn_config.latent_size,
340
+ chain_pos = attn_config.multiview_chain_pose,
341
+ )
342
+ # update multiview_attn_param_dict
343
+ self.multiview_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
344
+ used_param_ids.update(self.multiview_attn_param_dict.keys())
345
+
346
+ # initialize cross_attn_param_dict parameters
347
+ self.cross_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn2" in name and id(param) not in used_param_ids}
348
+ used_param_ids.update(self.cross_attn_param_dict.keys())
349
+
350
+ # initialize self_attn_param_dict parameters
351
+ self.self_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn1" in name and id(param) not in used_param_ids}
352
+ used_param_ids.update(self.self_attn_param_dict.keys())
353
+
354
+ # initialize other_param_dict parameters
355
+ self.other_param_dict = {id(param): param for name, param in self.unet.named_parameters() if id(param) not in used_param_ids}
356
+
357
+ if ref_unet is not None:
358
+ self.unet.ref_unet = ref_unet
359
+
360
+ self.rev_param_name_mapping = {id(param): name for name, param in self.unet.named_parameters()}
361
+
362
+ self.update_config(attn_config, force_update=True)
363
+ return self.unet
364
+
365
+ _attn_keys_to_update = ["enable_cross_attn_lora", "enable_cross_attn_ip", "enable_self_attn_lora", "enable_self_attn_ref", "enable_multiview_attn", "cls_labels"]
366
+
367
+ def update_config(self, attn_config: AttnConfig, force_update=False):
368
+ assert isinstance(self.unet, UNet2DConditionModel), "unet must be an instance of UNet2DConditionModel"
369
+
370
+ need_to_update = False
371
+ # update cls_labels
372
+ for key in self._attn_keys_to_update:
373
+ if getattr(self.attn_config, key) != getattr(attn_config, key):
374
+ need_to_update = True
375
+ break
376
+ if not force_update and not need_to_update:
377
+ return
378
+
379
+ self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
380
+
381
+ # setup loras
382
+ if self.attn_config.init_cross_attn_lora or self.attn_config.init_self_attn_lora:
383
+ if attn_config.enable_cross_attn_lora or attn_config.enable_self_attn_lora:
384
+ cross_attn_lora_weight = 1. if attn_config.enable_cross_attn_lora > 0 else 0
385
+ self_attn_lora_weight = 1. if attn_config.enable_self_attn_lora > 0 else 0
386
+ self.unet.set_adapters(["cross_attn_lora", "self_attn_lora"], weights=[cross_attn_lora_weight, self_attn_lora_weight])
387
+ else:
388
+ self.unet.disable_adapters()
389
+
390
+ # setup ipadapter
391
+ if self.attn_config.init_cross_attn_ip:
392
+ if attn_config.enable_cross_attn_ip:
393
+ change_switch(self.unet, "ipadapter_switch", "ipadapter")
394
+ else:
395
+ change_switch(self.unet, "ipadapter_switch", "default")
396
+
397
+ # setup reference attention processor
398
+ if self.attn_config.init_self_attn_ref:
399
+ if attn_config.enable_self_attn_ref:
400
+ switch_extra_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"))
401
+ else:
402
+ switch_extra_processor(self.unet, enable_filter=lambda name: False)
403
+
404
+ # setup multiview attention processor
405
+ if self.attn_config.init_multiview_attn:
406
+ if attn_config.enable_multiview_attn:
407
+ switch_multiview_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"))
408
+ else:
409
+ switch_multiview_processor(self.unet, enable_filter=lambda name: False)
410
+
411
+ # update cls_labels
412
+ for key in self._attn_keys_to_update:
413
+ setattr(self.attn_config, key, getattr(attn_config, key))
414
+
415
+ def unet_forward_hook(self, raw_forward, sample: torch.FloatTensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, cross_attention_kwargs=None, condition_latents=None, class_labels=None, noisy_condition_input=False, cond_pixels_clip=None, **kwargs):
416
+ if class_labels is None and len(self.class_labels) > 0:
417
+ class_labels = self.class_labels.repeat(sample.shape[0] // self.class_labels.shape[0]).to(sample.device)
418
+ elif self.attn_config.init_num_cls_label != 0:
419
+ assert class_labels is not None, "class_labels should be passed if self.class_labels is empty and self.attn_config.init_num_cls_label is not 0"
420
+ if class_labels is not None:
421
+ if self.attn_config.cls_label_type == "embedding":
422
+ pass
423
+ else:
424
+ raise ValueError(f"cls_label_type {self.attn_config.cls_label_type} is not supported")
425
+ if self.attn_config.init_self_attn_ref and self.attn_config.enable_self_attn_ref:
426
+ # NOTE: extra step, extract condition
427
+ ref_dict = {}
428
+ ref_unet = self.get_refunet().to(sample.device)
429
+ assert condition_latents is not None
430
+ if self.attn_config.self_attn_ref_other_model_name == "self":
431
+ raise NotImplementedError()
432
+ else:
433
+ with torch.no_grad():
434
+ cond_encoder_hidden_states = encoder_hidden_states.reshape(condition_latents.shape[0], -1, *encoder_hidden_states.shape[1:])[:, 0]
435
+ if timestep.dim() == 0:
436
+ cond_timestep = timestep
437
+ else:
438
+ cond_timestep = timestep.reshape(condition_latents.shape[0], -1)[:, 0]
439
+ ref_unet(condition_latents, cond_timestep, cond_encoder_hidden_states, cross_attention_kwargs=dict(ref_dict=ref_dict))
440
+ # NOTE: extra step, inject condition
441
+ # Predict the noise residual and compute loss
442
+ if cross_attention_kwargs is None:
443
+ cross_attention_kwargs = {}
444
+ cross_attention_kwargs.update(ref_dict=ref_dict, mode='inject')
445
+ elif condition_latents is not None:
446
+ if not hasattr(self, 'condition_latents_raised'):
447
+ print("Warning! condition_latents is not None, but self_attn_ref is not enabled! This warning will only be raised once.")
448
+ self.condition_latents_raised = True
449
+
450
+ if self.attn_config.init_cross_attn_ip:
451
+ raise NotImplementedError()
452
+
453
+ if self.attn_config.cat_condition:
454
+ assert condition_latents is not None
455
+ B = condition_latents.shape[0]
456
+ cat_latents = condition_latents.reshape(B, 1, *condition_latents.shape[1:]).repeat(1, sample.shape[0] // B, 1, 1, 1).reshape(*sample.shape)
457
+ sample = torch.cat([sample, cat_latents], dim=1)
458
+
459
+ return raw_forward(sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, **kwargs)
custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # modified by Wuvin
15
+
16
+
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
23
+ from diffusers.schedulers import KarrasDiffusionSchedulers
24
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
25
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
26
+ from PIL import Image
27
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
28
+
29
+
30
+
31
+ class StableDiffusionImageCustomPipeline(
32
+ StableDiffusionImageVariationPipeline
33
+ ):
34
+ def __init__(
35
+ self,
36
+ vae: AutoencoderKL,
37
+ image_encoder: CLIPVisionModelWithProjection,
38
+ unet: UNet2DConditionModel,
39
+ scheduler: KarrasDiffusionSchedulers,
40
+ safety_checker: StableDiffusionSafetyChecker,
41
+ feature_extractor: CLIPImageProcessor,
42
+ requires_safety_checker: bool = True,
43
+ latents_offset=None,
44
+ noisy_cond_latents=False,
45
+ ):
46
+ super().__init__(
47
+ vae=vae,
48
+ image_encoder=image_encoder,
49
+ unet=unet,
50
+ scheduler=scheduler,
51
+ safety_checker=safety_checker,
52
+ feature_extractor=feature_extractor,
53
+ requires_safety_checker=requires_safety_checker
54
+ )
55
+ latents_offset = tuple(latents_offset) if latents_offset is not None else None
56
+ self.latents_offset = latents_offset
57
+ if latents_offset is not None:
58
+ self.register_to_config(latents_offset=latents_offset)
59
+ self.noisy_cond_latents = noisy_cond_latents
60
+ self.register_to_config(noisy_cond_latents=noisy_cond_latents)
61
+
62
+ def encode_latents(self, image, device, dtype, height, width):
63
+ # support batchsize > 1
64
+ if isinstance(image, Image.Image):
65
+ image = [image]
66
+ image = [img.convert("RGB") for img in image]
67
+ images = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype)
68
+ latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
69
+ if self.latents_offset is not None:
70
+ return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
71
+ else:
72
+ return latents
73
+
74
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
75
+ dtype = next(self.image_encoder.parameters()).dtype
76
+
77
+ if not isinstance(image, torch.Tensor):
78
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
79
+
80
+ image = image.to(device=device, dtype=dtype)
81
+ image_embeddings = self.image_encoder(image).image_embeds
82
+ image_embeddings = image_embeddings.unsqueeze(1)
83
+
84
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
85
+ bs_embed, seq_len, _ = image_embeddings.shape
86
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
87
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
88
+
89
+ if do_classifier_free_guidance:
90
+ # NOTE: the same as original code
91
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
92
+ # For classifier free guidance, we need to do two forward passes.
93
+ # Here we concatenate the unconditional and text embeddings into a single batch
94
+ # to avoid doing two forward passes
95
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
96
+
97
+ return image_embeddings
98
+
99
+ @torch.no_grad()
100
+ def __call__(
101
+ self,
102
+ image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
103
+ height: Optional[int] = 1024,
104
+ width: Optional[int] = 1024,
105
+ height_cond: Optional[int] = 512,
106
+ width_cond: Optional[int] = 512,
107
+ num_inference_steps: int = 50,
108
+ guidance_scale: float = 7.5,
109
+ num_images_per_prompt: Optional[int] = 1,
110
+ eta: float = 0.0,
111
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
112
+ latents: Optional[torch.FloatTensor] = None,
113
+ output_type: Optional[str] = "pil",
114
+ return_dict: bool = True,
115
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
116
+ callback_steps: int = 1,
117
+ upper_left_feature: bool = False,
118
+ ):
119
+ r"""
120
+ The call function to the pipeline for generation.
121
+
122
+ Args:
123
+ image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
124
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
125
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
126
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
127
+ The height in pixels of the generated image.
128
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
129
+ The width in pixels of the generated image.
130
+ num_inference_steps (`int`, *optional*, defaults to 50):
131
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
132
+ expense of slower inference. This parameter is modulated by `strength`.
133
+ guidance_scale (`float`, *optional*, defaults to 7.5):
134
+ A higher guidance scale value encourages the model to generate images closely linked to the text
135
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
136
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
137
+ The number of images to generate per prompt.
138
+ eta (`float`, *optional*, defaults to 0.0):
139
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
140
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
141
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
142
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
143
+ generation deterministic.
144
+ latents (`torch.FloatTensor`, *optional*):
145
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
146
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
147
+ tensor is generated by sampling using the supplied random `generator`.
148
+ output_type (`str`, *optional*, defaults to `"pil"`):
149
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
150
+ return_dict (`bool`, *optional*, defaults to `True`):
151
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
152
+ plain tuple.
153
+ callback (`Callable`, *optional*):
154
+ A function that calls every `callback_steps` steps during inference. The function is called with the
155
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
156
+ callback_steps (`int`, *optional*, defaults to 1):
157
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
158
+ every step.
159
+
160
+ Returns:
161
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
162
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
163
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
164
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
165
+ "not-safe-for-work" (nsfw) content.
166
+
167
+ Examples:
168
+
169
+ ```py
170
+ from diffusers import StableDiffusionImageVariationPipeline
171
+ from PIL import Image
172
+ from io import BytesIO
173
+ import requests
174
+
175
+ pipe = StableDiffusionImageVariationPipeline.from_pretrained(
176
+ "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
177
+ )
178
+ pipe = pipe.to("cuda")
179
+
180
+ url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
181
+
182
+ response = requests.get(url)
183
+ image = Image.open(BytesIO(response.content)).convert("RGB")
184
+
185
+ out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
186
+ out["images"][0].save("result.jpg")
187
+ ```
188
+ """
189
+ # 0. Default height and width to unet
190
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
191
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
192
+
193
+ # 1. Check inputs. Raise error if not correct
194
+ self.check_inputs(image, height, width, callback_steps)
195
+
196
+ # 2. Define call parameters
197
+ if isinstance(image, Image.Image):
198
+ batch_size = 1
199
+ elif isinstance(image, list):
200
+ batch_size = len(image)
201
+ else:
202
+ batch_size = image.shape[0]
203
+ device = self._execution_device
204
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
205
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
206
+ # corresponds to doing no classifier free guidance.
207
+ do_classifier_free_guidance = guidance_scale > 1.0
208
+
209
+ # 3. Encode input image
210
+ if isinstance(image, Image.Image) and upper_left_feature:
211
+ # only use the first one of four images
212
+ emb_image = image.crop((0, 0, image.size[0] // 2, image.size[1] // 2))
213
+ else:
214
+ emb_image = image
215
+
216
+ image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
217
+ cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
218
+
219
+ # 4. Prepare timesteps
220
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
221
+ timesteps = self.scheduler.timesteps
222
+
223
+ # 5. Prepare latent variables
224
+ num_channels_latents = self.unet.config.out_channels
225
+ latents = self.prepare_latents(
226
+ batch_size * num_images_per_prompt,
227
+ num_channels_latents,
228
+ height,
229
+ width,
230
+ image_embeddings.dtype,
231
+ device,
232
+ generator,
233
+ latents,
234
+ )
235
+
236
+ # 6. Prepare extra step kwargs.
237
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
238
+
239
+ # 7. Denoising loop
240
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
241
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
242
+ for i, t in enumerate(timesteps):
243
+ if self.noisy_cond_latents:
244
+ raise ValueError("Noisy condition latents is not recommended.")
245
+ else:
246
+ noisy_cond_latents = cond_latents
247
+
248
+ noisy_cond_latents = torch.cat([torch.zeros_like(noisy_cond_latents), noisy_cond_latents]) if do_classifier_free_guidance else noisy_cond_latents
249
+ # expand the latents if we are doing classifier free guidance
250
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
251
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
252
+
253
+ # predict the noise residual
254
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=noisy_cond_latents).sample
255
+
256
+ # perform guidance
257
+ if do_classifier_free_guidance:
258
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
259
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
260
+
261
+ # compute the previous noisy sample x_t -> x_t-1
262
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
263
+
264
+ # call the callback, if provided
265
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
266
+ progress_bar.update()
267
+ if callback is not None and i % callback_steps == 0:
268
+ step_idx = i // getattr(self.scheduler, "order", 1)
269
+ callback(step_idx, t, latents)
270
+
271
+ self.maybe_free_model_hooks()
272
+
273
+ if self.latents_offset is not None:
274
+ latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
275
+
276
+ if not output_type == "latent":
277
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
278
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
279
+ else:
280
+ image = latents
281
+ has_nsfw_concept = None
282
+
283
+ if has_nsfw_concept is None:
284
+ do_denormalize = [True] * image.shape[0]
285
+ else:
286
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
287
+
288
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
289
+
290
+ self.maybe_free_model_hooks()
291
+
292
+ if not return_dict:
293
+ return (image, has_nsfw_concept)
294
+
295
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
296
+
297
+ if __name__ == "__main__":
298
+ pass
custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # modified by Wuvin
15
+
16
+
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
23
+ from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler
24
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
25
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
26
+ from PIL import Image
27
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
28
+
29
+
30
+
31
+ class StableDiffusionImage2MVCustomPipeline(
32
+ StableDiffusionImageVariationPipeline
33
+ ):
34
+ def __init__(
35
+ self,
36
+ vae: AutoencoderKL,
37
+ image_encoder: CLIPVisionModelWithProjection,
38
+ unet: UNet2DConditionModel,
39
+ scheduler: KarrasDiffusionSchedulers,
40
+ safety_checker: StableDiffusionSafetyChecker,
41
+ feature_extractor: CLIPImageProcessor,
42
+ requires_safety_checker: bool = True,
43
+ latents_offset=None,
44
+ noisy_cond_latents=False,
45
+ condition_offset=True,
46
+ ):
47
+ super().__init__(
48
+ vae=vae,
49
+ image_encoder=image_encoder,
50
+ unet=unet,
51
+ scheduler=scheduler,
52
+ safety_checker=safety_checker,
53
+ feature_extractor=feature_extractor,
54
+ requires_safety_checker=requires_safety_checker
55
+ )
56
+ latents_offset = tuple(latents_offset) if latents_offset is not None else None
57
+ self.latents_offset = latents_offset
58
+ if latents_offset is not None:
59
+ self.register_to_config(latents_offset=latents_offset)
60
+ if noisy_cond_latents:
61
+ raise NotImplementedError("Noisy condition latents not supported Now.")
62
+ self.condition_offset = condition_offset
63
+ self.register_to_config(condition_offset=condition_offset)
64
+
65
+ def encode_latents(self, image: Image.Image, device, dtype, height, width):
66
+ images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype)
67
+ # NOTE: .mode() for condition
68
+ latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
69
+ if self.latents_offset is not None and self.condition_offset:
70
+ return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
71
+ else:
72
+ return latents
73
+
74
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
75
+ dtype = next(self.image_encoder.parameters()).dtype
76
+
77
+ if not isinstance(image, torch.Tensor):
78
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
79
+
80
+ image = image.to(device=device, dtype=dtype)
81
+ image_embeddings = self.image_encoder(image).image_embeds
82
+ image_embeddings = image_embeddings.unsqueeze(1)
83
+
84
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
85
+ bs_embed, seq_len, _ = image_embeddings.shape
86
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
87
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
88
+
89
+ if do_classifier_free_guidance:
90
+ # NOTE: the same as original code
91
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
92
+ # For classifier free guidance, we need to do two forward passes.
93
+ # Here we concatenate the unconditional and text embeddings into a single batch
94
+ # to avoid doing two forward passes
95
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
96
+
97
+ return image_embeddings
98
+
99
+ @torch.no_grad()
100
+ def __call__(
101
+ self,
102
+ image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
103
+ height: Optional[int] = 1024,
104
+ width: Optional[int] = 1024,
105
+ height_cond: Optional[int] = 512,
106
+ width_cond: Optional[int] = 512,
107
+ num_inference_steps: int = 50,
108
+ guidance_scale: float = 7.5,
109
+ num_images_per_prompt: Optional[int] = 1,
110
+ eta: float = 0.0,
111
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
112
+ latents: Optional[torch.FloatTensor] = None,
113
+ output_type: Optional[str] = "pil",
114
+ return_dict: bool = True,
115
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
116
+ callback_steps: int = 1,
117
+ ):
118
+ r"""
119
+ The call function to the pipeline for generation.
120
+
121
+ Args:
122
+ image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
123
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
124
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
125
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
126
+ The height in pixels of the generated image.
127
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
128
+ The width in pixels of the generated image.
129
+ num_inference_steps (`int`, *optional*, defaults to 50):
130
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
131
+ expense of slower inference. This parameter is modulated by `strength`.
132
+ guidance_scale (`float`, *optional*, defaults to 7.5):
133
+ A higher guidance scale value encourages the model to generate images closely linked to the text
134
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
135
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
136
+ The number of images to generate per prompt.
137
+ eta (`float`, *optional*, defaults to 0.0):
138
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
139
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
140
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
141
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
142
+ generation deterministic.
143
+ latents (`torch.FloatTensor`, *optional*):
144
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
145
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
146
+ tensor is generated by sampling using the supplied random `generator`.
147
+ output_type (`str`, *optional*, defaults to `"pil"`):
148
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
149
+ return_dict (`bool`, *optional*, defaults to `True`):
150
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
151
+ plain tuple.
152
+ callback (`Callable`, *optional*):
153
+ A function that calls every `callback_steps` steps during inference. The function is called with the
154
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
155
+ callback_steps (`int`, *optional*, defaults to 1):
156
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
157
+ every step.
158
+
159
+ Returns:
160
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
161
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
162
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
163
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
164
+ "not-safe-for-work" (nsfw) content.
165
+
166
+ Examples:
167
+
168
+ ```py
169
+ from diffusers import StableDiffusionImageVariationPipeline
170
+ from PIL import Image
171
+ from io import BytesIO
172
+ import requests
173
+
174
+ pipe = StableDiffusionImageVariationPipeline.from_pretrained(
175
+ "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
176
+ )
177
+ pipe = pipe.to("cuda")
178
+
179
+ url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
180
+
181
+ response = requests.get(url)
182
+ image = Image.open(BytesIO(response.content)).convert("RGB")
183
+
184
+ out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
185
+ out["images"][0].save("result.jpg")
186
+ ```
187
+ """
188
+ # 0. Default height and width to unet
189
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
190
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
191
+
192
+ # 1. Check inputs. Raise error if not correct
193
+ self.check_inputs(image, height, width, callback_steps)
194
+
195
+ # 2. Define call parameters
196
+ if isinstance(image, Image.Image):
197
+ batch_size = 1
198
+ elif len(image) == 1:
199
+ image = image[0]
200
+ batch_size = 1
201
+ else:
202
+ raise NotImplementedError()
203
+ # elif isinstance(image, list):
204
+ # batch_size = len(image)
205
+ # else:
206
+ # batch_size = image.shape[0]
207
+ device = self._execution_device
208
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
209
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
210
+ # corresponds to doing no classifier free guidance.
211
+ do_classifier_free_guidance = guidance_scale > 1.0
212
+
213
+ # 3. Encode input image
214
+ emb_image = image
215
+
216
+ image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
217
+ cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
218
+ cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents
219
+ image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values
220
+ if do_classifier_free_guidance:
221
+ image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0)
222
+
223
+ # 4. Prepare timesteps
224
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
225
+ timesteps = self.scheduler.timesteps
226
+
227
+ # 5. Prepare latent variables
228
+ num_channels_latents = self.unet.config.out_channels
229
+ latents = self.prepare_latents(
230
+ batch_size * num_images_per_prompt,
231
+ num_channels_latents,
232
+ height,
233
+ width,
234
+ image_embeddings.dtype,
235
+ device,
236
+ generator,
237
+ latents,
238
+ )
239
+
240
+
241
+ # 6. Prepare extra step kwargs.
242
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
243
+ # 7. Denoising loop
244
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
245
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
246
+ for i, t in enumerate(timesteps):
247
+ # expand the latents if we are doing classifier free guidance
248
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
249
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
250
+
251
+ # predict the noise residual
252
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=cond_latents, noisy_condition_input=False, cond_pixels_clip=image_pixels).sample
253
+
254
+ # perform guidance
255
+ if do_classifier_free_guidance:
256
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
257
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
258
+
259
+ # compute the previous noisy sample x_t -> x_t-1
260
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
261
+
262
+ # call the callback, if provided
263
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
264
+ progress_bar.update()
265
+ if callback is not None and i % callback_steps == 0:
266
+ step_idx = i // getattr(self.scheduler, "order", 1)
267
+ callback(step_idx, t, latents)
268
+
269
+ self.maybe_free_model_hooks()
270
+
271
+ if self.latents_offset is not None:
272
+ latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
273
+
274
+ if not output_type == "latent":
275
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
276
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
277
+ else:
278
+ image = latents
279
+ has_nsfw_concept = None
280
+
281
+ if has_nsfw_concept is None:
282
+ do_denormalize = [True] * image.shape[0]
283
+ else:
284
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
285
+
286
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
287
+
288
+ self.maybe_free_model_hooks()
289
+
290
+ if not return_dict:
291
+ return (image, has_nsfw_concept)
292
+
293
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
294
+
295
+ if __name__ == "__main__":
296
+ pass
custum_3d_diffusion/modules.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __modules__ = {}
2
+
3
+ def register(name):
4
+ def decorator(cls):
5
+ __modules__[name] = cls
6
+ return cls
7
+
8
+ return decorator
9
+
10
+
11
+ def find(name):
12
+ return __modules__[name]
13
+
14
+ from custum_3d_diffusion.trainings import base, image2mvimage_trainer, image2image_trainer
custum_3d_diffusion/trainings/__init__.py ADDED
File without changes
custum_3d_diffusion/trainings/base.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from accelerate import Accelerator
3
+ from accelerate.logging import MultiProcessAdapter
4
+ from dataclasses import dataclass, field
5
+ from typing import Optional, Union
6
+ from datasets import load_dataset
7
+ import json
8
+ import abc
9
+ from diffusers.utils import make_image_grid
10
+ import numpy as np
11
+ import wandb
12
+
13
+ from custum_3d_diffusion.trainings.utils import load_config
14
+ from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig
15
+
16
+ class BasicTrainer(torch.nn.Module, abc.ABC):
17
+ accelerator: Accelerator
18
+ logger: MultiProcessAdapter
19
+ unet: ConfigurableUNet2DConditionModel
20
+ train_dataloader: torch.utils.data.DataLoader
21
+ test_dataset: torch.utils.data.Dataset
22
+ attn_config: AttnConfig
23
+
24
+ @dataclass
25
+ class TrainerConfig:
26
+ trainer_name: str = "basic"
27
+ pretrained_model_name_or_path: str = ""
28
+
29
+ attn_config: dict = field(default_factory=dict)
30
+ dataset_name: str = ""
31
+ dataset_config_name: Optional[str] = None
32
+ resolution: str = "1024"
33
+ dataloader_num_workers: int = 4
34
+ pair_sampler_group_size: int = 1
35
+ num_views: int = 4
36
+
37
+ max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps)
38
+ training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps
39
+ max_train_samples: Optional[int] = None
40
+ seed: Optional[int] = None # For dataset related operations and validation stuff
41
+ train_batch_size: int = 1
42
+
43
+ validation_interval: int = 5000
44
+ debug: bool = False
45
+
46
+ cfg: TrainerConfig # only enable_xxx is used
47
+
48
+ def __init__(
49
+ self,
50
+ accelerator: Accelerator,
51
+ logger: MultiProcessAdapter,
52
+ unet: ConfigurableUNet2DConditionModel,
53
+ config: Union[dict, str],
54
+ weight_dtype: torch.dtype,
55
+ index: int,
56
+ ):
57
+ super().__init__()
58
+ self.index = index # index in all trainers
59
+ self.accelerator = accelerator
60
+ self.logger = logger
61
+ self.unet = unet
62
+ self.weight_dtype = weight_dtype
63
+ self.ext_logs = {}
64
+ self.cfg = load_config(self.TrainerConfig, config)
65
+ self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
66
+ self.test_dataset = None
67
+ self.validate_trainer_config()
68
+ self.configure()
69
+
70
+ def get_HW(self):
71
+ resolution = json.loads(self.cfg.resolution)
72
+ if isinstance(resolution, int):
73
+ H = W = resolution
74
+ elif isinstance(resolution, list):
75
+ H, W = resolution
76
+ return H, W
77
+
78
+ def unet_update(self):
79
+ self.unet.update_config(self.attn_config)
80
+
81
+ def validate_trainer_config(self):
82
+ pass
83
+
84
+ def is_train_finished(self, current_step):
85
+ assert isinstance(self.cfg.max_train_steps, int)
86
+ return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
87
+
88
+ def next_train_step(self, current_step):
89
+ if self.is_train_finished(current_step):
90
+ return None
91
+ return current_step + self.cfg.training_step_interval
92
+
93
+ @classmethod
94
+ def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
95
+ catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
96
+ return make_image_grid(catted, rows=1, cols=len(catted))
97
+
98
+ def configure(self) -> None:
99
+ pass
100
+
101
+ @abc.abstractmethod
102
+ def init_shared_modules(self, shared_modules: dict) -> dict:
103
+ pass
104
+
105
+ def load_dataset(self):
106
+ dataset = load_dataset(
107
+ self.cfg.dataset_name,
108
+ self.cfg.dataset_config_name,
109
+ trust_remote_code=True
110
+ )
111
+ return dataset
112
+
113
+ @abc.abstractmethod
114
+ def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
115
+ """Both init train_dataloader and test_dataset, but returns train_dataloader only"""
116
+ pass
117
+
118
+ @abc.abstractmethod
119
+ def forward_step(
120
+ self,
121
+ *args,
122
+ **kwargs
123
+ ) -> torch.Tensor:
124
+ """
125
+ input a batch
126
+ return a loss
127
+ """
128
+ self.unet_update()
129
+ pass
130
+
131
+ @abc.abstractmethod
132
+ def construct_pipeline(self, shared_modules, unet):
133
+ pass
134
+
135
+ @abc.abstractmethod
136
+ def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
137
+ """
138
+ For inference time forward.
139
+ """
140
+ pass
141
+
142
+ @abc.abstractmethod
143
+ def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
144
+ pass
145
+
146
+ def do_validation(
147
+ self,
148
+ shared_modules,
149
+ unet,
150
+ global_step,
151
+ ):
152
+ self.unet_update()
153
+ self.logger.info("Running validation... ")
154
+ pipeline = self.construct_pipeline(shared_modules, unet)
155
+ pipeline.set_progress_bar_config(disable=True)
156
+ titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
157
+ for tracker in self.accelerator.trackers:
158
+ if tracker.name == "tensorboard":
159
+ np_images = np.stack([np.asarray(img) for img in images])
160
+ tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
161
+ elif tracker.name == "wandb":
162
+ [image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation
163
+ tracker.log({"validation": [
164
+ wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
165
+ for i, image in enumerate(images)]})
166
+ else:
167
+ self.logger.warn(f"image logging not implemented for {tracker.name}")
168
+ del pipeline
169
+ torch.cuda.empty_cache()
170
+ return images
171
+
172
+
173
+ @torch.no_grad()
174
+ def log_validation(
175
+ self,
176
+ shared_modules,
177
+ unet,
178
+ global_step,
179
+ force=False
180
+ ):
181
+ if self.accelerator.is_main_process:
182
+ for tracker in self.accelerator.trackers:
183
+ if tracker.name == "wandb":
184
+ tracker.log(self.ext_logs)
185
+ self.ext_logs = {}
186
+ if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
187
+ self.unet_update()
188
+ if self.accelerator.is_main_process:
189
+ self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)
190
+
191
+ def save_model(self, unwrap_unet, shared_modules, save_dir):
192
+ if self.accelerator.is_main_process:
193
+ pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
194
+ pipeline.save_pretrained(save_dir)
195
+ self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")
196
+
197
+ def save_debug_info(self, save_name="debug", **kwargs):
198
+ if self.cfg.debug:
199
+ to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
200
+ import pickle
201
+ import os
202
+ if os.path.exists(f"{save_name}.pkl"):
203
+ for i in range(100):
204
+ if not os.path.exists(f"{save_name}_v{i}.pkl"):
205
+ save_name = f"{save_name}_v{i}"
206
+ break
207
+ with open(f"{save_name}.pkl", "wb") as f:
208
+ pickle.dump(to_saves, f)
custum_3d_diffusion/trainings/config_classes.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional
3
+
4
+
5
+ @dataclass
6
+ class TrainerSubConfig:
7
+ trainer_type: str = ""
8
+ trainer: dict = field(default_factory=dict)
9
+
10
+
11
+ @dataclass
12
+ class ExprimentConfig:
13
+ trainers: List[dict] = field(default_factory=lambda: [])
14
+ init_config: dict = field(default_factory=dict)
15
+ pretrained_model_name_or_path: str = ""
16
+ pretrained_unet_state_dict_path: str = ""
17
+ # expriments related parameters
18
+ linear_beta_schedule: bool = False
19
+ zero_snr: bool = False
20
+ prediction_type: Optional[str] = None
21
+ seed: Optional[int] = None
22
+ max_train_steps: int = 1000000
23
+ gradient_accumulation_steps: int = 1
24
+ learning_rate: float = 1e-4
25
+ lr_scheduler: str = "constant"
26
+ lr_warmup_steps: int = 500
27
+ use_8bit_adam: bool = False
28
+ adam_beta1: float = 0.9
29
+ adam_beta2: float = 0.999
30
+ adam_weight_decay: float = 1e-2
31
+ adam_epsilon: float = 1e-08
32
+ max_grad_norm: float = 1.0
33
+ mixed_precision: Optional[str] = None # ["no", "fp16", "bf16", "fp8"]
34
+ skip_training: bool = False
35
+ debug: bool = False
custum_3d_diffusion/trainings/image2image_trainer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from diffusers import EulerAncestralDiscreteScheduler, DDPMScheduler
4
+ from dataclasses import dataclass
5
+
6
+ from custum_3d_diffusion.modules import register
7
+ from custum_3d_diffusion.trainings.image2mvimage_trainer import Image2MVImageTrainer
8
+ from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2img import StableDiffusionImageCustomPipeline
9
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
10
+
11
+ def get_HW(resolution):
12
+ if isinstance(resolution, str):
13
+ resolution = json.loads(resolution)
14
+ if isinstance(resolution, int):
15
+ H = W = resolution
16
+ elif isinstance(resolution, list):
17
+ H, W = resolution
18
+ return H, W
19
+
20
+
21
+ @register("image2image_trainer")
22
+ class Image2ImageTrainer(Image2MVImageTrainer):
23
+ """
24
+ Trainer for simple image to multiview images.
25
+ """
26
+ @dataclass
27
+ class TrainerConfig(Image2MVImageTrainer.TrainerConfig):
28
+ trainer_name: str = "image2image"
29
+
30
+ cfg: TrainerConfig
31
+
32
+ def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
33
+ raise NotImplementedError()
34
+
35
+ def construct_pipeline(self, shared_modules, unet, old_version=False):
36
+ MyPipeline = StableDiffusionImageCustomPipeline
37
+ pipeline = MyPipeline.from_pretrained(
38
+ self.cfg.pretrained_model_name_or_path,
39
+ vae=shared_modules['vae'],
40
+ image_encoder=shared_modules['image_encoder'],
41
+ feature_extractor=shared_modules['feature_extractor'],
42
+ unet=unet,
43
+ safety_checker=None,
44
+ torch_dtype=self.weight_dtype,
45
+ latents_offset=self.cfg.latents_offset,
46
+ noisy_cond_latents=self.cfg.noisy_condition_input,
47
+ )
48
+ pipeline.set_progress_bar_config(disable=True)
49
+ scheduler_dict = {}
50
+ if self.cfg.zero_snr:
51
+ scheduler_dict.update(rescale_betas_zero_snr=True)
52
+ if self.cfg.linear_beta_schedule:
53
+ scheduler_dict.update(beta_schedule='linear')
54
+
55
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
56
+ return pipeline
57
+
58
+ def get_forward_args(self):
59
+ if self.cfg.seed is None:
60
+ generator = None
61
+ else:
62
+ generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
63
+
64
+ H, W = get_HW(self.cfg.resolution)
65
+ H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
66
+
67
+ forward_args = dict(
68
+ num_images_per_prompt=1,
69
+ num_inference_steps=20,
70
+ height=H,
71
+ width=W,
72
+ height_cond=H_cond,
73
+ width_cond=W_cond,
74
+ generator=generator,
75
+ )
76
+ if self.cfg.zero_snr:
77
+ forward_args.update(guidance_rescale=0.7)
78
+ return forward_args
79
+
80
+ def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
81
+ forward_args = self.get_forward_args()
82
+ forward_args.update(pipeline_call_kwargs)
83
+ return pipeline(**forward_args)
84
+
85
+ def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
86
+ raise NotImplementedError()
custum_3d_diffusion/trainings/image2mvimage_trainer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler
3
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature
4
+
5
+ import json
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional
8
+
9
+ from custum_3d_diffusion.modules import register
10
+ from custum_3d_diffusion.trainings.base import BasicTrainer
11
+ from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline
12
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
+
14
+ def get_HW(resolution):
15
+ if isinstance(resolution, str):
16
+ resolution = json.loads(resolution)
17
+ if isinstance(resolution, int):
18
+ H = W = resolution
19
+ elif isinstance(resolution, list):
20
+ H, W = resolution
21
+ return H, W
22
+
23
+ @register("image2mvimage_trainer")
24
+ class Image2MVImageTrainer(BasicTrainer):
25
+ """
26
+ Trainer for simple image to multiview images.
27
+ """
28
+ @dataclass
29
+ class TrainerConfig(BasicTrainer.TrainerConfig):
30
+ trainer_name: str = "image2mvimage"
31
+ condition_image_column_name: str = "conditioning_image"
32
+ image_column_name: str = "image"
33
+ condition_dropout: float = 0.
34
+ condition_image_resolution: str = "512"
35
+ validation_images: Optional[List[str]] = None
36
+ noise_offset: float = 0.1
37
+ max_loss_drop: float = 0.
38
+ snr_gamma: float = 5.0
39
+ log_distribution: bool = False
40
+ latents_offset: Optional[List[float]] = None
41
+ input_perturbation: float = 0.
42
+ noisy_condition_input: bool = False # whether to add noise for ref unet input
43
+ normal_cls_offset: int = 0
44
+ condition_offset: bool = True
45
+ zero_snr: bool = False
46
+ linear_beta_schedule: bool = False
47
+
48
+ cfg: TrainerConfig
49
+
50
+ def configure(self) -> None:
51
+ return super().configure()
52
+
53
+ def init_shared_modules(self, shared_modules: dict) -> dict:
54
+ if 'vae' not in shared_modules:
55
+ vae = AutoencoderKL.from_pretrained(
56
+ self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype
57
+ )
58
+ vae.requires_grad_(False)
59
+ vae.to(self.accelerator.device, dtype=self.weight_dtype)
60
+ shared_modules['vae'] = vae
61
+ if 'image_encoder' not in shared_modules:
62
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
63
+ self.cfg.pretrained_model_name_or_path, subfolder="image_encoder"
64
+ )
65
+ image_encoder.requires_grad_(False)
66
+ image_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
67
+ shared_modules['image_encoder'] = image_encoder
68
+ if 'feature_extractor' not in shared_modules:
69
+ feature_extractor = CLIPImageProcessor.from_pretrained(
70
+ self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor"
71
+ )
72
+ shared_modules['feature_extractor'] = feature_extractor
73
+ return shared_modules
74
+
75
+ def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
76
+ raise NotImplementedError()
77
+
78
+ def loss_rescale(self, loss, timesteps=None):
79
+ raise NotImplementedError()
80
+
81
+ def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
82
+ raise NotImplementedError()
83
+
84
+ def construct_pipeline(self, shared_modules, unet, old_version=False):
85
+ MyPipeline = StableDiffusionImage2MVCustomPipeline
86
+ pipeline = MyPipeline.from_pretrained(
87
+ self.cfg.pretrained_model_name_or_path,
88
+ vae=shared_modules['vae'],
89
+ image_encoder=shared_modules['image_encoder'],
90
+ feature_extractor=shared_modules['feature_extractor'],
91
+ unet=unet,
92
+ safety_checker=None,
93
+ torch_dtype=self.weight_dtype,
94
+ latents_offset=self.cfg.latents_offset,
95
+ noisy_cond_latents=self.cfg.noisy_condition_input,
96
+ condition_offset=self.cfg.condition_offset,
97
+ )
98
+ pipeline.set_progress_bar_config(disable=True)
99
+ scheduler_dict = {}
100
+ if self.cfg.zero_snr:
101
+ scheduler_dict.update(rescale_betas_zero_snr=True)
102
+ if self.cfg.linear_beta_schedule:
103
+ scheduler_dict.update(beta_schedule='linear')
104
+
105
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
106
+ return pipeline
107
+
108
+ def get_forward_args(self):
109
+ if self.cfg.seed is None:
110
+ generator = None
111
+ else:
112
+ generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
113
+
114
+ H, W = get_HW(self.cfg.resolution)
115
+ H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
116
+
117
+ sub_img_H = H // 2
118
+ num_imgs = H // sub_img_H * W // sub_img_H
119
+
120
+ forward_args = dict(
121
+ num_images_per_prompt=num_imgs,
122
+ num_inference_steps=50,
123
+ height=sub_img_H,
124
+ width=sub_img_H,
125
+ height_cond=H_cond,
126
+ width_cond=W_cond,
127
+ generator=generator,
128
+ )
129
+ if self.cfg.zero_snr:
130
+ forward_args.update(guidance_rescale=0.7)
131
+ return forward_args
132
+
133
+ def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
134
+ forward_args = self.get_forward_args()
135
+ forward_args.update(pipeline_call_kwargs)
136
+ return pipeline(**forward_args)
137
+
138
+ def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
139
+ raise NotImplementedError()
custum_3d_diffusion/trainings/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import DictConfig, OmegaConf
2
+
3
+
4
+ def parse_structured(fields, cfg) -> DictConfig:
5
+ scfg = OmegaConf.structured(fields(**cfg))
6
+ return scfg
7
+
8
+
9
+ def load_config(fields, config, extras=None):
10
+ if extras is not None:
11
+ print("Warning! extra parameter in cli is not verified, may cause erros.")
12
+ if isinstance(config, str):
13
+ cfg = OmegaConf.load(config)
14
+ elif isinstance(config, dict):
15
+ cfg = OmegaConf.create(config)
16
+ elif isinstance(config, DictConfig):
17
+ cfg = config
18
+ else:
19
+ raise NotImplementedError(f"Unsupported config type {type(config)}")
20
+ if extras is not None:
21
+ cli_conf = OmegaConf.from_cli(extras)
22
+ cfg = OmegaConf.merge(cfg, cli_conf)
23
+ OmegaConf.resolve(cfg)
24
+ assert isinstance(cfg, DictConfig)
25
+ return parse_structured(fields, cfg)
gradio_app/__init__.py ADDED
File without changes
gradio_app/all_models.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scripts.sd_model_zoo import load_common_sd15_pipe
3
+ from diffusers import StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
4
+
5
+
6
+ class MyModelZoo:
7
+ _pipe_disney_controlnet_lineart_ipadapter_i2i: StableDiffusionControlNetImg2ImgPipeline = None
8
+
9
+ base_model = "runwayml/stable-diffusion-v1-5"
10
+
11
+ def __init__(self, base_model=None) -> None:
12
+ if base_model is not None:
13
+ self.base_model = base_model
14
+
15
+ @property
16
+ def pipe_disney_controlnet_tile_ipadapter_i2i(self):
17
+ return self._pipe_disney_controlnet_lineart_ipadapter_i2i
18
+
19
+ def init_models(self):
20
+ self._pipe_disney_controlnet_lineart_ipadapter_i2i = load_common_sd15_pipe(base_model=self.base_model, ip_adapter=True, plus_model=False, controlnet="./ckpt/controlnet-tile", pipeline_class=StableDiffusionControlNetImg2ImgPipeline)
21
+
22
+ model_zoo = MyModelZoo()
gradio_app/custom_models/image2mvimage.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: "./ckpt/img2mvimg"
2
+ mixed_precision: "bf16"
3
+
4
+ init_config:
5
+ # enable controls
6
+ enable_cross_attn_lora: False
7
+ enable_cross_attn_ip: False
8
+ enable_self_attn_lora: False
9
+ enable_self_attn_ref: False
10
+ enable_multiview_attn: True
11
+
12
+ # for cross attention
13
+ init_cross_attn_lora: False
14
+ init_cross_attn_ip: False
15
+ cross_attn_lora_rank: 256 # 0 for not enabled
16
+ cross_attn_lora_only_kv: False
17
+ ipadapter_pretrained_name: "h94/IP-Adapter"
18
+ ipadapter_subfolder_name: "models"
19
+ ipadapter_weight_name: "ip-adapter_sd15.safetensors"
20
+ ipadapter_effect_on: "all" # all, first
21
+
22
+ # for self attention
23
+ init_self_attn_lora: False
24
+ self_attn_lora_rank: 256
25
+ self_attn_lora_only_kv: False
26
+
27
+ # for self attention ref
28
+ init_self_attn_ref: False
29
+ self_attn_ref_position: "attn1"
30
+ self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
31
+ self_attn_ref_pixel_wise_crosspond: False
32
+ self_attn_ref_effect_on: "all"
33
+
34
+ # for multiview attention
35
+ init_multiview_attn: True
36
+ multiview_attn_position: "attn1"
37
+ use_mv_joint_attn: True
38
+ num_modalities: 1
39
+
40
+ # for unet
41
+ init_unet_path: "${pretrained_model_name_or_path}"
42
+ cat_condition: True # cat condition to input
43
+
44
+ # for cls embedding
45
+ init_num_cls_label: 8 # for initialize
46
+ cls_labels: [0, 1, 2, 3] # for current task
47
+
48
+ trainers:
49
+ - trainer_type: "image2mvimage_trainer"
50
+ trainer:
51
+ pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
52
+ attn_config:
53
+ cls_labels: [0, 1, 2, 3] # for current task
54
+ enable_cross_attn_lora: False
55
+ enable_cross_attn_ip: False
56
+ enable_self_attn_lora: False
57
+ enable_self_attn_ref: False
58
+ enable_multiview_attn: True
59
+ resolution: "256"
60
+ condition_image_resolution: "256"
61
+ normal_cls_offset: 4
62
+ condition_image_column_name: "conditioning_image"
63
+ image_column_name: "image"
gradio_app/custom_models/image2normal.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: "lambdalabs/sd-image-variations-diffusers"
2
+ mixed_precision: "bf16"
3
+
4
+ init_config:
5
+ # enable controls
6
+ enable_cross_attn_lora: False
7
+ enable_cross_attn_ip: False
8
+ enable_self_attn_lora: False
9
+ enable_self_attn_ref: True
10
+ enable_multiview_attn: False
11
+
12
+ # for cross attention
13
+ init_cross_attn_lora: False
14
+ init_cross_attn_ip: False
15
+ cross_attn_lora_rank: 512 # 0 for not enabled
16
+ cross_attn_lora_only_kv: False
17
+ ipadapter_pretrained_name: "h94/IP-Adapter"
18
+ ipadapter_subfolder_name: "models"
19
+ ipadapter_weight_name: "ip-adapter_sd15.safetensors"
20
+ ipadapter_effect_on: "all" # all, first
21
+
22
+ # for self attention
23
+ init_self_attn_lora: False
24
+ self_attn_lora_rank: 512
25
+ self_attn_lora_only_kv: False
26
+
27
+ # for self attention ref
28
+ init_self_attn_ref: True
29
+ self_attn_ref_position: "attn1"
30
+ self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
31
+ self_attn_ref_pixel_wise_crosspond: True
32
+ self_attn_ref_effect_on: "all"
33
+
34
+ # for multiview attention
35
+ init_multiview_attn: False
36
+ multiview_attn_position: "attn1"
37
+ num_modalities: 1
38
+
39
+ # for unet
40
+ init_unet_path: "${pretrained_model_name_or_path}"
41
+ init_num_cls_label: 0 # for initialize
42
+ cls_labels: [] # for current task
43
+
44
+ trainers:
45
+ - trainer_type: "image2image_trainer"
46
+ trainer:
47
+ pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
48
+ attn_config:
49
+ cls_labels: [] # for current task
50
+ enable_cross_attn_lora: False
51
+ enable_cross_attn_ip: False
52
+ enable_self_attn_lora: False
53
+ enable_self_attn_ref: True
54
+ enable_multiview_attn: False
55
+ resolution: "512"
56
+ condition_image_resolution: "512"
57
+ condition_image_column_name: "conditioning_image"
58
+ image_column_name: "image"
59
+
60
+
61
+
gradio_app/custom_models/mvimg_prediction.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import numpy as np
6
+ from rembg import remove
7
+ from gradio_app.utils import change_rgba_bg, rgba_to_rgb
8
+ from gradio_app.custom_models.utils import load_pipeline
9
+ from scripts.all_typing import *
10
+ from scripts.utils import session, simple_preprocess
11
+
12
+ training_config = "gradio_app/custom_models/image2mvimage.yaml"
13
+ checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
14
+
15
+ trainer, pipeline = load_pipeline(training_config, checkpoint_path)
16
+
17
+ def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
18
+ global pipeline
19
+ pipeline = pipeline.to("cuda")
20
+ if isinstance(img_list, Image.Image):
21
+ img_list = [img_list]
22
+ img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
23
+ ret = []
24
+ for img in img_list:
25
+ images = trainer.pipeline_forward(
26
+ pipeline=pipeline,
27
+ image=img,
28
+ guidance_scale=guidance_scale,
29
+ **kwargs
30
+ ).images
31
+ ret.extend(images)
32
+ return ret
33
+
34
+
35
+ def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145):
36
+ if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.:
37
+ # still do remove using rembg, since simple_preprocess requires RGBA image
38
+ print("RGB image not RGBA! still remove bg!")
39
+ remove_bg = True
40
+
41
+ if remove_bg:
42
+ input_image = remove(input_image, session=session)
43
+
44
+ # make front_pil RGBA with white bg
45
+ input_image = change_rgba_bg(input_image, "white")
46
+ single_image = simple_preprocess(input_image)
47
+
48
+ generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None
49
+
50
+ rgb_pils = predict(
51
+ single_image,
52
+ generator=generator,
53
+ guidance_scale=guidance_scale,
54
+ width=256,
55
+ height=256,
56
+ num_inference_steps=30,
57
+ )
58
+
59
+ return rgb_pils, single_image
gradio_app/custom_models/normal_prediction.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from PIL import Image
3
+ from gradio_app.utils import rgba_to_rgb, simple_remove
4
+ from gradio_app.custom_models.utils import load_pipeline
5
+ from scripts.utils import rotate_normals_torch
6
+ from scripts.all_typing import *
7
+
8
+ training_config = "gradio_app/custom_models/image2normal.yaml"
9
+ checkpoint_path = "ckpt/image2normal/unet_state_dict.pth"
10
+ trainer, pipeline = load_pipeline(training_config, checkpoint_path)
11
+
12
+ def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs):
13
+ global pipeline
14
+ pipeline = pipeline.to("cuda")
15
+
16
+ img_list = image if isinstance(image, list) else [image]
17
+ img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
18
+ images = trainer.pipeline_forward(
19
+ pipeline=pipeline,
20
+ image=img_list,
21
+ num_inference_steps=num_inference_steps,
22
+ guidance_scale=guidance_scale,
23
+ **kwargs
24
+ ).images
25
+ images = simple_remove(images)
26
+ if do_rotate and len(images) > 1:
27
+ images = rotate_normals_torch(images, return_types='pil')
28
+ return images
gradio_app/custom_models/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List
3
+ from dataclasses import dataclass
4
+ from gradio_app.utils import rgba_to_rgb
5
+ from custum_3d_diffusion.trainings.config_classes import ExprimentConfig, TrainerSubConfig
6
+ from custum_3d_diffusion import modules
7
+ from custum_3d_diffusion.custum_modules.unifield_processor import AttnConfig, ConfigurableUNet2DConditionModel
8
+ from custum_3d_diffusion.trainings.base import BasicTrainer
9
+ from custum_3d_diffusion.trainings.utils import load_config
10
+
11
+
12
+ @dataclass
13
+ class FakeAccelerator:
14
+ device: torch.device = torch.device("cuda")
15
+
16
+
17
+ def init_trainers(cfg_path: str, weight_dtype: torch.dtype, extras: dict):
18
+ accelerator = FakeAccelerator()
19
+ cfg: ExprimentConfig = load_config(ExprimentConfig, cfg_path, extras)
20
+ init_config: AttnConfig = load_config(AttnConfig, cfg.init_config)
21
+ configurable_unet = ConfigurableUNet2DConditionModel(init_config, weight_dtype)
22
+ configurable_unet.enable_xformers_memory_efficient_attention()
23
+ trainer_cfgs: List[TrainerSubConfig] = [load_config(TrainerSubConfig, trainer) for trainer in cfg.trainers]
24
+ trainers: List[BasicTrainer] = [modules.find(trainer.trainer_type)(accelerator, None, configurable_unet, trainer.trainer, weight_dtype, i) for i, trainer in enumerate(trainer_cfgs)]
25
+ return trainers, configurable_unet
26
+
27
+ from gradio_app.utils import make_image_grid, split_image
28
+ def process_image(function, img, guidance_scale=2., merged_image=False, remove_bg=True):
29
+ from rembg import remove
30
+ if remove_bg:
31
+ img = remove(img)
32
+ img = rgba_to_rgb(img)
33
+ if merged_image:
34
+ img = split_image(img, rows=2)
35
+ images = function(
36
+ image=img,
37
+ guidance_scale=guidance_scale,
38
+ )
39
+ if len(images) > 1:
40
+ return make_image_grid(images, rows=2)
41
+ else:
42
+ return images[0]
43
+
44
+
45
+ def process_text(trainer, pipeline, img, guidance_scale=2.):
46
+ pipeline.cfg.validation_prompts = [img]
47
+ titles, images = trainer.batched_validation_forward(pipeline, guidance_scale=[guidance_scale])
48
+ return images[0]
49
+
50
+
51
+ def load_pipeline(config_path, ckpt_path, pipeline_filter=lambda x: True, weight_dtype = torch.bfloat16):
52
+ training_config = config_path
53
+ load_from_checkpoint = ckpt_path
54
+ extras = []
55
+ device = "cuda"
56
+ trainers, configurable_unet = init_trainers(training_config, weight_dtype, extras)
57
+ shared_modules = dict()
58
+ for trainer in trainers:
59
+ shared_modules = trainer.init_shared_modules(shared_modules)
60
+
61
+ if load_from_checkpoint is not None:
62
+ state_dict = torch.load(load_from_checkpoint, map_location="cpu")
63
+ configurable_unet.unet.load_state_dict(state_dict, strict=False)
64
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
65
+ configurable_unet.unet.to(device, dtype=weight_dtype)
66
+
67
+ pipeline = None
68
+ trainer_out = None
69
+ for trainer in trainers:
70
+ if pipeline_filter(trainer.cfg.trainer_name):
71
+ pipeline = trainer.construct_pipeline(shared_modules, configurable_unet.unet)
72
+ pipeline.set_progress_bar_config(disable=False)
73
+ trainer_out = trainer
74
+ pipeline = pipeline.to(device, dtype=weight_dtype)
75
+ return trainer_out, pipeline
gradio_app/examples/Groot.png ADDED
gradio_app/examples/aaa.png ADDED
gradio_app/examples/abma.png ADDED
gradio_app/examples/akun.png ADDED
gradio_app/examples/anya.png ADDED
gradio_app/examples/bag.png ADDED

Git LFS Details

  • SHA256: ac798ea1f112091c04f5bdfa47c490806fb433a02fe17758aa1f8c55cd64b66e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.54 MB
gradio_app/examples/ex1.png ADDED

Git LFS Details

  • SHA256: d49ccccd40fe0317c2886b0d36a11667003d17a49cc49d9244208d250de9fe31
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
gradio_app/examples/ex2.png ADDED
gradio_app/examples/ex3.jpg ADDED
gradio_app/examples/ex4.png ADDED
gradio_app/examples/generated_1715761545_frame0.png ADDED
gradio_app/examples/generated_1715762357_frame0.png ADDED
gradio_app/examples/generated_1715763329_frame0.png ADDED
gradio_app/examples/hatsune_miku.png ADDED
gradio_app/examples/princess-large.png ADDED
gradio_app/gradio_3dgen.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from pytorch3d.structures import Meshes
6
+ from gradio_app.utils import clean_up
7
+ from gradio_app.custom_models.mvimg_prediction import run_mvprediction
8
+ from gradio_app.custom_models.normal_prediction import predict_normals
9
+ from scripts.refine_lr_to_sr import run_sr_fast
10
+ from scripts.utils import save_glb_and_video
11
+ # from scripts.multiview_inference import geo_reconstruct
12
+ from scripts.multiview_inference import geo_reconstruct_part1, geo_reconstruct_part2, geo_reconstruct_part3
13
+
14
+ @spaces.GPU(duration=100)
15
+ def run_mv(preview_img, input_processing, seed):
16
+ if preview_img.size[0] <= 512:
17
+ preview_img = run_sr_fast([preview_img])[0]
18
+ rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
19
+ return rgb_pils, front_pil
20
+
21
+ @spaces.GPU(duration=100) # seems split into multiple part will leads to `RuntimeError`, before fix it, still initialize here
22
+ def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
23
+ if preview_img is None:
24
+ raise gr.Error("The input image is none!")
25
+ if isinstance(preview_img, str):
26
+ preview_img = Image.open(preview_img)
27
+
28
+ rgb_pils, front_pil = run_mv(preview_img, input_processing, seed)
29
+
30
+ vertices, faces, img_list = geo_reconstruct_part1(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
31
+
32
+ meshes = geo_reconstruct_part2(vertices, faces)
33
+
34
+ new_meshes = geo_reconstruct_part3(meshes, img_list)
35
+
36
+ vertices = new_meshes.verts_packed()
37
+ vertices = vertices / 2 * 1.35
38
+ vertices[..., [0, 2]] = - vertices[..., [0, 2]]
39
+ new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures)
40
+
41
+ ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video)
42
+ return ret_mesh, video
43
+
44
+ #######################################
45
+ def create_ui(concurrency_id="wkl"):
46
+ with gr.Row():
47
+ with gr.Column(scale=1):
48
+ input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
49
+
50
+ example_folder = os.path.join(os.path.dirname(__file__), "./examples")
51
+ example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)])
52
+ gr.Examples(
53
+ examples=example_fns,
54
+ inputs=[input_image],
55
+ cache_examples=False,
56
+ label='Examples',
57
+ examples_per_page=12
58
+ )
59
+
60
+
61
+ with gr.Column(scale=1):
62
+ # export mesh display
63
+ output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320, camera_position=(90, 90, 2))
64
+ output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
65
+
66
+ input_processing = gr.Checkbox(
67
+ value=True,
68
+ label='Remove Background',
69
+ visible=True,
70
+ )
71
+ do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False)
72
+ expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
73
+ init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False)
74
+ setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed")
75
+ render_video = gr.Checkbox(value=False, visible=False, label="generate video")
76
+ fullrunv2_btn = gr.Button('Generate 3D', variant = "primary", interactive=True)
77
+
78
+ fullrunv2_btn.click(
79
+ fn = generate3dv2,
80
+ inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type],
81
+ outputs=[output_mesh, output_video],
82
+ concurrency_id=concurrency_id,
83
+ api_name="generate3dv2",
84
+ ).success(clean_up, api_name=False)
85
+ return input_image
gradio_app/gradio_3dgen_steps.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ from gradio_app.custom_models.mvimg_prediction import run_mvprediction
5
+ from gradio_app.utils import make_image_grid, split_image
6
+ from scripts.utils import save_glb_and_video
7
+
8
+ def concept_to_multiview(preview_img, input_processing, seed, guidance=1.):
9
+ seed = int(seed)
10
+ if preview_img is None:
11
+ raise gr.Error("preview_img is none.")
12
+ if isinstance(preview_img, str):
13
+ preview_img = Image.open(preview_img)
14
+
15
+ rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=seed, guidance_scale=guidance)
16
+ rgb_pil = make_image_grid(rgb_pils, rows=2)
17
+ return rgb_pil, front_pil
18
+
19
+ def concept_to_multiview_ui(concurrency_id="wkl"):
20
+ with gr.Row():
21
+ with gr.Column(scale=2):
22
+ preview_img = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
23
+ input_processing = gr.Checkbox(
24
+ value=True,
25
+ label='Remove Background',
26
+ )
27
+ seed = gr.Slider(minimum=-1, maximum=1000000000, value=-1, step=1.0, label="seed")
28
+ guidance = gr.Slider(minimum=1.0, maximum=5.0, value=1.0, label="Guidance Scale", step=0.5)
29
+ run_btn = gr.Button('Generate Multiview', interactive=True)
30
+ with gr.Column(scale=3):
31
+ # export mesh display
32
+ output_rgb = gr.Image(type='pil', label="RGB", show_label=True)
33
+ output_front = gr.Image(type='pil', image_mode='RGBA', label="Frontview", show_label=True)
34
+ run_btn.click(
35
+ fn = concept_to_multiview,
36
+ inputs=[preview_img, input_processing, seed, guidance],
37
+ outputs=[output_rgb, output_front],
38
+ concurrency_id=concurrency_id,
39
+ api_name=False,
40
+ )
41
+ return output_rgb, output_front
42
+
43
+ from gradio_app.custom_models.normal_prediction import predict_normals
44
+ from scripts.multiview_inference import geo_reconstruct
45
+ def multiview_to_mesh_v2(rgb_pil, normal_pil, front_pil, do_refine=False, expansion_weight=0.1, init_type="std"):
46
+ rgb_pils = split_image(rgb_pil, rows=2)
47
+ if normal_pil is not None:
48
+ normal_pil = split_image(normal_pil, rows=2)
49
+ if front_pil is None:
50
+ front_pil = rgb_pils[0]
51
+ new_meshes = geo_reconstruct(rgb_pils, normal_pil, front_pil, do_refine=do_refine, predict_normal=normal_pil is None, expansion_weight=expansion_weight, init_type=init_type)
52
+ ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=False)
53
+ return ret_mesh
54
+
55
+ def new_multiview_to_mesh_ui(concurrency_id="wkl"):
56
+ with gr.Row():
57
+ with gr.Column(scale=2):
58
+ rgb_pil = gr.Image(type='pil', image_mode='RGB', label='RGB')
59
+ front_pil = gr.Image(type='pil', image_mode='RGBA', label='Frontview(Optinal)')
60
+ normal_pil = gr.Image(type='pil', image_mode='RGBA', label='Normal(Optinal)')
61
+ do_refine = gr.Checkbox(
62
+ value=False,
63
+ label='Refine rgb',
64
+ visible=False,
65
+ )
66
+ expansion_weight = gr.Slider(minimum=-1.0, maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
67
+ init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh initialization", value="std", visible=False)
68
+ run_btn = gr.Button('Generate 3D', interactive=True)
69
+ with gr.Column(scale=3):
70
+ # export mesh display
71
+ output_mesh = gr.Model3D(value=None, label="mesh model", show_label=True)
72
+ run_btn.click(
73
+ fn = multiview_to_mesh_v2,
74
+ inputs=[rgb_pil, normal_pil, front_pil, do_refine, expansion_weight, init_type],
75
+ outputs=[output_mesh],
76
+ concurrency_id=concurrency_id,
77
+ api_name="multiview_to_mesh",
78
+ )
79
+ return rgb_pil, front_pil, output_mesh
80
+
81
+
82
+ #######################################
83
+ def create_step_ui(concurrency_id="wkl"):
84
+ with gr.Tab(label="3D:concept_to_multiview"):
85
+ concept_to_multiview_ui(concurrency_id)
86
+ with gr.Tab(label="3D:new_multiview_to_mesh"):
87
+ new_multiview_to_mesh_ui(concurrency_id)
gradio_app/gradio_local.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import os
3
+ import sys
4
+ sys.path.append(os.curdir)
5
+ if 'CUDA_VISIBLE_DEVICES' not in os.environ:
6
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
7
+ os.environ['TRANSFORMERS_OFFLINE']='0'
8
+ os.environ['DIFFUSERS_OFFLINE']='0'
9
+ os.environ['HF_HUB_OFFLINE']='0'
10
+ os.environ['GRADIO_ANALYTICS_ENABLED']='False'
11
+ os.environ['HF_ENDPOINT']='https://hf-mirror.com'
12
+ import torch
13
+ torch.set_float32_matmul_precision('medium')
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.set_grad_enabled(False)
16
+
17
+ import gradio as gr
18
+ import argparse
19
+
20
+ from gradio_app.gradio_3dgen import create_ui as create_3d_ui
21
+ # from app.gradio_3dgen_steps import create_step_ui
22
+ from gradio_app.all_models import model_zoo
23
+
24
+
25
+ _TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
26
+ _DESCRIPTION = '''
27
+ [Project page](https://wukailu.github.io/Unique3D/)
28
+
29
+ * High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
30
+
31
+ * The demo is still under construction, and more features are expected to be implemented soon.
32
+ '''
33
+
34
+ def launch(
35
+ port,
36
+ listen=False,
37
+ share=False,
38
+ gradio_root="",
39
+ ):
40
+ model_zoo.init_models()
41
+
42
+ with gr.Blocks(
43
+ title=_TITLE,
44
+ theme=gr.themes.Monochrome(),
45
+ ) as demo:
46
+ with gr.Row():
47
+ with gr.Column(scale=1):
48
+ gr.Markdown('# ' + _TITLE)
49
+ gr.Markdown(_DESCRIPTION)
50
+ create_3d_ui("wkl")
51
+
52
+ launch_args = {}
53
+ if listen:
54
+ launch_args["server_name"] = "0.0.0.0"
55
+
56
+ demo.queue(default_concurrency_limit=1).launch(
57
+ server_port=None if port == 0 else port,
58
+ share=share,
59
+ root_path=gradio_root if gradio_root != "" else None, # "/myapp"
60
+ **launch_args,
61
+ )
62
+
63
+ if __name__ == "__main__":
64
+ parser = argparse.ArgumentParser()
65
+ args, extra = parser.parse_known_args()
66
+ parser.add_argument("--listen", action="store_true")
67
+ parser.add_argument("--port", type=int, default=0)
68
+ parser.add_argument("--share", action="store_true")
69
+ parser.add_argument("--gradio_root", default="")
70
+ args = parser.parse_args()
71
+ launch(
72
+ args.port,
73
+ listen=args.listen,
74
+ share=args.share,
75
+ gradio_root=args.gradio_root,
76
+ )
gradio_app/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import gc
5
+ import numpy as np
6
+ import numpy as np
7
+ from PIL import Image
8
+ from scripts.refine_lr_to_sr import run_sr_fast
9
+
10
+ GRADIO_CACHE = "/tmp/gradio/"
11
+
12
+ def clean_up():
13
+ torch.cuda.empty_cache()
14
+ gc.collect()
15
+
16
+ def remove_color(arr):
17
+ if arr.shape[-1] == 4:
18
+ arr = arr[..., :3]
19
+ # calc diffs
20
+ base = arr[0, 0]
21
+ diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1)
22
+ alpha = (diffs <= 80)
23
+
24
+ arr[alpha] = 255
25
+ alpha = ~alpha
26
+ arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1)
27
+ return arr
28
+
29
+ def simple_remove(imgs, run_sr=True):
30
+ """Only works for normal"""
31
+ if not isinstance(imgs, list):
32
+ imgs = [imgs]
33
+ single_input = True
34
+ else:
35
+ single_input = False
36
+ if run_sr:
37
+ imgs = run_sr_fast(imgs)
38
+ rets = []
39
+ for img in imgs:
40
+ arr = np.array(img)
41
+ arr = remove_color(arr)
42
+ rets.append(Image.fromarray(arr.astype(np.uint8)))
43
+ if single_input:
44
+ return rets[0]
45
+ return rets
46
+
47
+ def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"):
48
+ new_image = Image.new("RGBA", rgba.size, bkgd)
49
+ new_image.paste(rgba, (0, 0), rgba)
50
+ new_image = new_image.convert('RGB')
51
+ return new_image
52
+
53
+ def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"):
54
+ rgb_white = rgba_to_rgb(rgba, bkgd)
55
+ new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1))
56
+ return new_rgba
57
+
58
+ def split_image(image, rows=None, cols=None):
59
+ """
60
+ inverse function of make_image_grid
61
+ """
62
+ # image is in square
63
+ if rows is None and cols is None:
64
+ # image.size [W, H]
65
+ rows = 1
66
+ cols = image.size[0] // image.size[1]
67
+ assert cols * image.size[1] == image.size[0]
68
+ subimg_size = image.size[1]
69
+ elif rows is None:
70
+ subimg_size = image.size[0] // cols
71
+ rows = image.size[1] // subimg_size
72
+ assert rows * subimg_size == image.size[1]
73
+ elif cols is None:
74
+ subimg_size = image.size[1] // rows
75
+ cols = image.size[0] // subimg_size
76
+ assert cols * subimg_size == image.size[0]
77
+ else:
78
+ subimg_size = image.size[1] // rows
79
+ assert cols * subimg_size == image.size[0]
80
+ subimgs = []
81
+ for i in range(rows):
82
+ for j in range(cols):
83
+ subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size))
84
+ subimgs.append(subimg)
85
+ return subimgs
86
+
87
+ def make_image_grid(images, rows=None, cols=None, resize=None):
88
+ if rows is None and cols is None:
89
+ rows = 1
90
+ cols = len(images)
91
+ if rows is None:
92
+ rows = len(images) // cols
93
+ if len(images) % cols != 0:
94
+ rows += 1
95
+ if cols is None:
96
+ cols = len(images) // rows
97
+ if len(images) % rows != 0:
98
+ cols += 1
99
+ total_imgs = rows * cols
100
+ if total_imgs > len(images):
101
+ images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))]
102
+
103
+ if resize is not None:
104
+ images = [img.resize((resize, resize)) for img in images]
105
+
106
+ w, h = images[0].size
107
+ grid = Image.new(images[0].mode, size=(cols * w, rows * h))
108
+
109
+ for i, img in enumerate(images):
110
+ grid.paste(img, box=(i % cols * w, i // cols * h))
111
+ return grid
112
+
mesh_reconstruction/func.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/Profactor/continuous-remeshing
2
+ import torch
3
+ import numpy as np
4
+ import trimesh
5
+ from typing import Tuple
6
+
7
+ def to_numpy(*args):
8
+ def convert(a):
9
+ if isinstance(a,torch.Tensor):
10
+ return a.detach().cpu().numpy()
11
+ assert a is None or isinstance(a,np.ndarray)
12
+ return a
13
+
14
+ return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args)
15
+
16
+ def laplacian(
17
+ num_verts:int,
18
+ edges: torch.Tensor #E,2
19
+ ) -> torch.Tensor: #sparse V,V
20
+ """create sparse Laplacian matrix"""
21
+ V = num_verts
22
+ E = edges.shape[0]
23
+
24
+ #adjacency matrix,
25
+ idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T # (2, 2*E)
26
+ ones = torch.ones(2*E, dtype=torch.float32, device=edges.device)
27
+ A = torch.sparse.FloatTensor(idx, ones, (V, V))
28
+
29
+ #degree matrix
30
+ deg = torch.sparse.sum(A, dim=1).to_dense()
31
+ idx = torch.arange(V, device=edges.device)
32
+ idx = torch.stack([idx, idx], dim=0)
33
+ D = torch.sparse.FloatTensor(idx, deg, (V, V))
34
+
35
+ return D - A
36
+
37
+ def _translation(x, y, z, device):
38
+ return torch.tensor([[1., 0, 0, x],
39
+ [0, 1, 0, y],
40
+ [0, 0, 1, z],
41
+ [0, 0, 0, 1]],device=device) #4,4
42
+
43
+ def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
44
+ """
45
+ see https://blog.csdn.net/wodownload2/article/details/85069240/
46
+ """
47
+ if l is None:
48
+ l = -r
49
+ if t is None:
50
+ t = r
51
+ if b is None:
52
+ b = -t
53
+ p = torch.zeros([4,4],device=device)
54
+ p[0,0] = 2*n/(r-l)
55
+ p[0,2] = (r+l)/(r-l)
56
+ p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1)
57
+ p[1,2] = (t+b)/(t-b)
58
+ p[2,2] = -(f+n)/(f-n)
59
+ p[2,3] = -(2*f*n)/(f-n)
60
+ p[3,2] = -1
61
+ return p #4,4
62
+
63
+ def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
64
+ if l is None:
65
+ l = -r
66
+ if t is None:
67
+ t = r
68
+ if b is None:
69
+ b = -t
70
+ o = torch.zeros([4,4],device=device)
71
+ o[0,0] = 2/(r-l)
72
+ o[0,3] = -(r+l)/(r-l)
73
+ o[1,1] = 2/(t-b) * (-1 if flip_y else 1)
74
+ o[1,3] = -(t+b)/(t-b)
75
+ o[2,2] = -2/(f-n)
76
+ o[2,3] = -(f+n)/(f-n)
77
+ o[3,3] = 1
78
+ return o #4,4
79
+
80
+ def make_star_cameras(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
81
+ if r is None:
82
+ r = 1/distance
83
+ A = az_count
84
+ P = pol_count
85
+ C = A * P
86
+
87
+ phi = torch.arange(0,A) * (2*torch.pi/A)
88
+ phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone()
89
+ phi_rot[:,0,2,2] = phi.cos()
90
+ phi_rot[:,0,2,0] = -phi.sin()
91
+ phi_rot[:,0,0,2] = phi.sin()
92
+ phi_rot[:,0,0,0] = phi.cos()
93
+
94
+ theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2
95
+ theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone()
96
+ theta_rot[0,:,1,1] = theta.cos()
97
+ theta_rot[0,:,1,2] = -theta.sin()
98
+ theta_rot[0,:,2,1] = theta.sin()
99
+ theta_rot[0,:,2,2] = theta.cos()
100
+
101
+ mv = torch.empty((C,4,4), device=device)
102
+ mv[:] = torch.eye(4, device=device)
103
+ mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3)
104
+ mv = _translation(0, 0, -distance, device) @ mv
105
+
106
+ return mv, _projection(r,device)
107
+
108
+ def make_star_cameras_orthographic(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
109
+ mv, _ = make_star_cameras(az_count,pol_count,distance,r,image_size,device)
110
+ if r is None:
111
+ r = 1
112
+ return mv, _orthographic(r,device)
113
+
114
+ def make_sphere(level:int=2,radius=1.,device='cuda') -> Tuple[torch.Tensor,torch.Tensor]:
115
+ sphere = trimesh.creation.icosphere(subdivisions=level, radius=1.0, color=None)
116
+ vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius
117
+ faces = torch.tensor(sphere.faces, device=device, dtype=torch.long)
118
+ return vertices,faces
119
+
120
+ from pytorch3d.renderer import (
121
+ FoVOrthographicCameras,
122
+ look_at_view_transform,
123
+ )
124
+
125
+ def get_camera(R, T, focal_length=1 / (2**0.5)):
126
+ focal_length = 1 / focal_length
127
+ camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
128
+ return camera
129
+
130
+ def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1):
131
+ R, T = look_at_view_transform(dist, 0, azim_list)
132
+ focal_length = 1 / focal
133
+ return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device)
mesh_reconstruction/opt.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/Profactor/continuous-remeshing
2
+ import time
3
+ import torch
4
+ import torch_scatter
5
+ from typing import Tuple
6
+ from mesh_reconstruction.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges
7
+
8
+ @torch.no_grad()
9
+ def remesh(
10
+ vertices_etc:torch.Tensor, #V,D
11
+ faces:torch.Tensor, #F,3 long
12
+ min_edgelen:torch.Tensor, #V
13
+ max_edgelen:torch.Tensor, #V
14
+ flip:bool,
15
+ max_vertices=1e6
16
+ ):
17
+
18
+ # dummies
19
+ vertices_etc,faces = prepend_dummies(vertices_etc,faces)
20
+ vertices = vertices_etc[:,:3] #V,3
21
+ nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device)
22
+ min_edgelen = torch.concat((nan_tensor,min_edgelen))
23
+ max_edgelen = torch.concat((nan_tensor,max_edgelen))
24
+
25
+ # collapse
26
+ edges,face_to_edge = calc_edges(faces) #E,2 F,3
27
+ edge_length = calc_edge_length(vertices,edges) #E
28
+ face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3
29
+ vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3
30
+ face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5)
31
+ shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0
32
+ priority = face_collapse.float() + shortness
33
+ vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority)
34
+
35
+ # split
36
+ if vertices.shape[0]<max_vertices:
37
+ edges,face_to_edge = calc_edges(faces) #E,2 F,3
38
+ vertices = vertices_etc[:,:3] #V,3
39
+ edge_length = calc_edge_length(vertices,edges) #E
40
+ splits = edge_length > max_edgelen[edges].mean(dim=-1)
41
+ vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False)
42
+
43
+ vertices_etc,faces = pack(vertices_etc,faces)
44
+ vertices = vertices_etc[:,:3]
45
+
46
+ if flip:
47
+ edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3
48
+ flip_edges(vertices,faces,edges,edge_to_face,with_border=False)
49
+
50
+ return remove_dummies(vertices_etc,faces)
51
+
52
+ def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int):
53
+ """lerp with adam's bias correction"""
54
+ c_prev = 1-weight**(step-1)
55
+ c = 1-weight**step
56
+ a_weight = weight*c_prev/c
57
+ b_weight = (1-weight)/c
58
+ a.mul_(a_weight).add_(b, alpha=b_weight)
59
+
60
+
61
+ class MeshOptimizer:
62
+ """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh()."""
63
+
64
+ def __init__(self,
65
+ vertices:torch.Tensor, #V,3
66
+ faces:torch.Tensor, #F,3
67
+ lr=0.3, #learning rate
68
+ betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu
69
+ gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing)
70
+ nu_ref=0.3, #reference velocity for edge length controller
71
+ edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length
72
+ edge_len_tol=.5, #edge length tolerance for split and collapse
73
+ gain=.2, #gain value for edge length controller
74
+ laplacian_weight=.02, #for laplacian smoothing/regularization
75
+ ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0])
76
+ grad_lim=10., #gradients are clipped to m1.abs()*grad_lim
77
+ remesh_interval=1, #larger intervals are faster but with worse mesh quality
78
+ local_edgelen=True, #set to False to use a global scalar reference edge length instead
79
+ ):
80
+ self._vertices = vertices
81
+ self._faces = faces
82
+ self._lr = lr
83
+ self._betas = betas
84
+ self._gammas = gammas
85
+ self._nu_ref = nu_ref
86
+ self._edge_len_lims = edge_len_lims
87
+ self._edge_len_tol = edge_len_tol
88
+ self._gain = gain
89
+ self._laplacian_weight = laplacian_weight
90
+ self._ramp = ramp
91
+ self._grad_lim = grad_lim
92
+ self._remesh_interval = remesh_interval
93
+ self._local_edgelen = local_edgelen
94
+ self._step = 0
95
+
96
+ V = self._vertices.shape[0]
97
+ # prepare continuous tensor for all vertex-based data
98
+ self._vertices_etc = torch.zeros([V,9],device=vertices.device)
99
+ self._split_vertices_etc()
100
+ self.vertices.copy_(vertices) #initialize vertices
101
+ self._vertices.requires_grad_()
102
+ self._ref_len.fill_(edge_len_lims[1])
103
+
104
+ @property
105
+ def vertices(self):
106
+ return self._vertices
107
+
108
+ @property
109
+ def faces(self):
110
+ return self._faces
111
+
112
+ def _split_vertices_etc(self):
113
+ self._vertices = self._vertices_etc[:,:3]
114
+ self._m2 = self._vertices_etc[:,3]
115
+ self._nu = self._vertices_etc[:,4]
116
+ self._m1 = self._vertices_etc[:,5:8]
117
+ self._ref_len = self._vertices_etc[:,8]
118
+
119
+ with_gammas = any(g!=0 for g in self._gammas)
120
+ self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3]
121
+
122
+ def zero_grad(self):
123
+ self._vertices.grad = None
124
+
125
+ @torch.no_grad()
126
+ def step(self):
127
+
128
+ eps = 1e-8
129
+
130
+ self._step += 1
131
+
132
+ # spatial smoothing
133
+ edges,_ = calc_edges(self._faces) #E,2
134
+ E = edges.shape[0]
135
+ edge_smooth = self._smooth[edges] #E,2,S
136
+ neighbor_smooth = torch.zeros_like(self._smooth) #V,S
137
+ torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth)
138
+
139
+ #apply optional smoothing of m1,m2,nu
140
+ if self._gammas[0]:
141
+ self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0])
142
+ if self._gammas[1]:
143
+ self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1])
144
+ if self._gammas[2]:
145
+ self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2])
146
+
147
+ #add laplace smoothing to gradients
148
+ laplace = self._vertices - neighbor_smooth[:,:3]
149
+ grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight)
150
+
151
+ #gradient clipping
152
+ if self._step>1:
153
+ grad_lim = self._m1.abs().mul_(self._grad_lim)
154
+ grad.clamp_(min=-grad_lim,max=grad_lim)
155
+
156
+ # moment updates
157
+ lerp_unbiased(self._m1, grad, self._betas[0], self._step)
158
+ lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step)
159
+
160
+ velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3
161
+ speed = velocity.norm(dim=-1) #V
162
+
163
+ if self._betas[2]:
164
+ lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V
165
+ else:
166
+ self._nu.copy_(speed) #V
167
+
168
+ # update vertices
169
+ ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp)
170
+ self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr)
171
+
172
+ # update target edge length
173
+ if self._step % self._remesh_interval == 0:
174
+ if self._local_edgelen:
175
+ len_change = (1 + (self._nu - self._nu_ref) * self._gain)
176
+ else:
177
+ len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain)
178
+ self._ref_len *= len_change
179
+ self._ref_len.clamp_(*self._edge_len_lims)
180
+
181
+ def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]:
182
+ min_edge_len = self._ref_len * (1 - self._edge_len_tol)
183
+ max_edge_len = self._ref_len * (1 + self._edge_len_tol)
184
+
185
+ self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e6)
186
+
187
+ self._split_vertices_etc()
188
+ self._vertices.requires_grad_()
189
+
190
+ return self._vertices, self._faces
mesh_reconstruction/recon.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ from typing import List
6
+ from mesh_reconstruction.remesh import calc_vertex_normals
7
+ from mesh_reconstruction.opt import MeshOptimizer
8
+ from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
9
+ from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
10
+ from scripts.utils import to_py3d_mesh, init_target
11
+
12
+ def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
13
+ vertices, faces = vertices.to("cuda"), faces.to("cuda")
14
+ assert len(pils) == 4
15
+ mv,proj = make_star_cameras_orthographic(4, 1)
16
+ renderer = NormalsRenderer(mv,proj,list(pils[0].size))
17
+ # cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
18
+ # renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda")
19
+
20
+ target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
21
+ # 1. no rotate
22
+ target_images = target_images[[0, 3, 2, 1]]
23
+
24
+ # 2. init from coarse mesh
25
+ opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len))
26
+
27
+ vertices = opt.vertices
28
+
29
+ mask = target_images[..., -1] < 0.5
30
+
31
+ for i in tqdm(range(steps)):
32
+ opt.zero_grad()
33
+ opt._lr *= decay
34
+ normals = calc_vertex_normals(vertices,faces)
35
+ images = renderer.render(vertices,normals,faces)
36
+
37
+ loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean()
38
+
39
+ t_mask = images[..., -1] > 0.5
40
+ loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean()
41
+ loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
42
+
43
+ loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight
44
+
45
+ # out of box
46
+ loss_oob = (vertices.abs() > 0.99).float().mean() * 10
47
+ loss = loss + loss_oob
48
+
49
+ loss.backward()
50
+ opt.step()
51
+
52
+ vertices,faces = opt.remesh(poisson=False)
53
+
54
+ vertices, faces = vertices.detach().cpu(), faces.detach().cpu()
55
+
56
+ if return_mesh:
57
+ return to_py3d_mesh(vertices, faces)
58
+ else:
59
+ return vertices, faces
mesh_reconstruction/refine.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from PIL import Image
3
+ import torch
4
+ from typing import List
5
+ from mesh_reconstruction.remesh import calc_vertex_normals
6
+ from mesh_reconstruction.opt import MeshOptimizer
7
+ from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
8
+ from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
9
+ from scripts.project_mesh import multiview_color_projection, get_cameras_list
10
+ from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
11
+
12
+ def run_mesh_refine(vertices, faces, pils: List[Image.Image], steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True):
13
+ vertices, faces = vertices.to("cuda"), faces.to("cuda")
14
+ if process_inputs:
15
+ vertices = vertices * 2 / 1.35
16
+ vertices[..., [0, 2]] = - vertices[..., [0, 2]]
17
+
18
+ poission_steps = []
19
+
20
+ assert len(pils) == 4
21
+ mv,proj = make_star_cameras_orthographic(4, 1)
22
+ renderer = NormalsRenderer(mv,proj,list(pils[0].size))
23
+ # cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
24
+ # renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda")
25
+
26
+ target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
27
+ # 1. no rotate
28
+ target_images = target_images[[0, 3, 2, 1]]
29
+
30
+ # 2. init from coarse mesh
31
+ opt = MeshOptimizer(vertices,faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02)
32
+
33
+ vertices = opt.vertices
34
+ alpha_init = None
35
+
36
+ mask = target_images[..., -1] < 0.5
37
+
38
+ for i in tqdm(range(steps)):
39
+ opt.zero_grad()
40
+ opt._lr *= decay
41
+ normals = calc_vertex_normals(vertices,faces)
42
+ images = renderer.render(vertices,normals,faces)
43
+ if alpha_init is None:
44
+ alpha_init = images.detach()
45
+
46
+ if i < update_warmup or i % update_normal_interval == 0:
47
+ with torch.no_grad():
48
+ py3d_mesh = to_py3d_mesh(vertices, faces, normals)
49
+ cameras = get_cameras_list(azim_list = [0, 90, 180, 270], device=vertices.device, focal=1.)
50
+ _, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=[2.0, 0.8, 1.0, 0.8], confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear'))
51
+ target_normal = target_normal * 2 - 1
52
+ target_normal = torch.nn.functional.normalize(target_normal, dim=-1)
53
+ debug_images = renderer.render(vertices,target_normal,faces)
54
+
55
+ d_mask = images[..., -1] > 0.5
56
+ loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean()
57
+
58
+ loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
59
+
60
+ loss = loss_debug_l2 + loss_alpha_target_mask_l2
61
+
62
+ # out of box
63
+ loss_oob = (vertices.abs() > 0.99).float().mean() * 10
64
+ loss = loss + loss_oob
65
+
66
+ loss.backward()
67
+ opt.step()
68
+
69
+ vertices,faces = opt.remesh(poisson=(i in poission_steps))
70
+
71
+ vertices, faces = vertices.detach().cpu(), faces.detach().cpu()
72
+
73
+ if process_outputs:
74
+ vertices = vertices / 2 * 1.35
75
+ vertices[..., [0, 2]] = - vertices[..., [0, 2]]
76
+
77
+ if return_mesh:
78
+ return to_py3d_mesh(vertices, faces)
79
+ else:
80
+ return vertices, faces
mesh_reconstruction/remesh.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/Profactor/continuous-remeshing
2
+ import torch
3
+ import torch.nn.functional as tfunc
4
+ import torch_scatter
5
+ from typing import Tuple
6
+
7
+ def prepend_dummies(
8
+ vertices:torch.Tensor, #V,D
9
+ faces:torch.Tensor, #F,3 long
10
+ )->Tuple[torch.Tensor,torch.Tensor]:
11
+ """prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
12
+ V,D = vertices.shape
13
+ vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
14
+ faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
15
+ return vertices,faces
16
+
17
+ def remove_dummies(
18
+ vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
19
+ faces:torch.Tensor, #F,3 long - first face all zeros
20
+ )->Tuple[torch.Tensor,torch.Tensor]:
21
+ """remove dummy elements added with prepend_dummies()"""
22
+ return vertices[1:],faces[1:]-1
23
+
24
+
25
+ def calc_edges(
26
+ faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros
27
+ with_edge_to_face: bool = False
28
+ ) -> Tuple[torch.Tensor, ...]:
29
+ """
30
+ returns Tuple of
31
+ - edges E,2 long, 0 for unused, lower vertex index first
32
+ - face_to_edge F,3 long
33
+ - (optional) edge_to_face shape=E,[left,right],[face,side]
34
+
35
+ o-<-----e1 e0,e1...edge, e0<e1
36
+ | /A L,R....left and right face
37
+ | L / | both triangles ordered counter clockwise
38
+ | / R | normals pointing out of screen
39
+ V/ |
40
+ e0---->-o
41
+ """
42
+
43
+ F = faces.shape[0]
44
+
45
+ # make full edges, lower vertex index first
46
+ face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
47
+ full_edges = face_edges.reshape(F*3,2)
48
+ sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2
49
+
50
+ # make unique edges
51
+ edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
52
+ E = edges.shape[0]
53
+ face_to_edge = full_to_unique.reshape(F,3) #F,3
54
+
55
+ if not with_edge_to_face:
56
+ return edges, face_to_edge
57
+
58
+ is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
59
+ edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
60
+ scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
61
+ edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
62
+ edge_to_face[0] = 0
63
+ return edges, face_to_edge, edge_to_face
64
+
65
+ def calc_edge_length(
66
+ vertices:torch.Tensor, #V,3 first may be dummy
67
+ edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
68
+ )->torch.Tensor: #E
69
+
70
+ full_vertices = vertices[edges] #E,2,3
71
+ a,b = full_vertices.unbind(dim=1) #E,3
72
+ return torch.norm(a-b,p=2,dim=-1)
73
+
74
+ def calc_face_normals(
75
+ vertices:torch.Tensor, #V,3 first vertex may be unreferenced
76
+ faces:torch.Tensor, #F,3 long, first face may be all zero
77
+ normalize:bool=False,
78
+ )->torch.Tensor: #F,3
79
+ """
80
+ n
81
+ |
82
+ c0 corners ordered counterclockwise when
83
+ / \ looking onto surface (in neg normal direction)
84
+ c1---c2
85
+ """
86
+ full_vertices = vertices[faces] #F,C=3,3
87
+ v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
88
+ face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
89
+ if normalize:
90
+ face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1)
91
+ return face_normals #F,3
92
+
93
+ def calc_vertex_normals(
94
+ vertices:torch.Tensor, #V,3 first vertex may be unreferenced
95
+ faces:torch.Tensor, #F,3 long, first face may be all zero
96
+ face_normals:torch.Tensor=None, #F,3, not normalized
97
+ )->torch.Tensor: #F,3
98
+
99
+ F = faces.shape[0]
100
+
101
+ if face_normals is None:
102
+ face_normals = calc_face_normals(vertices,faces)
103
+
104
+ vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
105
+ vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3))
106
+ vertex_normals = vertex_normals.sum(dim=1) #V,3
107
+ return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)
108
+
109
+ def calc_face_ref_normals(
110
+ faces:torch.Tensor, #F,3 long, 0 for unused
111
+ vertex_normals:torch.Tensor, #V,3 first unused
112
+ normalize:bool=False,
113
+ )->torch.Tensor: #F,3
114
+ """calculate reference normals for face flip detection"""
115
+ full_normals = vertex_normals[faces] #F,C=3,3
116
+ ref_normals = full_normals.sum(dim=1) #F,3
117
+ if normalize:
118
+ ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
119
+ return ref_normals
120
+
121
+ def pack(
122
+ vertices:torch.Tensor, #V,3 first unused and nan
123
+ faces:torch.Tensor, #F,3 long, 0 for unused
124
+ )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
125
+ """removes unused elements in vertices and faces"""
126
+ V = vertices.shape[0]
127
+
128
+ # remove unused faces
129
+ used_faces = faces[:,0]!=0
130
+ used_faces[0] = True
131
+ faces = faces[used_faces] #sync
132
+
133
+ # remove unused vertices
134
+ used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
135
+ used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add')
136
+ used_vertices = used_vertices.any(dim=1)
137
+ used_vertices[0] = True
138
+ vertices = vertices[used_vertices] #sync
139
+
140
+ # update used faces
141
+ ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
142
+ V1 = used_vertices.sum()
143
+ ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync
144
+ faces = ind[faces]
145
+
146
+ return vertices,faces
147
+
148
+ def split_edges(
149
+ vertices:torch.Tensor, #V,3 first unused
150
+ faces:torch.Tensor, #F,3 long, 0 for unused
151
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
152
+ face_to_edge:torch.Tensor, #F,3 long 0 for unused
153
+ splits, #E bool
154
+ pack_faces:bool=True,
155
+ )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
156
+
157
+ # c2 c2 c...corners = faces
158
+ # . . . . s...side_vert, 0 means no split
159
+ # . . .N2 . S...shrunk_face
160
+ # . . . . Ni...new_faces
161
+ # s2 s1 s2|c2...s1|c1
162
+ # . . . . .
163
+ # . . . S . .
164
+ # . . . . N1 .
165
+ # c0...(s0=0)....c1 s0|c0...........c1
166
+ #
167
+ # pseudo-code:
168
+ # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
169
+ # split = side_vert!=0 example:[False,True,True]
170
+ # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
171
+ # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
172
+ # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]
173
+
174
+ V = vertices.shape[0]
175
+ F = faces.shape[0]
176
+ S = splits.sum().item() #sync
177
+
178
+ if S==0:
179
+ return vertices,faces
180
+
181
+ edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
182
+ edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
183
+ side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
184
+ split_edges = edges[splits] #S sync
185
+
186
+ #vertices
187
+ split_vertices = vertices[split_edges].mean(dim=1) #S,3
188
+ vertices = torch.concat((vertices,split_vertices),dim=0)
189
+
190
+ #faces
191
+ side_split = side_vert!=0 #F,3
192
+ shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
193
+ new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
194
+ faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
195
+ if pack_faces:
196
+ mask = faces[:,0]!=0
197
+ mask[0] = True
198
+ faces = faces[mask] #F',3 sync
199
+
200
+ return vertices,faces
201
+
202
+ def collapse_edges(
203
+ vertices:torch.Tensor, #V,3 first unused
204
+ faces:torch.Tensor, #F,3 long 0 for unused
205
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
206
+ priorities:torch.Tensor, #E float
207
+ stable:bool=False, #only for unit testing
208
+ )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
209
+
210
+ V = vertices.shape[0]
211
+
212
+ # check spacing
213
+ _,order = priorities.sort(stable=stable) #E
214
+ rank = torch.zeros_like(order)
215
+ rank[order] = torch.arange(0,len(rank),device=rank.device)
216
+ vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
217
+ edge_rank = rank #E
218
+ for i in range(3):
219
+ torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
220
+ edge_rank,_ = vert_rank[edges].max(dim=-1) #E
221
+ candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2
222
+
223
+ # check connectivity
224
+ vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
225
+ vert_connections[candidates[:,0]] = 1 #start
226
+ edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
227
+ vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
228
+ vert_connections[candidates] = 0 #clear start and end
229
+ edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
230
+ vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
231
+ collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end
232
+
233
+ # mean vertices
234
+ vertices[collapses[:,0]] = vertices[collapses].mean(dim=1)
235
+
236
+ # update faces
237
+ dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
238
+ dest[collapses[:,1]] = dest[collapses[:,0]]
239
+ faces = dest[faces] #F,3
240
+ c0,c1,c2 = faces.unbind(dim=-1)
241
+ collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
242
+ faces[collapsed] = 0
243
+
244
+ return vertices,faces
245
+
246
+ def calc_face_collapses(
247
+ vertices:torch.Tensor, #V,3 first unused
248
+ faces:torch.Tensor, #F,3 long, 0 for unused
249
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
250
+ face_to_edge:torch.Tensor, #F,3 long 0 for unused
251
+ edge_length:torch.Tensor, #E
252
+ face_normals:torch.Tensor, #F,3
253
+ vertex_normals:torch.Tensor, #V,3 first unused
254
+ min_edge_length:torch.Tensor=None, #V
255
+ area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
256
+ shortest_probability = 0.8
257
+ )->torch.Tensor: #E edges to collapse
258
+
259
+ E = edges.shape[0]
260
+ F = faces.shape[0]
261
+
262
+ # face flips
263
+ ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
264
+ face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
265
+
266
+ # small faces
267
+ if min_edge_length is not None:
268
+ min_face_length = min_edge_length[faces].mean(dim=-1) #F
269
+ min_area = min_face_length**2 * area_ratio #F
270
+ face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
271
+ face_collapses[0] = False
272
+
273
+ # faces to edges
274
+ face_length = edge_length[face_to_edge] #F,3
275
+
276
+ if shortest_probability<1:
277
+ #select shortest edge with shortest_probability chance
278
+ randlim = round(2/(1-shortest_probability))
279
+ rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
280
+ sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
281
+ local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
282
+ else:
283
+ local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
284
+
285
+ edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
286
+ edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
287
+ edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long())
288
+
289
+ return edge_collapses.bool()
290
+
291
+ def flip_edges(
292
+ vertices:torch.Tensor, #V,3 first unused
293
+ faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
294
+ edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
295
+ edge_to_face:torch.Tensor, #E,[left,right],[face,side]
296
+ with_border:bool=True, #handle border edges (D=4 instead of D=6)
297
+ with_normal_check:bool=True, #check face normal flips
298
+ stable:bool=False, #only for unit testing
299
+ ):
300
+ V = vertices.shape[0]
301
+ E = edges.shape[0]
302
+ device=vertices.device
303
+ vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
304
+ vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
305
+ neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
306
+ neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
307
+ edge_is_inside = neighbors.all(dim=-1) #E
308
+
309
+ if with_border:
310
+ # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
311
+ # need to use float for masks in order to use scatter(reduce='multiply')
312
+ vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
313
+ src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
314
+ vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
315
+ vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
316
+ vertex_degree -= 2 * vertex_is_inside #V long
317
+
318
+ neighbor_degrees = vertex_degree[neighbors] #E,LR=2
319
+ edge_degrees = vertex_degree[edges] #E,2
320
+ #
321
+ # loss = Sum_over_affected_vertices((new_degree-6)**2)
322
+ # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
323
+ # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
324
+ # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
325
+ #
326
+ loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
327
+ candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
328
+ loss_change = loss_change[candidates] #E'
329
+ if loss_change.shape[0]==0:
330
+ return
331
+
332
+ edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
333
+ _,order = loss_change.sort(descending=True, stable=stable) #E'
334
+ rank = torch.zeros_like(order)
335
+ rank[order] = torch.arange(0,len(rank),device=rank.device)
336
+ vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
337
+ torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
338
+ vertex_rank,_ = vertex_rank.max(dim=-1) #V
339
+ neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
340
+ flip = rank==neighborhood_rank #E'
341
+
342
+ if with_normal_check:
343
+ # cl-<-----e1 e0,e1...edge, e0<e1
344
+ # | /A L,R....left and right face
345
+ # | L / | both triangles ordered counter clockwise
346
+ # | / R | normals pointing out of screen
347
+ # V/ |
348
+ # e0---->-cr
349
+ v = vertices[edges_neighbors] #E",4,3
350
+ v = v - v[:,0:1] #make relative to e0
351
+ e1 = v[:,1]
352
+ cl = v[:,2]
353
+ cr = v[:,3]
354
+ n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors
355
+ flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
356
+ flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face
357
+
358
+ flip_edges_neighbors = edges_neighbors[flip] #E",4
359
+ flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
360
+ flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
361
+ faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))
mesh_reconstruction/render.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/Profactor/continuous-remeshing
2
+ import nvdiffrast.torch as dr
3
+ import torch
4
+ from typing import Tuple
5
+
6
+ def _warmup(glctx, device=None):
7
+ device = 'cuda' if device is None else device
8
+ #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
9
+ def tensor(*args, **kwargs):
10
+ return torch.tensor(*args, device=device, **kwargs)
11
+ pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
12
+ tri = tensor([[0, 1, 2]], dtype=torch.int32)
13
+ dr.rasterize(glctx, pos, tri, resolution=[256, 256])
14
+
15
+ class NormalsRenderer:
16
+
17
+ _glctx:dr.RasterizeCudaContext = None
18
+
19
+ def __init__(
20
+ self,
21
+ mv: torch.Tensor, #C,4,4
22
+ proj: torch.Tensor, #C,4,4
23
+ image_size: Tuple[int,int],
24
+ mvp = None,
25
+ device=None,
26
+ ):
27
+ if mvp is None:
28
+ self._mvp = proj @ mv #C,4,4
29
+ else:
30
+ self._mvp = mvp
31
+ self._image_size = image_size
32
+ self._glctx = dr.RasterizeCudaContext(device=device)
33
+ _warmup(self._glctx, device)
34
+
35
+ def render(self,
36
+ vertices: torch.Tensor, #V,3 float
37
+ normals: torch.Tensor, #V,3 float in [-1, 1]
38
+ faces: torch.Tensor, #F,3 long
39
+ ) ->torch.Tensor: #C,H,W,4
40
+
41
+ V = vertices.shape[0]
42
+ faces = faces.type(torch.int32)
43
+ vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4
44
+ vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4
45
+ rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4
46
+ vert_col = (normals+1)/2 #V,3
47
+ col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3
48
+ alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1
49
+ col = torch.concat((col,alpha),dim=-1) #C,H,W,4
50
+ col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
51
+ return col #C,H,W,4
52
+
53
+ from pytorch3d.structures import Meshes
54
+ from pytorch3d.renderer.mesh.shader import ShaderBase
55
+ from pytorch3d.renderer import (
56
+ RasterizationSettings,
57
+ MeshRendererWithFragments,
58
+ TexturesVertex,
59
+ MeshRasterizer,
60
+ BlendParams,
61
+ FoVOrthographicCameras,
62
+ look_at_view_transform,
63
+ hard_rgb_blend,
64
+ )
65
+
66
+ class VertexColorShader(ShaderBase):
67
+ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
68
+ blend_params = kwargs.get("blend_params", self.blend_params)
69
+ texels = meshes.sample_textures(fragments)
70
+ return hard_rgb_blend(texels, fragments, blend_params)
71
+
72
+ def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"):
73
+ if len(mesh) != len(cameras):
74
+ if len(cameras) % len(mesh) == 0:
75
+ mesh = mesh.extend(len(cameras))
76
+ else:
77
+ raise NotImplementedError()
78
+
79
+ # render requires everything in float16 or float32
80
+ input_dtype = dtype
81
+ blend_params = BlendParams(1e-4, 1e-4, bkgd)
82
+
83
+ # Define the settings for rasterization and shading
84
+ raster_settings = RasterizationSettings(
85
+ image_size=(H, W),
86
+ blur_radius=blur_radius,
87
+ faces_per_pixel=faces_per_pixel,
88
+ clip_barycentric_coords=True,
89
+ bin_size=None,
90
+ max_faces_per_bin=500000,
91
+ )
92
+
93
+ # Create a renderer by composing a rasterizer and a shader
94
+ # We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used)
95
+ renderer = MeshRendererWithFragments(
96
+ rasterizer=MeshRasterizer(
97
+ cameras=cameras,
98
+ raster_settings=raster_settings
99
+ ),
100
+ shader=VertexColorShader(
101
+ device=device,
102
+ cameras=cameras,
103
+ blend_params=blend_params
104
+ )
105
+ )
106
+
107
+ # render RGB and depth, get mask
108
+ with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type):
109
+ images, _ = renderer(mesh)
110
+ return images # BHW4
111
+
112
+ class Pytorch3DNormalsRenderer:
113
+ def __init__(self, cameras, image_size, device):
114
+ self.cameras = cameras.to(device)
115
+ self._image_size = image_size
116
+ self.device = device
117
+
118
+ def render(self,
119
+ vertices: torch.Tensor, #V,3 float
120
+ normals: torch.Tensor, #V,3 float in [-1, 1]
121
+ faces: torch.Tensor, #F,3 long
122
+ ) ->torch.Tensor: #C,H,W,4
123
+ mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device)
124
+ return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device)
125
+
126
+ def save_tensor_to_img(tensor, save_dir):
127
+ from PIL import Image
128
+ import numpy as np
129
+ for idx, img in enumerate(tensor):
130
+ img = img[..., :3].cpu().numpy()
131
+ img = (img * 255).astype(np.uint8)
132
+ img = Image.fromarray(img)
133
+ img.save(save_dir + f"{idx}.png")
134
+
135
+ if __name__ == "__main__":
136
+ import sys
137
+ import os
138
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
139
+ from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
140
+ cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
141
+ mv,proj = make_star_cameras_orthographic(4, 1)
142
+ resolution = 1024
143
+ renderer1 = NormalsRenderer(mv,proj, [resolution,resolution], device="cuda")
144
+ renderer2 = Pytorch3DNormalsRenderer(cameras, [resolution,resolution], device="cuda")
145
+ vertices = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[1,0,0]], device="cuda", dtype=torch.float32)
146
+ normals = torch.tensor([[-1,-1,-1],[1,-1,-1],[-1,-1,1],[-1,1,-1]], device="cuda", dtype=torch.float32)
147
+ faces = torch.tensor([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], device="cuda", dtype=torch.long)
148
+
149
+ import time
150
+ t0 = time.time()
151
+ r1 = renderer1.render(vertices, normals, faces)
152
+ print("time r1:", time.time() - t0)
153
+
154
+ t0 = time.time()
155
+ r2 = renderer2.render(vertices, normals, faces)
156
+ print("time r2:", time.time() - t0)
157
+
158
+ for i in range(4):
159
+ print((r1[i]-r2[i]).abs().mean(), (r1[i]+r2[i]).abs().mean())
package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff4a35615ed42148c8579622bee6dca88f7f3be683671524a282fafaf7589682
3
+ size 3079614
package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11e7f7f781fef16c09ec8d03bfb6da84cf61c54fc59e8a4ea047a90c4a24e88f
3
+ size 162720703
scripts/all_typing.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code from https://github.com/threestudio-project
2
+
3
+ """
4
+ This module contains type annotations for the project, using
5
+ 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
6
+ 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
7
+
8
+ Two types of typing checking can be used:
9
+ 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
10
+ 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
11
+ """
12
+
13
+ # Basic types
14
+ from typing import (
15
+ Any,
16
+ Callable,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ Literal,
21
+ NamedTuple,
22
+ NewType,
23
+ Optional,
24
+ Sized,
25
+ Tuple,
26
+ Type,
27
+ TypeVar,
28
+ Union,
29
+ )
30
+
31
+ # Tensor dtype
32
+ # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
33
+ from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
34
+
35
+ # Config type
36
+ from omegaconf import DictConfig
37
+
38
+ # PyTorch Tensor type
39
+ from torch import Tensor
40
+
41
+ # Runtime type checking decorator
42
+ from typeguard import typechecked as typechecker
scripts/load_onnx.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import torch
3
+
4
+ providers = [
5
+ # ('TensorrtExecutionProvider', {
6
+ # 'device_id': 0,
7
+ # 'trt_max_workspace_size': 8 * 1024 * 1024 * 1024,
8
+ # 'trt_fp16_enable': True,
9
+ # 'trt_engine_cache_enable': True,
10
+ # }),
11
+ ('CUDAExecutionProvider', {
12
+ 'device_id': 0,
13
+ 'arena_extend_strategy': 'kSameAsRequested',
14
+ 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
15
+ 'cudnn_conv_algo_search': 'HEURISTIC',
16
+ })
17
+ ]
18
+
19
+ def load_onnx(file_path: str):
20
+ assert file_path.endswith(".onnx")
21
+ sess_opt = onnxruntime.SessionOptions()
22
+ ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers)
23
+ return ort_session
24
+
25
+
26
+ def load_onnx_caller(file_path: str, single_output=False):
27
+ ort_session = load_onnx(file_path)
28
+ def caller(*args):
29
+ torch_input = isinstance(args[0], torch.Tensor)
30
+ if torch_input:
31
+ torch_input_dtype = args[0].dtype
32
+ torch_input_device = args[0].device
33
+ # check all are torch.Tensor and have same dtype and device
34
+ assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor"
35
+ assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor"
36
+ assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor"
37
+ args = [arg.cpu().float().numpy() for arg in args]
38
+
39
+ ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))}
40
+ ort_outs = ort_session.run(None, ort_inputs)
41
+
42
+ if torch_input:
43
+ ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs]
44
+
45
+ if single_output:
46
+ return ort_outs[0]
47
+ return ort_outs
48
+ return caller
scripts/mesh_init.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import numpy as np
4
+ from pytorch3d.structures import Meshes
5
+ from pytorch3d.renderer import TexturesVertex
6
+ from scripts.utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh
7
+ import pymeshlab
8
+
9
+ _MAX_THREAD = 8
10
+
11
+ # rgb and depth to mesh
12
+ def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"):
13
+ pixel_center = 0.5 if use_pixel_centers else 0
14
+ i, j = np.meshgrid(
15
+ np.arange(W, dtype=np.float32) + pixel_center,
16
+ np.arange(H, dtype=np.float32) + pixel_center,
17
+ indexing='xy'
18
+ )
19
+ i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device)
20
+
21
+ origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3
22
+ directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3
23
+
24
+ return origins, directions
25
+
26
+ def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False):
27
+ if valid_HWC is None:
28
+ valid_HWC = torch.ones_like(pred_HWC).bool()
29
+ H, W = rgb_BCHW.shape[-2:]
30
+ rgb_BCHW = rgb_BCHW.flip(-2)
31
+ pred_HWC = pred_HWC.flip(0)
32
+ valid_HWC = valid_HWC.flip(0)
33
+ rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device)
34
+ verts = rays_o + rays_d * pred_HWC # [H, W, 3]
35
+ verts = verts.reshape(-1, 3) # [V, 3]
36
+ indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device)
37
+ faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1)
38
+ # faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1]
39
+ faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1]
40
+ faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1)
41
+ # faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:]
42
+ faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:]
43
+ faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3), faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)], dim=0) # (F, 3)
44
+ colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3)
45
+ if is_back:
46
+ verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device)
47
+
48
+ used_verts = faces.unique()
49
+ old_to_new_mapping = torch.zeros_like(verts[..., 0]).long()
50
+ old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device)
51
+ new_faces = old_to_new_mapping[faces]
52
+ mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]]))
53
+ return mesh
54
+
55
+ def normalmap_to_depthmap(normal_np):
56
+ from scripts.normal_to_height_map import estimate_height_map
57
+ height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96)
58
+ return height
59
+
60
+ def transform_back_normal_to_front(normal_pil):
61
+ arr = np.array(normal_pil) # in [0, 255]
62
+ arr[..., 0] = 255-arr[..., 0]
63
+ arr[..., 2] = 255-arr[..., 2]
64
+ return Image.fromarray(arr.astype(np.uint8))
65
+
66
+ def calc_w_over_h(normal_pil):
67
+ if isinstance(normal_pil, Image.Image):
68
+ arr = np.array(normal_pil)
69
+ else:
70
+ assert isinstance(normal_pil, np.ndarray)
71
+ arr = normal_pil
72
+ if arr.shape[-1] == 4:
73
+ alpha = arr[..., -1] / 255.
74
+ alpha[alpha >= 0.5] = 1
75
+ alpha[alpha < 0.5] = 0
76
+ else:
77
+ alpha = ~(arr.min(axis=-1) >= 250)
78
+ h_min, w_min = np.min(np.where(alpha), axis=1)
79
+ h_max, w_max = np.max(np.where(alpha), axis=1)
80
+ return (w_max - w_min) / (h_max - h_min)
81
+
82
+ def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0):
83
+ if is_back:
84
+ normal_pil = transform_back_normal_to_front(normal_pil)
85
+ normal_img = np.array(normal_pil)
86
+ rgb_img = np.array(rgb_pil)
87
+ if normal_img.shape[-1] == 4:
88
+ valid_HWC = normal_img[..., [3]] / 255
89
+ elif rgb_img.shape[-1] == 4:
90
+ valid_HWC = rgb_img[..., [3]] / 255
91
+ else:
92
+ raise ValueError("invalid input, either normal or rgb should have alpha channel")
93
+
94
+ real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0])
95
+
96
+ heights = normalmap_to_depthmap(normal_img)
97
+ rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None]
98
+ valid_HWC[valid_HWC < 0.5] = 0
99
+ valid_HWC[valid_HWC >= 0.5] = 1
100
+ valid_HWC = torch.from_numpy(valid_HWC).bool()
101
+ if init_type == "std":
102
+ # accurate but not stable
103
+ pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None]
104
+ elif init_type == "thin":
105
+ heights = heights - heights.min()
106
+ heights = (heights / heights.max() * 0.2)
107
+ pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
108
+ else:
109
+ # stable but not accurate
110
+ heights = heights - heights.min()
111
+ heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1]
112
+ pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
113
+
114
+ # set the boarder pixels to 0 height
115
+ import cv2
116
+ # edge filter
117
+ edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255)
118
+ edge = torch.from_numpy(edge).bool()[..., None]
119
+ pred_HWC[edge] = 0
120
+
121
+ valid_HWC[pred_HWC < clamp_min] = False
122
+ return depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back)
123
+
124
+ def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0):
125
+ ms = pymeshlab.MeshSet()
126
+ ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh")
127
+ if simplification > 0:
128
+ ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
129
+ ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True)
130
+ if simplification > 0:
131
+ ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
132
+ return meshlab_mesh_to_py3dmesh(ms.current_mesh())