naonauno commited on
Commit
9721396
·
verified ·
1 Parent(s): 607c81c

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +112 -145
  2. model.py +303 -0
  3. pipeline.py +1378 -0
  4. requirements.txt +558 -6
app.py CHANGED
@@ -1,154 +1,121 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
  prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
1
  import torch
2
+ import random
3
+ import datetime
4
+ import json
5
+ import os
6
+ from PIL import Image
7
+ import gradio as gr
8
+ from model import UNet2DConditionModelEx
9
+ from pipeline import StableDiffusionControlLoraV3Pipeline
10
+
11
+ def setup_pipeline():
12
+ print("Loading models...")
13
+ unet = UNet2DConditionModelEx.from_pretrained(
14
+ "runwayml/stable-diffusion-v1-5",
15
+ subfolder="unet",
16
+ torch_dtype=torch.float16
17
+ )
18
+ unet = unet.add_extra_conditions("ow-gbi-control-lora")
19
 
20
+ pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
21
+ "runwayml/stable-diffusion-v1-5",
22
+ unet=unet,
23
+ torch_dtype=torch.float16
24
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ lora_path = "owgbi-Dataset2-6000.safetensors"
27
+ pipe.load_lora_weights(lora_path)
 
 
 
 
 
 
28
 
29
+ pipe.enable_xformers_memory_efficient_attention()
30
+ pipe.enable_model_cpu_offload()
 
 
 
 
 
31
 
32
+ print("Pipeline ready!")
33
+ return pipe
 
 
 
 
 
 
34
 
35
+ pipe = setup_pipeline()
 
 
 
 
 
 
36
 
37
+ def generate_image_core(image, prompt, negative_prompt, seed, guidance_scale, steps, strength, num_images_per_prompt, guidance_rescale):
38
+ try:
39
+ pipe.unet.enable_gradient_checkpointing()
40
+ generator = torch.manual_seed(seed)
41
+ result = pipe(
42
  prompt,
43
+ negative_prompt=negative_prompt,
44
+ num_inference_steps=steps,
45
+ generator=generator,
46
+ image=image,
47
+ guidance_scale=guidance_scale,
48
+ strength=strength,
49
+ num_images_per_prompt=num_images_per_prompt,
50
+ guidance_rescale=guidance_rescale,
51
+ )
52
+ return result.images
53
+ except RuntimeError as e:
54
+ if "CUDA out of memory" in str(e):
55
+ print("Error: CUDA out of memory. Skipping this request.")
56
+ return [Image.new("RGB", (512, 512), color="gray")]
57
+ else:
58
+ raise e
59
+
60
+ def save_images(image, generated_images, prompt, negative_prompt, seed, guidance_scale, steps, strength, num_images_per_prompt, guidance_rescale):
61
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
62
+ save_path = os.path.join('generated_images', timestamp)
63
+ os.makedirs(save_path, exist_ok=True)
64
+
65
+ image.save(os.path.join(save_path, "reference.png"))
66
+ for i, gen_image in enumerate(generated_images):
67
+ gen_image.save(os.path.join(save_path, f"generated_{i}.png"))
68
+
69
+ parameters = {
70
+ "prompt": prompt,
71
+ "negative_prompt": negative_prompt,
72
+ "seed": seed,
73
+ "guidance_scale": guidance_scale,
74
+ "steps": steps,
75
+ "strength": strength,
76
+ "num_images_per_prompt": num_images_per_prompt,
77
+ "guidance_rescale": guidance_rescale
78
+ }
79
+ with open(os.path.join(save_path, "parameters.json"), "w") as f:
80
+ json.dump(parameters, f, indent=4)
81
+
82
+ def inference(image, prompt, negative_prompt, seed, guidance_scale, steps, strength, num_images_per_prompt, guidance_rescale):
83
+ try:
84
+ generated_images = generate_image_core(
85
+ image, prompt, negative_prompt, seed, guidance_scale,
86
+ steps, strength, num_images_per_prompt, guidance_rescale
87
+ )
88
+ save_images(
89
+ image, generated_images, prompt, negative_prompt, seed,
90
+ guidance_scale, steps, strength, num_images_per_prompt, guidance_rescale
91
+ )
92
+ return generated_images
93
+ except RuntimeError as e:
94
+ if "CUDA out of memory" in str(e):
95
+ print("Error: CUDA out of memory. Skipping this request.")
96
+ return [Image.new("RGB", (512, 512), color="gray")]
97
+ else:
98
+ raise e
99
+
100
+ # Create the Gradio interface
101
+ iface = gr.Interface(
102
+ fn=inference,
103
+ inputs=[
104
+ gr.Image(type="pil", label="Input Image"),
105
+ gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here..."),
106
+ gr.Textbox(label="Negative Prompt", lines=2, placeholder="Enter your negative prompt here..."),
107
+ gr.Number(label="Seed", value=random.randint(1, 10000)),
108
+ gr.Number(label="Guidance Scale", value=7.5),
109
+ gr.Number(label="Steps", value=25, precision=0),
110
+ gr.Number(label="Strength", value=0.8),
111
+ gr.Number(label="Number of Images", value=1, precision=0),
112
+ gr.Number(label="Guidance Rescale", value=1.0)
113
+ ],
114
+ outputs=gr.Gallery(label="Generated Images"),
115
+ title="Terrain Generator",
116
+ description="Generate terrain images using Stable Diffusion with ControlNet and LoRA.",
117
+ )
118
+
119
+ # Launch the interface
120
  if __name__ == "__main__":
