Diffusers
TalHach61 commited on
Commit
bbf18c3
·
verified ·
1 Parent(s): dd99fd9

Upload 2 files

Browse files
Files changed (2) hide show
  1. controlnet_bria.py +649 -0
  2. pipeline_bria_controlnet.py +532 -0
controlnet_bria.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from transformer_bria import TimestepProjEmbeddings
23
+ from diffusers.models.controlnet import zero_module, BaseOutput
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import PeftAdapterMixin
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
29
+
30
+ # from transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock, EmbedND
31
+ from diffusers.models.transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock
32
+
33
+ from diffusers.models.attention_processor import AttentionProcessor
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ @dataclass
39
+ class BriaControlNetOutput(BaseOutput):
40
+ controlnet_block_samples: Tuple[torch.Tensor]
41
+ controlnet_single_block_samples: Tuple[torch.Tensor]
42
+
43
+
44
+ class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
45
+ _supports_gradient_checkpointing = True
46
+
47
+ @register_to_config
48
+ def __init__(
49
+ self,
50
+ patch_size: int = 1,
51
+ in_channels: int = 64,
52
+ num_layers: int = 19,
53
+ num_single_layers: int = 38,
54
+ attention_head_dim: int = 128,
55
+ num_attention_heads: int = 24,
56
+ joint_attention_dim: int = 4096,
57
+ pooled_projection_dim: int = 768,
58
+ guidance_embeds: bool = False,
59
+ axes_dims_rope: List[int] = [16, 56, 56],
60
+ num_mode: int = None,
61
+ rope_theta: int = 10000,
62
+ time_theta: int = 10000,
63
+ ):
64
+ super().__init__()
65
+ self.out_channels = in_channels
66
+ self.inner_dim = num_attention_heads * attention_head_dim
67
+
68
+ # self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
69
+ self.pos_embed = EmbedND(dim=self.inner_dim, theta=rope_theta, axes_dim=axes_dims_rope)
70
+
71
+ # text_time_guidance_cls = (
72
+ # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
73
+ # )
74
+ # self.time_text_embed = text_time_guidance_cls(
75
+ # embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
76
+ # )
77
+ self.time_embed = TimestepProjEmbeddings(
78
+ embedding_dim=self.inner_dim,time_theta=time_theta
79
+ )
80
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
81
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
82
+
83
+ self.transformer_blocks = nn.ModuleList(
84
+ [
85
+ FluxTransformerBlock(
86
+ dim=self.inner_dim,
87
+ num_attention_heads=num_attention_heads,
88
+ attention_head_dim=attention_head_dim,
89
+ )
90
+ for i in range(num_layers)
91
+ ]
92
+ )
93
+
94
+ self.single_transformer_blocks = nn.ModuleList(
95
+ [
96
+ FluxSingleTransformerBlock(
97
+ dim=self.inner_dim,
98
+ num_attention_heads=num_attention_heads,
99
+ attention_head_dim=attention_head_dim,
100
+ )
101
+ for i in range(num_single_layers)
102
+ ]
103
+ )
104
+
105
+ # controlnet_blocks
106
+ self.controlnet_blocks = nn.ModuleList([])
107
+ for _ in range(len(self.transformer_blocks)):
108
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
109
+
110
+ self.controlnet_single_blocks = nn.ModuleList([])
111
+ for _ in range(len(self.single_transformer_blocks)):
112
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
113
+
114
+ self.union = num_mode is not None and num_mode > 0
115
+ if self.union:
116
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
117
+
118
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
119
+
120
+ self.gradient_checkpointing = False
121
+
122
+ @property
123
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
124
+ def attn_processors(self):
125
+ r"""
126
+ Returns:
127
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
128
+ indexed by its weight name.
129
+ """
130
+ # set recursively
131
+ processors = {}
132
+
133
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
134
+ if hasattr(module, "get_processor"):
135
+ processors[f"{name}.processor"] = module.get_processor()
136
+
137
+ for sub_name, child in module.named_children():
138
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
139
+
140
+ return processors
141
+
142
+ for name, module in self.named_children():
143
+ fn_recursive_add_processors(name, module, processors)
144
+
145
+ return processors
146
+
147
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
148
+ def set_attn_processor(self, processor):
149
+ r"""
150
+ Sets the attention processor to use to compute attention.
151
+
152
+ Parameters:
153
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
154
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
155
+ for **all** `Attention` layers.
156
+
157
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
158
+ processor. This is strongly recommended when setting trainable attention processors.
159
+
160
+ """
161
+ count = len(self.attn_processors.keys())
162
+
163
+ if isinstance(processor, dict) and len(processor) != count:
164
+ raise ValueError(
165
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
166
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
167
+ )
168
+
169
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
170
+ if hasattr(module, "set_processor"):
171
+ if not isinstance(processor, dict):
172
+ module.set_processor(processor)
173
+ else:
174
+ module.set_processor(processor.pop(f"{name}.processor"))
175
+
176
+ for sub_name, child in module.named_children():
177
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
178
+
179
+ for name, module in self.named_children():
180
+ fn_recursive_attn_processor(name, module, processor)
181
+
182
+ def _set_gradient_checkpointing(self, module, value=False):
183
+ if hasattr(module, "gradient_checkpointing"):
184
+ module.gradient_checkpointing = value
185
+
186
+ @classmethod
187
+ def from_transformer(
188
+ cls,
189
+ transformer,
190
+ num_layers: int = 4,
191
+ num_single_layers: int = 10,
192
+ attention_head_dim: int = 128,
193
+ num_attention_heads: int = 24,
194
+ load_weights_from_transformer=True,
195
+ ):
196
+ config = transformer.config
197
+ config["num_layers"] = num_layers
198
+ config["num_single_layers"] = num_single_layers
199
+ config["attention_head_dim"] = attention_head_dim
200
+ config["num_attention_heads"] = num_attention_heads
201
+
202
+ controlnet = cls(**config)
203
+
204
+ if load_weights_from_transformer:
205
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
206
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
207
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
208
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
209
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
210
+ controlnet.single_transformer_blocks.load_state_dict(
211
+ transformer.single_transformer_blocks.state_dict(), strict=False
212
+ )
213
+
214
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
215
+
216
+ return controlnet
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.Tensor,
221
+ controlnet_cond: torch.Tensor,
222
+ controlnet_mode: torch.Tensor = None,
223
+ conditioning_scale: float = 1.0,
224
+ encoder_hidden_states: torch.Tensor = None,
225
+ pooled_projections: torch.Tensor = None,
226
+ timestep: torch.LongTensor = None,
227
+ img_ids: torch.Tensor = None,
228
+ txt_ids: torch.Tensor = None,
229
+ guidance: torch.Tensor = None,
230
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
231
+ return_dict: bool = True,
232
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
233
+ """
234
+ The [`FluxTransformer2DModel`] forward method.
235
+
236
+ Args:
237
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
238
+ Input `hidden_states`.
239
+ controlnet_cond (`torch.Tensor`):
240
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
241
+ controlnet_mode (`torch.Tensor`):
242
+ The mode tensor of shape `(batch_size, 1)`.
243
+ conditioning_scale (`float`, defaults to `1.0`):
244
+ The scale factor for ControlNet outputs.
245
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
246
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
247
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
248
+ from the embeddings of input conditions.
249
+ timestep ( `torch.LongTensor`):
250
+ Used to indicate denoising step.
251
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
252
+ A list of tensors that if specified are added to the residuals of transformer blocks.
253
+ joint_attention_kwargs (`dict`, *optional*):
254
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
255
+ `self.processor` in
256
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
257
+ return_dict (`bool`, *optional*, defaults to `True`):
258
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
259
+ tuple.
260
+
261
+ Returns:
262
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
263
+ `tuple` where the first element is the sample tensor.
264
+ """
265
+ if guidance is not None:
266
+ print("guidance is not supported in BriaControlNetModel")
267
+ if pooled_projections is not None:
268
+ print("pooled_projections is not supported in BriaControlNetModel")
269
+ if joint_attention_kwargs is not None:
270
+ joint_attention_kwargs = joint_attention_kwargs.copy()
271
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
272
+ else:
273
+ lora_scale = 1.0
274
+
275
+ if USE_PEFT_BACKEND:
276
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
277
+ scale_lora_layers(self, lora_scale)
278
+ else:
279
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
280
+ logger.warning(
281
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
282
+ )
283
+ hidden_states = self.x_embedder(hidden_states)
284
+
285
+ # add
286
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
287
+
288
+ timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
289
+ if guidance is not None:
290
+ guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
291
+ else:
292
+ guidance = None
293
+ # temb = (
294
+ # self.time_text_embed(timestep, pooled_projections)
295
+ # if guidance is None
296
+ # else self.time_text_embed(timestep, guidance, pooled_projections)
297
+ # )
298
+ temb = self.time_embed(timestep, dtype=hidden_states.dtype)
299
+
300
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
301
+
302
+ if self.union:
303
+ # union mode
304
+ if controlnet_mode is None:
305
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
306
+ # union mode emb
307
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
308
+ if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]:
309
+ controlnet_mode_emb = controlnet_mode_emb.expand(encoder_hidden_states.shape[0], 1, 2048)
310
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
311
+ txt_ids = torch.cat((txt_ids[:, 0:1, :], txt_ids), dim=1)
312
+
313
+ # if txt_ids.ndim == 3:
314
+ # logger.warning(
315
+ # "Passing `txt_ids` 3d torch.Tensor is deprecated."
316
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
317
+ # )
318
+ # txt_ids = txt_ids[0]
319
+ # if img_ids.ndim == 3:
320
+ # logger.warning(
321
+ # "Passing `img_ids` 3d torch.Tensor is deprecated."
322
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
323
+ # )
324
+ # img_ids = img_ids[0]
325
+
326
+ # ids = torch.cat((txt_ids, img_ids), dim=0)
327
+ ids = torch.cat((txt_ids, img_ids), dim=1)
328
+ image_rotary_emb = self.pos_embed(ids)
329
+
330
+ block_samples = ()
331
+ for index_block, block in enumerate(self.transformer_blocks):
332
+ if self.training and self.gradient_checkpointing:
333
+
334
+ def create_custom_forward(module, return_dict=None):
335
+ def custom_forward(*inputs):
336
+ if return_dict is not None:
337
+ return module(*inputs, return_dict=return_dict)
338
+ else:
339
+ return module(*inputs)
340
+
341
+ return custom_forward
342
+
343
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
345
+ create_custom_forward(block),
346
+ hidden_states,
347
+ encoder_hidden_states,
348
+ temb,
349
+ image_rotary_emb,
350
+ **ckpt_kwargs,
351
+ )
352
+
353
+ else:
354
+ encoder_hidden_states, hidden_states = block(
355
+ hidden_states=hidden_states,
356
+ encoder_hidden_states=encoder_hidden_states,
357
+ temb=temb,
358
+ image_rotary_emb=image_rotary_emb,
359
+ )
360
+ block_samples = block_samples + (hidden_states,)
361
+
362
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
363
+
364
+ single_block_samples = ()
365
+ for index_block, block in enumerate(self.single_transformer_blocks):
366
+ if self.training and self.gradient_checkpointing:
367
+
368
+ def create_custom_forward(module, return_dict=None):
369
+ def custom_forward(*inputs):
370
+ if return_dict is not None:
371
+ return module(*inputs, return_dict=return_dict)
372
+ else:
373
+ return module(*inputs)
374
+
375
+ return custom_forward
376
+
377
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
378
+ hidden_states = torch.utils.checkpoint.checkpoint(
379
+ create_custom_forward(block),
380
+ hidden_states,
381
+ temb,
382
+ image_rotary_emb,
383
+ **ckpt_kwargs,
384
+ )
385
+
386
+ else:
387
+ hidden_states = block(
388
+ hidden_states=hidden_states,
389
+ temb=temb,
390
+ image_rotary_emb=image_rotary_emb,
391
+ )
392
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
393
+
394
+ # controlnet block
395
+ controlnet_block_samples = ()
396
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
397
+ block_sample = controlnet_block(block_sample)
398
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
399
+
400
+ controlnet_single_block_samples = ()
401
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
402
+ single_block_sample = controlnet_block(single_block_sample)
403
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
404
+
405
+ # scaling
406
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
407
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
408
+
409
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
410
+ controlnet_single_block_samples = (
411
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
412
+ )
413
+
414
+ if USE_PEFT_BACKEND:
415
+ # remove `lora_scale` from each PEFT layer
416
+ unscale_lora_layers(self, lora_scale)
417
+
418
+ if not return_dict:
419
+ return (controlnet_block_samples, controlnet_single_block_samples)
420
+
421
+ return BriaControlNetOutput(
422
+ controlnet_block_samples=controlnet_block_samples,
423
+ controlnet_single_block_samples=controlnet_single_block_samples,
424
+ )
425
+
426
+
427
+ class BriaMultiControlNetModel(ModelMixin):
428
+ r"""
429
+ `BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
430
+
431
+ This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
432
+ compatible with `BriaControlNetModel`.
433
+
434
+ Args:
435
+ controlnets (`List[BriaControlNetModel]`):
436
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
437
+ `BriaControlNetModel` as a list.
438
+ """
439
+
440
+ def __init__(self, controlnets):
441
+ super().__init__()
442
+ self.nets = nn.ModuleList(controlnets)
443
+
444
+ def forward(
445
+ self,
446
+ hidden_states: torch.FloatTensor,
447
+ controlnet_cond: List[torch.tensor],
448
+ controlnet_mode: List[torch.tensor],
449
+ conditioning_scale: List[float],
450
+ encoder_hidden_states: torch.Tensor = None,
451
+ pooled_projections: torch.Tensor = None,
452
+ timestep: torch.LongTensor = None,
453
+ img_ids: torch.Tensor = None,
454
+ txt_ids: torch.Tensor = None,
455
+ guidance: torch.Tensor = None,
456
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
457
+ return_dict: bool = True,
458
+ ) -> Union[BriaControlNetOutput, Tuple]:
459
+ # ControlNet-Union with multiple conditions
460
+ # only load one ControlNet for saving memories
461
+ if len(self.nets) == 1 and self.nets[0].union:
462
+ controlnet = self.nets[0]
463
+
464
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
465
+ block_samples, single_block_samples = controlnet(
466
+ hidden_states=hidden_states,
467
+ controlnet_cond=image,
468
+ controlnet_mode=mode[:, None],
469
+ conditioning_scale=scale,
470
+ timestep=timestep,
471
+ guidance=guidance,
472
+ pooled_projections=pooled_projections,
473
+ encoder_hidden_states=encoder_hidden_states,
474
+ txt_ids=txt_ids,
475
+ img_ids=img_ids,
476
+ joint_attention_kwargs=joint_attention_kwargs,
477
+ return_dict=return_dict,
478
+ )
479
+
480
+ # merge samples
481
+ if i == 0:
482
+ control_block_samples = block_samples
483
+ control_single_block_samples = single_block_samples
484
+ else:
485
+ control_block_samples = [
486
+ control_block_sample + block_sample
487
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
488
+ ]
489
+
490
+ control_single_block_samples = [
491
+ control_single_block_sample + block_sample
492
+ for control_single_block_sample, block_sample in zip(
493
+ control_single_block_samples, single_block_samples
494
+ )
495
+ ]
496
+
497
+ # Regular Multi-ControlNets
498
+ # load all ControlNets into memories
499
+ else:
500
+ for i, (image, mode, scale, controlnet) in enumerate(
501
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
502
+ ):
503
+ block_samples, single_block_samples = controlnet(
504
+ hidden_states=hidden_states,
505
+ controlnet_cond=image,
506
+ controlnet_mode=mode[:, None],
507
+ conditioning_scale=scale,
508
+ timestep=timestep,
509
+ guidance=guidance,
510
+ pooled_projections=pooled_projections,
511
+ encoder_hidden_states=encoder_hidden_states,
512
+ txt_ids=txt_ids,
513
+ img_ids=img_ids,
514
+ joint_attention_kwargs=joint_attention_kwargs,
515
+ return_dict=return_dict,
516
+ )
517
+
518
+ # merge samples
519
+ if i == 0:
520
+ control_block_samples = block_samples
521
+ control_single_block_samples = single_block_samples
522
+ else:
523
+ if block_samples is not None and control_block_samples is not None:
524
+ control_block_samples = [
525
+ control_block_sample + block_sample
526
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
527
+ ]
528
+ if single_block_samples is not None and control_single_block_samples is not None:
529
+ control_single_block_samples = [
530
+ control_single_block_sample + block_sample
531
+ for control_single_block_sample, block_sample in zip(
532
+ control_single_block_samples, single_block_samples
533
+ )
534
+ ]
535
+
536
+ return control_block_samples, control_single_block_samples
537
+
538
+
539
+
540
+ class BriaMultiControlNetModel(ModelMixin):
541
+ r"""
542
+ `BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
543
+
544
+ This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
545
+ compatible with `BriaControlNetModel`.
546
+
547
+ Args:
548
+ controlnets (`List[BriaControlNetModel]`):
549
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
550
+ `BriaControlNetModel` as a list.
551
+ """
552
+
553
+ def __init__(self, controlnets):
554
+ super().__init__()
555
+ self.nets = nn.ModuleList(controlnets)
556
+
557
+ def forward(
558
+ self,
559
+ hidden_states: torch.FloatTensor,
560
+ controlnet_cond: List[torch.tensor],
561
+ controlnet_mode: List[torch.tensor],
562
+ conditioning_scale: List[float],
563
+ encoder_hidden_states: torch.Tensor = None,
564
+ pooled_projections: torch.Tensor = None,
565
+ timestep: torch.LongTensor = None,
566
+ img_ids: torch.Tensor = None,
567
+ txt_ids: torch.Tensor = None,
568
+ guidance: torch.Tensor = None,
569
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
570
+ return_dict: bool = True,
571
+ ) -> Union[BriaControlNetOutput, Tuple]:
572
+ # ControlNet-Union with multiple conditions
573
+ # only load one ControlNet for saving memories
574
+ if len(self.nets) == 1 and self.nets[0].union:
575
+ controlnet = self.nets[0]
576
+
577
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
578
+ block_samples, single_block_samples = controlnet(
579
+ hidden_states=hidden_states,
580
+ controlnet_cond=image,
581
+ controlnet_mode=mode[:, None],
582
+ conditioning_scale=scale,
583
+ timestep=timestep,
584
+ guidance=guidance,
585
+ pooled_projections=pooled_projections,
586
+ encoder_hidden_states=encoder_hidden_states,
587
+ txt_ids=txt_ids,
588
+ img_ids=img_ids,
589
+ joint_attention_kwargs=joint_attention_kwargs,
590
+ return_dict=return_dict,
591
+ )
592
+
593
+ # merge samples
594
+ if i == 0:
595
+ control_block_samples = block_samples
596
+ control_single_block_samples = single_block_samples
597
+ else:
598
+ control_block_samples = [
599
+ control_block_sample + block_sample
600
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
601
+ ]
602
+
603
+ control_single_block_samples = [
604
+ control_single_block_sample + block_sample
605
+ for control_single_block_sample, block_sample in zip(
606
+ control_single_block_samples, single_block_samples
607
+ )
608
+ ]
609
+
610
+ # Regular Multi-ControlNets
611
+ # load all ControlNets into memories
612
+ else:
613
+ for i, (image, mode, scale, controlnet) in enumerate(
614
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
615
+ ):
616
+ block_samples, single_block_samples = controlnet(
617
+ hidden_states=hidden_states,
618
+ controlnet_cond=image,
619
+ controlnet_mode=mode[:, None],
620
+ conditioning_scale=scale,
621
+ timestep=timestep,
622
+ guidance=guidance,
623
+ pooled_projections=pooled_projections,
624
+ encoder_hidden_states=encoder_hidden_states,
625
+ txt_ids=txt_ids,
626
+ img_ids=img_ids,
627
+ joint_attention_kwargs=joint_attention_kwargs,
628
+ return_dict=return_dict,
629
+ )
630
+
631
+ # merge samples
632
+ if i == 0:
633
+ control_block_samples = block_samples
634
+ control_single_block_samples = single_block_samples
635
+ else:
636
+ if block_samples is not None and control_block_samples is not None:
637
+ control_block_samples = [
638
+ control_block_sample + block_sample
639
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
640
+ ]
641
+ if single_block_samples is not None and control_single_block_samples is not None:
642
+ control_single_block_samples = [
643
+ control_single_block_sample + block_sample
644
+ for control_single_block_sample, block_sample in zip(
645
+ control_single_block_samples, single_block_samples
646
+ )
647
+ ]
648
+
649
+ return control_block_samples, control_single_block_samples
pipeline_bria_controlnet.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, Dict, List, Optional, Union
16
+ import torch
17
+ from transformers import (
18
+ T5EncoderModel,
19
+ T5TokenizerFast,
20
+ )
21
+ from diffusers.image_processor import PipelineImageInput
22
+
23
+ from diffusers import AutoencoderKL # Waiting for diffusers udpdate
24
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
25
+ from diffusers.schedulers import KarrasDiffusionSchedulers
26
+ from diffusers.utils import logging
27
+ from diffusers.pipelines.flux.pipeline_output import BriaPipelineOutput
28
+ from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps
29
+ from .controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel
30
+
31
+ from .pipeline_bria import BriaPipeline
32
+ from transformer_bria import BriaTransformer2DModel
33
+ from bria_utils import get_original_sigmas
34
+
35
+ XLA_AVAILABLE = False
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class BriaControlNetPipeline(BriaPipeline):
42
+ r"""
43
+ Args:
44
+ transformer ([`SD3Transformer2DModel`]):
45
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
46
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
47
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
48
+ vae ([`AutoencoderKL`]):
49
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
50
+ text_encoder ([`T5EncoderModel`]):
51
+ Frozen text-encoder. Stable Diffusion 3 uses
52
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
53
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
54
+ tokenizer (`T5TokenizerFast`):
55
+ Tokenizer of class
56
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
57
+ """
58
+
59
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
60
+ _optional_components = []
61
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
62
+
63
+ def __init__( # EYAL - removed clip text encoder + tokenizer
64
+ self,
65
+ transformer: BriaTransformer2DModel,
66
+ scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
67
+ vae: AutoencoderKL,
68
+ text_encoder: T5EncoderModel,
69
+ tokenizer: T5TokenizerFast,
70
+ controlnet: BriaControlNetModel,
71
+ ):
72
+ super().__init__(
73
+ transformer=transformer, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer
74
+ )
75
+ self.register_modules(controlnet=controlnet)
76
+
77
+ def prepare_image(
78
+ self,
79
+ image,
80
+ width,
81
+ height,
82
+ batch_size,
83
+ num_images_per_prompt,
84
+ device,
85
+ dtype,
86
+ do_classifier_free_guidance=False,
87
+ guess_mode=False,
88
+ ):
89
+ if isinstance(image, torch.Tensor):
90
+ pass
91
+ else:
92
+ image = self.image_processor.preprocess(image, height=height, width=width)
93
+
94
+ image_batch_size = image.shape[0]
95
+
96
+ if image_batch_size == 1:
97
+ repeat_by = batch_size
98
+ else:
99
+ # image batch size is the same as prompt batch size
100
+ repeat_by = num_images_per_prompt
101
+
102
+ image = image.repeat_interleave(repeat_by, dim=0)
103
+
104
+ image = image.to(device=device, dtype=dtype)
105
+
106
+ if do_classifier_free_guidance and not guess_mode:
107
+ image = torch.cat([image] * 2)
108
+
109
+ return image
110
+
111
+ def prepare_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
112
+ num_channels_latents = self.transformer.config.in_channels // 4
113
+ control_image = self.prepare_image(
114
+ image=control_image,
115
+ width=width,
116
+ height=height,
117
+ batch_size=batch_size * num_images_per_prompt,
118
+ num_images_per_prompt=num_images_per_prompt,
119
+ device=device,
120
+ dtype=self.vae.dtype,
121
+ )
122
+ height, width = control_image.shape[-2:]
123
+
124
+ # vae encode
125
+ control_image = self.vae.encode(control_image).latent_dist.sample()
126
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
127
+
128
+ # pack
129
+ height_control_image, width_control_image = control_image.shape[2:]
130
+ control_image = self._pack_latents(
131
+ control_image,
132
+ batch_size * num_images_per_prompt,
133
+ num_channels_latents,
134
+ height_control_image,
135
+ width_control_image,
136
+ )
137
+
138
+ # Here we ensure that `control_mode` has the same length as the control_image.
139
+ if control_mode is not None:
140
+ if not isinstance(control_mode, int):
141
+ raise ValueError(" For `BriaControlNet`, `control_mode` should be an `int` or `None`")
142
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
143
+ control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
144
+
145
+ return control_image, control_mode
146
+
147
+ def prepare_multi_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
148
+ num_channels_latents = self.transformer.config.in_channels // 4
149
+ control_images = []
150
+ for i, control_image_ in enumerate(control_image):
151
+ control_image_ = self.prepare_image(
152
+ image=control_image_,
153
+ width=width,
154
+ height=height,
155
+ batch_size=batch_size * num_images_per_prompt,
156
+ num_images_per_prompt=num_images_per_prompt,
157
+ device=device,
158
+ dtype=self.vae.dtype,
159
+ )
160
+ height, width = control_image_.shape[-2:]
161
+
162
+ # vae encode
163
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
164
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
165
+
166
+ # pack
167
+ height_control_image, width_control_image = control_image_.shape[2:]
168
+ control_image_ = self._pack_latents(
169
+ control_image_,
170
+ batch_size * num_images_per_prompt,
171
+ num_channels_latents,
172
+ height_control_image,
173
+ width_control_image,
174
+ )
175
+ control_images.append(control_image_)
176
+
177
+ control_image = control_images
178
+
179
+ # Here we ensure that `control_mode` has the same length as the control_image.
180
+ if isinstance(control_mode, list) and len(control_mode) != len(control_image):
181
+ raise ValueError(
182
+ "For Multi-ControlNet, `control_mode` must be a list of the same "
183
+ + " length as the number of controlnets (control images) specified"
184
+ )
185
+ if not isinstance(control_mode, list):
186
+ control_mode = [control_mode] * len(control_image)
187
+ # set control mode
188
+ control_modes = []
189
+ for cmode in control_mode:
190
+ if cmode is None:
191
+ cmode = -1
192
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
193
+ control_modes.append(control_mode)
194
+ control_mode = control_modes
195
+
196
+ return control_image, control_mode
197
+
198
+ def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end):
199
+ controlnet_keep = []
200
+ for i in range(len(timesteps)):
201
+ keeps = [
202
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
203
+ for s, e in zip(control_guidance_start, control_guidance_end)
204
+ ]
205
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps)
206
+ return controlnet_keep
207
+
208
+ def get_control_start_end(self, control_guidance_start, control_guidance_end):
209
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
210
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
211
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
212
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
213
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
214
+ mult = 1 # TODO - why is this 1?
215
+ control_guidance_start, control_guidance_end = (
216
+ mult * [control_guidance_start],
217
+ mult * [control_guidance_end],
218
+ )
219
+
220
+ return control_guidance_start, control_guidance_end
221
+
222
+ @torch.no_grad()
223
+ def __call__(
224
+ self,
225
+ prompt: Union[str, List[str]] = None,
226
+ height: Optional[int] = None,
227
+ width: Optional[int] = None,
228
+ num_inference_steps: int = 30,
229
+ timesteps: List[int] = None,
230
+ guidance_scale: float = 3.5,
231
+ control_guidance_start: Union[float, List[float]] = 0.0,
232
+ control_guidance_end: Union[float, List[float]] = 1.0,
233
+ control_image: Optional[PipelineImageInput] = None,
234
+ control_mode: Optional[Union[int, List[int]]] = None,
235
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
236
+ negative_prompt: Optional[Union[str, List[str]]] = None,
237
+ num_images_per_prompt: Optional[int] = 1,
238
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
239
+ latents: Optional[torch.FloatTensor] = None,
240
+ prompt_embeds: Optional[torch.FloatTensor] = None,
241
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
242
+ output_type: Optional[str] = "pil",
243
+ return_dict: bool = True,
244
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
245
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
246
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
247
+ max_sequence_length: int = 128,
248
+ ):
249
+ r"""
250
+ Function invoked when calling the pipeline for generation.
251
+
252
+ Args:
253
+ prompt (`str` or `List[str]`, *optional*):
254
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
255
+ instead.
256
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
257
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
258
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
259
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
260
+ num_inference_steps (`int`, *optional*, defaults to 50):
261
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
262
+ expense of slower inference.
263
+ timesteps (`List[int]`, *optional*):
264
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
265
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
266
+ passed will be used. Must be in descending order.
267
+ guidance_scale (`float`, *optional*, defaults to 5.0):
268
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
269
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
270
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
271
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
272
+ usually at the expense of lower image quality.
273
+ negative_prompt (`str` or `List[str]`, *optional*):
274
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
275
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
276
+ less than `1`).
277
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
278
+ The number of images to generate per prompt.
279
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
280
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
281
+ to make generation deterministic.
282
+ latents (`torch.FloatTensor`, *optional*):
283
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
284
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
285
+ tensor will ge generated by sampling using the supplied random `generator`.
286
+ prompt_embeds (`torch.FloatTensor`, *optional*):
287
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
288
+ provided, text embeddings will be generated from `prompt` input argument.
289
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
290
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
291
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
292
+ argument.
293
+ output_type (`str`, *optional*, defaults to `"pil"`):
294
+ The output format of the generate image. Choose between
295
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
296
+ return_dict (`bool`, *optional*, defaults to `True`):
297
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
298
+ of a plain tuple.
299
+ joint_attention_kwargs (`dict`, *optional*):
300
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
301
+ `self.processor` in
302
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
303
+ callback_on_step_end (`Callable`, *optional*):
304
+ A function that calls at the end of each denoising steps during the inference. The function is called
305
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
306
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
307
+ `callback_on_step_end_tensor_inputs`.
308
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
309
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
310
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
311
+ `._callback_tensor_inputs` attribute of your pipeline class.
312
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
313
+
314
+ Examples:
315
+
316
+ Returns:
317
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
318
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
319
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
320
+ """
321
+
322
+ height = height or self.default_sample_size * self.vae_scale_factor
323
+ width = width or self.default_sample_size * self.vae_scale_factor
324
+ control_guidance_start, control_guidance_end = self.get_control_start_end(
325
+ control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end
326
+ )
327
+
328
+ # 1. Check inputs. Raise error if not correct
329
+ self.check_inputs(
330
+ prompt,
331
+ height,
332
+ width,
333
+ negative_prompt=negative_prompt,
334
+ prompt_embeds=prompt_embeds,
335
+ negative_prompt_embeds=negative_prompt_embeds,
336
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
337
+ max_sequence_length=max_sequence_length,
338
+ )
339
+
340
+ self._guidance_scale = guidance_scale
341
+ self._joint_attention_kwargs = joint_attention_kwargs
342
+ self._interrupt = False
343
+
344
+ # 2. Define call parameters
345
+ if prompt is not None and isinstance(prompt, str):
346
+ batch_size = 1
347
+ elif prompt is not None and isinstance(prompt, list):
348
+ batch_size = len(prompt)
349
+ else:
350
+ batch_size = prompt_embeds.shape[0]
351
+
352
+ device = self._execution_device
353
+
354
+ lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
355
+
356
+ (prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
357
+ prompt=prompt,
358
+ negative_prompt=negative_prompt,
359
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
360
+ prompt_embeds=prompt_embeds,
361
+ negative_prompt_embeds=negative_prompt_embeds,
362
+ device=device,
363
+ num_images_per_prompt=num_images_per_prompt,
364
+ max_sequence_length=max_sequence_length,
365
+ lora_scale=lora_scale,
366
+ )
367
+
368
+ if self.do_classifier_free_guidance:
369
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
370
+
371
+ # 3. Prepare control image
372
+ if control_image is not None:
373
+ if isinstance(self.controlnet, BriaControlNetModel):
374
+ control_image, control_mode = self.prepare_control(
375
+ control_image=control_image,
376
+ width=width,
377
+ height=height,
378
+ batch_size=batch_size,
379
+ num_images_per_prompt=num_images_per_prompt,
380
+ device=device,
381
+ control_mode=control_mode,
382
+ )
383
+ elif isinstance(self.controlnet, BriaMultiControlNetModel):
384
+ control_image, control_mode = self.prepare_multi_control(
385
+ control_image=control_image,
386
+ width=width,
387
+ height=height,
388
+ batch_size=batch_size,
389
+ num_images_per_prompt=num_images_per_prompt,
390
+ device=device,
391
+ control_mode=control_mode,
392
+ )
393
+
394
+ # 4. Prepare timesteps
395
+ # Sample from training sigmas
396
+ sigmas = get_original_sigmas(
397
+ num_train_timesteps=self.scheduler.config.num_train_timesteps, num_inference_steps=num_inference_steps
398
+ )
399
+ timesteps, num_inference_steps = retrieve_timesteps(
400
+ self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
401
+ )
402
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
403
+ self._num_timesteps = len(timesteps)
404
+
405
+ # 5. Prepare latent variables
406
+ num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
407
+ latents, latent_image_ids = self.prepare_latents(
408
+ batch_size=batch_size * num_images_per_prompt,
409
+ num_channels_latents=num_channels_latents,
410
+ height=height,
411
+ width=width,
412
+ dtype=prompt_embeds.dtype,
413
+ device=device,
414
+ generator=generator,
415
+ latents=latents,
416
+ )
417
+
418
+ # 6. Create tensor stating which controlnets to keep
419
+ if control_image is not None:
420
+ controlnet_keep = self.get_controlnet_keep(
421
+ timesteps=timesteps,
422
+ control_guidance_start=control_guidance_start,
423
+ control_guidance_end=control_guidance_end,
424
+ )
425
+
426
+ # EYAL - added the CFG loop
427
+ # 7. Denoising loop
428
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
429
+ for i, t in enumerate(timesteps):
430
+ if self.interrupt:
431
+ continue
432
+
433
+ # expand the latents if we are doing classifier free guidance
434
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
435
+ # if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
436
+ if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
437
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
438
+
439
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
440
+ timestep = t.expand(latent_model_input.shape[0])
441
+
442
+ # Handling ControlNet
443
+ if control_image is not None:
444
+ if isinstance(controlnet_keep[i], list):
445
+ if isinstance(controlnet_conditioning_scale, list):
446
+ cond_scale = controlnet_conditioning_scale
447
+ else:
448
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
449
+ else:
450
+ controlnet_cond_scale = controlnet_conditioning_scale
451
+ if isinstance(controlnet_cond_scale, list):
452
+ controlnet_cond_scale = controlnet_cond_scale[0]
453
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
454
+
455
+ # controlnet
456
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
457
+ hidden_states=latents,
458
+ controlnet_cond=control_image,
459
+ controlnet_mode=control_mode,
460
+ conditioning_scale=cond_scale,
461
+ timestep=timestep,
462
+ # guidance=guidance,
463
+ # pooled_projections=pooled_prompt_embeds,
464
+ encoder_hidden_states=prompt_embeds,
465
+ txt_ids=text_ids,
466
+ img_ids=latent_image_ids,
467
+ joint_attention_kwargs=self.joint_attention_kwargs,
468
+ return_dict=False,
469
+ )
470
+ else:
471
+ controlnet_block_samples, controlnet_single_block_samples = None, None
472
+
473
+ # This is predicts "v" from flow-matching
474
+ noise_pred = self.transformer(
475
+ hidden_states=latent_model_input,
476
+ timestep=timestep,
477
+ encoder_hidden_states=prompt_embeds,
478
+ joint_attention_kwargs=self.joint_attention_kwargs,
479
+ return_dict=False,
480
+ txt_ids=text_ids,
481
+ img_ids=latent_image_ids,
482
+ controlnet_block_samples=controlnet_block_samples,
483
+ controlnet_single_block_samples=controlnet_single_block_samples,
484
+ )[0]
485
+
486
+ # perform guidance
487
+ if self.do_classifier_free_guidance:
488
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
489
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
490
+
491
+ # compute the previous noisy sample x_t -> x_t-1
492
+ latents_dtype = latents.dtype
493
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
494
+
495
+ if latents.dtype != latents_dtype:
496
+ if torch.backends.mps.is_available():
497
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
498
+ latents = latents.to(latents_dtype)
499
+
500
+ if callback_on_step_end is not None:
501
+ callback_kwargs = {}
502
+ for k in callback_on_step_end_tensor_inputs:
503
+ callback_kwargs[k] = locals()[k]
504
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
505
+
506
+ latents = callback_outputs.pop("latents", latents)
507
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
508
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
509
+
510
+ # call the callback, if provided
511
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
512
+ progress_bar.update()
513
+
514
+ if XLA_AVAILABLE:
515
+ xm.mark_step()
516
+
517
+ if output_type == "latent":
518
+ image = latents
519
+
520
+ else:
521
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
522
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
523
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
524
+ image = self.image_processor.postprocess(image, output_type=output_type)
525
+
526
+ # Offload all models
527
+ self.maybe_free_model_hooks()
528
+
529
+ if not return_dict:
530
+ return (image,)
531
+
532
+ return BriaPipelineOutput(images=image)