wooyeolbaek commited on
Commit
c71e173
·
verified ·
1 Parent(s): 5cb0966

Create modules.py

Browse files
Files changed (1) hide show
  1. modules.py +1757 -0
modules.py ADDED
@@ -0,0 +1,1757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import inspect
3
+ import numpy as np
4
+ from typing import Any, Dict, Optional, Tuple, Union, List, Callable
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ from diffusers.models.attention import _chunked_feed_forward
11
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
12
+ from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
13
+ from diffusers.pipelines.flux.pipeline_flux import (
14
+ retrieve_timesteps,
15
+ replace_example_docstring,
16
+ EXAMPLE_DOC_STRING,
17
+ calculate_shift,
18
+ XLA_AVAILABLE,
19
+ FluxPipelineOutput
20
+ )
21
+ # from diffusers.models.transformers import FLUXTransformer2DModel
22
+ from diffusers.utils import (
23
+ deprecate,
24
+ BaseOutput,
25
+ is_torch_version,
26
+ logging,
27
+ USE_PEFT_BACKEND,
28
+ scale_lora_layers,
29
+ unscale_lora_layers,
30
+ )
31
+ from diffusers.models.attention_processor import (
32
+ Attention,
33
+ AttnProcessor,
34
+ AttnProcessor2_0,
35
+ )
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ attn_maps = {}
42
+
43
+
44
+ @torch.no_grad()
45
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
46
+ def FluxPipeline_call(
47
+ self,
48
+ prompt: Union[str, List[str]] = None,
49
+ prompt_2: Optional[Union[str, List[str]]] = None,
50
+ height: Optional[int] = None,
51
+ width: Optional[int] = None,
52
+ num_inference_steps: int = 28,
53
+ timesteps: List[int] = None,
54
+ guidance_scale: float = 3.5,
55
+ num_images_per_prompt: Optional[int] = 1,
56
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
57
+ latents: Optional[torch.FloatTensor] = None,
58
+ prompt_embeds: Optional[torch.FloatTensor] = None,
59
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
60
+ output_type: Optional[str] = "pil",
61
+ return_dict: bool = True,
62
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
63
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
64
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
65
+ max_sequence_length: int = 512,
66
+ ):
67
+ r"""
68
+ Function invoked when calling the pipeline for generation.
69
+
70
+ Args:
71
+ prompt (`str` or `List[str]`, *optional*):
72
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
73
+ instead.
74
+ prompt_2 (`str` or `List[str]`, *optional*):
75
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
76
+ will be used instead
77
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
78
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
79
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
80
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
81
+ num_inference_steps (`int`, *optional*, defaults to 50):
82
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
83
+ expense of slower inference.
84
+ timesteps (`List[int]`, *optional*):
85
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
86
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
87
+ passed will be used. Must be in descending order.
88
+ guidance_scale (`float`, *optional*, defaults to 7.0):
89
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
90
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
91
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
92
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
93
+ usually at the expense of lower image quality.
94
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
95
+ The number of images to generate per prompt.
96
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
97
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
98
+ to make generation deterministic.
99
+ latents (`torch.FloatTensor`, *optional*):
100
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
101
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
102
+ tensor will ge generated by sampling using the supplied random `generator`.
103
+ prompt_embeds (`torch.FloatTensor`, *optional*):
104
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
105
+ provided, text embeddings will be generated from `prompt` input argument.
106
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
107
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
108
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
109
+ output_type (`str`, *optional*, defaults to `"pil"`):
110
+ The output format of the generate image. Choose between
111
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
112
+ return_dict (`bool`, *optional*, defaults to `True`):
113
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
114
+ joint_attention_kwargs (`dict`, *optional*):
115
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
116
+ `self.processor` in
117
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
118
+ callback_on_step_end (`Callable`, *optional*):
119
+ A function that calls at the end of each denoising steps during the inference. The function is called
120
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
121
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
122
+ `callback_on_step_end_tensor_inputs`.
123
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
124
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
125
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
126
+ `._callback_tensor_inputs` attribute of your pipeline class.
127
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
128
+
129
+ Examples:
130
+
131
+ Returns:
132
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
133
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
134
+ images.
135
+ """
136
+
137
+ height = height or self.default_sample_size * self.vae_scale_factor
138
+ width = width or self.default_sample_size * self.vae_scale_factor
139
+
140
+ # 1. Check inputs. Raise error if not correct
141
+ self.check_inputs(
142
+ prompt,
143
+ prompt_2,
144
+ height,
145
+ width,
146
+ prompt_embeds=prompt_embeds,
147
+ pooled_prompt_embeds=pooled_prompt_embeds,
148
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
149
+ max_sequence_length=max_sequence_length,
150
+ )
151
+
152
+ self._guidance_scale = guidance_scale
153
+ self._joint_attention_kwargs = joint_attention_kwargs
154
+ self._interrupt = False
155
+
156
+ # 2. Define call parameters
157
+ if prompt is not None and isinstance(prompt, str):
158
+ batch_size = 1
159
+ elif prompt is not None and isinstance(prompt, list):
160
+ batch_size = len(prompt)
161
+ else:
162
+ batch_size = prompt_embeds.shape[0]
163
+
164
+ device = self._execution_device
165
+
166
+ lora_scale = (
167
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
168
+ )
169
+ (
170
+ prompt_embeds,
171
+ pooled_prompt_embeds,
172
+ text_ids,
173
+ ) = self.encode_prompt(
174
+ prompt=prompt,
175
+ prompt_2=prompt_2,
176
+ prompt_embeds=prompt_embeds,
177
+ pooled_prompt_embeds=pooled_prompt_embeds,
178
+ device=device,
179
+ num_images_per_prompt=num_images_per_prompt,
180
+ max_sequence_length=max_sequence_length,
181
+ lora_scale=lora_scale,
182
+ )
183
+
184
+ # 4. Prepare latent variables
185
+ num_channels_latents = self.transformer.config.in_channels // 4
186
+ latents, latent_image_ids = self.prepare_latents(
187
+ batch_size * num_images_per_prompt,
188
+ num_channels_latents,
189
+ height,
190
+ width,
191
+ prompt_embeds.dtype,
192
+ device,
193
+ generator,
194
+ latents,
195
+ )
196
+
197
+ # 5. Prepare timesteps
198
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
199
+ image_seq_len = latents.shape[1]
200
+ mu = calculate_shift(
201
+ image_seq_len,
202
+ self.scheduler.config.base_image_seq_len,
203
+ self.scheduler.config.max_image_seq_len,
204
+ self.scheduler.config.base_shift,
205
+ self.scheduler.config.max_shift,
206
+ )
207
+ timesteps, num_inference_steps = retrieve_timesteps(
208
+ self.scheduler,
209
+ num_inference_steps,
210
+ device,
211
+ timesteps,
212
+ sigmas,
213
+ mu=mu,
214
+ )
215
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
216
+ self._num_timesteps = len(timesteps)
217
+
218
+ # handle guidance
219
+ if self.transformer.config.guidance_embeds:
220
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
221
+ guidance = guidance.expand(latents.shape[0])
222
+ else:
223
+ guidance = None
224
+
225
+ # 6. Denoising loop
226
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
227
+ for i, t in enumerate(timesteps):
228
+ if self.interrupt:
229
+ continue
230
+
231
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
232
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
233
+
234
+ noise_pred = self.transformer(
235
+ hidden_states=latents,
236
+ timestep=timestep / 1000,
237
+ guidance=guidance,
238
+ pooled_projections=pooled_prompt_embeds,
239
+ encoder_hidden_states=prompt_embeds,
240
+ txt_ids=text_ids,
241
+ img_ids=latent_image_ids,
242
+ joint_attention_kwargs=self.joint_attention_kwargs,
243
+ return_dict=False,
244
+ ##################################################
245
+ height=height,
246
+ ##################################################
247
+ )[0]
248
+
249
+ # compute the previous noisy sample x_t -> x_t-1
250
+ latents_dtype = latents.dtype
251
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
252
+
253
+ if latents.dtype != latents_dtype:
254
+ if torch.backends.mps.is_available():
255
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
256
+ latents = latents.to(latents_dtype)
257
+
258
+ if callback_on_step_end is not None:
259
+ callback_kwargs = {}
260
+ for k in callback_on_step_end_tensor_inputs:
261
+ callback_kwargs[k] = locals()[k]
262
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
263
+
264
+ latents = callback_outputs.pop("latents", latents)
265
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
266
+
267
+ # call the callback, if provided
268
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
269
+ progress_bar.update()
270
+
271
+ if XLA_AVAILABLE:
272
+ xm.mark_step()
273
+
274
+ if output_type == "latent":
275
+ image = latents
276
+
277
+ else:
278
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
279
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
280
+ image = self.vae.decode(latents, return_dict=False)[0]
281
+ image = self.image_processor.postprocess(image, output_type=output_type)
282
+
283
+ # Offload all models
284
+ self.maybe_free_model_hooks()
285
+
286
+ if not return_dict:
287
+ return (image,)
288
+
289
+ return FluxPipelineOutput(images=image)
290
+
291
+
292
+ def UNet2DConditionModelForward(
293
+ self,
294
+ sample: torch.Tensor,
295
+ timestep: Union[torch.Tensor, float, int],
296
+ encoder_hidden_states: torch.Tensor,
297
+ class_labels: Optional[torch.Tensor] = None,
298
+ timestep_cond: Optional[torch.Tensor] = None,
299
+ attention_mask: Optional[torch.Tensor] = None,
300
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
301
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
302
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
303
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
304
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
305
+ encoder_attention_mask: Optional[torch.Tensor] = None,
306
+ return_dict: bool = True,
307
+ ) -> Union[UNet2DConditionOutput, Tuple]:
308
+ r"""
309
+ The [`UNet2DConditionModel`] forward method.
310
+
311
+ Args:
312
+ sample (`torch.Tensor`):
313
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
314
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
315
+ encoder_hidden_states (`torch.Tensor`):
316
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
317
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
318
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
319
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
320
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
321
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
322
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
323
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
324
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
325
+ negative values to the attention scores corresponding to "discard" tokens.
326
+ cross_attention_kwargs (`dict`, *optional*):
327
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
328
+ `self.processor` in
329
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
330
+ added_cond_kwargs: (`dict`, *optional*):
331
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
332
+ are passed along to the UNet blocks.
333
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
334
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
335
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
336
+ A tensor that if specified is added to the residual of the middle unet block.
337
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
338
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
339
+ encoder_attention_mask (`torch.Tensor`):
340
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
341
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
342
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
343
+ return_dict (`bool`, *optional*, defaults to `True`):
344
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
345
+ tuple.
346
+
347
+ Returns:
348
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
349
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
350
+ otherwise a `tuple` is returned where the first element is the sample tensor.
351
+ """
352
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
353
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
354
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
355
+ # on the fly if necessary.
356
+ default_overall_up_factor = 2**self.num_upsamplers
357
+
358
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
359
+ forward_upsample_size = False
360
+ upsample_size = None
361
+
362
+ for dim in sample.shape[-2:]:
363
+ if dim % default_overall_up_factor != 0:
364
+ # Forward upsample size to force interpolation output size.
365
+ forward_upsample_size = True
366
+ break
367
+
368
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
369
+ # expects mask of shape:
370
+ # [batch, key_tokens]
371
+ # adds singleton query_tokens dimension:
372
+ # [batch, 1, key_tokens]
373
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
374
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
375
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
376
+ if attention_mask is not None:
377
+ # assume that mask is expressed as:
378
+ # (1 = keep, 0 = discard)
379
+ # convert mask into a bias that can be added to attention scores:
380
+ # (keep = +0, discard = -10000.0)
381
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
382
+ attention_mask = attention_mask.unsqueeze(1)
383
+
384
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
385
+ if encoder_attention_mask is not None:
386
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
387
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
388
+
389
+ # 0. center input if necessary
390
+ if self.config.center_input_sample:
391
+ sample = 2 * sample - 1.0
392
+
393
+ # 1. time
394
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
395
+ emb = self.time_embedding(t_emb, timestep_cond)
396
+ aug_emb = None
397
+
398
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
399
+ if class_emb is not None:
400
+ if self.config.class_embeddings_concat:
401
+ emb = torch.cat([emb, class_emb], dim=-1)
402
+ else:
403
+ emb = emb + class_emb
404
+
405
+ aug_emb = self.get_aug_embed(
406
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
407
+ )
408
+ if self.config.addition_embed_type == "image_hint":
409
+ aug_emb, hint = aug_emb
410
+ sample = torch.cat([sample, hint], dim=1)
411
+
412
+ emb = emb + aug_emb if aug_emb is not None else emb
413
+
414
+ if self.time_embed_act is not None:
415
+ emb = self.time_embed_act(emb)
416
+
417
+ encoder_hidden_states = self.process_encoder_hidden_states(
418
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
419
+ )
420
+
421
+ # 2. pre-process
422
+ sample = self.conv_in(sample)
423
+
424
+ # 2.5 GLIGEN position net
425
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
426
+ cross_attention_kwargs = cross_attention_kwargs.copy()
427
+ gligen_args = cross_attention_kwargs.pop("gligen")
428
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
429
+
430
+ # 3. down
431
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
432
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
433
+ ################################################################################
434
+ if cross_attention_kwargs is None:
435
+ cross_attention_kwargs = {'timestep' : timestep}
436
+ else:
437
+ cross_attention_kwargs['timestep'] = timestep
438
+ ################################################################################
439
+
440
+
441
+ if cross_attention_kwargs is not None:
442
+ cross_attention_kwargs = cross_attention_kwargs.copy()
443
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
444
+ else:
445
+ lora_scale = 1.0
446
+
447
+ if USE_PEFT_BACKEND:
448
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
449
+ scale_lora_layers(self, lora_scale)
450
+
451
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
452
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
453
+ is_adapter = down_intrablock_additional_residuals is not None
454
+ # maintain backward compatibility for legacy usage, where
455
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
456
+ # but can only use one or the other
457
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
458
+ deprecate(
459
+ "T2I should not use down_block_additional_residuals",
460
+ "1.3.0",
461
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
462
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
463
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
464
+ standard_warn=False,
465
+ )
466
+ down_intrablock_additional_residuals = down_block_additional_residuals
467
+ is_adapter = True
468
+
469
+ down_block_res_samples = (sample,)
470
+ for downsample_block in self.down_blocks:
471
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
472
+ # For t2i-adapter CrossAttnDownBlock2D
473
+ additional_residuals = {}
474
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
475
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
476
+
477
+ sample, res_samples = downsample_block(
478
+ hidden_states=sample,
479
+ temb=emb,
480
+ encoder_hidden_states=encoder_hidden_states,
481
+ attention_mask=attention_mask,
482
+ cross_attention_kwargs=cross_attention_kwargs,
483
+ encoder_attention_mask=encoder_attention_mask,
484
+ **additional_residuals,
485
+ )
486
+ else:
487
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
488
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
489
+ sample += down_intrablock_additional_residuals.pop(0)
490
+
491
+ down_block_res_samples += res_samples
492
+
493
+ if is_controlnet:
494
+ new_down_block_res_samples = ()
495
+
496
+ for down_block_res_sample, down_block_additional_residual in zip(
497
+ down_block_res_samples, down_block_additional_residuals
498
+ ):
499
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
500
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
501
+
502
+ down_block_res_samples = new_down_block_res_samples
503
+
504
+ # 4. mid
505
+ if self.mid_block is not None:
506
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
507
+ sample = self.mid_block(
508
+ sample,
509
+ emb,
510
+ encoder_hidden_states=encoder_hidden_states,
511
+ attention_mask=attention_mask,
512
+ cross_attention_kwargs=cross_attention_kwargs,
513
+ encoder_attention_mask=encoder_attention_mask,
514
+ )
515
+ else:
516
+ sample = self.mid_block(sample, emb)
517
+
518
+ # To support T2I-Adapter-XL
519
+ if (
520
+ is_adapter
521
+ and len(down_intrablock_additional_residuals) > 0
522
+ and sample.shape == down_intrablock_additional_residuals[0].shape
523
+ ):
524
+ sample += down_intrablock_additional_residuals.pop(0)
525
+
526
+ if is_controlnet:
527
+ sample = sample + mid_block_additional_residual
528
+
529
+ # 5. up
530
+ for i, upsample_block in enumerate(self.up_blocks):
531
+ is_final_block = i == len(self.up_blocks) - 1
532
+
533
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
534
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
535
+
536
+ # if we have not reached the final block and need to forward the
537
+ # upsample size, we do it here
538
+ if not is_final_block and forward_upsample_size:
539
+ upsample_size = down_block_res_samples[-1].shape[2:]
540
+
541
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
542
+ sample = upsample_block(
543
+ hidden_states=sample,
544
+ temb=emb,
545
+ res_hidden_states_tuple=res_samples,
546
+ encoder_hidden_states=encoder_hidden_states,
547
+ cross_attention_kwargs=cross_attention_kwargs,
548
+ upsample_size=upsample_size,
549
+ attention_mask=attention_mask,
550
+ encoder_attention_mask=encoder_attention_mask,
551
+ )
552
+ else:
553
+ sample = upsample_block(
554
+ hidden_states=sample,
555
+ temb=emb,
556
+ res_hidden_states_tuple=res_samples,
557
+ upsample_size=upsample_size,
558
+ )
559
+
560
+ # 6. post-process
561
+ if self.conv_norm_out:
562
+ sample = self.conv_norm_out(sample)
563
+ sample = self.conv_act(sample)
564
+ sample = self.conv_out(sample)
565
+
566
+ if USE_PEFT_BACKEND:
567
+ # remove `lora_scale` from each PEFT layer
568
+ unscale_lora_layers(self, lora_scale)
569
+
570
+ if not return_dict:
571
+ return (sample,)
572
+
573
+ return UNet2DConditionOutput(sample=sample)
574
+
575
+
576
+ def SD3Transformer2DModelForward(
577
+ self,
578
+ hidden_states: torch.FloatTensor,
579
+ encoder_hidden_states: torch.FloatTensor = None,
580
+ pooled_projections: torch.FloatTensor = None,
581
+ timestep: torch.LongTensor = None,
582
+ block_controlnet_hidden_states: List = None,
583
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
584
+ return_dict: bool = True,
585
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
586
+ """
587
+ The [`SD3Transformer2DModel`] forward method.
588
+
589
+ Args:
590
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
591
+ Input `hidden_states`.
592
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
593
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
594
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
595
+ from the embeddings of input conditions.
596
+ timestep ( `torch.LongTensor`):
597
+ Used to indicate denoising step.
598
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
599
+ A list of tensors that if specified are added to the residuals of transformer blocks.
600
+ joint_attention_kwargs (`dict`, *optional*):
601
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
602
+ `self.processor` in
603
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
604
+ return_dict (`bool`, *optional*, defaults to `True`):
605
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
606
+ tuple.
607
+
608
+ Returns:
609
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
610
+ `tuple` where the first element is the sample tensor.
611
+ """
612
+ if joint_attention_kwargs is not None:
613
+ joint_attention_kwargs = joint_attention_kwargs.copy()
614
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
615
+ else:
616
+ lora_scale = 1.0
617
+
618
+ if USE_PEFT_BACKEND:
619
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
620
+ scale_lora_layers(self, lora_scale)
621
+ else:
622
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
623
+ logger.warning(
624
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
625
+ )
626
+
627
+ height, width = hidden_states.shape[-2:]
628
+
629
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
630
+ temb = self.time_text_embed(timestep, pooled_projections)
631
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
632
+
633
+ for index_block, block in enumerate(self.transformer_blocks):
634
+ if self.training and self.gradient_checkpointing:
635
+
636
+ def create_custom_forward(module, return_dict=None):
637
+ def custom_forward(*inputs):
638
+ if return_dict is not None:
639
+ return module(*inputs, return_dict=return_dict)
640
+ else:
641
+ return module(*inputs)
642
+
643
+ return custom_forward
644
+
645
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
646
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
647
+ create_custom_forward(block),
648
+ hidden_states,
649
+ encoder_hidden_states,
650
+ temb,
651
+ **ckpt_kwargs,
652
+ )
653
+
654
+ else:
655
+ encoder_hidden_states, hidden_states = block(
656
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
657
+ ##########################################################################################
658
+ timestep=timestep, height=height // self.config.patch_size,
659
+ ##########################################################################################
660
+ )
661
+
662
+ # controlnet residual
663
+ if block_controlnet_hidden_states is not None and block.context_pre_only is False:
664
+ interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
665
+ hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
666
+
667
+ hidden_states = self.norm_out(hidden_states, temb)
668
+ hidden_states = self.proj_out(hidden_states)
669
+
670
+ # unpatchify
671
+ patch_size = self.config.patch_size
672
+ height = height // patch_size
673
+ width = width // patch_size
674
+
675
+ hidden_states = hidden_states.reshape(
676
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
677
+ )
678
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
679
+ output = hidden_states.reshape(
680
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
681
+ )
682
+
683
+ if USE_PEFT_BACKEND:
684
+ # remove `lora_scale` from each PEFT layer
685
+ unscale_lora_layers(self, lora_scale)
686
+
687
+ if not return_dict:
688
+ return (output,)
689
+
690
+ return Transformer2DModelOutput(sample=output)
691
+
692
+
693
+ def FluxTransformer2DModelForward(
694
+ self,
695
+ hidden_states: torch.Tensor,
696
+ encoder_hidden_states: torch.Tensor = None,
697
+ pooled_projections: torch.Tensor = None,
698
+ timestep: torch.LongTensor = None,
699
+ img_ids: torch.Tensor = None,
700
+ txt_ids: torch.Tensor = None,
701
+ guidance: torch.Tensor = None,
702
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
703
+ controlnet_block_samples=None,
704
+ controlnet_single_block_samples=None,
705
+ return_dict: bool = True,
706
+ controlnet_blocks_repeat: bool = False,
707
+ ##################################################
708
+ height: int = None,
709
+ ##################################################
710
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
711
+ """
712
+ The [`FluxTransformer2DModel`] forward method.
713
+
714
+ Args:
715
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
716
+ Input `hidden_states`.
717
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
718
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
719
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
720
+ from the embeddings of input conditions.
721
+ timestep ( `torch.LongTensor`):
722
+ Used to indicate denoising step.
723
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
724
+ A list of tensors that if specified are added to the residuals of transformer blocks.
725
+ joint_attention_kwargs (`dict`, *optional*):
726
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
727
+ `self.processor` in
728
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
729
+ return_dict (`bool`, *optional*, defaults to `True`):
730
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
731
+ tuple.
732
+
733
+ Returns:
734
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
735
+ `tuple` where the first element is the sample tensor.
736
+ """
737
+ if joint_attention_kwargs is not None:
738
+ joint_attention_kwargs = joint_attention_kwargs.copy()
739
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
740
+ else:
741
+ lora_scale = 1.0
742
+
743
+ if USE_PEFT_BACKEND:
744
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
745
+ scale_lora_layers(self, lora_scale)
746
+ else:
747
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
748
+ logger.warning(
749
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
750
+ )
751
+ hidden_states = self.x_embedder(hidden_states)
752
+
753
+ timestep = timestep.to(hidden_states.dtype) * 1000
754
+ if guidance is not None:
755
+ guidance = guidance.to(hidden_states.dtype) * 1000
756
+ else:
757
+ guidance = None
758
+ temb = (
759
+ self.time_text_embed(timestep, pooled_projections)
760
+ if guidance is None
761
+ else self.time_text_embed(timestep, guidance, pooled_projections)
762
+ )
763
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
764
+
765
+ if txt_ids.ndim == 3:
766
+ logger.warning(
767
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
768
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
769
+ )
770
+ txt_ids = txt_ids[0]
771
+ if img_ids.ndim == 3:
772
+ logger.warning(
773
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
774
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
775
+ )
776
+ img_ids = img_ids[0]
777
+
778
+ ids = torch.cat((txt_ids, img_ids), dim=0)
779
+ image_rotary_emb = self.pos_embed(ids)
780
+
781
+ for index_block, block in enumerate(self.transformer_blocks):
782
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
783
+
784
+ def create_custom_forward(module, return_dict=None):
785
+ def custom_forward(*inputs):
786
+ if return_dict is not None:
787
+ return module(*inputs, return_dict=return_dict)
788
+ else:
789
+ return module(*inputs)
790
+
791
+ return custom_forward
792
+
793
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
794
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
795
+ create_custom_forward(block),
796
+ hidden_states,
797
+ encoder_hidden_states,
798
+ temb,
799
+ image_rotary_emb,
800
+ **ckpt_kwargs,
801
+ )
802
+
803
+ else:
804
+ encoder_hidden_states, hidden_states = block(
805
+ hidden_states=hidden_states,
806
+ encoder_hidden_states=encoder_hidden_states,
807
+ temb=temb,
808
+ image_rotary_emb=image_rotary_emb,
809
+ joint_attention_kwargs=joint_attention_kwargs,
810
+ ##########################################################################################
811
+ timestep=timestep, height=height // self.config.patch_size,
812
+ ##########################################################################################
813
+ )
814
+
815
+ # controlnet residual
816
+ if controlnet_block_samples is not None:
817
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
818
+ interval_control = int(np.ceil(interval_control))
819
+ # For Xlabs ControlNet.
820
+ if controlnet_blocks_repeat:
821
+ hidden_states = (
822
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
823
+ )
824
+ else:
825
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
826
+
827
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
828
+
829
+ for index_block, block in enumerate(self.single_transformer_blocks):
830
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
831
+
832
+ def create_custom_forward(module, return_dict=None):
833
+ def custom_forward(*inputs):
834
+ if return_dict is not None:
835
+ return module(*inputs, return_dict=return_dict)
836
+ else:
837
+ return module(*inputs)
838
+
839
+ return custom_forward
840
+
841
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
842
+ hidden_states = torch.utils.checkpoint.checkpoint(
843
+ create_custom_forward(block),
844
+ hidden_states,
845
+ temb,
846
+ image_rotary_emb,
847
+ **ckpt_kwargs,
848
+ )
849
+
850
+ else:
851
+ hidden_states = block(
852
+ hidden_states=hidden_states,
853
+ temb=temb,
854
+ image_rotary_emb=image_rotary_emb,
855
+ joint_attention_kwargs=joint_attention_kwargs,
856
+ )
857
+
858
+ # controlnet residual
859
+ if controlnet_single_block_samples is not None:
860
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
861
+ interval_control = int(np.ceil(interval_control))
862
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
863
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
864
+ + controlnet_single_block_samples[index_block // interval_control]
865
+ )
866
+
867
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
868
+
869
+ hidden_states = self.norm_out(hidden_states, temb)
870
+ output = self.proj_out(hidden_states)
871
+
872
+ if USE_PEFT_BACKEND:
873
+ # remove `lora_scale` from each PEFT layer
874
+ unscale_lora_layers(self, lora_scale)
875
+
876
+ if not return_dict:
877
+ return (output,)
878
+
879
+ return Transformer2DModelOutput(sample=output)
880
+
881
+
882
+ def Transformer2DModelForward(
883
+ self,
884
+ hidden_states: torch.Tensor,
885
+ encoder_hidden_states: Optional[torch.Tensor] = None,
886
+ timestep: Optional[torch.LongTensor] = None,
887
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
888
+ class_labels: Optional[torch.LongTensor] = None,
889
+ cross_attention_kwargs: Dict[str, Any] = None,
890
+ attention_mask: Optional[torch.Tensor] = None,
891
+ encoder_attention_mask: Optional[torch.Tensor] = None,
892
+ return_dict: bool = True,
893
+ ):
894
+ """
895
+ The [`Transformer2DModel`] forward method.
896
+
897
+ Args:
898
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
899
+ Input `hidden_states`.
900
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
901
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
902
+ self-attention.
903
+ timestep ( `torch.LongTensor`, *optional*):
904
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
905
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
906
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
907
+ `AdaLayerZeroNorm`.
908
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
909
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
910
+ `self.processor` in
911
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
912
+ attention_mask ( `torch.Tensor`, *optional*):
913
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
914
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
915
+ negative values to the attention scores corresponding to "discard" tokens.
916
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
917
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
918
+
919
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
920
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
921
+
922
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
923
+ above. This bias will be added to the cross-attention scores.
924
+ return_dict (`bool`, *optional*, defaults to `True`):
925
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
926
+ tuple.
927
+
928
+ Returns:
929
+ If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned,
930
+ otherwise a `tuple` where the first element is the sample tensor.
931
+ """
932
+ if cross_attention_kwargs is not None:
933
+ if cross_attention_kwargs.get("scale", None) is not None:
934
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
935
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
936
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
937
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
938
+ # expects mask of shape:
939
+ # [batch, key_tokens]
940
+ # adds singleton query_tokens dimension:
941
+ # [batch, 1, key_tokens]
942
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
943
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
944
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
945
+ if attention_mask is not None and attention_mask.ndim == 2:
946
+ # assume that mask is expressed as:
947
+ # (1 = keep, 0 = discard)
948
+ # convert mask into a bias that can be added to attention scores:
949
+ # (keep = +0, discard = -10000.0)
950
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
951
+ attention_mask = attention_mask.unsqueeze(1)
952
+
953
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
954
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
955
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
956
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
957
+
958
+ # 1. Input
959
+ if self.is_input_continuous:
960
+ batch_size, _, height, width = hidden_states.shape
961
+ residual = hidden_states
962
+ hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
963
+ elif self.is_input_vectorized:
964
+ hidden_states = self.latent_image_embedding(hidden_states)
965
+ elif self.is_input_patches:
966
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
967
+ hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
968
+ hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
969
+ )
970
+
971
+ ####################################################################################################
972
+ cross_attention_kwargs['height'] = height
973
+ cross_attention_kwargs['width'] = width
974
+ ####################################################################################################
975
+
976
+ # 2. Blocks
977
+ for block in self.transformer_blocks:
978
+ if self.training and self.gradient_checkpointing:
979
+
980
+ def create_custom_forward(module, return_dict=None):
981
+ def custom_forward(*inputs):
982
+ if return_dict is not None:
983
+ return module(*inputs, return_dict=return_dict)
984
+ else:
985
+ return module(*inputs)
986
+
987
+ return custom_forward
988
+
989
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
990
+ hidden_states = torch.utils.checkpoint.checkpoint(
991
+ create_custom_forward(block),
992
+ hidden_states,
993
+ attention_mask,
994
+ encoder_hidden_states,
995
+ encoder_attention_mask,
996
+ timestep,
997
+ cross_attention_kwargs,
998
+ class_labels,
999
+ **ckpt_kwargs,
1000
+ )
1001
+ else:
1002
+ hidden_states = block(
1003
+ hidden_states,
1004
+ attention_mask=attention_mask,
1005
+ encoder_hidden_states=encoder_hidden_states,
1006
+ encoder_attention_mask=encoder_attention_mask,
1007
+ timestep=timestep,
1008
+ cross_attention_kwargs=cross_attention_kwargs,
1009
+ class_labels=class_labels,
1010
+ )
1011
+
1012
+ # 3. Output
1013
+ if self.is_input_continuous:
1014
+ output = self._get_output_for_continuous_inputs(
1015
+ hidden_states=hidden_states,
1016
+ residual=residual,
1017
+ batch_size=batch_size,
1018
+ height=height,
1019
+ width=width,
1020
+ inner_dim=inner_dim,
1021
+ )
1022
+ elif self.is_input_vectorized:
1023
+ output = self._get_output_for_vectorized_inputs(hidden_states)
1024
+ elif self.is_input_patches:
1025
+ output = self._get_output_for_patched_inputs(
1026
+ hidden_states=hidden_states,
1027
+ timestep=timestep,
1028
+ class_labels=class_labels,
1029
+ embedded_timestep=embedded_timestep,
1030
+ height=height,
1031
+ width=width,
1032
+ )
1033
+
1034
+ if not return_dict:
1035
+ return (output,)
1036
+
1037
+ return Transformer2DModelOutput(sample=output)
1038
+
1039
+
1040
+ def BasicTransformerBlockForward(
1041
+ self,
1042
+ hidden_states: torch.Tensor,
1043
+ attention_mask: Optional[torch.Tensor] = None,
1044
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1045
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1046
+ timestep: Optional[torch.LongTensor] = None,
1047
+ cross_attention_kwargs: Dict[str, Any] = None,
1048
+ class_labels: Optional[torch.LongTensor] = None,
1049
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1050
+ ) -> torch.Tensor:
1051
+ if cross_attention_kwargs is not None:
1052
+ if cross_attention_kwargs.get("scale", None) is not None:
1053
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1054
+
1055
+ # Notice that normalization is always applied before the real computation in the following blocks.
1056
+ # 0. Self-Attention
1057
+ batch_size = hidden_states.shape[0]
1058
+
1059
+ if self.norm_type == "ada_norm":
1060
+ norm_hidden_states = self.norm1(hidden_states, timestep)
1061
+ elif self.norm_type == "ada_norm_zero":
1062
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
1063
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
1064
+ )
1065
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
1066
+ norm_hidden_states = self.norm1(hidden_states)
1067
+ elif self.norm_type == "ada_norm_continuous":
1068
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
1069
+ elif self.norm_type == "ada_norm_single":
1070
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
1071
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
1072
+ ).chunk(6, dim=1)
1073
+ norm_hidden_states = self.norm1(hidden_states)
1074
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
1075
+ norm_hidden_states = norm_hidden_states.squeeze(1)
1076
+ else:
1077
+ raise ValueError("Incorrect norm used")
1078
+
1079
+ if self.pos_embed is not None:
1080
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1081
+
1082
+ # 1. Prepare GLIGEN inputs
1083
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1084
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
1085
+
1086
+ ################################################################################
1087
+ attn_parameters = set(inspect.signature(self.attn1.processor.__call__).parameters.keys())
1088
+ ################################################################################
1089
+
1090
+ attn_output = self.attn1(
1091
+ norm_hidden_states,
1092
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1093
+ attention_mask=attention_mask,
1094
+ ################################################################################
1095
+ **{k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters},
1096
+ ################################################################################
1097
+ )
1098
+ if self.norm_type == "ada_norm_zero":
1099
+ attn_output = gate_msa.unsqueeze(1) * attn_output
1100
+ elif self.norm_type == "ada_norm_single":
1101
+ attn_output = gate_msa * attn_output
1102
+
1103
+ hidden_states = attn_output + hidden_states
1104
+ if hidden_states.ndim == 4:
1105
+ hidden_states = hidden_states.squeeze(1)
1106
+
1107
+ # 1.2 GLIGEN Control
1108
+ if gligen_kwargs is not None:
1109
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
1110
+
1111
+ # 3. Cross-Attention
1112
+ if self.attn2 is not None:
1113
+ if self.norm_type == "ada_norm":
1114
+ norm_hidden_states = self.norm2(hidden_states, timestep)
1115
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
1116
+ norm_hidden_states = self.norm2(hidden_states)
1117
+ elif self.norm_type == "ada_norm_single":
1118
+ # For PixArt norm2 isn't applied here:
1119
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
1120
+ norm_hidden_states = hidden_states
1121
+ elif self.norm_type == "ada_norm_continuous":
1122
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
1123
+ else:
1124
+ raise ValueError("Incorrect norm")
1125
+
1126
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1127
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1128
+
1129
+ attn_output = self.attn2(
1130
+ norm_hidden_states,
1131
+ encoder_hidden_states=encoder_hidden_states,
1132
+ attention_mask=encoder_attention_mask,
1133
+ **cross_attention_kwargs,
1134
+ )
1135
+ hidden_states = attn_output + hidden_states
1136
+
1137
+ # 4. Feed-forward
1138
+ # i2vgen doesn't have this norm 🤷‍♂️
1139
+ if self.norm_type == "ada_norm_continuous":
1140
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
1141
+ elif not self.norm_type == "ada_norm_single":
1142
+ norm_hidden_states = self.norm3(hidden_states)
1143
+
1144
+ if self.norm_type == "ada_norm_zero":
1145
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
1146
+
1147
+ if self.norm_type == "ada_norm_single":
1148
+ norm_hidden_states = self.norm2(hidden_states)
1149
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
1150
+
1151
+ if self._chunk_size is not None:
1152
+ # "feed_forward_chunk_size" can be used to save memory
1153
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1154
+ else:
1155
+ ff_output = self.ff(norm_hidden_states)
1156
+
1157
+ if self.norm_type == "ada_norm_zero":
1158
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
1159
+ elif self.norm_type == "ada_norm_single":
1160
+ ff_output = gate_mlp * ff_output
1161
+
1162
+ hidden_states = ff_output + hidden_states
1163
+ if hidden_states.ndim == 4:
1164
+ hidden_states = hidden_states.squeeze(1)
1165
+
1166
+ return hidden_states
1167
+
1168
+
1169
+ def JointTransformerBlockForward(
1170
+ self,
1171
+ hidden_states: torch.FloatTensor,
1172
+ encoder_hidden_states: torch.FloatTensor,
1173
+ temb: torch.FloatTensor,
1174
+ ############################################################
1175
+ height: int = None,
1176
+ timestep: Optional[torch.Tensor] = None,
1177
+ ############################################################
1178
+ ):
1179
+ if self.use_dual_attention:
1180
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
1181
+ hidden_states, emb=temb
1182
+ )
1183
+ else:
1184
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
1185
+
1186
+ if self.context_pre_only:
1187
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
1188
+ else:
1189
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
1190
+ encoder_hidden_states, emb=temb
1191
+ )
1192
+
1193
+ # Attention.
1194
+ attn_output, context_attn_output = self.attn(
1195
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
1196
+ ############################################################
1197
+ timestep=timestep, height=height,
1198
+ ############################################################
1199
+ )
1200
+
1201
+ # Process attention outputs for the `hidden_states`.
1202
+ attn_output = gate_msa.unsqueeze(1) * attn_output
1203
+ hidden_states = hidden_states + attn_output
1204
+
1205
+ if self.use_dual_attention:
1206
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
1207
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
1208
+ hidden_states = hidden_states + attn_output2
1209
+
1210
+ norm_hidden_states = self.norm2(hidden_states)
1211
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
1212
+ if self._chunk_size is not None:
1213
+ # "feed_forward_chunk_size" can be used to save memory
1214
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1215
+ else:
1216
+ ff_output = self.ff(norm_hidden_states)
1217
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
1218
+
1219
+ hidden_states = hidden_states + ff_output
1220
+
1221
+ # Process attention outputs for the `encoder_hidden_states`.
1222
+ if self.context_pre_only:
1223
+ encoder_hidden_states = None
1224
+ else:
1225
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
1226
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
1227
+
1228
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
1229
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
1230
+ if self._chunk_size is not None:
1231
+ # "feed_forward_chunk_size" can be used to save memory
1232
+ context_ff_output = _chunked_feed_forward(
1233
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
1234
+ )
1235
+ else:
1236
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
1237
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
1238
+
1239
+ return encoder_hidden_states, hidden_states
1240
+
1241
+
1242
+ def FluxTransformerBlockForward(
1243
+ self,
1244
+ hidden_states: torch.FloatTensor,
1245
+ encoder_hidden_states: torch.FloatTensor,
1246
+ temb: torch.FloatTensor,
1247
+ image_rotary_emb=None,
1248
+ joint_attention_kwargs=None,
1249
+ ############################################################
1250
+ height: int = None,
1251
+ timestep: Optional[torch.Tensor] = None,
1252
+ ############################################################
1253
+ ):
1254
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
1255
+
1256
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
1257
+ encoder_hidden_states, emb=temb
1258
+ )
1259
+ joint_attention_kwargs = joint_attention_kwargs or {}
1260
+ # Attention.
1261
+ attn_output, context_attn_output = self.attn(
1262
+ hidden_states=norm_hidden_states,
1263
+ encoder_hidden_states=norm_encoder_hidden_states,
1264
+ image_rotary_emb=image_rotary_emb,
1265
+ ############################################################
1266
+ timestep=timestep, height=height,
1267
+ ############################################################
1268
+ **joint_attention_kwargs,
1269
+ )
1270
+
1271
+ # Process attention outputs for the `hidden_states`.
1272
+ attn_output = gate_msa.unsqueeze(1) * attn_output
1273
+ hidden_states = hidden_states + attn_output
1274
+
1275
+ norm_hidden_states = self.norm2(hidden_states)
1276
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
1277
+
1278
+ ff_output = self.ff(norm_hidden_states)
1279
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
1280
+
1281
+ hidden_states = hidden_states + ff_output
1282
+
1283
+ # Process attention outputs for the `encoder_hidden_states`.
1284
+
1285
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
1286
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
1287
+
1288
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
1289
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
1290
+
1291
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
1292
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
1293
+ if encoder_hidden_states.dtype == torch.float16:
1294
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
1295
+
1296
+ return encoder_hidden_states, hidden_states
1297
+
1298
+
1299
+ def attn_call(
1300
+ self,
1301
+ attn: Attention,
1302
+ hidden_states: torch.Tensor,
1303
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1304
+ attention_mask: Optional[torch.Tensor] = None,
1305
+ temb: Optional[torch.Tensor] = None,
1306
+ height: int = None,
1307
+ width: int = None,
1308
+ timestep: Optional[torch.Tensor] = None,
1309
+ *args,
1310
+ **kwargs,
1311
+ ) -> torch.Tensor:
1312
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1313
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1314
+ deprecate("scale", "1.0.0", deprecation_message)
1315
+
1316
+ residual = hidden_states
1317
+
1318
+ if attn.spatial_norm is not None:
1319
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1320
+
1321
+ input_ndim = hidden_states.ndim
1322
+
1323
+ if input_ndim == 4:
1324
+ batch_size, channel, height, width = hidden_states.shape
1325
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1326
+
1327
+ batch_size, sequence_length, _ = (
1328
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1329
+ )
1330
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1331
+
1332
+ if attn.group_norm is not None:
1333
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1334
+
1335
+ query = attn.to_q(hidden_states)
1336
+
1337
+ if encoder_hidden_states is None:
1338
+ encoder_hidden_states = hidden_states
1339
+ elif attn.norm_cross:
1340
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1341
+
1342
+ key = attn.to_k(encoder_hidden_states)
1343
+ value = attn.to_v(encoder_hidden_states)
1344
+
1345
+ query = attn.head_to_batch_dim(query)
1346
+ key = attn.head_to_batch_dim(key)
1347
+ value = attn.head_to_batch_dim(value)
1348
+
1349
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1350
+ ####################################################################################################
1351
+ if hasattr(self, "store_attn_map"):
1352
+ self.attn_map = rearrange(attention_probs, 'b (h w) d -> b d h w', h=height)
1353
+ self.timestep = int(timestep.item())
1354
+ ####################################################################################################
1355
+ hidden_states = torch.bmm(attention_probs, value)
1356
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1357
+
1358
+ # linear proj
1359
+ hidden_states = attn.to_out[0](hidden_states)
1360
+ # dropout
1361
+ hidden_states = attn.to_out[1](hidden_states)
1362
+
1363
+ if input_ndim == 4:
1364
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1365
+
1366
+ if attn.residual_connection:
1367
+ hidden_states = hidden_states + residual
1368
+
1369
+ hidden_states = hidden_states / attn.rescale_output_factor
1370
+
1371
+ return hidden_states
1372
+
1373
+
1374
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
1375
+ # Efficient implementation equivalent to the following:
1376
+ L, S = query.size(-2), key.size(-2)
1377
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
1378
+ attn_bias = torch.zeros(L, S, dtype=query.dtype)
1379
+ if is_causal:
1380
+ assert attn_mask is None
1381
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
1382
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
1383
+ attn_bias.to(query.dtype)
1384
+
1385
+ if attn_mask is not None:
1386
+ if attn_mask.dtype == torch.bool:
1387
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
1388
+ else:
1389
+ attn_bias += attn_mask
1390
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
1391
+ attn_weight += attn_bias.to(attn_weight.device)
1392
+ attn_weight = torch.softmax(attn_weight, dim=-1)
1393
+
1394
+ return torch.dropout(attn_weight, dropout_p, train=True) @ value, attn_weight
1395
+
1396
+
1397
+ def attn_call2_0(
1398
+ self,
1399
+ attn: Attention,
1400
+ hidden_states: torch.Tensor,
1401
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1402
+ attention_mask: Optional[torch.Tensor] = None,
1403
+ temb: Optional[torch.Tensor] = None,
1404
+ height: int = None,
1405
+ width: int = None,
1406
+ timestep: Optional[torch.Tensor] = None,
1407
+ *args,
1408
+ **kwargs,
1409
+ ) -> torch.Tensor:
1410
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1411
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1412
+ deprecate("scale", "1.0.0", deprecation_message)
1413
+
1414
+ residual = hidden_states
1415
+ if attn.spatial_norm is not None:
1416
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1417
+
1418
+ input_ndim = hidden_states.ndim
1419
+
1420
+ if input_ndim == 4:
1421
+ batch_size, channel, height, width = hidden_states.shape
1422
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1423
+
1424
+ batch_size, sequence_length, _ = (
1425
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1426
+ )
1427
+
1428
+ if attention_mask is not None:
1429
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1430
+ # scaled_dot_product_attention expects attention_mask shape to be
1431
+ # (batch, heads, source_length, target_length)
1432
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1433
+
1434
+ if attn.group_norm is not None:
1435
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1436
+
1437
+ query = attn.to_q(hidden_states)
1438
+
1439
+ if encoder_hidden_states is None:
1440
+ encoder_hidden_states = hidden_states
1441
+ elif attn.norm_cross:
1442
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1443
+
1444
+ key = attn.to_k(encoder_hidden_states)
1445
+ value = attn.to_v(encoder_hidden_states)
1446
+
1447
+ inner_dim = key.shape[-1]
1448
+ head_dim = inner_dim // attn.heads
1449
+
1450
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1451
+
1452
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1453
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1454
+
1455
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1456
+ # TODO: add support for attn.scale when we move to Torch 2.1
1457
+ ####################################################################################################
1458
+ if hasattr(self, "store_attn_map"):
1459
+ hidden_states, attention_probs = scaled_dot_product_attention(
1460
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1461
+ )
1462
+ self.attn_map = rearrange(attention_probs, 'batch attn_head (h w) attn_dim -> batch attn_head h w attn_dim ', h=height) # detach height*width
1463
+ self.timestep = int(timestep.item())
1464
+ else:
1465
+ hidden_states = F.scaled_dot_product_attention(
1466
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1467
+ )
1468
+ ####################################################################################################
1469
+
1470
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # (b,attn_head,h*w,attn_dim) -> (b,h*w,attn_head*attn_dim)
1471
+ hidden_states = hidden_states.to(query.dtype)
1472
+
1473
+ # linear proj
1474
+ hidden_states = attn.to_out[0](hidden_states)
1475
+ # dropout
1476
+ hidden_states = attn.to_out[1](hidden_states)
1477
+
1478
+ if input_ndim == 4:
1479
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1480
+
1481
+ if attn.residual_connection:
1482
+ hidden_states = hidden_states + residual
1483
+
1484
+ hidden_states = hidden_states / attn.rescale_output_factor
1485
+
1486
+ return hidden_states
1487
+
1488
+
1489
+ def lora_attn_call(self, attn: Attention, hidden_states, height, width, *args, **kwargs):
1490
+ self_cls_name = self.__class__.__name__
1491
+ deprecate(
1492
+ self_cls_name,
1493
+ "0.26.0",
1494
+ (
1495
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1496
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1497
+ " `LoraLoaderMixin.load_lora_weights`"
1498
+ ),
1499
+ )
1500
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1501
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1502
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1503
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1504
+
1505
+ attn._modules.pop("processor")
1506
+ attn.processor = AttnProcessor()
1507
+ ####################################################################################################
1508
+ attn.processor.__call__ = attn_call.__get__(attn.processor, AttnProcessor)
1509
+ ####################################################################################################
1510
+
1511
+ if hasattr(self, "store_attn_map"):
1512
+ attn.processor.store_attn_map = True
1513
+
1514
+ return attn.processor(attn, hidden_states, height, width, *args, **kwargs)
1515
+
1516
+
1517
+ def lora_attn_call2_0(self, attn: Attention, hidden_states, height, width, *args, **kwargs):
1518
+ self_cls_name = self.__class__.__name__
1519
+ deprecate(
1520
+ self_cls_name,
1521
+ "0.26.0",
1522
+ (
1523
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1524
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1525
+ " `LoraLoaderMixin.load_lora_weights`"
1526
+ ),
1527
+ )
1528
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1529
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1530
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1531
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1532
+
1533
+ attn._modules.pop("processor")
1534
+ attn.processor = AttnProcessor2_0()
1535
+ ####################################################################################################
1536
+ attn.processor.__call__ = attn_call.__get__(attn.processor, AttnProcessor2_0)
1537
+ ####################################################################################################
1538
+
1539
+ if hasattr(self, "store_attn_map"):
1540
+ attn.processor.store_attn_map = True
1541
+
1542
+ return attn.processor(attn, hidden_states, height, width, *args, **kwargs)
1543
+
1544
+
1545
+ def joint_attn_call2_0(
1546
+ self,
1547
+ attn: Attention,
1548
+ hidden_states: torch.FloatTensor,
1549
+ encoder_hidden_states: torch.FloatTensor = None,
1550
+ attention_mask: Optional[torch.FloatTensor] = None,
1551
+ ############################################################
1552
+ height: int = None,
1553
+ timestep: Optional[torch.Tensor] = None,
1554
+ ############################################################
1555
+ *args,
1556
+ **kwargs,
1557
+ ) -> torch.FloatTensor:
1558
+ residual = hidden_states
1559
+
1560
+ batch_size = hidden_states.shape[0]
1561
+
1562
+ # `sample` projections.
1563
+ query = attn.to_q(hidden_states)
1564
+ key = attn.to_k(hidden_states)
1565
+ value = attn.to_v(hidden_states)
1566
+
1567
+ inner_dim = key.shape[-1]
1568
+ head_dim = inner_dim // attn.heads
1569
+
1570
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1571
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1572
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1573
+
1574
+ if attn.norm_q is not None:
1575
+ query = attn.norm_q(query)
1576
+ if attn.norm_k is not None:
1577
+ key = attn.norm_k(key)
1578
+
1579
+ # `context` projections.
1580
+ if encoder_hidden_states is not None:
1581
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1582
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1583
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1584
+
1585
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1586
+ batch_size, -1, attn.heads, head_dim
1587
+ ).transpose(1, 2)
1588
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1589
+ batch_size, -1, attn.heads, head_dim
1590
+ ).transpose(1, 2)
1591
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1592
+ batch_size, -1, attn.heads, head_dim
1593
+ ).transpose(1, 2)
1594
+
1595
+ if attn.norm_added_q is not None:
1596
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1597
+ if attn.norm_added_k is not None:
1598
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1599
+
1600
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
1601
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
1602
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
1603
+
1604
+ ####################################################################################################
1605
+ if hasattr(self, "store_attn_map"):
1606
+ hidden_states, attention_probs = scaled_dot_product_attention(
1607
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1608
+ )
1609
+
1610
+ image_length = query.shape[2] - encoder_hidden_states_query_proj.shape[2]
1611
+
1612
+ # (4,24,4429,4429) -> (4,24,4096,333)
1613
+ attention_probs = attention_probs[:,:,:image_length,image_length:].cpu()
1614
+
1615
+ self.attn_map = rearrange(
1616
+ attention_probs,
1617
+ 'batch attn_head (height width) attn_dim -> batch attn_head height width attn_dim',
1618
+ height = height
1619
+ ) # (4, 24, 4096, 333) -> (4, 24, height, width, 333)
1620
+ self.timestep = timestep[0].cpu().item() # TODO: int -> list
1621
+ else:
1622
+ hidden_states = F.scaled_dot_product_attention(
1623
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1624
+ )
1625
+ ####################################################################################################
1626
+
1627
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1628
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1629
+ hidden_states = hidden_states.to(query.dtype)
1630
+
1631
+ if encoder_hidden_states is not None:
1632
+ # Split the attention outputs.
1633
+ hidden_states, encoder_hidden_states = (
1634
+ hidden_states[:, : residual.shape[1]],
1635
+ hidden_states[:, residual.shape[1] :],
1636
+ )
1637
+ if not attn.context_pre_only:
1638
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1639
+
1640
+ # linear proj
1641
+ hidden_states = attn.to_out[0](hidden_states)
1642
+ # dropout
1643
+ hidden_states = attn.to_out[1](hidden_states)
1644
+
1645
+ if encoder_hidden_states is not None:
1646
+ return hidden_states, encoder_hidden_states
1647
+ else:
1648
+ return hidden_states
1649
+
1650
+
1651
+ # FluxAttnProcessor2_0
1652
+ def flux_attn_call2_0(
1653
+ self,
1654
+ attn: Attention,
1655
+ hidden_states: torch.FloatTensor,
1656
+ encoder_hidden_states: torch.FloatTensor = None,
1657
+ attention_mask: Optional[torch.FloatTensor] = None,
1658
+ image_rotary_emb: Optional[torch.Tensor] = None,
1659
+ ############################################################
1660
+ height: int = None,
1661
+ timestep: Optional[torch.Tensor] = None,
1662
+ ############################################################
1663
+ ) -> torch.FloatTensor:
1664
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1665
+
1666
+ # `sample` projections.
1667
+ query = attn.to_q(hidden_states)
1668
+ key = attn.to_k(hidden_states)
1669
+ value = attn.to_v(hidden_states)
1670
+
1671
+ inner_dim = key.shape[-1]
1672
+ head_dim = inner_dim // attn.heads
1673
+
1674
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1675
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1676
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1677
+
1678
+ if attn.norm_q is not None:
1679
+ query = attn.norm_q(query)
1680
+ if attn.norm_k is not None:
1681
+ key = attn.norm_k(key)
1682
+
1683
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1684
+ if encoder_hidden_states is not None:
1685
+ # `context` projections.
1686
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1687
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1688
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1689
+
1690
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1691
+ batch_size, -1, attn.heads, head_dim
1692
+ ).transpose(1, 2)
1693
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1694
+ batch_size, -1, attn.heads, head_dim
1695
+ ).transpose(1, 2)
1696
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1697
+ batch_size, -1, attn.heads, head_dim
1698
+ ).transpose(1, 2)
1699
+
1700
+ if attn.norm_added_q is not None:
1701
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1702
+ if attn.norm_added_k is not None:
1703
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1704
+
1705
+ # attention
1706
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1707
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1708
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1709
+
1710
+ if image_rotary_emb is not None:
1711
+ from diffusers.models.embeddings import apply_rotary_emb
1712
+
1713
+
1714
+ query = apply_rotary_emb(query, image_rotary_emb)
1715
+ key = apply_rotary_emb(key, image_rotary_emb)
1716
+
1717
+ ####################################################################################################
1718
+ if hasattr(self, "store_attn_map"):
1719
+ hidden_states, attention_probs = scaled_dot_product_attention(
1720
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1721
+ )
1722
+
1723
+ image_length = query.shape[2] - encoder_hidden_states_query_proj.shape[2]
1724
+
1725
+ # (4,24,4429,4429) -> (4,24,4096,333)
1726
+ attention_probs = attention_probs[:,:,:image_length,image_length:].cpu()
1727
+
1728
+ self.attn_map = rearrange(
1729
+ attention_probs,
1730
+ 'batch attn_head (height width) attn_dim -> batch attn_head height width attn_dim',
1731
+ height = height
1732
+ ) # (4, 24, 4096, 333) -> (4, 24, height, width, 333)
1733
+ self.timestep = timestep[0].cpu().item() # TODO: int -> list
1734
+ else:
1735
+ hidden_states = F.scaled_dot_product_attention(
1736
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1737
+ )
1738
+ ####################################################################################################
1739
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1740
+ hidden_states = hidden_states.to(query.dtype)
1741
+
1742
+ if encoder_hidden_states is not None:
1743
+ encoder_hidden_states, hidden_states = (
1744
+ hidden_states[:, : encoder_hidden_states.shape[1]],
1745
+ hidden_states[:, encoder_hidden_states.shape[1] :],
1746
+ )
1747
+
1748
+ # linear proj
1749
+ hidden_states = attn.to_out[0](hidden_states)
1750
+ # dropout
1751
+ hidden_states = attn.to_out[1](hidden_states)
1752
+
1753
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1754
+
1755
+ return hidden_states, encoder_hidden_states
1756
+ else:
1757
+ return hidden_states