AlanB commited on
Commit
d4fdf03
1 Parent(s): 2b49720

Replaced with the right pipeline

Browse files
Files changed (1) hide show
  1. pipeline.py +512 -675
pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -13,30 +13,23 @@
13
  # limitations under the License.
14
 
15
  import inspect
16
- from typing import Callable, Dict, List, Optional, Tuple, Union
17
 
18
- import numpy as np
19
  import torch
20
  from transformers import (
21
- BertModel,
22
- BertTokenizer,
23
- CLIPImageProcessor,
24
- MT5Tokenizer,
25
  T5EncoderModel,
 
26
  )
27
 
28
- from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
29
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
- from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
31
- from diffusers.models.embeddings import get_2d_rotary_pos_embed
32
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
33
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
34
- from diffusers.pipelines.stable_diffusion.safety_checker import (
35
- StableDiffusionSafetyChecker,
36
- )
37
- from diffusers.schedulers import DDPMScheduler
38
  from diffusers.utils import (
39
- deprecate,
40
  is_torch_xla_available,
41
  logging,
42
  replace_example_docstring,
@@ -58,114 +51,28 @@ EXAMPLE_DOC_STRING = """
58
  Examples:
59
  ```py
60
  >>> import torch
61
- >>> from diffusers import FlowMatchEulerDiscreteScheduler
62
- >>> from diffusers.utils import load_image
63
- >>> from PIL import Image
64
- >>> from torchvision import transforms
65
- >>> from pipeline_hunyuandit_differential_img2img import HunyuanDiTDifferentialImg2ImgPipeline
66
- >>> pipe = HunyuanDiTDifferentialImg2ImgPipeline.from_pretrained(
67
- >>> "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16
68
- >>> ).to("cuda")
69
- >>> source_image = load_image(
70
- >>> "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png"
71
- >>> )
72
- >>> map = load_image(
73
- >>> "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask_2.png"
74
- >>> )
75
- >>> prompt = "a green pear"
76
- >>> negative_prompt = "blurry"
77
- >>> image = pipe(
78
- >>> prompt=prompt,
79
- >>> negative_prompt=negative_prompt,
80
- >>> image=source_image,
81
- >>> num_inference_steps=28,
82
- >>> guidance_scale=4.5,
83
- >>> strength=1.0,
84
- >>> map=map,
85
- >>> ).images[0]
86
-
87
- ```
88
- """
89
 
90
- STANDARD_RATIO = np.array(
91
- [
92
- 1.0, # 1:1
93
- 4.0 / 3.0, # 4:3
94
- 3.0 / 4.0, # 3:4
95
- 16.0 / 9.0, # 16:9
96
- 9.0 / 16.0, # 9:16
97
- ]
98
- )
99
- STANDARD_SHAPE = [
100
- [(1024, 1024), (1280, 1280)], # 1:1
101
- [(1024, 768), (1152, 864), (1280, 960)], # 4:3
102
- [(768, 1024), (864, 1152), (960, 1280)], # 3:4
103
- [(1280, 768)], # 16:9
104
- [(768, 1280)], # 9:16
105
- ]
106
- STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]
107
- SUPPORTED_SHAPE = [
108
- (1024, 1024),
109
- (1280, 1280), # 1:1
110
- (1024, 768),
111
- (1152, 864),
112
- (1280, 960), # 4:3
113
- (768, 1024),
114
- (864, 1152),
115
- (960, 1280), # 3:4
116
- (1280, 768), # 16:9
117
- (768, 1280), # 9:16
118
- ]
119
-
120
-
121
- def map_to_standard_shapes(target_width, target_height):
122
- target_ratio = target_width / target_height
123
- closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
124
- closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
125
- width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
126
- return width, height
127
-
128
-
129
- def get_resize_crop_region_for_grid(src, tgt_size):
130
- th = tw = tgt_size
131
- h, w = src
132
-
133
- r = h / w
134
-
135
- # resize
136
- if r > 1:
137
- resize_height = th
138
- resize_width = int(round(th / h * w))
139
- else:
140
- resize_width = tw
141
- resize_height = int(round(tw / w * h))
142
 
143
- crop_top = int(round((th - resize_height) / 2.0))
144
- crop_left = int(round((tw - resize_width) / 2.0))
 
 
145
 
146
- return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
 
147
 
 
148
 
149
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
150
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
151
- """
152
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
153
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
154
- """
155
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
156
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
157
- # rescale the results from guidance (fixes overexposure)
158
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
159
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
160
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
161
- return noise_cfg
162
 
163
 
164
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
165
  def retrieve_latents(
166
- encoder_output: torch.Tensor,
167
- generator: Optional[torch.Generator] = None,
168
- sample_mode: str = "sample",
169
  ):
170
  if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
171
  return encoder_output.latent_dist.sample(generator)
@@ -237,138 +144,225 @@ def retrieve_timesteps(
237
  return timesteps, num_inference_steps
238
 
239
 
240
- class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
241
  r"""
242
- Differential Pipeline for English/Chinese-to-image generation using HunyuanDiT.
243
-
244
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
245
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
246
-
247
- HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
248
- ourselves)
249
-
250
  Args:
 
 
 
 
251
  vae ([`AutoencoderKL`]):
252
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use
253
- `sdxl-vae-fp16-fix`.
254
- text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
255
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
256
- HunyuanDiT uses a fine-tuned [bilingual CLIP].
257
- tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
258
- A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
259
- transformer ([`HunyuanDiT2DModel`]):
260
- The HunyuanDiT model designed by Tencent Hunyuan.
261
- text_encoder_2 (`T5EncoderModel`):
262
- The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
263
- tokenizer_2 (`MT5Tokenizer`):
264
- The tokenizer for the mT5 embedder.
265
- scheduler ([`DDPMScheduler`]):
266
- A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
 
 
 
 
 
 
 
 
 
267
  """
268
 
269
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
270
- _optional_components = [
271
- "safety_checker",
272
- "feature_extractor",
273
- "text_encoder_2",
274
- "tokenizer_2",
275
- "text_encoder",
276
- "tokenizer",
277
- ]
278
- _exclude_from_cpu_offload = ["safety_checker"]
279
- _callback_tensor_inputs = [
280
- "latents",
281
- "prompt_embeds",
282
- "negative_prompt_embeds",
283
- "prompt_embeds_2",
284
- "negative_prompt_embeds_2",
285
- ]
286
 
287
  def __init__(
288
  self,
 
 
289
  vae: AutoencoderKL,
290
- text_encoder: BertModel,
291
- tokenizer: BertTokenizer,
292
- transformer: HunyuanDiT2DModel,
293
- scheduler: DDPMScheduler,
294
- safety_checker: StableDiffusionSafetyChecker,
295
- feature_extractor: CLIPImageProcessor,
296
- requires_safety_checker: bool = True,
297
- text_encoder_2=T5EncoderModel,
298
- tokenizer_2=MT5Tokenizer,
299
  ):
300
  super().__init__()
301
 
302
  self.register_modules(
303
  vae=vae,
304
  text_encoder=text_encoder,
 
 
305
  tokenizer=tokenizer,
306
  tokenizer_2=tokenizer_2,
 
307
  transformer=transformer,
308
  scheduler=scheduler,
309
- safety_checker=safety_checker,
310
- feature_extractor=feature_extractor,
311
- text_encoder_2=text_encoder_2,
 
 
 
 
312
  )
313
 
314
- if safety_checker is None and requires_safety_checker:
315
- logger.warning(
316
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
317
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
318
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
319
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
320
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
321
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
322
- )
323
 
324
- if safety_checker is not None and feature_extractor is None:
325
- raise ValueError(
326
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
327
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  )
329
 
330
- self.vae_scale_factor = (
331
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
332
- )
333
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
334
- self.mask_processor = VaeImageProcessor(
335
- vae_scale_factor=self.vae_scale_factor,
336
- do_normalize=False,
337
- do_convert_grayscale=True,
338
  )
339
- self.register_to_config(requires_safety_checker=requires_safety_checker)
340
- self.default_sample_size = (
341
- self.transformer.config.sample_size
342
- if hasattr(self, "transformer") and self.transformer is not None
343
- else 128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  )
345
 
346
- # copied from diffusers.pipelines.huanyuandit.pipeline_huanyuandit.HunyuanDiTPipeline.encode_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  def encode_prompt(
348
  self,
349
- prompt: str,
350
- device: torch.device = None,
351
- dtype: torch.dtype = None,
 
352
  num_images_per_prompt: int = 1,
353
  do_classifier_free_guidance: bool = True,
354
- negative_prompt: Optional[str] = None,
355
- prompt_embeds: Optional[torch.Tensor] = None,
356
- negative_prompt_embeds: Optional[torch.Tensor] = None,
357
- prompt_attention_mask: Optional[torch.Tensor] = None,
358
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
359
- max_sequence_length: Optional[int] = None,
360
- text_encoder_index: int = 0,
 
 
361
  ):
