dylanebert HF staff commited on
Commit
c505505
1 Parent(s): 25c21f3

add multi-view-diffusion

Browse files
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ },
28
+ "use_square_size": false
29
+ }
image_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
3
+ "architectures": [
4
+ "CLIPVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "gelu",
9
+ "hidden_size": 1280,
10
+ "image_size": 224,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 32,
19
+ "patch_size": 14,
20
+ "projection_dim": 1024,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.35.2"
23
+ }
image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a56cfd4ffcf40be097c430324ec184cc37187f6dafef128ef9225438a3c03c4
3
+ size 1261595704
model_index.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MVDreamPipeline",
3
+ "_diffusers_version": "0.25.0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPImageProcessor"
7
+ ],
8
+ "image_encoder": [
9
+ "transformers",
10
+ "CLIPVisionModel"
11
+ ],
12
+ "requires_safety_checker": false,
13
+ "scheduler": [
14
+ "diffusers",
15
+ "DDIMScheduler"
16
+ ],
17
+ "text_encoder": [
18
+ "transformers",
19
+ "CLIPTextModel"
20
+ ],
21
+ "tokenizer": [
22
+ "transformers",
23
+ "CLIPTokenizer"
24
+ ],
25
+ "unet": [
26
+ "mv_unet",
27
+ "MultiViewUNetModel"
28
+ ],
29
+ "vae": [
30
+ "diffusers",
31
+ "AutoencoderKL"
32
+ ]
33
+ }
pipeline.py ADDED
@@ -0,0 +1,1592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from inspect import isfunction
4
+ from typing import Any, Callable, List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ # require xformers!
12
+ import xformers
13
+ import xformers.ops
14
+ from diffusers import AutoencoderKL, DiffusionPipeline
15
+ from diffusers.configuration_utils import ConfigMixin, FrozenDict
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers.schedulers import DDIMScheduler
18
+ from diffusers.utils import (
19
+ deprecate,
20
+ is_accelerate_available,
21
+ is_accelerate_version,
22
+ logging,
23
+ )
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from einops import rearrange, repeat
26
+ from kiui.cam import orbit_camera
27
+ from transformers import (
28
+ CLIPImageProcessor,
29
+ CLIPTextModel,
30
+ CLIPTokenizer,
31
+ CLIPVisionModel,
32
+ )
33
+
34
+
35
+ def get_camera(
36
+ num_frames,
37
+ elevation=15,
38
+ azimuth_start=0,
39
+ azimuth_span=360,
40
+ blender_coord=True,
41
+ extra_view=False,
42
+ ):
43
+ angle_gap = azimuth_span / num_frames
44
+ cameras = []
45
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
46
+
47
+ pose = orbit_camera(
48
+ -elevation, azimuth, radius=1
49
+ ) # kiui's elevation is negated, [4, 4]
50
+
51
+ # opengl to blender
52
+ if blender_coord:
53
+ pose[2] *= -1
54
+ pose[[1, 2]] = pose[[2, 1]]
55
+
56
+ cameras.append(pose.flatten())
57
+
58
+ if extra_view:
59
+ cameras.append(np.zeros_like(cameras[0]))
60
+
61
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
62
+
63
+
64
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
65
+ """
66
+ Create sinusoidal timestep embeddings.
67
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
68
+ These may be fractional.
69
+ :param dim: the dimension of the output.
70
+ :param max_period: controls the minimum frequency of the embeddings.
71
+ :return: an [N x dim] Tensor of positional embeddings.
72
+ """
73
+ if not repeat_only:
74
+ half = dim // 2
75
+ freqs = torch.exp(
76
+ -math.log(max_period)
77
+ * torch.arange(start=0, end=half, dtype=torch.float32)
78
+ / half
79
+ ).to(device=timesteps.device)
80
+ args = timesteps[:, None] * freqs[None]
81
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
82
+ if dim % 2:
83
+ embedding = torch.cat(
84
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
85
+ )
86
+ else:
87
+ embedding = repeat(timesteps, "b -> b d", d=dim)
88
+ # import pdb; pdb.set_trace()
89
+ return embedding
90
+
91
+
92
+ def zero_module(module):
93
+ """
94
+ Zero out the parameters of a module and return it.
95
+ """
96
+ for p in module.parameters():
97
+ p.detach().zero_()
98
+ return module
99
+
100
+
101
+ def conv_nd(dims, *args, **kwargs):
102
+ """
103
+ Create a 1D, 2D, or 3D convolution module.
104
+ """
105
+ if dims == 1:
106
+ return nn.Conv1d(*args, **kwargs)
107
+ elif dims == 2:
108
+ return nn.Conv2d(*args, **kwargs)
109
+ elif dims == 3:
110
+ return nn.Conv3d(*args, **kwargs)
111
+ raise ValueError(f"unsupported dimensions: {dims}")
112
+
113
+
114
+ def avg_pool_nd(dims, *args, **kwargs):
115
+ """
116
+ Create a 1D, 2D, or 3D average pooling module.
117
+ """
118
+ if dims == 1:
119
+ return nn.AvgPool1d(*args, **kwargs)
120
+ elif dims == 2:
121
+ return nn.AvgPool2d(*args, **kwargs)
122
+ elif dims == 3:
123
+ return nn.AvgPool3d(*args, **kwargs)
124
+ raise ValueError(f"unsupported dimensions: {dims}")
125
+
126
+
127
+ def default(val, d):
128
+ if val is not None:
129
+ return val
130
+ return d() if isfunction(d) else d
131
+
132
+
133
+ class GEGLU(nn.Module):
134
+ def __init__(self, dim_in, dim_out):
135
+ super().__init__()
136
+ self.proj = nn.Linear(dim_in, dim_out * 2)
137
+
138
+ def forward(self, x):
139
+ x, gate = self.proj(x).chunk(2, dim=-1)
140
+ return x * F.gelu(gate)
141
+
142
+
143
+ class FeedForward(nn.Module):
144
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
145
+ super().__init__()
146
+ inner_dim = int(dim * mult)
147
+ dim_out = default(dim_out, dim)
148
+ project_in = (
149
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
150
+ if not glu
151
+ else GEGLU(dim, inner_dim)
152
+ )
153
+
154
+ self.net = nn.Sequential(
155
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
156
+ )
157
+
158
+ def forward(self, x):
159
+ return self.net(x)
160
+
161
+
162
+ class MemoryEfficientCrossAttention(nn.Module):
163
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
164
+ def __init__(
165
+ self,
166
+ query_dim,
167
+ context_dim=None,
168
+ heads=8,
169
+ dim_head=64,
170
+ dropout=0.0,
171
+ ip_dim=0,
172
+ ip_weight=1,
173
+ ):
174
+ super().__init__()
175
+
176
+ inner_dim = dim_head * heads
177
+ context_dim = default(context_dim, query_dim)
178
+
179
+ self.heads = heads
180
+ self.dim_head = dim_head
181
+
182
+ self.ip_dim = ip_dim
183
+ self.ip_weight = ip_weight
184
+
185
+ if self.ip_dim > 0:
186
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
187
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
188
+
189
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
190
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
191
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
192
+
193
+ self.to_out = nn.Sequential(
194
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
195
+ )
196
+ self.attention_op: Optional[Any] = None
197
+
198
+ def forward(self, x, context=None):
199
+ q = self.to_q(x)
200
+ context = default(context, x)
201
+
202
+ if self.ip_dim > 0:
203
+ # context: [B, 77 + 16(ip), 1024]
204
+ token_len = context.shape[1]
205
+ context_ip = context[:, -self.ip_dim :, :]
206
+ k_ip = self.to_k_ip(context_ip)
207
+ v_ip = self.to_v_ip(context_ip)
208
+ context = context[:, : (token_len - self.ip_dim), :]
209
+
210
+ k = self.to_k(context)
211
+ v = self.to_v(context)
212
+
213
+ b, _, _ = q.shape
214
+ q, k, v = map(
215
+ lambda t: t.unsqueeze(3)
216
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
217
+ .permute(0, 2, 1, 3)
218
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
219
+ .contiguous(),
220
+ (q, k, v),
221
+ )
222
+
223
+ # actually compute the attention, what we cannot get enough of
224
+ out = xformers.ops.memory_efficient_attention(
225
+ q, k, v, attn_bias=None, op=self.attention_op
226
+ )
227
+
228
+ if self.ip_dim > 0:
229
+ k_ip, v_ip = map(
230
+ lambda t: t.unsqueeze(3)
231
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
232
+ .permute(0, 2, 1, 3)
233
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
234
+ .contiguous(),
235
+ (k_ip, v_ip),
236
+ )
237
+ # actually compute the attention, what we cannot get enough of
238
+ out_ip = xformers.ops.memory_efficient_attention(
239
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
240
+ )
241
+ out = out + self.ip_weight * out_ip
242
+
243
+ out = (
244
+ out.unsqueeze(0)
245
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
246
+ .permute(0, 2, 1, 3)
247
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
248
+ )
249
+ return self.to_out(out)
250
+
251
+
252
+ class BasicTransformerBlock3D(nn.Module):
253
+
254
+ def __init__(
255
+ self,
256
+ dim,
257
+ n_heads,
258
+ d_head,
259
+ context_dim,
260
+ dropout=0.0,
261
+ gated_ff=True,
262
+ ip_dim=0,
263
+ ip_weight=1,
264
+ ):
265
+ super().__init__()
266
+
267
+ self.attn1 = MemoryEfficientCrossAttention(
268
+ query_dim=dim,
269
+ context_dim=None, # self-attention
270
+ heads=n_heads,
271
+ dim_head=d_head,
272
+ dropout=dropout,
273
+ )
274
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
275
+ self.attn2 = MemoryEfficientCrossAttention(
276
+ query_dim=dim,
277
+ context_dim=context_dim,
278
+ heads=n_heads,
279
+ dim_head=d_head,
280
+ dropout=dropout,
281
+ # ip only applies to cross-attention
282
+ ip_dim=ip_dim,
283
+ ip_weight=ip_weight,
284
+ )
285
+ self.norm1 = nn.LayerNorm(dim)
286
+ self.norm2 = nn.LayerNorm(dim)
287
+ self.norm3 = nn.LayerNorm(dim)
288
+
289
+ def forward(self, x, context=None, num_frames=1):
290
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
291
+ x = self.attn1(self.norm1(x), context=None) + x
292
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
293
+ x = self.attn2(self.norm2(x), context=context) + x
294
+ x = self.ff(self.norm3(x)) + x
295
+ return x
296
+
297
+
298
+ class SpatialTransformer3D(nn.Module):
299
+
300
+ def __init__(
301
+ self,
302
+ in_channels,
303
+ n_heads,
304
+ d_head,
305
+ context_dim, # cross attention input dim
306
+ depth=1,
307
+ dropout=0.0,
308
+ ip_dim=0,
309
+ ip_weight=1,
310
+ ):
311
+ super().__init__()
312
+
313
+ if not isinstance(context_dim, list):
314
+ context_dim = [context_dim]
315
+
316
+ self.in_channels = in_channels
317
+
318
+ inner_dim = n_heads * d_head
319
+ self.norm = nn.GroupNorm(
320
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
321
+ )
322
+ self.proj_in = nn.Linear(in_channels, inner_dim)
323
+
324
+ self.transformer_blocks = nn.ModuleList(
325
+ [
326
+ BasicTransformerBlock3D(
327
+ inner_dim,
328
+ n_heads,
329
+ d_head,
330
+ context_dim=context_dim[d],
331
+ dropout=dropout,
332
+ ip_dim=ip_dim,
333
+ ip_weight=ip_weight,
334
+ )
335
+ for d in range(depth)
336
+ ]
337
+ )
338
+
339
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
340
+
341
+ def forward(self, x, context=None, num_frames=1):
342
+ # note: if no context is given, cross-attention defaults to self-attention
343
+ if not isinstance(context, list):
344
+ context = [context]
345
+ b, c, h, w = x.shape
346
+ x_in = x
347
+ x = self.norm(x)
348
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
349
+ x = self.proj_in(x)
350
+ for i, block in enumerate(self.transformer_blocks):
351
+ x = block(x, context=context[i], num_frames=num_frames)
352
+ x = self.proj_out(x)
353
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
354
+
355
+ return x + x_in
356
+
357
+
358
+ class PerceiverAttention(nn.Module):
359
+ def __init__(self, *, dim, dim_head=64, heads=8):
360
+ super().__init__()
361
+ self.scale = dim_head**-0.5
362
+ self.dim_head = dim_head
363
+ self.heads = heads
364
+ inner_dim = dim_head * heads
365
+
366
+ self.norm1 = nn.LayerNorm(dim)
367
+ self.norm2 = nn.LayerNorm(dim)
368
+
369
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
370
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
371
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
372
+
373
+ def forward(self, x, latents):
374
+ """
375
+ Args:
376
+ x (torch.Tensor): image features
377
+ shape (b, n1, D)
378
+ latent (torch.Tensor): latent features
379
+ shape (b, n2, D)
380
+ """
381
+ x = self.norm1(x)
382
+ latents = self.norm2(latents)
383
+
384
+ b, h, _ = latents.shape
385
+
386
+ q = self.to_q(latents)
387
+ kv_input = torch.cat((x, latents), dim=-2)
388
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
389
+
390
+ q, k, v = map(
391
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
392
+ .transpose(1, 2)
393
+ .reshape(b, self.heads, t.shape[1], -1)
394
+ .contiguous(),
395
+ (q, k, v),
396
+ )
397
+
398
+ # attention
399
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
400
+ weight = (q * scale) @ (k * scale).transpose(
401
+ -2, -1
402
+ ) # More stable with f16 than dividing afterwards
403
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
404
+ out = weight @ v
405
+
406
+ out = out.permute(0, 2, 1, 3).reshape(b, h, -1)
407
+
408
+ return self.to_out(out)
409
+
410
+
411
+ class Resampler(nn.Module):
412
+ def __init__(
413
+ self,
414
+ dim=1024,
415
+ depth=8,
416
+ dim_head=64,
417
+ heads=16,
418
+ num_queries=8,
419
+ embedding_dim=768,
420
+ output_dim=1024,
421
+ ff_mult=4,
422
+ ):
423
+ super().__init__()
424
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
425
+ self.proj_in = nn.Linear(embedding_dim, dim)
426
+ self.proj_out = nn.Linear(dim, output_dim)
427
+ self.norm_out = nn.LayerNorm(output_dim)
428
+
429
+ self.layers = nn.ModuleList([])
430
+ for _ in range(depth):
431
+ self.layers.append(
432
+ nn.ModuleList(
433
+ [
434
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
435
+ nn.Sequential(
436
+ nn.LayerNorm(dim),
437
+ nn.Linear(dim, dim * ff_mult, bias=False),
438
+ nn.GELU(),
439
+ nn.Linear(dim * ff_mult, dim, bias=False),
440
+ ),
441
+ ]
442
+ )
443
+ )
444
+
445
+ def forward(self, x):
446
+ latents = self.latents.repeat(x.size(0), 1, 1)
447
+ x = self.proj_in(x)
448
+ for attn, ff in self.layers:
449
+ latents = attn(x, latents) + latents
450
+ latents = ff(latents) + latents
451
+
452
+ latents = self.proj_out(latents)
453
+ return self.norm_out(latents)
454
+
455
+
456
+ class CondSequential(nn.Sequential):
457
+ """
458
+ A sequential module that passes timestep embeddings to the children that
459
+ support it as an extra input.
460
+ """
461
+
462
+ def forward(self, x, emb, context=None, num_frames=1):
463
+ for layer in self:
464
+ if isinstance(layer, ResBlock):
465
+ x = layer(x, emb)
466
+ elif isinstance(layer, SpatialTransformer3D):
467
+ x = layer(x, context, num_frames=num_frames)
468
+ else:
469
+ x = layer(x)
470
+ return x
471
+
472
+
473
+ class Upsample(nn.Module):
474
+ """
475
+ An upsampling layer with an optional convolution.
476
+ :param channels: channels in the inputs and outputs.
477
+ :param use_conv: a bool determining if a convolution is applied.
478
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
479
+ upsampling occurs in the inner-two dimensions.
480
+ """
481
+
482
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
483
+ super().__init__()
484
+ self.channels = channels
485
+ self.out_channels = out_channels or channels
486
+ self.use_conv = use_conv
487
+ self.dims = dims
488
+ if use_conv:
489
+ self.conv = conv_nd(
490
+ dims, self.channels, self.out_channels, 3, padding=padding
491
+ )
492
+
493
+ def forward(self, x):
494
+ assert x.shape[1] == self.channels
495
+ if self.dims == 3:
496
+ x = F.interpolate(
497
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
498
+ )
499
+ else:
500
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
501
+ if self.use_conv:
502
+ x = self.conv(x)
503
+ return x
504
+
505
+
506
+ class Downsample(nn.Module):
507
+ """
508
+ A downsampling layer with an optional convolution.
509
+ :param channels: channels in the inputs and outputs.
510
+ :param use_conv: a bool determining if a convolution is applied.
511
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
512
+ downsampling occurs in the inner-two dimensions.
513
+ """
514
+
515
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
516
+ super().__init__()
517
+ self.channels = channels
518
+ self.out_channels = out_channels or channels
519
+ self.use_conv = use_conv
520
+ self.dims = dims
521
+ stride = 2 if dims != 3 else (1, 2, 2)
522
+ if use_conv:
523
+ self.op = conv_nd(
524
+ dims,
525
+ self.channels,
526
+ self.out_channels,
527
+ 3,
528
+ stride=stride,
529
+ padding=padding,
530
+ )
531
+ else:
532
+ assert self.channels == self.out_channels
533
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
534
+
535
+ def forward(self, x):
536
+ assert x.shape[1] == self.channels
537
+ return self.op(x)
538
+
539
+
540
+ class ResBlock(nn.Module):
541
+ """
542
+ A residual block that can optionally change the number of channels.
543
+ :param channels: the number of input channels.
544
+ :param emb_channels: the number of timestep embedding channels.
545
+ :param dropout: the rate of dropout.
546
+ :param out_channels: if specified, the number of out channels.
547
+ :param use_conv: if True and out_channels is specified, use a spatial
548
+ convolution instead of a smaller 1x1 convolution to change the
549
+ channels in the skip connection.
550
+ :param dims: determines if the signal is 1D, 2D, or 3D.
551
+ :param up: if True, use this block for upsampling.
552
+ :param down: if True, use this block for downsampling.
553
+ """
554
+
555
+ def __init__(
556
+ self,
557
+ channels,
558
+ emb_channels,
559
+ dropout,
560
+ out_channels=None,
561
+ use_conv=False,
562
+ use_scale_shift_norm=False,
563
+ dims=2,
564
+ up=False,
565
+ down=False,
566
+ ):
567
+ super().__init__()
568
+ self.channels = channels
569
+ self.emb_channels = emb_channels
570
+ self.dropout = dropout
571
+ self.out_channels = out_channels or channels
572
+ self.use_conv = use_conv
573
+ self.use_scale_shift_norm = use_scale_shift_norm
574
+
575
+ self.in_layers = nn.Sequential(
576
+ nn.GroupNorm(32, channels),
577
+ nn.SiLU(),
578
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
579
+ )
580
+
581
+ self.updown = up or down
582
+
583
+ if up:
584
+ self.h_upd = Upsample(channels, False, dims)
585
+ self.x_upd = Upsample(channels, False, dims)
586
+ elif down:
587
+ self.h_upd = Downsample(channels, False, dims)
588
+ self.x_upd = Downsample(channels, False, dims)
589
+ else:
590
+ self.h_upd = self.x_upd = nn.Identity()
591
+
592
+ self.emb_layers = nn.Sequential(
593
+ nn.SiLU(),
594
+ nn.Linear(
595
+ emb_channels,
596
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
597
+ ),
598
+ )
599
+ self.out_layers = nn.Sequential(
600
+ nn.GroupNorm(32, self.out_channels),
601
+ nn.SiLU(),
602
+ nn.Dropout(p=dropout),
603
+ zero_module(
604
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
605
+ ),
606
+ )
607
+
608
+ if self.out_channels == channels:
609
+ self.skip_connection = nn.Identity()
610
+ elif use_conv:
611
+ self.skip_connection = conv_nd(
612
+ dims, channels, self.out_channels, 3, padding=1
613
+ )
614
+ else:
615
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
616
+
617
+ def forward(self, x, emb):
618
+ if self.updown:
619
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
620
+ h = in_rest(x)
621
+ h = self.h_upd(h)
622
+ x = self.x_upd(x)
623
+ h = in_conv(h)
624
+ else:
625
+ h = self.in_layers(x)
626
+ emb_out = self.emb_layers(emb).type(h.dtype)
627
+ while len(emb_out.shape) < len(h.shape):
628
+ emb_out = emb_out[..., None]
629
+ if self.use_scale_shift_norm:
630
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
631
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
632
+ h = out_norm(h) * (1 + scale) + shift
633
+ h = out_rest(h)
634
+ else:
635
+ h = h + emb_out
636
+ h = self.out_layers(h)
637
+ return self.skip_connection(x) + h
638
+
639
+
640
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
641
+ """
642
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
643
+ :param in_channels: channels in the input Tensor.
644
+ :param model_channels: base channel count for the model.
645
+ :param out_channels: channels in the output Tensor.
646
+ :param num_res_blocks: number of residual blocks per downsample.
647
+ :param attention_resolutions: a collection of downsample rates at which
648
+ attention will take place. May be a set, list, or tuple.
649
+ For example, if this contains 4, then at 4x downsampling, attention
650
+ will be used.
651
+ :param dropout: the dropout probability.
652
+ :param channel_mult: channel multiplier for each level of the UNet.
653
+ :param conv_resample: if True, use learned convolutions for upsampling and
654
+ downsampling.
655
+ :param dims: determines if the signal is 1D, 2D, or 3D.
656
+ :param num_classes: if specified (as an int), then this model will be
657
+ class-conditional with `num_classes` classes.
658
+ :param num_heads: the number of attention heads in each attention layer.
659
+ :param num_heads_channels: if specified, ignore num_heads and instead use
660
+ a fixed channel width per attention head.
661
+ :param num_heads_upsample: works with num_heads to set a different number
662
+ of heads for upsampling. Deprecated.
663
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
664
+ :param resblock_updown: use residual blocks for up/downsampling.
665
+ :param use_new_attention_order: use a different attention pattern for potentially
666
+ increased efficiency.
667
+ :param camera_dim: dimensionality of camera input.
668
+ """
669
+
670
+ def __init__(
671
+ self,
672
+ image_size,
673
+ in_channels,
674
+ model_channels,
675
+ out_channels,
676
+ num_res_blocks,
677
+ attention_resolutions,
678
+ dropout=0,
679
+ channel_mult=(1, 2, 4, 8),
680
+ conv_resample=True,
681
+ dims=2,
682
+ num_classes=None,
683
+ num_heads=-1,
684
+ num_head_channels=-1,
685
+ num_heads_upsample=-1,
686
+ use_scale_shift_norm=False,
687
+ resblock_updown=False,
688
+ transformer_depth=1,
689
+ context_dim=None,
690
+ n_embed=None,
691
+ num_attention_blocks=None,
692
+ adm_in_channels=None,
693
+ camera_dim=None,
694
+ ip_dim=0, # imagedream uses ip_dim > 0
695
+ ip_weight=1.0,
696
+ **kwargs,
697
+ ):
698
+ super().__init__()
699
+ assert context_dim is not None
700
+
701
+ if num_heads_upsample == -1:
702
+ num_heads_upsample = num_heads
703
+
704
+ if num_heads == -1:
705
+ assert (
706
+ num_head_channels != -1
707
+ ), "Either num_heads or num_head_channels has to be set"
708
+
709
+ if num_head_channels == -1:
710
+ assert (
711
+ num_heads != -1
712
+ ), "Either num_heads or num_head_channels has to be set"
713
+
714
+ self.image_size = image_size
715
+ self.in_channels = in_channels
716
+ self.model_channels = model_channels
717
+ self.out_channels = out_channels
718
+ if isinstance(num_res_blocks, int):
719
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
720
+ else:
721
+ if len(num_res_blocks) != len(channel_mult):
722
+ raise ValueError(
723
+ "provide num_res_blocks either as an int (globally constant) or "
724
+ "as a list/tuple (per-level) with the same length as channel_mult"
725
+ )
726
+ self.num_res_blocks = num_res_blocks
727
+
728
+ if num_attention_blocks is not None:
729
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
730
+ assert all(
731
+ map(
732
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
733
+ range(len(num_attention_blocks)),
734
+ )
735
+ )
736
+ print(
737
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
738
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
739
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
740
+ f"attention will still not be set."
741
+ )
742
+
743
+ self.attention_resolutions = attention_resolutions
744
+ self.dropout = dropout
745
+ self.channel_mult = channel_mult
746
+ self.conv_resample = conv_resample
747
+ self.num_classes = num_classes
748
+ self.num_heads = num_heads
749
+ self.num_head_channels = num_head_channels
750
+ self.num_heads_upsample = num_heads_upsample
751
+ self.predict_codebook_ids = n_embed is not None
752
+
753
+ self.ip_dim = ip_dim
754
+ self.ip_weight = ip_weight
755
+
756
+ if self.ip_dim > 0:
757
+ self.image_embed = Resampler(
758
+ dim=context_dim,
759
+ depth=4,
760
+ dim_head=64,
761
+ heads=12,
762
+ num_queries=ip_dim, # num token
763
+ embedding_dim=1280,
764
+ output_dim=context_dim,
765
+ ff_mult=4,
766
+ )
767
+
768
+ time_embed_dim = model_channels * 4
769
+ self.time_embed = nn.Sequential(
770
+ nn.Linear(model_channels, time_embed_dim),
771
+ nn.SiLU(),
772
+ nn.Linear(time_embed_dim, time_embed_dim),
773
+ )
774
+
775
+ if camera_dim is not None:
776
+ time_embed_dim = model_channels * 4
777
+ self.camera_embed = nn.Sequential(
778
+ nn.Linear(camera_dim, time_embed_dim),
779
+ nn.SiLU(),
780
+ nn.Linear(time_embed_dim, time_embed_dim),
781
+ )
782
+
783
+ if self.num_classes is not None:
784
+ if isinstance(self.num_classes, int):
785
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
786
+ elif self.num_classes == "continuous":
787
+ # print("setting up linear c_adm embedding layer")
788
+ self.label_emb = nn.Linear(1, time_embed_dim)
789
+ elif self.num_classes == "sequential":
790
+ assert adm_in_channels is not None
791
+ self.label_emb = nn.Sequential(
792
+ nn.Sequential(
793
+ nn.Linear(adm_in_channels, time_embed_dim),
794
+ nn.SiLU(),
795
+ nn.Linear(time_embed_dim, time_embed_dim),
796
+ )
797
+ )
798
+ else:
799
+ raise ValueError()
800
+
801
+ self.input_blocks = nn.ModuleList(
802
+ [CondSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
803
+ )
804
+ self._feature_size = model_channels
805
+ input_block_chans = [model_channels]
806
+ ch = model_channels
807
+ ds = 1
808
+ for level, mult in enumerate(channel_mult):
809
+ for nr in range(self.num_res_blocks[level]):
810
+ layers: List[Any] = [
811
+ ResBlock(
812
+ ch,
813
+ time_embed_dim,
814
+ dropout,
815
+ out_channels=mult * model_channels,
816
+ dims=dims,
817
+ use_scale_shift_norm=use_scale_shift_norm,
818
+ )
819
+ ]
820
+ ch = mult * model_channels
821
+ if ds in attention_resolutions:
822
+ if num_head_channels == -1:
823
+ dim_head = ch // num_heads
824
+ else:
825
+ num_heads = ch // num_head_channels
826
+ dim_head = num_head_channels
827
+
828
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
829
+ layers.append(
830
+ SpatialTransformer3D(
831
+ ch,
832
+ num_heads,
833
+ dim_head,
834
+ context_dim=context_dim,
835
+ depth=transformer_depth,
836
+ ip_dim=self.ip_dim,
837
+ ip_weight=self.ip_weight,
838
+ )
839
+ )
840
+ self.input_blocks.append(CondSequential(*layers))
841
+ self._feature_size += ch
842
+ input_block_chans.append(ch)
843
+ if level != len(channel_mult) - 1:
844
+ out_ch = ch
845
+ self.input_blocks.append(
846
+ CondSequential(
847
+ ResBlock(
848
+ ch,
849
+ time_embed_dim,
850
+ dropout,
851
+ out_channels=out_ch,
852
+ dims=dims,
853
+ use_scale_shift_norm=use_scale_shift_norm,
854
+ down=True,
855
+ )
856
+ if resblock_updown
857
+ else Downsample(
858
+ ch, conv_resample, dims=dims, out_channels=out_ch
859
+ )
860
+ )
861
+ )
862
+ ch = out_ch
863
+ input_block_chans.append(ch)
864
+ ds *= 2
865
+ self._feature_size += ch
866
+
867
+ if num_head_channels == -1:
868
+ dim_head = ch // num_heads
869
+ else:
870
+ num_heads = ch // num_head_channels
871
+ dim_head = num_head_channels
872
+
873
+ self.middle_block = CondSequential(
874
+ ResBlock(
875
+ ch,
876
+ time_embed_dim,
877
+ dropout,
878
+ dims=dims,
879
+ use_scale_shift_norm=use_scale_shift_norm,
880
+ ),
881
+ SpatialTransformer3D(
882
+ ch,
883
+ num_heads,
884
+ dim_head,
885
+ context_dim=context_dim,
886
+ depth=transformer_depth,
887
+ ip_dim=self.ip_dim,
888
+ ip_weight=self.ip_weight,
889
+ ),
890
+ ResBlock(
891
+ ch,
892
+ time_embed_dim,
893
+ dropout,
894
+ dims=dims,
895
+ use_scale_shift_norm=use_scale_shift_norm,
896
+ ),
897
+ )
898
+ self._feature_size += ch
899
+
900
+ self.output_blocks = nn.ModuleList([])
901
+ for level, mult in list(enumerate(channel_mult))[::-1]:
902
+ for i in range(self.num_res_blocks[level] + 1):
903
+ ich = input_block_chans.pop()
904
+ layers = [
905
+ ResBlock(
906
+ ch + ich,
907
+ time_embed_dim,
908
+ dropout,
909
+ out_channels=model_channels * mult,
910
+ dims=dims,
911
+ use_scale_shift_norm=use_scale_shift_norm,
912
+ )
913
+ ]
914
+ ch = model_channels * mult
915
+ if ds in attention_resolutions:
916
+ if num_head_channels == -1:
917
+ dim_head = ch // num_heads
918
+ else:
919
+ num_heads = ch // num_head_channels
920
+ dim_head = num_head_channels
921
+
922
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
923
+ layers.append(
924
+ SpatialTransformer3D(
925
+ ch,
926
+ num_heads,
927
+ dim_head,
928
+ context_dim=context_dim,
929
+ depth=transformer_depth,
930
+ ip_dim=self.ip_dim,
931
+ ip_weight=self.ip_weight,
932
+ )
933
+ )
934
+ if level and i == self.num_res_blocks[level]:
935
+ out_ch = ch
936
+ layers.append(
937
+ ResBlock(
938
+ ch,
939
+ time_embed_dim,
940
+ dropout,
941
+ out_channels=out_ch,
942
+ dims=dims,
943
+ use_scale_shift_norm=use_scale_shift_norm,
944
+ up=True,
945
+ )
946
+ if resblock_updown
947
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
948
+ )
949
+ ds //= 2
950
+ self.output_blocks.append(CondSequential(*layers))
951
+ self._feature_size += ch
952
+
953
+ self.out = nn.Sequential(
954
+ nn.GroupNorm(32, ch),
955
+ nn.SiLU(),
956
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
957
+ )
958
+ if self.predict_codebook_ids:
959
+ self.id_predictor = nn.Sequential(
960
+ nn.GroupNorm(32, ch),
961
+ conv_nd(dims, model_channels, n_embed, 1),
962
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
963
+ )
964
+
965
+ def forward(
966
+ self,
967
+ x,
968
+ timesteps=None,
969
+ context=None,
970
+ y=None,
971
+ camera=None,
972
+ num_frames=1,
973
+ ip=None,
974
+ ip_img=None,
975
+ **kwargs,
976
+ ):
977
+ """
978
+ Apply the model to an input batch.
979
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
980
+ :param timesteps: a 1-D batch of timesteps.
981
+ :param context: conditioning plugged in via crossattn
982
+ :param y: an [N] Tensor of labels, if class-conditional.
983
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
984
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
985
+ """
986
+ assert (
987
+ x.shape[0] % num_frames == 0
988
+ ), "input batch size must be dividable by num_frames!"
989
+ assert (y is not None) == (
990
+ self.num_classes is not None
991
+ ), "must specify y if and only if the model is class-conditional"
992
+
993
+ hs = []
994
+
995
+ t_emb = timestep_embedding(
996
+ timesteps, self.model_channels, repeat_only=False
997
+ ).to(x.dtype)
998
+
999
+ emb = self.time_embed(t_emb)
1000
+
1001
+ if self.num_classes is not None:
1002
+ assert y is not None
1003
+ assert y.shape[0] == x.shape[0]
1004
+ emb = emb + self.label_emb(y)
1005
+
1006
+ # Add camera embeddings
1007
+ if camera is not None:
1008
+ emb = emb + self.camera_embed(camera)
1009
+
1010
+ # imagedream variant
1011
+ if self.ip_dim > 0:
1012
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
1013
+ ip_emb = self.image_embed(ip)
1014
+ context = torch.cat((context, ip_emb), 1)
1015
+
1016
+ h = x
1017
+ for module in self.input_blocks:
1018
+ h = module(h, emb, context, num_frames=num_frames)
1019
+ hs.append(h)
1020
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
1021
+ for module in self.output_blocks:
1022
+ h = torch.cat([h, hs.pop()], dim=1)
1023
+ h = module(h, emb, context, num_frames=num_frames)
1024
+ h = h.type(x.dtype)
1025
+ if self.predict_codebook_ids:
1026
+ return self.id_predictor(h)
1027
+ else:
1028
+ return self.out(h)
1029
+
1030
+
1031
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1032
+
1033
+
1034
+ class MVDreamPipeline(DiffusionPipeline):
1035
+
1036
+ _optional_components = ["feature_extractor", "image_encoder"]
1037
+
1038
+ def __init__(
1039
+ self,
1040
+ vae: AutoencoderKL,
1041
+ unet: MultiViewUNetModel,
1042
+ tokenizer: CLIPTokenizer,
1043
+ text_encoder: CLIPTextModel,
1044
+ scheduler: DDIMScheduler,
1045
+ # imagedream variant
1046
+ feature_extractor: CLIPImageProcessor,
1047
+ image_encoder: CLIPVisionModel,
1048
+ requires_safety_checker: bool = False,
1049
+ ):
1050
+ super().__init__()
1051
+
1052
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
1053
+ deprecation_message = (
1054
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
1055
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
1056
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
1057
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
1058
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
1059
+ " file"
1060
+ )
1061
+ deprecate(
1062
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
1063
+ )
1064
+ new_config = dict(scheduler.config)
1065
+ new_config["steps_offset"] = 1
1066
+ scheduler._internal_dict = FrozenDict(new_config)
1067
+
1068
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
1069
+ deprecation_message = (
1070
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
1071
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
1072
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
1073
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
1074
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
1075
+ )
1076
+ deprecate(
1077
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
1078
+ )
1079
+ new_config = dict(scheduler.config)
1080
+ new_config["clip_sample"] = False
1081
+ scheduler._internal_dict = FrozenDict(new_config)
1082
+
1083
+ self.register_modules(
1084
+ vae=vae,
1085
+ unet=unet,
1086
+ scheduler=scheduler,
1087
+ tokenizer=tokenizer,
1088
+ text_encoder=text_encoder,
1089
+ feature_extractor=feature_extractor,
1090
+ image_encoder=image_encoder,
1091
+ )
1092
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
1093
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
1094
+
1095
+ def enable_vae_slicing(self):
1096
+ r"""
1097
+ Enable sliced VAE decoding.
1098
+
1099
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
1100
+ steps. This is useful to save some memory and allow larger batch sizes.
1101
+ """
1102
+ self.vae.enable_slicing()
1103
+
1104
+ def disable_vae_slicing(self):
1105
+ r"""
1106
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
1107
+ computing decoding in one step.
1108
+ """
1109
+ self.vae.disable_slicing()
1110
+
1111
+ def enable_vae_tiling(self):
1112
+ r"""
1113
+ Enable tiled VAE decoding.
1114
+
1115
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
1116
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
1117
+ """
1118
+ self.vae.enable_tiling()
1119
+
1120
+ def disable_vae_tiling(self):
1121
+ r"""
1122
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
1123
+ computing decoding in one step.
1124
+ """
1125
+ self.vae.disable_tiling()
1126
+
1127
+ def enable_sequential_cpu_offload(self, gpu_id=0):
1128
+ r"""
1129
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
1130
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
1131
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
1132
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
1133
+ `enable_model_cpu_offload`, but performance is lower.
1134
+ """
1135
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
1136
+ from accelerate import cpu_offload
1137
+ else:
1138
+ raise ImportError(
1139
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
1140
+ )
1141
+
1142
+ device = torch.device(f"cuda:{gpu_id}")
1143
+
1144
+ if self.device.type != "cpu":
1145
+ self.to("cpu", silence_dtype_warnings=True)
1146
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1147
+
1148
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
1149
+ cpu_offload(cpu_offloaded_model, device)
1150
+
1151
+ def enable_model_cpu_offload(self, gpu_id=0):
1152
+ r"""
1153
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
1154
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
1155
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
1156
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
1157
+ """
1158
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1159
+ from accelerate import cpu_offload_with_hook
1160
+ else:
1161
+ raise ImportError(
1162
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
1163
+ )
1164
+
1165
+ device = torch.device(f"cuda:{gpu_id}")
1166
+
1167
+ if self.device.type != "cpu":
1168
+ self.to("cpu", silence_dtype_warnings=True)
1169
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1170
+
1171
+ hook = None
1172
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
1173
+ _, hook = cpu_offload_with_hook(
1174
+ cpu_offloaded_model, device, prev_module_hook=hook
1175
+ )
1176
+
1177
+ # We'll offload the last model manually.
1178
+ self.final_offload_hook = hook
1179
+
1180
+ @property
1181
+ def _execution_device(self):
1182
+ r"""
1183
+ Returns the device on which the pipeline's models will be executed. After calling
1184
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
1185
+ hooks.
1186
+ """
1187
+ if not hasattr(self.unet, "_hf_hook"):
1188
+ return self.device
1189
+ for module in self.unet.modules():
1190
+ if (
1191
+ hasattr(module, "_hf_hook")
1192
+ and hasattr(module._hf_hook, "execution_device")
1193
+ and module._hf_hook.execution_device is not None
1194
+ ):
1195
+ return torch.device(module._hf_hook.execution_device)
1196
+ return self.device
1197
+
1198
+ def _encode_prompt(
1199
+ self,
1200
+ prompt,
1201
+ device,
1202
+ num_images_per_prompt,
1203
+ do_classifier_free_guidance: bool,
1204
+ negative_prompt=None,
1205
+ ):
1206
+ r"""
1207
+ Encodes the prompt into text encoder hidden states.
1208
+
1209
+ Args:
1210
+ prompt (`str` or `List[str]`, *optional*):
1211
+ prompt to be encoded
1212
+ device: (`torch.device`):
1213
+ torch device
1214
+ num_images_per_prompt (`int`):
1215
+ number of images that should be generated per prompt
1216
+ do_classifier_free_guidance (`bool`):
1217
+ whether to use classifier free guidance or not
1218
+ negative_prompt (`str` or `List[str]`, *optional*):
1219
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1220
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
1221
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
1222
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1223
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1224
+ provided, text embeddings will be generated from `prompt` input argument.
1225
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1226
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1227
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1228
+ argument.
1229
+ """
1230
+ if prompt is not None and isinstance(prompt, str):
1231
+ batch_size = 1
1232
+ elif prompt is not None and isinstance(prompt, list):
1233
+ batch_size = len(prompt)
1234
+ else:
1235
+ raise ValueError(
1236
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
1237
+ )
1238
+
1239
+ text_inputs = self.tokenizer(
1240
+ prompt,
1241
+ padding="max_length",
1242
+ max_length=self.tokenizer.model_max_length,
1243
+ truncation=True,
1244
+ return_tensors="pt",
1245
+ )
1246
+ text_input_ids = text_inputs.input_ids
1247
+ untruncated_ids = self.tokenizer(
1248
+ prompt, padding="longest", return_tensors="pt"
1249
+ ).input_ids
1250
+
1251
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
1252
+ text_input_ids, untruncated_ids
1253
+ ):
1254
+ removed_text = self.tokenizer.batch_decode(
1255
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
1256
+ )
1257
+ logger.warning(
1258
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
1259
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
1260
+ )
1261
+
1262
+ if (
1263
+ hasattr(self.text_encoder.config, "use_attention_mask")
1264
+ and self.text_encoder.config.use_attention_mask
1265
+ ):
1266
+ attention_mask = text_inputs.attention_mask.to(device)
1267
+ else:
1268
+ attention_mask = None
1269
+
1270
+ prompt_embeds = self.text_encoder(
1271
+ text_input_ids.to(device),
1272
+ attention_mask=attention_mask,
1273
+ )
1274
+ prompt_embeds = prompt_embeds[0]
1275
+
1276
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
1277
+
1278
+ bs_embed, seq_len, _ = prompt_embeds.shape
1279
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
1280
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1281
+ prompt_embeds = prompt_embeds.view(
1282
+ bs_embed * num_images_per_prompt, seq_len, -1
1283
+ )
1284
+
1285
+ # get unconditional embeddings for classifier free guidance
1286
+ if do_classifier_free_guidance:
1287
+ uncond_tokens: List[str]
1288
+ if negative_prompt is None:
1289
+ uncond_tokens = [""] * batch_size
1290
+ elif type(prompt) is not type(negative_prompt):
1291
+ raise TypeError(
1292
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
1293
+ f" {type(prompt)}."
1294
+ )
1295
+ elif isinstance(negative_prompt, str):
1296
+ uncond_tokens = [negative_prompt]
1297
+ elif batch_size != len(negative_prompt):
1298
+ raise ValueError(
1299
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
1300
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
1301
+ " the batch size of `prompt`."
1302
+ )
1303
+ else:
1304
+ uncond_tokens = negative_prompt
1305
+
1306
+ max_length = prompt_embeds.shape[1]
1307
+ uncond_input = self.tokenizer(
1308
+ uncond_tokens,
1309
+ padding="max_length",
1310
+ max_length=max_length,
1311
+ truncation=True,
1312
+ return_tensors="pt",
1313
+ )
1314
+
1315
+ if (
1316
+ hasattr(self.text_encoder.config, "use_attention_mask")
1317
+ and self.text_encoder.config.use_attention_mask
1318
+ ):
1319
+ attention_mask = uncond_input.attention_mask.to(device)
1320
+ else:
1321
+ attention_mask = None
1322
+
1323
+ negative_prompt_embeds = self.text_encoder(
1324
+ uncond_input.input_ids.to(device),
1325
+ attention_mask=attention_mask,
1326
+ )
1327
+ negative_prompt_embeds = negative_prompt_embeds[0]
1328
+
1329
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
1330
+ seq_len = negative_prompt_embeds.shape[1]
1331
+
1332
+ negative_prompt_embeds = negative_prompt_embeds.to(
1333
+ dtype=self.text_encoder.dtype, device=device
1334
+ )
1335
+
1336
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
1337
+ 1, num_images_per_prompt, 1
1338
+ )
1339
+ negative_prompt_embeds = negative_prompt_embeds.view(
1340
+ batch_size * num_images_per_prompt, seq_len, -1
1341
+ )
1342
+
1343
+ # For classifier free guidance, we need to do two forward passes.
1344
+ # Here we concatenate the unconditional and text embeddings into a single batch
1345
+ # to avoid doing two forward passes
1346
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1347
+
1348
+ return prompt_embeds
1349
+
1350
+ def decode_latents(self, latents):
1351
+ latents = 1 / self.vae.config.scaling_factor * latents
1352
+ image = self.vae.decode(latents).sample
1353
+ image = (image / 2 + 0.5).clamp(0, 1)
1354
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
1355
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1356
+ return image
1357
+
1358
+ def prepare_extra_step_kwargs(self, generator, eta):
1359
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1360
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1361
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1362
+ # and should be between [0, 1]
1363
+
1364
+ accepts_eta = "eta" in set(
1365
+ inspect.signature(self.scheduler.step).parameters.keys()
1366
+ )
1367
+ extra_step_kwargs = {}
1368
+ if accepts_eta:
1369
+ extra_step_kwargs["eta"] = eta
1370
+
1371
+ # check if the scheduler accepts generator
1372
+ accepts_generator = "generator" in set(
1373
+ inspect.signature(self.scheduler.step).parameters.keys()
1374
+ )
1375
+ if accepts_generator:
1376
+ extra_step_kwargs["generator"] = generator
1377
+ return extra_step_kwargs
1378
+
1379
+ def prepare_latents(
1380
+ self,
1381
+ batch_size,
1382
+ num_channels_latents,
1383
+ height,
1384
+ width,
1385
+ dtype,
1386
+ device,
1387
+ generator,
1388
+ latents=None,
1389
+ ):
1390
+ shape = (
1391
+ batch_size,
1392
+ num_channels_latents,
1393
+ height // self.vae_scale_factor,
1394
+ width // self.vae_scale_factor,
1395
+ )
1396
+ if isinstance(generator, list) and len(generator) != batch_size:
1397
+ raise ValueError(
1398
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1399
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1400
+ )
1401
+
1402
+ if latents is None:
1403
+ latents = randn_tensor(
1404
+ shape, generator=generator, device=device, dtype=dtype
1405
+ )
1406
+ else:
1407
+ latents = latents.to(device)
1408
+
1409
+ # scale the initial noise by the standard deviation required by the scheduler
1410
+ latents = latents * self.scheduler.init_noise_sigma
1411
+ return latents
1412
+
1413
+ def encode_image(self, image, device, num_images_per_prompt):
1414
+ dtype = next(self.image_encoder.parameters()).dtype
1415
+
1416
+ if image.dtype == np.float32:
1417
+ image = (image * 255).astype(np.uint8)
1418
+
1419
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
1420
+ image = image.to(device=device, dtype=dtype)
1421
+
1422
+ image_embeds = self.image_encoder(
1423
+ image, output_hidden_states=True
1424
+ ).hidden_states[-2]
1425
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
1426
+
1427
+ return torch.zeros_like(image_embeds), image_embeds
1428
+
1429
+ def encode_image_latents(self, image, device, num_images_per_prompt):
1430
+
1431
+ dtype = next(self.image_encoder.parameters()).dtype
1432
+
1433
+ image = (
1434
+ torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
1435
+ ) # [1, 3, H, W]
1436
+ image = 2 * image - 1
1437
+ image = F.interpolate(image, (256, 256), mode="bilinear", align_corners=False)
1438
+ image = image.to(dtype=dtype)
1439
+
1440
+ posterior = self.vae.encode(image).latent_dist
1441
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
1442
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
1443
+
1444
+ return torch.zeros_like(latents), latents
1445
+
1446
+ @torch.no_grad()
1447
+ def __call__(
1448
+ self,
1449
+ prompt: str = "",
1450
+ image: Optional[np.ndarray] = None,
1451
+ height: int = 256,
1452
+ width: int = 256,
1453
+ elevation: float = 0,
1454
+ num_inference_steps: int = 50,
1455
+ guidance_scale: float = 7.0,
1456
+ negative_prompt: str = "",
1457
+ num_images_per_prompt: int = 1,
1458
+ eta: float = 0.0,
1459
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1460
+ output_type: Optional[str] = "numpy", # pil, numpy, latents
1461
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1462
+ callback_steps: int = 1,
1463
+ num_frames: int = 4,
1464
+ device=torch.device("cuda:0"),
1465
+ ):
1466
+ self.unet = self.unet.to(device=device)
1467
+ self.vae = self.vae.to(device=device)
1468
+ self.text_encoder = self.text_encoder.to(device=device)
1469
+
1470
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1471
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1472
+ # corresponds to doing no classifier free guidance.
1473
+ do_classifier_free_guidance = guidance_scale > 1.0
1474
+
1475
+ # Prepare timesteps
1476
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1477
+ timesteps = self.scheduler.timesteps
1478
+
1479
+ # imagedream variant
1480
+ if image is not None:
1481
+ assert isinstance(image, np.ndarray) and image.dtype == np.float32
1482
+ self.image_encoder = self.image_encoder.to(device=device)
1483
+ image_embeds_neg, image_embeds_pos = self.encode_image(
1484
+ image, device, num_images_per_prompt
1485
+ )
1486
+ image_latents_neg, image_latents_pos = self.encode_image_latents(
1487
+ image, device, num_images_per_prompt
1488
+ )
1489
+
1490
+ _prompt_embeds = self._encode_prompt(
1491
+ prompt=prompt,
1492
+ device=device,
1493
+ num_images_per_prompt=num_images_per_prompt,
1494
+ do_classifier_free_guidance=do_classifier_free_guidance,
1495
+ negative_prompt=negative_prompt,
1496
+ ) # type: ignore
1497
+ prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
1498
+
1499
+ # Prepare latent variables
1500
+ actual_num_frames = num_frames if image is None else num_frames + 1
1501
+ latents: torch.Tensor = self.prepare_latents(
1502
+ actual_num_frames * num_images_per_prompt,
1503
+ 4,
1504
+ height,
1505
+ width,
1506
+ prompt_embeds_pos.dtype,
1507
+ device,
1508
+ generator,
1509
+ None,
1510
+ )
1511
+
1512
+ # Get camera
1513
+ camera = get_camera(
1514
+ num_frames, elevation=elevation, extra_view=(image is not None)
1515
+ ).to(dtype=latents.dtype, device=device)
1516
+ camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
1517
+
1518
+ # Prepare extra step kwargs.
1519
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1520
+
1521
+ # Denoising loop
1522
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1523
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1524
+ for i, t in enumerate(timesteps):
1525
+ # expand the latents if we are doing classifier free guidance
1526
+ multiplier = 2 if do_classifier_free_guidance else 1
1527
+ latent_model_input = torch.cat([latents] * multiplier)
1528
+ latent_model_input = self.scheduler.scale_model_input(
1529
+ latent_model_input, t
1530
+ )
1531
+
1532
+ unet_inputs = {
1533
+ "x": latent_model_input,
1534
+ "timesteps": torch.tensor(
1535
+ [t] * actual_num_frames * multiplier,
1536
+ dtype=latent_model_input.dtype,
1537
+ device=device,
1538
+ ),
1539
+ "context": torch.cat(
1540
+ [prompt_embeds_neg] * actual_num_frames
1541
+ + [prompt_embeds_pos] * actual_num_frames
1542
+ ),
1543
+ "num_frames": actual_num_frames,
1544
+ "camera": torch.cat([camera] * multiplier),
1545
+ }
1546
+
1547
+ if image is not None:
1548
+ unet_inputs["ip"] = torch.cat(
1549
+ [image_embeds_neg] * actual_num_frames
1550
+ + [image_embeds_pos] * actual_num_frames
1551
+ )
1552
+ unet_inputs["ip_img"] = torch.cat(
1553
+ [image_latents_neg] + [image_latents_pos]
1554
+ ) # no repeat
1555
+
1556
+ # predict the noise residual
1557
+ noise_pred = self.unet.forward(**unet_inputs)
1558
+
1559
+ # perform guidance
1560
+ if do_classifier_free_guidance:
1561
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1562
+ noise_pred = noise_pred_uncond + guidance_scale * (
1563
+ noise_pred_text - noise_pred_uncond
1564
+ )
1565
+
1566
+ # compute the previous noisy sample x_t -> x_t-1
1567
+ latents: torch.Tensor = self.scheduler.step(
1568
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1569
+ )[0]
1570
+
1571
+ # call the callback, if provided
1572
+ if i == len(timesteps) - 1 or (
1573
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1574
+ ):
1575
+ progress_bar.update()
1576
+ if callback is not None and i % callback_steps == 0:
1577
+ callback(i, t, latents) # type: ignore
1578
+
1579
+ # Post-processing
1580
+ if output_type == "latent":
1581
+ image = latents
1582
+ elif output_type == "pil":
1583
+ image = self.decode_latents(latents)
1584
+ image = self.numpy_to_pil(image)
1585
+ else: # numpy
1586
+ image = self.decode_latents(latents)
1587
+
1588
+ # Offload last model to CPU
1589
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1590
+ self.final_offload_hook.offload()
1591
+
1592
+ return image
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wheel
2
+ numpy
3
+ tyro
4
+ diffusers
5
+ dearpygui
6
+ einops
7
+ accelerate
8
+ gradio
9
+ imageio
10
+ imageio-ffmpeg
11
+ lpips
12
+ matplotlib
13
+ packaging
14
+ Pillow
15
+ pygltflib
16
+ rembg[gpu,cli]
17
+ rich
18
+ safetensors
19
+ scikit-image
20
+ scikit-learn
21
+ scipy
22
+ spaces
23
+ tqdm
24
+ transformers
25
+ trimesh
26
+ kiui >= 0.2.3
27
+ xatlas
28
+ roma
29
+ plyfile
30
+ torch == 2.2.0
31
+ torchvision == 0.17.0
32
+ torchaudio == 2.2.0
33
+ xformers
34
+ ushlex
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.25.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 1,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/stable-diffusion-2-1",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1024,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 23,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.35.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1827c465450322616f06dea41596eac7d493f4e95904dcb51f0fc745c4e13f
3
+ size 680820392
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MultiViewUNetModel",
3
+ "_diffusers_version": "0.25.0",
4
+ "attention_resolutions": [
5
+ 4,
6
+ 2,
7
+ 1
8
+ ],
9
+ "camera_dim": 16,
10
+ "channel_mult": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 4
15
+ ],
16
+ "context_dim": 1024,
17
+ "image_size": 32,
18
+ "in_channels": 4,
19
+ "ip_dim": 16,
20
+ "model_channels": 320,
21
+ "num_head_channels": 64,
22
+ "num_res_blocks": 2,
23
+ "out_channels": 4,
24
+ "transformer_depth": 1
25
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28d8b241a54125fa0a041c1818a5dcdb717e6f5270eea1268172acd3ab0238e0
3
+ size 1883435904
unet/mv_unet.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from inspect import isfunction
4
+ from typing import Optional, Any, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from diffusers.configuration_utils import ConfigMixin
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+
14
+ # require xformers!
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ from kiui.cam import orbit_camera
19
+
20
+ def get_camera(
21
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
+ ):
23
+ angle_gap = azimuth_span / num_frames
24
+ cameras = []
25
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
+
27
+ pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
28
+
29
+ # opengl to blender
30
+ if blender_coord:
31
+ pose[2] *= -1
32
+ pose[[1, 2]] = pose[[2, 1]]
33
+
34
+ cameras.append(pose.flatten())
35
+
36
+ if extra_view:
37
+ cameras.append(np.zeros_like(cameras[0]))
38
+
39
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
+
41
+
42
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
43
+ """
44
+ Create sinusoidal timestep embeddings.
45
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
46
+ These may be fractional.
47
+ :param dim: the dimension of the output.
48
+ :param max_period: controls the minimum frequency of the embeddings.
49
+ :return: an [N x dim] Tensor of positional embeddings.
50
+ """
51
+ if not repeat_only:
52
+ half = dim // 2
53
+ freqs = torch.exp(
54
+ -math.log(max_period)
55
+ * torch.arange(start=0, end=half, dtype=torch.float32)
56
+ / half
57
+ ).to(device=timesteps.device)
58
+ args = timesteps[:, None] * freqs[None]
59
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
+ if dim % 2:
61
+ embedding = torch.cat(
62
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
63
+ )
64
+ else:
65
+ embedding = repeat(timesteps, "b -> b d", d=dim)
66
+ # import pdb; pdb.set_trace()
67
+ return embedding
68
+
69
+
70
+ def zero_module(module):
71
+ """
72
+ Zero out the parameters of a module and return it.
73
+ """
74
+ for p in module.parameters():
75
+ p.detach().zero_()
76
+ return module
77
+
78
+
79
+ def conv_nd(dims, *args, **kwargs):
80
+ """
81
+ Create a 1D, 2D, or 3D convolution module.
82
+ """
83
+ if dims == 1:
84
+ return nn.Conv1d(*args, **kwargs)
85
+ elif dims == 2:
86
+ return nn.Conv2d(*args, **kwargs)
87
+ elif dims == 3:
88
+ return nn.Conv3d(*args, **kwargs)
89
+ raise ValueError(f"unsupported dimensions: {dims}")
90
+
91
+
92
+ def avg_pool_nd(dims, *args, **kwargs):
93
+ """
94
+ Create a 1D, 2D, or 3D average pooling module.
95
+ """
96
+ if dims == 1:
97
+ return nn.AvgPool1d(*args, **kwargs)
98
+ elif dims == 2:
99
+ return nn.AvgPool2d(*args, **kwargs)
100
+ elif dims == 3:
101
+ return nn.AvgPool3d(*args, **kwargs)
102
+ raise ValueError(f"unsupported dimensions: {dims}")
103
+
104
+
105
+ def default(val, d):
106
+ if val is not None:
107
+ return val
108
+ return d() if isfunction(d) else d
109
+
110
+
111
+ class GEGLU(nn.Module):
112
+ def __init__(self, dim_in, dim_out):
113
+ super().__init__()
114
+ self.proj = nn.Linear(dim_in, dim_out * 2)
115
+
116
+ def forward(self, x):
117
+ x, gate = self.proj(x).chunk(2, dim=-1)
118
+ return x * F.gelu(gate)
119
+
120
+
121
+ class FeedForward(nn.Module):
122
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
123
+ super().__init__()
124
+ inner_dim = int(dim * mult)
125
+ dim_out = default(dim_out, dim)
126
+ project_in = (
127
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
128
+ if not glu
129
+ else GEGLU(dim, inner_dim)
130
+ )
131
+
132
+ self.net = nn.Sequential(
133
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
134
+ )
135
+
136
+ def forward(self, x):
137
+ return self.net(x)
138
+
139
+
140
+ class MemoryEfficientCrossAttention(nn.Module):
141
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
142
+ def __init__(
143
+ self,
144
+ query_dim,
145
+ context_dim=None,
146
+ heads=8,
147
+ dim_head=64,
148
+ dropout=0.0,
149
+ ip_dim=0,
150
+ ip_weight=1,
151
+ ):
152
+ super().__init__()
153
+
154
+ inner_dim = dim_head * heads
155
+ context_dim = default(context_dim, query_dim)
156
+
157
+ self.heads = heads
158
+ self.dim_head = dim_head
159
+
160
+ self.ip_dim = ip_dim
161
+ self.ip_weight = ip_weight
162
+
163
+ if self.ip_dim > 0:
164
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
165
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
166
+
167
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
168
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
169
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
170
+
171
+ self.to_out = nn.Sequential(
172
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
173
+ )
174
+ self.attention_op: Optional[Any] = None
175
+
176
+ def forward(self, x, context=None):
177
+ q = self.to_q(x)
178
+ context = default(context, x)
179
+
180
+ if self.ip_dim > 0:
181
+ # context: [B, 77 + 16(ip), 1024]
182
+ token_len = context.shape[1]
183
+ context_ip = context[:, -self.ip_dim :, :]
184
+ k_ip = self.to_k_ip(context_ip)
185
+ v_ip = self.to_v_ip(context_ip)
186
+ context = context[:, : (token_len - self.ip_dim), :]
187
+
188
+ k = self.to_k(context)
189
+ v = self.to_v(context)
190
+
191
+ b, _, _ = q.shape
192
+ q, k, v = map(
193
+ lambda t: t.unsqueeze(3)
194
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
195
+ .permute(0, 2, 1, 3)
196
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
197
+ .contiguous(),
198
+ (q, k, v),
199
+ )
200
+
201
+ # actually compute the attention, what we cannot get enough of
202
+ out = xformers.ops.memory_efficient_attention(
203
+ q, k, v, attn_bias=None, op=self.attention_op
204
+ )
205
+
206
+ if self.ip_dim > 0:
207
+ k_ip, v_ip = map(
208
+ lambda t: t.unsqueeze(3)
209
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
210
+ .permute(0, 2, 1, 3)
211
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
212
+ .contiguous(),
213
+ (k_ip, v_ip),
214
+ )
215
+ # actually compute the attention, what we cannot get enough of
216
+ out_ip = xformers.ops.memory_efficient_attention(
217
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
218
+ )
219
+ out = out + self.ip_weight * out_ip
220
+
221
+ out = (
222
+ out.unsqueeze(0)
223
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
224
+ .permute(0, 2, 1, 3)
225
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
226
+ )
227
+ return self.to_out(out)
228
+
229
+
230
+ class BasicTransformerBlock3D(nn.Module):
231
+
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ n_heads,
236
+ d_head,
237
+ context_dim,
238
+ dropout=0.0,
239
+ gated_ff=True,
240
+ ip_dim=0,
241
+ ip_weight=1,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.attn1 = MemoryEfficientCrossAttention(
246
+ query_dim=dim,
247
+ context_dim=None, # self-attention
248
+ heads=n_heads,
249
+ dim_head=d_head,
250
+ dropout=dropout,
251
+ )
252
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
+ self.attn2 = MemoryEfficientCrossAttention(
254
+ query_dim=dim,
255
+ context_dim=context_dim,
256
+ heads=n_heads,
257
+ dim_head=d_head,
258
+ dropout=dropout,
259
+ # ip only applies to cross-attention
260
+ ip_dim=ip_dim,
261
+ ip_weight=ip_weight,
262
+ )
263
+ self.norm1 = nn.LayerNorm(dim)
264
+ self.norm2 = nn.LayerNorm(dim)
265
+ self.norm3 = nn.LayerNorm(dim)
266
+
267
+ def forward(self, x, context=None, num_frames=1):
268
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
269
+ x = self.attn1(self.norm1(x), context=None) + x
270
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
271
+ x = self.attn2(self.norm2(x), context=context) + x
272
+ x = self.ff(self.norm3(x)) + x
273
+ return x
274
+
275
+
276
+ class SpatialTransformer3D(nn.Module):
277
+
278
+ def __init__(
279
+ self,
280
+ in_channels,
281
+ n_heads,
282
+ d_head,
283
+ context_dim, # cross attention input dim
284
+ depth=1,
285
+ dropout=0.0,
286
+ ip_dim=0,
287
+ ip_weight=1,
288
+ ):
289
+ super().__init__()
290
+
291
+ if not isinstance(context_dim, list):
292
+ context_dim = [context_dim]
293
+
294
+ self.in_channels = in_channels
295
+
296
+ inner_dim = n_heads * d_head
297
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
298
+ self.proj_in = nn.Linear(in_channels, inner_dim)
299
+
300
+ self.transformer_blocks = nn.ModuleList(
301
+ [
302
+ BasicTransformerBlock3D(
303
+ inner_dim,
304
+ n_heads,
305
+ d_head,
306
+ context_dim=context_dim[d],
307
+ dropout=dropout,
308
+ ip_dim=ip_dim,
309
+ ip_weight=ip_weight,
310
+ )
311
+ for d in range(depth)
312
+ ]
313
+ )
314
+
315
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
316
+
317
+
318
+ def forward(self, x, context=None, num_frames=1):
319
+ # note: if no context is given, cross-attention defaults to self-attention
320
+ if not isinstance(context, list):
321
+ context = [context]
322
+ b, c, h, w = x.shape
323
+ x_in = x
324
+ x = self.norm(x)
325
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
326
+ x = self.proj_in(x)
327
+ for i, block in enumerate(self.transformer_blocks):
328
+ x = block(x, context=context[i], num_frames=num_frames)
329
+ x = self.proj_out(x)
330
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
331
+
332
+ return x + x_in
333
+
334
+
335
+ class PerceiverAttention(nn.Module):
336
+ def __init__(self, *, dim, dim_head=64, heads=8):
337
+ super().__init__()
338
+ self.scale = dim_head ** -0.5
339
+ self.dim_head = dim_head
340
+ self.heads = heads
341
+ inner_dim = dim_head * heads
342
+
343
+ self.norm1 = nn.LayerNorm(dim)
344
+ self.norm2 = nn.LayerNorm(dim)
345
+
346
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
347
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
348
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
349
+
350
+ def forward(self, x, latents):
351
+ """
352
+ Args:
353
+ x (torch.Tensor): image features
354
+ shape (b, n1, D)
355
+ latent (torch.Tensor): latent features
356
+ shape (b, n2, D)
357
+ """
358
+ x = self.norm1(x)
359
+ latents = self.norm2(latents)
360
+
361
+ b, l, _ = latents.shape
362
+
363
+ q = self.to_q(latents)
364
+ kv_input = torch.cat((x, latents), dim=-2)
365
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
366
+
367
+ q, k, v = map(
368
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
369
+ .transpose(1, 2)
370
+ .reshape(b, self.heads, t.shape[1], -1)
371
+ .contiguous(),
372
+ (q, k, v),
373
+ )
374
+
375
+ # attention
376
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
377
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
378
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
379
+ out = weight @ v
380
+
381
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
382
+
383
+ return self.to_out(out)
384
+
385
+
386
+ class Resampler(nn.Module):
387
+ def __init__(
388
+ self,
389
+ dim=1024,
390
+ depth=8,
391
+ dim_head=64,
392
+ heads=16,
393
+ num_queries=8,
394
+ embedding_dim=768,
395
+ output_dim=1024,
396
+ ff_mult=4,
397
+ ):
398
+ super().__init__()
399
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
400
+ self.proj_in = nn.Linear(embedding_dim, dim)
401
+ self.proj_out = nn.Linear(dim, output_dim)
402
+ self.norm_out = nn.LayerNorm(output_dim)
403
+
404
+ self.layers = nn.ModuleList([])
405
+ for _ in range(depth):
406
+ self.layers.append(
407
+ nn.ModuleList(
408
+ [
409
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
410
+ nn.Sequential(
411
+ nn.LayerNorm(dim),
412
+ nn.Linear(dim, dim * ff_mult, bias=False),
413
+ nn.GELU(),
414
+ nn.Linear(dim * ff_mult, dim, bias=False),
415
+ )
416
+ ]
417
+ )
418
+ )
419
+
420
+ def forward(self, x):
421
+ latents = self.latents.repeat(x.size(0), 1, 1)
422
+ x = self.proj_in(x)
423
+ for attn, ff in self.layers:
424
+ latents = attn(x, latents) + latents
425
+ latents = ff(latents) + latents
426
+
427
+ latents = self.proj_out(latents)
428
+ return self.norm_out(latents)
429
+
430
+
431
+ class CondSequential(nn.Sequential):
432
+ """
433
+ A sequential module that passes timestep embeddings to the children that
434
+ support it as an extra input.
435
+ """
436
+
437
+ def forward(self, x, emb, context=None, num_frames=1):
438
+ for layer in self:
439
+ if isinstance(layer, ResBlock):
440
+ x = layer(x, emb)
441
+ elif isinstance(layer, SpatialTransformer3D):
442
+ x = layer(x, context, num_frames=num_frames)
443
+ else:
444
+ x = layer(x)
445
+ return x
446
+
447
+
448
+ class Upsample(nn.Module):
449
+ """
450
+ An upsampling layer with an optional convolution.
451
+ :param channels: channels in the inputs and outputs.
452
+ :param use_conv: a bool determining if a convolution is applied.
453
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
454
+ upsampling occurs in the inner-two dimensions.
455
+ """
456
+
457
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
458
+ super().__init__()
459
+ self.channels = channels
460
+ self.out_channels = out_channels or channels
461
+ self.use_conv = use_conv
462
+ self.dims = dims
463
+ if use_conv:
464
+ self.conv = conv_nd(
465
+ dims, self.channels, self.out_channels, 3, padding=padding
466
+ )
467
+
468
+ def forward(self, x):
469
+ assert x.shape[1] == self.channels
470
+ if self.dims == 3:
471
+ x = F.interpolate(
472
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
473
+ )
474
+ else:
475
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
476
+ if self.use_conv:
477
+ x = self.conv(x)
478
+ return x
479
+
480
+
481
+ class Downsample(nn.Module):
482
+ """
483
+ A downsampling layer with an optional convolution.
484
+ :param channels: channels in the inputs and outputs.
485
+ :param use_conv: a bool determining if a convolution is applied.
486
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
487
+ downsampling occurs in the inner-two dimensions.
488
+ """
489
+
490
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
491
+ super().__init__()
492
+ self.channels = channels
493
+ self.out_channels = out_channels or channels
494
+ self.use_conv = use_conv
495
+ self.dims = dims
496
+ stride = 2 if dims != 3 else (1, 2, 2)
497
+ if use_conv:
498
+ self.op = conv_nd(
499
+ dims,
500
+ self.channels,
501
+ self.out_channels,
502
+ 3,
503
+ stride=stride,
504
+ padding=padding,
505
+ )
506
+ else:
507
+ assert self.channels == self.out_channels
508
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
509
+
510
+ def forward(self, x):
511
+ assert x.shape[1] == self.channels
512
+ return self.op(x)
513
+
514
+
515
+ class ResBlock(nn.Module):
516
+ """
517
+ A residual block that can optionally change the number of channels.
518
+ :param channels: the number of input channels.
519
+ :param emb_channels: the number of timestep embedding channels.
520
+ :param dropout: the rate of dropout.
521
+ :param out_channels: if specified, the number of out channels.
522
+ :param use_conv: if True and out_channels is specified, use a spatial
523
+ convolution instead of a smaller 1x1 convolution to change the
524
+ channels in the skip connection.
525
+ :param dims: determines if the signal is 1D, 2D, or 3D.
526
+ :param up: if True, use this block for upsampling.
527
+ :param down: if True, use this block for downsampling.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ channels,
533
+ emb_channels,
534
+ dropout,
535
+ out_channels=None,
536
+ use_conv=False,
537
+ use_scale_shift_norm=False,
538
+ dims=2,
539
+ up=False,
540
+ down=False,
541
+ ):
542
+ super().__init__()
543
+ self.channels = channels
544
+ self.emb_channels = emb_channels
545
+ self.dropout = dropout
546
+ self.out_channels = out_channels or channels
547
+ self.use_conv = use_conv
548
+ self.use_scale_shift_norm = use_scale_shift_norm
549
+
550
+ self.in_layers = nn.Sequential(
551
+ nn.GroupNorm(32, channels),
552
+ nn.SiLU(),
553
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
554
+ )
555
+
556
+ self.updown = up or down
557
+
558
+ if up:
559
+ self.h_upd = Upsample(channels, False, dims)
560
+ self.x_upd = Upsample(channels, False, dims)
561
+ elif down:
562
+ self.h_upd = Downsample(channels, False, dims)
563
+ self.x_upd = Downsample(channels, False, dims)
564
+ else:
565
+ self.h_upd = self.x_upd = nn.Identity()
566
+
567
+ self.emb_layers = nn.Sequential(
568
+ nn.SiLU(),
569
+ nn.Linear(
570
+ emb_channels,
571
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
572
+ ),
573
+ )
574
+ self.out_layers = nn.Sequential(
575
+ nn.GroupNorm(32, self.out_channels),
576
+ nn.SiLU(),
577
+ nn.Dropout(p=dropout),
578
+ zero_module(
579
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
580
+ ),
581
+ )
582
+
583
+ if self.out_channels == channels:
584
+ self.skip_connection = nn.Identity()
585
+ elif use_conv:
586
+ self.skip_connection = conv_nd(
587
+ dims, channels, self.out_channels, 3, padding=1
588
+ )
589
+ else:
590
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
591
+
592
+ def forward(self, x, emb):
593
+ if self.updown:
594
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
595
+ h = in_rest(x)
596
+ h = self.h_upd(h)
597
+ x = self.x_upd(x)
598
+ h = in_conv(h)
599
+ else:
600
+ h = self.in_layers(x)
601
+ emb_out = self.emb_layers(emb).type(h.dtype)
602
+ while len(emb_out.shape) < len(h.shape):
603
+ emb_out = emb_out[..., None]
604
+ if self.use_scale_shift_norm:
605
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
606
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
607
+ h = out_norm(h) * (1 + scale) + shift
608
+ h = out_rest(h)
609
+ else:
610
+ h = h + emb_out
611
+ h = self.out_layers(h)
612
+ return self.skip_connection(x) + h
613
+
614
+
615
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
616
+ """
617
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
618
+ :param in_channels: channels in the input Tensor.
619
+ :param model_channels: base channel count for the model.
620
+ :param out_channels: channels in the output Tensor.
621
+ :param num_res_blocks: number of residual blocks per downsample.
622
+ :param attention_resolutions: a collection of downsample rates at which
623
+ attention will take place. May be a set, list, or tuple.
624
+ For example, if this contains 4, then at 4x downsampling, attention
625
+ will be used.
626
+ :param dropout: the dropout probability.
627
+ :param channel_mult: channel multiplier for each level of the UNet.
628
+ :param conv_resample: if True, use learned convolutions for upsampling and
629
+ downsampling.
630
+ :param dims: determines if the signal is 1D, 2D, or 3D.
631
+ :param num_classes: if specified (as an int), then this model will be
632
+ class-conditional with `num_classes` classes.
633
+ :param num_heads: the number of attention heads in each attention layer.
634
+ :param num_heads_channels: if specified, ignore num_heads and instead use
635
+ a fixed channel width per attention head.
636
+ :param num_heads_upsample: works with num_heads to set a different number
637
+ of heads for upsampling. Deprecated.
638
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
639
+ :param resblock_updown: use residual blocks for up/downsampling.
640
+ :param use_new_attention_order: use a different attention pattern for potentially
641
+ increased efficiency.
642
+ :param camera_dim: dimensionality of camera input.
643
+ """
644
+
645
+ def __init__(
646
+ self,
647
+ image_size,
648
+ in_channels,
649
+ model_channels,
650
+ out_channels,
651
+ num_res_blocks,
652
+ attention_resolutions,
653
+ dropout=0,
654
+ channel_mult=(1, 2, 4, 8),
655
+ conv_resample=True,
656
+ dims=2,
657
+ num_classes=None,
658
+ num_heads=-1,
659
+ num_head_channels=-1,
660
+ num_heads_upsample=-1,
661
+ use_scale_shift_norm=False,
662
+ resblock_updown=False,
663
+ transformer_depth=1,
664
+ context_dim=None,
665
+ n_embed=None,
666
+ num_attention_blocks=None,
667
+ adm_in_channels=None,
668
+ camera_dim=None,
669
+ ip_dim=0, # imagedream uses ip_dim > 0
670
+ ip_weight=1.0,
671
+ **kwargs,
672
+ ):
673
+ super().__init__()
674
+ assert context_dim is not None
675
+
676
+ if num_heads_upsample == -1:
677
+ num_heads_upsample = num_heads
678
+
679
+ if num_heads == -1:
680
+ assert (
681
+ num_head_channels != -1
682
+ ), "Either num_heads or num_head_channels has to be set"
683
+
684
+ if num_head_channels == -1:
685
+ assert (
686
+ num_heads != -1
687
+ ), "Either num_heads or num_head_channels has to be set"
688
+
689
+ self.image_size = image_size
690
+ self.in_channels = in_channels
691
+ self.model_channels = model_channels
692
+ self.out_channels = out_channels
693
+ if isinstance(num_res_blocks, int):
694
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
695
+ else:
696
+ if len(num_res_blocks) != len(channel_mult):
697
+ raise ValueError(
698
+ "provide num_res_blocks either as an int (globally constant) or "
699
+ "as a list/tuple (per-level) with the same length as channel_mult"
700
+ )
701
+ self.num_res_blocks = num_res_blocks
702
+
703
+ if num_attention_blocks is not None:
704
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
705
+ assert all(
706
+ map(
707
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
708
+ range(len(num_attention_blocks)),
709
+ )
710
+ )
711
+ print(
712
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
713
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
714
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
715
+ f"attention will still not be set."
716
+ )
717
+
718
+ self.attention_resolutions = attention_resolutions
719
+ self.dropout = dropout
720
+ self.channel_mult = channel_mult
721
+ self.conv_resample = conv_resample
722
+ self.num_classes = num_classes
723
+ self.num_heads = num_heads
724
+ self.num_head_channels = num_head_channels
725
+ self.num_heads_upsample = num_heads_upsample
726
+ self.predict_codebook_ids = n_embed is not None
727
+
728
+ self.ip_dim = ip_dim
729
+ self.ip_weight = ip_weight
730
+
731
+ if self.ip_dim > 0:
732
+ self.image_embed = Resampler(
733
+ dim=context_dim,
734
+ depth=4,
735
+ dim_head=64,
736
+ heads=12,
737
+ num_queries=ip_dim, # num token
738
+ embedding_dim=1280,
739
+ output_dim=context_dim,
740
+ ff_mult=4,
741
+ )
742
+
743
+ time_embed_dim = model_channels * 4
744
+ self.time_embed = nn.Sequential(
745
+ nn.Linear(model_channels, time_embed_dim),
746
+ nn.SiLU(),
747
+ nn.Linear(time_embed_dim, time_embed_dim),
748
+ )
749
+
750
+ if camera_dim is not None:
751
+ time_embed_dim = model_channels * 4
752
+ self.camera_embed = nn.Sequential(
753
+ nn.Linear(camera_dim, time_embed_dim),
754
+ nn.SiLU(),
755
+ nn.Linear(time_embed_dim, time_embed_dim),
756
+ )
757
+
758
+ if self.num_classes is not None:
759
+ if isinstance(self.num_classes, int):
760
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
761
+ elif self.num_classes == "continuous":
762
+ # print("setting up linear c_adm embedding layer")
763
+ self.label_emb = nn.Linear(1, time_embed_dim)
764
+ elif self.num_classes == "sequential":
765
+ assert adm_in_channels is not None
766
+ self.label_emb = nn.Sequential(
767
+ nn.Sequential(
768
+ nn.Linear(adm_in_channels, time_embed_dim),
769
+ nn.SiLU(),
770
+ nn.Linear(time_embed_dim, time_embed_dim),
771
+ )
772
+ )
773
+ else:
774
+ raise ValueError()
775
+
776
+ self.input_blocks = nn.ModuleList(
777
+ [
778
+ CondSequential(
779
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
780
+ )
781
+ ]
782
+ )
783
+ self._feature_size = model_channels
784
+ input_block_chans = [model_channels]
785
+ ch = model_channels
786
+ ds = 1
787
+ for level, mult in enumerate(channel_mult):
788
+ for nr in range(self.num_res_blocks[level]):
789
+ layers: List[Any] = [
790
+ ResBlock(
791
+ ch,
792
+ time_embed_dim,
793
+ dropout,
794
+ out_channels=mult * model_channels,
795
+ dims=dims,
796
+ use_scale_shift_norm=use_scale_shift_norm,
797
+ )
798
+ ]
799
+ ch = mult * model_channels
800
+ if ds in attention_resolutions:
801
+ if num_head_channels == -1:
802
+ dim_head = ch // num_heads
803
+ else:
804
+ num_heads = ch // num_head_channels
805
+ dim_head = num_head_channels
806
+
807
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
808
+ layers.append(
809
+ SpatialTransformer3D(
810
+ ch,
811
+ num_heads,
812
+ dim_head,
813
+ context_dim=context_dim,
814
+ depth=transformer_depth,
815
+ ip_dim=self.ip_dim,
816
+ ip_weight=self.ip_weight,
817
+ )
818
+ )
819
+ self.input_blocks.append(CondSequential(*layers))
820
+ self._feature_size += ch
821
+ input_block_chans.append(ch)
822
+ if level != len(channel_mult) - 1:
823
+ out_ch = ch
824
+ self.input_blocks.append(
825
+ CondSequential(
826
+ ResBlock(
827
+ ch,
828
+ time_embed_dim,
829
+ dropout,
830
+ out_channels=out_ch,
831
+ dims=dims,
832
+ use_scale_shift_norm=use_scale_shift_norm,
833
+ down=True,
834
+ )
835
+ if resblock_updown
836
+ else Downsample(
837
+ ch, conv_resample, dims=dims, out_channels=out_ch
838
+ )
839
+ )
840
+ )
841
+ ch = out_ch
842
+ input_block_chans.append(ch)
843
+ ds *= 2
844
+ self._feature_size += ch
845
+
846
+ if num_head_channels == -1:
847
+ dim_head = ch // num_heads
848
+ else:
849
+ num_heads = ch // num_head_channels
850
+ dim_head = num_head_channels
851
+
852
+ self.middle_block = CondSequential(
853
+ ResBlock(
854
+ ch,
855
+ time_embed_dim,
856
+ dropout,
857
+ dims=dims,
858
+ use_scale_shift_norm=use_scale_shift_norm,
859
+ ),
860
+ SpatialTransformer3D(
861
+ ch,
862
+ num_heads,
863
+ dim_head,
864
+ context_dim=context_dim,
865
+ depth=transformer_depth,
866
+ ip_dim=self.ip_dim,
867
+ ip_weight=self.ip_weight,
868
+ ),
869
+ ResBlock(
870
+ ch,
871
+ time_embed_dim,
872
+ dropout,
873
+ dims=dims,
874
+ use_scale_shift_norm=use_scale_shift_norm,
875
+ ),
876
+ )
877
+ self._feature_size += ch
878
+
879
+ self.output_blocks = nn.ModuleList([])
880
+ for level, mult in list(enumerate(channel_mult))[::-1]:
881
+ for i in range(self.num_res_blocks[level] + 1):
882
+ ich = input_block_chans.pop()
883
+ layers = [
884
+ ResBlock(
885
+ ch + ich,
886
+ time_embed_dim,
887
+ dropout,
888
+ out_channels=model_channels * mult,
889
+ dims=dims,
890
+ use_scale_shift_norm=use_scale_shift_norm,
891
+ )
892
+ ]
893
+ ch = model_channels * mult
894
+ if ds in attention_resolutions:
895
+ if num_head_channels == -1:
896
+ dim_head = ch // num_heads
897
+ else:
898
+ num_heads = ch // num_head_channels
899
+ dim_head = num_head_channels
900
+
901
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
902
+ layers.append(
903
+ SpatialTransformer3D(
904
+ ch,
905
+ num_heads,
906
+ dim_head,
907
+ context_dim=context_dim,
908
+ depth=transformer_depth,
909
+ ip_dim=self.ip_dim,
910
+ ip_weight=self.ip_weight,
911
+ )
912
+ )
913
+ if level and i == self.num_res_blocks[level]:
914
+ out_ch = ch
915
+ layers.append(
916
+ ResBlock(
917
+ ch,
918
+ time_embed_dim,
919
+ dropout,
920
+ out_channels=out_ch,
921
+ dims=dims,
922
+ use_scale_shift_norm=use_scale_shift_norm,
923
+ up=True,
924
+ )
925
+ if resblock_updown
926
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
927
+ )
928
+ ds //= 2
929
+ self.output_blocks.append(CondSequential(*layers))
930
+ self._feature_size += ch
931
+
932
+ self.out = nn.Sequential(
933
+ nn.GroupNorm(32, ch),
934
+ nn.SiLU(),
935
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
936
+ )
937
+ if self.predict_codebook_ids:
938
+ self.id_predictor = nn.Sequential(
939
+ nn.GroupNorm(32, ch),
940
+ conv_nd(dims, model_channels, n_embed, 1),
941
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
942
+ )
943
+
944
+ def forward(
945
+ self,
946
+ x,
947
+ timesteps=None,
948
+ context=None,
949
+ y=None,
950
+ camera=None,
951
+ num_frames=1,
952
+ ip=None,
953
+ ip_img=None,
954
+ **kwargs,
955
+ ):
956
+ """
957
+ Apply the model to an input batch.
958
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
959
+ :param timesteps: a 1-D batch of timesteps.
960
+ :param context: conditioning plugged in via crossattn
961
+ :param y: an [N] Tensor of labels, if class-conditional.
962
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
963
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
964
+ """
965
+ assert (
966
+ x.shape[0] % num_frames == 0
967
+ ), "input batch size must be dividable by num_frames!"
968
+ assert (y is not None) == (
969
+ self.num_classes is not None
970
+ ), "must specify y if and only if the model is class-conditional"
971
+
972
+ hs = []
973
+
974
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
975
+
976
+ emb = self.time_embed(t_emb)
977
+
978
+ if self.num_classes is not None:
979
+ assert y is not None
980
+ assert y.shape[0] == x.shape[0]
981
+ emb = emb + self.label_emb(y)
982
+
983
+ # Add camera embeddings
984
+ if camera is not None:
985
+ emb = emb + self.camera_embed(camera)
986
+
987
+ # imagedream variant
988
+ if self.ip_dim > 0:
989
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
990
+ ip_emb = self.image_embed(ip)
991
+ context = torch.cat((context, ip_emb), 1)
992
+
993
+ h = x
994
+ for module in self.input_blocks:
995
+ h = module(h, emb, context, num_frames=num_frames)
996
+ hs.append(h)
997
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
998
+ for module in self.output_blocks:
999
+ h = torch.cat([h, hs.pop()], dim=1)
1000
+ h = module(h, emb, context, num_frames=num_frames)
1001
+ h = h.type(x.dtype)
1002
+ if self.predict_codebook_ids:
1003
+ return self.id_predictor(h)
1004
+ else:
1005
+ return self.out(h)
vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.25.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 256,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e4c08995484ee61270175e9e7a072b66a6e4eeb5f0c266667fe1f45b90daf9a
3
+ size 167335342