AlanB commited on
Commit
05645fd
1 Parent(s): c446a24

Pipeline from https://github.com/huggingface/diffusers/pull/9268

Browse files
Files changed (1) hide show
  1. pipeline.py +982 -0
pipeline.py ADDED
@@ -0,0 +1,982 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
+
22
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
23
+ from diffusers.loaders import FluxLoraLoaderMixin
24
+ from diffusers.models.autoencoders import AutoencoderKL
25
+ from diffusers.models.transformers import FluxTransformer2DModel
26
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils import (
30
+ USE_PEFT_BACKEND,
31
+ is_torch_xla_available,
32
+ logging,
33
+ replace_example_docstring,
34
+ scale_lora_layers,
35
+ unscale_lora_layers,
36
+ )
37
+ from diffusers.utils.torch_utils import randn_tensor
38
+
39
+
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> import torch
54
+ >>> from diffusers.utils import load_image
55
+ >>> from pipeline import FluxDifferentialImg2ImgPipeline
56
+
57
+ >>> image = load_image(
58
+ >>> "https://github.com/exx8/differential-diffusion/blob/main/assets/input.jpg?raw=true",
59
+ >>> )
60
+
61
+ >>> mask = load_image(
62
+ >>> "https://github.com/exx8/differential-diffusion/blob/main/assets/map.jpg?raw=true",
63
+ >>> )
64
+
65
+ >>> pipe = FluxDifferentialImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
66
+ >>> pipe.enable_model_cpu_offload()
67
+
68
+ >>> prompt = "painting of a mountain landscape with a meadow and a forest, meadow background, anime countryside landscape, anime nature wallpap, anime landscape wallpaper, studio ghibli landscape, anime landscape, mountain behind meadow, anime background art, studio ghibli environment, background of flowery hill, anime beautiful peace scene, forrest background, anime scenery, landscape background, background art, anime scenery concept art"
69
+ >>> out = pipe(
70
+ >>> prompt=prompt,
71
+ >>> num_inference_steps=20,
72
+ >>> guidance_scale=7.5,
73
+ >>> image=image,
74
+ >>> mask_image=mask,
75
+ >>> strength=1.0,
76
+ >>> ).images[0]
77
+
78
+ >>> out.save("image.png")
79
+ ```
80
+ """
81
+
82
+
83
+ def calculate_shift(
84
+ image_seq_len,
85
+ base_seq_len: int = 256,
86
+ max_seq_len: int = 4096,
87
+ base_shift: float = 0.5,
88
+ max_shift: float = 1.16,
89
+ ):
90
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
91
+ b = base_shift - m * base_seq_len
92
+ mu = image_seq_len * m + b
93
+ return mu
94
+
95
+
96
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
97
+ def retrieve_latents(
98
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
99
+ ):
100
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
101
+ return encoder_output.latent_dist.sample(generator)
102
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
103
+ return encoder_output.latent_dist.mode()
104
+ elif hasattr(encoder_output, "latents"):
105
+ return encoder_output.latents
106
+ else:
107
+ raise AttributeError("Could not access latents of provided encoder_output")
108
+
109
+
110
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
111
+ def retrieve_timesteps(
112
+ scheduler,
113
+ num_inference_steps: Optional[int] = None,
114
+ device: Optional[Union[str, torch.device]] = None,
115
+ timesteps: Optional[List[int]] = None,
116
+ sigmas: Optional[List[float]] = None,
117
+ **kwargs,
118
+ ):
119
+ """
120
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
121
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
122
+
123
+ Args:
124
+ scheduler (`SchedulerMixin`):
125
+ The scheduler to get timesteps from.
126
+ num_inference_steps (`int`):
127
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
128
+ must be `None`.
129
+ device (`str` or `torch.device`, *optional*):
130
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
131
+ timesteps (`List[int]`, *optional*):
132
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
133
+ `num_inference_steps` and `sigmas` must be `None`.
134
+ sigmas (`List[float]`, *optional*):
135
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
136
+ `num_inference_steps` and `timesteps` must be `None`.
137
+
138
+ Returns:
139
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
140
+ second element is the number of inference steps.
141
+ """
142
+ if timesteps is not None and sigmas is not None:
143
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
144
+ if timesteps is not None:
145
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
146
+ if not accepts_timesteps:
147
+ raise ValueError(
148
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
149
+ f" timestep schedules. Please check whether you are using the correct scheduler."
150
+ )
151
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
152
+ timesteps = scheduler.timesteps
153
+ num_inference_steps = len(timesteps)
154
+ elif sigmas is not None:
155
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
156
+ if not accept_sigmas:
157
+ raise ValueError(
158
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
159
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
160
+ )
161
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
162
+ timesteps = scheduler.timesteps
163
+ num_inference_steps = len(timesteps)
164
+ else:
165
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
166
+ timesteps = scheduler.timesteps
167
+ return timesteps, num_inference_steps
168
+
169
+
170
+ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
171
+ r"""
172
+ Differential Image to Image pipeline for the Flux family of models.
173
+
174
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
175
+
176
+ Args:
177
+ transformer ([`FluxTransformer2DModel`]):
178
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
179
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
180
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
181
+ vae ([`AutoencoderKL`]):
182
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
183
+ text_encoder ([`CLIPTextModel`]):
184
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
185
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
186
+ text_encoder_2 ([`T5EncoderModel`]):
187
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
188
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
189
+ tokenizer (`CLIPTokenizer`):
190
+ Tokenizer of class
191
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
192
+ tokenizer_2 (`T5TokenizerFast`):
193
+ Second Tokenizer of class
194
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
195
+ """
196
+
197
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
198
+ _optional_components = []
199
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
200
+
201
+ def __init__(
202
+ self,
203
+ scheduler: FlowMatchEulerDiscreteScheduler,
204
+ vae: AutoencoderKL,
205
+ text_encoder: CLIPTextModel,
206
+ tokenizer: CLIPTokenizer,
207
+ text_encoder_2: T5EncoderModel,
208
+ tokenizer_2: T5TokenizerFast,
209
+ transformer: FluxTransformer2DModel,
210
+ ):
211
+ super().__init__()
212
+
213
+ self.register_modules(
214
+ vae=vae,
215
+ text_encoder=text_encoder,
216
+ text_encoder_2=text_encoder_2,
217
+ tokenizer=tokenizer,
218
+ tokenizer_2=tokenizer_2,
219
+ transformer=transformer,
220
+ scheduler=scheduler,
221
+ )
222
+ self.vae_scale_factor = (
223
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
224
+ )
225
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
226
+ self.mask_processor = VaeImageProcessor(
227
+ vae_scale_factor=self.vae_scale_factor,
228
+ vae_latent_channels=self.vae.config.latent_channels,
229
+ do_normalize=False,
230
+ do_binarize=False,
231
+ do_convert_grayscale=True,
232
+ )
233
+ self.tokenizer_max_length = (
234
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
235
+ )
236
+ self.default_sample_size = 64
237
+
238
+ def _get_t5_prompt_embeds(
239
+ self,
240
+ prompt: Union[str, List[str]] = None,
241
+ num_images_per_prompt: int = 1,
242
+ max_sequence_length: int = 512,
243
+ device: Optional[torch.device] = None,
244
+ dtype: Optional[torch.dtype] = None,
245
+ ):
246
+ device = device or self._execution_device
247
+ dtype = dtype or self.text_encoder.dtype
248
+
249
+ prompt = [prompt] if isinstance(prompt, str) else prompt
250
+ batch_size = len(prompt)
251
+
252
+ text_inputs = self.tokenizer_2(
253
+ prompt,
254
+ padding="max_length",
255
+ max_length=max_sequence_length,
256
+ truncation=True,
257
+ return_length=False,
258
+ return_overflowing_tokens=False,
259
+ return_tensors="pt",
260
+ )
261
+ text_input_ids = text_inputs.input_ids
262
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
263
+
264
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
265
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
266
+ logger.warning(
267
+ "The following part of your input was truncated because `max_sequence_length` is set to "
268
+ f" {max_sequence_length} tokens: {removed_text}"
269
+ )
270
+
271
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
272
+
273
+ dtype = self.text_encoder_2.dtype
274
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
275
+
276
+ _, seq_len, _ = prompt_embeds.shape
277
+
278
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
279
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
280
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
281
+
282
+ return prompt_embeds
283
+
284
+ def _get_clip_prompt_embeds(
285
+ self,
286
+ prompt: Union[str, List[str]],
287
+ num_images_per_prompt: int = 1,
288
+ device: Optional[torch.device] = None,
289
+ ):
290
+ device = device or self._execution_device
291
+
292
+ prompt = [prompt] if isinstance(prompt, str) else prompt
293
+ batch_size = len(prompt)
294
+
295
+ text_inputs = self.tokenizer(
296
+ prompt,
297
+ padding="max_length",
298
+ max_length=self.tokenizer_max_length,
299
+ truncation=True,
300
+ return_overflowing_tokens=False,
301
+ return_length=False,
302
+ return_tensors="pt",
303
+ )
304
+
305
+ text_input_ids = text_inputs.input_ids
306
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
307
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
308
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
309
+ logger.warning(
310
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
311
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
312
+ )
313
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
314
+
315
+ # Use pooled output of CLIPTextModel
316
+ prompt_embeds = prompt_embeds.pooler_output
317
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
318
+
319
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
320
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
321
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
322
+
323
+ return prompt_embeds
324
+
325
+ def encode_prompt(
326
+ self,
327
+ prompt: Union[str, List[str]],
328
+ prompt_2: Union[str, List[str]],
329
+ device: Optional[torch.device] = None,
330
+ num_images_per_prompt: int = 1,
331
+ prompt_embeds: Optional[torch.FloatTensor] = None,
332
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
333
+ max_sequence_length: int = 512,
334
+ lora_scale: Optional[float] = None,
335
+ ):
336
+ r"""
337
+
338
+ Args:
339
+ prompt (`str` or `List[str]`, *optional*):
340
+ prompt to be encoded
341
+ prompt_2 (`str` or `List[str]`, *optional*):
342
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
343
+ used in all text-encoders
344
+ device: (`torch.device`):
345
+ torch device
346
+ num_images_per_prompt (`int`):
347
+ number of images that should be generated per prompt
348
+ prompt_embeds (`torch.FloatTensor`, *optional*):
349
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
350
+ provided, text embeddings will be generated from `prompt` input argument.
351
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
352
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
353
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
354
+ lora_scale (`float`, *optional*):
355
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
356
+ """
357
+ device = device or self._execution_device
358
+
359
+ # set lora scale so that monkey patched LoRA
360
+ # function of text encoder can correctly access it
361
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
362
+ self._lora_scale = lora_scale
363
+
364
+ # dynamically adjust the LoRA scale
365
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
366
+ scale_lora_layers(self.text_encoder, lora_scale)
367
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
368
+ scale_lora_layers(self.text_encoder_2, lora_scale)
369
+
370
+ prompt = [prompt] if isinstance(prompt, str) else prompt
371
+ if prompt is not None:
372
+ batch_size = len(prompt)
373
+ else:
374
+ batch_size = prompt_embeds.shape[0]
375
+
376
+ if prompt_embeds is None:
377
+ prompt_2 = prompt_2 or prompt
378
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
379
+
380
+ # We only use the pooled prompt output from the CLIPTextModel
381
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
382
+ prompt=prompt,
383
+ device=device,
384
+ num_images_per_prompt=num_images_per_prompt,
385
+ )
386
+ prompt_embeds = self._get_t5_prompt_embeds(
387
+ prompt=prompt_2,
388
+ num_images_per_prompt=num_images_per_prompt,
389
+ max_sequence_length=max_sequence_length,
390
+ device=device,
391
+ )
392
+
393
+ if self.text_encoder is not None:
394
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
395
+ # Retrieve the original scale by scaling back the LoRA layers
396
+ unscale_lora_layers(self.text_encoder, lora_scale)
397
+
398
+ if self.text_encoder_2 is not None:
399
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
400
+ # Retrieve the original scale by scaling back the LoRA layers
401
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
402
+
403
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
404
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
405
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
406
+
407
+ return prompt_embeds, pooled_prompt_embeds, text_ids
408
+
409
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
410
+ if isinstance(generator, list):
411
+ image_latents = [
412
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
413
+ for i in range(image.shape[0])
414
+ ]
415
+ image_latents = torch.cat(image_latents, dim=0)
416
+ else:
417
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
418
+
419
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
420
+
421
+ return image_latents
422
+
423
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
424
+ def get_timesteps(self, timesteps, num_inference_steps, strength, device):
425
+ # get the original timestep using init_timestep
426
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
427
+
428
+ t_start = int(max(num_inference_steps - init_timestep, 0))
429
+ timesteps = timesteps[t_start * self.scheduler.order :]
430
+ if hasattr(self.scheduler, "set_begin_index"):
431
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
432
+
433
+ return timesteps, num_inference_steps - t_start
434
+
435
+ def check_inputs(
436
+ self,
437
+ prompt,
438
+ prompt_2,
439
+ strength,
440
+ height,
441
+ width,
442
+ prompt_embeds=None,
443
+ pooled_prompt_embeds=None,
444
+ callback_on_step_end_tensor_inputs=None,
445
+ max_sequence_length=None,
446
+ ):
447
+ if strength < 0 or strength > 1:
448
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
449
+
450
+ if height % 8 != 0 or width % 8 != 0:
451
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
452
+
453
+ if callback_on_step_end_tensor_inputs is not None and not all(
454
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
455
+ ):
456
+ raise ValueError(
457
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
458
+ )
459
+
460
+ if prompt is not None and prompt_embeds is not None:
461
+ raise ValueError(
462
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
463
+ " only forward one of the two."
464
+ )
465
+ elif prompt_2 is not None and prompt_embeds is not None:
466
+ raise ValueError(
467
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
468
+ " only forward one of the two."
469
+ )
470
+ elif prompt is None and prompt_embeds is None:
471
+ raise ValueError(
472
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
473
+ )
474
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
475
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
476
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
477
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
478
+
479
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
480
+ raise ValueError(
481
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
482
+ )
483
+
484
+ if max_sequence_length is not None and max_sequence_length > 512:
485
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
486
+
487
+ @staticmethod
488
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
489
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
490
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
491
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
492
+
493
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
494
+
495
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
496
+ latent_image_ids = latent_image_ids.reshape(
497
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
498
+ )
499
+
500
+ return latent_image_ids.to(device=device, dtype=dtype)
501
+
502
+ @staticmethod
503
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
504
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
505
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
506
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
507
+
508
+ return latents
509
+
510
+ @staticmethod
511
+ def _unpack_latents(latents, height, width, vae_scale_factor):
512
+ batch_size, num_patches, channels = latents.shape
513
+
514
+ height = height // vae_scale_factor
515
+ width = width // vae_scale_factor
516
+
517
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
518
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
519
+
520
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
521
+
522
+ return latents
523
+
524
+ def prepare_latents(
525
+ self,
526
+ batch_size,
527
+ num_channels_latents,
528
+ height,
529
+ width,
530
+ dtype,
531
+ device,
532
+ generator,
533
+ latents=None,
534
+ image=None,
535
+ timestep=None,
536
+ is_strength_max=None,
537
+ ):
538
+ if isinstance(generator, list) and len(generator) != batch_size:
539
+ raise ValueError(
540
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
541
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
542
+ )
543
+
544
+ if (image is None or timestep is None) and not is_strength_max:
545
+ raise ValueError(
546
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
547
+ "However, either the image or the noise timestep has not been provided."
548
+ )
549
+
550
+ height = 2 * (int(height) // self.vae_scale_factor)
551
+ width = 2 * (int(width) // self.vae_scale_factor)
552
+
553
+ shape = (batch_size, num_channels_latents, height, width)
554
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
555
+ # return latents.to(device=device, dtype=dtype), latent_image_ids
556
+
557
+ if latents is None:
558
+ image = image.to(device=device, dtype=dtype)
559
+ image_latents = self._encode_vae_image(image=image, generator=generator)
560
+ else:
561
+ image_latents = latents
562
+
563
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
564
+ latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise)
565
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
566
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
567
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
568
+ return latents, noise, image_latents, latent_image_ids
569
+
570
+ def prepare_mask_latents(
571
+ self,
572
+ mask,
573
+ masked_image,
574
+ batch_size,
575
+ num_images_per_prompt,
576
+ height,
577
+ width,
578
+ dtype,
579
+ device,
580
+ generator,
581
+ ):
582
+ # resize the mask to latents shape as we concatenate the mask to the latents
583
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
584
+ # and half precision
585
+ mask = torch.nn.functional.interpolate(
586
+ mask, size=(2 * height // self.vae_scale_factor, 2 * width // self.vae_scale_factor)
587
+ )
588
+ mask = mask.to(device=device, dtype=dtype)
589
+
590
+ batch_size = batch_size * num_images_per_prompt
591
+
592
+ masked_image = masked_image.to(device=device, dtype=dtype)
593
+
594
+ if masked_image.shape[1] == 16:
595
+ masked_image_latents = masked_image
596
+ else:
597
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
598
+
599
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
600
+
601
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
602
+ if mask.shape[0] < batch_size:
603
+ if not batch_size % mask.shape[0] == 0:
604
+ raise ValueError(
605
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
606
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
607
+ " of masks that you pass is divisible by the total requested batch size."
608
+ )
609
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
610
+ if masked_image_latents.shape[0] < batch_size:
611
+ if not batch_size % masked_image_latents.shape[0] == 0:
612
+ raise ValueError(
613
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
614
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
615
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
616
+ )
617
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
618
+
619
+ # aligning device to prevent device errors when concating it with the latent model input
620
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
621
+ return mask, masked_image_latents
622
+
623
+ @property
624
+ def guidance_scale(self):
625
+ return self._guidance_scale
626
+
627
+ @property
628
+ def joint_attention_kwargs(self):
629
+ return self._joint_attention_kwargs
630
+
631
+ @property
632
+ def num_timesteps(self):
633
+ return self._num_timesteps
634
+
635
+ @property
636
+ def interrupt(self):
637
+ return self._interrupt
638
+
639
+ @torch.no_grad()
640
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
641
+ def __call__(
642
+ self,
643
+ prompt: Union[str, List[str]] = None,
644
+ prompt_2: Optional[Union[str, List[str]]] = None,
645
+ image: PipelineImageInput = None,
646
+ mask_image: PipelineImageInput = None,
647
+ height: Optional[int] = None,
648
+ width: Optional[int] = None,
649
+ padding_mask_crop: Optional[int] = None,
650
+ strength: float = 0.6,
651
+ num_inference_steps: int = 28,
652
+ timesteps: List[int] = None,
653
+ guidance_scale: float = 7.0,
654
+ num_images_per_prompt: Optional[int] = 1,
655
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
656
+ latents: Optional[torch.FloatTensor] = None,
657
+ prompt_embeds: Optional[torch.FloatTensor] = None,
658
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
659
+ output_type: Optional[str] = "pil",
660
+ return_dict: bool = True,
661
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
662
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
663
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
664
+ max_sequence_length: int = 512,
665
+ ):
666
+ r"""
667
+ Function invoked when calling the pipeline for generation.
668
+
669
+ Args:
670
+ prompt (`str` or `List[str]`, *optional*):
671
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
672
+ instead.
673
+ prompt_2 (`str` or `List[str]`, *optional*):
674
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
675
+ will be used instead
676
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
677
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
678
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
679
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
680
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
681
+ latents as `image`, but if passing latents directly it is not encoded again.
682
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
683
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
684
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
685
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
686
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
687
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
688
+ 1)`, or `(H, W)`.
689
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
690
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
691
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
692
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
693
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
694
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
695
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
696
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
697
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
698
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
699
+ the image is large and contain information irrelevant for inpainting, such as background.
700
+ strength (`float`, *optional*, defaults to 1.0):
701
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
702
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
703
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
704
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
705
+ essentially ignores `image`.
706
+ num_inference_steps (`int`, *optional*, defaults to 50):
707
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
708
+ expense of slower inference.
709
+ timesteps (`List[int]`, *optional*):
710
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
711
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
712
+ passed will be used. Must be in descending order.
713
+ guidance_scale (`float`, *optional*, defaults to 7.0):
714
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
715
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
716
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
717
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
718
+ usually at the expense of lower image quality.
719
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
720
+ The number of images to generate per prompt.
721
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
722
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
723
+ to make generation deterministic.
724
+ latents (`torch.FloatTensor`, *optional*):
725
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
726
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
727
+ tensor will ge generated by sampling using the supplied random `generator`.
728
+ prompt_embeds (`torch.FloatTensor`, *optional*):
729
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
730
+ provided, text embeddings will be generated from `prompt` input argument.
731
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
732
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
733
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
734
+ output_type (`str`, *optional*, defaults to `"pil"`):
735
+ The output format of the generate image. Choose between
736
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
737
+ return_dict (`bool`, *optional*, defaults to `True`):
738
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
739
+ joint_attention_kwargs (`dict`, *optional*):
740
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
741
+ `self.processor` in
742
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
743
+ callback_on_step_end (`Callable`, *optional*):
744
+ A function that calls at the end of each denoising steps during the inference. The function is called
745
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
746
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
747
+ `callback_on_step_end_tensor_inputs`.
748
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
749
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
750
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
751
+ `._callback_tensor_inputs` attribute of your pipeline class.
752
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
753
+
754
+ Examples:
755
+
756
+ Returns:
757
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
758
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
759
+ images.
760
+ """
761
+
762
+ height = height or self.default_sample_size * self.vae_scale_factor
763
+ width = width or self.default_sample_size * self.vae_scale_factor
764
+
765
+ # 1. Check inputs. Raise error if not correct
766
+ self.check_inputs(
767
+ prompt,
768
+ prompt_2,
769
+ strength,
770
+ height,
771
+ width,
772
+ prompt_embeds=prompt_embeds,
773
+ pooled_prompt_embeds=pooled_prompt_embeds,
774
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
775
+ max_sequence_length=max_sequence_length,
776
+ )
777
+
778
+ self._guidance_scale = guidance_scale
779
+ self._joint_attention_kwargs = joint_attention_kwargs
780
+ self._interrupt = False
781
+ is_strength_max = strength == 1.0
782
+
783
+ # 2. Preprocess mask and image
784
+ if padding_mask_crop is not None:
785
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
786
+ resize_mode = "fill"
787
+ else:
788
+ crops_coords = None
789
+ resize_mode = "default"
790
+
791
+ original_image = image
792
+ init_image = self.image_processor.preprocess(
793
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
794
+ )
795
+ init_image = init_image.to(dtype=torch.float32)
796
+
797
+ # 3. Define call parameters
798
+ if prompt is not None and isinstance(prompt, str):
799
+ batch_size = 1
800
+ elif prompt is not None and isinstance(prompt, list):
801
+ batch_size = len(prompt)
802
+ else:
803
+ batch_size = prompt_embeds.shape[0]
804
+
805
+ device = self._execution_device
806
+
807
+ lora_scale = (
808
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
809
+ )
810
+ (
811
+ prompt_embeds,
812
+ pooled_prompt_embeds,
813
+ text_ids,
814
+ ) = self.encode_prompt(
815
+ prompt=prompt,
816
+ prompt_2=prompt_2,
817
+ prompt_embeds=prompt_embeds,
818
+ pooled_prompt_embeds=pooled_prompt_embeds,
819
+ device=device,
820
+ num_images_per_prompt=num_images_per_prompt,
821
+ max_sequence_length=max_sequence_length,
822
+ lora_scale=lora_scale,
823
+ )
824
+
825
+ # 4.Prepare timesteps
826
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
827
+ image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
828
+ mu = calculate_shift(
829
+ image_seq_len,
830
+ self.scheduler.config.base_image_seq_len,
831
+ self.scheduler.config.max_image_seq_len,
832
+ self.scheduler.config.base_shift,
833
+ self.scheduler.config.max_shift,
834
+ )
835
+ timesteps, num_inference_steps = retrieve_timesteps(
836
+ self.scheduler,
837
+ num_inference_steps,
838
+ device,
839
+ timesteps,
840
+ sigmas,
841
+ mu=mu,
842
+ )
843
+ timesteps, num_inference_steps = self.get_timesteps(timesteps, num_inference_steps, strength, device)
844
+
845
+ if num_inference_steps < 1:
846
+ raise ValueError(
847
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
848
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
849
+ )
850
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
851
+
852
+ # 5. Prepare latent variables
853
+ num_channels_latents = self.transformer.config.in_channels // 4
854
+
855
+ latents, noise, original_image_latents, latent_image_ids = self.prepare_latents(
856
+ batch_size * num_images_per_prompt,
857
+ num_channels_latents,
858
+ height,
859
+ width,
860
+ prompt_embeds.dtype,
861
+ device,
862
+ generator,
863
+ latents,
864
+ init_image,
865
+ latent_timestep,
866
+ is_strength_max,
867
+ )
868
+
869
+ # start diff diff preparation
870
+ original_mask = self.mask_processor.preprocess(
871
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
872
+ )
873
+
874
+ masked_image = init_image * original_mask
875
+ original_mask, _ = self.prepare_mask_latents(
876
+ original_mask,
877
+ masked_image,
878
+ batch_size,
879
+ num_images_per_prompt,
880
+ height,
881
+ width,
882
+ prompt_embeds.dtype,
883
+ device,
884
+ generator,
885
+ )
886
+
887
+ mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps
888
+ mask_thresholds = mask_thresholds.unsqueeze(1).unsqueeze(1).to(device)
889
+ masks = (original_mask > mask_thresholds)
890
+ masks = self._pack_latents(
891
+ masks.repeat(num_channels_latents, 1, 1, 1).permute(1, 0, 2, 3),
892
+ len(mask_thresholds),
893
+ num_channels_latents,
894
+ 2 * (int(height) // self.vae_scale_factor),
895
+ 2 * (int(width) // self.vae_scale_factor),
896
+ )
897
+ # end diff diff preparation
898
+
899
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
900
+
901
+ # 6. Denoising loop
902
+ latents_dtype = latents.dtype
903
+ # for 64 channel transformer only.
904
+ image_latent = original_image_latents
905
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
906
+ for i, t in enumerate(timesteps):
907
+ if self.interrupt:
908
+ continue
909
+
910
+ timestep = t.expand(latents.shape[0]).to(latents_dtype)
911
+
912
+ # handle guidance
913
+ if self.transformer.config.guidance_embeds:
914
+ guidance = torch.tensor([guidance_scale], device=device)
915
+ guidance = guidance.expand(latents.shape[0])
916
+ else:
917
+ guidance = None
918
+
919
+ noise_pred = self.transformer(
920
+ hidden_states=latents,
921
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
922
+ timestep=timestep / 1000,
923
+ guidance=guidance,
924
+ pooled_projections=pooled_prompt_embeds,
925
+ encoder_hidden_states=prompt_embeds,
926
+ txt_ids=text_ids,
927
+ img_ids=latent_image_ids,
928
+ joint_attention_kwargs=self.joint_attention_kwargs,
929
+ return_dict=False,
930
+ )[0]
931
+
932
+ # compute the previous noisy sample x_t -> x_t-1
933
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
934
+
935
+ if i < len(timesteps) - 1:
936
+ noise_timestep = timesteps[i + 1]
937
+ image_latent = self.scheduler.scale_noise(
938
+ original_image_latents, torch.tensor([noise_timestep]), noise
939
+ )
940
+
941
+ # start diff diff
942
+ mask = masks[i].to(latents_dtype)
943
+ latents = image_latent * mask + latents * (1 - mask)
944
+ # end diff diff
945
+
946
+ if latents.dtype != latents_dtype:
947
+ if torch.backends.mps.is_available():
948
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
949
+ latents = latents.to(latents_dtype)
950
+
951
+ if callback_on_step_end is not None:
952
+ callback_kwargs = {}
953
+ for k in callback_on_step_end_tensor_inputs:
954
+ callback_kwargs[k] = locals()[k]
955
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
956
+
957
+ latents = callback_outputs.pop("latents", latents)
958
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
959
+
960
+ # call the callback, if provided
961
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
962
+ progress_bar.update()
963
+
964
+ if XLA_AVAILABLE:
965
+ xm.mark_step()
966
+
967
+ if output_type == "latent":
968
+ image = latents
969
+
970
+ else:
971
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
972
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
973
+ image = self.vae.decode(latents, return_dict=False)[0]
974
+ image = self.image_processor.postprocess(image, output_type=output_type)
975
+
976
+ # Offload all models
977
+ self.maybe_free_model_hooks()
978
+
979
+ if not return_dict:
980
+ return (image,)
981
+
982
+ return FluxPipelineOutput(images=image)