362
  r"""
363
- Encodes the prompt into text encoder hidden states.
364
 
365
  Args:
366
  prompt (`str` or `List[str]`, *optional*):
367
  prompt to be encoded
 
 
 
 
 
 
368
  device: (`torch.device`):
369
  torch device
370
- dtype (`torch.dtype`):
371
- torch dtype
372
  num_images_per_prompt (`int`):
373
  number of images that should be generated per prompt
374
  do_classifier_free_guidance (`bool`):
@@ -377,194 +371,155 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
377
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
378
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
379
  less than `1`).
380
- prompt_embeds (`torch.Tensor`, *optional*):
 
 
 
 
 
 
381
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
382
  provided, text embeddings will be generated from `prompt` input argument.
383
- negative_prompt_embeds (`torch.Tensor`, *optional*):
384
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
385
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
386
  argument.
387
- prompt_attention_mask (`torch.Tensor`, *optional*):
388
- Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
389
- negative_prompt_attention_mask (`torch.Tensor`, *optional*):
390
- Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
391
- max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
392
- text_encoder_index (`int`, *optional*):
393
- Index of the text encoder to use. `0` for clip and `1` for T5.
 
 
 
394
  """
395
- if dtype is None:
396
- if self.text_encoder_2 is not None:
397
- dtype = self.text_encoder_2.dtype
398
- elif self.transformer is not None:
399
- dtype = self.transformer.dtype
400
- else:
401
- dtype = None
402
-
403
- if device is None:
404
- device = self._execution_device
405
-
406
- tokenizers = [self.tokenizer, self.tokenizer_2]
407
- text_encoders = [self.text_encoder, self.text_encoder_2]
408
-
409
- tokenizer = tokenizers[text_encoder_index]
410
- text_encoder = text_encoders[text_encoder_index]
411
-
412
- if max_sequence_length is None:
413
- if text_encoder_index == 0:
414
- max_length = 77
415
- if text_encoder_index == 1:
416
- max_length = 256
417
- else:
418
- max_length = max_sequence_length
419
 
420
- if prompt is not None and isinstance(prompt, str):
421
- batch_size = 1
422
- elif prompt is not None and isinstance(prompt, list):
423
  batch_size = len(prompt)
424
  else:
425
  batch_size = prompt_embeds.shape[0]
426
 
427
  if prompt_embeds is None:
428
- text_inputs = tokenizer(
429
- prompt,
430
- padding="max_length",
431
- max_length=max_length,
432
- truncation=True,
433
- return_attention_mask=True,
434
- return_tensors="pt",
 
 
 
 
 
435
  )
436
- text_input_ids = text_inputs.input_ids
437
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
438
-
439
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
440
- text_input_ids, untruncated_ids
441
- ):
442
- removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
443
- logger.warning(
444
- "The following part of your input was truncated because CLIP can only handle sequences up to"
445
- f" {tokenizer.model_max_length} tokens: {removed_text}"
446
- )
447
 
448
- prompt_attention_mask = text_inputs.attention_mask.to(device)
449
- prompt_embeds = text_encoder(
450
- text_input_ids.to(device),
451
- attention_mask=prompt_attention_mask,
 
452
  )
453
- prompt_embeds = prompt_embeds[0]
454
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
455
 
456
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
 
 
457
 
458
- bs_embed, seq_len, _ = prompt_embeds.shape
459
- # duplicate text embeddings for each generation per prompt, using mps friendly method
460
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
461
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
462
 
463
- # get unconditional embeddings for classifier free guidance
464
  if do_classifier_free_guidance and negative_prompt_embeds is None:
465
- uncond_tokens: List[str]
466
- if negative_prompt is None:
467
- uncond_tokens = [""] * batch_size
468
- elif prompt is not None and type(prompt) is not type(negative_prompt):
 
 
 
 
 
 
 
 
 
 
469
  raise TypeError(
470
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
471
  f" {type(prompt)}."
472
  )
473
- elif isinstance(negative_prompt, str):
474
- uncond_tokens = [negative_prompt]
475
  elif batch_size != len(negative_prompt):
476
  raise ValueError(
477
  f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
478
  f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
479
  " the batch size of `prompt`."
480
  )
481
- else:
482
- uncond_tokens = negative_prompt
483
-
484
- max_length = prompt_embeds.shape[1]
485
- uncond_input = tokenizer(
486
- uncond_tokens,
487
- padding="max_length",
488
- max_length=max_length,
489
- truncation=True,
490
- return_tensors="pt",
491
- )
492
 
493
- negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
494
- negative_prompt_embeds = text_encoder(
495
- uncond_input.input_ids.to(device),
496
- attention_mask=negative_prompt_attention_mask,
 
 
497
  )
498
- negative_prompt_embeds = negative_prompt_embeds[0]
499
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
500
-
501
- if do_classifier_free_guidance:
502
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
503
- seq_len = negative_prompt_embeds.shape[1]
504
-
505
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
506
 
507
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
508
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
 
 
509
 
510
- return (
511
- prompt_embeds,
512
- negative_prompt_embeds,
513
- prompt_attention_mask,
514
- negative_prompt_attention_mask,
515
- )
516
 
517
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
518
- def run_safety_checker(self, image, device, dtype):
519
- if self.safety_checker is None:
520
- has_nsfw_concept = None
521
- else:
522
- if torch.is_tensor(image):
523
- feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
524
- else:
525
- feature_extractor_input = self.image_processor.numpy_to_pil(image)
526
- safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
527
- image, has_nsfw_concept = self.safety_checker(
528
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
529
  )
