Diffusers
TalHach61 commited on
Commit
040ea64
·
verified ·
1 Parent(s): 66fe703

Upload 5 files

Browse files
bria_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional, List
2
+ import torch
3
+ from diffusers.utils import logging
4
+ from transformers import (
5
+ T5EncoderModel,
6
+ T5TokenizerFast,
7
+ )
8
+ import numpy as np
9
+
10
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11
+
12
+ def get_t5_prompt_embeds(
13
+ tokenizer: T5TokenizerFast ,
14
+ text_encoder: T5EncoderModel,
15
+ prompt: Union[str, List[str]] = None,
16
+ num_images_per_prompt: int = 1,
17
+ max_sequence_length: int = 128,
18
+ device: Optional[torch.device] = None,
19
+ ):
20
+ device = device or text_encoder.device
21
+
22
+ prompt = [prompt] if isinstance(prompt, str) else prompt
23
+ batch_size = len(prompt)
24
+
25
+ text_inputs = tokenizer(
26
+ prompt,
27
+ # padding="max_length",
28
+ max_length=max_sequence_length,
29
+ truncation=True,
30
+ add_special_tokens=True,
31
+ return_tensors="pt",
32
+ )
33
+ text_input_ids = text_inputs.input_ids
34
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
35
+
36
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
37
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
38
+ logger.warning(
39
+ "The following part of your input was truncated because `max_sequence_length` is set to "
40
+ f" {max_sequence_length} tokens: {removed_text}"
41
+ )
42
+
43
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
44
+
45
+ # Concat zeros to max_sequence
46
+ b, seq_len, dim = prompt_embeds.shape
47
+ if seq_len<max_sequence_length:
48
+ padding = torch.zeros((b,max_sequence_length-seq_len,dim),dtype=prompt_embeds.dtype,device=prompt_embeds.device)
49
+ prompt_embeds = torch.concat([prompt_embeds,padding],dim=1)
50
+
51
+ prompt_embeds = prompt_embeds.to(device=device)
52
+
53
+ _, seq_len, _ = prompt_embeds.shape
54
+
55
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
56
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
57
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
58
+
59
+ return prompt_embeds
60
+
61
+ # in order the get the same sigmas as in training and sample from them
62
+ def get_original_sigmas(num_train_timesteps=1000,num_inference_steps=1000):
63
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
64
+ sigmas = timesteps / num_train_timesteps
65
+
66
+ inds = [int(ind) for ind in np.linspace(0, num_train_timesteps-1, num_inference_steps)]
67
+ new_sigmas = sigmas[inds]
68
+ return new_sigmas
69
+
70
+ def is_ng_none(negative_prompt):
71
+ return negative_prompt is None or negative_prompt=='' or (isinstance(negative_prompt,list) and negative_prompt[0] is None) or (type(negative_prompt)==list and negative_prompt[0]=='')
controlnet_bria.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from transformer_bria import TimestepProjEmbeddings
23
+ from diffusers.models.controlnet import zero_module, BaseOutput
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import PeftAdapterMixin
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
29
+
30
+ # from transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock, EmbedND
31
+ from diffusers.models.transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock
32
+
33
+ from diffusers.models.attention_processor import AttentionProcessor
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ @dataclass
39
+ class BriaControlNetOutput(BaseOutput):
40
+ controlnet_block_samples: Tuple[torch.Tensor]
41
+ controlnet_single_block_samples: Tuple[torch.Tensor]
42
+
43
+
44
+ class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
45
+ _supports_gradient_checkpointing = True
46
+
47
+ @register_to_config
48
+ def __init__(
49
+ self,
50
+ patch_size: int = 1,
51
+ in_channels: int = 64,
52
+ num_layers: int = 19,
53
+ num_single_layers: int = 38,
54
+ attention_head_dim: int = 128,
55
+ num_attention_heads: int = 24,
56
+ joint_attention_dim: int = 4096,
57
+ pooled_projection_dim: int = 768,
58
+ guidance_embeds: bool = False,
59
+ axes_dims_rope: List[int] = [16, 56, 56],
60
+ num_mode: int = None,
61
+ rope_theta: int = 10000,
62
+ time_theta: int = 10000,
63
+ ):
64
+ super().__init__()
65
+ self.out_channels = in_channels
66
+ self.inner_dim = num_attention_heads * attention_head_dim
67
+
68
+ # self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
69
+ self.pos_embed = EmbedND(dim=self.inner_dim, theta=rope_theta, axes_dim=axes_dims_rope)
70
+
71
+ # text_time_guidance_cls = (
72
+ # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
73
+ # )
74
+ # self.time_text_embed = text_time_guidance_cls(
75
+ # embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
76
+ # )
77
+ self.time_embed = TimestepProjEmbeddings(
78
+ embedding_dim=self.inner_dim, time_theta=time_theta
79
+ )
80
+
81
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
82
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
83
+
84
+ self.transformer_blocks = nn.ModuleList(
85
+ [
86
+ FluxTransformerBlock(
87
+ dim=self.inner_dim,
88
+ num_attention_heads=num_attention_heads,
89
+ attention_head_dim=attention_head_dim,
90
+ )
91
+ for i in range(num_layers)
92
+ ]
93
+ )
94
+
95
+ self.single_transformer_blocks = nn.ModuleList(
96
+ [
97
+ FluxSingleTransformerBlock(
98
+ dim=self.inner_dim,
99
+ num_attention_heads=num_attention_heads,
100
+ attention_head_dim=attention_head_dim,
101
+ )
102
+ for i in range(num_single_layers)
103
+ ]
104
+ )
105
+
106
+ # controlnet_blocks
107
+ self.controlnet_blocks = nn.ModuleList([])
108
+ for _ in range(len(self.transformer_blocks)):
109
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
110
+
111
+ self.controlnet_single_blocks = nn.ModuleList([])
112
+ for _ in range(len(self.single_transformer_blocks)):
113
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
114
+
115
+ self.union = num_mode is not None and num_mode > 0
116
+ if self.union:
117
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
118
+
119
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
120
+
121
+ self.gradient_checkpointing = False
122
+
123
+ @property
124
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
125
+ def attn_processors(self):
126
+ r"""
127
+ Returns:
128
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
129
+ indexed by its weight name.
130
+ """
131
+ # set recursively
132
+ processors = {}
133
+
134
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
135
+ if hasattr(module, "get_processor"):
136
+ processors[f"{name}.processor"] = module.get_processor()
137
+
138
+ for sub_name, child in module.named_children():
139
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
140
+
141
+ return processors
142
+
143
+ for name, module in self.named_children():
144
+ fn_recursive_add_processors(name, module, processors)
145
+
146
+ return processors
147
+
148
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
149
+ def set_attn_processor(self, processor):
150
+ r"""
151
+ Sets the attention processor to use to compute attention.
152
+
153
+ Parameters:
154
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
155
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
156
+ for **all** `Attention` layers.
157
+
158
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
159
+ processor. This is strongly recommended when setting trainable attention processors.
160
+
161
+ """
162
+ count = len(self.attn_processors.keys())
163
+
164
+ if isinstance(processor, dict) and len(processor) != count:
165
+ raise ValueError(
166
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
167
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
168
+ )
169
+
170
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
171
+ if hasattr(module, "set_processor"):
172
+ if not isinstance(processor, dict):
173
+ module.set_processor(processor)
174
+ else:
175
+ module.set_processor(processor.pop(f"{name}.processor"))
176
+
177
+ for sub_name, child in module.named_children():
178
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
179
+
180
+ for name, module in self.named_children():
181
+ fn_recursive_attn_processor(name, module, processor)
182
+
183
+ def _set_gradient_checkpointing(self, module, value=False):
184
+ if hasattr(module, "gradient_checkpointing"):
185
+ module.gradient_checkpointing = value
186
+
187
+ @classmethod
188
+ def from_transformer(
189
+ cls,
190
+ transformer,
191
+ num_layers: int = 4,
192
+ num_single_layers: int = 10,
193
+ attention_head_dim: int = 128,
194
+ num_attention_heads: int = 24,
195
+ load_weights_from_transformer=True,
196
+ ):
197
+ config = transformer.config
198
+ config["num_layers"] = num_layers
199
+ config["num_single_layers"] = num_single_layers
200
+ config["attention_head_dim"] = attention_head_dim
201
+ config["num_attention_heads"] = num_attention_heads
202
+
203
+ controlnet = cls(**config)
204
+
205
+ if load_weights_from_transformer:
206
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
207
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
208
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
209
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
210
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
211
+ controlnet.single_transformer_blocks.load_state_dict(
212
+ transformer.single_transformer_blocks.state_dict(), strict=False
213
+ )
214
+
215
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
216
+
217
+ return controlnet
218
+
219
+ def forward(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ controlnet_cond: torch.Tensor,
223
+ controlnet_mode: torch.Tensor = None,
224
+ conditioning_scale: float = 1.0,
225
+ encoder_hidden_states: torch.Tensor = None,
226
+ pooled_projections: torch.Tensor = None,
227
+ timestep: torch.LongTensor = None,
228
+ img_ids: torch.Tensor = None,
229
+ txt_ids: torch.Tensor = None,
230
+ guidance: torch.Tensor = None,
231
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
232
+ return_dict: bool = True,
233
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
234
+ """
235
+ The [`FluxTransformer2DModel`] forward method.
236
+
237
+ Args:
238
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
239
+ Input `hidden_states`.
240
+ controlnet_cond (`torch.Tensor`):
241
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
242
+ controlnet_mode (`torch.Tensor`):
243
+ The mode tensor of shape `(batch_size, 1)`.
244
+ conditioning_scale (`float`, defaults to `1.0`):
245
+ The scale factor for ControlNet outputs.
246
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
247
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
248
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
249
+ from the embeddings of input conditions.
250
+ timestep ( `torch.LongTensor`):
251
+ Used to indicate denoising step.
252
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
253
+ A list of tensors that if specified are added to the residuals of transformer blocks.
254
+ joint_attention_kwargs (`dict`, *optional*):
255
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
256
+ `self.processor` in
257
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
258
+ return_dict (`bool`, *optional*, defaults to `True`):
259
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
260
+ tuple.
261
+
262
+ Returns:
263
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
264
+ `tuple` where the first element is the sample tensor.
265
+ """
266
+ if guidance is not None:
267
+ print("guidance is not supported in BriaControlNetModel")
268
+ if pooled_projections is not None:
269
+ print("pooled_projections is not supported in BriaControlNetModel")
270
+ if joint_attention_kwargs is not None:
271
+ joint_attention_kwargs = joint_attention_kwargs.copy()
272
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
273
+ else:
274
+ lora_scale = 1.0
275
+
276
+ if USE_PEFT_BACKEND:
277
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
278
+ scale_lora_layers(self, lora_scale)
279
+ else:
280
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
281
+ logger.warning(
282
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
283
+ )
284
+ hidden_states = self.x_embedder(hidden_states)
285
+
286
+ # add
287
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
288
+
289
+ timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
290
+ if guidance is not None:
291
+ guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
292
+ else:
293
+ guidance = None
294
+ # temb = (
295
+ # self.time_text_embed(timestep, pooled_projections)
296
+ # if guidance is None
297
+ # else self.time_text_embed(timestep, guidance, pooled_projections)
298
+ # )
299
+ temb = self.time_embed(timestep, dtype=hidden_states.dtype)
300
+
301
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
302
+
303
+ if self.union:
304
+ # union mode
305
+ if controlnet_mode is None:
306
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
307
+ # union mode emb
308
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
309
+ if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]:
310
+ controlnet_mode_emb = controlnet_mode_emb.expand(encoder_hidden_states.shape[0], 1, 2048)
311
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
312
+ txt_ids = torch.cat((txt_ids[:, 0:1, :], txt_ids), dim=1)
313
+
314
+ # if txt_ids.ndim == 3:
315
+ # logger.warning(
316
+ # "Passing `txt_ids` 3d torch.Tensor is deprecated."
317
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
318
+ # )
319
+ # txt_ids = txt_ids[0]
320
+ # if img_ids.ndim == 3:
321
+ # logger.warning(
322
+ # "Passing `img_ids` 3d torch.Tensor is deprecated."
323
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
324
+ # )
325
+ # img_ids = img_ids[0]
326
+
327
+ # ids = torch.cat((txt_ids, img_ids), dim=0)
328
+ ids = torch.cat((txt_ids, img_ids), dim=1)
329
+ image_rotary_emb = self.pos_embed(ids)
330
+
331
+ block_samples = ()
332
+ for index_block, block in enumerate(self.transformer_blocks):
333
+ if self.training and self.gradient_checkpointing:
334
+
335
+ def create_custom_forward(module, return_dict=None):
336
+ def custom_forward(*inputs):
337
+ if return_dict is not None:
338
+ return module(*inputs, return_dict=return_dict)
339
+ else:
340
+ return module(*inputs)
341
+
342
+ return custom_forward
343
+
344
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
345
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
346
+ create_custom_forward(block),
347
+ hidden_states,
348
+ encoder_hidden_states,
349
+ temb,
350
+ image_rotary_emb,
351
+ **ckpt_kwargs,
352
+ )
353
+
354
+ else:
355
+ encoder_hidden_states, hidden_states = block(
356
+ hidden_states=hidden_states,
357
+ encoder_hidden_states=encoder_hidden_states,
358
+ temb=temb,
359
+ image_rotary_emb=image_rotary_emb,
360
+ )
361
+ block_samples = block_samples + (hidden_states,)
362
+
363
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
364
+
365
+ single_block_samples = ()
366
+ for index_block, block in enumerate(self.single_transformer_blocks):
367
+ if self.training and self.gradient_checkpointing:
368
+
369
+ def create_custom_forward(module, return_dict=None):
370
+ def custom_forward(*inputs):
371
+ if return_dict is not None:
372
+ return module(*inputs, return_dict=return_dict)
373
+ else:
374
+ return module(*inputs)
375
+
376
+ return custom_forward
377
+
378
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
379
+ hidden_states = torch.utils.checkpoint.checkpoint(
380
+ create_custom_forward(block),
381
+ hidden_states,
382
+ temb,
383
+ image_rotary_emb,
384
+ **ckpt_kwargs,
385
+ )
386
+
387
+ else:
388
+ hidden_states = block(
389
+ hidden_states=hidden_states,
390
+ temb=temb,
391
+ image_rotary_emb=image_rotary_emb,
392
+ )
393
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
394
+
395
+ # controlnet block
396
+ controlnet_block_samples = ()
397
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
398
+ block_sample = controlnet_block(block_sample)
399
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
400
+
401
+ controlnet_single_block_samples = ()
402
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
403
+ single_block_sample = controlnet_block(single_block_sample)
404
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
405
+
406
+ # scaling
407
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
408
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
409
+
410
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
411
+ controlnet_single_block_samples = (
412
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
413
+ )
414
+
415
+ if USE_PEFT_BACKEND:
416
+ # remove `lora_scale` from each PEFT layer
417
+ unscale_lora_layers(self, lora_scale)
418
+
419
+ if not return_dict:
420
+ return (controlnet_block_samples, controlnet_single_block_samples)
421
+
422
+ return BriaControlNetOutput(
423
+ controlnet_block_samples=controlnet_block_samples,
424
+ controlnet_single_block_samples=controlnet_single_block_samples,
425
+ )
426
+
427
+
428
+ class BriaMultiControlNetModel(ModelMixin):
429
+ r"""
430
+ `BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
431
+
432
+ This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
433
+ compatible with `BriaControlNetModel`.
434
+
435
+ Args:
436
+ controlnets (`List[BriaControlNetModel]`):
437
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
438
+ `BriaControlNetModel` as a list.
439
+ """
440
+
441
+ def __init__(self, controlnets):
442
+ super().__init__()
443
+ self.nets = nn.ModuleList(controlnets)
444
+
445
+ def forward(
446
+ self,
447
+ hidden_states: torch.FloatTensor,
448
+ controlnet_cond: List[torch.tensor],
449
+ controlnet_mode: List[torch.tensor],
450
+ conditioning_scale: List[float],
451
+ encoder_hidden_states: torch.Tensor = None,
452
+ pooled_projections: torch.Tensor = None,
453
+ timestep: torch.LongTensor = None,
454
+ img_ids: torch.Tensor = None,
455
+ txt_ids: torch.Tensor = None,
456
+ guidance: torch.Tensor = None,
457
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
458
+ return_dict: bool = True,
459
+ ) -> Union[BriaControlNetOutput, Tuple]:
460
+ # ControlNet-Union with multiple conditions
461
+ # only load one ControlNet for saving memories
462
+ if len(self.nets) == 1 and self.nets[0].union:
463
+ controlnet = self.nets[0]
464
+
465
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
466
+ block_samples, single_block_samples = controlnet(
467
+ hidden_states=hidden_states,
468
+ controlnet_cond=image,
469
+ controlnet_mode=mode[:, None],
470
+ conditioning_scale=scale,
471
+ timestep=timestep,
472
+ guidance=guidance,
473
+ pooled_projections=pooled_projections,
474
+ encoder_hidden_states=encoder_hidden_states,
475
+ txt_ids=txt_ids,
476
+ img_ids=img_ids,
477
+ joint_attention_kwargs=joint_attention_kwargs,
478
+ return_dict=return_dict,
479
+ )
480
+
481
+ # merge samples
482
+ if i == 0:
483
+ control_block_samples = block_samples
484
+ control_single_block_samples = single_block_samples
485
+ else:
486
+ control_block_samples = [
487
+ control_block_sample + block_sample
488
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
489
+ ]
490
+
491
+ control_single_block_samples = [
492
+ control_single_block_sample + block_sample
493
+ for control_single_block_sample, block_sample in zip(
494
+ control_single_block_samples, single_block_samples
495
+ )
496
+ ]
497
+
498
+ # Regular Multi-ControlNets
499
+ # load all ControlNets into memories
500
+ else:
501
+ for i, (image, mode, scale, controlnet) in enumerate(
502
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
503
+ ):
504
+ block_samples, single_block_samples = controlnet(
505
+ hidden_states=hidden_states,
506
+ controlnet_cond=image,
507
+ controlnet_mode=mode[:, None],
508
+ conditioning_scale=scale,
509
+ timestep=timestep,
510
+ guidance=guidance,
511
+ pooled_projections=pooled_projections,
512
+ encoder_hidden_states=encoder_hidden_states,
513
+ txt_ids=txt_ids,
514
+ img_ids=img_ids,
515
+ joint_attention_kwargs=joint_attention_kwargs,
516
+ return_dict=return_dict,
517
+ )
518
+
519
+ # merge samples
520
+ if i == 0:
521
+ control_block_samples = block_samples
522
+ control_single_block_samples = single_block_samples
523
+ else:
524
+ if block_samples is not None and control_block_samples is not None:
525
+ control_block_samples = [
526
+ control_block_sample + block_sample
527
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
528
+ ]
529
+ if single_block_samples is not None and control_single_block_samples is not None:
530
+ control_single_block_samples = [
531
+ control_single_block_sample + block_sample
532
+ for control_single_block_sample, block_sample in zip(
533
+ control_single_block_samples, single_block_samples
534
+ )
535
+ ]
536
+
537
+ return control_block_samples, control_single_block_samples
538
+
539
+
540
+
541
+ class BriaMultiControlNetModel(ModelMixin):
542
+ r"""
543
+ `BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
544
+
545
+ This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
546
+ compatible with `BriaControlNetModel`.
547
+
548
+ Args:
549
+ controlnets (`List[BriaControlNetModel]`):
550
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
551
+ `BriaControlNetModel` as a list.
552
+ """
553
+
554
+ def __init__(self, controlnets):
555
+ super().__init__()
556
+ self.nets = nn.ModuleList(controlnets)
557
+
558
+ def forward(
559
+ self,
560
+ hidden_states: torch.FloatTensor,
561
+ controlnet_cond: List[torch.tensor],
562
+ controlnet_mode: List[torch.tensor],
563
+ conditioning_scale: List[float],
564
+ encoder_hidden_states: torch.Tensor = None,
565
+ pooled_projections: torch.Tensor = None,
566
+ timestep: torch.LongTensor = None,
567
+ img_ids: torch.Tensor = None,
568
+ txt_ids: torch.Tensor = None,
569
+ guidance: torch.Tensor = None,
570
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
571
+ return_dict: bool = True,
572
+ ) -> Union[BriaControlNetOutput, Tuple]:
573
+ # ControlNet-Union with multiple conditions
574
+ # only load one ControlNet for saving memories
575
+ if len(self.nets) == 1 and self.nets[0].union:
576
+ controlnet = self.nets[0]
577
+
578
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
579
+ block_samples, single_block_samples = controlnet(
580
+ hidden_states=hidden_states,
581
+ controlnet_cond=image,
582
+ controlnet_mode=mode[:, None],
583
+ conditioning_scale=scale,
584
+ timestep=timestep,
585
+ guidance=guidance,
586
+ pooled_projections=pooled_projections,
587
+ encoder_hidden_states=encoder_hidden_states,
588
+ txt_ids=txt_ids,
589
+ img_ids=img_ids,
590
+ joint_attention_kwargs=joint_attention_kwargs,
591
+ return_dict=return_dict,
592
+ )
593
+
594
+ # merge samples
595
+ if i == 0:
596
+ control_block_samples = block_samples
597
+ control_single_block_samples = single_block_samples
598
+ else:
599
+ control_block_samples = [
600
+ control_block_sample + block_sample
601
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
602
+ ]
603
+
604
+ control_single_block_samples = [
605
+ control_single_block_sample + block_sample
606
+ for control_single_block_sample, block_sample in zip(
607
+ control_single_block_samples, single_block_samples
608
+ )
609
+ ]
610
+
611
+ # Regular Multi-ControlNets
612
+ # load all ControlNets into memories
613
+ else:
614
+ for i, (image, mode, scale, controlnet) in enumerate(
615
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
616
+ ):
617
+ block_samples, single_block_samples = controlnet(
618
+ hidden_states=hidden_states,
619
+ controlnet_cond=image,
620
+ controlnet_mode=mode[:, None],
621
+ conditioning_scale=scale,
622
+ timestep=timestep,
623
+ guidance=guidance,
624
+ pooled_projections=pooled_projections,
625
+ encoder_hidden_states=encoder_hidden_states,
626
+ txt_ids=txt_ids,
627
+ img_ids=img_ids,
628
+ joint_attention_kwargs=joint_attention_kwargs,
629
+ return_dict=return_dict,
630
+ )
631
+
632
+ # merge samples
633
+ if i == 0:
634
+ control_block_samples = block_samples
635
+ control_single_block_samples = single_block_samples
636
+ else:
637
+ if block_samples is not None and control_block_samples is not None:
638
+ control_block_samples = [
639
+ control_block_sample + block_sample
640
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
641
+ ]
642
+ if single_block_samples is not None and control_single_block_samples is not None:
643
+ control_single_block_samples = [
644
+ control_single_block_sample + block_sample
645
+ for control_single_block_sample, block_sample in zip(
646
+ control_single_block_samples, single_block_samples
647
+ )
648
+ ]
649
+
650
+ return control_block_samples, control_single_block_samples
pipeline_bria.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, retrieve_timesteps, calculate_shift
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import torch
5
+
6
+ from transformers import (
7
+ T5EncoderModel,
8
+ T5TokenizerFast,
9
+ )
10
+
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from diffusers import AutoencoderKL , DDIMScheduler, EulerAncestralDiscreteScheduler
13
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
14
+ from diffusers.schedulers import KarrasDiffusionSchedulers
15
+ from diffusers.loaders import FluxLoraLoaderMixin
16
+ from diffusers.utils import (
17
+ USE_PEFT_BACKEND,
18
+ is_torch_xla_available,
19
+ logging,
20
+ replace_example_docstring,
21
+ scale_lora_layers,
22
+ unscale_lora_layers,
23
+ )
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
26
+ from transformer_bria import BriaTransformer2DModel
27
+ from bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none
28
+ import numpy as np
29
+
30
+ if is_torch_xla_available():
31
+ import torch_xla.core.xla_model as xm
32
+
33
+ XLA_AVAILABLE = True
34
+ else:
35
+ XLA_AVAILABLE = False
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ EXAMPLE_DOC_STRING = """
41
+ Examples:
42
+ ```py
43
+ >>> import torch
44
+ >>> from diffusers import StableDiffusion3Pipeline
45
+
46
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
47
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
48
+ ... )
49
+ >>> pipe.to("cuda")
50
+ >>> prompt = "A cat holding a sign that says hello world"
51
+ >>> image = pipe(prompt).images[0]
52
+ >>> image.save("sd3.png")
53
+ ```
54
+ """
55
+
56
+ T5_PRECISION = torch.float16
57
+
58
+ """
59
+ Based on FluxPipeline with several changes:
60
+ - no pooled embeddings
61
+ - We use zero padding for prompts
62
+ - No guidance embedding since this is not a distilled version
63
+ """
64
+ class BriaPipeline(FluxPipeline):
65
+ r"""
66
+ Args:
67
+ transformer ([`SD3Transformer2DModel`]):
68
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
69
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
70
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
71
+ vae ([`AutoencoderKL`]):
72
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
73
+ text_encoder ([`T5EncoderModel`]):
74
+ Frozen text-encoder. Stable Diffusion 3 uses
75
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
76
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
77
+ tokenizer (`T5TokenizerFast`):
78
+ Tokenizer of class
79
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ transformer: BriaTransformer2DModel,
85
+ scheduler: Union[FlowMatchEulerDiscreteScheduler,KarrasDiffusionSchedulers],
86
+ vae: AutoencoderKL,
87
+ text_encoder: T5EncoderModel,
88
+ tokenizer: T5TokenizerFast
89
+ ):
90
+ self.register_modules(
91
+ vae=vae,
92
+ text_encoder=text_encoder,
93
+ tokenizer=tokenizer,
94
+ transformer=transformer,
95
+ scheduler=scheduler,
96
+ )
97
+
98
+ # TODO - why different than offical flux (-1)
99
+ self.vae_scale_factor = (
100
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
101
+ )
102
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
103
+ self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
104
+
105
+ # T5 is senstive to precision so we use the precision used for precompute and cast as needed
106
+ self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
107
+ for block in self.text_encoder.encoder.block:
108
+ block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
109
+
110
+ def encode_prompt(
111
+ self,
112
+ prompt: Union[str, List[str]],
113
+ device: Optional[torch.device] = None,
114
+ num_images_per_prompt: int = 1,
115
+ do_classifier_free_guidance: bool = True,
116
+ negative_prompt: Optional[Union[str, List[str]]] = None,
117
+ prompt_embeds: Optional[torch.FloatTensor] = None,
118
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
119
+ max_sequence_length: int = 128,
120
+ lora_scale: Optional[float] = None,
121
+ ):
122
+ r"""
123
+
124
+ Args:
125
+ prompt (`str` or `List[str]`, *optional*):
126
+ prompt to be encoded
127
+ device: (`torch.device`):
128
+ torch device
129
+ num_images_per_prompt (`int`):
130
+ number of images that should be generated per prompt
131
+ do_classifier_free_guidance (`bool`):
132
+ whether to use classifier free guidance or not
133
+ negative_prompt (`str` or `List[str]`, *optional*):
134
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
135
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
136
+ less than `1`).
137
+ prompt_embeds (`torch.FloatTensor`, *optional*):
138
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
139
+ provided, text embeddings will be generated from `prompt` input argument.
140
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
141
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
142
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
143
+ argument.
144
+ """
145
+ device = device or self._execution_device
146
+
147
+ # set lora scale so that monkey patched LoRA
148
+ # function of text encoder can correctly access it
149
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
150
+ self._lora_scale = lora_scale
151
+
152
+ # dynamically adjust the LoRA scale
153
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
154
+ scale_lora_layers(self.text_encoder, lora_scale)
155
+
156
+ prompt = [prompt] if isinstance(prompt, str) else prompt
157
+ if prompt is not None:
158
+ batch_size = len(prompt)
159
+ else:
160
+ batch_size = prompt_embeds.shape[0]
161
+
162
+ if prompt_embeds is None:
163
+ prompt_embeds = get_t5_prompt_embeds(
164
+ self.tokenizer,
165
+ self.text_encoder,
166
+ prompt=prompt,
167
+ num_images_per_prompt=num_images_per_prompt,
168
+ max_sequence_length=max_sequence_length,
169
+ device=device,
170
+ ).to(dtype=self.transformer.dtype)
171
+
172
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
173
+ if not is_ng_none(negative_prompt):
174
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
175
+
176
+ if prompt is not None and type(prompt) is not type(negative_prompt):
177
+ raise TypeError(
178
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
179
+ f" {type(prompt)}."
180
+ )
181
+ elif batch_size != len(negative_prompt):
182
+ raise ValueError(
183
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
184
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
185
+ " the batch size of `prompt`."
186
+ )
187
+
188
+ negative_prompt_embeds = get_t5_prompt_embeds(
189
+ self.tokenizer,
190
+ self.text_encoder,
191
+ prompt=negative_prompt,
192
+ num_images_per_prompt=num_images_per_prompt,
193
+ max_sequence_length=max_sequence_length,
194
+ device=device,
195
+ ).to(dtype=self.transformer.dtype)
196
+ else:
197
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
198
+
199
+ if self.text_encoder is not None:
200
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
201
+ # Retrieve the original scale by scaling back the LoRA layers
202
+ unscale_lora_layers(self.text_encoder, lora_scale)
203
+
204
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
205
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
206
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
207
+
208
+ return prompt_embeds, negative_prompt_embeds, text_ids
209
+
210
+ @property
211
+ def guidance_scale(self):
212
+ return self._guidance_scale
213
+
214
+
215
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
216
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
217
+ # corresponds to doing no classifier free guidance.
218
+ @property
219
+ def do_classifier_free_guidance(self):
220
+ return self._guidance_scale > 1
221
+
222
+ @property
223
+ def joint_attention_kwargs(self):
224
+ return self._joint_attention_kwargs
225
+
226
+ @property
227
+ def num_timesteps(self):
228
+ return self._num_timesteps
229
+
230
+ @property
231
+ def interrupt(self):
232
+ return self._interrupt
233
+
234
+ @torch.no_grad()
235
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
236
+ def __call__(
237
+ self,
238
+ prompt: Union[str, List[str]] = None,
239
+ height: Optional[int] = None,
240
+ width: Optional[int] = None,
241
+ num_inference_steps: int = 30,
242
+ timesteps: List[int] = None,
243
+ guidance_scale: float = 5,
244
+ negative_prompt: Optional[Union[str, List[str]]] = None,
245
+ num_images_per_prompt: Optional[int] = 1,
246
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
247
+ latents: Optional[torch.FloatTensor] = None,
248
+ prompt_embeds: Optional[torch.FloatTensor] = None,
249
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
250
+ output_type: Optional[str] = "pil",
251
+ return_dict: bool = True,
252
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
253
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
254
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
255
+ max_sequence_length: int = 128,
256
+ clip_value:Union[None,float] = None,
257
+ normalize:bool = False,
258
+ ):
259
+ r"""
260
+ Function invoked when calling the pipeline for generation.
261
+
262
+ Args:
263
+ prompt (`str` or `List[str]`, *optional*):
264
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
265
+ instead.
266
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
267
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
268
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
269
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
270
+ num_inference_steps (`int`, *optional*, defaults to 50):
271
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
272
+ expense of slower inference.
273
+ timesteps (`List[int]`, *optional*):
274
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
275
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
276
+ passed will be used. Must be in descending order.
277
+ guidance_scale (`float`, *optional*, defaults to 5.0):
278
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
279
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
280
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
281
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
282
+ usually at the expense of lower image quality.
283
+ negative_prompt (`str` or `List[str]`, *optional*):
284
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
285
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
286
+ less than `1`).
287
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
288
+ The number of images to generate per prompt.
289
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
290
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
291
+ to make generation deterministic.
292
+ latents (`torch.FloatTensor`, *optional*):
293
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
294
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
295
+ tensor will ge generated by sampling using the supplied random `generator`.
296
+ prompt_embeds (`torch.FloatTensor`, *optional*):
297
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
298
+ provided, text embeddings will be generated from `prompt` input argument.
299
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
300
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
301
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
302
+ argument.
303
+ output_type (`str`, *optional*, defaults to `"pil"`):
304
+ The output format of the generate image. Choose between
305
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
306
+ return_dict (`bool`, *optional*, defaults to `True`):
307
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
308
+ of a plain tuple.
309
+ joint_attention_kwargs (`dict`, *optional*):
310
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
311
+ `self.processor` in
312
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
313
+ callback_on_step_end (`Callable`, *optional*):
314
+ A function that calls at the end of each denoising steps during the inference. The function is called
315
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
316
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
317
+ `callback_on_step_end_tensor_inputs`.
318
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
319
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
320
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
321
+ `._callback_tensor_inputs` attribute of your pipeline class.
322
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
323
+
324
+ Examples:
325
+
326
+ Returns:
327
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
328
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
329
+ images.
330
+ """
331
+
332
+ height = height or self.default_sample_size * self.vae_scale_factor
333
+ width = width or self.default_sample_size * self.vae_scale_factor
334
+
335
+ # 1. Check inputs. Raise error if not correct
336
+ self.check_inputs(
337
+ prompt=prompt,
338
+ height=height,
339
+ width=width,
340
+ prompt_embeds=prompt_embeds,
341
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
342
+ max_sequence_length=max_sequence_length,
343
+ )
344
+
345
+ self._guidance_scale = guidance_scale
346
+ self._joint_attention_kwargs = joint_attention_kwargs
347
+ self._interrupt = False
348
+
349
+ # 2. Define call parameters
350
+ if prompt is not None and isinstance(prompt, str):
351
+ batch_size = 1
352
+ elif prompt is not None and isinstance(prompt, list):
353
+ batch_size = len(prompt)
354
+ else:
355
+ batch_size = prompt_embeds.shape[0]
356
+
357
+ device = self._execution_device
358
+
359
+ lora_scale = (
360
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
361
+ )
362
+
363
+ (
364
+ prompt_embeds,
365
+ negative_prompt_embeds,
366
+ text_ids
367
+ ) = self.encode_prompt(
368
+ prompt=prompt,
369
+ negative_prompt=negative_prompt,
370
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
371
+ prompt_embeds=prompt_embeds,
372
+ negative_prompt_embeds=negative_prompt_embeds,
373
+ device=device,
374
+ num_images_per_prompt=num_images_per_prompt,
375
+ max_sequence_length=max_sequence_length,
376
+ lora_scale=lora_scale,
377
+ )
378
+
379
+ if self.do_classifier_free_guidance:
380
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
381
+
382
+
383
+
384
+ # 5. Prepare latent variables
385
+ num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
386
+ latents, latent_image_ids = self.prepare_latents(
387
+ batch_size * num_images_per_prompt,
388
+ num_channels_latents,
389
+ height,
390
+ width,
391
+ prompt_embeds.dtype,
392
+ device,
393
+ generator,
394
+ latents,
395
+ )
396
+
397
+ if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
398
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
399
+ image_seq_len = latents.shape[1] # Shift by height - Why just height?
400
+ print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
401
+
402
+ mu = calculate_shift(
403
+ image_seq_len,
404
+ self.scheduler.config.base_image_seq_len,
405
+ self.scheduler.config.max_image_seq_len,
406
+ self.scheduler.config.base_shift,
407
+ self.scheduler.config.max_shift,
408
+ )
409
+ timesteps, num_inference_steps = retrieve_timesteps(
410
+ self.scheduler,
411
+ num_inference_steps,
412
+ device,
413
+ timesteps,
414
+ sigmas,
415
+ mu=mu,
416
+ )
417
+ else:
418
+ # 4. Prepare timesteps
419
+ # Sample from training sigmas
420
+ if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler):
421
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None)
422
+ else:
423
+ sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps)
424
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas)
425
+
426
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
427
+ self._num_timesteps = len(timesteps)
428
+
429
+ # Supprot different diffusers versions
430
+ if len(latent_image_ids.shape)==2:
431
+ text_ids=text_ids.squeeze()
432
+
433
+ # 6. Denoising loop
434
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
435
+ for i, t in enumerate(timesteps):
436
+ if self.interrupt:
437
+ continue
438
+
439
+ # expand the latents if we are doing classifier free guidance
440
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
441
+ if type(self.scheduler)!=FlowMatchEulerDiscreteScheduler:
442
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
443
+
444
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
445
+ timestep = t.expand(latent_model_input.shape[0])
446
+
447
+ # This is predicts "v" from flow-matching or eps from diffusion
448
+ noise_pred = self.transformer(
449
+ hidden_states=latent_model_input,
450
+ timestep=timestep,
451
+ encoder_hidden_states=prompt_embeds,
452
+ joint_attention_kwargs=self.joint_attention_kwargs,
453
+ return_dict=False,
454
+ txt_ids=text_ids,
455
+ img_ids=latent_image_ids,
456
+ )[0]
457
+
458
+ # perform guidance
459
+ if self.do_classifier_free_guidance:
460
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
461
+ cfg_noise_pred_text = noise_pred_text.std()
462
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
463
+
464
+ if normalize:
465
+ noise_pred = noise_pred * (0.7 *(cfg_noise_pred_text/noise_pred.std())) + 0.3 * noise_pred
466
+
467
+ if clip_value:
468
+ assert clip_value>0
469
+ noise_pred = noise_pred.clip(-clip_value,clip_value)
470
+
471
+ # compute the previous noisy sample x_t -> x_t-1
472
+ latents_dtype = latents.dtype
473
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
474
+
475
+ if latents.dtype != latents_dtype:
476
+ if torch.backends.mps.is_available():
477
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
478
+ latents = latents.to(latents_dtype)
479
+
480
+ if callback_on_step_end is not None:
481
+ callback_kwargs = {}
482
+ for k in callback_on_step_end_tensor_inputs:
483
+ callback_kwargs[k] = locals()[k]
484
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
485
+
486
+ latents = callback_outputs.pop("latents", latents)
487
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
488
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
489
+
490
+ # call the callback, if provided
491
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
492
+ progress_bar.update()
493
+
494
+ if XLA_AVAILABLE:
495
+ xm.mark_step()
496
+
497
+ if output_type == "latent":
498
+ image = latents
499
+
500
+ else:
501
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
502
+ latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
503
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
504
+ image = self.image_processor.postprocess(image, output_type=output_type)
505
+
506
+ # Offload all models
507
+ self.maybe_free_model_hooks()
508
+
509
+ if not return_dict:
510
+ return (image,)
511
+
512
+ return FluxPipelineOutput(images=image)
513
+
514
+ def check_inputs(
515
+ self,
516
+ prompt,
517
+ height,
518
+ width,
519
+ negative_prompt=None,
520
+ prompt_embeds=None,
521
+ negative_prompt_embeds=None,
522
+ callback_on_step_end_tensor_inputs=None,
523
+ max_sequence_length=None,
524
+ ):
525
+ if height % 8 != 0 or width % 8 != 0:
526
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
527
+
528
+ if callback_on_step_end_tensor_inputs is not None and not all(
529
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
530
+ ):
531
+ raise ValueError(
532
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
533
+ )
534
+
535
+ if prompt is not None and prompt_embeds is not None:
536
+ raise ValueError(
537
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
538
+ " only forward one of the two."
539
+ )
540
+ elif prompt is None and prompt_embeds is None:
541
+ raise ValueError(
542
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
543
+ )
544
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
545
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
546
+
547
+ if negative_prompt is not None and negative_prompt_embeds is not None:
548
+ raise ValueError(
549
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
550
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
551
+ )
552
+
553
+
554
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
555
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
556
+ raise ValueError(
557
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
558
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
559
+ f" {negative_prompt_embeds.shape}."
560
+ )
561
+
562
+ if max_sequence_length is not None and max_sequence_length > 512:
563
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
564
+
565
+ def to(self, *args, **kwargs):
566
+ DiffusionPipeline.to(self, *args, **kwargs)
567
+ # T5 is senstive to precision so we use the precision used for precompute and cast as needed
568
+ self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
569
+ for block in self.text_encoder.encoder.block:
570
+ block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
571
+
572
+ return self
573
+
574
+
575
+
576
+
pipeline_bria_controlnet.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, Dict, List, Optional, Union
16
+ import torch
17
+ from transformers import (
18
+ T5EncoderModel,
19
+ T5TokenizerFast,
20
+ )
21
+ from diffusers.image_processor import PipelineImageInput
22
+
23
+ from diffusers import AutoencoderKL # Waiting for diffusers udpdate
24
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
25
+ from diffusers.schedulers import KarrasDiffusionSchedulers
26
+ from diffusers.utils import logging
27
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
28
+ from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps
29
+ from .controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel
30
+
31
+ from .pipeline_bria import BriaPipeline
32
+ from transformer_bria import BriaTransformer2DModel
33
+ from bria_utils import get_original_sigmas
34
+
35
+ XLA_AVAILABLE = False
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class BriaControlNetPipeline(BriaPipeline):
42
+ r"""
43
+ Args:
44
+ transformer ([`SD3Transformer2DModel`]):
45
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
46
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
47
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
48
+ vae ([`AutoencoderKL`]):
49
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
50
+ text_encoder ([`T5EncoderModel`]):
51
+ Frozen text-encoder. Stable Diffusion 3 uses
52
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
53
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
54
+ tokenizer (`T5TokenizerFast`):
55
+ Tokenizer of class
56
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
57
+ """
58
+
59
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
60
+ _optional_components = []
61
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
62
+
63
+ def __init__( # EYAL - removed clip text encoder + tokenizer
64
+ self,
65
+ transformer: BriaTransformer2DModel,
66
+ scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
67
+ vae: AutoencoderKL,
68
+ text_encoder: T5EncoderModel,
69
+ tokenizer: T5TokenizerFast,
70
+ controlnet: BriaControlNetModel,
71
+ ):
72
+ super().__init__(
73
+ transformer=transformer, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer
74
+ )
75
+ self.register_modules(controlnet=controlnet)
76
+
77
+ def prepare_image(
78
+ self,
79
+ image,
80
+ width,
81
+ height,
82
+ batch_size,
83
+ num_images_per_prompt,
84
+ device,
85
+ dtype,
86
+ do_classifier_free_guidance=False,
87
+ guess_mode=False,
88
+ ):
89
+ if isinstance(image, torch.Tensor):
90
+ pass
91
+ else:
92
+ image = self.image_processor.preprocess(image, height=height, width=width)
93
+
94
+ image_batch_size = image.shape[0]
95
+
96
+ if image_batch_size == 1:
97
+ repeat_by = batch_size
98
+ else:
99
+ # image batch size is the same as prompt batch size
100
+ repeat_by = num_images_per_prompt
101
+
102
+ image = image.repeat_interleave(repeat_by, dim=0)
103
+
104
+ image = image.to(device=device, dtype=dtype)
105
+
106
+ if do_classifier_free_guidance and not guess_mode:
107
+ image = torch.cat([image] * 2)
108
+
109
+ return image
110
+
111
+ def prepare_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
112
+ num_channels_latents = self.transformer.config.in_channels // 4
113
+ control_image = self.prepare_image(
114
+ image=control_image,
115
+ width=width,
116
+ height=height,
117
+ batch_size=batch_size * num_images_per_prompt,
118
+ num_images_per_prompt=num_images_per_prompt,
119
+ device=device,
120
+ dtype=self.vae.dtype,
121
+ )
122
+ height, width = control_image.shape[-2:]
123
+
124
+ # vae encode
125
+ control_image = self.vae.encode(control_image).latent_dist.sample()
126
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
127
+
128
+ # pack
129
+ height_control_image, width_control_image = control_image.shape[2:]
130
+ control_image = self._pack_latents(
131
+ control_image,
132
+ batch_size * num_images_per_prompt,
133
+ num_channels_latents,
134
+ height_control_image,
135
+ width_control_image,
136
+ )
137
+
138
+ # Here we ensure that `control_mode` has the same length as the control_image.
139
+ if control_mode is not None:
140
+ if not isinstance(control_mode, int):
141
+ raise ValueError(" For `BriaControlNet`, `control_mode` should be an `int` or `None`")
142
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
143
+ control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
144
+
145
+ return control_image, control_mode
146
+
147
+ def prepare_multi_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
148
+ num_channels_latents = self.transformer.config.in_channels // 4
149
+ control_images = []
150
+ for i, control_image_ in enumerate(control_image):
151
+ control_image_ = self.prepare_image(
152
+ image=control_image_,
153
+ width=width,
154
+ height=height,
155
+ batch_size=batch_size * num_images_per_prompt,
156
+ num_images_per_prompt=num_images_per_prompt,
157
+ device=device,
158
+ dtype=self.vae.dtype,
159
+ )
160
+ height, width = control_image_.shape[-2:]
161
+
162
+ # vae encode
163
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
164
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
165
+
166
+ # pack
167
+ height_control_image, width_control_image = control_image_.shape[2:]
168
+ control_image_ = self._pack_latents(
169
+ control_image_,
170
+ batch_size * num_images_per_prompt,
171
+ num_channels_latents,
172
+ height_control_image,
173
+ width_control_image,
174
+ )
175
+ control_images.append(control_image_)
176
+
177
+ control_image = control_images
178
+
179
+ # Here we ensure that `control_mode` has the same length as the control_image.
180
+ if isinstance(control_mode, list) and len(control_mode) != len(control_image):
181
+ raise ValueError(
182
+ "For Multi-ControlNet, `control_mode` must be a list of the same "
183
+ + " length as the number of controlnets (control images) specified"
184
+ )
185
+ if not isinstance(control_mode, list):
186
+ control_mode = [control_mode] * len(control_image)
187
+ # set control mode
188
+ control_modes = []
189
+ for cmode in control_mode:
190
+ if cmode is None:
191
+ cmode = -1
192
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
193
+ control_modes.append(control_mode)
194
+ control_mode = control_modes
195
+
196
+ return control_image, control_mode
197
+
198
+ def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end):
199
+ controlnet_keep = []
200
+ for i in range(len(timesteps)):
201
+ keeps = [
202
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
203
+ for s, e in zip(control_guidance_start, control_guidance_end)
204
+ ]
205
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps)
206
+ return controlnet_keep
207
+
208
+ def get_control_start_end(self, control_guidance_start, control_guidance_end):
209
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
210
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
211
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
212
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
213
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
214
+ mult = 1 # TODO - why is this 1?
215
+ control_guidance_start, control_guidance_end = (
216
+ mult * [control_guidance_start],
217
+ mult * [control_guidance_end],
218
+ )
219
+
220
+ return control_guidance_start, control_guidance_end
221
+
222
+ @torch.no_grad()
223
+ def __call__(
224
+ self,
225
+ prompt: Union[str, List[str]] = None,
226
+ height: Optional[int] = None,
227
+ width: Optional[int] = None,
228
+ num_inference_steps: int = 30,
229
+ timesteps: List[int] = None,
230
+ guidance_scale: float = 3.5,
231
+ control_guidance_start: Union[float, List[float]] = 0.0,
232
+ control_guidance_end: Union[float, List[float]] = 1.0,
233
+ control_image: Optional[PipelineImageInput] = None,
234
+ control_mode: Optional[Union[int, List[int]]] = None,
235
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
236
+ negative_prompt: Optional[Union[str, List[str]]] = None,
237
+ num_images_per_prompt: Optional[int] = 1,
238
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
239
+ latents: Optional[torch.FloatTensor] = None,
240
+ prompt_embeds: Optional[torch.FloatTensor] = None,
241
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
242
+ output_type: Optional[str] = "pil",
243
+ return_dict: bool = True,
244
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
245
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
246
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
247
+ max_sequence_length: int = 128,
248
+ ):
249
+ r"""
250
+ Function invoked when calling the pipeline for generation.
251
+
252
+ Args:
253
+ prompt (`str` or `List[str]`, *optional*):
254
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
255
+ instead.
256
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
257
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
258
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
259
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
260
+ num_inference_steps (`int`, *optional*, defaults to 50):
261
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
262
+ expense of slower inference.
263
+ timesteps (`List[int]`, *optional*):
264
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
265
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
266
+ passed will be used. Must be in descending order.
267
+ guidance_scale (`float`, *optional*, defaults to 5.0):
268
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
269
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
270
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
271
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
272
+ usually at the expense of lower image quality.
273
+ negative_prompt (`str` or `List[str]`, *optional*):
274
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
275
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
276
+ less than `1`).
277
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
278
+ The number of images to generate per prompt.
279
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
280
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
281
+ to make generation deterministic.
282
+ latents (`torch.FloatTensor`, *optional*):
283
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
284
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
285
+ tensor will ge generated by sampling using the supplied random `generator`.
286
+ prompt_embeds (`torch.FloatTensor`, *optional*):
287
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
288
+ provided, text embeddings will be generated from `prompt` input argument.
289
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
290
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
291
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
292
+ argument.
293
+ output_type (`str`, *optional*, defaults to `"pil"`):
294
+ The output format of the generate image. Choose between
295
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
296
+ return_dict (`bool`, *optional*, defaults to `True`):
297
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
298
+ of a plain tuple.
299
+ joint_attention_kwargs (`dict`, *optional*):
300
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
301
+ `self.processor` in
302
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
303
+ callback_on_step_end (`Callable`, *optional*):
304
+ A function that calls at the end of each denoising steps during the inference. The function is called
305
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
306
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
307
+ `callback_on_step_end_tensor_inputs`.
308
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
309
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
310
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
311
+ `._callback_tensor_inputs` attribute of your pipeline class.
312
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
313
+
314
+ Examples:
315
+
316
+ Returns:
317
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
318
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
319
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
320
+ """
321
+
322
+ height = height or self.default_sample_size * self.vae_scale_factor
323
+ width = width or self.default_sample_size * self.vae_scale_factor
324
+ control_guidance_start, control_guidance_end = self.get_control_start_end(
325
+ control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end
326
+ )
327
+
328
+ # 1. Check inputs. Raise error if not correct
329
+ self.check_inputs(
330
+ prompt,
331
+ height,
332
+ width,
333
+ negative_prompt=negative_prompt,
334
+ prompt_embeds=prompt_embeds,
335
+ negative_prompt_embeds=negative_prompt_embeds,
336
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
337
+ max_sequence_length=max_sequence_length,
338
+ )
339
+
340
+ self._guidance_scale = guidance_scale
341
+ self._joint_attention_kwargs = joint_attention_kwargs
342
+ self._interrupt = False
343
+
344
+ # 2. Define call parameters
345
+ if prompt is not None and isinstance(prompt, str):
346
+ batch_size = 1
347
+ elif prompt is not None and isinstance(prompt, list):
348
+ batch_size = len(prompt)
349
+ else:
350
+ batch_size = prompt_embeds.shape[0]
351
+
352
+ device = self._execution_device
353
+
354
+ lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
355
+
356
+ (prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
357
+ prompt=prompt,
358
+ negative_prompt=negative_prompt,
359
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
360
+ prompt_embeds=prompt_embeds,
361
+ negative_prompt_embeds=negative_prompt_embeds,
362
+ device=device,
363
+ num_images_per_prompt=num_images_per_prompt,
364
+ max_sequence_length=max_sequence_length,
365
+ lora_scale=lora_scale,
366
+ )
367
+
368
+ if self.do_classifier_free_guidance:
369
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
370
+
371
+ # 3. Prepare control image
372
+ if control_image is not None:
373
+ if isinstance(self.controlnet, BriaControlNetModel):
374
+ control_image, control_mode = self.prepare_control(
375
+ control_image=control_image,
376
+ width=width,
377
+ height=height,
378
+ batch_size=batch_size,
379
+ num_images_per_prompt=num_images_per_prompt,
380
+ device=device,
381
+ control_mode=control_mode,
382
+ )
383
+ elif isinstance(self.controlnet, BriaMultiControlNetModel):
384
+ control_image, control_mode = self.prepare_multi_control(
385
+ control_image=control_image,
386
+ width=width,
387
+ height=height,
388
+ batch_size=batch_size,
389
+ num_images_per_prompt=num_images_per_prompt,
390
+ device=device,
391
+ control_mode=control_mode,
392
+ )
393
+
394
+ # 4. Prepare timesteps
395
+ # Sample from training sigmas
396
+ sigmas = get_original_sigmas(
397
+ num_train_timesteps=self.scheduler.config.num_train_timesteps, num_inference_steps=num_inference_steps
398
+ )
399
+ timesteps, num_inference_steps = retrieve_timesteps(
400
+ self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
401
+ )
402
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
403
+ self._num_timesteps = len(timesteps)
404
+
405
+ # 5. Prepare latent variables
406
+ num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
407
+ latents, latent_image_ids = self.prepare_latents(
408
+ batch_size=batch_size * num_images_per_prompt,
409
+ num_channels_latents=num_channels_latents,
410
+ height=height,
411
+ width=width,
412
+ dtype=prompt_embeds.dtype,
413
+ device=device,
414
+ generator=generator,
415
+ latents=latents,
416
+ )
417
+
418
+ # 6. Create tensor stating which controlnets to keep
419
+ if control_image is not None:
420
+ controlnet_keep = self.get_controlnet_keep(
421
+ timesteps=timesteps,
422
+ control_guidance_start=control_guidance_start,
423
+ control_guidance_end=control_guidance_end,
424
+ )
425
+
426
+ # EYAL - added the CFG loop
427
+ # 7. Denoising loop
428
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
429
+ for i, t in enumerate(timesteps):
430
+ if self.interrupt:
431
+ continue
432
+
433
+ # expand the latents if we are doing classifier free guidance
434
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
435
+ # if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
436
+ if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
437
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
438
+
439
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
440
+ timestep = t.expand(latent_model_input.shape[0])
441
+
442
+ # Handling ControlNet
443
+ if control_image is not None:
444
+ if isinstance(controlnet_keep[i], list):
445
+ if isinstance(controlnet_conditioning_scale, list):
446
+ cond_scale = controlnet_conditioning_scale
447
+ else:
448
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
449
+ else:
450
+ controlnet_cond_scale = controlnet_conditioning_scale
451
+ if isinstance(controlnet_cond_scale, list):
452
+ controlnet_cond_scale = controlnet_cond_scale[0]
453
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
454
+
455
+ # controlnet
456
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
457
+ hidden_states=latents,
458
+ controlnet_cond=control_image,
459
+ controlnet_mode=control_mode,
460
+ conditioning_scale=cond_scale,
461
+ timestep=timestep,
462
+ # guidance=guidance,
463
+ # pooled_projections=pooled_prompt_embeds,
464
+ encoder_hidden_states=prompt_embeds,
465
+ txt_ids=text_ids,
466
+ img_ids=latent_image_ids,
467
+ joint_attention_kwargs=self.joint_attention_kwargs,
468
+ return_dict=False,
469
+ )
470
+ else:
471
+ controlnet_block_samples, controlnet_single_block_samples = None, None
472
+
473
+ # This is predicts "v" from flow-matching
474
+ noise_pred = self.transformer(
475
+ hidden_states=latent_model_input,
476
+ timestep=timestep,
477
+ encoder_hidden_states=prompt_embeds,
478
+ joint_attention_kwargs=self.joint_attention_kwargs,
479
+ return_dict=False,
480
+ txt_ids=text_ids,
481
+ img_ids=latent_image_ids,
482
+ controlnet_block_samples=controlnet_block_samples,
483
+ controlnet_single_block_samples=controlnet_single_block_samples,
484
+ )[0]
485
+
486
+ # perform guidance
487
+ if self.do_classifier_free_guidance:
488
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
489
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
490
+
491
+ # compute the previous noisy sample x_t -> x_t-1
492
+ latents_dtype = latents.dtype
493
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
494
+
495
+ if latents.dtype != latents_dtype:
496
+ if torch.backends.mps.is_available():
497
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
498
+ latents = latents.to(latents_dtype)
499
+
500
+ if callback_on_step_end is not None:
501
+ callback_kwargs = {}
502
+ for k in callback_on_step_end_tensor_inputs:
503
+ callback_kwargs[k] = locals()[k]
504
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
505
+
506
+ latents = callback_outputs.pop("latents", latents)
507
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
508
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
509
+
510
+ # call the callback, if provided
511
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
512
+ progress_bar.update()
513
+
514
+ if XLA_AVAILABLE:
515
+ xm.mark_step()
516
+
517
+ if output_type == "latent":
518
+ image = latents
519
+
520
+ else:
521
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
522
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
523
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
524
+ image = self.image_processor.postprocess(image, output_type=output_type)
525
+
526
+ # Offload all models
527
+ self.maybe_free_model_hooks()
528
+
529
+ if not return_dict:
530
+ return (image,)
531
+
532
+ return FluxPipelineOutput(images=image)
transformer_bria.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.loaders import PeftAdapterMixin, FromOriginalModelMixin
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+ from diffusers.models.normalization import AdaLayerNormContinuous
9
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
10
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
11
+ from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
12
+ from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
13
+
14
+ # Support different diffusers versions
15
+ try:
16
+ from diffusers.models.embeddings import FluxPosEmbed as EmbedND
17
+ except:
18
+ from diffusers.models.transformers.transformer_flux import rope
19
+ class EmbedND(nn.Module):
20
+ def __init__(self, theta: int, axes_dim: List[int]):
21
+ super().__init__()
22
+ self.theta = theta
23
+ self.axes_dim = axes_dim
24
+
25
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
26
+ n_axes = ids.shape[-1]
27
+ emb = torch.cat(
28
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
29
+ dim=-3,
30
+ )
31
+ return emb.unsqueeze(1)
32
+
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+ class Timesteps(nn.Module):
38
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1,time_theta=10000):
39
+ super().__init__()
40
+ self.num_channels = num_channels
41
+ self.flip_sin_to_cos = flip_sin_to_cos
42
+ self.downscale_freq_shift = downscale_freq_shift
43
+ self.scale = scale
44
+ self.time_theta=time_theta
45
+
46
+ def forward(self, timesteps):
47
+ t_emb = get_timestep_embedding(
48
+ timesteps,
49
+ self.num_channels,
50
+ flip_sin_to_cos=self.flip_sin_to_cos,
51
+ downscale_freq_shift=self.downscale_freq_shift,
52
+ scale=self.scale,
53
+ max_period=self.time_theta
54
+ )
55
+ return t_emb
56
+
57
+ class TimestepProjEmbeddings(nn.Module):
58
+ def __init__(self, embedding_dim, time_theta):
59
+ super().__init__()
60
+
61
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0,time_theta=time_theta)
62
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
63
+
64
+ def forward(self, timestep, dtype):
65
+ timesteps_proj = self.time_proj(timestep)
66
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
67
+ return timesteps_emb
68
+
69
+ """
70
+ Based on FluxPipeline with several changes:
71
+ - no pooled embeddings
72
+ - We use zero padding for prompts
73
+ - No guidance embedding since this is not a distilled version
74
+ """
75
+ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
76
+ """
77
+ The Transformer model introduced in Flux.
78
+
79
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
80
+
81
+ Parameters:
82
+ patch_size (`int`): Patch size to turn the input data into small patches.
83
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
84
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
85
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
86
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
87
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
88
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
89
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
90
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
91
+ """
92
+
93
+ _supports_gradient_checkpointing = True
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ patch_size: int = 1,
99
+ in_channels: int = 64,
100
+ num_layers: int = 19,
101
+ num_single_layers: int = 38,
102
+ attention_head_dim: int = 128,
103
+ num_attention_heads: int = 24,
104
+ joint_attention_dim: int = 4096,
105
+ pooled_projection_dim: int = None,
106
+ guidance_embeds: bool = False,
107
+ axes_dims_rope: List[int] = [16, 56, 56],
108
+ rope_theta = 10000,
109
+ time_theta = 10000
110
+ ):
111
+ super().__init__()
112
+ self.out_channels = in_channels
113
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
114
+
115
+ self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
116
+
117
+
118
+ self.time_embed = TimestepProjEmbeddings(
119
+ embedding_dim=self.inner_dim,time_theta=time_theta
120
+ )
121
+
122
+ # if pooled_projection_dim:
123
+ # self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu")
124
+
125
+ if guidance_embeds:
126
+ self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
127
+
128
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
129
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
130
+
131
+ self.transformer_blocks = nn.ModuleList(
132
+ [
133
+ FluxTransformerBlock(
134
+ dim=self.inner_dim,
135
+ num_attention_heads=self.config.num_attention_heads,
136
+ attention_head_dim=self.config.attention_head_dim,
137
+ )
138
+ for i in range(self.config.num_layers)
139
+ ]
140
+ )
141
+
142
+ self.single_transformer_blocks = nn.ModuleList(
143
+ [
144
+ FluxSingleTransformerBlock(
145
+ dim=self.inner_dim,
146
+ num_attention_heads=self.config.num_attention_heads,
147
+ attention_head_dim=self.config.attention_head_dim,
148
+ )
149
+ for i in range(self.config.num_single_layers)
150
+ ]
151
+ )
152
+
153
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
154
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
155
+
156
+ self.gradient_checkpointing = False
157
+
158
+ def _set_gradient_checkpointing(self, module, value=False):
159
+ if hasattr(module, "gradient_checkpointing"):
160
+ module.gradient_checkpointing = value
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ encoder_hidden_states: torch.Tensor = None,
166
+ pooled_projections: torch.Tensor = None,
167
+ timestep: torch.LongTensor = None,
168
+ img_ids: torch.Tensor = None,
169
+ txt_ids: torch.Tensor = None,
170
+ guidance: torch.Tensor = None,
171
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
172
+ return_dict: bool = True,
173
+ controlnet_block_samples = None,
174
+ controlnet_single_block_samples=None,
175
+
176
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
177
+ """
178
+ The [`FluxTransformer2DModel`] forward method.
179
+
180
+ Args:
181
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
182
+ Input `hidden_states`.
183
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
184
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
185
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
186
+ from the embeddings of input conditions.
187
+ timestep ( `torch.LongTensor`):
188
+ Used to indicate denoising step.
189
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
190
+ A list of tensors that if specified are added to the residuals of transformer blocks.
191
+ joint_attention_kwargs (`dict`, *optional*):
192
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
193
+ `self.processor` in
194
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
195
+ return_dict (`bool`, *optional*, defaults to `True`):
196
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
197
+ tuple.
198
+
199
+ Returns:
200
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
201
+ `tuple` where the first element is the sample tensor.
202
+ """
203
+ if joint_attention_kwargs is not None:
204
+ joint_attention_kwargs = joint_attention_kwargs.copy()
205
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
206
+ else:
207
+ lora_scale = 1.0
208
+
209
+ if USE_PEFT_BACKEND:
210
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
211
+ scale_lora_layers(self, lora_scale)
212
+ else:
213
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
214
+ logger.warning(
215
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
216
+ )
217
+ hidden_states = self.x_embedder(hidden_states)
218
+
219
+ timestep = timestep.to(hidden_states.dtype)
220
+ if guidance is not None:
221
+ guidance = guidance.to(hidden_states.dtype)
222
+ else:
223
+ guidance = None
224
+
225
+ # temb = (
226
+ # self.time_text_embed(timestep, pooled_projections)
227
+ # if guidance is None
228
+ # else self.time_text_embed(timestep, guidance, pooled_projections)
229
+ # )
230
+
231
+ temb = self.time_embed(timestep,dtype=hidden_states.dtype)
232
+
233
+ # if pooled_projections:
234
+ # temb+=self.pooled_text_embed(pooled_projections)
235
+
236
+ if guidance:
237
+ temb+=self.guidance_embed(guidance,dtype=hidden_states.dtype)
238
+
239
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
240
+
241
+ if len(txt_ids.shape)==2:
242
+ ids = torch.cat((txt_ids, img_ids), dim=0)
243
+ else:
244
+ ids = torch.cat((txt_ids, img_ids), dim=1)
245
+ image_rotary_emb = self.pos_embed(ids)
246
+
247
+ for index_block, block in enumerate(self.transformer_blocks):
248
+ if self.training and self.gradient_checkpointing:
249
+
250
+ def create_custom_forward(module, return_dict=None):
251
+ def custom_forward(*inputs):
252
+ if return_dict is not None:
253
+ return module(*inputs, return_dict=return_dict)
254
+ else:
255
+ return module(*inputs)
256
+
257
+ return custom_forward
258
+
259
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
260
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
261
+ create_custom_forward(block),
262
+ hidden_states,
263
+ encoder_hidden_states,
264
+ temb,
265
+ image_rotary_emb,
266
+ **ckpt_kwargs,
267
+ )
268
+
269
+ else:
270
+ encoder_hidden_states, hidden_states = block(
271
+ hidden_states=hidden_states,
272
+ encoder_hidden_states=encoder_hidden_states,
273
+ temb=temb,
274
+ image_rotary_emb=image_rotary_emb,
275
+ )
276
+
277
+ # controlnet residual
278
+ if controlnet_block_samples is not None:
279
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
280
+ interval_control = int(np.ceil(interval_control))
281
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
282
+
283
+
284
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
285
+
286
+ for index_block, block in enumerate(self.single_transformer_blocks):
287
+ if self.training and self.gradient_checkpointing:
288
+
289
+ def create_custom_forward(module, return_dict=None):
290
+ def custom_forward(*inputs):
291
+ if return_dict is not None:
292
+ return module(*inputs, return_dict=return_dict)
293
+ else:
294
+ return module(*inputs)
295
+
296
+ return custom_forward
297
+
298
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
299
+ hidden_states = torch.utils.checkpoint.checkpoint(
300
+ create_custom_forward(block),
301
+ hidden_states,
302
+ temb,
303
+ image_rotary_emb,
304
+ **ckpt_kwargs,
305
+ )
306
+
307
+ else:
308
+ hidden_states = block(
309
+ hidden_states=hidden_states,
310
+ temb=temb,
311
+ image_rotary_emb=image_rotary_emb,
312
+ )
313
+
314
+ # controlnet residual
315
+ if controlnet_single_block_samples is not None:
316
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
317
+ interval_control = int(np.ceil(interval_control))
318
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
319
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
320
+ + controlnet_single_block_samples[index_block // interval_control]
321
+ )
322
+
323
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
324
+
325
+ hidden_states = self.norm_out(hidden_states, temb)
326
+ output = self.proj_out(hidden_states)
327
+
328
+ if USE_PEFT_BACKEND:
329
+ # remove `lora_scale` from each PEFT layer
330
+ unscale_lora_layers(self, lora_scale)
331
+
332
+ if not return_dict:
333
+ return (output,)
334
+
335
+ return Transformer2DModelOutput(sample=output)