wanghaofan commited on
Commit
f2b7487
·
verified ·
1 Parent(s): 06398f1

Upload 11 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ teasers/0.png filter=lfs diff=lfs merge=lfs -text
37
+ teasers/1.png filter=lfs diff=lfs merge=lfs -text
assets/1.jpg ADDED
assets/2.jpg ADDED
infer_sd35_large_ipa.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+
4
+ from models.transformer_sd3 import SD3Transformer2DModel
5
+ from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
6
+
7
+
8
+ if __name__ == '__main__':
9
+
10
+ model_path = 'stabilityai/stable-diffusion-3.5-large'
11
+ ip_adapter_path = './ip-adapter.bin'
12
+ image_encoder_path = "google/siglip-so400m-patch14-384"
13
+
14
+ transformer = SD3Transformer2DModel.from_pretrained(
15
+ model_path, subfolder="transformer", torch_dtype=torch.bfloat16
16
+ )
17
+
18
+ pipe = StableDiffusion3Pipeline.from_pretrained(
19
+ model_path, transformer=transformer, torch_dtype=torch.bfloat16
20
+ ).to("cuda")
21
+
22
+ pipe.init_ipadapter(
23
+ ip_adapter_path=ip_adapter_path,
24
+ image_encoder_path=image_encoder_path,
25
+ nb_token=64,
26
+ )
27
+
28
+ ref_img = Image.open('./assets/1.jpg').convert('RGB')
29
+ image = pipe(
30
+ width=1024,
31
+ height=1024,
32
+ prompt='a cat',
33
+ negative_prompt="lowres, low quality, worst quality",
34
+ num_inference_steps=24,
35
+ guidance_scale=5.0,
36
+ generator=torch.Generator("cuda").manual_seed(42),
37
+ clip_image=ref_img,
38
+ ipadapter_scale=0.5,
39
+ ).images[0]
40
+ image.save('./result.jpg')
ip-adapter.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fe54774aa528e712d9145ff6a59dd93b1fcf1d5935304feffd980ae6d42ae03
3
+ size 1595970439
models/__init__.py ADDED
File without changes
models/attention.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import deprecate, logging
21
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
+ from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
24
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ return ff_output
44
+
45
+
46
+ @maybe_allow_in_graph
47
+ class GatedSelfAttentionDense(nn.Module):
48
+ r"""
49
+ A gated self-attention dense layer that combines visual features and object features.
50
+
51
+ Parameters:
52
+ query_dim (`int`): The number of channels in the query.
53
+ context_dim (`int`): The number of channels in the context.
54
+ n_heads (`int`): The number of heads to use for attention.
55
+ d_head (`int`): The number of channels in each head.
56
+ """
57
+
58
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
+ super().__init__()
60
+
61
+ # we need a linear projection since we need cat visual feature and obj feature
62
+ self.linear = nn.Linear(context_dim, query_dim)
63
+
64
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
66
+
67
+ self.norm1 = nn.LayerNorm(query_dim)
68
+ self.norm2 = nn.LayerNorm(query_dim)
69
+
70
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
+
73
+ self.enabled = True
74
+
75
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return x
78
+
79
+ n_visual = x.shape[1]
80
+ objs = self.linear(objs)
81
+
82
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
83
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
84
+
85
+ return x
86
+
87
+
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ dim: int,
106
+ num_attention_heads: int,
107
+ attention_head_dim: int,
108
+ context_pre_only: bool = False,
109
+ qk_norm: Optional[str] = None,
110
+ use_dual_attention: bool = False,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.use_dual_attention = use_dual_attention
115
+ self.context_pre_only = context_pre_only
116
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
117
+
118
+ if use_dual_attention:
119
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
120
+ else:
121
+ self.norm1 = AdaLayerNormZero(dim)
122
+
123
+ if context_norm_type == "ada_norm_continous":
124
+ self.norm1_context = AdaLayerNormContinuous(
125
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
126
+ )
127
+ elif context_norm_type == "ada_norm_zero":
128
+ self.norm1_context = AdaLayerNormZero(dim)
129
+ else:
130
+ raise ValueError(
131
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
132
+ )
133
+
134
+ if hasattr(F, "scaled_dot_product_attention"):
135
+ processor = JointAttnProcessor2_0()
136
+ else:
137
+ raise ValueError(
138
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
139
+ )
140
+
141
+ self.attn = Attention(
142
+ query_dim=dim,
143
+ cross_attention_dim=None,
144
+ added_kv_proj_dim=dim,
145
+ dim_head=attention_head_dim,
146
+ heads=num_attention_heads,
147
+ out_dim=dim,
148
+ context_pre_only=context_pre_only,
149
+ bias=True,
150
+ processor=processor,
151
+ qk_norm=qk_norm,
152
+ eps=1e-6,
153
+ )
154
+
155
+ if use_dual_attention:
156
+ self.attn2 = Attention(
157
+ query_dim=dim,
158
+ cross_attention_dim=None,
159
+ dim_head=attention_head_dim,
160
+ heads=num_attention_heads,
161
+ out_dim=dim,
162
+ bias=True,
163
+ processor=processor,
164
+ qk_norm=qk_norm,
165
+ eps=1e-6,
166
+ )
167
+ else:
168
+ self.attn2 = None
169
+
170
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
171
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
172
+
173
+ if not context_pre_only:
174
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
175
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
176
+ else:
177
+ self.norm2_context = None
178
+ self.ff_context = None
179
+
180
+ # let chunk size default to None
181
+ self._chunk_size = None
182
+ self._chunk_dim = 0
183
+
184
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
185
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
186
+ # Sets chunk feed-forward
187
+ self._chunk_size = chunk_size
188
+ self._chunk_dim = dim
189
+
190
+ def forward(
191
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor,
192
+ joint_attention_kwargs=None,
193
+ ):
194
+ if self.use_dual_attention:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
196
+ hidden_states, emb=temb
197
+ )
198
+ else:
199
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
200
+
201
+ if self.context_pre_only:
202
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
203
+ else:
204
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
205
+ encoder_hidden_states, emb=temb
206
+ )
207
+
208
+ # Attention.
209
+ attn_output, context_attn_output = self.attn(
210
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
211
+ **({} if joint_attention_kwargs is None else joint_attention_kwargs),
212
+ )
213
+
214
+ # Process attention outputs for the `hidden_states`.
215
+ attn_output = gate_msa.unsqueeze(1) * attn_output
216
+ hidden_states = hidden_states + attn_output
217
+
218
+ if self.use_dual_attention:
219
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **({} if joint_attention_kwargs is None else joint_attention_kwargs),)
220
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
221
+ hidden_states = hidden_states + attn_output2
222
+
223
+ norm_hidden_states = self.norm2(hidden_states)
224
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
225
+ if self._chunk_size is not None:
226
+ # "feed_forward_chunk_size" can be used to save memory
227
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
228
+ else:
229
+ ff_output = self.ff(norm_hidden_states)
230
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
231
+
232
+ hidden_states = hidden_states + ff_output
233
+
234
+ # Process attention outputs for the `encoder_hidden_states`.
235
+ if self.context_pre_only:
236
+ encoder_hidden_states = None
237
+ else:
238
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
239
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
240
+
241
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
242
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
243
+ if self._chunk_size is not None:
244
+ # "feed_forward_chunk_size" can be used to save memory
245
+ context_ff_output = _chunked_feed_forward(
246
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
247
+ )
248
+ else:
249
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
250
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
251
+
252
+ return encoder_hidden_states, hidden_states
253
+
254
+
255
+ @maybe_allow_in_graph
256
+ class BasicTransformerBlock(nn.Module):
257
+ r"""
258
+ A basic Transformer block.
259
+
260
+ Parameters:
261
+ dim (`int`): The number of channels in the input and output.
262
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
263
+ attention_head_dim (`int`): The number of channels in each head.
264
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
265
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
266
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
267
+ num_embeds_ada_norm (:
268
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
269
+ attention_bias (:
270
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
271
+ only_cross_attention (`bool`, *optional*):
272
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
273
+ double_self_attention (`bool`, *optional*):
274
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
275
+ upcast_attention (`bool`, *optional*):
276
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
277
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
278
+ Whether to use learnable elementwise affine parameters for normalization.
279
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
280
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
281
+ final_dropout (`bool` *optional*, defaults to False):
282
+ Whether to apply a final dropout after the last feed-forward layer.
283
+ attention_type (`str`, *optional*, defaults to `"default"`):
284
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
285
+ positional_embeddings (`str`, *optional*, defaults to `None`):
286
+ The type of positional embeddings to apply to.
287
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
288
+ The maximum number of positional embeddings to apply.
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ dim: int,
294
+ num_attention_heads: int,
295
+ attention_head_dim: int,
296
+ dropout=0.0,
297
+ cross_attention_dim: Optional[int] = None,
298
+ activation_fn: str = "geglu",
299
+ num_embeds_ada_norm: Optional[int] = None,
300
+ attention_bias: bool = False,
301
+ only_cross_attention: bool = False,
302
+ double_self_attention: bool = False,
303
+ upcast_attention: bool = False,
304
+ norm_elementwise_affine: bool = True,
305
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
306
+ norm_eps: float = 1e-5,
307
+ final_dropout: bool = False,
308
+ attention_type: str = "default",
309
+ positional_embeddings: Optional[str] = None,
310
+ num_positional_embeddings: Optional[int] = None,
311
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
312
+ ada_norm_bias: Optional[int] = None,
313
+ ff_inner_dim: Optional[int] = None,
314
+ ff_bias: bool = True,
315
+ attention_out_bias: bool = True,
316
+ ):
317
+ super().__init__()
318
+ self.dim = dim
319
+ self.num_attention_heads = num_attention_heads
320
+ self.attention_head_dim = attention_head_dim
321
+ self.dropout = dropout
322
+ self.cross_attention_dim = cross_attention_dim
323
+ self.activation_fn = activation_fn
324
+ self.attention_bias = attention_bias
325
+ self.double_self_attention = double_self_attention
326
+ self.norm_elementwise_affine = norm_elementwise_affine
327
+ self.positional_embeddings = positional_embeddings
328
+ self.num_positional_embeddings = num_positional_embeddings
329
+ self.only_cross_attention = only_cross_attention
330
+
331
+ # We keep these boolean flags for backward-compatibility.
332
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
333
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
334
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
335
+ self.use_layer_norm = norm_type == "layer_norm"
336
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
337
+
338
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
339
+ raise ValueError(
340
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
341
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
342
+ )
343
+
344
+ self.norm_type = norm_type
345
+ self.num_embeds_ada_norm = num_embeds_ada_norm
346
+
347
+ if positional_embeddings and (num_positional_embeddings is None):
348
+ raise ValueError(
349
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
350
+ )
351
+
352
+ if positional_embeddings == "sinusoidal":
353
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
354
+ else:
355
+ self.pos_embed = None
356
+
357
+ # Define 3 blocks. Each block has its own normalization layer.
358
+ # 1. Self-Attn
359
+ if norm_type == "ada_norm":
360
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
361
+ elif norm_type == "ada_norm_zero":
362
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
363
+ elif norm_type == "ada_norm_continuous":
364
+ self.norm1 = AdaLayerNormContinuous(
365
+ dim,
366
+ ada_norm_continous_conditioning_embedding_dim,
367
+ norm_elementwise_affine,
368
+ norm_eps,
369
+ ada_norm_bias,
370
+ "rms_norm",
371
+ )
372
+ else:
373
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
374
+
375
+ self.attn1 = Attention(
376
+ query_dim=dim,
377
+ heads=num_attention_heads,
378
+ dim_head=attention_head_dim,
379
+ dropout=dropout,
380
+ bias=attention_bias,
381
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
382
+ upcast_attention=upcast_attention,
383
+ out_bias=attention_out_bias,
384
+ )
385
+
386
+ # 2. Cross-Attn
387
+ if cross_attention_dim is not None or double_self_attention:
388
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
389
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
390
+ # the second cross attention block.
391
+ if norm_type == "ada_norm":
392
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
393
+ elif norm_type == "ada_norm_continuous":
394
+ self.norm2 = AdaLayerNormContinuous(
395
+ dim,
396
+ ada_norm_continous_conditioning_embedding_dim,
397
+ norm_elementwise_affine,
398
+ norm_eps,
399
+ ada_norm_bias,
400
+ "rms_norm",
401
+ )
402
+ else:
403
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
404
+
405
+ self.attn2 = Attention(
406
+ query_dim=dim,
407
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
408
+ heads=num_attention_heads,
409
+ dim_head=attention_head_dim,
410
+ dropout=dropout,
411
+ bias=attention_bias,
412
+ upcast_attention=upcast_attention,
413
+ out_bias=attention_out_bias,
414
+ ) # is self-attn if encoder_hidden_states is none
415
+ else:
416
+ if norm_type == "ada_norm_single": # For Latte
417
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
418
+ else:
419
+ self.norm2 = None
420
+ self.attn2 = None
421
+
422
+ # 3. Feed-forward
423
+ if norm_type == "ada_norm_continuous":
424
+ self.norm3 = AdaLayerNormContinuous(
425
+ dim,
426
+ ada_norm_continous_conditioning_embedding_dim,
427
+ norm_elementwise_affine,
428
+ norm_eps,
429
+ ada_norm_bias,
430
+ "layer_norm",
431
+ )
432
+
433
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
434
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
435
+ elif norm_type == "layer_norm_i2vgen":
436
+ self.norm3 = None
437
+
438
+ self.ff = FeedForward(
439
+ dim,
440
+ dropout=dropout,
441
+ activation_fn=activation_fn,
442
+ final_dropout=final_dropout,
443
+ inner_dim=ff_inner_dim,
444
+ bias=ff_bias,
445
+ )
446
+
447
+ # 4. Fuser
448
+ if attention_type == "gated" or attention_type == "gated-text-image":
449
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
450
+
451
+ # 5. Scale-shift for PixArt-Alpha.
452
+ if norm_type == "ada_norm_single":
453
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
454
+
455
+ # let chunk size default to None
456
+ self._chunk_size = None
457
+ self._chunk_dim = 0
458
+
459
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
460
+ # Sets chunk feed-forward
461
+ self._chunk_size = chunk_size
462
+ self._chunk_dim = dim
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states: torch.Tensor,
467
+ attention_mask: Optional[torch.Tensor] = None,
468
+ encoder_hidden_states: Optional[torch.Tensor] = None,
469
+ encoder_attention_mask: Optional[torch.Tensor] = None,
470
+ timestep: Optional[torch.LongTensor] = None,
471
+ cross_attention_kwargs: Dict[str, Any] = None,
472
+ class_labels: Optional[torch.LongTensor] = None,
473
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
474
+ ) -> torch.Tensor:
475
+ if cross_attention_kwargs is not None:
476
+ if cross_attention_kwargs.get("scale", None) is not None:
477
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
478
+
479
+ # Notice that normalization is always applied before the real computation in the following blocks.
480
+ # 0. Self-Attention
481
+ batch_size = hidden_states.shape[0]
482
+
483
+ if self.norm_type == "ada_norm":
484
+ norm_hidden_states = self.norm1(hidden_states, timestep)
485
+ elif self.norm_type == "ada_norm_zero":
486
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
487
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
488
+ )
489
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
490
+ norm_hidden_states = self.norm1(hidden_states)
491
+ elif self.norm_type == "ada_norm_continuous":
492
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
493
+ elif self.norm_type == "ada_norm_single":
494
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
495
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
496
+ ).chunk(6, dim=1)
497
+ norm_hidden_states = self.norm1(hidden_states)
498
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
499
+ else:
500
+ raise ValueError("Incorrect norm used")
501
+
502
+ if self.pos_embed is not None:
503
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
504
+
505
+ # 1. Prepare GLIGEN inputs
506
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
507
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
508
+
509
+ attn_output = self.attn1(
510
+ norm_hidden_states,
511
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
512
+ attention_mask=attention_mask,
513
+ **cross_attention_kwargs,
514
+ )
515
+
516
+ if self.norm_type == "ada_norm_zero":
517
+ attn_output = gate_msa.unsqueeze(1) * attn_output
518
+ elif self.norm_type == "ada_norm_single":
519
+ attn_output = gate_msa * attn_output
520
+
521
+ hidden_states = attn_output + hidden_states
522
+ if hidden_states.ndim == 4:
523
+ hidden_states = hidden_states.squeeze(1)
524
+
525
+ # 1.2 GLIGEN Control
526
+ if gligen_kwargs is not None:
527
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
528
+
529
+ # 3. Cross-Attention
530
+ if self.attn2 is not None:
531
+ if self.norm_type == "ada_norm":
532
+ norm_hidden_states = self.norm2(hidden_states, timestep)
533
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
534
+ norm_hidden_states = self.norm2(hidden_states)
535
+ elif self.norm_type == "ada_norm_single":
536
+ # For PixArt norm2 isn't applied here:
537
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
538
+ norm_hidden_states = hidden_states
539
+ elif self.norm_type == "ada_norm_continuous":
540
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
541
+ else:
542
+ raise ValueError("Incorrect norm")
543
+
544
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
545
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
546
+
547
+ attn_output = self.attn2(
548
+ norm_hidden_states,
549
+ encoder_hidden_states=encoder_hidden_states,
550
+ attention_mask=encoder_attention_mask,
551
+ **cross_attention_kwargs,
552
+ )
553
+ hidden_states = attn_output + hidden_states
554
+
555
+ # 4. Feed-forward
556
+ # i2vgen doesn't have this norm 🤷‍♂️
557
+ if self.norm_type == "ada_norm_continuous":
558
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
559
+ elif not self.norm_type == "ada_norm_single":
560
+ norm_hidden_states = self.norm3(hidden_states)
561
+
562
+ if self.norm_type == "ada_norm_zero":
563
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
564
+
565
+ if self.norm_type == "ada_norm_single":
566
+ norm_hidden_states = self.norm2(hidden_states)
567
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
568
+
569
+ if self._chunk_size is not None:
570
+ # "feed_forward_chunk_size" can be used to save memory
571
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
572
+ else:
573
+ ff_output = self.ff(norm_hidden_states)
574
+
575
+ if self.norm_type == "ada_norm_zero":
576
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
577
+ elif self.norm_type == "ada_norm_single":
578
+ ff_output = gate_mlp * ff_output
579
+
580
+ hidden_states = ff_output + hidden_states
581
+ if hidden_states.ndim == 4:
582
+ hidden_states = hidden_states.squeeze(1)
583
+
584
+ return hidden_states
585
+
586
+
587
+ class LuminaFeedForward(nn.Module):
588
+ r"""
589
+ A feed-forward layer.
590
+
591
+ Parameters:
592
+ hidden_size (`int`):
593
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
594
+ hidden representations.
595
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
596
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
597
+ of this value.
598
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
599
+ dimension. Defaults to None.
600
+ """
601
+
602
+ def __init__(
603
+ self,
604
+ dim: int,
605
+ inner_dim: int,
606
+ multiple_of: Optional[int] = 256,
607
+ ffn_dim_multiplier: Optional[float] = None,
608
+ ):
609
+ super().__init__()
610
+ inner_dim = int(2 * inner_dim / 3)
611
+ # custom hidden_size factor multiplier
612
+ if ffn_dim_multiplier is not None:
613
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
614
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
615
+
616
+ self.linear_1 = nn.Linear(
617
+ dim,
618
+ inner_dim,
619
+ bias=False,
620
+ )
621
+ self.linear_2 = nn.Linear(
622
+ inner_dim,
623
+ dim,
624
+ bias=False,
625
+ )
626
+ self.linear_3 = nn.Linear(
627
+ dim,
628
+ inner_dim,
629
+ bias=False,
630
+ )
631
+ self.silu = FP32SiLU()
632
+
633
+ def forward(self, x):
634
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
635
+
636
+
637
+ @maybe_allow_in_graph
638
+ class TemporalBasicTransformerBlock(nn.Module):
639
+ r"""
640
+ A basic Transformer block for video like data.
641
+
642
+ Parameters:
643
+ dim (`int`): The number of channels in the input and output.
644
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
645
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
646
+ attention_head_dim (`int`): The number of channels in each head.
647
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
648
+ """
649
+
650
+ def __init__(
651
+ self,
652
+ dim: int,
653
+ time_mix_inner_dim: int,
654
+ num_attention_heads: int,
655
+ attention_head_dim: int,
656
+ cross_attention_dim: Optional[int] = None,
657
+ ):
658
+ super().__init__()
659
+ self.is_res = dim == time_mix_inner_dim
660
+
661
+ self.norm_in = nn.LayerNorm(dim)
662
+
663
+ # Define 3 blocks. Each block has its own normalization layer.
664
+ # 1. Self-Attn
665
+ self.ff_in = FeedForward(
666
+ dim,
667
+ dim_out=time_mix_inner_dim,
668
+ activation_fn="geglu",
669
+ )
670
+
671
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
672
+ self.attn1 = Attention(
673
+ query_dim=time_mix_inner_dim,
674
+ heads=num_attention_heads,
675
+ dim_head=attention_head_dim,
676
+ cross_attention_dim=None,
677
+ )
678
+
679
+ # 2. Cross-Attn
680
+ if cross_attention_dim is not None:
681
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
682
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
683
+ # the second cross attention block.
684
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
685
+ self.attn2 = Attention(
686
+ query_dim=time_mix_inner_dim,
687
+ cross_attention_dim=cross_attention_dim,
688
+ heads=num_attention_heads,
689
+ dim_head=attention_head_dim,
690
+ ) # is self-attn if encoder_hidden_states is none
691
+ else:
692
+ self.norm2 = None
693
+ self.attn2 = None
694
+
695
+ # 3. Feed-forward
696
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
697
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
698
+
699
+ # let chunk size default to None
700
+ self._chunk_size = None
701
+ self._chunk_dim = None
702
+
703
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
704
+ # Sets chunk feed-forward
705
+ self._chunk_size = chunk_size
706
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
707
+ self._chunk_dim = 1
708
+
709
+ def forward(
710
+ self,
711
+ hidden_states: torch.Tensor,
712
+ num_frames: int,
713
+ encoder_hidden_states: Optional[torch.Tensor] = None,
714
+ ) -> torch.Tensor:
715
+ # Notice that normalization is always applied before the real computation in the following blocks.
716
+ # 0. Self-Attention
717
+ batch_size = hidden_states.shape[0]
718
+
719
+ batch_frames, seq_length, channels = hidden_states.shape
720
+ batch_size = batch_frames // num_frames
721
+
722
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
723
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
724
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
725
+
726
+ residual = hidden_states
727
+ hidden_states = self.norm_in(hidden_states)
728
+
729
+ if self._chunk_size is not None:
730
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
731
+ else:
732
+ hidden_states = self.ff_in(hidden_states)
733
+
734
+ if self.is_res:
735
+ hidden_states = hidden_states + residual
736
+
737
+ norm_hidden_states = self.norm1(hidden_states)
738
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
739
+ hidden_states = attn_output + hidden_states
740
+
741
+ # 3. Cross-Attention
742
+ if self.attn2 is not None:
743
+ norm_hidden_states = self.norm2(hidden_states)
744
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
745
+ hidden_states = attn_output + hidden_states
746
+
747
+ # 4. Feed-forward
748
+ norm_hidden_states = self.norm3(hidden_states)
749
+
750
+ if self._chunk_size is not None:
751
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
752
+ else:
753
+ ff_output = self.ff(norm_hidden_states)
754
+
755
+ if self.is_res:
756
+ hidden_states = ff_output + hidden_states
757
+ else:
758
+ hidden_states = ff_output
759
+
760
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
761
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
762
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
763
+
764
+ return hidden_states
765
+
766
+
767
+ class SkipFFTransformerBlock(nn.Module):
768
+ def __init__(
769
+ self,
770
+ dim: int,
771
+ num_attention_heads: int,
772
+ attention_head_dim: int,
773
+ kv_input_dim: int,
774
+ kv_input_dim_proj_use_bias: bool,
775
+ dropout=0.0,
776
+ cross_attention_dim: Optional[int] = None,
777
+ attention_bias: bool = False,
778
+ attention_out_bias: bool = True,
779
+ ):
780
+ super().__init__()
781
+ if kv_input_dim != dim:
782
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
783
+ else:
784
+ self.kv_mapper = None
785
+
786
+ self.norm1 = RMSNorm(dim, 1e-06)
787
+
788
+ self.attn1 = Attention(
789
+ query_dim=dim,
790
+ heads=num_attention_heads,
791
+ dim_head=attention_head_dim,
792
+ dropout=dropout,
793
+ bias=attention_bias,
794
+ cross_attention_dim=cross_attention_dim,
795
+ out_bias=attention_out_bias,
796
+ )
797
+
798
+ self.norm2 = RMSNorm(dim, 1e-06)
799
+
800
+ self.attn2 = Attention(
801
+ query_dim=dim,
802
+ cross_attention_dim=cross_attention_dim,
803
+ heads=num_attention_heads,
804
+ dim_head=attention_head_dim,
805
+ dropout=dropout,
806
+ bias=attention_bias,
807
+ out_bias=attention_out_bias,
808
+ )
809
+
810
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
811
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
812
+
813
+ if self.kv_mapper is not None:
814
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
815
+
816
+ norm_hidden_states = self.norm1(hidden_states)
817
+
818
+ attn_output = self.attn1(
819
+ norm_hidden_states,
820
+ encoder_hidden_states=encoder_hidden_states,
821
+ **cross_attention_kwargs,
822
+ )
823
+
824
+ hidden_states = attn_output + hidden_states
825
+
826
+ norm_hidden_states = self.norm2(hidden_states)
827
+
828
+ attn_output = self.attn2(
829
+ norm_hidden_states,
830
+ encoder_hidden_states=encoder_hidden_states,
831
+ **cross_attention_kwargs,
832
+ )
833
+
834
+ hidden_states = attn_output + hidden_states
835
+
836
+ return hidden_states
837
+
838
+
839
+ @maybe_allow_in_graph
840
+ class FreeNoiseTransformerBlock(nn.Module):
841
+ r"""
842
+ A FreeNoise Transformer block.
843
+
844
+ Parameters:
845
+ dim (`int`):
846
+ The number of channels in the input and output.
847
+ num_attention_heads (`int`):
848
+ The number of heads to use for multi-head attention.
849
+ attention_head_dim (`int`):
850
+ The number of channels in each head.
851
+ dropout (`float`, *optional*, defaults to 0.0):
852
+ The dropout probability to use.
853
+ cross_attention_dim (`int`, *optional*):
854
+ The size of the encoder_hidden_states vector for cross attention.
855
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
856
+ Activation function to be used in feed-forward.
857
+ num_embeds_ada_norm (`int`, *optional*):
858
+ The number of diffusion steps used during training. See `Transformer2DModel`.
859
+ attention_bias (`bool`, defaults to `False`):
860
+ Configure if the attentions should contain a bias parameter.
861
+ only_cross_attention (`bool`, defaults to `False`):
862
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
863
+ double_self_attention (`bool`, defaults to `False`):
864
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
865
+ upcast_attention (`bool`, defaults to `False`):
866
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
867
+ norm_elementwise_affine (`bool`, defaults to `True`):
868
+ Whether to use learnable elementwise affine parameters for normalization.
869
+ norm_type (`str`, defaults to `"layer_norm"`):
870
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
871
+ final_dropout (`bool` defaults to `False`):
872
+ Whether to apply a final dropout after the last feed-forward layer.
873
+ attention_type (`str`, defaults to `"default"`):
874
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
875
+ positional_embeddings (`str`, *optional*):
876
+ The type of positional embeddings to apply to.
877
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
878
+ The maximum number of positional embeddings to apply.
879
+ ff_inner_dim (`int`, *optional*):
880
+ Hidden dimension of feed-forward MLP.
881
+ ff_bias (`bool`, defaults to `True`):
882
+ Whether or not to use bias in feed-forward MLP.
883
+ attention_out_bias (`bool`, defaults to `True`):
884
+ Whether or not to use bias in attention output project layer.
885
+ context_length (`int`, defaults to `16`):
886
+ The maximum number of frames that the FreeNoise block processes at once.
887
+ context_stride (`int`, defaults to `4`):
888
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
889
+ weighting_scheme (`str`, defaults to `"pyramid"`):
890
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
891
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
892
+ used.
893
+ """
894
+
895
+ def __init__(
896
+ self,
897
+ dim: int,
898
+ num_attention_heads: int,
899
+ attention_head_dim: int,
900
+ dropout: float = 0.0,
901
+ cross_attention_dim: Optional[int] = None,
902
+ activation_fn: str = "geglu",
903
+ num_embeds_ada_norm: Optional[int] = None,
904
+ attention_bias: bool = False,
905
+ only_cross_attention: bool = False,
906
+ double_self_attention: bool = False,
907
+ upcast_attention: bool = False,
908
+ norm_elementwise_affine: bool = True,
909
+ norm_type: str = "layer_norm",
910
+ norm_eps: float = 1e-5,
911
+ final_dropout: bool = False,
912
+ positional_embeddings: Optional[str] = None,
913
+ num_positional_embeddings: Optional[int] = None,
914
+ ff_inner_dim: Optional[int] = None,
915
+ ff_bias: bool = True,
916
+ attention_out_bias: bool = True,
917
+ context_length: int = 16,
918
+ context_stride: int = 4,
919
+ weighting_scheme: str = "pyramid",
920
+ ):
921
+ super().__init__()
922
+ self.dim = dim
923
+ self.num_attention_heads = num_attention_heads
924
+ self.attention_head_dim = attention_head_dim
925
+ self.dropout = dropout
926
+ self.cross_attention_dim = cross_attention_dim
927
+ self.activation_fn = activation_fn
928
+ self.attention_bias = attention_bias
929
+ self.double_self_attention = double_self_attention
930
+ self.norm_elementwise_affine = norm_elementwise_affine
931
+ self.positional_embeddings = positional_embeddings
932
+ self.num_positional_embeddings = num_positional_embeddings
933
+ self.only_cross_attention = only_cross_attention
934
+
935
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
936
+
937
+ # We keep these boolean flags for backward-compatibility.
938
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
939
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
940
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
941
+ self.use_layer_norm = norm_type == "layer_norm"
942
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
943
+
944
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
945
+ raise ValueError(
946
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
947
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
948
+ )
949
+
950
+ self.norm_type = norm_type
951
+ self.num_embeds_ada_norm = num_embeds_ada_norm
952
+
953
+ if positional_embeddings and (num_positional_embeddings is None):
954
+ raise ValueError(
955
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
956
+ )
957
+
958
+ if positional_embeddings == "sinusoidal":
959
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
960
+ else:
961
+ self.pos_embed = None
962
+
963
+ # Define 3 blocks. Each block has its own normalization layer.
964
+ # 1. Self-Attn
965
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
966
+
967
+ self.attn1 = Attention(
968
+ query_dim=dim,
969
+ heads=num_attention_heads,
970
+ dim_head=attention_head_dim,
971
+ dropout=dropout,
972
+ bias=attention_bias,
973
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
974
+ upcast_attention=upcast_attention,
975
+ out_bias=attention_out_bias,
976
+ )
977
+
978
+ # 2. Cross-Attn
979
+ if cross_attention_dim is not None or double_self_attention:
980
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
981
+
982
+ self.attn2 = Attention(
983
+ query_dim=dim,
984
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
985
+ heads=num_attention_heads,
986
+ dim_head=attention_head_dim,
987
+ dropout=dropout,
988
+ bias=attention_bias,
989
+ upcast_attention=upcast_attention,
990
+ out_bias=attention_out_bias,
991
+ ) # is self-attn if encoder_hidden_states is none
992
+
993
+ # 3. Feed-forward
994
+ self.ff = FeedForward(
995
+ dim,
996
+ dropout=dropout,
997
+ activation_fn=activation_fn,
998
+ final_dropout=final_dropout,
999
+ inner_dim=ff_inner_dim,
1000
+ bias=ff_bias,
1001
+ )
1002
+
1003
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1004
+
1005
+ # let chunk size default to None
1006
+ self._chunk_size = None
1007
+ self._chunk_dim = 0
1008
+
1009
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
1010
+ frame_indices = []
1011
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
1012
+ window_start = i
1013
+ window_end = min(num_frames, i + self.context_length)
1014
+ frame_indices.append((window_start, window_end))
1015
+ return frame_indices
1016
+
1017
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
1018
+ if weighting_scheme == "flat":
1019
+ weights = [1.0] * num_frames
1020
+
1021
+ elif weighting_scheme == "pyramid":
1022
+ if num_frames % 2 == 0:
1023
+ # num_frames = 4 => [1, 2, 2, 1]
1024
+ mid = num_frames // 2
1025
+ weights = list(range(1, mid + 1))
1026
+ weights = weights + weights[::-1]
1027
+ else:
1028
+ # num_frames = 5 => [1, 2, 3, 2, 1]
1029
+ mid = (num_frames + 1) // 2
1030
+ weights = list(range(1, mid))
1031
+ weights = weights + [mid] + weights[::-1]
1032
+
1033
+ elif weighting_scheme == "delayed_reverse_sawtooth":
1034
+ if num_frames % 2 == 0:
1035
+ # num_frames = 4 => [0.01, 2, 2, 1]
1036
+ mid = num_frames // 2
1037
+ weights = [0.01] * (mid - 1) + [mid]
1038
+ weights = weights + list(range(mid, 0, -1))
1039
+ else:
1040
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
1041
+ mid = (num_frames + 1) // 2
1042
+ weights = [0.01] * mid
1043
+ weights = weights + list(range(mid, 0, -1))
1044
+ else:
1045
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1046
+
1047
+ return weights
1048
+
1049
+ def set_free_noise_properties(
1050
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1051
+ ) -> None:
1052
+ self.context_length = context_length
1053
+ self.context_stride = context_stride
1054
+ self.weighting_scheme = weighting_scheme
1055
+
1056
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1057
+ # Sets chunk feed-forward
1058
+ self._chunk_size = chunk_size
1059
+ self._chunk_dim = dim
1060
+
1061
+ def forward(
1062
+ self,
1063
+ hidden_states: torch.Tensor,
1064
+ attention_mask: Optional[torch.Tensor] = None,
1065
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1066
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1067
+ cross_attention_kwargs: Dict[str, Any] = None,
1068
+ *args,
1069
+ **kwargs,
1070
+ ) -> torch.Tensor:
1071
+ if cross_attention_kwargs is not None:
1072
+ if cross_attention_kwargs.get("scale", None) is not None:
1073
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1074
+
1075
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1076
+
1077
+ # hidden_states: [B x H x W, F, C]
1078
+ device = hidden_states.device
1079
+ dtype = hidden_states.dtype
1080
+
1081
+ num_frames = hidden_states.size(1)
1082
+ frame_indices = self._get_frame_indices(num_frames)
1083
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1084
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1085
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1086
+
1087
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1088
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1089
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
1090
+ if not is_last_frame_batch_complete:
1091
+ if num_frames < self.context_length:
1092
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1093
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
1094
+ frame_indices.append((num_frames - self.context_length, num_frames))
1095
+
1096
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1097
+ accumulated_values = torch.zeros_like(hidden_states)
1098
+
1099
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
1100
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1101
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1102
+ # essentially a non-multiple of `context_length`.
1103
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1104
+ weights *= frame_weights
1105
+
1106
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1107
+
1108
+ # Notice that normalization is always applied before the real computation in the following blocks.
1109
+ # 1. Self-Attention
1110
+ norm_hidden_states = self.norm1(hidden_states_chunk)
1111
+
1112
+ if self.pos_embed is not None:
1113
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1114
+
1115
+ attn_output = self.attn1(
1116
+ norm_hidden_states,
1117
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1118
+ attention_mask=attention_mask,
1119
+ **cross_attention_kwargs,
1120
+ )
1121
+
1122
+ hidden_states_chunk = attn_output + hidden_states_chunk
1123
+ if hidden_states_chunk.ndim == 4:
1124
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
1125
+
1126
+ # 2. Cross-Attention
1127
+ if self.attn2 is not None:
1128
+ norm_hidden_states = self.norm2(hidden_states_chunk)
1129
+
1130
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1131
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1132
+
1133
+ attn_output = self.attn2(
1134
+ norm_hidden_states,
1135
+ encoder_hidden_states=encoder_hidden_states,
1136
+ attention_mask=encoder_attention_mask,
1137
+ **cross_attention_kwargs,
1138
+ )
1139
+ hidden_states_chunk = attn_output + hidden_states_chunk
1140
+
1141
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1142
+ accumulated_values[:, -last_frame_batch_length:] += (
1143
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1144
+ )
1145
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1146
+ else:
1147
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1148
+ num_times_accumulated[:, frame_start:frame_end] += weights
1149
+
1150
+ # TODO(aryan): Maybe this could be done in a better way.
1151
+ #
1152
+ # Previously, this was:
1153
+ # hidden_states = torch.where(
1154
+ # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1155
+ # )
1156
+ #
1157
+ # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1158
+ # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1159
+ # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1160
+ # looked into this deeply because other memory optimizations led to more pronounced reductions.
1161
+ hidden_states = torch.cat(
1162
+ [
1163
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1164
+ for accumulated_split, num_times_split in zip(
1165
+ accumulated_values.split(self.context_length, dim=1),
1166
+ num_times_accumulated.split(self.context_length, dim=1),
1167
+ )
1168
+ ],
1169
+ dim=1,
1170
+ ).to(dtype)
1171
+
1172
+ # 3. Feed-forward
1173
+ norm_hidden_states = self.norm3(hidden_states)
1174
+
1175
+ if self._chunk_size is not None:
1176
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1177
+ else:
1178
+ ff_output = self.ff(norm_hidden_states)
1179
+
1180
+ hidden_states = ff_output + hidden_states
1181
+ if hidden_states.ndim == 4:
1182
+ hidden_states = hidden_states.squeeze(1)
1183
+
1184
+ return hidden_states
1185
+
1186
+
1187
+ class FeedForward(nn.Module):
1188
+ r"""
1189
+ A feed-forward layer.
1190
+
1191
+ Parameters:
1192
+ dim (`int`): The number of channels in the input.
1193
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1194
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1195
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1196
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1197
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1198
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1199
+ """
1200
+
1201
+ def __init__(
1202
+ self,
1203
+ dim: int,
1204
+ dim_out: Optional[int] = None,
1205
+ mult: int = 4,
1206
+ dropout: float = 0.0,
1207
+ activation_fn: str = "geglu",
1208
+ final_dropout: bool = False,
1209
+ inner_dim=None,
1210
+ bias: bool = True,
1211
+ ):
1212
+ super().__init__()
1213
+ if inner_dim is None:
1214
+ inner_dim = int(dim * mult)
1215
+ dim_out = dim_out if dim_out is not None else dim
1216
+
1217
+ if activation_fn == "gelu":
1218
+ act_fn = GELU(dim, inner_dim, bias=bias)
1219
+ if activation_fn == "gelu-approximate":
1220
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1221
+ elif activation_fn == "geglu":
1222
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1223
+ elif activation_fn == "geglu-approximate":
1224
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1225
+ elif activation_fn == "swiglu":
1226
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
1227
+
1228
+ self.net = nn.ModuleList([])
1229
+ # project in
1230
+ self.net.append(act_fn)
1231
+ # project dropout
1232
+ self.net.append(nn.Dropout(dropout))
1233
+ # project out
1234
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
1235
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1236
+ if final_dropout:
1237
+ self.net.append(nn.Dropout(dropout))
1238
+
1239
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1240
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1241
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1242
+ deprecate("scale", "1.0.0", deprecation_message)
1243
+ for module in self.net:
1244
+ hidden_states = module(hidden_states)
1245
+ return hidden_states
models/resampler.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusers.models.embeddings import Timesteps, TimestepEmbedding
8
+
9
+ def get_timestep_embedding(
10
+ timesteps: torch.Tensor,
11
+ embedding_dim: int,
12
+ flip_sin_to_cos: bool = False,
13
+ downscale_freq_shift: float = 1,
14
+ scale: float = 1,
15
+ max_period: int = 10000,
16
+ ):
17
+ """
18
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
19
+
20
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
21
+ These may be fractional.
22
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
23
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
24
+ """
25
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
26
+
27
+ half_dim = embedding_dim // 2
28
+ exponent = -math.log(max_period) * torch.arange(
29
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
30
+ )
31
+ exponent = exponent / (half_dim - downscale_freq_shift)
32
+
33
+ emb = torch.exp(exponent)
34
+ emb = timesteps[:, None].float() * emb[None, :]
35
+
36
+ # scale embeddings
37
+ emb = scale * emb
38
+
39
+ # concat sine and cosine embeddings
40
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
41
+
42
+ # flip sine and cosine embeddings
43
+ if flip_sin_to_cos:
44
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
45
+
46
+ # zero pad
47
+ if embedding_dim % 2 == 1:
48
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
49
+ return emb
50
+
51
+
52
+ # FFN
53
+ def FeedForward(dim, mult=4):
54
+ inner_dim = int(dim * mult)
55
+ return nn.Sequential(
56
+ nn.LayerNorm(dim),
57
+ nn.Linear(dim, inner_dim, bias=False),
58
+ nn.GELU(),
59
+ nn.Linear(inner_dim, dim, bias=False),
60
+ )
61
+
62
+
63
+ def reshape_tensor(x, heads):
64
+ bs, length, width = x.shape
65
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
66
+ x = x.view(bs, length, heads, -1)
67
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
68
+ x = x.transpose(1, 2)
69
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
70
+ x = x.reshape(bs, heads, length, -1)
71
+ return x
72
+
73
+
74
+ class PerceiverAttention(nn.Module):
75
+ def __init__(self, *, dim, dim_head=64, heads=8):
76
+ super().__init__()
77
+ self.scale = dim_head**-0.5
78
+ self.dim_head = dim_head
79
+ self.heads = heads
80
+ inner_dim = dim_head * heads
81
+
82
+ self.norm1 = nn.LayerNorm(dim)
83
+ self.norm2 = nn.LayerNorm(dim)
84
+
85
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
86
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
87
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
88
+
89
+
90
+ def forward(self, x, latents, shift=None, scale=None):
91
+ """
92
+ Args:
93
+ x (torch.Tensor): image features
94
+ shape (b, n1, D)
95
+ latent (torch.Tensor): latent features
96
+ shape (b, n2, D)
97
+ """
98
+ x = self.norm1(x)
99
+ latents = self.norm2(latents)
100
+
101
+ if shift is not None and scale is not None:
102
+ latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
103
+
104
+ b, l, _ = latents.shape
105
+
106
+ q = self.to_q(latents)
107
+ kv_input = torch.cat((x, latents), dim=-2)
108
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
109
+
110
+ q = reshape_tensor(q, self.heads)
111
+ k = reshape_tensor(k, self.heads)
112
+ v = reshape_tensor(v, self.heads)
113
+
114
+ # attention
115
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
116
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
117
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
118
+ out = weight @ v
119
+
120
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
121
+
122
+ return self.to_out(out)
123
+
124
+
125
+ class Resampler(nn.Module):
126
+ def __init__(
127
+ self,
128
+ dim=1024,
129
+ depth=8,
130
+ dim_head=64,
131
+ heads=16,
132
+ num_queries=8,
133
+ embedding_dim=768,
134
+ output_dim=1024,
135
+ ff_mult=4,
136
+ *args,
137
+ **kwargs,
138
+ ):
139
+ super().__init__()
140
+
141
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
142
+
143
+ self.proj_in = nn.Linear(embedding_dim, dim)
144
+
145
+ self.proj_out = nn.Linear(dim, output_dim)
146
+ self.norm_out = nn.LayerNorm(output_dim)
147
+
148
+ self.layers = nn.ModuleList([])
149
+ for _ in range(depth):
150
+ self.layers.append(
151
+ nn.ModuleList(
152
+ [
153
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
154
+ FeedForward(dim=dim, mult=ff_mult),
155
+ ]
156
+ )
157
+ )
158
+
159
+ def forward(self, x):
160
+
161
+ latents = self.latents.repeat(x.size(0), 1, 1)
162
+
163
+ x = self.proj_in(x)
164
+
165
+ for attn, ff in self.layers:
166
+ latents = attn(x, latents) + latents
167
+ latents = ff(latents) + latents
168
+
169
+ latents = self.proj_out(latents)
170
+ return self.norm_out(latents)
171
+
172
+
173
+ class TimeResampler(nn.Module):
174
+ def __init__(
175
+ self,
176
+ dim=1024,
177
+ depth=8,
178
+ dim_head=64,
179
+ heads=16,
180
+ num_queries=8,
181
+ embedding_dim=768,
182
+ output_dim=1024,
183
+ ff_mult=4,
184
+ timestep_in_dim=320,
185
+ timestep_flip_sin_to_cos=True,
186
+ timestep_freq_shift=0,
187
+ ):
188
+ super().__init__()
189
+
190
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
191
+
192
+ self.proj_in = nn.Linear(embedding_dim, dim)
193
+
194
+ self.proj_out = nn.Linear(dim, output_dim)
195
+ self.norm_out = nn.LayerNorm(output_dim)
196
+
197
+ self.layers = nn.ModuleList([])
198
+ for _ in range(depth):
199
+ self.layers.append(
200
+ nn.ModuleList(
201
+ [
202
+ # msa
203
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
204
+ # ff
205
+ FeedForward(dim=dim, mult=ff_mult),
206
+ # adaLN
207
+ nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True))
208
+ ]
209
+ )
210
+ )
211
+
212
+ # time
213
+ self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
214
+ self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
215
+
216
+ # adaLN
217
+ # self.adaLN_modulation = nn.Sequential(
218
+ # nn.SiLU(),
219
+ # nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
220
+ # )
221
+
222
+
223
+ def forward(self, x, timestep, need_temb=False):
224
+ timestep_emb = self.embedding_time(x, timestep) # bs, dim
225
+
226
+ latents = self.latents.repeat(x.size(0), 1, 1)
227
+
228
+ x = self.proj_in(x)
229
+ x = x + timestep_emb[:, None]
230
+
231
+ for attn, ff, adaLN_modulation in self.layers:
232
+ shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1)
233
+ latents = attn(x, latents, shift_msa, scale_msa) + latents
234
+
235
+ res = latents
236
+ for idx_ff in range(len(ff)):
237
+ layer_ff = ff[idx_ff]
238
+ latents = layer_ff(latents)
239
+ if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
240
+ latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
241
+ latents = latents + res
242
+
243
+ # latents = ff(latents) + latents
244
+
245
+ latents = self.proj_out(latents)
246
+ latents = self.norm_out(latents)
247
+
248
+ if need_temb:
249
+ return latents, timestep_emb
250
+ else:
251
+ return latents
252
+
253
+
254
+
255
+ def embedding_time(self, sample, timestep):
256
+
257
+ # 1. time
258
+ timesteps = timestep
259
+ if not torch.is_tensor(timesteps):
260
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
261
+ # This would be a good case for the `match` statement (Python 3.10+)
262
+ is_mps = sample.device.type == "mps"
263
+ if isinstance(timestep, float):
264
+ dtype = torch.float32 if is_mps else torch.float64
265
+ else:
266
+ dtype = torch.int32 if is_mps else torch.int64
267
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
268
+ elif len(timesteps.shape) == 0:
269
+ timesteps = timesteps[None].to(sample.device)
270
+
271
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
272
+ timesteps = timesteps.expand(sample.shape[0])
273
+
274
+ t_emb = self.time_proj(timesteps)
275
+
276
+ # timesteps does not contain any weights and will always return f32 tensors
277
+ # but time_embedding might actually be running in fp16. so we need to cast here.
278
+ # there might be better ways to encapsulate this.
279
+ t_emb = t_emb.to(dtype=sample.dtype)
280
+
281
+ emb = self.time_embedding(t_emb, None)
282
+ return emb
283
+
284
+
285
+
286
+
287
+
288
+ if __name__ == '__main__':
289
+ model = TimeResampler(
290
+ dim=1280,
291
+ depth=4,
292
+ dim_head=64,
293
+ heads=20,
294
+ num_queries=16,
295
+ embedding_dim=512,
296
+ output_dim=2048,
297
+ ff_mult=4,
298
+ timestep_in_dim=320,
299
+ timestep_flip_sin_to_cos=True,
300
+ timestep_freq_shift=0,
301
+ in_channel_extra_emb=2048,
302
+ )
303
+
304
+
models/transformer_sd3.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from .attention import JointTransformerBlock
24
+ from diffusers.models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import AdaLayerNormContinuous
27
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
+ from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
29
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
36
+ """
37
+ The Transformer model introduced in Stable Diffusion 3.
38
+
39
+ Reference: https://arxiv.org/abs/2403.03206
40
+
41
+ Parameters:
42
+ sample_size (`int`): The width of the latent images. This is fixed during training since
43
+ it is used to learn a number of position embeddings.
44
+ patch_size (`int`): Patch size to turn the input data into small patches.
45
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
46
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
47
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
48
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
49
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
50
+ caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
51
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
52
+ out_channels (`int`, defaults to 16): Number of output channels.
53
+
54
+ """
55
+
56
+ _supports_gradient_checkpointing = True
57
+
58
+ @register_to_config
59
+ def __init__(
60
+ self,
61
+ sample_size: int = 128,
62
+ patch_size: int = 2,
63
+ in_channels: int = 16,
64
+ num_layers: int = 18,
65
+ attention_head_dim: int = 64,
66
+ num_attention_heads: int = 18,
67
+ joint_attention_dim: int = 4096,
68
+ caption_projection_dim: int = 1152,
69
+ pooled_projection_dim: int = 2048,
70
+ out_channels: int = 16,
71
+ pos_embed_max_size: int = 96,
72
+ dual_attention_layers: Tuple[
73
+ int, ...
74
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
75
+ qk_norm: Optional[str] = None,
76
+ ):
77
+ super().__init__()
78
+ default_out_channels = in_channels
79
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
80
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
81
+
82
+ self.pos_embed = PatchEmbed(
83
+ height=self.config.sample_size,
84
+ width=self.config.sample_size,
85
+ patch_size=self.config.patch_size,
86
+ in_channels=self.config.in_channels,
87
+ embed_dim=self.inner_dim,
88
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
89
+ )
90
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
91
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
92
+ )
93
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
94
+
95
+ # `attention_head_dim` is doubled to account for the mixing.
96
+ # It needs to crafted when we get the actual checkpoints.
97
+ self.transformer_blocks = nn.ModuleList(
98
+ [
99
+ JointTransformerBlock(
100
+ dim=self.inner_dim,
101
+ num_attention_heads=self.config.num_attention_heads,
102
+ attention_head_dim=self.config.attention_head_dim,
103
+ context_pre_only=i == num_layers - 1,
104
+ qk_norm=qk_norm,
105
+ use_dual_attention=True if i in dual_attention_layers else False,
106
+ )
107
+ for i in range(self.config.num_layers)
108
+ ]
109
+ )
110
+
111
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
112
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
113
+
114
+ self.gradient_checkpointing = False
115
+
116
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
117
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
118
+ """
119
+ Sets the attention processor to use [feed forward
120
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
121
+
122
+ Parameters:
123
+ chunk_size (`int`, *optional*):
124
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
125
+ over each tensor of dim=`dim`.
126
+ dim (`int`, *optional*, defaults to `0`):
127
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
128
+ or dim=1 (sequence length).
129
+ """
130
+ if dim not in [0, 1]:
131
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
132
+
133
+ # By default chunk size is 1
134
+ chunk_size = chunk_size or 1
135
+
136
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
137
+ if hasattr(module, "set_chunk_feed_forward"):
138
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
139
+
140
+ for child in module.children():
141
+ fn_recursive_feed_forward(child, chunk_size, dim)
142
+
143
+ for module in self.children():
144
+ fn_recursive_feed_forward(module, chunk_size, dim)
145
+
146
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
147
+ def disable_forward_chunking(self):
148
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
149
+ if hasattr(module, "set_chunk_feed_forward"):
150
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
151
+
152
+ for child in module.children():
153
+ fn_recursive_feed_forward(child, chunk_size, dim)
154
+
155
+ for module in self.children():
156
+ fn_recursive_feed_forward(module, None, 0)
157
+
158
+ @property
159
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
160
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
161
+ r"""
162
+ Returns:
163
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
164
+ indexed by its weight name.
165
+ """
166
+ # set recursively
167
+ processors = {}
168
+
169
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
170
+ if hasattr(module, "get_processor"):
171
+ processors[f"{name}.processor"] = module.get_processor()
172
+
173
+ for sub_name, child in module.named_children():
174
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
175
+
176
+ return processors
177
+
178
+ for name, module in self.named_children():
179
+ fn_recursive_add_processors(name, module, processors)
180
+
181
+ return processors
182
+
183
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
184
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
185
+ r"""
186
+ Sets the attention processor to use to compute attention.
187
+
188
+ Parameters:
189
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
190
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
191
+ for **all** `Attention` layers.
192
+
193
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
194
+ processor. This is strongly recommended when setting trainable attention processors.
195
+
196
+ """
197
+ count = len(self.attn_processors.keys())
198
+
199
+ if isinstance(processor, dict) and len(processor) != count:
200
+ raise ValueError(
201
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
202
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
203
+ )
204
+
205
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
206
+ if hasattr(module, "set_processor"):
207
+ if not isinstance(processor, dict):
208
+ module.set_processor(processor)
209
+ else:
210
+ module.set_processor(processor.pop(f"{name}.processor"))
211
+
212
+ for sub_name, child in module.named_children():
213
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
214
+
215
+ for name, module in self.named_children():
216
+ fn_recursive_attn_processor(name, module, processor)
217
+
218
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
219
+ def fuse_qkv_projections(self):
220
+ """
221
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
222
+ are fused. For cross-attention modules, key and value projection matrices are fused.
223
+
224
+ <Tip warning={true}>
225
+
226
+ This API is 🧪 experimental.
227
+
228
+ </Tip>
229
+ """
230
+ self.original_attn_processors = None
231
+
232
+ for _, attn_processor in self.attn_processors.items():
233
+ if "Added" in str(attn_processor.__class__.__name__):
234
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
235
+
236
+ self.original_attn_processors = self.attn_processors
237
+
238
+ for module in self.modules():
239
+ if isinstance(module, Attention):
240
+ module.fuse_projections(fuse=True)
241
+
242
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
243
+
244
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
245
+ def unfuse_qkv_projections(self):
246
+ """Disables the fused QKV projection if enabled.
247
+
248
+ <Tip warning={true}>
249
+
250
+ This API is 🧪 experimental.
251
+
252
+ </Tip>
253
+
254
+ """
255
+ if self.original_attn_processors is not None:
256
+ self.set_attn_processor(self.original_attn_processors)
257
+
258
+ def _set_gradient_checkpointing(self, module, value=False):
259
+ if hasattr(module, "gradient_checkpointing"):
260
+ module.gradient_checkpointing = value
261
+
262
+ def forward(
263
+ self,
264
+ hidden_states: torch.FloatTensor,
265
+ encoder_hidden_states: torch.FloatTensor = None,
266
+ pooled_projections: torch.FloatTensor = None,
267
+ timestep: torch.LongTensor = None,
268
+ block_controlnet_hidden_states: List = None,
269
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
270
+ return_dict: bool = True,
271
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
272
+ """
273
+ The [`SD3Transformer2DModel`] forward method.
274
+
275
+ Args:
276
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
277
+ Input `hidden_states`.
278
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
279
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281
+ from the embeddings of input conditions.
282
+ timestep ( `torch.LongTensor`):
283
+ Used to indicate denoising step.
284
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
285
+ A list of tensors that if specified are added to the residuals of transformer blocks.
286
+ joint_attention_kwargs (`dict`, *optional*):
287
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
288
+ `self.processor` in
289
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
290
+ return_dict (`bool`, *optional*, defaults to `True`):
291
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
292
+ tuple.
293
+
294
+ Returns:
295
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
296
+ `tuple` where the first element is the sample tensor.
297
+ """
298
+ if joint_attention_kwargs is not None:
299
+ joint_attention_kwargs = joint_attention_kwargs.copy()
300
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
301
+ else:
302
+ lora_scale = 1.0
303
+
304
+ if USE_PEFT_BACKEND:
305
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
306
+ scale_lora_layers(self, lora_scale)
307
+ else:
308
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
309
+ logger.warning(
310
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
311
+ )
312
+
313
+ height, width = hidden_states.shape[-2:]
314
+
315
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
316
+ temb = self.time_text_embed(timestep, pooled_projections)
317
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
318
+
319
+ for index_block, block in enumerate(self.transformer_blocks):
320
+ if self.training and self.gradient_checkpointing:
321
+
322
+ def create_custom_forward(module, return_dict=None):
323
+ def custom_forward(*inputs):
324
+ if return_dict is not None:
325
+ return module(*inputs, return_dict=return_dict)
326
+ else:
327
+ return module(*inputs)
328
+
329
+ return custom_forward
330
+
331
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
332
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
333
+ create_custom_forward(block),
334
+ hidden_states,
335
+ encoder_hidden_states,
336
+ temb,
337
+ joint_attention_kwargs,
338
+ **ckpt_kwargs,
339
+ )
340
+
341
+ else:
342
+ encoder_hidden_states, hidden_states = block(
343
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
344
+ joint_attention_kwargs=joint_attention_kwargs,
345
+ )
346
+
347
+ # controlnet residual
348
+ if block_controlnet_hidden_states is not None and block.context_pre_only is False:
349
+ interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
350
+ hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
351
+
352
+ hidden_states = self.norm_out(hidden_states, temb)
353
+ hidden_states = self.proj_out(hidden_states)
354
+
355
+ # unpatchify
356
+ patch_size = self.config.patch_size
357
+ height = height // patch_size
358
+ width = width // patch_size
359
+
360
+ hidden_states = hidden_states.reshape(
361
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
362
+ )
363
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
364
+ output = hidden_states.reshape(
365
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
366
+ )
367
+
368
+ if USE_PEFT_BACKEND:
369
+ # remove `lora_scale` from each PEFT layer
370
+ unscale_lora_layers(self, lora_scale)
371
+
372
+ if not return_dict:
373
+ return (output,)
374
+
375
+ return Transformer2DModelOutput(sample=output)
pipeline_stable_diffusion_3_ipa.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers import (
22
+ CLIPTextModelWithProjection,
23
+ CLIPTokenizer,
24
+ T5EncoderModel,
25
+ T5TokenizerFast,
26
+ )
27
+
28
+ from diffusers.image_processor import VaeImageProcessor
29
+ from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
30
+ from diffusers.models.autoencoders import AutoencoderKL
31
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
32
+ from diffusers.utils import (
33
+ USE_PEFT_BACKEND,
34
+ is_torch_xla_available,
35
+ logging,
36
+ replace_example_docstring,
37
+ scale_lora_layers,
38
+ unscale_lora_layers,
39
+ )
40
+ from diffusers.utils.torch_utils import randn_tensor
41
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
43
+
44
+ from models.resampler import TimeResampler
45
+ from models.transformer_sd3 import SD3Transformer2DModel
46
+ from diffusers.models.normalization import RMSNorm
47
+ from einops import rearrange
48
+
49
+
50
+ if is_torch_xla_available():
51
+ import torch_xla.core.xla_model as xm
52
+
53
+ XLA_AVAILABLE = True
54
+ else:
55
+ XLA_AVAILABLE = False
56
+
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ EXAMPLE_DOC_STRING = """
61
+ Examples:
62
+ ```py
63
+ >>> import torch
64
+ >>> from diffusers import StableDiffusion3Pipeline
65
+
66
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
67
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
68
+ ... )
69
+ >>> pipe.to("cuda")
70
+ >>> prompt = "A cat holding a sign that says hello world"
71
+ >>> image = pipe(prompt).images[0]
72
+ >>> image.save("sd3.png")
73
+ ```
74
+ """
75
+
76
+
77
+ class AdaLayerNorm(nn.Module):
78
+ """
79
+ Norm layer adaptive layer norm zero (adaLN-Zero).
80
+
81
+ Parameters:
82
+ embedding_dim (`int`): The size of each embedding vector.
83
+ num_embeddings (`int`): The size of the embeddings dictionary.
84
+ """
85
+
86
+ def __init__(self, embedding_dim: int, time_embedding_dim=None, mode='normal'):
87
+ super().__init__()
88
+
89
+ self.silu = nn.SiLU()
90
+ num_params_dict = dict(
91
+ zero=6,
92
+ normal=2,
93
+ )
94
+ num_params = num_params_dict[mode]
95
+ self.linear = nn.Linear(time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True)
96
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
97
+ self.mode = mode
98
+
99
+ def forward(
100
+ self,
101
+ x,
102
+ hidden_dtype = None,
103
+ emb = None,
104
+ ):
105
+ emb = self.linear(self.silu(emb))
106
+ if self.mode == 'normal':
107
+ shift_msa, scale_msa = emb.chunk(2, dim=1)
108
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
109
+ return x
110
+
111
+ elif self.mode == 'zero':
112
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
113
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
114
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
115
+
116
+
117
+ class JointIPAttnProcessor(torch.nn.Module):
118
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
119
+
120
+ def __init__(
121
+ self,
122
+ hidden_size=None,
123
+ cross_attention_dim=None,
124
+ ip_hidden_states_dim=None,
125
+ ip_encoder_hidden_states_dim=None,
126
+ head_dim=None,
127
+ timesteps_emb_dim=1280,
128
+ ):
129
+ super().__init__()
130
+
131
+ self.norm_ip = AdaLayerNorm(ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim)
132
+ self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
133
+ self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
134
+ self.norm_q = RMSNorm(head_dim, 1e-6)
135
+ self.norm_k = RMSNorm(head_dim, 1e-6)
136
+ self.norm_ip_k = RMSNorm(head_dim, 1e-6)
137
+
138
+
139
+ def __call__(
140
+ self,
141
+ attn,
142
+ hidden_states: torch.FloatTensor,
143
+ encoder_hidden_states: torch.FloatTensor = None,
144
+ attention_mask: Optional[torch.FloatTensor] = None,
145
+ emb_dict=None,
146
+ *args,
147
+ **kwargs,
148
+ ) -> torch.FloatTensor:
149
+ residual = hidden_states
150
+
151
+ batch_size = hidden_states.shape[0]
152
+
153
+ # `sample` projections.
154
+ query = attn.to_q(hidden_states)
155
+ key = attn.to_k(hidden_states)
156
+ value = attn.to_v(hidden_states)
157
+ img_query = query
158
+ img_key = key
159
+ img_value = value
160
+
161
+ inner_dim = key.shape[-1]
162
+ head_dim = inner_dim // attn.heads
163
+
164
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
165
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
166
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
167
+
168
+ if attn.norm_q is not None:
169
+ query = attn.norm_q(query)
170
+ if attn.norm_k is not None:
171
+ key = attn.norm_k(key)
172
+
173
+ # `context` projections.
174
+ if encoder_hidden_states is not None:
175
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
176
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
177
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
178
+
179
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
180
+ batch_size, -1, attn.heads, head_dim
181
+ ).transpose(1, 2)
182
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
183
+ batch_size, -1, attn.heads, head_dim
184
+ ).transpose(1, 2)
185
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
186
+ batch_size, -1, attn.heads, head_dim
187
+ ).transpose(1, 2)
188
+
189
+ if attn.norm_added_q is not None:
190
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
191
+ if attn.norm_added_k is not None:
192
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
193
+
194
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
195
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
196
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
197
+
198
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
199
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
200
+ hidden_states = hidden_states.to(query.dtype)
201
+
202
+ if encoder_hidden_states is not None:
203
+ # Split the attention outputs.
204
+ hidden_states, encoder_hidden_states = (
205
+ hidden_states[:, : residual.shape[1]],
206
+ hidden_states[:, residual.shape[1] :],
207
+ )
208
+ if not attn.context_pre_only:
209
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
210
+
211
+
212
+ # IPadapter
213
+ ip_hidden_states = emb_dict.get('ip_hidden_states', None)
214
+ ip_hidden_states = self.get_ip_hidden_states(
215
+ attn,
216
+ img_query,
217
+ ip_hidden_states,
218
+ img_key,
219
+ img_value,
220
+ None,
221
+ None,
222
+ emb_dict['temb'],
223
+ )
224
+ if ip_hidden_states is not None:
225
+ hidden_states = hidden_states + ip_hidden_states * emb_dict.get('scale', 1.0)
226
+
227
+
228
+ # linear proj
229
+ hidden_states = attn.to_out[0](hidden_states)
230
+ # dropout
231
+ hidden_states = attn.to_out[1](hidden_states)
232
+
233
+ if encoder_hidden_states is not None:
234
+ return hidden_states, encoder_hidden_states
235
+ else:
236
+ return hidden_states
237
+
238
+
239
+ def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_value=None, text_key=None, text_value=None, temb=None):
240
+ if ip_hidden_states is None:
241
+ return None
242
+
243
+ if not hasattr(self, 'to_k_ip') or not hasattr(self, 'to_v_ip'):
244
+ return None
245
+
246
+ # norm ip input
247
+ norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=temb)
248
+
249
+ # to k and v
250
+ ip_key = self.to_k_ip(norm_ip_hidden_states)
251
+ ip_value = self.to_v_ip(norm_ip_hidden_states)
252
+
253
+ # reshape
254
+ query = rearrange(query, 'b l (h d) -> b h l d', h=attn.heads)
255
+ img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads)
256
+ img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads)
257
+ ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads)
258
+ ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads)
259
+
260
+ # norm
261
+ query = self.norm_q(query)
262
+ img_key = self.norm_k(img_key)
263
+ ip_key = self.norm_ip_k(ip_key)
264
+
265
+ # cat img
266
+ key = torch.cat([img_key, ip_key], dim=2)
267
+ value = torch.cat([img_value, ip_value], dim=2)
268
+
269
+ #
270
+ ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
271
+ ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)')
272
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
273
+ return ip_hidden_states
274
+
275
+
276
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
277
+ def retrieve_timesteps(
278
+ scheduler,
279
+ num_inference_steps: Optional[int] = None,
280
+ device: Optional[Union[str, torch.device]] = None,
281
+ timesteps: Optional[List[int]] = None,
282
+ sigmas: Optional[List[float]] = None,
283
+ **kwargs,
284
+ ):
285
+ """
286
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
287
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
288
+
289
+ Args:
290
+ scheduler (`SchedulerMixin`):
291
+ The scheduler to get timesteps from.
292
+ num_inference_steps (`int`):
293
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
294
+ must be `None`.
295
+ device (`str` or `torch.device`, *optional*):
296
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
297
+ timesteps (`List[int]`, *optional*):
298
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
299
+ `num_inference_steps` and `sigmas` must be `None`.
300
+ sigmas (`List[float]`, *optional*):
301
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
302
+ `num_inference_steps` and `timesteps` must be `None`.
303
+
304
+ Returns:
305
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
306
+ second element is the number of inference steps.
307
+ """
308
+ if timesteps is not None and sigmas is not None:
309
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
310
+ if timesteps is not None:
311
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
312
+ if not accepts_timesteps:
313
+ raise ValueError(
314
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
315
+ f" timestep schedules. Please check whether you are using the correct scheduler."
316
+ )
317
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
318
+ timesteps = scheduler.timesteps
319
+ num_inference_steps = len(timesteps)
320
+ elif sigmas is not None:
321
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
322
+ if not accept_sigmas:
323
+ raise ValueError(
324
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
325
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
326
+ )
327
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
328
+ timesteps = scheduler.timesteps
329
+ num_inference_steps = len(timesteps)
330
+ else:
331
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
332
+ timesteps = scheduler.timesteps
333
+ return timesteps, num_inference_steps
334
+
335
+
336
+ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
337
+ r"""
338
+ Args:
339
+ transformer ([`SD3Transformer2DModel`]):
340
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
341
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
342
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
343
+ vae ([`AutoencoderKL`]):
344
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
345
+ text_encoder ([`CLIPTextModelWithProjection`]):
346
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
347
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
348
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
349
+ as its dimension.
350
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
351
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
352
+ specifically the
353
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
354
+ variant.
355
+ text_encoder_3 ([`T5EncoderModel`]):
356
+ Frozen text-encoder. Stable Diffusion 3 uses
357
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
358
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
359
+ tokenizer (`CLIPTokenizer`):
360
+ Tokenizer of class
361
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
362
+ tokenizer_2 (`CLIPTokenizer`):
363
+ Second Tokenizer of class
364
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
365
+ tokenizer_3 (`T5TokenizerFast`):
366
+ Tokenizer of class
367
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
368
+ """
369
+
370
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
371
+ _optional_components = []
372
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
373
+
374
+ def __init__(
375
+ self,
376
+ transformer: SD3Transformer2DModel,
377
+ scheduler: FlowMatchEulerDiscreteScheduler,
378
+ vae: AutoencoderKL,
379
+ text_encoder: CLIPTextModelWithProjection,
380
+ tokenizer: CLIPTokenizer,
381
+ text_encoder_2: CLIPTextModelWithProjection,
382
+ tokenizer_2: CLIPTokenizer,
383
+ text_encoder_3: T5EncoderModel,
384
+ tokenizer_3: T5TokenizerFast,
385
+ ):
386
+ super().__init__()
387
+
388
+ self.register_modules(
389
+ vae=vae,
390
+ text_encoder=text_encoder,
391
+ text_encoder_2=text_encoder_2,
392
+ text_encoder_3=text_encoder_3,
393
+ tokenizer=tokenizer,
394
+ tokenizer_2=tokenizer_2,
395
+ tokenizer_3=tokenizer_3,
396
+ transformer=transformer,
397
+ scheduler=scheduler,
398
+ )
399
+ self.vae_scale_factor = (
400
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
401
+ )
402
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
403
+ self.tokenizer_max_length = (
404
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
405
+ )
406
+ self.default_sample_size = (
407
+ self.transformer.config.sample_size
408
+ if hasattr(self, "transformer") and self.transformer is not None
409
+ else 128
410
+ )
411
+
412
+ def _get_t5_prompt_embeds(
413
+ self,
414
+ prompt: Union[str, List[str]] = None,
415
+ num_images_per_prompt: int = 1,
416
+ max_sequence_length: int = 256,
417
+ device: Optional[torch.device] = None,
418
+ dtype: Optional[torch.dtype] = None,
419
+ ):
420
+ device = device or self._execution_device
421
+ dtype = dtype or self.text_encoder.dtype
422
+
423
+ prompt = [prompt] if isinstance(prompt, str) else prompt
424
+ batch_size = len(prompt)
425
+
426
+ if self.text_encoder_3 is None:
427
+ return torch.zeros(
428
+ (
429
+ batch_size * num_images_per_prompt,
430
+ self.tokenizer_max_length,
431
+ self.transformer.config.joint_attention_dim,
432
+ ),
433
+ device=device,
434
+ dtype=dtype,
435
+ )
436
+
437
+ text_inputs = self.tokenizer_3(
438
+ prompt,
439
+ padding="max_length",
440
+ max_length=max_sequence_length,
441
+ truncation=True,
442
+ add_special_tokens=True,
443
+ return_tensors="pt",
444
+ )
445
+ text_input_ids = text_inputs.input_ids
446
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
447
+
448
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
449
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
450
+ logger.warning(
451
+ "The following part of your input was truncated because `max_sequence_length` is set to "
452
+ f" {max_sequence_length} tokens: {removed_text}"
453
+ )
454
+
455
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
456
+
457
+ dtype = self.text_encoder_3.dtype
458
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
459
+
460
+ _, seq_len, _ = prompt_embeds.shape
461
+
462
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
463
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
464
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
465
+
466
+ return prompt_embeds
467
+
468
+ def _get_clip_prompt_embeds(
469
+ self,
470
+ prompt: Union[str, List[str]],
471
+ num_images_per_prompt: int = 1,
472
+ device: Optional[torch.device] = None,
473
+ clip_skip: Optional[int] = None,
474
+ clip_model_index: int = 0,
475
+ ):
476
+ device = device or self._execution_device
477
+
478
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
479
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
480
+
481
+ tokenizer = clip_tokenizers[clip_model_index]
482
+ text_encoder = clip_text_encoders[clip_model_index]
483
+
484
+ prompt = [prompt] if isinstance(prompt, str) else prompt
485
+ batch_size = len(prompt)
486
+
487
+ text_inputs = tokenizer(
488
+ prompt,
489
+ padding="max_length",
490
+ max_length=self.tokenizer_max_length,
491
+ truncation=True,
492
+ return_tensors="pt",
493
+ )
494
+
495
+ text_input_ids = text_inputs.input_ids
496
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
497
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
498
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
499
+ logger.warning(
500
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
501
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
502
+ )
503
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
504
+ pooled_prompt_embeds = prompt_embeds[0]
505
+
506
+ if clip_skip is None:
507
+ prompt_embeds = prompt_embeds.hidden_states[-2]
508
+ else:
509
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
510
+
511
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
512
+
513
+ _, seq_len, _ = prompt_embeds.shape
514
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
515
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
516
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
517
+
518
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
519
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
520
+
521
+ return prompt_embeds, pooled_prompt_embeds
522
+
523
+ def encode_prompt(
524
+ self,
525
+ prompt: Union[str, List[str]],
526
+ prompt_2: Union[str, List[str]],
527
+ prompt_3: Union[str, List[str]],
528
+ device: Optional[torch.device] = None,
529
+ num_images_per_prompt: int = 1,
530
+ do_classifier_free_guidance: bool = True,
531
+ negative_prompt: Optional[Union[str, List[str]]] = None,
532
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
533
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
534
+ prompt_embeds: Optional[torch.FloatTensor] = None,
535
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
536
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
537
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
538
+ clip_skip: Optional[int] = None,
539
+ max_sequence_length: int = 256,
540
+ lora_scale: Optional[float] = None,
541
+ ):
542
+ r"""
543
+
544
+ Args:
545
+ prompt (`str` or `List[str]`, *optional*):
546
+ prompt to be encoded
547
+ prompt_2 (`str` or `List[str]`, *optional*):
548
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
549
+ used in all text-encoders
550
+ prompt_3 (`str` or `List[str]`, *optional*):
551
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
552
+ used in all text-encoders
553
+ device: (`torch.device`):
554
+ torch device
555
+ num_images_per_prompt (`int`):
556
+ number of images that should be generated per prompt
557
+ do_classifier_free_guidance (`bool`):
558
+ whether to use classifier free guidance or not
559
+ negative_prompt (`str` or `List[str]`, *optional*):
560
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
561
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
562
+ less than `1`).
563
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
564
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
565
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
566
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
567
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
568
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
569
+ prompt_embeds (`torch.FloatTensor`, *optional*):
570
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
571
+ provided, text embeddings will be generated from `prompt` input argument.
572
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
573
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
574
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
575
+ argument.
576
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
577
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
578
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
579
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
580
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
581
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
582
+ input argument.
583
+ clip_skip (`int`, *optional*):
584
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
585
+ the output of the pre-final layer will be used for computing the prompt embeddings.
586
+ lora_scale (`float`, *optional*):
587
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
588
+ """
589
+ device = device or self._execution_device
590
+
591
+ # set lora scale so that monkey patched LoRA
592
+ # function of text encoder can correctly access it
593
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
594
+ self._lora_scale = lora_scale
595
+
596
+ # dynamically adjust the LoRA scale
597
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
598
+ scale_lora_layers(self.text_encoder, lora_scale)
599
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
600
+ scale_lora_layers(self.text_encoder_2, lora_scale)
601
+
602
+ prompt = [prompt] if isinstance(prompt, str) else prompt
603
+ if prompt is not None:
604
+ batch_size = len(prompt)
605
+ else:
606
+ batch_size = prompt_embeds.shape[0]
607
+
608
+ if prompt_embeds is None:
609
+ prompt_2 = prompt_2 or prompt
610
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
611
+
612
+ prompt_3 = prompt_3 or prompt
613
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
614
+
615
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
616
+ prompt=prompt,
617
+ device=device,
618
+ num_images_per_prompt=num_images_per_prompt,
619
+ clip_skip=clip_skip,
620
+ clip_model_index=0,
621
+ )
622
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
623
+ prompt=prompt_2,
624
+ device=device,
625
+ num_images_per_prompt=num_images_per_prompt,
626
+ clip_skip=clip_skip,
627
+ clip_model_index=1,
628
+ )
629
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
630
+
631
+ t5_prompt_embed = self._get_t5_prompt_embeds(
632
+ prompt=prompt_3,
633
+ num_images_per_prompt=num_images_per_prompt,
634
+ max_sequence_length=max_sequence_length,
635
+ device=device,
636
+ )
637
+
638
+ clip_prompt_embeds = torch.nn.functional.pad(
639
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
640
+ )
641
+
642
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
643
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
644
+
645
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
646
+ negative_prompt = negative_prompt or ""
647
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
648
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
649
+
650
+ # normalize str to list
651
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
652
+ negative_prompt_2 = (
653
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
654
+ )
655
+ negative_prompt_3 = (
656
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
657
+ )
658
+
659
+ if prompt is not None and type(prompt) is not type(negative_prompt):
660
+ raise TypeError(
661
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
662
+ f" {type(prompt)}."
663
+ )
664
+ elif batch_size != len(negative_prompt):
665
+ raise ValueError(
666
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
667
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
668
+ " the batch size of `prompt`."
669
+ )
670
+
671
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
672
+ negative_prompt,
673
+ device=device,
674
+ num_images_per_prompt=num_images_per_prompt,
675
+ clip_skip=None,
676
+ clip_model_index=0,
677
+ )
678
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
679
+ negative_prompt_2,
680
+ device=device,
681
+ num_images_per_prompt=num_images_per_prompt,
682
+ clip_skip=None,
683
+ clip_model_index=1,
684
+ )
685
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
686
+
687
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
688
+ prompt=negative_prompt_3,
689
+ num_images_per_prompt=num_images_per_prompt,
690
+ max_sequence_length=max_sequence_length,
691
+ device=device,
692
+ )
693
+
694
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
695
+ negative_clip_prompt_embeds,
696
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
697
+ )
698
+
699
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
700
+ negative_pooled_prompt_embeds = torch.cat(
701
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
702
+ )
703
+
704
+ if self.text_encoder is not None:
705
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
706
+ # Retrieve the original scale by scaling back the LoRA layers
707
+ unscale_lora_layers(self.text_encoder, lora_scale)
708
+
709
+ if self.text_encoder_2 is not None:
710
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
711
+ # Retrieve the original scale by scaling back the LoRA layers
712
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
713
+
714
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
715
+
716
+ def check_inputs(
717
+ self,
718
+ prompt,
719
+ prompt_2,
720
+ prompt_3,
721
+ height,
722
+ width,
723
+ negative_prompt=None,
724
+ negative_prompt_2=None,
725
+ negative_prompt_3=None,
726
+ prompt_embeds=None,
727
+ negative_prompt_embeds=None,
728
+ pooled_prompt_embeds=None,
729
+ negative_pooled_prompt_embeds=None,
730
+ callback_on_step_end_tensor_inputs=None,
731
+ max_sequence_length=None,
732
+ ):
733
+ if height % 8 != 0 or width % 8 != 0:
734
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
735
+
736
+ if callback_on_step_end_tensor_inputs is not None and not all(
737
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
738
+ ):
739
+ raise ValueError(
740
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
741
+ )
742
+
743
+ if prompt is not None and prompt_embeds is not None:
744
+ raise ValueError(
745
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
746
+ " only forward one of the two."
747
+ )
748
+ elif prompt_2 is not None and prompt_embeds is not None:
749
+ raise ValueError(
750
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
751
+ " only forward one of the two."
752
+ )
753
+ elif prompt_3 is not None and prompt_embeds is not None:
754
+ raise ValueError(
755
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
756
+ " only forward one of the two."
757
+ )
758
+ elif prompt is None and prompt_embeds is None:
759
+ raise ValueError(
760
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
761
+ )
762
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
763
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
764
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
765
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
766
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
767
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
768
+
769
+ if negative_prompt is not None and negative_prompt_embeds is not None:
770
+ raise ValueError(
771
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
772
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
773
+ )
774
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
775
+ raise ValueError(
776
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
777
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
778
+ )
779
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
780
+ raise ValueError(
781
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
782
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
783
+ )
784
+
785
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
786
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
787
+ raise ValueError(
788
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
789
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
790
+ f" {negative_prompt_embeds.shape}."
791
+ )
792
+
793
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
794
+ raise ValueError(
795
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
796
+ )
797
+
798
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
799
+ raise ValueError(
800
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
801
+ )
802
+
803
+ if max_sequence_length is not None and max_sequence_length > 512:
804
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
805
+
806
+ def prepare_latents(
807
+ self,
808
+ batch_size,
809
+ num_channels_latents,
810
+ height,
811
+ width,
812
+ dtype,
813
+ device,
814
+ generator,
815
+ latents=None,
816
+ ):
817
+ if latents is not None:
818
+ return latents.to(device=device, dtype=dtype)
819
+
820
+ shape = (
821
+ batch_size,
822
+ num_channels_latents,
823
+ int(height) // self.vae_scale_factor,
824
+ int(width) // self.vae_scale_factor,
825
+ )
826
+
827
+ if isinstance(generator, list) and len(generator) != batch_size:
828
+ raise ValueError(
829
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
830
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
831
+ )
832
+
833
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
834
+
835
+ return latents
836
+
837
+ @property
838
+ def guidance_scale(self):
839
+ return self._guidance_scale
840
+
841
+ @property
842
+ def clip_skip(self):
843
+ return self._clip_skip
844
+
845
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
846
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
847
+ # corresponds to doing no classifier free guidance.
848
+ @property
849
+ def do_classifier_free_guidance(self):
850
+ return self._guidance_scale > 1
851
+
852
+ @property
853
+ def joint_attention_kwargs(self):
854
+ return self._joint_attention_kwargs
855
+
856
+ @property
857
+ def num_timesteps(self):
858
+ return self._num_timesteps
859
+
860
+ @property
861
+ def interrupt(self):
862
+ return self._interrupt
863
+
864
+
865
+ @torch.inference_mode()
866
+ def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_dim=2432):
867
+ from transformers import SiglipVisionModel, SiglipImageProcessor
868
+ state_dict = torch.load(ip_adapter_path, map_location="cpu")
869
+
870
+ device, dtype = self.transformer.device, self.transformer.dtype
871
+ image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path)
872
+ image_processor = SiglipImageProcessor.from_pretrained(image_encoder_path)
873
+ image_encoder.eval()
874
+ image_encoder.to(device, dtype=dtype)
875
+ self.image_encoder = image_encoder
876
+ self.clip_image_processor = image_processor
877
+
878
+ sample_class = TimeResampler
879
+ image_proj_model = sample_class(
880
+ dim=1280,
881
+ depth=4,
882
+ dim_head=64,
883
+ heads=20,
884
+ num_queries=nb_token,
885
+ embedding_dim=1152,
886
+ output_dim=output_dim,
887
+ ff_mult=4,
888
+ timestep_in_dim=320,
889
+ timestep_flip_sin_to_cos=True,
890
+ timestep_freq_shift=0,
891
+ )
892
+ image_proj_model.eval()
893
+ image_proj_model.to(device, dtype=dtype)
894
+ key_name = image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
895
+ print(f"=> loading image_proj_model: {key_name}")
896
+
897
+ self.image_proj_model = image_proj_model
898
+
899
+
900
+ attn_procs = {}
901
+ transformer = self.transformer
902
+ for idx_name, name in enumerate(transformer.attn_processors.keys()):
903
+ hidden_size = transformer.config.attention_head_dim * transformer.config.num_attention_heads
904
+ ip_hidden_states_dim = transformer.config.attention_head_dim * transformer.config.num_attention_heads
905
+ ip_encoder_hidden_states_dim = transformer.config.caption_projection_dim
906
+
907
+ attn_procs[name] = JointIPAttnProcessor(
908
+ hidden_size=hidden_size,
909
+ cross_attention_dim=transformer.config.caption_projection_dim,
910
+ ip_hidden_states_dim=ip_hidden_states_dim,
911
+ ip_encoder_hidden_states_dim=ip_encoder_hidden_states_dim,
912
+ head_dim=transformer.config.attention_head_dim,
913
+ timesteps_emb_dim=1280,
914
+ ).to(device, dtype=dtype)
915
+
916
+ self.transformer.set_attn_processor(attn_procs)
917
+ tmp_ip_layers = torch.nn.ModuleList(self.transformer.attn_processors.values())
918
+
919
+ key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
920
+ print(f"=> loading ip_adapter: {key_name}")
921
+
922
+
923
+ @torch.inference_mode()
924
+ def encode_clip_image_emb(self, clip_image, device, dtype):
925
+
926
+ # clip
927
+ clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
928
+ clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
929
+ clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
930
+ clip_image_embeds = torch.cat([torch.zeros_like(clip_image_embeds), clip_image_embeds], dim=0)
931
+
932
+ return clip_image_embeds
933
+
934
+
935
+
936
+ @torch.no_grad()
937
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
938
+ def __call__(
939
+ self,
940
+ prompt: Union[str, List[str]] = None,
941
+ prompt_2: Optional[Union[str, List[str]]] = None,
942
+ prompt_3: Optional[Union[str, List[str]]] = None,
943
+ height: Optional[int] = None,
944
+ width: Optional[int] = None,
945
+ num_inference_steps: int = 28,
946
+ timesteps: List[int] = None,
947
+ guidance_scale: float = 7.0,
948
+ negative_prompt: Optional[Union[str, List[str]]] = None,
949
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
950
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
951
+ num_images_per_prompt: Optional[int] = 1,
952
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
953
+ latents: Optional[torch.FloatTensor] = None,
954
+ prompt_embeds: Optional[torch.FloatTensor] = None,
955
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
956
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
957
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
958
+ output_type: Optional[str] = "pil",
959
+ return_dict: bool = True,
960
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
961
+ clip_skip: Optional[int] = None,
962
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
963
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
964
+ max_sequence_length: int = 256,
965
+
966
+ # ipa
967
+ clip_image=None,
968
+ ipadapter_scale=1.0,
969
+ ):
970
+ r"""
971
+ Function invoked when calling the pipeline for generation.
972
+
973
+ Args:
974
+ prompt (`str` or `List[str]`, *optional*):
975
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
976
+ instead.
977
+ prompt_2 (`str` or `List[str]`, *optional*):
978
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
979
+ will be used instead
980
+ prompt_3 (`str` or `List[str]`, *optional*):
981
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
982
+ will be used instead
983
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
984
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
985
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
986
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
987
+ num_inference_steps (`int`, *optional*, defaults to 50):
988
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
989
+ expense of slower inference.
990
+ timesteps (`List[int]`, *optional*):
991
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
992
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
993
+ passed will be used. Must be in descending order.
994
+ guidance_scale (`float`, *optional*, defaults to 7.0):
995
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
996
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
997
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
998
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
999
+ usually at the expense of lower image quality.
1000
+ negative_prompt (`str` or `List[str]`, *optional*):
1001
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1002
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1003
+ less than `1`).
1004
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1005
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1006
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
1007
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
1008
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
1009
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
1010
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1011
+ The number of images to generate per prompt.
1012
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1013
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1014
+ to make generation deterministic.
1015
+ latents (`torch.FloatTensor`, *optional*):
1016
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1017
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1018
+ tensor will ge generated by sampling using the supplied random `generator`.
1019
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1020
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1021
+ provided, text embeddings will be generated from `prompt` input argument.
1022
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1023
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1024
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1025
+ argument.
1026
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1027
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1028
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1029
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1030
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1031
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1032
+ input argument.
1033
+ output_type (`str`, *optional*, defaults to `"pil"`):
1034
+ The output format of the generate image. Choose between
1035
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1036
+ return_dict (`bool`, *optional*, defaults to `True`):
1037
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
1038
+ of a plain tuple.
1039
+ joint_attention_kwargs (`dict`, *optional*):
1040
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1041
+ `self.processor` in
1042
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1043
+ callback_on_step_end (`Callable`, *optional*):
1044
+ A function that calls at the end of each denoising steps during the inference. The function is called
1045
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1046
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1047
+ `callback_on_step_end_tensor_inputs`.
1048
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1049
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1050
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1051
+ `._callback_tensor_inputs` attribute of your pipeline class.
1052
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
1053
+
1054
+ Examples:
1055
+
1056
+ Returns:
1057
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
1058
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
1059
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1060
+ """
1061
+
1062
+ height = height or self.default_sample_size * self.vae_scale_factor
1063
+ width = width or self.default_sample_size * self.vae_scale_factor
1064
+
1065
+ # 1. Check inputs. Raise error if not correct
1066
+ self.check_inputs(
1067
+ prompt,
1068
+ prompt_2,
1069
+ prompt_3,
1070
+ height,
1071
+ width,
1072
+ negative_prompt=negative_prompt,
1073
+ negative_prompt_2=negative_prompt_2,
1074
+ negative_prompt_3=negative_prompt_3,
1075
+ prompt_embeds=prompt_embeds,
1076
+ negative_prompt_embeds=negative_prompt_embeds,
1077
+ pooled_prompt_embeds=pooled_prompt_embeds,
1078
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1079
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1080
+ max_sequence_length=max_sequence_length,
1081
+ )
1082
+
1083
+ self._guidance_scale = guidance_scale
1084
+ self._clip_skip = clip_skip
1085
+ self._joint_attention_kwargs = joint_attention_kwargs
1086
+ self._interrupt = False
1087
+
1088
+ # 2. Define call parameters
1089
+ if prompt is not None and isinstance(prompt, str):
1090
+ batch_size = 1
1091
+ elif prompt is not None and isinstance(prompt, list):
1092
+ batch_size = len(prompt)
1093
+ else:
1094
+ batch_size = prompt_embeds.shape[0]
1095
+
1096
+ device = self._execution_device
1097
+ dtype = self.transformer.dtype
1098
+
1099
+ lora_scale = (
1100
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1101
+ )
1102
+ (
1103
+ prompt_embeds,
1104
+ negative_prompt_embeds,
1105
+ pooled_prompt_embeds,
1106
+ negative_pooled_prompt_embeds,
1107
+ ) = self.encode_prompt(
1108
+ prompt=prompt,
1109
+ prompt_2=prompt_2,
1110
+ prompt_3=prompt_3,
1111
+ negative_prompt=negative_prompt,
1112
+ negative_prompt_2=negative_prompt_2,
1113
+ negative_prompt_3=negative_prompt_3,
1114
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1115
+ prompt_embeds=prompt_embeds,
1116
+ negative_prompt_embeds=negative_prompt_embeds,
1117
+ pooled_prompt_embeds=pooled_prompt_embeds,
1118
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1119
+ device=device,
1120
+ clip_skip=self.clip_skip,
1121
+ num_images_per_prompt=num_images_per_prompt,
1122
+ max_sequence_length=max_sequence_length,
1123
+ lora_scale=lora_scale,
1124
+ )
1125
+
1126
+ if self.do_classifier_free_guidance:
1127
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1128
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1129
+
1130
+ # 3. prepare clip emb
1131
+ clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1132
+ clip_image_embeds = self.encode_clip_image_emb(clip_image, device, dtype)
1133
+
1134
+ # 4. Prepare timesteps
1135
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1136
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1137
+ self._num_timesteps = len(timesteps)
1138
+
1139
+ # 5. Prepare latent variables
1140
+ num_channels_latents = self.transformer.config.in_channels
1141
+ latents = self.prepare_latents(
1142
+ batch_size * num_images_per_prompt,
1143
+ num_channels_latents,
1144
+ height,
1145
+ width,
1146
+ prompt_embeds.dtype,
1147
+ device,
1148
+ generator,
1149
+ latents,
1150
+ )
1151
+
1152
+ # 6. Denoising loop
1153
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1154
+ for i, t in enumerate(timesteps):
1155
+ if self.interrupt:
1156
+ continue
1157
+
1158
+ # expand the latents if we are doing classifier free guidance
1159
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1160
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1161
+ timestep = t.expand(latent_model_input.shape[0])
1162
+
1163
+ image_prompt_embeds, timestep_emb = self.image_proj_model(
1164
+ clip_image_embeds,
1165
+ timestep.to(dtype=latents.dtype),
1166
+ need_temb=True
1167
+ )
1168
+
1169
+ joint_attention_kwargs = dict(
1170
+ emb_dict=dict(
1171
+ ip_hidden_states=image_prompt_embeds,
1172
+ temb=timestep_emb,
1173
+ scale=ipadapter_scale,
1174
+ )
1175
+ )
1176
+
1177
+ noise_pred = self.transformer(
1178
+ hidden_states=latent_model_input,
1179
+ timestep=timestep,
1180
+ encoder_hidden_states=prompt_embeds,
1181
+ pooled_projections=pooled_prompt_embeds,
1182
+ joint_attention_kwargs=joint_attention_kwargs,
1183
+ return_dict=False,
1184
+ )[0]
1185
+
1186
+ # perform guidance
1187
+ if self.do_classifier_free_guidance:
1188
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1189
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1190
+
1191
+ # compute the previous noisy sample x_t -> x_t-1
1192
+ latents_dtype = latents.dtype
1193
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1194
+
1195
+ if latents.dtype != latents_dtype:
1196
+ if torch.backends.mps.is_available():
1197
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1198
+ latents = latents.to(latents_dtype)
1199
+
1200
+ if callback_on_step_end is not None:
1201
+ callback_kwargs = {}
1202
+ for k in callback_on_step_end_tensor_inputs:
1203
+ callback_kwargs[k] = locals()[k]
1204
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1205
+
1206
+ latents = callback_outputs.pop("latents", latents)
1207
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1208
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1209
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1210
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1211
+ )
1212
+
1213
+ # call the callback, if provided
1214
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1215
+ progress_bar.update()
1216
+
1217
+ if XLA_AVAILABLE:
1218
+ xm.mark_step()
1219
+
1220
+ if output_type == "latent":
1221
+ image = latents
1222
+
1223
+ else:
1224
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1225
+
1226
+ image = self.vae.decode(latents, return_dict=False)[0]
1227
+ image = self.image_processor.postprocess(image, output_type=output_type)
1228
+
1229
+ # Offload all models
1230
+ self.maybe_free_model_hooks()
1231
+
1232
+ if not return_dict:
1233
+ return (image,)
1234
+
1235
+ return StableDiffusion3PipelineOutput(images=image)
teasers/0.png ADDED

Git LFS Details

  • SHA256: 6325e12735c57a61449fc94330d6e1e744977994bedff1fe6a2f37588d0a448e
  • Pointer size: 132 Bytes
  • Size of remote file: 5.2 MB
teasers/1.png ADDED

Git LFS Details

  • SHA256: 6bdca1eae51d34f587bea5cc218e861f1c88678c9f69d12ba1931fbcc567e9db
  • Pointer size: 132 Bytes
  • Size of remote file: 5.32 MB