530
- return image, has_nsfw_concept
531
-
532
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
533
- def prepare_extra_step_kwargs(self, generator, eta):
534
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
535
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
536
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
537
- # and should be between [0, 1]
538
-
539
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
540
- extra_step_kwargs = {}
541
- if accepts_eta:
542
- extra_step_kwargs["eta"] = eta
543
-
544
- # check if the scheduler accepts generator
545
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
546
- if accepts_generator:
547
- extra_step_kwargs["generator"] = generator
548
- return extra_step_kwargs
549
 
550
  def check_inputs(
551
  self,
552
  prompt,
553
- height,
554
- width,
 
555
  negative_prompt=None,
 
 
556
  prompt_embeds=None,
557
  negative_prompt_embeds=None,
558
- prompt_attention_mask=None,
559
- negative_prompt_attention_mask=None,
560
- prompt_embeds_2=None,
561
- negative_prompt_embeds_2=None,
562
- prompt_attention_mask_2=None,
563
- negative_prompt_attention_mask_2=None,
564
  callback_on_step_end_tensor_inputs=None,
 
565
  ):
566
- if height % 8 != 0 or width % 8 != 0:
567
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
 
568
  if callback_on_step_end_tensor_inputs is not None and not all(
569
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
570
  ):
@@ -577,36 +532,43 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
577
  f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
578
  " only forward one of the two."
579
  )
580
- elif prompt is None and prompt_embeds is None:
581
  raise ValueError(
582
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
 
 
 
 
 
 
583
  )
584
- elif prompt is None and prompt_embeds_2 is None:
585
  raise ValueError(
586
- "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
587
  )
588
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
589
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
590
-
591
- if prompt_embeds is not None and prompt_attention_mask is None:
592
- raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
593
-
594
- if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
595
- raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
596
 
597
  if negative_prompt is not None and negative_prompt_embeds is not None:
598
  raise ValueError(
599
  f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
600
  f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
601
  )
602
-
603
- if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
604
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
605
-
606
- if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
 
607
  raise ValueError(
608
- "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
 
609
  )
 
610
  if prompt_embeds is not None and negative_prompt_embeds is not None:
611
  if prompt_embeds.shape != negative_prompt_embeds.shape:
612
  raise ValueError(
@@ -614,38 +576,31 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
614
  f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
615
  f" {negative_prompt_embeds.shape}."
616
  )
617
- if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
618
- if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
619
- raise ValueError(
620
- "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
621
- f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
622
- f" {negative_prompt_embeds_2.shape}."
623
- )
624
 
625
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
 
 
 
 
 
 
 
 
 
 
 
 
626
  def get_timesteps(self, num_inference_steps, strength, device):
627
  # get the original timestep using init_timestep
628
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
629
 
630
- t_start = max(num_inference_steps - init_timestep, 0)
631
  timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
632
- if hasattr(self.scheduler, "set_begin_index"):
633
- self.scheduler.set_begin_index(t_start * self.scheduler.order)
634
 
635
  return timesteps, num_inference_steps - t_start
636
 
637
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
638
  def prepare_latents(
639
- self,
640
- batch_size,
641
- num_channels_latents,
642
- height,
643
- width,
644
- image,
645
- timestep,
646
- dtype,
647
- device,
648
- generator=None,
649
  ):
650
  shape = (
651
  batch_size,
@@ -655,6 +610,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
655
  )
656
 
657
  image = image.to(device=device, dtype=dtype)
 
658
  if isinstance(generator, list) and len(generator) != batch_size:
659
  raise ValueError(
660
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -665,25 +621,13 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
665
  retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)
666
  ]
667
  init_latents = torch.cat(init_latents, dim=0)
668
-
669
  else:
670
  init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
671
 
672
- init_latents = init_latents * self.vae.config.scaling_factor
 
673
  if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
674
  # expand init_latents for batch_size
675
- deprecation_message = (
676
- f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
677
- " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
678
- " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
679
- " your script to pass as many initial images as text prompts to suppress this warning."
680
- )
681
- deprecate(
682
- "len(prompt) != len(image)",
683
- "1.0.0",
684
- deprecation_message,
685
- standard_warn=False,
686
- )
687
  additional_image_per_prompt = batch_size // init_latents.shape[0]
688
  init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
689
  elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
@@ -696,9 +640,8 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
696
  shape = init_latents.shape
697
  noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
698
 
699
- # get latents
700
- init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
701
- latents = init_latents
702
 
703
  return latents
704
 
@@ -707,8 +650,8 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
707
  return self._guidance_scale
708
 
709
  @property
710
- def guidance_rescale(self):
711
- return self._guidance_rescale
712
 
713
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
714
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -730,185 +673,144 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
730
  def __call__(
731
  self,
732
  prompt: Union[str, List[str]] = None,
733
- image: PipelineImageInput = None,
734
- strength: float = 0.8,
735
  height: Optional[int] = None,
736
  width: Optional[int] = None,
737
- num_inference_steps: Optional[int] = 50,
 
 
738
  timesteps: List[int] = None,
739
- sigmas: List[float] = None,
740
- guidance_scale: Optional[float] = 5.0,
741
  negative_prompt: Optional[Union[str, List[str]]] = None,
 
 
742
  num_images_per_prompt: Optional[int] = 1,
743
- eta: Optional[float] = 0.0,
744
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
745
- latents: Optional[torch.Tensor] = None,
746
- prompt_embeds: Optional[torch.Tensor] = None,
747
- prompt_embeds_2: Optional[torch.Tensor] = None,
748
- negative_prompt_embeds: Optional[torch.Tensor] = None,
749
- negative_prompt_embeds_2: Optional[torch.Tensor] = None,
750
- prompt_attention_mask: Optional[torch.Tensor] = None,
751
- prompt_attention_mask_2: Optional[torch.Tensor] = None,
752
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
753
- negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
754
  output_type: Optional[str] = "pil",
755
  return_dict: bool = True,
756
- callback_on_step_end: Optional[
757
- Union[
758
- Callable[[int, int, Dict], None],
759
- PipelineCallback,
760
- MultiPipelineCallbacks,
761
- ]
762
- ] = None,
763
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
764
- guidance_rescale: float = 0.0,
765
- original_size: Optional[Tuple[int, int]] = (1024, 1024),
766
- target_size: Optional[Tuple[int, int]] = None,
767
- crops_coords_top_left: Tuple[int, int] = (0, 0),
768
- use_resolution_binning: bool = True,
769
  map: PipelineImageInput = None,
770
- denoising_start: Optional[float] = None,
771
  ):