121
+ iface.launch()
model.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import copy
4
+ import torch
5
+ from torch import nn, svd_lowrank
6
+
7
+ from peft.tuners.lora import LoraLayer, Conv2d as PeftConv2d
8
+ from diffusers.configuration_utils import register_to_config
9
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput, UNet2DConditionModel as UNet2DConditionModel
10
+
11
+
12
+ class UNet2DConditionModelEx(UNet2DConditionModel):
13
+ @register_to_config
14
+ def __init__(
15
+ self,
16
+ sample_size: Optional[int] = None,
17
+ in_channels: int = 4,
18
+ out_channels: int = 4,
19
+ center_input_sample: bool = False,
20
+ flip_sin_to_cos: bool = True,
21
+ freq_shift: int = 0,
22
+ down_block_types: Tuple[str] = (
23
+ "CrossAttnDownBlock2D",
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "DownBlock2D",
27
+ ),
28
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
29
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
30
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
31
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
32
+ layers_per_block: Union[int, Tuple[int]] = 2,
33
+ downsample_padding: int = 1,
34
+ mid_block_scale_factor: float = 1,
35
+ dropout: float = 0.0,
36
+ act_fn: str = "silu",
37
+ norm_num_groups: Optional[int] = 32,
38
+ norm_eps: float = 1e-5,
39
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
40
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
41
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
42
+ encoder_hid_dim: Optional[int] = None,
43
+ encoder_hid_dim_type: Optional[str] = None,
44
+ attention_head_dim: Union[int, Tuple[int]] = 8,
45
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
46
+ dual_cross_attention: bool = False,
47
+ use_linear_projection: bool = False,
48
+ class_embed_type: Optional[str] = None,
49
+ addition_embed_type: Optional[str] = None,
50
+ addition_time_embed_dim: Optional[int] = None,
51
+ num_class_embeds: Optional[int] = None,
52
+ upcast_attention: bool = False,
53
+ resnet_time_scale_shift: str = "default",
54
+ resnet_skip_time_act: bool = False,
55
+ resnet_out_scale_factor: float = 1.0,
56
+ time_embedding_type: str = "positional",
57
+ time_embedding_dim: Optional[int] = None,
58
+ time_embedding_act_fn: Optional[str] = None,
59
+ timestep_post_act: Optional[str] = None,
60
+ time_cond_proj_dim: Optional[int] = None,
61
+ conv_in_kernel: int = 3,
62
+ conv_out_kernel: int = 3,
63
+ projection_class_embeddings_input_dim: Optional[int] = None,
64
+ attention_type: str = "default",
65
+ class_embeddings_concat: bool = False,
66
+ mid_block_only_cross_attention: Optional[bool] = None,
67
+ cross_attention_norm: Optional[str] = None,
68
+ addition_embed_type_num_heads: int = 64,
69
+ extra_condition_names: List[str] = [],
70
+ ):
71
+ num_extra_conditions = len(extra_condition_names)
72
+ super().__init__(
73
+ sample_size=sample_size,
74
+ in_channels=in_channels * (1 + num_extra_conditions),
75
+ out_channels=out_channels,
76
+ center_input_sample=center_input_sample,
77
+ flip_sin_to_cos=flip_sin_to_cos,
78
+ freq_shift=freq_shift,
79
+ down_block_types=down_block_types,
80
+ mid_block_type=mid_block_type,
81
+ up_block_types=up_block_types,
82
+ only_cross_attention=only_cross_attention,
83
+ block_out_channels=block_out_channels,
84
+ layers_per_block=layers_per_block,
85
+ downsample_padding=downsample_padding,
86
+ mid_block_scale_factor=mid_block_scale_factor,
87
+ dropout=dropout,
88
+ act_fn=act_fn,
89
+ norm_num_groups=norm_num_groups,
90
+ norm_eps=norm_eps,
91
+ cross_attention_dim=cross_attention_dim,
92
+ transformer_layers_per_block=transformer_layers_per_block,
93
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
94
+ encoder_hid_dim=encoder_hid_dim,
95
+ encoder_hid_dim_type=encoder_hid_dim_type,
96
+ attention_head_dim=attention_head_dim,
97
+ num_attention_heads=num_attention_heads,
98
+ dual_cross_attention=dual_cross_attention,
99
+ use_linear_projection=use_linear_projection,
100
+ class_embed_type=class_embed_type,
101
+ addition_embed_type=addition_embed_type,
102
+ addition_time_embed_dim=addition_time_embed_dim,
103
+ num_class_embeds=num_class_embeds,
104
+ upcast_attention=upcast_attention,
105
+ resnet_time_scale_shift=resnet_time_scale_shift,
106
+ resnet_skip_time_act=resnet_skip_time_act,
107
+ resnet_out_scale_factor=resnet_out_scale_factor,
108
+ time_embedding_type=time_embedding_type,
109
+ time_embedding_dim=time_embedding_dim,
110
+ time_embedding_act_fn=time_embedding_act_fn,
111
+ timestep_post_act=timestep_post_act,
112
+ time_cond_proj_dim=time_cond_proj_dim,
113
+ conv_in_kernel=conv_in_kernel,
114
+ conv_out_kernel=conv_out_kernel,
115
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
116
+ attention_type=attention_type,
117
+ class_embeddings_concat=class_embeddings_concat,
118
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
119
+ cross_attention_norm=cross_attention_norm,
120
+ addition_embed_type_num_heads=addition_embed_type_num_heads,)
121
+ self._internal_dict = copy.deepcopy(self._internal_dict)
122
+ self.config.in_channels = in_channels
123
+ self.config.extra_condition_names = extra_condition_names
124
+
125
+ @property
126
+ def extra_condition_names(self) -> List[str]:
127
+ return self.config.extra_condition_names
128
+
129
+ def add_extra_conditions(self, extra_condition_names: Union[str, List[str]]):
130
+ if isinstance(extra_condition_names, str):
131
+ extra_condition_names = [extra_condition_names]
132
+ conv_in_kernel = self.config.conv_in_kernel
133
+ conv_in_weight = self.conv_in.weight
134
+ self.config.extra_condition_names += extra_condition_names
135
+ full_in_channels = self.config.in_channels * (1 + len(self.config.extra_condition_names))
136
+ new_conv_in_weight = torch.zeros(
137
+ conv_in_weight.shape[0], full_in_channels, conv_in_kernel, conv_in_kernel,
138
+ dtype=conv_in_weight.dtype,
139
+ device=conv_in_weight.device,)
140
+ new_conv_in_weight[:,:conv_in_weight.shape[1]] = conv_in_weight
141
+ self.conv_in.weight = nn.Parameter(
142
+ new_conv_in_weight.data,
143
+ requires_grad=conv_in_weight.requires_grad,)
144
+ self.conv_in.in_channels = full_in_channels
145
+
146
+ return self
147
+
148
+ def activate_extra_condition_adapters(self):
149
+ lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)]
150
+ if len(lora_layers) > 0:
151
+ self._hf_peft_config_loaded = True
152
+ for lora_layer in lora_layers:
153
+ adapter_names = [k for k in lora_layer.scaling.keys() if k in self.config.extra_condition_names]
154
+ adapter_names += lora_layer.active_adapters
155
+ adapter_names = list(set(adapter_names))
156
+ lora_layer.set_adapter(adapter_names)
157
+
158
+ def set_extra_condition_scale(self, scale: Union[float, List[float]] = 1.0):
159
+ if isinstance(scale, float):
160
+ scale = [scale] * len(self.config.extra_condition_names)
161
+
162
+ lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)]
163
+ for s, n in zip(scale, self.config.extra_condition_names):
164
+ for lora_layer in lora_layers:
165
+ lora_layer.set_scale(n, s)
166
+
167
+ @property
168
+ def default_half_lora_target_modules(self) -> List[str]:
169
+ module_names = []
170
+ for name, module in self.named_modules():
171
+ if "conv_out" in name or "up_blocks" in name:
172
+ continue
173
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
174
+ module_names.append(name)
175
+ return list(set(module_names))
176
+
177
+ @property
178
+ def default_full_lora_target_modules(self) -> List[str]:
179
+ module_names = []
180
+ for name, module in self.named_modules():
181
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
182
+ module_names.append(name)
183
+ return list(set(module_names))
184
+
185
+ @property
186
+ def default_half_skip_attn_lora_target_modules(self) -> List[str]:
187
+ return [
188
+ module_name
189
+ for module_name in self.default_half_lora_target_modules
190
+ if all(
191
+ not module_name.endswith(attn_name)
192
+ for attn_name in
193
+ ["to_k", "to_q", "to_v", "to_out.0"]
194
+ )
195
+ ]
196
+
197
+ @property
198
+ def default_full_skip_attn_lora_target_modules(self) -> List[str]:
199
+ return [
200
+ module_name
201
+ for module_name in self.default_full_lora_target_modules
202
+ if all(
203
+ not module_name.endswith(attn_name)
204
+ for attn_name in
205
+ ["to_k", "to_q", "to_v", "to_out.0"]
206
+ )
207
+ ]
208
+
209
+ def forward(
210
+ self,
211
+ sample: torch.Tensor,
212
+ timestep: Union[torch.Tensor, float, int],
213
+ encoder_hidden_states: torch.Tensor,
214
+ class_labels: Optional[torch.Tensor] = None,
215
+ timestep_cond: Optional[torch.Tensor] = None,
216
+ attention_mask: Optional[torch.Tensor] = None,
217
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
218
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
219
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
220
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
221
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
222
+ encoder_attention_mask: Optional[torch.Tensor] = None,
223
+ extra_conditions: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
224
+ return_dict: bool = True,
225
+ ) -> Union[UNet2DConditionOutput, Tuple]:
226
+ if extra_conditions is not None:
227
+ if isinstance(extra_conditions, list):
228
+ extra_conditions = torch.cat(extra_conditions, dim=1)
229
+ sample = torch.cat([sample, extra_conditions], dim=1)
230
+ return super().forward(
231
+ sample=sample,
232
+ timestep=timestep,
233
+ encoder_hidden_states=encoder_hidden_states,
234
+ class_labels=class_labels,
235
+ timestep_cond=timestep_cond,
236
+ attention_mask=attention_mask,
237
+ cross_attention_kwargs=cross_attention_kwargs,
238
+ added_cond_kwargs=added_cond_kwargs,
239
+ down_block_additional_residuals=down_block_additional_residuals,
240
+ mid_block_additional_residual=mid_block_additional_residual,
241
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
242
+ encoder_attention_mask=encoder_attention_mask,
243
+ return_dict=return_dict,)
244
+
245
+
246
+ class PeftConv2dEx(PeftConv2d):
247
+ def reset_lora_parameters(self, adapter_name, init_lora_weights):
248
+ if init_lora_weights is False:
249
+ return
250
+
251
+ if isinstance(init_lora_weights, str) and "pissa" in init_lora_weights.lower():
252
+ if self.conv2d_pissa_init(adapter_name, init_lora_weights):
253
+ return
254
+ # Failed
255
+ init_lora_weights = "gaussian"
256
+
257
+ super(PeftConv2d, self).reset_lora_parameters(adapter_name, init_lora_weights)
258
+
259
+ def conv2d_pissa_init(self, adapter_name, init_lora_weights):
260
+ weight = weight_ori = self.get_base_layer().weight
261
+ weight = weight.flatten(start_dim=1)
262
+ if self.r[adapter_name] > weight.shape[0]:
263
+ return False
264
+ dtype = weight.dtype
265
+ if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
266
+ raise TypeError(
267
+ "Please initialize PiSSA under float32, float16, or bfloat16. "
268
+ "Subsequently, re-quantize the residual model to help minimize quantization errors."
269
+ )
270
+ weight = weight.to(torch.float32)
271
+
272
+ if init_lora_weights == "pissa":
273
+ # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
274
+ V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
275
+ Vr = V[:, : self.r[adapter_name]]
276
+ Sr = S[: self.r[adapter_name]]
277
+ Sr /= self.scaling[adapter_name]
278
+ Uhr = Uh[: self.r[adapter_name]]
279
+ elif len(init_lora_weights.split("_niter_")) == 2:
280
+ Vr, Sr, Ur = svd_lowrank(
281
+ weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
282
+ )
283
+ Sr /= self.scaling[adapter_name]
284
+ Uhr = Ur.t()
285
+ else:
286
+ raise ValueError(
287
+ f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead."
288
+ )
289
+
290
+ lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
291
+ lora_B = Vr @ torch.diag(torch.sqrt(Sr))
292
+ self.lora_A[adapter_name].weight.data = lora_A.view([-1] + list(weight_ori.shape[1:]))
293
+ self.lora_B[adapter_name].weight.data = lora_B.view([-1, self.r[adapter_name]] + [1] * (weight_ori.ndim - 2))
294
+ weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
295
+ weight = weight.to(dtype)
296
+ self.get_base_layer().weight.data = weight.view_as(weight_ori)
297
+
298
+ return True
299
+
300
+
301
+ # Patch peft conv2d
302
+ PeftConv2d.reset_lora_parameters = PeftConv2dEx.reset_lora_parameters
303
+ PeftConv2d.conv2d_pissa_init = PeftConv2dEx.conv2d_pissa_init
pipeline.py ADDED
@@ -0,0 +1,1378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
9
+
10
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
11
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
12
+ from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
13
+ from diffusers.models import AutoencoderKL, ImageProjection
14
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
15
+ from diffusers.schedulers import KarrasDiffusionSchedulers
16
+ from diffusers.utils import (
17
+ USE_PEFT_BACKEND,
18
+ deprecate,
19
+ logging,
20
+ replace_example_docstring,
21
+ scale_lora_layers,
22
+ unscale_lora_layers,
23
+ )
24
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
26
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
27
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
28
+ from model import UNet2DConditionModelEx
29
+
30
+
31
+ from huggingface_hub.utils import validate_hf_hub_args
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ EXAMPLE_DOC_STRING = """
38
+ Examples:
39
+ ```py
40
+ >>> # !pip install opencv-python transformers accelerate
41
+ >>> from diffusers import UniPCMultistepScheduler
42
+ >>> from diffusers.utils import load_image
43
+ >>> from model import UNet2DConditionModelEx
44
+ >>> from pipeline import StableDiffusionControlLoraV3Pipeline
45
+ >>> import numpy as np
46
+ >>> import torch
47
+
48
+ >>> import cv2
49
+ >>> from PIL import Image
50
+
51
+ >>> # download an image
52
+ >>> image = load_image(
53
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
54
+ ... )
55
+ >>> image = np.array(image)
56
+
57
+ >>> # get canny image
58
+ >>> image = cv2.Canny(image, 100, 200)
59
+ >>> image = image[:, :, None]
60
+ >>> image = np.concatenate([image, image, image], axis=2)
61
+ >>> canny_image = Image.fromarray(image)
62
+
63
+ >>> # load stable diffusion v1-5 and control-lora-v3
64
+ >>> unet: UNet2DConditionModelEx = UNet2DConditionModelEx.from_pretrained(
65
+ ... "runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float16
66
+ ... )
67
+ >>> unet = unet.add_extra_conditions(["canny"])
68
+ >>> pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
69
+ ... "runwayml/stable-diffusion-v1-5", unet=unet, torch_dtype=torch.float16
70
+ ... )
71
+ >>> # load attention processors
72
+ >>> pipe.load_lora_weights("HighCWu/sd-control-lora-v3-canny")
73
+
74
+ >>> # speed up diffusion process with faster scheduler and memory optimization
75
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
76
+ >>> # remove following line if xformers is not installed
77
+ >>> pipe.enable_xformers_memory_efficient_attention()
78
+
79
+ >>> pipe.enable_model_cpu_offload()
80
+
81
+ >>> # generate image
82
+ >>> generator = torch.manual_seed(0)
83
+ >>> image = pipe(
84
+ ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
85
+ ... ).images[0]
86
+ ```
87
+ """
88
+
89
+
90
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
91
+ def retrieve_timesteps(
92
+ scheduler,
93
+ num_inference_steps: Optional[int] = None,
94
+ device: Optional[Union[str, torch.device]] = None,
95
+ timesteps: Optional[List[int]] = None,
96
+ sigmas: Optional[List[float]] = None,
97
+ **kwargs,
98
+ ):
99
+ """
100
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
101
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
102
+
103
+ Args:
104
+ scheduler (`SchedulerMixin`):
105
+ The scheduler to get timesteps from.
106
+ num_inference_steps (`int`):
107
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
108
+ must be `None`.
109
+ device (`str` or `torch.device`, *optional*):
110
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
111
+ timesteps (`List[int]`, *optional*):
112
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
113
+ `num_inference_steps` and `sigmas` must be `None`.
114
+ sigmas (`List[float]`, *optional*):
115
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
116
+ `num_inference_steps` and `timesteps` must be `None`.
117
+
118
+ Returns:
119
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
120
+ second element is the number of inference steps.
121
+ """
122
+ if timesteps is not None and sigmas is not None:
123
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
124
+ if timesteps is not None:
125
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accepts_timesteps:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" timestep schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ elif sigmas is not None:
135
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
136
+ if not accept_sigmas:
137
+ raise ValueError(
138
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
139
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
140
+ )
141
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ num_inference_steps = len(timesteps)
144
+ else:
145
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
146
+ timesteps = scheduler.timesteps
147
+ return timesteps, num_inference_steps
148
+
149
+
150
+ class StableDiffusionControlLoraV3Pipeline(
151
+ DiffusionPipeline,
152
+ StableDiffusionMixin,
153
+ TextualInversionLoaderMixin,
154
+ LoraLoaderMixin,
155
+ IPAdapterMixin,
156
+ FromSingleFileMixin,
157
+ ):
158
+ r"""
159
+ Pipeline for text-to-image generation using Stable Diffusion with extra condition guidance.
160
+
161
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
162
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
163
+
164
+ The pipeline also inherits the following loading methods:
165
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
166
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
167
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
168
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
169
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
170
+
171
+ Args:
172
+ vae ([`AutoencoderKL`]):
173
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
174
+ text_encoder ([`~transformers.CLIPTextModel`]):
175
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
176
+ tokenizer ([`~transformers.CLIPTokenizer`]):
177
+ A `CLIPTokenizer` to tokenize text.
178
+ unet ([`UNet2DConditionModelEx`]):
179
+ A `UNet2DConditionModelEx` to denoise the encoded image latents with extra conditions.
180
+ scheduler ([`SchedulerMixin`]):
181
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
182
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
183
+ safety_checker ([`StableDiffusionSafetyChecker`]):
184
+ Classification module that estimates whether generated images could be considered offensive or harmful.
185
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
186
+ about a model's potential harms.
187
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
188
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
189
+ """
190
+
191
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
192
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
193
+ _exclude_from_cpu_offload = ["safety_checker"]
194
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
195
+
196
+ def __init__(
197
+ self,
198
+ vae: AutoencoderKL,
199
+ text_encoder: CLIPTextModel,
200
+ tokenizer: CLIPTokenizer,
201
+ unet: UNet2DConditionModelEx,
202
+ scheduler: KarrasDiffusionSchedulers,
203
+ safety_checker: StableDiffusionSafetyChecker,
204
+ feature_extractor: CLIPImageProcessor,
205
+ image_encoder: CLIPVisionModelWithProjection = None,
206
+ requires_safety_checker: bool = True,
207
+ ):
208
+ super().__init__()
209
+
210
+ if safety_checker is None and requires_safety_checker:
211
+ logger.warning(
212
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
213
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
214
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
215
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
216
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
217
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
218
+ )
219
+
220
+ if safety_checker is not None and feature_extractor is None:
221
+ raise ValueError(
222
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
223
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
224
+ )
225
+
226
+ self.register_modules(
227
+ vae=vae,
228
+ text_encoder=text_encoder,
229
+ tokenizer=tokenizer,
230
+ unet=unet,
231
+ scheduler=scheduler,
232
+ safety_checker=safety_checker,
233
+ feature_extractor=feature_extractor,
234
+ image_encoder=image_encoder,
235
+ )
236
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
237
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
238
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
239
+
240
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
241
+ def _encode_prompt(
242
+ self,
243
+ prompt,
244
+ device,
245
+ num_images_per_prompt,
246
+ do_classifier_free_guidance,
247
+ negative_prompt=None,
248
+ prompt_embeds: Optional[torch.Tensor] = None,
249
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
250
+ lora_scale: Optional[float] = None,
251
+ **kwargs,
252
+ ):
253
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
254
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
255
+
256
+ prompt_embeds_tuple = self.encode_prompt(
257
+ prompt=prompt,
258
+ device=device,
259
+ num_images_per_prompt=num_images_per_prompt,
260
+ do_classifier_free_guidance=do_classifier_free_guidance,
261
+ negative_prompt=negative_prompt,
262
+ prompt_embeds=prompt_embeds,
263
+ negative_prompt_embeds=negative_prompt_embeds,
264
+ lora_scale=lora_scale,
265
+ **kwargs,
266
+ )
267
+
268
+ # concatenate for backwards comp
269
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
270
+
271
+ return prompt_embeds
272
+
273
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
274
+ def encode_prompt(
275
+ self,
276
+ prompt,
277
+ device,
278
+ num_images_per_prompt,
279
+ do_classifier_free_guidance,
280
+ negative_prompt=None,
281
+ prompt_embeds: Optional[torch.Tensor] = None,
282
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
283
+ lora_scale: Optional[float] = None,
284
+ clip_skip: Optional[int] = None,
285
+ ):
286
+ r"""
287
+ Encodes the prompt into text encoder hidden states.
288
+
289
+ Args:
290
+ prompt (`str` or `List[str]`, *optional*):
291
+ prompt to be encoded
292
+ device: (`torch.device`):
293
+ torch device
294
+ num_images_per_prompt (`int`):
295
+ number of images that should be generated per prompt
296
+ do_classifier_free_guidance (`bool`):
297
+ whether to use classifier free guidance or not
298
+ negative_prompt (`str` or `List[str]`, *optional*):
299
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
300
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
301
+ less than `1`).
302
+ prompt_embeds (`torch.Tensor`, *optional*):
303
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
304
+ provided, text embeddings will be generated from `prompt` input argument.
305
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
306
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
307
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
308
+ argument.
309
+ lora_scale (`float`, *optional*):
310
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
311
+ clip_skip (`int`, *optional*):
312
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
313
+ the output of the pre-final layer will be used for computing the prompt embeddings.
314
+ """
315
+ # set lora scale so that monkey patched LoRA
316
+ # function of text encoder can correctly access it
317
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
318
+ self._lora_scale = lora_scale
319
+
320
+ # dynamically adjust the LoRA scale
321
+ if not USE_PEFT_BACKEND:
322
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
323
+ else:
324
+ scale_lora_layers(self.text_encoder, lora_scale)
325
+
326
+ if prompt is not None and isinstance(prompt, str):
327
+ batch_size = 1
328
+ elif prompt is not None and isinstance(prompt, list):
329
+ batch_size = len(prompt)
330
+ else:
331
+ batch_size = prompt_embeds.shape[0]
332
+
333
+ if prompt_embeds is None:
334
+ # textual inversion: process multi-vector tokens if necessary
335
+ if isinstance(self, TextualInversionLoaderMixin):
336
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
337
+
338
+ text_inputs = self.tokenizer(
339
+ prompt,
340
+ padding="max_length",
341
+ max_length=self.tokenizer.model_max_length,
342
+ truncation=True,
343
+ return_tensors="pt",
344
+ )
345
+ text_input_ids = text_inputs.input_ids
346
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
347
+
348
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
349
+ text_input_ids, untruncated_ids
350
+ ):
351
+ removed_text = self.tokenizer.batch_decode(
352
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
353
+ )
354
+ logger.warning(
355
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
356
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
357
+ )
358
+
359
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
360
+ attention_mask = text_inputs.attention_mask.to(device)
361
+ else:
362
+ attention_mask = None
363
+
364
+ if clip_skip is None:
365
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
366
+ prompt_embeds = prompt_embeds[0]
367
+ else:
368
+ prompt_embeds = self.text_encoder(
369
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
370
+ )
371
+ # Access the `hidden_states` first, that contains a tuple of
372
+ # all the hidden states from the encoder layers. Then index into
373
+ # the tuple to access the hidden states from the desired layer.
374
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
375
+ # We also need to apply the final LayerNorm here to not mess with the
376
+ # representations. The `last_hidden_states` that we typically use for
377
+ # obtaining the final prompt representations passes through the LayerNorm
378
+ # layer.
379
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
380
+
381
+ if self.text_encoder is not None:
382
+ prompt_embeds_dtype = self.text_encoder.dtype
383
+ elif self.unet is not None:
384
+ prompt_embeds_dtype = self.unet.dtype
385
+ else:
386
+ prompt_embeds_dtype = prompt_embeds.dtype
387
+
388
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
389
+
390
+ bs_embed, seq_len, _ = prompt_embeds.shape
391
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
392
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
393
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
394
+
395
+ # get unconditional embeddings for classifier free guidance
396
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
397
+ uncond_tokens: List[str]
398
+ if negative_prompt is None:
399
+ uncond_tokens = [""] * batch_size
400
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
401
+ raise TypeError(
402
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
403
+ f" {type(prompt)}."
404
+ )
405
+ elif isinstance(negative_prompt, str):
406
+ uncond_tokens = [negative_prompt]
407
+ elif batch_size != len(negative_prompt):
408
+ raise ValueError(
409
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
410
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
411
+ " the batch size of `prompt`."
412
+ )
413
+ else:
414
+ uncond_tokens = negative_prompt
415
+
416
+ # textual inversion: process multi-vector tokens if necessary
417
+ if isinstance(self, TextualInversionLoaderMixin):
418
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
419
+
420
+ max_length = prompt_embeds.shape[1]
421
+ uncond_input = self.tokenizer(
422
+ uncond_tokens,
423
+ padding="max_length",
424
+ max_length=max_length,
425
+ truncation=True,
426
+ return_tensors="pt",
427
+ )
428
+
429
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
430
+ attention_mask = uncond_input.attention_mask.to(device)
431
+ else:
432
+ attention_mask = None
433
+
434
+ negative_prompt_embeds = self.text_encoder(
435
+ uncond_input.input_ids.to(device),
436
+ attention_mask=attention_mask,
437
+ )
438
+ negative_prompt_embeds = negative_prompt_embeds[0]
439
+
440
+ if do_classifier_free_guidance:
441
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
442
+ seq_len = negative_prompt_embeds.shape[1]
443
+
444
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
445
+
446
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
447
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
448
+
449
+ if self.text_encoder is not None:
450
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
451
+ # Retrieve the original scale by scaling back the LoRA layers
452
+ unscale_lora_layers(self.text_encoder, lora_scale)
453
+
454
+ return prompt_embeds, negative_prompt_embeds
455
+
456
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
457
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
458
+ dtype = next(self.image_encoder.parameters()).dtype
459
+
460
+ if not isinstance(image, torch.Tensor):
461
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
462
+
463
+ image = image.to(device=device, dtype=dtype)
464
+ if output_hidden_states:
465
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
466
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
467
+ uncond_image_enc_hidden_states = self.image_encoder(
468
+ torch.zeros_like(image), output_hidden_states=True
469
+ ).hidden_states[-2]
470
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
471
+ num_images_per_prompt, dim=0
472
+ )
473
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
474
+ else:
475
+ image_embeds = self.image_encoder(image).image_embeds
476
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
477
+ uncond_image_embeds = torch.zeros_like(image_embeds)
478
+
479
+ return image_embeds, uncond_image_embeds
480
+
481
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
482
+ def prepare_ip_adapter_image_embeds(
483
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
484
+ ):
485
+ if ip_adapter_image_embeds is None:
486
+ if not isinstance(ip_adapter_image, list):
487
+ ip_adapter_image = [ip_adapter_image]
488
+
489
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
490
+ raise ValueError(
491
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
492
+ )
493
+
494
+ image_embeds = []
495
+ for single_ip_adapter_image, image_proj_layer in zip(
496
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
497
+ ):
498
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
499
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
500
+ single_ip_adapter_image, device, 1, output_hidden_state
501
+ )
502
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
503
+ single_negative_image_embeds = torch.stack(
504
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
505
+ )
506
+
507
+ if do_classifier_free_guidance:
508
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
509
+ single_image_embeds = single_image_embeds.to(device)
510
+
511
+ image_embeds.append(single_image_embeds)
512
+ else:
513
+ repeat_dims = [1]
514
+ image_embeds = []
515
+ for single_image_embeds in ip_adapter_image_embeds:
516
+ if do_classifier_free_guidance:
517
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
518
+ single_image_embeds = single_image_embeds.repeat(
519
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
520
+ )
521
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
522
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
523
+ )
524
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
525
+ else:
526
+ single_image_embeds = single_image_embeds.repeat(
527
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
528
+ )
529
+ image_embeds.append(single_image_embeds)
530
+
531
+ return image_embeds
532
+
533
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
534
+ def run_safety_checker(self, image, device, dtype):
535
+ if self.safety_checker is None:
536
+ has_nsfw_concept = None
537
+ else:
538
+ has_nsfw_concept = None
539
+ #if torch.is_tensor(image):
540
+ # feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
541
+ #else:
542
+ # feature_extractor_input = self.image_processor.numpy_to_pil(image)
543
+ #safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
544
+ #image, has_nsfw_concept = self.safety_checker(
545
+ # images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
546
+ #)
547
+ return image, has_nsfw_concept
548
+
549
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
550
+ def decode_latents(self, latents):
551
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
552
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
553
+
554
+ latents = 1 / self.vae.config.scaling_factor * latents
555
+ image = self.vae.decode(latents, return_dict=False)[0]
556
+ image = (image / 2 + 0.5).clamp(0, 1)
557
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
558
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
559
+ return image
560
+
561
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
562
+ def prepare_extra_step_kwargs(self, generator, eta):
563
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
564
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
565
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
566
+ # and should be between [0, 1]
567
+
568
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
569
+ extra_step_kwargs = {}
570
+ if accepts_eta:
571
+ extra_step_kwargs["eta"] = eta
572
+
573
+ # check if the scheduler accepts generator
574
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
575
+ if accepts_generator:
576
+ extra_step_kwargs["generator"] = generator
577
+ return extra_step_kwargs
578
+
579
+ def check_inputs(
580
+ self,
581
+ prompt,
582
+ image,
583
+ callback_steps,
584
+ negative_prompt=None,
585
+ prompt_embeds=None,
586
+ negative_prompt_embeds=None,
587
+ ip_adapter_image=None,
588
+ ip_adapter_image_embeds=None,
589
+ extra_condition_scale=1.0,
590
+ control_guidance_start=0.0,
591
+ control_guidance_end=1.0,
592
+ callback_on_step_end_tensor_inputs=None,
593
+ ):
594
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
595
+ raise ValueError(
596
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
597
+ f" {type(callback_steps)}."
598
+ )
599
+
600
+ if callback_on_step_end_tensor_inputs is not None and not all(
601
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
602
+ ):
603
+ raise ValueError(
604
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
605
+ )
606
+
607
+ if prompt is not None and prompt_embeds is not None:
608
+ raise ValueError(
609
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
610
+ " only forward one of the two."
611
+ )
612
+ elif prompt is None and prompt_embeds is None:
613
+ raise ValueError(
614
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
615
+ )
616
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
617
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
618
+
619
+ if negative_prompt is not None and negative_prompt_embeds is not None:
620
+ raise ValueError(
621
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
622
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
623
+ )
624
+
625
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
626
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
627
+ raise ValueError(
628
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
629
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
630
+ f" {negative_prompt_embeds.shape}."
631
+ )
632
+
633
+ # Check `image`
634
+ unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
635
+ num_extra_conditions = len(unet.extra_condition_names)
636
+ if num_extra_conditions == 1:
637
+ self.check_image(image, prompt, prompt_embeds)
638
+ elif num_extra_conditions > 1:
639
+ if not isinstance(image, list):
640
+ raise TypeError("For multiple extra conditions: `image` must be type `list`")
641
+
642
+ # When `image` is a nested list:
643
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
644
+ elif any(isinstance(i, list) for i in image):
645
+ transposed_image = [list(t) for t in zip(*image)]
646
+ if len(transposed_image) != num_extra_conditions:
647
+ raise ValueError(
648
+ f"For multiple extra conditions: if you pass`image` as a list of list, each sublist must have the same length as the number of extra conditions, but the sublists in `image` got {len(transposed_image)} images and {num_extra_conditions} extra conditions."
649
+ )
650
+ for image_ in transposed_image:
651
+ self.check_image(image_, prompt, prompt_embeds)
652
+ elif len(image) != num_extra_conditions:
653
+ raise ValueError(
654
+ f"For multiple extra conditions: `image` must have the same length as the number of extra conditions, but got {len(image)} images and {num_extra_conditions} extra conditions."
655
+ )
656
+ else:
657
+ for image_ in image:
658
+ self.check_image(image_, prompt, prompt_embeds)
659
+ else:
660
+ assert False
661
+
662
+ # Check `extra_condition_scale`
663
+ if num_extra_conditions == 1:
664
+ if not isinstance(extra_condition_scale, float):
665
+ raise TypeError("For single extra condition: `extra_condition_scale` must be type `float`.")
666
+ elif num_extra_conditions >= 1:
667
+ if isinstance(extra_condition_scale, list):
668
+ if any(isinstance(i, list) for i in extra_condition_scale):
669
+ raise ValueError(
670
+ "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
671
+ "The conditioning scale must be fixed across the batch."
672
+ )
673
+ elif isinstance(extra_condition_scale, list) and len(extra_condition_scale) != num_extra_conditions:
674
+ raise ValueError(
675
+ "For multiple extra conditions: When `extra_condition_scale` is specified as `list`, it must have"
676
+ " the same length as the number of extra conditions"
677
+ )
678
+ else:
679
+ assert False
680
+
681
+ if not isinstance(control_guidance_start, (tuple, list)):
682
+ control_guidance_start = [control_guidance_start]
683
+
684
+ if not isinstance(control_guidance_end, (tuple, list)):
685
+ control_guidance_end = [control_guidance_end]
686
+
687
+ if len(control_guidance_start) != len(control_guidance_end):
688
+ raise ValueError(
689
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
690
+ )
691
+
692
+ if num_extra_conditions > 1:
693
+ if len(control_guidance_start) != num_extra_conditions:
694
+ raise ValueError(
695
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_extra_conditions} extra conditions available. Make sure to provide {num_extra_conditions}."
696
+ )
697
+
698
+ for start, end in zip(control_guidance_start, control_guidance_end):
699
+ if start >= end:
700
+ raise ValueError(
701
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
702
+ )
703
+ if start < 0.0:
704
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
705
+ if end > 1.0:
706
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
707
+
708
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
709
+ raise ValueError(
710
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
711
+ )
712
+
713
+ if ip_adapter_image_embeds is not None:
714
+ if not isinstance(ip_adapter_image_embeds, list):
715
+ raise ValueError(
716
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
717
+ )
718
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
719
+ raise ValueError(
720
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
721
+ )
722
+
723
+ def check_image(self, image, prompt, prompt_embeds):
724
+ image_is_pil = isinstance(image, PIL.Image.Image)
725
+ image_is_tensor = isinstance(image, torch.Tensor)
726
+ image_is_np = isinstance(image, np.ndarray)
727
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
728
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
729
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
730
+
731
+ if (
732
+ not image_is_pil
733
+ and not image_is_tensor
734
+ and not image_is_np
735
+ and not image_is_pil_list
736
+ and not image_is_tensor_list
737
+ and not image_is_np_list
738
+ ):
739
+ raise TypeError(
740
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
741
+ )
742
+
743
+ if image_is_pil:
744
+ image_batch_size = 1
745
+ else:
746
+ image_batch_size = len(image)
747
+
748
+ if prompt is not None and isinstance(prompt, str):
749
+ prompt_batch_size = 1
750
+ elif prompt is not None and isinstance(prompt, list):
751
+ prompt_batch_size = len(prompt)
752
+ elif prompt_embeds is not None:
753
+ prompt_batch_size = prompt_embeds.shape[0]
754
+
755
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
756
+ raise ValueError(
757
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
758
+ )
759
+
760
+ def prepare_image(
761
+ self,
762
+ image,
763
+ width,
764
+ height,
765
+ batch_size,
766
+ num_images_per_prompt,
767
+ device,
768
+ dtype,
769
+ do_classifier_free_guidance=False,
770
+ ):
771
+ image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
772
+ image_batch_size = image.shape[0]
773
+
774
+ if image_batch_size == 1:
775
+ repeat_by = batch_size
776
+ else:
777
+ # image batch size is the same as prompt batch size
778
+ repeat_by = num_images_per_prompt
779
+
780
+ image = image.repeat_interleave(repeat_by, dim=0)
781
+
782
+ image = image.to(device=device, dtype=dtype)
783
+
784
+ if do_classifier_free_guidance:
785
+ image = torch.cat([image] * 2)
786
+
787
+ return image
788
+
789
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
790
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
791
+ shape = (
792
+ batch_size,
793
+ num_channels_latents,
794
+ int(height) // self.vae_scale_factor,
795
+ int(width) // self.vae_scale_factor,
796
+ )
797
+ if isinstance(generator, list) and len(generator) != batch_size:
798
+ raise ValueError(
799
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
800
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
801
+ )
802
+
803
+ if latents is None:
804
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
805
+ else:
806
+ latents = latents.to(device)
807
+
808
+ # scale the initial noise by the standard deviation required by the scheduler
809
+ latents = latents * self.scheduler.init_noise_sigma
810
+ return latents
811
+
812
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
813
+ def get_guidance_scale_embedding(
814
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
815
+ ) -> torch.Tensor:
816
+ """
817
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
818
+
819
+ Args:
820
+ w (`torch.Tensor`):
821
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
822
+ embedding_dim (`int`, *optional*, defaults to 512):
823
+ Dimension of the embeddings to generate.
824
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
825
+ Data type of the generated embeddings.
826
+
827
+ Returns:
828
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
829
+ """
830
+ assert len(w.shape) == 1
831
+ w = w * 1000.0
832
+
833
+ half_dim = embedding_dim // 2
834
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
835
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
836
+ emb = w.to(dtype)[:, None] * emb[None, :]
837
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
838
+ if embedding_dim % 2 == 1: # zero pad
839
+ emb = torch.nn.functional.pad(emb, (0, 1))
840
+ assert emb.shape == (w.shape[0], embedding_dim)
841
+ return emb
842
+
843
+ @property
844
+ def guidance_scale(self):
845
+ return self._guidance_scale
846
+
847
+ @property
848
+ def clip_skip(self):
849
+ return self._clip_skip
850
+
851
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
852
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
853
+ # corresponds to doing no classifier free guidance.
854
+ @property
855
+ def do_classifier_free_guidance(self):
856
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
857
+
858
+ @property
859
+ def cross_attention_kwargs(self):
860
+ return self._cross_attention_kwargs
861
+
862
+ @property
863
+ def num_timesteps(self):
864
+ return self._num_timesteps
865
+
866
+ @classmethod
867
+ @validate_hf_hub_args
868
+ def lora_state_dict(
869
+ cls,
870
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
871
+ **kwargs,
872
+ ):
873
+ # Override to add support for different LoRA alphas
874
+ state_dict, network_alphas = super(StableDiffusionControlLoraV3Pipeline, cls).lora_state_dict(
875
+ pretrained_model_name_or_path_or_dict, **kwargs
876
+ )
877
+ if network_alphas is None:
878
+ network_alphas = {}
879
+ for k, v in state_dict.items():
880
+ if ".lora_A." in k:
881
+ network_alphas[".".join(k.split(".lora_A.")[0].split(".") + ["alpha"])] = v.shape[0]
882
+ return state_dict, network_alphas
883
+
884
+ def load_lora_weights(
885
+ self,
886
+ pretrained_model_name_or_path_or_dict: Union[
887
+ Union[str, Dict[str, torch.Tensor]],
888
+ List[Union[str, Dict[str, torch.Tensor]]]
889
+ ],
890
+ adapter_name=None,
891
+ **kwargs
892
+ ):
893
+ unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
894
+ num_condition_names = len(unet.extra_condition_names)
895
+ in_channels = unet.config.in_channels
896
+
897
+ kwargs["weight_name"] = kwargs.pop("weight_name", "pytorch_lora_weights.safetensors")
898
+
899
+ if adapter_name is not None and adapter_name not in unet.extra_condition_names:
900
+ unet._hf_peft_config_loaded = True
901
+ super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
902
+ unet.set_adapter(adapter_name)
903
+ return
904
+
905
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
906
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] * num_condition_names
907
+ pretrained_model_name_or_path_or_dict_list = pretrained_model_name_or_path_or_dict
908
+
909
+ assert len(pretrained_model_name_or_path_or_dict) == len(unet.extra_condition_names)
910
+
911
+ adapter_name_ori = adapter_name
912
+ for i, (pretrained_model_name_or_path_or_dict, adapter_name) in enumerate(zip(
913
+ pretrained_model_name_or_path_or_dict_list,
914
+ unet.extra_condition_names
915
+ )):
916
+ _kwargs = {**kwargs}
917
+ subfolder = _kwargs.pop("subfolder", None)
918
+ if isinstance(subfolder, list):
919
+ subfolder = subfolder[i]
920
+
921
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
922
+ pretrained_model_name_or_path_or_dict, _ = self.lora_state_dict(
923
+ pretrained_model_name_or_path_or_dict,
924
+ subfolder=subfolder,
925
+ **_kwargs
926
+ )
927
+
928
+ if adapter_name_ori is not None:
929
+ # only load lora of the input adapter name, then break the loop
930
+ i = unet.extra_condition_names.index(adapter_name_ori)
931
+ adapter_name = adapter_name_ori
932
+
933
+ unet_conv_in_lora_A_name, old_weight = ([
934
+ (k, v)
935
+ for k, v in pretrained_model_name_or_path_or_dict.items()
936
+ if "unet." in k and ".conv_in." in k and ".lora_A." in k
937
+ ] + [(None, None)])[0]
938
+ if unet_conv_in_lora_A_name is not None:
939
+ in_weight = old_weight[:,:in_channels]
940
+ cond_weight = old_weight[:,in_channels:]
941
+ zero_weight = torch.zeros_like(in_weight)
942
+ new_weight = torch.cat(
943
+ [in_weight] +
944
+ [zero_weight] * i +
945
+ [cond_weight] +
946
+ [zero_weight] * (num_condition_names - i - 1),
947
+ dim=1
948
+ )
949
+ pretrained_model_name_or_path_or_dict[unet_conv_in_lora_A_name] = new_weight
950
+
951
+ super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name, **_kwargs)
952
+
953
+ if adapter_name_ori is not None:
954
+ break
955
+
956
+ unet.activate_extra_condition_adapters()
957
+
958
+ @torch.no_grad()
959
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
960
+ def __call__(
961
+ self,
962
+ prompt: Union[str, List[str]] = None,
963
+ image: PipelineImageInput = None,
964
+ height: Optional[int] = None,
965
+ width: Optional[int] = None,
966
+ num_inference_steps: int = 50,
967
+ timesteps: List[int] = None,
968
+ sigmas: List[float] = None,
969
+ guidance_scale: float = 7.5,
970
+ negative_prompt: Optional[Union[str, List[str]]] = None,
971
+ num_images_per_prompt: Optional[int] = 1,
972
+ eta: float = 0.0,
973
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
974
+ latents: Optional[torch.Tensor] = None,
975
+ prompt_embeds: Optional[torch.Tensor] = None,
976
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
977
+ ip_adapter_image: Optional[PipelineImageInput] = None,
978
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
979
+ output_type: Optional[str] = "pil",
980
+ return_dict: bool = True,
981
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
982
+ extra_condition_scale: Union[float, List[float]] = 1.0,
983
+ control_guidance_start: Union[float, List[float]] = 0.0,
984
+ control_guidance_end: Union[float, List[float]] = 1.0,
985
+ clip_skip: Optional[int] = None,
986
+ callback_on_step_end: Optional[
987
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
988
+ ] = None,
989
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
990
+ **kwargs,
991
+ ):
992
+ r"""
993
+ The call function to the pipeline for generation.
994
+
995
+ Args:
996
+ prompt (`str` or `List[str]`, *optional*):
997
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
998
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
999
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1000
+ The extra input condition to provide guidance to the `unet` for generation after encoded by `vae`. If the type is
1001
+ specified as `torch.Tensor`, its `vae` latent representation is passed to UNet. `PIL.Image.Image` can also be accepted
1002
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
1003
+ width are passed, `image` is resized accordingly. If multiple extra conditions are specified in `unet`,
1004
+ images must be passed as a list such that each element of the list can be correctly batched for input
1005
+ to `unet`. When `prompt` is a list, and if a list of images is passed for `unet`, each will be paired with each prompt
1006
+ in the `prompt` list. This also applies to multiple extra conditions, where a list of image lists can be
1007
+ passed to batch for each prompt and each extra condition.
1008
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1009
+ The height in pixels of the generated image.
1010
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1011
+ The width in pixels of the generated image.
1012
+ num_inference_steps (`int`, *optional*, defaults to 50):
1013
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1014
+ expense of slower inference.
1015
+ timesteps (`List[int]`, *optional*):
1016
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1017
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1018
+ passed will be used. Must be in descending order.
1019
+ sigmas (`List[float]`, *optional*):
1020
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1021
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1022
+ will be used.
1023
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1024
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1025
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1026
+ negative_prompt (`str` or `List[str]`, *optional*):
1027
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1028
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1029
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1030
+ The number of images to generate per prompt.
1031
+ eta (`float`, *optional*, defaults to 0.0):
1032
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1033
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1034
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1035
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1036
+ generation deterministic.
1037
+ latents (`torch.Tensor`, *optional*):
1038
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1039
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1040
+ tensor is generated by sampling using the supplied random `generator`.
1041
+ prompt_embeds (`torch.Tensor`, *optional*):
1042
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1043
+ provided, text embeddings are generated from the `prompt` input argument.
1044
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1045
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1046
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1047
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1048
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1049
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1050
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1051
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1052
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1053
+ output_type (`str`, *optional*, defaults to `"pil"`):
1054
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1055
+ return_dict (`bool`, *optional*, defaults to `True`):
1056
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1057
+ plain tuple.
1058
+ callback (`Callable`, *optional*):
1059
+ A function that calls every `callback_steps` steps during inference. The function is called with the
1060
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
1061
+ callback_steps (`int`, *optional*, defaults to 1):
1062
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
1063
+ every step.
1064
+ cross_attention_kwargs (`dict`, *optional*):
1065
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1066
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1067
+ extra_condition_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1068
+ The control lora scale of `unet`. If multiple extra conditions are specified in `unet`, you can set
1069
+ the corresponding scale as a list.
1070
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1071
+ The percentage of total steps at which the extra condtion starts applying.
1072
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1073
+ The percentage of total steps at which the extra condtion stops applying.
1074
+ clip_skip (`int`, *optional*):
1075
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1076
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1077
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1078
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1079
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1080
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1081
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1082
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1083
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1084
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1085
+ `._callback_tensor_inputs` attribute of your pipeline class.
1086
+
1087
+ Examples:
1088
+
1089
+ Returns:
1090
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1091
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1092
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1093
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1094
+ "not-safe-for-work" (nsfw) content.
1095
+ """
1096
+
1097
+ callback = kwargs.pop("callback", None)
1098
+ callback_steps = kwargs.pop("callback_steps", None)
1099
+
1100
+ if callback is not None:
1101
+ deprecate(
1102
+ "callback",
1103
+ "1.0.0",
1104
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1105
+ )
1106
+ if callback_steps is not None:
1107
+ deprecate(
1108
+ "callback_steps",
1109
+ "1.0.0",
1110
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1111
+ )
1112
+
1113
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1114
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1115
+
1116
+ unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
1117
+ num_extra_conditions = len(unet.extra_condition_names)
1118
+
1119
+ # align format for control guidance
1120
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1121
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1122
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1123
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1124
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1125
+ mult = num_extra_conditions
1126
+ control_guidance_start, control_guidance_end = (
1127
+ mult * [control_guidance_start],
1128
+ mult * [control_guidance_end],
1129
+ )
1130
+
1131
+ # 1. Check inputs. Raise error if not correct
1132
+ self.check_inputs(
1133
+ prompt,
1134
+ image,
1135
+ callback_steps,
1136
+ negative_prompt,
1137
+ prompt_embeds,
1138
+ negative_prompt_embeds,
1139
+ ip_adapter_image,
1140
+ ip_adapter_image_embeds,
1141
+ extra_condition_scale,
1142
+ control_guidance_start,
1143
+ control_guidance_end,
1144
+ callback_on_step_end_tensor_inputs,
1145
+ )
1146
+
1147
+ self._guidance_scale = guidance_scale
1148
+ self._clip_skip = clip_skip
1149
+ self._cross_attention_kwargs = cross_attention_kwargs
1150
+
1151
+ # 2. Define call parameters
1152
+ if prompt is not None and isinstance(prompt, str):
1153
+ batch_size = 1
1154
+ elif prompt is not None and isinstance(prompt, list):
1155
+ batch_size = len(prompt)
1156
+ else:
1157
+ batch_size = prompt_embeds.shape[0]
1158
+
1159
+ device = self._execution_device
1160
+
1161
+ if num_extra_conditions > 1 and isinstance(extra_condition_scale, float):
1162
+ extra_condition_scale = [extra_condition_scale] * num_extra_conditions
1163
+
1164
+ # 3. Encode input prompt
1165
+ text_encoder_lora_scale = (
1166
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1167
+ )
1168
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1169
+ prompt,
1170
+ device,
1171
+ num_images_per_prompt,
1172
+ self.do_classifier_free_guidance,
1173
+ negative_prompt,
1174
+ prompt_embeds=prompt_embeds,
1175
+ negative_prompt_embeds=negative_prompt_embeds,
1176
+ lora_scale=text_encoder_lora_scale,
1177
+ clip_skip=self.clip_skip,
1178
+ )
1179
+ # For classifier free guidance, we need to do two forward passes.
1180
+ # Here we concatenate the unconditional and text embeddings into a single batch
1181
+ # to avoid doing two forward passes
1182
+ if self.do_classifier_free_guidance:
1183
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1184
+
1185
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1186
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1187
+ ip_adapter_image,
1188
+ ip_adapter_image_embeds,
1189
+ device,
1190
+ batch_size * num_images_per_prompt,
1191
+ self.do_classifier_free_guidance,
1192
+ )
1193
+
1194
+ # 4. Prepare image
1195
+ if num_extra_conditions == 1:
1196
+ image = self.prepare_image(
1197
+ image=image,
1198
+ width=width,
1199
+ height=height,
1200
+ batch_size=batch_size * num_images_per_prompt,
1201
+ num_images_per_prompt=num_images_per_prompt,
1202
+ device=device,
1203
+ dtype=unet.dtype,
1204
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1205
+ )
1206
+ height, width = image.shape[-2:]
1207
+ image = (
1208
+ self.vae.encode(image.to(dtype=unet.dtype)).latent_dist.mode() * self.vae.config.scaling_factor
1209
+ )
1210
+ elif num_extra_conditions >= 1:
1211
+ images = []
1212
+
1213
+ # Nested lists as extra condition
1214
+ if isinstance(image[0], list):
1215
+ # Transpose the nested image list
1216
+ image = [list(t) for t in zip(*image)]
1217
+
1218
+ for image_ in image:
1219
+ image_ = self.prepare_image(
1220
+ image=image_,
1221
+ width=width,
1222
+ height=height,
1223
+ batch_size=batch_size * num_images_per_prompt,
1224
+ num_images_per_prompt=num_images_per_prompt,
1225
+ device=device,
1226
+ dtype=unet.dtype,
1227
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1228
+ )
1229
+
1230
+ images.append(image_)
1231
+
1232
+ image = images
1233
+ height, width = image[0].shape[-2:]
1234
+ image = [
1235
+ self.vae.encode(image.to(dtype=unet.dtype)).latent_dist.mode() * self.vae.config.scaling_factor
1236
+ for image in images
1237
+ ]
1238
+ else:
1239
+ assert False
1240
+
1241
+ # 5. Prepare timesteps
1242
+ timesteps, num_inference_steps = retrieve_timesteps(
1243
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1244
+ )
1245
+ self._num_timesteps = len(timesteps)
1246
+
1247
+ # 6. Prepare latent variables
1248
+ num_channels_latents = self.unet.config.in_channels
1249
+ latents = self.prepare_latents(
1250
+ batch_size * num_images_per_prompt,
1251
+ num_channels_latents,
1252
+ height,
1253
+ width,
1254
+ prompt_embeds.dtype,
1255
+ device,
1256
+ generator,
1257
+ latents,
1258
+ )
1259
+
1260
+ # 6.5 Optionally get Guidance Scale Embedding
1261
+ timestep_cond = None
1262
+ if self.unet.config.time_cond_proj_dim is not None:
1263
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1264
+ timestep_cond = self.get_guidance_scale_embedding(
1265
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1266
+ ).to(device=device, dtype=latents.dtype)
1267
+
1268
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1269
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1270
+
1271
+ # 7.1 Add image embeds for IP-Adapter
1272
+ added_cond_kwargs = (
1273
+ {"image_embeds": image_embeds}
1274
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1275
+ else None
1276
+ )
1277
+
1278
+ # 7.2 Create tensor stating which extra_conditions to keep
1279
+ extra_condition_keep = []
1280
+ for i in range(len(timesteps)):
1281
+ keeps = [
1282
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1283
+ for s, e in zip(control_guidance_start, control_guidance_end)
1284
+ ]
1285
+ extra_condition_keep.append(keeps[0] if num_extra_conditions == 1 else keeps)
1286
+
1287
+ # 8. Denoising loop
1288
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1289
+ is_unet_compiled = is_compiled_module(self.unet)
1290
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1291
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1292
+ for i, t in enumerate(timesteps):
1293
+ # Relevant thread:
1294
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1295
+ if is_unet_compiled and is_torch_higher_equal_2_1:
1296
+ torch._inductor.cudagraph_mark_step_begin()
1297
+ # expand the latents if we are doing classifier free guidance
1298
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1299
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1300
+
1301
+ if isinstance(extra_condition_keep[i], list):
1302
+ cond_scale = [c * s for c, s in zip(extra_condition_scale, extra_condition_keep[i])]
1303
+ else:
1304
+ extra_cond_scale = extra_condition_scale
1305
+ if isinstance(extra_cond_scale, list):
1306
+ extra_cond_scale = extra_cond_scale[0]
1307
+ cond_scale = extra_cond_scale * extra_condition_keep[i]
1308
+
1309
+ self.unet.set_extra_condition_scale(cond_scale)
1310
+
1311
+ # predict the noise residual
1312
+ noise_pred = self.unet(
1313
+ latent_model_input,
1314
+ t,
1315
+ encoder_hidden_states=prompt_embeds,
1316
+ timestep_cond=timestep_cond,
1317
+ cross_attention_kwargs=self.cross_attention_kwargs,
1318
+ added_cond_kwargs=added_cond_kwargs,
1319
+ extra_conditions=image,
1320
+ return_dict=False,
1321
+ )[0]
1322
+
1323
+ # perform guidance
1324
+ if self.do_classifier_free_guidance:
1325
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1326
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1327
+
1328
+ # compute the previous noisy sample x_t -> x_t-1
1329
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1330
+
1331
+ if callback_on_step_end is not None:
1332
+ callback_kwargs = {}
1333
+ for k in callback_on_step_end_tensor_inputs:
1334
+ callback_kwargs[k] = locals()[k]
1335
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1336
+
1337
+ latents = callback_outputs.pop("latents", latents)
1338
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1339
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1340
+
1341
+ # call the callback, if provided
1342
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1343
+ progress_bar.update()
1344
+ if callback is not None and i % callback_steps == 0:
1345
+ step_idx = i // getattr(self.scheduler, "order", 1)
1346
+ callback(step_idx, t, latents)
1347
+
1348
+ self.unet.set_extra_condition_scale(1.0)
1349
+
1350
+ # If we do sequential model offloading, let's offload unet
1351
+ # manually for max memory savings
1352
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1353
+ self.unet.to("cpu")
1354
+ torch.cuda.empty_cache()
1355
+
1356
+ if not output_type == "latent":
1357
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1358
+ 0
1359
+ ]
1360
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1361
+ else:
1362
+ image = latents
1363
+ has_nsfw_concept = None
1364
+
1365
+ if has_nsfw_concept is None:
1366
+ do_denormalize = [True] * image.shape[0]
1367
+ else:
1368
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1369
+
1370
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1371
+
1372
+ # Offload all models
1373
+ self.maybe_free_model_hooks()
1374
+
1375
+ if not return_dict:
1376
+ return (image, has_nsfw_concept)
1377
+
1378
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
requirements.txt CHANGED
@@ -1,6 +1,558 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ accelerate==1.2.1
3
+ aiofiles==23.2.1
4
+ aiohappyeyeballs==2.4.4
5
+ aiohttp==3.11.11
6
+ aiosignal==1.3.2
7
+ alabaster==1.0.0
8
+ albucore==0.0.19
9
+ albumentations==1.4.20
10
+ altair==5.5.0
11
+ annotated-types==0.7.0
12
+ anyio==3.7.1
13
+ argon2-cffi==23.1.0
14
+ argon2-cffi-bindings==21.2.0
15
+ array_record==0.6.0
16
+ arviz==0.20.0
17
+ astropy==6.1.7
18
+ astropy-iers-data==0.2025.1.6.0.33.42
19
+ astunparse==1.6.3
20
+ async-timeout==4.0.3
21
+ atpublic==4.1.0
22
+ attrs==24.3.0
23
+ audioread==3.0.1
24
+ autograd==1.7.0
25
+ babel==2.16.0
26
+ backcall==0.2.0
27
+ beautifulsoup4==4.12.3
28
+ bigframes==1.31.0
29
+ bigquery-magics==0.5.0
30
+ bleach==6.2.0
31
+ blinker==1.9.0
32
+ blis==0.7.11
33
+ blosc2==2.7.1
34
+ bokeh==3.6.2
35
+ Bottleneck==1.4.2
36
+ bqplot==0.12.44
37
+ branca==0.8.1
38
+ CacheControl==0.14.2
39
+ cachetools==5.5.0
40
+ catalogue==2.0.10
41
+ certifi==2024.12.14
42
+ cffi==1.17.1
43
+ chardet==5.2.0
44
+ charset-normalizer==3.4.1
45
+ chex==0.1.88
46
+ clarabel==0.9.0
47
+ click==8.1.8
48
+ cloudpathlib==0.20.0
49
+ cloudpickle==3.1.0
50
+ cmake==3.31.2
51
+ cmdstanpy==1.2.5
52
+ colorcet==3.1.0
53
+ colorlover==0.3.0
54
+ colour==0.1.5
55
+ community==1.0.0b1
56
+ confection==0.1.5
57
+ cons==0.4.6
58
+ contourpy==1.3.1
59
+ cryptography==43.0.3
60
+ cuda-python==12.2.1
61
+ cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.10.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
62
+ cufflinks==0.17.3
63
+ cupy-cuda12x==12.2.0
64
+ cvxopt==1.3.2
65
+ cvxpy==1.6.0
66
+ cycler==0.12.1
67
+ cymem==2.0.10
68
+ Cython==3.0.11
69
+ dask==2024.10.0
70
+ datascience==0.17.6
71
+ db-dtypes==1.3.1
72
+ dbus-python==1.2.18
73
+ debugpy==1.8.0
74
+ decorator==4.4.2
75
+ defusedxml==0.7.1
76
+ Deprecated==1.2.15
77
+ diffusers==0.32.1
78
+ discord.py==2.4.0
79
+ distro==1.9.0
80
+ dlib==19.24.2
81
+ dm-tree==0.1.8
82
+ docker-pycreds==0.4.0
83
+ docstring_parser==0.16
84
+ docutils==0.21.2
85
+ dopamine_rl==4.1.0
86
+ duckdb==1.1.3
87
+ earthengine-api==1.4.4
88
+ easydict==1.13
89
+ editdistance==0.8.1
90
+ eerepr==0.0.4
91
+ einops==0.8.0
92
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
93
+ entrypoints==0.4
94
+ et_xmlfile==2.0.0
95
+ etils==1.11.0
96
+ etuples==0.3.9
97
+ eval_type_backport==0.2.2
98
+ exceptiongroup==1.2.2
99
+ fastai==2.7.18
100
+ fastapi==0.115.6
101
+ fastcore==1.7.28
102
+ fastdownload==0.0.7
103
+ fastjsonschema==2.21.1
104
+ fastprogress==1.0.3
105
+ fastrlock==0.8.3
106
+ ffmpy==0.5.0
107
+ filelock==3.16.1
108
+ firebase-admin==6.6.0
109
+ Flask==3.1.0
110
+ flatbuffers==24.12.23
111
+ flax==0.10.2
112
+ folium==0.19.4
113
+ fonttools==4.55.3
114
+ frozendict==2.4.6
115
+ frozenlist==1.5.0
116
+ fsspec==2024.10.0
117
+ future==1.0.0
118
+ gast==0.6.0
119
+ gcsfs==2024.10.0
120
+ GDAL==3.6.4
121
+ gdown==5.2.0
122
+ geemap==0.35.1
123
+ gensim==4.3.3
124
+ geocoder==1.38.1
125
+ geographiclib==2.0
126
+ geopandas==1.0.1
127
+ geopy==2.4.1
128
+ gin-config==0.5.0
129
+ gitdb==4.0.12
130
+ GitPython==3.1.44
131
+ glob2==0.7
132
+ google==2.0.3
133
+ google-ai-generativelanguage==0.6.10
134
+ google-api-core==2.19.2
135
+ google-api-python-client==2.155.0
136
+ google-auth==2.27.0
137
+ google-auth-httplib2==0.2.0
138
+ google-auth-oauthlib==1.2.1
139
+ google-cloud-aiplatform==1.74.0
140
+ google-cloud-bigquery==3.25.0
141
+ google-cloud-bigquery-connection==1.17.0
142
+ google-cloud-bigquery-storage==2.27.0
143
+ google-cloud-bigtable==2.27.0
144
+ google-cloud-core==2.4.1
145
+ google-cloud-datastore==2.20.2
146
+ google-cloud-firestore==2.19.0
147
+ google-cloud-functions==1.19.0
148
+ google-cloud-iam==2.17.0
149
+ google-cloud-language==2.16.0
150
+ google-cloud-pubsub==2.27.2
151
+ google-cloud-resource-manager==1.14.0
152
+ google-cloud-storage==2.19.0
153
+ google-cloud-translate==3.19.0
154
+ google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
155
+ google-crc32c==1.6.0
156
+ google-genai==0.3.0
157
+ google-generativeai==0.8.3
158
+ google-pasta==0.2.0
159
+ google-resumable-media==2.7.2
160
+ googleapis-common-protos==1.66.0
161
+ googledrivedownloader==0.4
162
+ gradio==5.12.0
163
+ gradio_client==1.5.4
164
+ graphviz==0.20.3
165
+ greenlet==3.1.1
166
+ grpc-google-iam-v1==0.14.0
167
+ grpcio==1.69.0
168
+ grpcio-status==1.62.3
169
+ gspread==6.1.4
170
+ gspread-dataframe==4.0.0
171
+ gym==0.25.2
172
+ gym-notices==0.0.8
173
+ h11==0.14.0
174
+ h5netcdf==1.4.1
175
+ h5py==3.12.1
176
+ holidays==0.64
177
+ holoviews==1.20.0
178
+ html5lib==1.1
179
+ httpcore==1.0.7
180
+ httpimport==1.4.0
181
+ httplib2==0.22.0
182
+ httpx==0.28.1
183
+ huggingface-hub==0.27.1
184
+ humanize==4.11.0
185
+ hyperopt==0.2.7
186
+ ibis-framework==9.2.0
187
+ idna==3.10
188
+ imageio==2.36.1
189
+ imageio-ffmpeg==0.5.1
190
+ imagesize==1.4.1
191
+ imbalanced-learn==0.13.0
192
+ imgaug==0.4.0
193
+ immutabledict==4.2.1
194
+ importlib_metadata==8.5.0
195
+ importlib_resources==6.5.2
196
+ imutils==0.5.4
197
+ inflect==7.5.0
198
+ iniconfig==2.0.0
199
+ intel-cmplr-lib-ur==2025.0.4
200
+ intel-openmp==2025.0.4
201
+ ipyevents==2.0.2
202
+ ipyfilechooser==0.6.0
203
+ ipykernel==5.5.6
204
+ ipyleaflet==0.19.2
205
+ ipyparallel==8.8.0
206
+ ipython==7.34.0
207
+ ipython-genutils==0.2.0
208
+ ipython-sql==0.5.0
209
+ ipytree==0.2.2
210
+ ipywidgets==7.7.1
211
+ itsdangerous==2.2.0
212
+ jax==0.4.33
213
+ jax-cuda12-pjrt==0.4.33
214
+ jax-cuda12-plugin==0.4.33
215
+ jaxlib==0.4.33
216
+ jeepney==0.7.1
217
+ jellyfish==1.1.0
218
+ jieba==0.42.1
219
+ Jinja2==3.1.5
220
+ jiter==0.8.2
221
+ joblib==1.4.2
222
+ jsonpatch==1.33
223
+ jsonpickle==4.0.1
224
+ jsonpointer==3.0.0
225
+ jsonschema==4.23.0
226
+ jsonschema-specifications==2024.10.1
227
+ jupyter-client==6.1.12
228
+ jupyter-console==6.1.0
229
+ jupyter-leaflet==0.19.2
230
+ jupyter-server==1.24.0
231
+ jupyter_core==5.7.2
232
+ jupyterlab_pygments==0.3.0
233
+ jupyterlab_widgets==3.0.13
234
+ kaggle==1.6.17
235
+ kagglehub==0.3.6
236
+ keras==3.5.0
237
+ keyring==23.5.0
238
+ kiwisolver==1.4.8
239
+ langchain==0.3.14
240
+ langchain-core==0.3.29
241
+ langchain-text-splitters==0.3.5
242
+ langcodes==3.5.0
243
+ langsmith==0.2.10
244
+ language_data==1.3.0
245
+ launchpadlib==1.10.16
246
+ lazr.restfulclient==0.14.4
247
+ lazr.uri==1.0.6
248
+ lazy_loader==0.4
249
+ libclang==18.1.1
250
+ libcudf-cu12 @ https://pypi.nvidia.com/libcudf-cu12/libcudf_cu12-24.10.1-py3-none-manylinux_2_28_x86_64.whl
251
+ librosa==0.10.2.post1
252
+ lightgbm==4.5.0
253
+ linkify-it-py==2.0.3
254
+ llvmlite==0.43.0
255
+ locket==1.0.0
256
+ logical-unification==0.4.6
257
+ lxml==5.3.0
258
+ marisa-trie==1.2.1
259
+ Markdown==3.7
260
+ markdown-it-py==3.0.0
261
+ MarkupSafe==2.1.5
262
+ matplotlib==3.10.0
263
+ matplotlib-inline==0.1.7
264
+ matplotlib-venn==1.1.1
265
+ mdit-py-plugins==0.4.2
266
+ mdurl==0.1.2
267
+ miniKanren==1.0.3
268
+ missingno==0.5.2
269
+ mistune==3.1.0
270
+ mizani==0.13.1
271
+ mkl==2025.0.1
272
+ ml-dtypes==0.4.1
273
+ mlxtend==0.23.3
274
+ more-itertools==10.5.0
275
+ moviepy==1.0.3
276
+ mpmath==1.3.0
277
+ msgpack==1.1.0
278
+ multidict==6.1.0
279
+ multipledispatch==1.0.0
280
+ multitasking==0.0.11
281
+ murmurhash==1.0.11
282
+ music21==9.3.0
283
+ namex==0.0.8
284
+ narwhals==1.21.1
285
+ natsort==8.4.0
286
+ nbclassic==1.1.0
287
+ nbclient==0.10.2
288
+ nbconvert==7.16.5
289
+ nbformat==5.10.4
290
+ ndindex==1.9.2
291
+ nest-asyncio==1.6.0
292
+ networkx==3.4.2
293
+ nibabel==5.3.2
294
+ nltk==3.9.1
295
+ notebook==6.5.5
296
+ notebook_shim==0.2.4
297
+ numba==0.60.0
298
+ numexpr==2.10.2
299
+ numpy==1.26.4
300
+ nvidia-cublas-cu12==12.6.4.1
301
+ nvidia-cuda-cupti-cu12==12.6.80
302
+ nvidia-cuda-nvcc-cu12==12.6.85
303
+ nvidia-cuda-runtime-cu12==12.6.77
304
+ nvidia-cudnn-cu12==9.6.0.74
305
+ nvidia-cufft-cu12==11.3.0.4
306
+ nvidia-curand-cu12==10.3.7.77
307
+ nvidia-cusolver-cu12==11.7.1.2
308
+ nvidia-cusparse-cu12==12.5.4.2
309
+ nvidia-nccl-cu12==2.24.3
310
+ nvidia-nvjitlink-cu12==12.6.85
311
+ nvtx==0.2.10
312
+ nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-24.10.0-py3-none-any.whl
313
+ oauth2client==4.1.3
314
+ oauthlib==3.2.2
315
+ openai==1.59.4
316
+ opencv-contrib-python==4.10.0.84
317
+ opencv-python==4.10.0.84
318
+ opencv-python-headless==4.10.0.84
319
+ openpyxl==3.1.5
320
+ opentelemetry-api==1.29.0
321
+ opentelemetry-sdk==1.29.0
322
+ opentelemetry-semantic-conventions==0.50b0
323
+ opt_einsum==3.4.0
324
+ optax==0.2.4
325
+ optree==0.13.1
326
+ orbax-checkpoint==0.6.4
327
+ orjson==3.10.13
328
+ osqp==0.6.7.post3
329
+ packaging==24.2
330
+ pandas==2.2.2
331
+ pandas-datareader==0.10.0
332
+ pandas-gbq==0.26.1
333
+ pandas-stubs==2.2.2.240909
334
+ pandocfilters==1.5.1
335
+ panel==1.5.5
336
+ param==2.2.0
337
+ parso==0.8.4
338
+ parsy==2.1
339
+ partd==1.4.2
340
+ pathlib==1.0.1
341
+ patsy==1.0.1
342
+ peewee==3.17.8
343
+ peft==0.14.0
344
+ pexpect==4.9.0
345
+ pickleshare==0.7.5
346
+ pillow==11.1.0
347
+ platformdirs==4.3.6
348
+ plotly==5.24.1
349
+ plotnine==0.14.5
350
+ pluggy==1.5.0
351
+ ply==3.11
352
+ polars==1.9.0
353
+ pooch==1.8.2
354
+ portpicker==1.5.2
355
+ preshed==3.0.9
356
+ prettytable==3.12.0
357
+ proglog==0.1.10
358
+ progressbar2==4.5.0
359
+ prometheus_client==0.21.1
360
+ promise==2.3
361
+ prompt_toolkit==3.0.48
362
+ propcache==0.2.1
363
+ prophet==1.1.6
364
+ proto-plus==1.25.0
365
+ protobuf==4.25.5
366
+ psutil==5.9.5
367
+ psycopg2==2.9.10
368
+ ptyprocess==0.7.0
369
+ py-cpuinfo==9.0.0
370
+ py4j==0.10.9.7
371
+ pyarrow==17.0.0
372
+ pyasn1==0.6.1
373
+ pyasn1_modules==0.4.1
374
+ pycocotools==2.0.8
375
+ pycparser==2.22
376
+ pydantic==2.10.4
377
+ pydantic_core==2.27.2
378
+ pydata-google-auth==1.9.0
379
+ pydot==3.0.4
380
+ pydotplus==2.0.2
381
+ PyDrive==1.3.1
382
+ PyDrive2==1.21.3
383
+ pydub==0.25.1
384
+ pyerfa==2.0.1.5
385
+ pygame==2.6.1
386
+ pygit2==1.16.0
387
+ Pygments==2.18.0
388
+ PyGObject==3.42.1
389
+ PyJWT==2.10.1
390
+ pylibcudf-cu12 @ https://pypi.nvidia.com/pylibcudf-cu12/pylibcudf_cu12-24.10.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
391
+ pylibcugraph-cu12==24.10.0
392
+ pylibraft-cu12==24.10.0
393
+ pymc==5.19.1
394
+ pymystem3==0.2.0
395
+ pynvjitlink-cu12==0.4.0
396
+ pyogrio==0.10.0
397
+ Pyomo==6.8.2
398
+ PyOpenGL==3.1.7
399
+ pyOpenSSL==24.2.1
400
+ pyparsing==3.2.1
401
+ pyperclip==1.9.0
402
+ pyproj==3.7.0
403
+ pyshp==2.3.1
404
+ PySocks==1.7.1
405
+ pyspark==3.5.4
406
+ pytensor==2.26.4
407
+ pytest==8.3.4
408
+ python-apt==0.0.0
409
+ python-box==7.3.0
410
+ python-dateutil==2.8.2
411
+ python-dotenv==1.0.1
412
+ python-louvain==0.16
413
+ python-multipart==0.0.20
414
+ python-slugify==8.0.4
415
+ python-utils==3.9.1
416
+ pytz==2024.2
417
+ pyviz_comms==3.0.3
418
+ PyYAML==6.0.2
419
+ pyzmq==24.0.1
420
+ qdldl==0.1.7.post5
421
+ ratelim==0.1.6
422
+ referencing==0.35.1
423
+ regex==2024.11.6
424
+ requests==2.32.3
425
+ requests-oauthlib==1.3.1
426
+ requests-toolbelt==1.0.0
427
+ requirements-parser==0.9.0
428
+ rich==13.9.4
429
+ rmm-cu12==24.10.0
430
+ rpds-py==0.22.3
431
+ rpy2==3.4.2
432
+ rsa==4.9
433
+ ruff==0.9.1
434
+ safehttpx==0.1.6
435
+ safetensors==0.5.1
436
+ scikit-image==0.25.0
437
+ scikit-learn==1.6.0
438
+ scipy==1.13.1
439
+ scooby==0.10.0
440
+ scs==3.2.7.post2
441
+ seaborn==0.13.2
442
+ SecretStorage==3.3.1
443
+ semantic-version==2.10.0
444
+ Send2Trash==1.8.3
445
+ sentence-transformers==3.3.1
446
+ sentencepiece==0.2.0
447
+ sentry-sdk==2.19.2
448
+ setproctitle==1.3.4
449
+ shap==0.46.0
450
+ shapely==2.0.6
451
+ shellingham==1.5.4
452
+ simple-parsing==0.1.6
453
+ six==1.17.0
454
+ sklearn-compat==0.1.3
455
+ sklearn-pandas==2.2.0
456
+ slicer==0.0.8
457
+ smart-open==7.1.0
458
+ smmap==5.0.2
459
+ sniffio==1.3.1
460
+ snowballstemmer==2.2.0
461
+ soundfile==0.13.0
462
+ soupsieve==2.6
463
+ soxr==0.5.0.post1
464
+ spacy==3.7.5
465
+ spacy-legacy==3.0.12
466
+ spacy-loggers==1.0.5
467
+ Sphinx==8.1.3
468
+ sphinxcontrib-applehelp==2.0.0
469
+ sphinxcontrib-devhelp==2.0.0
470
+ sphinxcontrib-htmlhelp==2.1.0
471
+ sphinxcontrib-jsmath==1.0.1
472
+ sphinxcontrib-qthelp==2.0.0
473
+ sphinxcontrib-serializinghtml==2.0.0
474
+ SQLAlchemy==2.0.36
475
+ sqlglot==25.1.0
476
+ sqlparse==0.5.3
477
+ srsly==2.5.0
478
+ stanio==0.5.1
479
+ starlette==0.41.3
480
+ statsmodels==0.14.4
481
+ stringzilla==3.11.3
482
+ sympy==1.13.1
483
+ tables==3.10.1
484
+ tabulate==0.9.0
485
+ tbb==2022.0.0
486
+ tcmlib==1.2.0
487
+ tenacity==9.0.0
488
+ tensorboard==2.17.1
489
+ tensorboard-data-server==0.7.2
490
+ tensorflow==2.17.1
491
+ tensorflow-datasets==4.9.7
492
+ tensorflow-hub==0.16.1
493
+ tensorflow-io-gcs-filesystem==0.37.1
494
+ tensorflow-metadata==1.13.1
495
+ tensorflow-probability==0.24.0
496
+ tensorstore==0.1.71
497
+ termcolor==2.5.0
498
+ terminado==0.18.1
499
+ text-unidecode==1.3
500
+ textblob==0.17.1
501
+ tf-slim==1.1.0
502
+ tf_keras==2.17.0
503
+ thinc==8.2.5
504
+ threadpoolctl==3.5.0
505
+ tifffile==2024.12.12
506
+ timm==1.0.12
507
+ tinycss2==1.4.0
508
+ tokenizers==0.21.0
509
+ toml==0.10.2
510
+ tomli==2.2.1
511
+ tomlkit==0.13.2
512
+ toolz==0.12.1
513
+ torch @ https://download.pytorch.org/whl/cu121_full/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl
514
+ torchaudio @ https://download.pytorch.org/whl/cu121/torchaudio-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl
515
+ torchsummary==1.5.1
516
+ torchvision @ https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp310-cp310-linux_x86_64.whl
517
+ tornado==6.3.3
518
+ tqdm==4.67.1
519
+ traitlets==5.7.1
520
+ traittypes==0.2.1
521
+ transformers==4.47.1
522
+ tweepy==4.14.0
523
+ typeguard==4.4.1
524
+ typer==0.15.1
525
+ types-pytz==2024.2.0.20241221
526
+ types-setuptools==75.6.0.20241223
527
+ typing_extensions==4.12.2
528
+ tzdata==2024.2
529
+ tzlocal==5.2
530
+ uc-micro-py==1.0.3
531
+ umf==0.9.1
532
+ uritemplate==4.1.1
533
+ urllib3==2.3.0
534
+ uvicorn==0.34.0
535
+ vega-datasets==0.9.0
536
+ wadllib==1.3.6
537
+ wandb==0.19.1
538
+ wasabi==1.1.3
539
+ wcwidth==0.2.13
540
+ weasel==0.4.1
541
+ webcolors==24.11.1
542
+ webencodings==0.5.1
543
+ websocket-client==1.8.0
544
+ websockets==14.1
545
+ Werkzeug==3.1.3
546
+ widgetsnbextension==3.6.10
547
+ wordcloud==1.9.4
548
+ wrapt==1.17.0
549
+ xarray==2025.1.0
550
+ xarray-einstats==0.8.0
551
+ xformers==0.0.29.post1
552
+ xgboost==2.1.3
553
+ xlrd==2.0.1
554
+ xyzservices==2024.9.0
555
+ yarl==1.18.3
556
+ yellowbrick==1.5
557
+ yfinance==0.2.51
558
+ zipp==3.21.0