772
  r"""
773
- The call function to the pipeline for generation with HunyuanDiT.
774
 
775
  Args:
776
  prompt (`str` or `List[str]`, *optional*):
777
- The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
778
- image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
779
- `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
780
- numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
781
- or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
782
- list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
783
- latents as `image`, but if passing latents directly it is not encoded again.
784
- strength (`float`, *optional*, defaults to 0.8):
785
- Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
786
- starting point and more noise is added the higher the `strength`. The number of denoising steps depends
787
- on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
788
- process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
789
- essentially ignores `image`.
790
- height (`int`):
791
- The height in pixels of the generated image.
792
- width (`int`):
793
- The width in pixels of the generated image.
794
  num_inference_steps (`int`, *optional*, defaults to 50):
795
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
796
- expense of slower inference. This parameter is modulated by `strength`.
797
  timesteps (`List[int]`, *optional*):
798
  Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
799
  in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
800
  passed will be used. Must be in descending order.
801
- sigmas (`List[float]`, *optional*):
802
- Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
803
- their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
804
- will be used.
805
- guidance_scale (`float`, *optional*, defaults to 7.5):
806
- A higher guidance scale value encourages the model to generate images closely linked to the text
807
- `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
808
  negative_prompt (`str` or `List[str]`, *optional*):
809
- The prompt or prompts to guide what to not include in image generation. If not defined, you need to
810
- pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
 
 
 
 
 
 
 
811
  num_images_per_prompt (`int`, *optional*, defaults to 1):
812
  The number of images to generate per prompt.
813
- eta (`float`, *optional*, defaults to 0.0):
814
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
815
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
816
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
817
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
818
- generation deterministic.
819
- prompt_embeds (`torch.Tensor`, *optional*):
820
- Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
821
- provided, text embeddings are generated from the `prompt` input argument.
822
- prompt_embeds_2 (`torch.Tensor`, *optional*):
823
- Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
824
- provided, text embeddings are generated from the `prompt` input argument.
825
- negative_prompt_embeds (`torch.Tensor`, *optional*):
826
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
827
- not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
828
- negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
829
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
830
- not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
831
- prompt_attention_mask (`torch.Tensor`, *optional*):
832
- Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
833
- prompt_attention_mask_2 (`torch.Tensor`, *optional*):
834
- Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly.
835
- negative_prompt_attention_mask (`torch.Tensor`, *optional*):
836
- Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
837
- negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
838
- Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly.
839
  output_type (`str`, *optional*, defaults to `"pil"`):
840
- The output format of the generated image. Choose between `PIL.Image` or `np.array`.
 
841
  return_dict (`bool`, *optional*, defaults to `True`):
842
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
843
- plain tuple.
844
- callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
845
- A callback function or a list of callback functions to be called at the end of each denoising step.
846
- callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
847
- A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
848
- inputs will be passed.
849
- guidance_rescale (`float`, *optional*, defaults to 0.0):
850
- Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
851
- Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
852
- original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
853
- The original size of the image. Used to calculate the time ids.
854
- target_size (`Tuple[int, int]`, *optional*):
855
- The target size of the image. Used to calculate the time ids.
856
- crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
857
- The top left coordinates of the crop. Used to calculate the time ids.
858
- use_resolution_binning (`bool`, *optional*, defaults to `True`):
859
- Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest
860
- standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960,
861
- 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`.
862
- denoising_start (`float`, *optional*):
863
- When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
864
- bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
865
- it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
866
- strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
867
- is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
868
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
869
  Examples:
870
 
871
  Returns:
872
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
873
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
874
- otherwise a `tuple` is returned where the first element is a list with the generated images and the
875
- second element is a list of `bool`s indicating whether the corresponding generated image contains
876
- "not-safe-for-work" (nsfw) content.
877
  """
878
 
879
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
880
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
881
-
882
- # 0. default height and width
883
  height = height or self.default_sample_size * self.vae_scale_factor
884
  width = width or self.default_sample_size * self.vae_scale_factor
885
- height = int((height // 16) * 16)
886
- width = int((width // 16) * 16)
887
-
888
- if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE:
889
- width, height = map_to_standard_shapes(width, height)
890
- height = int(height)
891
- width = int(width)
892
- logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}")
893
 
894
  # 1. Check inputs. Raise error if not correct
895
  self.check_inputs(
896
  prompt,
897
- height,
898
- width,
899
- negative_prompt,
900
- prompt_embeds,
901
- negative_prompt_embeds,
902
- prompt_attention_mask,
903
- negative_prompt_attention_mask,
904
- prompt_embeds_2,
905
- negative_prompt_embeds_2,
906
- prompt_attention_mask_2,
907
- negative_prompt_attention_mask_2,
908
- callback_on_step_end_tensor_inputs,
909
  )
 
910
  self._guidance_scale = guidance_scale
911
- self._guidance_rescale = guidance_rescale
912
  self._interrupt = False
913
 
914
  # 2. Define call parameters
@@ -921,59 +823,42 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
921
 
922
  device = self._execution_device
923
 
924
- # 3. Encode input prompt
925
-
926
  (
927
  prompt_embeds,
928
  negative_prompt_embeds,
929
- prompt_attention_mask,
930
- negative_prompt_attention_mask,
931
  ) = self.encode_prompt(
932
  prompt=prompt,
933
- device=device,
934
- dtype=self.transformer.dtype,
935
- num_images_per_prompt=num_images_per_prompt,
936
- do_classifier_free_guidance=self.do_classifier_free_guidance,
937
  negative_prompt=negative_prompt,
 
 
 
938
  prompt_embeds=prompt_embeds,
939
  negative_prompt_embeds=negative_prompt_embeds,
940
- prompt_attention_mask=prompt_attention_mask,
941
- negative_prompt_attention_mask=negative_prompt_attention_mask,
942
- max_sequence_length=77,
943
- text_encoder_index=0,
944
- )
945
- (
946
- prompt_embeds_2,
947
- negative_prompt_embeds_2,
948
- prompt_attention_mask_2,
949
- negative_prompt_attention_mask_2,
950
- ) = self.encode_prompt(
951
- prompt=prompt,
952
  device=device,
953
- dtype=self.transformer.dtype,
954
  num_images_per_prompt=num_images_per_prompt,
955
- do_classifier_free_guidance=self.do_classifier_free_guidance,
956
- negative_prompt=negative_prompt,
957
- prompt_embeds=prompt_embeds_2,
958
- negative_prompt_embeds=negative_prompt_embeds_2,
959
- prompt_attention_mask=prompt_attention_mask_2,
960
- negative_prompt_attention_mask=negative_prompt_attention_mask_2,
961
- max_sequence_length=256,
962
- text_encoder_index=1,
963
  )
964
 
965
- # 4. Preprocess image
 
 
 
 
966
  init_image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
 
967
  map = self.mask_processor.preprocess(
968
- map,
969
- height=height // self.vae_scale_factor,
970
- width=width // self.vae_scale_factor,
971
  ).to(device)
972
 
973
- # 5. Prepare timesteps
974
- timesteps, num_inference_steps = retrieve_timesteps(
975
- self.scheduler, num_inference_steps, device, timesteps, sigmas
976
- )
977
 
978
  # begin diff diff change
979
  total_time_steps = num_inference_steps
@@ -982,58 +867,25 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
982
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
983
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
984
 
985
- # 6. Prepare latent variables
986
  num_channels_latents = self.transformer.config.in_channels
987
- latents = self.prepare_latents(
988
- batch_size * num_images_per_prompt,
989
- num_channels_latents,
990
- height,
991
- width,
992
- init_image,
993
- latent_timestep,
994
- prompt_embeds.dtype,
995
- device,
996
- generator,
997
- )
998
-
999
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1000
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1001
-
1002
- # 8. create image_rotary_emb, style embedding & time ids
1003
- grid_height = height // 8 // self.transformer.config.patch_size
1004
- grid_width = width // 8 // self.transformer.config.patch_size
1005
- base_size = 512 // 8 // self.transformer.config.patch_size
1006
- grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
1007
- image_rotary_emb = get_2d_rotary_pos_embed(
1008
- self.transformer.inner_dim // self.transformer.num_heads,
1009
- grid_crops_coords,
1010
- (grid_height, grid_width),
1011
- )
1012
-
1013
- style = torch.tensor([0], device=device)
1014
 
1015
- target_size = target_size or (height, width)
1016
- add_time_ids = list(original_size + target_size + crops_coords_top_left)
1017
- add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
1018
 
1019
- if self.do_classifier_free_guidance:
1020
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1021
- prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
1022
- prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
1023
- prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
1024
- add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
1025
- style = torch.cat([style] * 2, dim=0)
1026
-
1027
- prompt_embeds = prompt_embeds.to(device=device)
1028
- prompt_attention_mask = prompt_attention_mask.to(device=device)
1029
- prompt_embeds_2 = prompt_embeds_2.to(device=device)
1030
- prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
1031
- add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
1032
- batch_size * num_images_per_prompt, 1
1033
- )
1034
- style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
1035
- # 9. Denoising loop
1036
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1037
  # preparations for diff diff
1038
  original_with_noise = self.prepare_latents(
1039
  batch_size * num_images_per_prompt,
@@ -1048,15 +900,16 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
1048
  )
1049
  thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps
1050
  thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)
1051
- masks = map.squeeze() > (thresholds + (denoising_start or 0))
1052
  # end diff diff preparations
1053
- self._num_timesteps = len(timesteps)
1054
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1055
  for i, t in enumerate(timesteps):
1056
  if self.interrupt:
1057
  continue
 
1058
  # diff diff
1059
- if i == 0 and denoising_start is None:
1060
  latents = original_with_noise[:1]
1061
  else:
1062
  mask = masks[i].unsqueeze(0).to(latents.dtype)
@@ -1066,40 +919,30 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
1066
 
1067
  # expand the latents if we are doing classifier free guidance
1068
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1069
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
1070
 
1071
- # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
1072
- t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
1073
- dtype=latent_model_input.dtype
1074
- )
1075
-
1076
- # predict the noise residual
1077
  noise_pred = self.transformer(
1078
- latent_model_input,
1079
- t_expand,
1080
  encoder_hidden_states=prompt_embeds,
1081
- text_embedding_mask=prompt_attention_mask,
1082
- encoder_hidden_states_t5=prompt_embeds_2,
1083
- text_embedding_mask_t5=prompt_attention_mask_2,
1084
- image_meta_size=add_time_ids,
1085
- style=style,
1086
- image_rotary_emb=image_rotary_emb,
1087
  return_dict=False,
1088
  )[0]
1089
 
1090
- noise_pred, _ = noise_pred.chunk(2, dim=1)
1091
-
1092
  # perform guidance
1093
  if self.do_classifier_free_guidance:
1094
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1095
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1096
-
1097
- if self.do_classifier_free_guidance and guidance_rescale > 0.0:
1098
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1099
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1100
 
1101
  # compute the previous noisy sample x_t -> x_t-1
1102
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
 
 
 
 
 
1103
 
1104
  if callback_on_step_end is not None:
1105
  callback_kwargs = {}
@@ -1111,9 +954,8 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
1111
  latents = callback_outputs.pop("latents", latents)
1112
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1113
  negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1114
- prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
1115
- negative_prompt_embeds_2 = callback_outputs.pop(
1116
- "negative_prompt_embeds_2", negative_prompt_embeds_2
1117
  )
1118
 
1119
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1122,24 +964,19 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
1122
  if XLA_AVAILABLE:
1123
  xm.mark_step()
1124
 
1125
- if not output_type == "latent":
1126
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1127
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1128
- else:
1129
  image = latents
1130
- has_nsfw_concept = None
1131
 
1132
- if has_nsfw_concept is None:
1133
- do_denormalize = [True] * image.shape[0]
1134
  else:
1135
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1136
 
1137
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
 
1138
 
1139
  # Offload all models
1140
  self.maybe_free_model_hooks()
1141
 
1142
  if not return_dict:
1143
- return (image, has_nsfw_concept)
1144
 
1145
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
13
  # limitations under the License.
14
 
15
  import inspect
16
+ from typing import Callable, Dict, List, Optional, Union
17
 
 
18
  import torch
19
  from transformers import (
20
+ CLIPTextModelWithProjection,
21
+ CLIPTokenizer,
 
 
22
  T5EncoderModel,
23
+ T5TokenizerFast,
24
  )
25
 
 
26
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
27
+ from diffusers.models.autoencoders import AutoencoderKL
28
+ from diffusers.models.transformers import SD3Transformer2DModel
29
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
31
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
 
 
 
32
  from diffusers.utils import (
 
33
  is_torch_xla_available,
34
  logging,
35
  replace_example_docstring,
 
51
  Examples:
52
  ```py
53
  >>> import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ >>> from diffusers import AutoPipelineForImage2Image
56
+ >>> from diffusers.utils import load_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ >>> device = "cuda"
59
+ >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
60
+ >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
61
+ >>> pipe = pipe.to(device)
62
 
63
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
64
+ >>> init_image = load_image(url).resize((512, 512))
65
 
66
+ >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
67
 
68
+ >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]
69
+ ```
70
+ """
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
74
  def retrieve_latents(
75
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
 
 
76
  ):
77
  if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
78
  return encoder_output.latent_dist.sample(generator)
 
144
  return timesteps, num_inference_steps
145
 
146
 
147
+ class StableDiffusion3DifferentialImg2ImgPipeline(DiffusionPipeline):
148
  r"""
 
 
 
 
 
 
 
 
149
  Args:
150
+ transformer ([`SD3Transformer2DModel`]):
151
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
152
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
153
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
154
  vae ([`AutoencoderKL`]):
155
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
156
+ text_encoder ([`CLIPTextModelWithProjection`]):
157
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
158
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
159
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
160
+ as its dimension.
161
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
162
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
163
+ specifically the
164
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
165
+ variant.
166
+ text_encoder_3 ([`T5EncoderModel`]):
167
+ Frozen text-encoder. Stable Diffusion 3 uses
168
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
169
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
170
+ tokenizer (`CLIPTokenizer`):
171
+ Tokenizer of class
172
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
173
+ tokenizer_2 (`CLIPTokenizer`):
174
+ Second Tokenizer of class
175
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
176
+ tokenizer_3 (`T5TokenizerFast`):
177
+ Tokenizer of class
178
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
179
  """
180
 
181
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
182
+ _optional_components = []
183
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  def __init__(
186
  self,
187
+ transformer: SD3Transformer2DModel,
188
+ scheduler: FlowMatchEulerDiscreteScheduler,
189
  vae: AutoencoderKL,
190
+ text_encoder: CLIPTextModelWithProjection,
191
+ tokenizer: CLIPTokenizer,
192
+ text_encoder_2: CLIPTextModelWithProjection,
193
+ tokenizer_2: CLIPTokenizer,
194
+ text_encoder_3: T5EncoderModel,
195
+ tokenizer_3: T5TokenizerFast,
 
 
 
196
  ):
197
  super().__init__()
198
 
199
  self.register_modules(
200
  vae=vae,
201
  text_encoder=text_encoder,
202
+ text_encoder_2=text_encoder_2,
203
+ text_encoder_3=text_encoder_3,
204
  tokenizer=tokenizer,
205
  tokenizer_2=tokenizer_2,
206
+ tokenizer_3=tokenizer_3,
207
  transformer=transformer,
208
  scheduler=scheduler,
209
+ )
210
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
211
+ self.image_processor = VaeImageProcessor(
212
+ vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
213
+ )
214
+ self.mask_processor = VaeImageProcessor(
215
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True
216
  )
217
 
218
+ self.tokenizer_max_length = self.tokenizer.model_max_length
219
+ self.default_sample_size = self.transformer.config.sample_size
 
 
 
 
 
 
 
220
 
221
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
222
+ def _get_t5_prompt_embeds(
223
+ self,
224
+ prompt: Union[str, List[str]] = None,
225
+ num_images_per_prompt: int = 1,
226
+ max_sequence_length: int = 256,
227
+ device: Optional[torch.device] = None,
228
+ dtype: Optional[torch.dtype] = None,
229
+ ):
230
+ device = device or self._execution_device
231
+ dtype = dtype or self.text_encoder.dtype
232
+
233
+ prompt = [prompt] if isinstance(prompt, str) else prompt
234
+ batch_size = len(prompt)
235
+
236
+ if self.text_encoder_3 is None:
237
+ return torch.zeros(
238
+ (
239
+ batch_size * num_images_per_prompt,
240
+ self.tokenizer_max_length,
241
+ self.transformer.config.joint_attention_dim,
242
+ ),
243
+ device=device,
244
+ dtype=dtype,
245
  )
246
 
247
+ text_inputs = self.tokenizer_3(
248
+ prompt,
249
+ padding="max_length",
250
+ max_length=max_sequence_length,
251
+ truncation=True,
252
+ add_special_tokens=True,
253
+ return_tensors="pt",
 
254
  )
255
+ text_input_ids = text_inputs.input_ids
256
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
257
+
258
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
259
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
260
+ logger.warning(
261
+ "The following part of your input was truncated because `max_sequence_length` is set to "
262
+ f" {max_sequence_length} tokens: {removed_text}"
263
+ )
264
+
265
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
266
+
267
+ dtype = self.text_encoder_3.dtype
268
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
269
+
270
+ _, seq_len, _ = prompt_embeds.shape
271
+
272
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
273
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
274
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
275
+
276
+ return prompt_embeds
277
+
278
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
279
+ def _get_clip_prompt_embeds(
280
+ self,
281
+ prompt: Union[str, List[str]],
282
+ num_images_per_prompt: int = 1,
283
+ device: Optional[torch.device] = None,
284
+ clip_skip: Optional[int] = None,
285
+ clip_model_index: int = 0,
286
+ ):
287
+ device = device or self._execution_device
288
+
289
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
290
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
291
+
292
+ tokenizer = clip_tokenizers[clip_model_index]
293
+ text_encoder = clip_text_encoders[clip_model_index]
294
+
295
+ prompt = [prompt] if isinstance(prompt, str) else prompt
296
+ batch_size = len(prompt)
297
+
298
+ text_inputs = tokenizer(
299
+ prompt,
300
+ padding="max_length",
301
+ max_length=self.tokenizer_max_length,
302
+ truncation=True,
303
+ return_tensors="pt",
304
  )
305
 
306
+ text_input_ids = text_inputs.input_ids
307
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
308
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
309
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
310
+ logger.warning(
311
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
312
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
313
+ )
314
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
315
+ pooled_prompt_embeds = prompt_embeds[0]
316
+
317
+ if clip_skip is None:
318
+ prompt_embeds = prompt_embeds.hidden_states[-2]
319
+ else:
320
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
321
+
322
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
323
+
324
+ _, seq_len, _ = prompt_embeds.shape
325
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
326
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
327
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
328
+
329
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
330
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
331
+
332
+ return prompt_embeds, pooled_prompt_embeds
333
+
334
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
335
  def encode_prompt(
336
  self,
337
+ prompt: Union[str, List[str]],
338
+ prompt_2: Union[str, List[str]],
339
+ prompt_3: Union[str, List[str]],
340
+ device: Optional[torch.device] = None,
341
  num_images_per_prompt: int = 1,
342
  do_classifier_free_guidance: bool = True,
343
+ negative_prompt: Optional[Union[str, List[str]]] = None,
344
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
345
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
346
+ prompt_embeds: Optional[torch.FloatTensor] = None,
347
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
348
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
349
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
350
+ clip_skip: Optional[int] = None,
351
+ max_sequence_length: int = 256,
352
  ):
353
  r"""
 
354
 
355
  Args:
356
  prompt (`str` or `List[str]`, *optional*):
357
  prompt to be encoded
358
+ prompt_2 (`str` or `List[str]`, *optional*):
359
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
360
+ used in all text-encoders
361
+ prompt_3 (`str` or `List[str]`, *optional*):
362
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
363
+ used in all text-encoders
364
  device: (`torch.device`):
365
  torch device
 
 
366
  num_images_per_prompt (`int`):
367
  number of images that should be generated per prompt
368
  do_classifier_free_guidance (`bool`):
 
371
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
372
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
373
  less than `1`).
374
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
375
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
376
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
377
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
378
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
379
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
380
+ prompt_embeds (`torch.FloatTensor`, *optional*):
381
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
382
  provided, text embeddings will be generated from `prompt` input argument.
383
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
384
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
385
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
386
  argument.
387
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
388
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
389
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
390
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
391
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
392
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
393
+ input argument.
394
+ clip_skip (`int`, *optional*):
395
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
396
+ the output of the pre-final layer will be used for computing the prompt embeddings.
397
  """
398
+ device = device or self._execution_device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
+ prompt = [prompt] if isinstance(prompt, str) else prompt
401
+ if prompt is not None:
 
402
  batch_size = len(prompt)
403
  else:
404
  batch_size = prompt_embeds.shape[0]
405
 
406
  if prompt_embeds is None:
407
+ prompt_2 = prompt_2 or prompt
408
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
409
+
410
+ prompt_3 = prompt_3 or prompt
411
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
412
+
413
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
414
+ prompt=prompt,
415
+ device=device,
416
+ num_images_per_prompt=num_images_per_prompt,
417
+ clip_skip=clip_skip,
418
+ clip_model_index=0,
419
  )
420
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
421
+ prompt=prompt_2,
422
+ device=device,
423
+ num_images_per_prompt=num_images_per_prompt,
424
+ clip_skip=clip_skip,
425
+ clip_model_index=1,
426
+ )
427
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
 
 
 
428
 
429
+ t5_prompt_embed = self._get_t5_prompt_embeds(
430
+ prompt=prompt_3,
431
+ num_images_per_prompt=num_images_per_prompt,
432
+ max_sequence_length=max_sequence_length,
433
+ device=device,
434
  )
 
 
435
 
436
+ clip_prompt_embeds = torch.nn.functional.pad(
437
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
438
+ )
439
 
440
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
441
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
 
 
442
 
 
443
  if do_classifier_free_guidance and negative_prompt_embeds is None:
444
+ negative_prompt = negative_prompt or ""
445
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
446
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
447
+
448
+ # normalize str to list
449
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
450
+ negative_prompt_2 = (
451
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
452
+ )
453
+ negative_prompt_3 = (
454
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
455
+ )
456
+
457
+ if prompt is not None and type(prompt) is not type(negative_prompt):
458
  raise TypeError(
459
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
460
  f" {type(prompt)}."
461
  )
 
 
462
  elif batch_size != len(negative_prompt):
463
  raise ValueError(
464
  f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
465
  f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
466
  " the batch size of `prompt`."
467
  )
 
 
 
 
 
 
 
 
 
 
 
468
 
469
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
470
+ negative_prompt,
471
+ device=device,
472
+ num_images_per_prompt=num_images_per_prompt,
473
+ clip_skip=None,
474
+ clip_model_index=0,
475
  )
476
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
477
+ negative_prompt_2,
478
+ device=device,
479
+ num_images_per_prompt=num_images_per_prompt,
480
+ clip_skip=None,
481
+ clip_model_index=1,
482
+ )
483
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
484
 
485
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
486
+ prompt=negative_prompt_3,
487
+ num_images_per_prompt=num_images_per_prompt,
488
+ max_sequence_length=max_sequence_length,
489
+ device=device,
490
+ )
491
 
492
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
493
+ negative_clip_prompt_embeds,
494
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
495
+ )
 
 
496
 
497
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
498
+ negative_pooled_prompt_embeds = torch.cat(
499
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
 
 
 
 
 
 
 
 
 
500
  )
501
+
502
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
  def check_inputs(
505
  self,
506
  prompt,
507
+ prompt_2,
508
+ prompt_3,
509
+ strength,
510
  negative_prompt=None,
511
+ negative_prompt_2=None,
512
+ negative_prompt_3=None,
513
  prompt_embeds=None,
514
  negative_prompt_embeds=None,
515
+ pooled_prompt_embeds=None,
516
+ negative_pooled_prompt_embeds=None,
 
 
 
 
517
  callback_on_step_end_tensor_inputs=None,
518
+ max_sequence_length=None,
519
  ):
520
+ if strength < 0 or strength > 1:
521
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
522
+
523
  if callback_on_step_end_tensor_inputs is not None and not all(
524
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
525
  ):
 
532
  f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
533
  " only forward one of the two."
534
  )
535
+ elif prompt_2 is not None and prompt_embeds is not None:
536
  raise ValueError(
537
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
538
+ " only forward one of the two."
539
+ )
540
+ elif prompt_3 is not None and prompt_embeds is not None:
541
+ raise ValueError(
542
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
543
+ " only forward one of the two."
544
  )
545
+ elif prompt is None and prompt_embeds is None:
546
  raise ValueError(
547
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
548
  )
549
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
550
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
551
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
552
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
553
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
554
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
 
 
555
 
556
  if negative_prompt is not None and negative_prompt_embeds is not None:
557
  raise ValueError(
558
  f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
559
  f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
560
  )
561
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
562
+ raise ValueError(
563
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
564
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
565
+ )
566
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
567
  raise ValueError(
568
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
569
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
570
  )
571
+
572
  if prompt_embeds is not None and negative_prompt_embeds is not None:
573
  if prompt_embeds.shape != negative_prompt_embeds.shape:
574
  raise ValueError(
 
576
  f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
577
  f" {negative_prompt_embeds.shape}."
578
  )
 
 
 
 
 
 
 
579
 
580
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
581
+ raise ValueError(
582
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
583
+ )
584
+
585
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
586
+ raise ValueError(
587
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
588
+ )
589
+
590
+ if max_sequence_length is not None and max_sequence_length > 512:
591
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
592
+
593
  def get_timesteps(self, num_inference_steps, strength, device):
594
  # get the original timestep using init_timestep
595
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
596
 
597
+ t_start = int(max(num_inference_steps - init_timestep, 0))
598
  timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
 
 
599
 
600
  return timesteps, num_inference_steps - t_start
601
 
 
602
  def prepare_latents(
603
+ self, batch_size, num_channels_latents, height, width, image, timestep, dtype, device, generator=None
 
 
 
 
 
 
 
 
 
604
  ):
605
  shape = (
606
  batch_size,
 
610
  )
611
 
612
  image = image.to(device=device, dtype=dtype)
613
+
614
  if isinstance(generator, list) and len(generator) != batch_size:
615
  raise ValueError(
616
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
 
621
  retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)
622
  ]
623
  init_latents = torch.cat(init_latents, dim=0)
 
624
  else:
625
  init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
626
 
627
+ init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
628
+
629
  if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
630
  # expand init_latents for batch_size
 
 
 
 
 
 
 
 
 
 
 
 
631
  additional_image_per_prompt = batch_size // init_latents.shape[0]
632
  init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
633
  elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
 
640
  shape = init_latents.shape
641
  noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
642
 
643
+ init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
644
+ latents = init_latents.to(device=device, dtype=dtype)
 
645
 
646
  return latents
647
 
 
650
  return self._guidance_scale
651
 
652
  @property
653
+ def clip_skip(self):
654
+ return self._clip_skip
655
 
656
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
657
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
 
673
  def __call__(
674
  self,
675
  prompt: Union[str, List[str]] = None,
676
+ prompt_2: Optional[Union[str, List[str]]] = None,
677
+ prompt_3: Optional[Union[str, List[str]]] = None,
678
  height: Optional[int] = None,
679
  width: Optional[int] = None,
680
+ image: PipelineImageInput = None,
681
+ strength: float = 0.6,
682
+ num_inference_steps: int = 50,
683
  timesteps: List[int] = None,
684
+ guidance_scale: float = 7.0,
 
685
  negative_prompt: Optional[Union[str, List[str]]] = None,
686
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
687
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
688
  num_images_per_prompt: Optional[int] = 1,
 
689
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
690
+ latents: Optional[torch.FloatTensor] = None,
691
+ prompt_embeds: Optional[torch.FloatTensor] = None,
692
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
693
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
694
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
 
 
 
 
695
  output_type: Optional[str] = "pil",
696
  return_dict: bool = True,
697
+ clip_skip: Optional[int] = None,
698
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
 
 
 
 
 
699
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
700
+ max_sequence_length: int = 256,
 
 
 
 
701
  map: PipelineImageInput = None,
 
702
  ):
703
  r"""
704
+ Function invoked when calling the pipeline for generation.
705
 
706
  Args:
707
  prompt (`str` or `List[str]`, *optional*):
708
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
709
+ instead.
710
+ prompt_2 (`str` or `List[str]`, *optional*):
711
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
712
+ will be used instead
713
+ prompt_3 (`str` or `List[str]`, *optional*):
714
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
715
+ will be used instead
716
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
717
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
718
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
719
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
 
 
 
 
 
720
  num_inference_steps (`int`, *optional*, defaults to 50):
721
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
722
+ expense of slower inference.
723
  timesteps (`List[int]`, *optional*):
724
  Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
725
  in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
726
  passed will be used. Must be in descending order.
727
+ guidance_scale (`float`, *optional*, defaults to 5.0):
728
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
729
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
730
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
731
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
732
+ usually at the expense of lower image quality.
 
733
  negative_prompt (`str` or `List[str]`, *optional*):
734
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
735
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
736
+ less than `1`).
737
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
738
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
739
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
740
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
741
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
742
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
743
  num_images_per_prompt (`int`, *optional*, defaults to 1):
744
  The number of images to generate per prompt.
 
 
 
745
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
746
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
747
+ to make generation deterministic.
748
+ latents (`torch.FloatTensor`, *optional*):
749
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
750
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
751
+ tensor will ge generated by sampling using the supplied random `generator`.
752
+ prompt_embeds (`torch.FloatTensor`, *optional*):
753
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
754
+ provided, text embeddings will be generated from `prompt` input argument.
755
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
756
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
757
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
758
+ argument.
759
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
760
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
761
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
762
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
763
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
764
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
765
+ input argument.
 
 
766
  output_type (`str`, *optional*, defaults to `"pil"`):
767
+ The output format of the generate image. Choose between
768
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
769
  return_dict (`bool`, *optional*, defaults to `True`):
770
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
771
+ of a plain tuple.
772
+ callback_on_step_end (`Callable`, *optional*):
773
+ A function that calls at the end of each denoising steps during the inference. The function is called
774
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
775
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
776
+ `callback_on_step_end_tensor_inputs`.
777
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
778
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
779
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
780
+ `._callback_tensor_inputs` attribute of your pipeline class.
781
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
782
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
783
  Examples:
784
 
785
  Returns:
786
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
787
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
788
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
 
 
789
  """
790
 
791
+ # 0. Default height and width
 
 
 
792
  height = height or self.default_sample_size * self.vae_scale_factor
793
  width = width or self.default_sample_size * self.vae_scale_factor
 
 
 
 
 
 
 
 
794
 
795
  # 1. Check inputs. Raise error if not correct
796
  self.check_inputs(
797
  prompt,
798
+ prompt_2,
799
+ prompt_3,
800
+ strength,
801
+ negative_prompt=negative_prompt,
802
+ negative_prompt_2=negative_prompt_2,
803
+ negative_prompt_3=negative_prompt_3,
804
+ prompt_embeds=prompt_embeds,
805
+ negative_prompt_embeds=negative_prompt_embeds,
806
+ pooled_prompt_embeds=pooled_prompt_embeds,
807
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
808
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
809
+ max_sequence_length=max_sequence_length,
810
  )
811
+
812
  self._guidance_scale = guidance_scale
813
+ self._clip_skip = clip_skip
814
  self._interrupt = False
815
 
816
  # 2. Define call parameters
 
823
 
824
  device = self._execution_device
825
 
 
 
826
  (
827
  prompt_embeds,
828
  negative_prompt_embeds,
829
+ pooled_prompt_embeds,
830
+ negative_pooled_prompt_embeds,
831
  ) = self.encode_prompt(
832
  prompt=prompt,
833
+ prompt_2=prompt_2,
834
+ prompt_3=prompt_3,
 
 
835
  negative_prompt=negative_prompt,
836
+ negative_prompt_2=negative_prompt_2,
837
+ negative_prompt_3=negative_prompt_3,
838
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
839
  prompt_embeds=prompt_embeds,
840
  negative_prompt_embeds=negative_prompt_embeds,
841
+ pooled_prompt_embeds=pooled_prompt_embeds,
842
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
 
 
 
 
 
 
 
 
 
 
843
  device=device,
844
+ clip_skip=self.clip_skip,
845
  num_images_per_prompt=num_images_per_prompt,
846
+ max_sequence_length=max_sequence_length,
 
 
 
 
 
 
 
847
  )
848
 
849
+ if self.do_classifier_free_guidance:
850
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
851
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
852
+
853
+ # 3. Preprocess image
854
  init_image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
855
+
856
  map = self.mask_processor.preprocess(
857
+ map, height=height // self.vae_scale_factor, width=width // self.vae_scale_factor
 
 
858
  ).to(device)
859
 
860
+ # 4. Prepare timesteps
861
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
 
 
862
 
863
  # begin diff diff change
864
  total_time_steps = num_inference_steps
 
867
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
868
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
869
 
870
+ # 5. Prepare latent variables
871
  num_channels_latents = self.transformer.config.in_channels
872
+ if latents is None:
873
+ latents = self.prepare_latents(
874
+ batch_size * num_images_per_prompt,
875
+ num_channels_latents,
876
+ height,
877
+ width,
878
+ init_image,
879
+ latent_timestep,
880
+ prompt_embeds.dtype,
881
+ device,
882
+ generator,
883
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
884
 
885
+ # 6. Denoising loop
886
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
887
+ self._num_timesteps = len(timesteps)
888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889
  # preparations for diff diff
890
  original_with_noise = self.prepare_latents(
891
  batch_size * num_images_per_prompt,
 
900
  )
901
  thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps
902
  thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)
903
+ masks = map.squeeze() > thresholds
904
  # end diff diff preparations
905
+
906
  with self.progress_bar(total=num_inference_steps) as progress_bar:
907
  for i, t in enumerate(timesteps):
908
  if self.interrupt:
909
  continue
910
+
911
  # diff diff
912
+ if i == 0:
913
  latents = original_with_noise[:1]
914
  else:
915
  mask = masks[i].unsqueeze(0).to(latents.dtype)
 
919
 
920
  # expand the latents if we are doing classifier free guidance
921
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
922
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
923
+ timestep = t.expand(latent_model_input.shape[0])
924
 
 
 
 
 
 
 
925
  noise_pred = self.transformer(
926
+ hidden_states=latent_model_input,
927
+ timestep=timestep,
928
  encoder_hidden_states=prompt_embeds,
929
+ pooled_projections=pooled_prompt_embeds,
 
 
 
 
 
930
  return_dict=False,
931
  )[0]
932
 
 
 
933
  # perform guidance
934
  if self.do_classifier_free_guidance:
935
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
936
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
 
 
937
 
938
  # compute the previous noisy sample x_t -> x_t-1
939
+ latents_dtype = latents.dtype
940
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
941
+
942
+ if latents.dtype != latents_dtype:
943
+ if torch.backends.mps.is_available():
944
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
945
+ latents = latents.to(latents_dtype)
946
 
947
  if callback_on_step_end is not None:
948
  callback_kwargs = {}
 
954
  latents = callback_outputs.pop("latents", latents)
955
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
956
  negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
957
+ negative_pooled_prompt_embeds = callback_outputs.pop(
958
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
 
959
  )
960
 
961
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
964
  if XLA_AVAILABLE:
965
  xm.mark_step()
966
 
967
+ if output_type == "latent":
 
 
 
968
  image = latents
 
969
 
 
 
970
  else:
971
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
972
 
973
+ image = self.vae.decode(latents, return_dict=False)[0]
974
+ image = self.image_processor.postprocess(image, output_type=output_type)
975
 
976
  # Offload all models
977
  self.maybe_free_model_hooks()
978
 
979
  if not return_dict:
980
+ return (image,)
981
 
982
+ return StableDiffusion3PipelineOutput(images=image)