taki0112 commited on
Commit
4f4656c
·
1 Parent(s): 2b3f761
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pipelines.inverted_ve_pipeline import STYLE_DESCRIPTION_DICT, create_image_grid
3
+ import gradio as gr
4
+ import os, json
5
+
6
+ from pipelines.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
7
+ from diffusers import AutoencoderKL
8
+ from random import randint
9
+ from utils import init_latent
10
+
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ if device == 'cpu':
13
+ torch_dtype = torch.float32
14
+ else:
15
+ torch_dtype = torch.float16
16
+
17
+ def memory_efficient(model):
18
+ try:
19
+ model.to(device)
20
+ except Exception as e:
21
+ print("Error moving model to device:", e)
22
+
23
+ try:
24
+ model.enable_model_cpu_offload()
25
+ except AttributeError:
26
+ print("enable_model_cpu_offload is not supported.")
27
+ try:
28
+ model.enable_vae_slicing()
29
+ except AttributeError:
30
+ print("enable_vae_slicing is not supported.")
31
+ if device == 'cuda':
32
+ try:
33
+ model.enable_xformers_memory_efficient_attention()
34
+ except AttributeError:
35
+ print("enable_xformers_memory_efficient_attention is not supported.")
36
+
37
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype)
38
+ model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch_dtype)
39
+
40
+ print("vae")
41
+ memory_efficient(vae)
42
+ print("SDXL")
43
+ memory_efficient(model)
44
+
45
+
46
+ # controlnet_scale, canny thres 1, 2 (2 > 1, 2:1, 3:1)
47
+
48
+ def parse_config(config):
49
+ with open(config, 'r') as f:
50
+ config = json.load(f)
51
+ return config
52
+
53
+
54
+ def load_example_style():
55
+ folder_path = 'assets/ref'
56
+ examples = []
57
+ for filename in os.listdir(folder_path):
58
+ if filename.endswith((".png")):
59
+ image_path = os.path.join(folder_path, filename)
60
+ image_name = os.path.basename(image_path)
61
+ style_name = image_name.split('_')[1]
62
+
63
+ config_path = './config/{}.json'.format(style_name)
64
+ config = parse_config(config_path)
65
+ inf_object_name = config["inference_info"]["inf_object_list"][0]
66
+
67
+ image_info = [image_path, style_name, inf_object_name, 1, 50]
68
+ examples.append(image_info)
69
+
70
+ return examples
71
+
72
+ def style_fn(image_path, style_name, content_text, output_number, diffusion_step=50):
73
+ """
74
+
75
+ :param style_name: 어떤 json 파일 부를거냐 ?
76
+ :param content_text: 어떤 콘텐츠로 변화를 원하니 ?
77
+ :param output_number: 몇개 생성할거니 ?
78
+ :return:
79
+ """
80
+ config_path = './config/{}.json'.format(style_name)
81
+ config = parse_config(config_path)
82
+
83
+ inf_object = content_text
84
+ inf_seeds = [randint(0, 10**10) for _ in range(int(output_number))]
85
+ # inf_seeds = [i for i in range(int(output_number))]
86
+
87
+
88
+ activate_layer_indices_list = config['inference_info']['activate_layer_indices_list']
89
+ activate_step_indices_list = config['inference_info']['activate_step_indices_list']
90
+ ref_seed = config['reference_info']['ref_seeds'][0]
91
+
92
+ attn_map_save_steps = config['inference_info']['attn_map_save_steps']
93
+ guidance_scale = config['guidance_scale']
94
+ use_inf_negative_prompt = config['inference_info']['use_negative_prompt']
95
+
96
+ style_name = config["style_name_list"][0]
97
+
98
+ ref_object = config["reference_info"]["ref_object_list"][0]
99
+ ref_with_style_description = config['reference_info']['with_style_description']
100
+ inf_with_style_description = config['inference_info']['with_style_description']
101
+
102
+ use_shared_attention = config['inference_info']['use_shared_attention']
103
+ adain_queries = config['inference_info']['adain_queries']
104
+ adain_keys = config['inference_info']['adain_keys']
105
+ adain_values = config['inference_info']['adain_values']
106
+
107
+ use_advanced_sampling = config['inference_info']['use_advanced_sampling']
108
+
109
+ style_description_pos, style_description_neg = STYLE_DESCRIPTION_DICT[style_name][0], \
110
+ STYLE_DESCRIPTION_DICT[style_name][1]
111
+
112
+ # Inference
113
+ with torch.inference_mode():
114
+ grid = None
115
+ if ref_with_style_description:
116
+ ref_prompt = style_description_pos.replace("{object}", ref_object)
117
+ else:
118
+ ref_prompt = ref_object
119
+
120
+ if inf_with_style_description:
121
+ inf_prompt = style_description_pos.replace("{object}", inf_object)
122
+ else:
123
+ inf_prompt = inf_object
124
+
125
+ for activate_layer_indices in activate_layer_indices_list:
126
+
127
+ for activate_step_indices in activate_step_indices_list:
128
+
129
+ str_activate_layer, str_activate_step = model.activate_layer(
130
+ activate_layer_indices=activate_layer_indices,
131
+ attn_map_save_steps=attn_map_save_steps,
132
+ activate_step_indices=activate_step_indices, use_shared_attention=use_shared_attention,
133
+ adain_queries=adain_queries,
134
+ adain_keys=adain_keys,
135
+ adain_values=adain_values,
136
+ )
137
+ # ref_latent = model.get_init_latent(ref_seed, precomputed_path=None)
138
+ ref_latent = init_latent(model, device_name=device, dtype=torch_dtype, seed=ref_seed)
139
+ latents = [ref_latent]
140
+
141
+ for inf_seed in inf_seeds:
142
+ # latents.append(model.get_init_latent(inf_seed, precomputed_path=None))
143
+ inf_latent = init_latent(model, device_name=device, dtype=torch_dtype, seed=inf_seed)
144
+ latents.append(inf_latent)
145
+
146
+ latents = torch.cat(latents, dim=0)
147
+ latents.to(device)
148
+
149
+ images = model(
150
+ prompt=ref_prompt,
151
+ negative_prompt=style_description_neg,
152
+ guidance_scale=guidance_scale,
153
+ num_inference_steps=diffusion_step,
154
+ latents=latents,
155
+ num_images_per_prompt=len(inf_seeds) + 1,
156
+ target_prompt=inf_prompt,
157
+ use_inf_negative_prompt=use_inf_negative_prompt,
158
+ use_advanced_sampling=use_advanced_sampling
159
+ )[0][1:]
160
+
161
+ n_row = 1
162
+ n_col = len(inf_seeds) # 원본추가하려면 + 1
163
+
164
+ # make grid
165
+ grid = create_image_grid(images, n_row, n_col, padding=10)
166
+
167
+ torch.cuda.empty_cache()
168
+
169
+ return grid
170
+
171
+ description_md = """
172
+
173
+ ### We introduce `Visual Style Prompting`, which reflects the style of a reference image to the images generated by a pretrained text-to-image diffusion model without finetuning or optimization (e.g., Figure N).
174
+ ### 📖 [[Paper](https://arxiv.org/abs/2402.12974)] | ✨ [[Project page](https://curryjung.github.io/VisualStylePrompt)] | ✨ [[Code](https://github.com/naver-ai/Visual-Style-Prompting)]
175
+ ### 🔥 [[w/ Controlnet ver](https://huggingface.co/spaces/naver-ai/VisualStylePrompting_Controlnet)]
176
+ ---
177
+ ### To try out our vanilla demo,
178
+ 1. Choose a `style reference` from the collection of images below.
179
+ 2. Enter the `text prompt`.
180
+ 3. Choose the `number of outputs`.
181
+
182
+ ### To achieve faster results, we recommend lowering the diffusion steps to 30.
183
+ ### Enjoy ! 😄
184
+ """
185
+
186
+ iface_style = gr.Interface(
187
+ fn=style_fn,
188
+ inputs=[
189
+ gr.components.Image(label="Style Image"),
190
+ gr.components.Textbox(label='Style name', visible=False),
191
+ gr.components.Textbox(label="Text prompt", placeholder="Enter Text prompt"),
192
+ gr.components.Textbox(label="Number of outputs", placeholder="Enter Number of outputs"),
193
+ gr.components.Slider(minimum=50, maximum=50, step=10, value=50, label="Diffusion steps")
194
+ ],
195
+ outputs=gr.components.Image(type="pil"),
196
+ title="🎨 Visual Style Prompting (default)",
197
+ description=description_md,
198
+ examples=load_example_style(),
199
+ )
200
+
201
+ iface_style.launch(debug=True)
config/chinese-ink-paint.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "chinese-ink-paint"
7
+ ],
8
+ "save_info": {
9
+ "base_exp_dir": "experiments",
10
+ "base_exp_name": "results"
11
+ },
12
+ "reference_info": {
13
+ "ref_seeds": [
14
+ 1
15
+ ],
16
+ "ref_object_list": [
17
+ "A horse"
18
+ ],
19
+ "with_style_description": true,
20
+ "external_init_noise_path": false,
21
+ "guidance_scale": 7.0,
22
+ "use_negative_prompt": true
23
+ },
24
+ "inference_info": {
25
+ "activate_layer_indices_list": [
26
+ [
27
+ [
28
+ 0,
29
+ 0
30
+ ],
31
+ [
32
+ 128,
33
+ 140
34
+ ]
35
+ ]
36
+ ],
37
+ "inf_seeds": [
38
+ 0,
39
+ 1,
40
+ 2,
41
+ 3,
42
+ 4,
43
+ 5,
44
+ 6,
45
+ 7,
46
+ 8,
47
+ 9
48
+ ],
49
+ "inf_object_list": [
50
+ "A tiger"
51
+ ],
52
+ "with_style_description": true,
53
+ "negative_prompts": false,
54
+ "external_init_noise_path": false,
55
+ "attn_map_save_steps": [],
56
+ "guidance_scale": 7.0,
57
+ "use_negative_prompt": true,
58
+ "activate_step_indices_list": [
59
+ [
60
+ [
61
+ 0,
62
+ 49
63
+ ]
64
+ ]
65
+ ],
66
+ "use_advanced_sampling": true,
67
+ "use_shared_attention": false,
68
+ "adain_queries": true,
69
+ "adain_keys": true,
70
+ "adain_values": false
71
+ }
72
+ }
config/cloud.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "cloud"
7
+ ],
8
+ "save_info": {
9
+ "base_exp_dir": "experiments",
10
+ "base_exp_name": "results"
11
+ },
12
+ "reference_info": {
13
+ "ref_seeds": [
14
+ 3
15
+ ],
16
+ "ref_object_list": [
17
+ "a Cloud in the sky"
18
+ ],
19
+ "with_style_description": true,
20
+ "external_init_noise_path": false,
21
+ "guidance_scale": 7.0,
22
+ "use_negative_prompt": true
23
+ },
24
+ "inference_info": {
25
+ "activate_layer_indices_list": [
26
+ [
27
+ [
28
+ 0,
29
+ 0
30
+ ],
31
+ [
32
+ 128,
33
+ 140
34
+ ]
35
+ ]
36
+ ],
37
+ "inf_seeds": [
38
+ 0,
39
+ 1,
40
+ 2,
41
+ 3,
42
+ 4,
43
+ 5
44
+
45
+ ],
46
+ "inf_object_list": [
47
+ "A photo of a dog"
48
+ ],
49
+ "with_style_description": true,
50
+ "negative_prompts": false,
51
+ "external_init_noise_path": false,
52
+ "attn_map_save_steps": [],
53
+ "guidance_scale": 7.0,
54
+ "use_negative_prompt": true,
55
+ "activate_step_indices_list": [
56
+ [
57
+ [
58
+ 0,
59
+ 49
60
+ ]
61
+ ]
62
+ ],
63
+ "use_advanced_sampling": true,
64
+ "use_shared_attention": false,
65
+ "adain_queries": true,
66
+ "adain_keys": true,
67
+ "adain_values": false
68
+ }
69
+ }
config/default.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "anime",
7
+ "Artstyle_Pop_Art",
8
+ "low_poly",
9
+ "line_art"
10
+ ],
11
+ "save_info": {
12
+ "base_exp_dir": "experiments",
13
+ "base_exp_name": "results"
14
+ },
15
+ "reference_info": {
16
+ "ref_seeds": [
17
+ 42
18
+ ],
19
+ "ref_object_list": [
20
+ "cat"
21
+ ],
22
+ "with_style_description": true,
23
+ "external_init_noise_path": false,
24
+ "guidance_scale": 7.0,
25
+ "use_negative_prompt": true
26
+ },
27
+ "inference_info": {
28
+ "activate_layer_indices_list": [
29
+ [
30
+ [
31
+ 0,
32
+ 0
33
+ ],
34
+ [
35
+ 128,
36
+ 140
37
+ ]
38
+ ]
39
+ ],
40
+ "inf_seeds": [
41
+ 0,
42
+ 1,
43
+ 2,
44
+ 3,
45
+ 4,
46
+ 5,
47
+ 6,
48
+ 7,
49
+ 8,
50
+ 9
51
+ ],
52
+ "inf_object_list": [
53
+ "A photo of a dog"
54
+ ],
55
+ "with_style_description": true,
56
+ "negative_prompts": false,
57
+ "external_init_noise_path": false,
58
+ "attn_map_save_steps": [],
59
+ "guidance_scale": 7.0,
60
+ "use_negative_prompt": true,
61
+ "activate_step_indices_list": [
62
+ [
63
+ [
64
+ 0,
65
+ 49
66
+ ]
67
+ ]
68
+ ],
69
+ "use_advanced_sampling": true,
70
+ "use_shared_attention": false,
71
+ "adain_queries": true,
72
+ "adain_keys": true,
73
+ "adain_values": false
74
+ }
75
+ }
config/digital-art.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "digital-art"
7
+ ],
8
+ "save_info": {
9
+ "base_exp_dir": "experiments",
10
+ "base_exp_name": "results"
11
+ },
12
+ "reference_info": {
13
+ "ref_seeds": [
14
+ 42
15
+ ],
16
+ "ref_object_list": [
17
+ "A robot"
18
+ ],
19
+ "with_style_description": true,
20
+ "external_init_noise_path": false,
21
+ "guidance_scale": 7.0,
22
+ "use_negative_prompt": true
23
+ },
24
+ "inference_info": {
25
+ "activate_layer_indices_list": [
26
+ [
27
+ [
28
+ 0,
29
+ 0
30
+ ],
31
+ [
32
+ 128,
33
+ 140
34
+ ]
35
+ ]
36
+ ],
37
+ "inf_seeds": [
38
+ 0,
39
+ 1,
40
+ 2,
41
+ 3,
42
+ 4,
43
+ 5,
44
+ 6,
45
+ 7,
46
+ 8,
47
+ 9
48
+ ],
49
+ "inf_object_list": [
50
+ "A woman playing basketball"
51
+ ],
52
+ "with_style_description": true,
53
+ "negative_prompts": false,
54
+ "external_init_noise_path": false,
55
+ "attn_map_save_steps": [],
56
+ "guidance_scale": 7.0,
57
+ "use_negative_prompt": true,
58
+ "activate_step_indices_list": [
59
+ [
60
+ [
61
+ 0,
62
+ 49
63
+ ]
64
+ ]
65
+ ],
66
+ "use_advanced_sampling": true,
67
+ "use_shared_attention": false,
68
+ "adain_queries": true,
69
+ "adain_keys": true,
70
+ "adain_values": false
71
+ }
72
+ }
config/fire.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "fire"
7
+ ],
8
+ "save_info": {
9
+ "base_exp_dir": "experiments",
10
+ "base_exp_name": "results"
11
+ },
12
+ "reference_info": {
13
+ "ref_seeds": [
14
+ 3
15
+ ],
16
+ "ref_object_list": [
17
+ "fire"
18
+ ],
19
+ "with_style_description": true,
20
+ "external_init_noise_path": false,
21
+ "guidance_scale": 7.0,
22
+ "use_negative_prompt": true
23
+ },
24
+ "inference_info": {
25
+ "activate_layer_indices_list": [
26
+ [
27
+ [
28
+ 0,
29
+ 0
30
+ ],
31
+ [
32
+ 128,
33
+ 140
34
+ ]
35
+ ]
36
+ ],
37
+ "inf_seeds": [
38
+ 0,
39
+ 1,
40
+ 2,
41
+ 3,
42
+ 4,
43
+ 5,
44
+ 6,
45
+ 7,
46
+ 8,
47
+ 9
48
+ ],
49
+ "inf_object_list": [
50
+ "A dragon"
51
+ ],
52
+ "with_style_description": true,
53
+ "negative_prompts": false,
54
+ "external_init_noise_path": false,
55
+ "attn_map_save_steps": [],
56
+ "guidance_scale": 7.0,
57
+ "use_negative_prompt": true,
58
+ "activate_step_indices_list": [
59
+ [
60
+ [
61
+ 0,
62
+ 49
63
+ ]
64
+ ]
65
+ ],
66
+ "use_advanced_sampling": true,
67
+ "use_shared_attention": false,
68
+ "adain_queries": true,
69
+ "adain_keys": true,
70
+ "adain_values": false
71
+ }
72
+ }
config/klimt.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "klimt"
7
+ ],
8
+ "save_info": {
9
+ "base_exp_dir": "experiments",
10
+ "base_exp_name": "results"
11
+ },
12
+ "reference_info": {
13
+ "ref_seeds": [
14
+ 1
15
+ ],
16
+ "ref_object_list": [
17
+ "the kiss"
18
+ ],
19
+ "with_style_description": true,
20
+ "external_init_noise_path": false,
21
+ "guidance_scale": 7.0,
22
+ "use_negative_prompt": true
23
+ },
24
+ "inference_info": {
25
+ "activate_layer_indices_list": [
26
+ [
27
+ [
28
+ 0,
29
+ 0
30
+ ],
31
+ [
32
+ 128,
33
+ 140
34
+ ]
35
+ ]
36
+ ],
37
+ "inf_seeds": [
38
+ 0,
39
+ 1,
40
+ 2,
41
+ 3,
42
+ 4,
43
+ 5,
44
+ 6,
45
+ 7,
46
+ 8,
47
+ 9
48
+ ],
49
+ "inf_object_list": [
50
+ "Frog"
51
+ ],
52
+ "with_style_description": true,
53
+ "negative_prompts": false,
54
+ "external_init_noise_path": false,
55
+ "attn_map_save_steps": [],
56
+ "guidance_scale": 7.0,
57
+ "use_negative_prompt": true,
58
+ "activate_step_indices_list": [
59
+ [
60
+ [
61
+ 0,
62
+ 49
63
+ ]
64
+ ]
65
+ ],
66
+ "use_advanced_sampling": true,
67
+ "use_shared_attention": false,
68
+ "adain_queries": true,
69
+ "adain_keys": true,
70
+ "adain_values": false
71
+ }
72
+ }
config/line-art.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "line-art"
7
+ ],
8
+ "save_info": {
9
+ "base_exp_dir": "experiments",
10
+ "base_exp_name": "results"
11
+ },
12
+ "reference_info": {
13
+ "ref_seeds": [
14
+ 42
15
+ ],
16
+ "ref_object_list": [
17
+ "an owl"
18
+ ],
19
+ "with_style_description": true,
20
+ "external_init_noise_path": false,
21
+ "guidance_scale": 7.0,
22
+ "use_negative_prompt": true
23
+ },
24
+ "inference_info": {
25
+ "activate_layer_indices_list": [
26
+ [
27
+ [
28
+ 0,
29
+ 0
30
+ ],
31
+ [
32
+ 128,
33
+ 140
34
+ ]
35
+ ]
36
+ ],
37
+ "inf_seeds": [
38
+ 0,
39
+ 1,
40
+ 2,
41
+ 3,
42
+ 4,
43
+ 5,
44
+ 6,
45
+ 7,
46
+ 8,
47
+ 9
48
+ ],
49
+ "inf_object_list": [
50
+ "A dragon"
51
+ ],
52
+ "with_style_description": true,
53
+ "negative_prompts": false,
54
+ "external_init_noise_path": false,
55
+ "attn_map_save_steps": [],
56
+ "guidance_scale": 7.0,
57
+ "use_negative_prompt": true,
58
+ "activate_step_indices_list": [
59
+ [
60
+ [
61
+ 0,
62
+ 49
63
+ ]
64
+ ]
65
+ ],
66
+ "use_advanced_sampling": true,
67
+ "use_shared_attention": false,
68
+ "adain_queries": true,
69
+ "adain_keys": true,
70
+ "adain_values": false
71
+ }
72
+ }
config/low-poly.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "low-poly"
7
+ ],
8
+ "save_info": {
9
+ "base_exp_dir": "experiments",
10
+ "base_exp_name": "results"
11
+ },
12
+ "reference_info": {
13
+ "ref_seeds": [
14
+ 42
15
+ ],
16
+ "ref_object_list": [
17
+ "A cat"
18
+ ],
19
+ "with_style_description": true,
20
+ "external_init_noise_path": false,
21
+ "guidance_scale": 7.0,
22
+ "use_negative_prompt": true
23
+ },
24
+ "inference_info": {
25
+ "activate_layer_indices_list": [
26
+ [
27
+ [
28
+ 0,
29
+ 0
30
+ ],
31
+ [
32
+ 128,
33
+ 140
34
+ ]
35
+ ]
36
+ ],
37
+ "inf_seeds": [
38
+ 0,
39
+ 1,
40
+ 2,
41
+ 3,
42
+ 4,
43
+ 5,
44
+ 6,
45
+ 7,
46
+ 8,
47
+ 9
48
+ ],
49
+ "inf_object_list": [
50
+ "A rhino"
51
+ ],
52
+ "with_style_description": true,
53
+ "negative_prompts": false,
54
+ "external_init_noise_path": false,
55
+ "attn_map_save_steps": [],
56
+ "guidance_scale": 7.0,
57
+ "use_negative_prompt": true,
58
+ "activate_step_indices_list": [
59
+ [
60
+ [
61
+ 0,
62
+ 49
63
+ ]
64
+ ]
65
+ ],
66
+ "use_advanced_sampling": true,
67
+ "use_shared_attention": false,
68
+ "adain_queries": true,
69
+ "adain_keys": true,
70
+ "adain_values": false
71
+ }
72
+ }
config/munch.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "munch"
7
+ ],
8
+ "save_info": {
9
+ "base_exp_dir": "experiments",
10
+ "base_exp_name": "results"
11
+ },
12
+ "reference_info": {
13
+ "ref_seeds": [
14
+ 1
15
+ ],
16
+ "ref_object_list": [
17
+ "The scream"
18
+ ],
19
+ "with_style_description": true,
20
+ "external_init_noise_path": false,
21
+ "guidance_scale": 7.0,
22
+ "use_negative_prompt": true
23
+ },
24
+ "inference_info": {
25
+ "activate_layer_indices_list": [
26
+ [
27
+ [
28
+ 0,
29
+ 0
30
+ ],
31
+ [
32
+ 128,
33
+ 140
34
+ ]
35
+ ]
36
+ ],
37
+ "inf_seeds": [
38
+ 0,
39
+ 1,
40
+ 2,
41
+ 3,
42
+ 4,
43
+ 5,
44
+ 6,
45
+ 7,
46
+ 8,
47
+ 9
48
+ ],
49
+ "inf_object_list": [
50
+ "A dragon"
51
+ ],
52
+ "with_style_description": true,
53
+ "negative_prompts": false,
54
+ "external_init_noise_path": false,
55
+ "attn_map_save_steps": [],
56
+ "guidance_scale": 7.0,
57
+ "use_negative_prompt": true,
58
+ "activate_step_indices_list": [
59
+ [
60
+ [
61
+ 0,
62
+ 49
63
+ ]
64
+ ]
65
+ ],
66
+ "use_advanced_sampling": true,
67
+ "use_shared_attention": false,
68
+ "adain_queries": true,
69
+ "adain_keys": true,
70
+ "adain_values": false
71
+ }
72
+ }
config/totoro.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "totoro"
7
+ ],
8
+ "save_info": {
9
+ "base_exp_dir": "experiments",
10
+ "base_exp_name": "results"
11
+ },
12
+ "reference_info": {
13
+ "ref_seeds": [
14
+ 42
15
+ ],
16
+ "ref_object_list": [
17
+ "totoro holding a tiny umbrella in the rain"
18
+ ],
19
+ "with_style_description": true,
20
+ "external_init_noise_path": false,
21
+ "guidance_scale": 7.0,
22
+ "use_negative_prompt": true
23
+ },
24
+ "inference_info": {
25
+ "activate_layer_indices_list": [
26
+ [
27
+ [
28
+ 0,
29
+ 0
30
+ ],
31
+ [
32
+ 108,
33
+ 140
34
+ ]
35
+ ]
36
+ ],
37
+ "inf_seeds": [
38
+ 0,
39
+ 1,
40
+ 2,
41
+ 3,
42
+ 4,
43
+ 5,
44
+ 6,
45
+ 7,
46
+ 8,
47
+ 9
48
+ ],
49
+ "inf_object_list": [
50
+ "1 cute bird holding a tiny umbrella, forward facing"
51
+ ],
52
+ "with_style_description": true,
53
+ "negative_prompts": true,
54
+ "external_init_noise_path": false,
55
+ "attn_map_save_steps": [],
56
+ "guidance_scale": 7.0,
57
+ "use_negative_prompt": true,
58
+ "activate_step_indices_list": [
59
+ [
60
+ [
61
+ 0,
62
+ 49
63
+ ]
64
+ ]
65
+ ],
66
+ "use_advanced_sampling": true,
67
+ "use_shared_attention": false,
68
+ "adain_queries": true,
69
+ "adain_keys": true,
70
+ "adain_values": false
71
+ }
72
+ }
config/van-gogh.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "precomputed_path": "./precomputed",
4
+ "guidance_scale": 7.0,
5
+ "style_name_list": [
6
+ "van-gogh"
7
+ ],
8
+ "save_info": {
9
+ "base_exp_dir": "experiments",
10
+ "base_exp_name": "results"
11
+ },
12
+ "reference_info": {
13
+ "ref_seeds": [
14
+ 1
15
+ ],
16
+ "ref_object_list": [
17
+ "The Starry Night"
18
+ ],
19
+ "with_style_description": true,
20
+ "external_init_noise_path": false,
21
+ "guidance_scale": 7.0,
22
+ "use_negative_prompt": true
23
+ },
24
+ "inference_info": {
25
+ "activate_layer_indices_list": [
26
+ [
27
+ [
28
+ 0,
29
+ 0
30
+ ],
31
+ [
32
+ 128,
33
+ 140
34
+ ]
35
+ ]
36
+ ],
37
+ "inf_seeds": [
38
+ 0,
39
+ 1,
40
+ 2,
41
+ 3,
42
+ 4,
43
+ 5,
44
+ 6,
45
+ 7,
46
+ 8,
47
+ 9
48
+ ],
49
+ "inf_object_list": [
50
+ "A dragon"
51
+ ],
52
+ "with_style_description": true,
53
+ "negative_prompts": false,
54
+ "external_init_noise_path": false,
55
+ "attn_map_save_steps": [],
56
+ "guidance_scale": 7.0,
57
+ "use_negative_prompt": true,
58
+ "activate_step_indices_list": [
59
+ [
60
+ [
61
+ 0,
62
+ 49
63
+ ]
64
+ ]
65
+ ],
66
+ "use_advanced_sampling": true,
67
+ "use_shared_attention": false,
68
+ "adain_queries": true,
69
+ "adain_keys": true,
70
+ "adain_values": false
71
+ }
72
+ }
pipelines/__init__.py ADDED
File without changes
pipelines/controlnet.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalControlnetMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.unet_2d_blocks import (
34
+ CrossAttnDownBlock2D,
35
+ DownBlock2D,
36
+ UNetMidBlock2DCrossAttn,
37
+ get_down_block,
38
+ )
39
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ @dataclass
46
+ class ControlNetOutput(BaseOutput):
47
+ """
48
+ The output of [`ControlNetModel`].
49
+
50
+ Args:
51
+ down_block_res_samples (`tuple[torch.Tensor]`):
52
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
53
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
54
+ used to condition the original UNet's downsampling activations.
55
+ mid_down_block_re_sample (`torch.Tensor`):
56
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
57
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
58
+ Output can be used to condition the original UNet's middle block activation.
59
+ """
60
+
61
+ down_block_res_samples: Tuple[torch.Tensor]
62
+ mid_block_res_sample: torch.Tensor
63
+
64
+
65
+ class ControlNetConditioningEmbedding(nn.Module):
66
+ """
67
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
68
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
69
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
70
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
71
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
72
+ model) to encode image-space conditions ... into feature maps ..."
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ conditioning_embedding_channels: int,
78
+ conditioning_channels: int = 3,
79
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
80
+ ):
81
+ super().__init__()
82
+
83
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
84
+
85
+ self.blocks = nn.ModuleList([])
86
+
87
+ for i in range(len(block_out_channels) - 1):
88
+ channel_in = block_out_channels[i]
89
+ channel_out = block_out_channels[i + 1]
90
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
91
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
92
+
93
+ self.conv_out = zero_module(
94
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
95
+ )
96
+
97
+ def forward(self, conditioning):
98
+ embedding = self.conv_in(conditioning)
99
+ embedding = F.silu(embedding)
100
+
101
+ for block in self.blocks:
102
+ embedding = block(embedding)
103
+ embedding = F.silu(embedding)
104
+
105
+ embedding = self.conv_out(embedding)
106
+
107
+ return embedding
108
+
109
+
110
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
111
+ """
112
+ A ControlNet model.
113
+
114
+ Args:
115
+ in_channels (`int`, defaults to 4):
116
+ The number of channels in the input sample.
117
+ flip_sin_to_cos (`bool`, defaults to `True`):
118
+ Whether to flip the sin to cos in the time embedding.
119
+ freq_shift (`int`, defaults to 0):
120
+ The frequency shift to apply to the time embedding.
121
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
122
+ The tuple of downsample blocks to use.
123
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
124
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
125
+ The tuple of output channels for each block.
126
+ layers_per_block (`int`, defaults to 2):
127
+ The number of layers per block.
128
+ downsample_padding (`int`, defaults to 1):
129
+ The padding to use for the downsampling convolution.
130
+ mid_block_scale_factor (`float`, defaults to 1):
131
+ The scale factor to use for the mid block.
132
+ act_fn (`str`, defaults to "silu"):
133
+ The activation function to use.
134
+ norm_num_groups (`int`, *optional*, defaults to 32):
135
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
136
+ in post-processing.
137
+ norm_eps (`float`, defaults to 1e-5):
138
+ The epsilon to use for the normalization.
139
+ cross_attention_dim (`int`, defaults to 1280):
140
+ The dimension of the cross attention features.
141
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
142
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
143
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
144
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
145
+ encoder_hid_dim (`int`, *optional*, defaults to None):
146
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
147
+ dimension to `cross_attention_dim`.
148
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
149
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
150
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
151
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
152
+ The dimension of the attention heads.
153
+ use_linear_projection (`bool`, defaults to `False`):
154
+ class_embed_type (`str`, *optional*, defaults to `None`):
155
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
156
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
157
+ addition_embed_type (`str`, *optional*, defaults to `None`):
158
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
159
+ "text". "text" will use the `TextTimeEmbedding` layer.
160
+ num_class_embeds (`int`, *optional*, defaults to 0):
161
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
162
+ class conditioning with `class_embed_type` equal to `None`.
163
+ upcast_attention (`bool`, defaults to `False`):
164
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
165
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
166
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
167
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
168
+ `class_embed_type="projection"`.
169
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
170
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
171
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
172
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
173
+ global_pool_conditions (`bool`, defaults to `False`):
174
+ """
175
+
176
+ _supports_gradient_checkpointing = True
177
+
178
+ @register_to_config
179
+ def __init__(
180
+ self,
181
+ in_channels: int = 4,
182
+ conditioning_channels: int = 3,
183
+ flip_sin_to_cos: bool = True,
184
+ freq_shift: int = 0,
185
+ down_block_types: Tuple[str] = (
186
+ "CrossAttnDownBlock2D",
187
+ "CrossAttnDownBlock2D",
188
+ "CrossAttnDownBlock2D",
189
+ "DownBlock2D",
190
+ ),
191
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
192
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
193
+ layers_per_block: int = 2,
194
+ downsample_padding: int = 1,
195
+ mid_block_scale_factor: float = 1,
196
+ act_fn: str = "silu",
197
+ norm_num_groups: Optional[int] = 32,
198
+ norm_eps: float = 1e-5,
199
+ cross_attention_dim: int = 1280,
200
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
201
+ encoder_hid_dim: Optional[int] = None,
202
+ encoder_hid_dim_type: Optional[str] = None,
203
+ attention_head_dim: Union[int, Tuple[int]] = 8,
204
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
205
+ use_linear_projection: bool = False,
206
+ class_embed_type: Optional[str] = None,
207
+ addition_embed_type: Optional[str] = None,
208
+ addition_time_embed_dim: Optional[int] = None,
209
+ num_class_embeds: Optional[int] = None,
210
+ upcast_attention: bool = False,
211
+ resnet_time_scale_shift: str = "default",
212
+ projection_class_embeddings_input_dim: Optional[int] = None,
213
+ controlnet_conditioning_channel_order: str = "rgb",
214
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
215
+ global_pool_conditions: bool = False,
216
+ addition_embed_type_num_heads=64,
217
+ ):
218
+ super().__init__()
219
+
220
+ # If `num_attention_heads` is not defined (which is the case for most models)
221
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
222
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
223
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
224
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
225
+ # which is why we correct for the naming here.
226
+ num_attention_heads = num_attention_heads or attention_head_dim
227
+
228
+ # Check inputs
229
+ if len(block_out_channels) != len(down_block_types):
230
+ raise ValueError(
231
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
232
+ )
233
+
234
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
235
+ raise ValueError(
236
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
237
+ )
238
+
239
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
242
+ )
243
+
244
+ if isinstance(transformer_layers_per_block, int):
245
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
246
+
247
+ # input
248
+ conv_in_kernel = 3
249
+ conv_in_padding = (conv_in_kernel - 1) // 2
250
+ self.conv_in = nn.Conv2d(
251
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
252
+ )
253
+
254
+ # time
255
+ time_embed_dim = block_out_channels[0] * 4
256
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
257
+ timestep_input_dim = block_out_channels[0]
258
+ self.time_embedding = TimestepEmbedding(
259
+ timestep_input_dim,
260
+ time_embed_dim,
261
+ act_fn=act_fn,
262
+ )
263
+
264
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
265
+ encoder_hid_dim_type = "text_proj"
266
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
267
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
268
+
269
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
270
+ raise ValueError(
271
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
272
+ )
273
+
274
+ if encoder_hid_dim_type == "text_proj":
275
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
276
+ elif encoder_hid_dim_type == "text_image_proj":
277
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
278
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
279
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
280
+ self.encoder_hid_proj = TextImageProjection(
281
+ text_embed_dim=encoder_hid_dim,
282
+ image_embed_dim=cross_attention_dim,
283
+ cross_attention_dim=cross_attention_dim,
284
+ )
285
+
286
+ elif encoder_hid_dim_type is not None:
287
+ raise ValueError(
288
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
289
+ )
290
+ else:
291
+ self.encoder_hid_proj = None
292
+
293
+ # class embedding
294
+ if class_embed_type is None and num_class_embeds is not None:
295
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
296
+ elif class_embed_type == "timestep":
297
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
298
+ elif class_embed_type == "identity":
299
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
300
+ elif class_embed_type == "projection":
301
+ if projection_class_embeddings_input_dim is None:
302
+ raise ValueError(
303
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
304
+ )
305
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
306
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
307
+ # 2. it projects from an arbitrary input dimension.
308
+ #
309
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
310
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
311
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
312
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
313
+ else:
314
+ self.class_embedding = None
315
+
316
+ if addition_embed_type == "text":
317
+ if encoder_hid_dim is not None:
318
+ text_time_embedding_from_dim = encoder_hid_dim
319
+ else:
320
+ text_time_embedding_from_dim = cross_attention_dim
321
+
322
+ self.add_embedding = TextTimeEmbedding(
323
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
324
+ )
325
+ elif addition_embed_type == "text_image":
326
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
327
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
328
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
329
+ self.add_embedding = TextImageTimeEmbedding(
330
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
331
+ )
332
+ elif addition_embed_type == "text_time":
333
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
334
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
335
+
336
+ elif addition_embed_type is not None:
337
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
338
+
339
+ # control net conditioning embedding
340
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
341
+ conditioning_embedding_channels=block_out_channels[0],
342
+ block_out_channels=conditioning_embedding_out_channels,
343
+ conditioning_channels=conditioning_channels,
344
+ )
345
+
346
+ self.down_blocks = nn.ModuleList([])
347
+ self.controlnet_down_blocks = nn.ModuleList([])
348
+
349
+ if isinstance(only_cross_attention, bool):
350
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
351
+
352
+ if isinstance(attention_head_dim, int):
353
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
354
+
355
+ if isinstance(num_attention_heads, int):
356
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
357
+
358
+ # down
359
+ output_channel = block_out_channels[0]
360
+
361
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
362
+ controlnet_block = zero_module(controlnet_block)
363
+ self.controlnet_down_blocks.append(controlnet_block)
364
+
365
+ for i, down_block_type in enumerate(down_block_types):
366
+ input_channel = output_channel
367
+ output_channel = block_out_channels[i]
368
+ is_final_block = i == len(block_out_channels) - 1
369
+
370
+ down_block = get_down_block(
371
+ down_block_type,
372
+ num_layers=layers_per_block,
373
+ transformer_layers_per_block=transformer_layers_per_block[i],
374
+ in_channels=input_channel,
375
+ out_channels=output_channel,
376
+ temb_channels=time_embed_dim,
377
+ add_downsample=not is_final_block,
378
+ resnet_eps=norm_eps,
379
+ resnet_act_fn=act_fn,
380
+ resnet_groups=norm_num_groups,
381
+ cross_attention_dim=cross_attention_dim,
382
+ num_attention_heads=num_attention_heads[i],
383
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
384
+ downsample_padding=downsample_padding,
385
+ use_linear_projection=use_linear_projection,
386
+ only_cross_attention=only_cross_attention[i],
387
+ upcast_attention=upcast_attention,
388
+ resnet_time_scale_shift=resnet_time_scale_shift,
389
+ )
390
+ self.down_blocks.append(down_block)
391
+
392
+ for _ in range(layers_per_block):
393
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
394
+ controlnet_block = zero_module(controlnet_block)
395
+ self.controlnet_down_blocks.append(controlnet_block)
396
+
397
+ if not is_final_block:
398
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
399
+ controlnet_block = zero_module(controlnet_block)
400
+ self.controlnet_down_blocks.append(controlnet_block)
401
+
402
+ # mid
403
+ mid_block_channel = block_out_channels[-1]
404
+
405
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
406
+ controlnet_block = zero_module(controlnet_block)
407
+ self.controlnet_mid_block = controlnet_block
408
+
409
+ self.mid_block = UNetMidBlock2DCrossAttn(
410
+ transformer_layers_per_block=transformer_layers_per_block[-1],
411
+ in_channels=mid_block_channel,
412
+ temb_channels=time_embed_dim,
413
+ resnet_eps=norm_eps,
414
+ resnet_act_fn=act_fn,
415
+ output_scale_factor=mid_block_scale_factor,
416
+ resnet_time_scale_shift=resnet_time_scale_shift,
417
+ cross_attention_dim=cross_attention_dim,
418
+ num_attention_heads=num_attention_heads[-1],
419
+ resnet_groups=norm_num_groups,
420
+ use_linear_projection=use_linear_projection,
421
+ upcast_attention=upcast_attention,
422
+ )
423
+
424
+ @classmethod
425
+ def from_unet(
426
+ cls,
427
+ unet: UNet2DConditionModel,
428
+ controlnet_conditioning_channel_order: str = "rgb",
429
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
430
+ load_weights_from_unet: bool = True,
431
+ ):
432
+ r"""
433
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
434
+
435
+ Parameters:
436
+ unet (`UNet2DConditionModel`):
437
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
438
+ where applicable.
439
+ """
440
+ transformer_layers_per_block = (
441
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
442
+ )
443
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
444
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
445
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
446
+ addition_time_embed_dim = (
447
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
448
+ )
449
+
450
+ controlnet = cls(
451
+ encoder_hid_dim=encoder_hid_dim,
452
+ encoder_hid_dim_type=encoder_hid_dim_type,
453
+ addition_embed_type=addition_embed_type,
454
+ addition_time_embed_dim=addition_time_embed_dim,
455
+ transformer_layers_per_block=transformer_layers_per_block,
456
+ in_channels=unet.config.in_channels,
457
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
458
+ freq_shift=unet.config.freq_shift,
459
+ down_block_types=unet.config.down_block_types,
460
+ only_cross_attention=unet.config.only_cross_attention,
461
+ block_out_channels=unet.config.block_out_channels,
462
+ layers_per_block=unet.config.layers_per_block,
463
+ downsample_padding=unet.config.downsample_padding,
464
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
465
+ act_fn=unet.config.act_fn,
466
+ norm_num_groups=unet.config.norm_num_groups,
467
+ norm_eps=unet.config.norm_eps,
468
+ cross_attention_dim=unet.config.cross_attention_dim,
469
+ attention_head_dim=unet.config.attention_head_dim,
470
+ num_attention_heads=unet.config.num_attention_heads,
471
+ use_linear_projection=unet.config.use_linear_projection,
472
+ class_embed_type=unet.config.class_embed_type,
473
+ num_class_embeds=unet.config.num_class_embeds,
474
+ upcast_attention=unet.config.upcast_attention,
475
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
476
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
477
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
478
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
479
+ )
480
+
481
+ if load_weights_from_unet:
482
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
483
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
484
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
485
+
486
+ if controlnet.class_embedding:
487
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
488
+
489
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
490
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
491
+
492
+ return controlnet
493
+
494
+ @property
495
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
496
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
497
+ r"""
498
+ Returns:
499
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
500
+ indexed by its weight name.
501
+ """
502
+ # set recursively
503
+ processors = {}
504
+
505
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
506
+ if hasattr(module, "get_processor"):
507
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
508
+
509
+ for sub_name, child in module.named_children():
510
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
511
+
512
+ return processors
513
+
514
+ for name, module in self.named_children():
515
+ fn_recursive_add_processors(name, module, processors)
516
+
517
+ return processors
518
+
519
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
520
+ def set_attn_processor(
521
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
522
+ ):
523
+ r"""
524
+ Sets the attention processor to use to compute attention.
525
+
526
+ Parameters:
527
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
528
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
529
+ for **all** `Attention` layers.
530
+
531
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
532
+ processor. This is strongly recommended when setting trainable attention processors.
533
+
534
+ """
535
+ count = len(self.attn_processors.keys())
536
+
537
+ if isinstance(processor, dict) and len(processor) != count:
538
+ raise ValueError(
539
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
540
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
541
+ )
542
+
543
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
544
+ if hasattr(module, "set_processor"):
545
+ if not isinstance(processor, dict):
546
+ module.set_processor(processor, _remove_lora=_remove_lora)
547
+ else:
548
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
549
+
550
+ for sub_name, child in module.named_children():
551
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
552
+
553
+ for name, module in self.named_children():
554
+ fn_recursive_attn_processor(name, module, processor)
555
+
556
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
557
+ def set_default_attn_processor(self):
558
+ """
559
+ Disables custom attention processors and sets the default attention implementation.
560
+ """
561
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
562
+ processor = AttnAddedKVProcessor()
563
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
564
+ processor = AttnProcessor()
565
+ else:
566
+ raise ValueError(
567
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
568
+ )
569
+
570
+ self.set_attn_processor(processor, _remove_lora=True)
571
+
572
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
573
+ def set_attention_slice(self, slice_size):
574
+ r"""
575
+ Enable sliced attention computation.
576
+
577
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
578
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
579
+
580
+ Args:
581
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
582
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
583
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
584
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
585
+ must be a multiple of `slice_size`.
586
+ """
587
+ sliceable_head_dims = []
588
+
589
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
590
+ if hasattr(module, "set_attention_slice"):
591
+ sliceable_head_dims.append(module.sliceable_head_dim)
592
+
593
+ for child in module.children():
594
+ fn_recursive_retrieve_sliceable_dims(child)
595
+
596
+ # retrieve number of attention layers
597
+ for module in self.children():
598
+ fn_recursive_retrieve_sliceable_dims(module)
599
+
600
+ num_sliceable_layers = len(sliceable_head_dims)
601
+
602
+ if slice_size == "auto":
603
+ # half the attention head size is usually a good trade-off between
604
+ # speed and memory
605
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
606
+ elif slice_size == "max":
607
+ # make smallest slice possible
608
+ slice_size = num_sliceable_layers * [1]
609
+
610
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
611
+
612
+ if len(slice_size) != len(sliceable_head_dims):
613
+ raise ValueError(
614
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
615
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
616
+ )
617
+
618
+ for i in range(len(slice_size)):
619
+ size = slice_size[i]
620
+ dim = sliceable_head_dims[i]
621
+ if size is not None and size > dim:
622
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
623
+
624
+ # Recursively walk through all the children.
625
+ # Any children which exposes the set_attention_slice method
626
+ # gets the message
627
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
628
+ if hasattr(module, "set_attention_slice"):
629
+ module.set_attention_slice(slice_size.pop())
630
+
631
+ for child in module.children():
632
+ fn_recursive_set_attention_slice(child, slice_size)
633
+
634
+ reversed_slice_size = list(reversed(slice_size))
635
+ for module in self.children():
636
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
637
+
638
+ def _set_gradient_checkpointing(self, module, value=False):
639
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
640
+ module.gradient_checkpointing = value
641
+
642
+ def forward(
643
+ self,
644
+ sample: torch.FloatTensor,
645
+ timestep: Union[torch.Tensor, float, int],
646
+ encoder_hidden_states: torch.Tensor,
647
+ controlnet_cond: torch.FloatTensor,
648
+ conditioning_scale: float = 1.0,
649
+ class_labels: Optional[torch.Tensor] = None,
650
+ timestep_cond: Optional[torch.Tensor] = None,
651
+ attention_mask: Optional[torch.Tensor] = None,
652
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
653
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
654
+ guess_mode: bool = False,
655
+ return_dict: bool = True,
656
+ ) -> Union[ControlNetOutput, Tuple]:
657
+ """
658
+ The [`ControlNetModel`] forward method.
659
+
660
+ Args:
661
+ sample (`torch.FloatTensor`):
662
+ The noisy input tensor.
663
+ timestep (`Union[torch.Tensor, float, int]`):
664
+ The number of timesteps to denoise an input.
665
+ encoder_hidden_states (`torch.Tensor`):
666
+ The encoder hidden states.
667
+ controlnet_cond (`torch.FloatTensor`):
668
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
669
+ conditioning_scale (`float`, defaults to `1.0`):
670
+ The scale factor for ControlNet outputs.
671
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
672
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
673
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
674
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
675
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
676
+ embeddings.
677
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
678
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
679
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
680
+ negative values to the attention scores corresponding to "discard" tokens.
681
+ added_cond_kwargs (`dict`):
682
+ Additional conditions for the Stable Diffusion XL UNet.
683
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
684
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
685
+ guess_mode (`bool`, defaults to `False`):
686
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
687
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
688
+ return_dict (`bool`, defaults to `True`):
689
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
690
+
691
+ Returns:
692
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
693
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
694
+ returned where the first element is the sample tensor.
695
+ """
696
+ # check channel order
697
+ channel_order = self.config.controlnet_conditioning_channel_order
698
+
699
+ if channel_order == "rgb":
700
+ # in rgb order by default
701
+ ...
702
+ elif channel_order == "bgr":
703
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
704
+ else:
705
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
706
+
707
+ # prepare attention_mask
708
+ if attention_mask is not None:
709
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
710
+ attention_mask = attention_mask.unsqueeze(1)
711
+
712
+ # 1. time
713
+ timesteps = timestep
714
+ if not torch.is_tensor(timesteps):
715
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
716
+ # This would be a good case for the `match` statement (Python 3.10+)
717
+ is_mps = sample.device.type == "mps"
718
+ if isinstance(timestep, float):
719
+ dtype = torch.float32 if is_mps else torch.float64
720
+ else:
721
+ dtype = torch.int32 if is_mps else torch.int64
722
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
723
+ elif len(timesteps.shape) == 0:
724
+ timesteps = timesteps[None].to(sample.device)
725
+
726
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
727
+ timesteps = timesteps.expand(sample.shape[0])
728
+
729
+ t_emb = self.time_proj(timesteps)
730
+
731
+ # timesteps does not contain any weights and will always return f32 tensors
732
+ # but time_embedding might actually be running in fp16. so we need to cast here.
733
+ # there might be better ways to encapsulate this.
734
+ t_emb = t_emb.to(dtype=sample.dtype)
735
+
736
+ emb = self.time_embedding(t_emb, timestep_cond)
737
+ aug_emb = None
738
+
739
+ if self.class_embedding is not None:
740
+ if class_labels is None:
741
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
742
+
743
+ if self.config.class_embed_type == "timestep":
744
+ class_labels = self.time_proj(class_labels)
745
+
746
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
747
+ emb = emb + class_emb
748
+
749
+ if self.config.addition_embed_type is not None:
750
+ if self.config.addition_embed_type == "text":
751
+ aug_emb = self.add_embedding(encoder_hidden_states)
752
+
753
+ elif self.config.addition_embed_type == "text_time":
754
+ if "text_embeds" not in added_cond_kwargs:
755
+ raise ValueError(
756
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
757
+ )
758
+ text_embeds = added_cond_kwargs.get("text_embeds")
759
+ if "time_ids" not in added_cond_kwargs:
760
+ raise ValueError(
761
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
762
+ )
763
+ time_ids = added_cond_kwargs.get("time_ids")
764
+ time_embeds = self.add_time_proj(time_ids.flatten())
765
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
766
+
767
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
768
+ add_embeds = add_embeds.to(emb.dtype)
769
+ aug_emb = self.add_embedding(add_embeds)
770
+
771
+ emb = emb + aug_emb if aug_emb is not None else emb
772
+
773
+ # 2. pre-process
774
+ sample = self.conv_in(sample)
775
+
776
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
777
+ sample = sample + controlnet_cond
778
+
779
+ # 3. down
780
+ down_block_res_samples = (sample,)
781
+ for downsample_block in self.down_blocks:
782
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
783
+ sample, res_samples = downsample_block(
784
+ hidden_states=sample,
785
+ temb=emb,
786
+ encoder_hidden_states=encoder_hidden_states,
787
+ attention_mask=attention_mask,
788
+ cross_attention_kwargs=cross_attention_kwargs,
789
+ )
790
+ else:
791
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
792
+
793
+ down_block_res_samples += res_samples
794
+
795
+ # 4. mid
796
+ if self.mid_block is not None:
797
+ sample = self.mid_block(
798
+ sample,
799
+ emb,
800
+ encoder_hidden_states=encoder_hidden_states,
801
+ attention_mask=attention_mask,
802
+ cross_attention_kwargs=cross_attention_kwargs,
803
+ )
804
+
805
+ # 5. Control net blocks
806
+
807
+ controlnet_down_block_res_samples = ()
808
+
809
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
810
+ down_block_res_sample = controlnet_block(down_block_res_sample)
811
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
812
+
813
+ down_block_res_samples = controlnet_down_block_res_samples
814
+
815
+ mid_block_res_sample = self.controlnet_mid_block(sample)
816
+
817
+ # 6. scaling
818
+ if guess_mode and not self.config.global_pool_conditions:
819
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
820
+ scales = scales * conditioning_scale
821
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
822
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
823
+ else:
824
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
825
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
826
+
827
+ if self.config.global_pool_conditions:
828
+ down_block_res_samples = [
829
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
830
+ ]
831
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
832
+
833
+ if not return_dict:
834
+ return (down_block_res_samples, mid_block_res_sample)
835
+
836
+ return ControlNetOutput(
837
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
838
+ )
839
+
840
+
841
+ def zero_module(module):
842
+ for p in module.parameters():
843
+ nn.init.zeros_(p)
844
+ return module
pipelines/inverted_ve_pipeline.py ADDED
@@ -0,0 +1,1615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from diffusers import StableDiffusionPipeline
3
+ import torch
4
+ from dataclasses import dataclass
5
+ from typing import Callable, List, Optional, Union, Any, Dict
6
+ import numpy as np
7
+ from diffusers.utils import deprecate, logging, BaseOutput
8
+ from einops import rearrange, repeat
9
+ from torch.nn.functional import grid_sample
10
+ from torch.nn import functional as nnf
11
+ import torchvision.transforms as T
12
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
13
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel, attention_processor
14
+ from diffusers.schedulers import KarrasDiffusionSchedulers
15
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
16
+ import PIL
17
+ from PIL import Image
18
+ from kornia.morphology import dilation
19
+ from collections import OrderedDict
20
+ from packaging import version
21
+ import inspect
22
+ from diffusers.utils import (
23
+ deprecate,
24
+ is_accelerate_available,
25
+ is_accelerate_version,
26
+ logging,
27
+ replace_example_docstring,
28
+ )
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ import torch.nn as nn
31
+
32
+ T = torch.Tensor
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class StyleAlignedArgs:
37
+ share_group_norm: bool = True
38
+ share_layer_norm: bool = True,
39
+ share_attention: bool = True
40
+ adain_queries: bool = True
41
+ adain_keys: bool = True
42
+ adain_values: bool = False
43
+ full_attention_share: bool = False
44
+ keys_scale: float = 1.
45
+ only_self_level: float = 0.
46
+
47
+ def expand_first(feat: T, scale=1., ) -> T:
48
+ b = feat.shape[0]
49
+ feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
50
+ if scale == 1:
51
+ feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
52
+ else:
53
+ feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
54
+ feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
55
+ return feat_style.reshape(*feat.shape)
56
+
57
+
58
+ def concat_first(feat: T, dim=2, scale=1.) -> T:
59
+ feat_style = expand_first(feat, scale=scale)
60
+ return torch.cat((feat, feat_style), dim=dim)
61
+
62
+
63
+ def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]:
64
+ feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
65
+ feat_mean = feat.mean(dim=-2, keepdims=True)
66
+ return feat_mean, feat_std
67
+
68
+
69
+ def adain(feat: T) -> T:
70
+ feat_mean, feat_std = calc_mean_std(feat)
71
+ feat_style_mean = expand_first(feat_mean)
72
+ feat_style_std = expand_first(feat_std)
73
+ feat = (feat - feat_mean) / feat_std
74
+ feat = feat * feat_style_std + feat_style_mean
75
+ return feat
76
+
77
+
78
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
79
+
80
+
81
+ EXAMPLE_DOC_STRING = """
82
+ Examples:
83
+ ```py
84
+ >>> import torch
85
+ >>> from diffusers import StableDiffusionPipeline
86
+
87
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
88
+ >>> pipe = pipe.to("cuda")
89
+
90
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
91
+ >>> image = pipe(prompt).images[0]
92
+ ```
93
+ """
94
+
95
+ # ACTIVATE_STEP_CANDIDATE = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 1]
96
+
97
+
98
+ def create_image_grid(image_list, rows, cols, padding=10):
99
+ # Ensure the number of rows and columns doesn't exceed the number of images
100
+ rows = min(rows, len(image_list))
101
+ cols = min(cols, len(image_list))
102
+
103
+ # Get the dimensions of a single image
104
+ image_width, image_height = image_list[0].size
105
+
106
+ # Calculate the size of the output image
107
+ grid_width = cols * (image_width + padding) - padding
108
+ grid_height = rows * (image_height + padding) - padding
109
+
110
+ # Create an empty grid image
111
+ grid_image = Image.new('RGB', (grid_width, grid_height), (255, 255, 255))
112
+
113
+ # Paste images into the grid
114
+ for i, img in enumerate(image_list[:rows * cols]):
115
+ row = i // cols
116
+ col = i % cols
117
+ x = col * (image_width + padding)
118
+ y = row * (image_height + padding)
119
+ grid_image.paste(img, (x, y))
120
+
121
+ return grid_image
122
+
123
+
124
+
125
+
126
+ class CrossFrameAttnProcessor_backup:
127
+ def __init__(self, unet_chunk_size=2):
128
+ self.unet_chunk_size = unet_chunk_size
129
+
130
+ def __call__(
131
+ self,
132
+ attn,
133
+ hidden_states,
134
+ encoder_hidden_states=None,
135
+ attention_mask=None):
136
+
137
+
138
+ batch_size, sequence_length, _ = hidden_states.shape
139
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
140
+ query = attn.to_q(hidden_states)
141
+
142
+ is_cross_attention = encoder_hidden_states is not None
143
+ if encoder_hidden_states is None:
144
+ encoder_hidden_states = hidden_states
145
+ # elif attn.cross_attention_norm:
146
+ # encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
147
+ key = attn.to_k(encoder_hidden_states)
148
+ value = attn.to_v(encoder_hidden_states)
149
+ # Sparse Attention
150
+ if not is_cross_attention:
151
+ video_length = key.size()[0] // self.unet_chunk_size
152
+ # former_frame_index = torch.arange(video_length) - 1
153
+ # former_frame_index[0] = 0
154
+ # import pdb; pdb.set_trace()
155
+
156
+ # if video_length > 3:
157
+ # import pdb; pdb.set_trace()
158
+ former_frame_index = [0] * video_length
159
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
160
+ key = key[:, former_frame_index]
161
+ key = rearrange(key, "b f d c -> (b f) d c")
162
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
163
+ value = value[:, former_frame_index]
164
+ value = rearrange(value, "b f d c -> (b f) d c")
165
+
166
+
167
+ query = attn.head_to_batch_dim(query)
168
+ key = attn.head_to_batch_dim(key)
169
+ value = attn.head_to_batch_dim(value)
170
+
171
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
172
+ hidden_states = torch.bmm(attention_probs, value)
173
+ hidden_states = attn.batch_to_head_dim(hidden_states)
174
+
175
+ # linear proj
176
+ hidden_states = attn.to_out[0](hidden_states)
177
+ # dropout
178
+ hidden_states = attn.to_out[1](hidden_states)
179
+
180
+ return hidden_states
181
+
182
+
183
+ class SharedAttentionProcessor:
184
+ def __init__(self,
185
+ adain_keys=True,
186
+ adain_queries=True,
187
+ adain_values=False,
188
+ keys_scale=1.,
189
+ attn_map_save_steps=[]):
190
+ super().__init__()
191
+ self.adain_queries = adain_queries
192
+ self.adain_keys = adain_keys
193
+ self.adain_values = adain_values
194
+ # self.full_attention_share = style_aligned_args.full_attention_share
195
+ self.keys_scale = keys_scale
196
+ self.attn_map_save_steps = attn_map_save_steps
197
+
198
+
199
+ def __call__(
200
+ self,
201
+ attn: attention_processor.Attention,
202
+ hidden_states,
203
+ encoder_hidden_states=None,
204
+ attention_mask=None,
205
+ **kwargs
206
+ ):
207
+
208
+ if not hasattr(attn, "attn_map"):
209
+ setattr(attn, "attn_map", {})
210
+ setattr(attn, "inference_step", 0)
211
+ else:
212
+ attn.inference_step += 1
213
+
214
+ residual = hidden_states
215
+ input_ndim = hidden_states.ndim
216
+ if input_ndim == 4:
217
+ batch_size, channel, height, width = hidden_states.shape
218
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
219
+ batch_size, sequence_length, _ = (
220
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
221
+ )
222
+
223
+ is_cross_attention = encoder_hidden_states is not None
224
+
225
+ if attention_mask is not None:
226
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
227
+ # scaled_dot_product_attention expects attention_mask shape to be
228
+ # (batch, heads, source_length, target_length)
229
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
230
+
231
+ if attn.group_norm is not None:
232
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
233
+
234
+ query = attn.to_q(hidden_states)
235
+
236
+ if encoder_hidden_states is None:
237
+ encoder_hidden_states = hidden_states
238
+ # elif attn.cross_attention_norm:
239
+ # encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
240
+ key = attn.to_k(encoder_hidden_states)
241
+ value = attn.to_v(encoder_hidden_states)
242
+
243
+ inner_dim = key.shape[-1]
244
+ head_dim = inner_dim // attn.heads
245
+
246
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
247
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
248
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+ # if self.step >= self.start_inject:
250
+
251
+
252
+ if not is_cross_attention:# and self.share_attention:
253
+ if self.adain_queries:
254
+ query = adain(query)
255
+ if self.adain_keys:
256
+ key = adain(key)
257
+ if self.adain_values:
258
+ value = adain(value)
259
+ key = concat_first(key, -2, scale=self.keys_scale)
260
+ value = concat_first(value, -2)
261
+ hidden_states = nnf.scaled_dot_product_attention(
262
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
263
+ )
264
+ else:
265
+ hidden_states = nnf.scaled_dot_product_attention(
266
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
267
+ )
268
+
269
+
270
+
271
+
272
+ # hidden_states = adain(hidden_states)
273
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
274
+ hidden_states = hidden_states.to(query.dtype)
275
+
276
+ # linear proj
277
+ hidden_states = attn.to_out[0](hidden_states)
278
+ # dropout
279
+ hidden_states = attn.to_out[1](hidden_states)
280
+
281
+ if input_ndim == 4:
282
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
283
+
284
+ if attn.residual_connection:
285
+ hidden_states = hidden_states + residual
286
+
287
+ hidden_states = hidden_states / attn.rescale_output_factor
288
+ return hidden_states
289
+
290
+
291
+ class SharedAttentionProcessor_v2:
292
+ def __init__(self,
293
+ adain_keys=True,
294
+ adain_queries=True,
295
+ adain_values=False,
296
+ keys_scale=1.,
297
+ attn_map_save_steps=[]):
298
+ super().__init__()
299
+ self.adain_queries = adain_queries
300
+ self.adain_keys = adain_keys
301
+ self.adain_values = adain_values
302
+ # self.full_attention_share = style_aligned_args.full_attention_share
303
+ self.keys_scale = keys_scale
304
+ self.attn_map_save_steps = attn_map_save_steps
305
+
306
+
307
+ def __call__(
308
+ self,
309
+ attn: attention_processor.Attention,
310
+ hidden_states,
311
+ encoder_hidden_states=None,
312
+ attention_mask=None,
313
+ **kwargs
314
+ ):
315
+
316
+ if not hasattr(attn, "attn_map"):
317
+ setattr(attn, "attn_map", {})
318
+ setattr(attn, "inference_step", 0)
319
+ else:
320
+ attn.inference_step += 1
321
+
322
+ residual = hidden_states
323
+ input_ndim = hidden_states.ndim
324
+ if input_ndim == 4:
325
+ batch_size, channel, height, width = hidden_states.shape
326
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
327
+ batch_size, sequence_length, _ = (
328
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
329
+ )
330
+
331
+ is_cross_attention = encoder_hidden_states is not None
332
+
333
+ if attention_mask is not None:
334
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
335
+ # scaled_dot_product_attention expects attention_mask shape to be
336
+ # (batch, heads, source_length, target_length)
337
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
338
+
339
+ if attn.group_norm is not None:
340
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
341
+
342
+ query = attn.to_q(hidden_states)
343
+
344
+
345
+ if encoder_hidden_states is None:
346
+ encoder_hidden_states = hidden_states
347
+ # elif attn.cross_attention_norm:
348
+ # encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
349
+ key = attn.to_k(encoder_hidden_states)
350
+ value = attn.to_v(encoder_hidden_states)
351
+
352
+ tmp_query_shape = query.shape
353
+ tmp_key_shape = key.shape
354
+ tmp_value_shape = value.shape
355
+
356
+
357
+ inner_dim = key.shape[-1]
358
+ head_dim = inner_dim // attn.heads
359
+
360
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
361
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
362
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
363
+ # if self.step >= self.start_inject:
364
+
365
+
366
+ if not is_cross_attention:# and self.share_attention:
367
+ if self.adain_queries:
368
+ query = adain(query)
369
+ if self.adain_keys:
370
+ key = adain(key)
371
+ if self.adain_values:
372
+ value = adain(value)
373
+ key = concat_first(key, -2, scale=self.keys_scale)
374
+ value = concat_first(value, -2)
375
+ # hidden_states = nnf.scaled_dot_product_attention(
376
+ # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
377
+ # )
378
+
379
+ if attn.inference_step in self.attn_map_save_steps:
380
+
381
+ query = query.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
382
+ key = key.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
383
+ value = value.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
384
+
385
+ query = attn.head_to_batch_dim(query)
386
+ key = attn.head_to_batch_dim(key)
387
+ value = attn.head_to_batch_dim(value)
388
+
389
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
390
+
391
+ if attn.inference_step in self.attn_map_save_steps:
392
+ attn.attn_map[attn.inference_step] = attention_probs.clone().cpu().detach()
393
+
394
+ hidden_states = torch.bmm(attention_probs, value)
395
+ hidden_states = attn.batch_to_head_dim(hidden_states)
396
+ else:
397
+ hidden_states = nnf.scaled_dot_product_attention(
398
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
399
+ )
400
+ # hidden_states = adain(hidden_states)
401
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
402
+ hidden_states = hidden_states.to(query.dtype)
403
+
404
+ else:
405
+
406
+ hidden_states = nnf.scaled_dot_product_attention(
407
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
408
+ )
409
+ # hidden_states = adain(hidden_states)
410
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
411
+ hidden_states = hidden_states.to(query.dtype)
412
+
413
+ # linear proj
414
+ hidden_states = attn.to_out[0](hidden_states)
415
+ # dropout
416
+ hidden_states = attn.to_out[1](hidden_states)
417
+
418
+ if input_ndim == 4:
419
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
420
+
421
+ if attn.residual_connection:
422
+ hidden_states = hidden_states + residual
423
+
424
+ hidden_states = hidden_states / attn.rescale_output_factor
425
+
426
+ if attn.inference_step == 49:
427
+ #initialize inference step
428
+ attn.inference_step = -1
429
+
430
+ return hidden_states
431
+
432
+
433
+ def swapping_attention(key, value, chunk_size=2):
434
+ chunk_length = key.size()[0] // chunk_size # [text-condition, null-condition]
435
+ reference_image_index = [0] * chunk_length # [0 0 0 0 0]
436
+ key = rearrange(key, "(b f) d c -> b f d c", f=chunk_length)
437
+ key = key[:, reference_image_index] # ref to all
438
+ key = rearrange(key, "b f d c -> (b f) d c")
439
+ value = rearrange(value, "(b f) d c -> b f d c", f=chunk_length)
440
+ value = value[:, reference_image_index] # ref to all
441
+ value = rearrange(value, "b f d c -> (b f) d c")
442
+
443
+ return key, value
444
+
445
+ class CrossFrameAttnProcessor:
446
+ def __init__(self, unet_chunk_size=2, attn_map_save_steps=[],activate_step_indices=None):
447
+ self.unet_chunk_size = unet_chunk_size
448
+ self.attn_map_save_steps = attn_map_save_steps
449
+ self.activate_step_indices = activate_step_indices
450
+
451
+ def __call__(
452
+ self,
453
+ attn,
454
+ hidden_states,
455
+ encoder_hidden_states=None,
456
+ attention_mask=None):
457
+
458
+ if not hasattr(attn, "attn_map"):
459
+ setattr(attn, "attn_map", {})
460
+ setattr(attn, "inference_step", 0)
461
+ else:
462
+ attn.inference_step += 1
463
+
464
+
465
+
466
+ batch_size, sequence_length, _ = hidden_states.shape
467
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
468
+ query = attn.to_q(hidden_states)
469
+
470
+ is_cross_attention = encoder_hidden_states is not None
471
+ if encoder_hidden_states is None:
472
+ encoder_hidden_states = hidden_states
473
+
474
+ key = attn.to_k(encoder_hidden_states)
475
+ value = attn.to_v(encoder_hidden_states)
476
+
477
+ is_in_inference_step = False
478
+
479
+ if self.activate_step_indices is not None:
480
+ for activate_step_index in self.activate_step_indices:
481
+ if attn.inference_step >= activate_step_index[0] and attn.inference_step <= activate_step_index[1]:
482
+ is_in_inference_step = True
483
+ break
484
+
485
+ # Swapping Attention
486
+ if not is_cross_attention and is_in_inference_step:
487
+ key, value = swapping_attention(key, value, self.unet_chunk_size)
488
+
489
+
490
+
491
+
492
+ query = attn.head_to_batch_dim(query)
493
+ key = attn.head_to_batch_dim(key)
494
+ value = attn.head_to_batch_dim(value)
495
+
496
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
497
+
498
+ if attn.inference_step in self.attn_map_save_steps:
499
+ attn.attn_map[attn.inference_step] = attention_probs.clone().cpu().detach()
500
+
501
+ hidden_states = torch.bmm(attention_probs, value)
502
+ hidden_states = attn.batch_to_head_dim(hidden_states)
503
+
504
+ # linear proj
505
+ hidden_states = attn.to_out[0](hidden_states)
506
+ # dropout
507
+ hidden_states = attn.to_out[1](hidden_states)
508
+
509
+ if attn.inference_step == 49:
510
+ attn.inference_step = -1
511
+
512
+ return hidden_states
513
+
514
+
515
+
516
+
517
+ class CrossFrameAttnProcessor4Inversion:
518
+ def __init__(self, unet_chunk_size=2, attn_map_save_steps=[],activate_step_indices=None):
519
+ self.unet_chunk_size = unet_chunk_size
520
+ self.attn_map_save_steps = attn_map_save_steps
521
+ self.activate_step_indices = activate_step_indices
522
+
523
+ def __call__(
524
+ self,
525
+ attn,
526
+ hidden_states,
527
+ encoder_hidden_states=None,
528
+ attention_mask=None):
529
+
530
+ if not hasattr(attn, "attn_map"):
531
+ setattr(attn, "attn_map", {})
532
+ setattr(attn, "inference_step", 0)
533
+ else:
534
+ attn.inference_step += 1
535
+
536
+
537
+
538
+ batch_size, sequence_length, _ = hidden_states.shape
539
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
540
+ query = attn.to_q(hidden_states)
541
+
542
+ is_cross_attention = encoder_hidden_states is not None
543
+ if encoder_hidden_states is None:
544
+ encoder_hidden_states = hidden_states
545
+ # elif attn.cross_attention_norm:
546
+ # encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
547
+ key = attn.to_k(encoder_hidden_states)
548
+ value = attn.to_v(encoder_hidden_states)
549
+
550
+ is_in_inference_step = False
551
+
552
+ if self.activate_step_indices is not None:
553
+ for activate_step_index in self.activate_step_indices:
554
+ if attn.inference_step >= activate_step_index[0] and attn.inference_step <= activate_step_index[1]:
555
+ is_in_inference_step = True
556
+ break
557
+
558
+ # Swapping Attention
559
+ if not is_cross_attention and is_in_inference_step:
560
+ key, value = swapping_attention(key, value, self.unet_chunk_size)
561
+
562
+
563
+
564
+ query = attn.head_to_batch_dim(query)
565
+ key = attn.head_to_batch_dim(key)
566
+ value = attn.head_to_batch_dim(value)
567
+
568
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
569
+
570
+ # if attn.inference_step > 45 and attn.inference_step < 50:
571
+ # if attn.inference_step == 42 or attn.inference_step==49:
572
+ if attn.inference_step in self.attn_map_save_steps:
573
+ attn.attn_map[attn.inference_step] = attention_probs.clone().cpu().detach()
574
+
575
+ hidden_states = torch.bmm(attention_probs, value)
576
+ hidden_states = attn.batch_to_head_dim(hidden_states)
577
+
578
+ # linear proj
579
+ hidden_states = attn.to_out[0](hidden_states)
580
+ # dropout
581
+ hidden_states = attn.to_out[1](hidden_states)
582
+
583
+ if attn.inference_step == 49:
584
+ #initialize inference step
585
+ attn.inference_step = -1
586
+
587
+ return hidden_states
588
+
589
+
590
+
591
+ class CrossFrameAttnProcessor_store:
592
+ def __init__(self, unet_chunk_size=2, attn_map_save_steps=[]):
593
+ self.unet_chunk_size = unet_chunk_size
594
+ self.attn_map_save_steps = attn_map_save_steps
595
+
596
+ def __call__(
597
+ self,
598
+ attn,
599
+ hidden_states,
600
+ encoder_hidden_states=None,
601
+ attention_mask=None):
602
+
603
+ batch_size, sequence_length, _ = hidden_states.shape
604
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
605
+ query = attn.to_q(hidden_states)
606
+
607
+ is_cross_attention = encoder_hidden_states is not None
608
+ if encoder_hidden_states is None:
609
+ encoder_hidden_states = hidden_states
610
+ # elif attn.cross_attention_norm:
611
+ # encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
612
+ key = attn.to_k(encoder_hidden_states)
613
+ value = attn.to_v(encoder_hidden_states)
614
+
615
+ # Swapping Attention
616
+ if not is_cross_attention:
617
+ key, value = swapping_attention(key, value, self.unet_chunk_size)
618
+
619
+
620
+ query = attn.head_to_batch_dim(query)
621
+ key = attn.head_to_batch_dim(key)
622
+ value = attn.head_to_batch_dim(value)
623
+
624
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
625
+
626
+ if not hasattr(attn, "attn_map"):
627
+ setattr(attn, "attn_map", {})
628
+ setattr(attn, "inference_step", 0)
629
+ else:
630
+ attn.inference_step += 1
631
+
632
+
633
+ # if attn.inference_step > 45 and attn.inference_step < 50:
634
+ # if attn.inference_step == 42 or attn.inference_step==49:
635
+ if attn.inference_step in self.attn_map_save_steps:
636
+ attn.attn_map[attn.inference_step] = attention_probs.clone().cpu().detach()
637
+
638
+ hidden_states = torch.bmm(attention_probs, value)
639
+ hidden_states = attn.batch_to_head_dim(hidden_states)
640
+
641
+ # linear proj
642
+ hidden_states = attn.to_out[0](hidden_states)
643
+ # dropout
644
+ hidden_states = attn.to_out[1](hidden_states)
645
+
646
+ return hidden_states
647
+
648
+
649
+ class InvertedVEAttnProcessor:
650
+ def __init__(self, unet_chunk_size=2, scale=1.0):
651
+ self.unet_chunk_size = unet_chunk_size
652
+ self.scale = scale
653
+
654
+ def __call__(
655
+ self,
656
+ attn,
657
+ hidden_states,
658
+ encoder_hidden_states=None,
659
+ attention_mask=None):
660
+ batch_size, sequence_length, _ = hidden_states.shape
661
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
662
+ query = attn.to_q(hidden_states)
663
+
664
+ is_cross_attention = encoder_hidden_states is not None
665
+ if encoder_hidden_states is None:
666
+ encoder_hidden_states = hidden_states
667
+ elif attn.cross_attention_norm:
668
+ encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
669
+ key = attn.to_k(encoder_hidden_states)
670
+ value = attn.to_v(encoder_hidden_states)
671
+
672
+ #Dual Attention
673
+ if not is_cross_attention:
674
+ ve_key = key.clone()
675
+ ve_value = value.clone()
676
+ video_length = ve_key.size()[0] // self.unet_chunk_size
677
+
678
+ former_frame_index = [0] * video_length
679
+ ve_key = rearrange(ve_key, "(b f) d c -> b f d c", f=video_length)
680
+ ve_key = ve_key[:, former_frame_index]
681
+ ve_key = rearrange(ve_key, "b f d c -> (b f) d c")
682
+ ve_value = rearrange(ve_value, "(b f) d c -> b f d c", f=video_length)
683
+ ve_value = ve_value[:, former_frame_index]
684
+ ve_value = rearrange(ve_value, "b f d c -> (b f) d c")
685
+
686
+ ve_key = attn.head_to_batch_dim(ve_key)
687
+ ve_value = attn.head_to_batch_dim(ve_value)
688
+ ve_query = attn.head_to_batch_dim(query)
689
+
690
+ ve_attention_probs = attn.get_attention_scores(ve_query, ve_key, attention_mask)
691
+ ve_hidden_states = torch.bmm(ve_attention_probs, ve_value)
692
+ ve_hidden_states = attn.batch_to_head_dim(ve_hidden_states)
693
+ ve_hidden_states[0,...] = 0
694
+ ve_hidden_states[video_length,...] = 0
695
+
696
+ query = attn.head_to_batch_dim(query)
697
+ key = attn.head_to_batch_dim(key)
698
+ value = attn.head_to_batch_dim(value)
699
+
700
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
701
+ hidden_states = torch.bmm(attention_probs, value)
702
+ hidden_states = attn.batch_to_head_dim(hidden_states)
703
+
704
+ hidden_states = hidden_states + ve_hidden_states * self.scale
705
+
706
+ else:
707
+ query = attn.head_to_batch_dim(query)
708
+ key = attn.head_to_batch_dim(key)
709
+ value = attn.head_to_batch_dim(value)
710
+
711
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
712
+ hidden_states = torch.bmm(attention_probs, value)
713
+ hidden_states = attn.batch_to_head_dim(hidden_states)
714
+
715
+
716
+
717
+ # linear proj
718
+ hidden_states = attn.to_out[0](hidden_states)
719
+ # dropout
720
+ hidden_states = attn.to_out[1](hidden_states)
721
+
722
+ return hidden_states
723
+
724
+ class AttnProcessor(nn.Module):
725
+ r"""
726
+ Default processor for performing attention-related computations.
727
+ """
728
+ def __init__(
729
+ self,
730
+ hidden_size=None,
731
+ cross_attention_dim=None,
732
+ ):
733
+ super().__init__()
734
+
735
+ def __call__(
736
+ self,
737
+ attn,
738
+ hidden_states,
739
+ encoder_hidden_states=None,
740
+ attention_mask=None,
741
+ temb=None,
742
+ ):
743
+
744
+ residual = hidden_states
745
+ # import pdb; pdb.set_trace()
746
+ # if attn.spatial_norm is not None:
747
+ # hidden_states = attn.spatial_norm(hidden_states, temb)
748
+
749
+ input_ndim = hidden_states.ndim
750
+
751
+ if input_ndim == 4:
752
+ batch_size, channel, height, width = hidden_states.shape
753
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
754
+
755
+ batch_size, sequence_length, _ = (
756
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
757
+ )
758
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
759
+
760
+ # if attn.group_norm is not None:
761
+ # hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
762
+
763
+ query = attn.to_q(hidden_states)
764
+
765
+ if encoder_hidden_states is None:
766
+ encoder_hidden_states = hidden_states
767
+ elif attn.norm_cross:
768
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
769
+
770
+ key = attn.to_k(encoder_hidden_states)
771
+ value = attn.to_v(encoder_hidden_states)
772
+
773
+ query = attn.head_to_batch_dim(query)
774
+ key = attn.head_to_batch_dim(key)
775
+ value = attn.head_to_batch_dim(value)
776
+
777
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
778
+ hidden_states = torch.bmm(attention_probs, value)
779
+ hidden_states = attn.batch_to_head_dim(hidden_states)
780
+
781
+ # linear proj
782
+ hidden_states = attn.to_out[0](hidden_states)
783
+ # dropout
784
+ hidden_states = attn.to_out[1](hidden_states)
785
+
786
+ if input_ndim == 4:
787
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
788
+
789
+ if attn.residual_connection:
790
+ hidden_states = hidden_states + residual
791
+
792
+ hidden_states = hidden_states / attn.rescale_output_factor
793
+
794
+ return hidden_states
795
+
796
+
797
+ @dataclass
798
+ class StableDiffusionPipelineOutput(BaseOutput):
799
+ """
800
+ Output class for Stable Diffusion pipelines.
801
+
802
+ Args:
803
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
804
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
805
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
806
+ nsfw_content_detected (`List[bool]`)
807
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
808
+ (nsfw) content, or `None` if safety checking could not be performed.
809
+ """
810
+
811
+ images: Union[List[PIL.Image.Image], np.ndarray]
812
+ nsfw_content_detected: Optional[List[bool]]
813
+
814
+ class FrozenDict(OrderedDict):
815
+ def __init__(self, *args, **kwargs):
816
+ super().__init__(*args, **kwargs)
817
+
818
+ for key, value in self.items():
819
+ setattr(self, key, value)
820
+
821
+ self.__frozen = True
822
+
823
+ def __delitem__(self, *args, **kwargs):
824
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
825
+
826
+ def setdefault(self, *args, **kwargs):
827
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
828
+
829
+ def pop(self, *args, **kwargs):
830
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
831
+
832
+ def update(self, *args, **kwargs):
833
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
834
+
835
+ def __setattr__(self, name, value):
836
+ if hasattr(self, "__frozen") and self.__frozen:
837
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
838
+ super().__setattr__(name, value)
839
+
840
+ def __setitem__(self, name, value):
841
+ if hasattr(self, "__frozen") and self.__frozen:
842
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
843
+ super().__setitem__(name, value)
844
+
845
+
846
+ class InvertedVEPipeline(StableDiffusionPipeline):
847
+ r"""
848
+ Pipeline for text-to-image generation using Stable Diffusion.
849
+
850
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
851
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
852
+
853
+ Args:
854
+ vae ([`AutoencoderKL`]):
855
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
856
+ text_encoder ([`CLIPTextModel`]):
857
+ Frozen text-encoder. Stable Diffusion uses the text portion of
858
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
859
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
860
+ tokenizer (`CLIPTokenizer`):
861
+ Tokenizer of class
862
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
863
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
864
+ scheduler ([`SchedulerMixin`]):
865
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
866
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
867
+ safety_checker ([`StableDiffusionSafetyChecker`]):
868
+ Classification module that estimates whether generated images could be considered offensive or harmful.
869
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
870
+ feature_extractor ([`CLIPFeatureExtractor`]):
871
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
872
+ """
873
+ _optional_components = ["safety_checker", "feature_extractor"]
874
+
875
+ def __init__(
876
+ self,
877
+ vae: AutoencoderKL,
878
+ text_encoder: CLIPTextModel,
879
+ tokenizer: CLIPTokenizer,
880
+ unet: UNet2DConditionModel,
881
+ scheduler: KarrasDiffusionSchedulers,
882
+ safety_checker: StableDiffusionSafetyChecker,
883
+ feature_extractor: CLIPFeatureExtractor,
884
+ requires_safety_checker: bool = True,
885
+ ):
886
+ # super().__init__()
887
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
888
+ safety_checker, feature_extractor, requires_safety_checker)
889
+
890
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
891
+ deprecation_message = (
892
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
893
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
894
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
895
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
896
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
897
+ " file"
898
+ )
899
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
900
+ new_config = dict(scheduler.config)
901
+ new_config["steps_offset"] = 1
902
+ scheduler._internal_dict = FrozenDict(new_config)
903
+
904
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
905
+ deprecation_message = (
906
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
907
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
908
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
909
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
910
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
911
+ )
912
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
913
+ new_config = dict(scheduler.config)
914
+ new_config["clip_sample"] = False
915
+ scheduler._internal_dict = FrozenDict(new_config)
916
+
917
+ if safety_checker is None and requires_safety_checker:
918
+ logger.warning(
919
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
920
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
921
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
922
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
923
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
924
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
925
+ )
926
+
927
+ if safety_checker is not None and feature_extractor is None:
928
+ raise ValueError(
929
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
930
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
931
+ )
932
+
933
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
934
+ version.parse(unet.config._diffusers_version).base_version
935
+ ) < version.parse("0.9.0.dev0")
936
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
937
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
938
+ deprecation_message = (
939
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
940
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
941
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
942
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
943
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
944
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
945
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
946
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
947
+ " the `unet/config.json` file"
948
+ )
949
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
950
+ new_config = dict(unet.config)
951
+ new_config["sample_size"] = 64
952
+ unet._internal_dict = FrozenDict(new_config)
953
+
954
+ self.register_modules(
955
+ vae=vae,
956
+ text_encoder=text_encoder,
957
+ tokenizer=tokenizer,
958
+ unet=unet,
959
+ scheduler=scheduler,
960
+ safety_checker=safety_checker,
961
+ feature_extractor=feature_extractor,
962
+ )
963
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
964
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
965
+
966
+ def enable_vae_slicing(self):
967
+ r"""
968
+ Enable sliced VAE decoding.
969
+
970
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
971
+ steps. This is useful to save some memory and allow larger batch sizes.
972
+ """
973
+ self.vae.enable_slicing()
974
+
975
+ def disable_vae_slicing(self):
976
+ r"""
977
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
978
+ computing decoding in one step.
979
+ """
980
+ self.vae.disable_slicing()
981
+
982
+ def enable_vae_tiling(self):
983
+ r"""
984
+ Enable tiled VAE decoding.
985
+
986
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
987
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
988
+ """
989
+ self.vae.enable_tiling()
990
+
991
+ def disable_vae_tiling(self):
992
+ r"""
993
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
994
+ computing decoding in one step.
995
+ """
996
+ self.vae.disable_tiling()
997
+
998
+ def enable_sequential_cpu_offload(self, gpu_id=0):
999
+ r"""
1000
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
1001
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
1002
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
1003
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
1004
+ `enable_model_cpu_offload`, but performance is lower.
1005
+ """
1006
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
1007
+ from accelerate import cpu_offload
1008
+ else:
1009
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
1010
+
1011
+ device = torch.device(f"cuda:{gpu_id}")
1012
+
1013
+ if self.device.type != "cpu":
1014
+ self.to("cpu", silence_dtype_warnings=True)
1015
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1016
+
1017
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
1018
+ cpu_offload(cpu_offloaded_model, device)
1019
+
1020
+ if self.safety_checker is not None:
1021
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
1022
+
1023
+ def enable_model_cpu_offload(self, gpu_id=0):
1024
+ r"""
1025
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
1026
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
1027
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
1028
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
1029
+ """
1030
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1031
+ from accelerate import cpu_offload_with_hook
1032
+ else:
1033
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
1034
+
1035
+ device = torch.device(f"cuda:{gpu_id}")
1036
+
1037
+ if self.device.type != "cpu":
1038
+ self.to("cpu", silence_dtype_warnings=True)
1039
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1040
+
1041
+ hook = None
1042
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
1043
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
1044
+
1045
+ if self.safety_checker is not None:
1046
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
1047
+
1048
+ # We'll offload the last model manually.
1049
+ self.final_offload_hook = hook
1050
+
1051
+ @property
1052
+ def _execution_device(self):
1053
+ r"""
1054
+ Returns the device on which the pipeline's models will be executed. After calling
1055
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
1056
+ hooks.
1057
+ """
1058
+ if not hasattr(self.unet, "_hf_hook"):
1059
+ return self.device
1060
+ for module in self.unet.modules():
1061
+ if (
1062
+ hasattr(module, "_hf_hook")
1063
+ and hasattr(module._hf_hook, "execution_device")
1064
+ and module._hf_hook.execution_device is not None
1065
+ ):
1066
+ return torch.device(module._hf_hook.execution_device)
1067
+ return self.device
1068
+
1069
+
1070
+ def run_safety_checker(self, image, device, dtype):
1071
+ if self.safety_checker is not None:
1072
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
1073
+ image, has_nsfw_concept = self.safety_checker(
1074
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
1075
+ )
1076
+ else:
1077
+ has_nsfw_concept = None
1078
+ return image, has_nsfw_concept
1079
+
1080
+ def decode_latents(self, latents):
1081
+ latents = 1 / self.vae.config.scaling_factor * latents
1082
+ image = self.vae.decode(latents).sample
1083
+ image = (image / 2 + 0.5).clamp(0, 1)
1084
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
1085
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1086
+ return image
1087
+
1088
+ def prepare_extra_step_kwargs(self, generator, eta):
1089
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1090
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1091
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1092
+ # and should be between [0, 1]
1093
+
1094
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
1095
+ extra_step_kwargs = {}
1096
+ if accepts_eta:
1097
+ extra_step_kwargs["eta"] = eta
1098
+
1099
+ # check if the scheduler accepts generator
1100
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
1101
+ if accepts_generator:
1102
+ extra_step_kwargs["generator"] = generator
1103
+ return extra_step_kwargs
1104
+
1105
+ def check_inputs(
1106
+ self,
1107
+ prompt,
1108
+ height,
1109
+ width,
1110
+ callback_steps,
1111
+ negative_prompt=None,
1112
+ prompt_embeds=None,
1113
+ negative_prompt_embeds=None,
1114
+ ):
1115
+ if height % 8 != 0 or width % 8 != 0:
1116
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
1117
+
1118
+ if (callback_steps is None) or (
1119
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
1120
+ ):
1121
+ raise ValueError(
1122
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
1123
+ f" {type(callback_steps)}."
1124
+ )
1125
+
1126
+ if prompt is not None and prompt_embeds is not None:
1127
+ raise ValueError(
1128
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
1129
+ " only forward one of the two."
1130
+ )
1131
+ elif prompt is None and prompt_embeds is None:
1132
+ raise ValueError(
1133
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
1134
+ )
1135
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
1136
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
1137
+
1138
+ if negative_prompt is not None and negative_prompt_embeds is not None:
1139
+ raise ValueError(
1140
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
1141
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
1142
+ )
1143
+
1144
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
1145
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
1146
+ raise ValueError(
1147
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
1148
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
1149
+ f" {negative_prompt_embeds.shape}."
1150
+ )
1151
+
1152
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
1153
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
1154
+ if isinstance(generator, list) and len(generator) != batch_size:
1155
+ raise ValueError(
1156
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1157
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1158
+ )
1159
+
1160
+ if latents is None:
1161
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
1162
+ else:
1163
+ latents = latents.to(device)
1164
+
1165
+ # scale the initial noise by the standard deviation required by the scheduler
1166
+ latents = latents * self.scheduler.init_noise_sigma
1167
+ return latents
1168
+
1169
+ @torch.no_grad()
1170
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1171
+ def __call__(
1172
+ self,
1173
+ prompt: Union[str, List[str]] = None,
1174
+ height: Optional[int] = None,
1175
+ width: Optional[int] = None,
1176
+ num_inference_steps: int = 50,
1177
+ guidance_scale: float = 7.5,
1178
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1179
+ num_images_per_prompt: Optional[int] = 1,
1180
+ eta: float = 0.0,
1181
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1182
+ latents: Optional[torch.FloatTensor] = None,
1183
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1184
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1185
+ output_type: Optional[str] = "pil",
1186
+ return_dict: bool = True,
1187
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1188
+ callback_steps: int = 1,
1189
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1190
+ target_prompt: Optional[str] = None,
1191
+ # device: Optional[Union[str, torch.device]] = "cpu",
1192
+ ):
1193
+ r"""
1194
+ Function invoked when calling the pipeline for generation.
1195
+
1196
+ Args:
1197
+ prompt (`str` or `List[str]`, *optional*):
1198
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1199
+ instead.
1200
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1201
+ The height in pixels of the generated image.
1202
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1203
+ The width in pixels of the generated image.
1204
+ num_inference_steps (`int`, *optional*, defaults to 50):
1205
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1206
+ expense of slower inference.
1207
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1208
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1209
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1210
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1211
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1212
+ usually at the expense of lower image quality.
1213
+ negative_prompt (`str` or `List[str]`, *optional*):
1214
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1215
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
1216
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
1217
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1218
+ The number of images to generate per prompt.
1219
+ eta (`float`, *optional*, defaults to 0.0):
1220
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1221
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1222
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1223
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1224
+ to make generation deterministic.
1225
+ latents (`torch.FloatTensor`, *optional*):
1226
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1227
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1228
+ tensor will ge generated by sampling using the supplied random `generator`.
1229
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1230
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1231
+ provided, text embeddings will be generated from `prompt` input argument.
1232
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1233
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1234
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1235
+ argument.
1236
+ output_type (`str`, *optional*, defaults to `"pil"`):
1237
+ The output format of the generate image. Choose between
1238
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1239
+ return_dict (`bool`, *optional*, defaults to `True`):
1240
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1241
+ plain tuple.
1242
+ callback (`Callable`, *optional*):
1243
+ A function that will be called every `callback_steps` steps during inference. The function will be
1244
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1245
+ callback_steps (`int`, *optional*, defaults to 1):
1246
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1247
+ called at every step.
1248
+ cross_attention_kwargs (`dict`, *optional*):
1249
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
1250
+ `self.processor` in
1251
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1252
+
1253
+ Examples:
1254
+
1255
+ Returns:
1256
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1257
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1258
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1259
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1260
+ (nsfw) content, according to the `safety_checker`.
1261
+ """
1262
+ # 0. Default height and width to unet
1263
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1264
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1265
+
1266
+ # 1. Check inputs. Raise error if not correct
1267
+ self.check_inputs(
1268
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
1269
+ )
1270
+
1271
+ # 2. Define call parameters
1272
+ if prompt is not None and isinstance(prompt, str):
1273
+ batch_size = 1
1274
+ elif prompt is not None and isinstance(prompt, list):
1275
+ batch_size = len(prompt)
1276
+ else:
1277
+ batch_size = prompt_embeds.shape[0]
1278
+
1279
+ device = self._execution_device
1280
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1281
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1282
+ # corresponds to doing no classifier free guidance.
1283
+ do_classifier_free_guidance = guidance_scale > 1.0
1284
+
1285
+ # 3. Encode input prompt
1286
+ # import pdb; pdb.set_trace()
1287
+
1288
+
1289
+ prompt_embeds = self._encode_prompt(
1290
+ prompt,
1291
+ device,
1292
+ num_images_per_prompt,
1293
+ do_classifier_free_guidance,
1294
+ negative_prompt,
1295
+ prompt_embeds=prompt_embeds,
1296
+ negative_prompt_embeds=negative_prompt_embeds,
1297
+ )
1298
+
1299
+ # import pdb; pdb.set_trace()
1300
+
1301
+ if target_prompt is not None:
1302
+ target_prompt_embeds = self._encode_prompt(
1303
+ target_prompt,
1304
+ device,
1305
+ num_images_per_prompt,
1306
+ do_classifier_free_guidance,
1307
+ negative_prompt,
1308
+ prompt_embeds=None,
1309
+ negative_prompt_embeds=negative_prompt_embeds,
1310
+ )
1311
+ prompt_embeds[num_images_per_prompt+1: ] = target_prompt_embeds[num_images_per_prompt+1:]
1312
+ import pdb; pdb.set_trace()
1313
+
1314
+ # 4. Prepare timesteps
1315
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1316
+ timesteps = self.scheduler.timesteps
1317
+
1318
+ # 5. Prepare latent variables
1319
+ num_channels_latents = self.unet.in_channels
1320
+ latents = self.prepare_latents(
1321
+ batch_size * num_images_per_prompt,
1322
+ num_channels_latents,
1323
+ height,
1324
+ width,
1325
+ prompt_embeds.dtype,
1326
+ device,
1327
+ generator,
1328
+ latents,
1329
+ )
1330
+
1331
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1332
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1333
+
1334
+ # 7. Denoising loop
1335
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1336
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1337
+ for i, t in enumerate(timesteps):
1338
+ # expand the latents if we are doing classifier free guidance
1339
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1340
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1341
+
1342
+ # predict the noise residual
1343
+ noise_pred = self.unet(
1344
+ latent_model_input,
1345
+ t,
1346
+ encoder_hidden_states=prompt_embeds,
1347
+ cross_attention_kwargs=cross_attention_kwargs,
1348
+ ).sample
1349
+
1350
+ # perform guidance
1351
+ if do_classifier_free_guidance:
1352
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1353
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1354
+
1355
+ # compute the previous noisy sample x_t -> x_t-1
1356
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1357
+
1358
+ # call the callback, if provided
1359
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1360
+ progress_bar.update()
1361
+ if callback is not None and i % callback_steps == 0:
1362
+ callback(i, t, latents)
1363
+
1364
+ if output_type == "latent":
1365
+ image = latents
1366
+ has_nsfw_concept = None
1367
+ elif output_type == "pil":
1368
+ # 8. Post-processing
1369
+ image = self.decode_latents(latents)
1370
+
1371
+ # 9. Run safety checker
1372
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1373
+
1374
+ # 10. Convert to PIL
1375
+ image = self.numpy_to_pil(image)
1376
+ else:
1377
+ # 8. Post-processing
1378
+ image = self.decode_latents(latents)
1379
+
1380
+ # 9. Run safety checker
1381
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1382
+
1383
+ # Offload last model to CPU
1384
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1385
+ self.final_offload_hook.offload()
1386
+
1387
+ if not return_dict:
1388
+ return (image, has_nsfw_concept)
1389
+
1390
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1391
+
1392
+
1393
+ ACTIVATE_LAYER_CANDIDATE= [
1394
+ 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor',
1395
+ 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
1396
+ 'down_blocks.1.attentions.0.transformer_blocks.1.attn1.processor',
1397
+ 'down_blocks.1.attentions.0.transformer_blocks.1.attn2.processor',
1398
+ 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor',
1399
+ 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
1400
+ 'down_blocks.1.attentions.1.transformer_blocks.1.attn1.processor',
1401
+ 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor', #8
1402
+
1403
+ 'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor',
1404
+ 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor',
1405
+ 'down_blocks.2.attentions.0.transformer_blocks.1.attn1.processor',
1406
+ 'down_blocks.2.attentions.0.transformer_blocks.1.attn2.processor',
1407
+ 'down_blocks.2.attentions.0.transformer_blocks.2.attn1.processor',
1408
+ 'down_blocks.2.attentions.0.transformer_blocks.2.attn2.processor',
1409
+ 'down_blocks.2.attentions.0.transformer_blocks.3.attn1.processor',
1410
+ 'down_blocks.2.attentions.0.transformer_blocks.3.attn2.processor',
1411
+ 'down_blocks.2.attentions.0.transformer_blocks.4.attn1.processor',
1412
+ 'down_blocks.2.attentions.0.transformer_blocks.4.attn2.processor',
1413
+ 'down_blocks.2.attentions.0.transformer_blocks.5.attn1.processor',
1414
+ 'down_blocks.2.attentions.0.transformer_blocks.5.attn2.processor',
1415
+ 'down_blocks.2.attentions.0.transformer_blocks.6.attn1.processor',
1416
+ 'down_blocks.2.attentions.0.transformer_blocks.6.attn2.processor',
1417
+ 'down_blocks.2.attentions.0.transformer_blocks.7.attn1.processor',
1418
+ 'down_blocks.2.attentions.0.transformer_blocks.7.attn2.processor',
1419
+ 'down_blocks.2.attentions.0.transformer_blocks.8.attn1.processor',
1420
+ 'down_blocks.2.attentions.0.transformer_blocks.8.attn2.processor',
1421
+ 'down_blocks.2.attentions.0.transformer_blocks.9.attn1.processor',
1422
+ 'down_blocks.2.attentions.0.transformer_blocks.9.attn2.processor', #20
1423
+
1424
+ 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor',
1425
+ 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor',
1426
+ 'down_blocks.2.attentions.1.transformer_blocks.1.attn1.processor',
1427
+ 'down_blocks.2.attentions.1.transformer_blocks.1.attn2.processor',
1428
+ 'down_blocks.2.attentions.1.transformer_blocks.2.attn1.processor',
1429
+ 'down_blocks.2.attentions.1.transformer_blocks.2.attn2.processor',
1430
+ 'down_blocks.2.attentions.1.transformer_blocks.3.attn1.processor',
1431
+ 'down_blocks.2.attentions.1.transformer_blocks.3.attn2.processor',
1432
+ 'down_blocks.2.attentions.1.transformer_blocks.4.attn1.processor',
1433
+ 'down_blocks.2.attentions.1.transformer_blocks.4.attn2.processor',
1434
+ 'down_blocks.2.attentions.1.transformer_blocks.5.attn1.processor',
1435
+ 'down_blocks.2.attentions.1.transformer_blocks.5.attn2.processor',
1436
+ 'down_blocks.2.attentions.1.transformer_blocks.6.attn1.processor',
1437
+ 'down_blocks.2.attentions.1.transformer_blocks.6.attn2.processor',
1438
+ 'down_blocks.2.attentions.1.transformer_blocks.7.attn1.processor',
1439
+ 'down_blocks.2.attentions.1.transformer_blocks.7.attn2.processor',
1440
+ 'down_blocks.2.attentions.1.transformer_blocks.8.attn1.processor',
1441
+ 'down_blocks.2.attentions.1.transformer_blocks.8.attn2.processor',
1442
+ 'down_blocks.2.attentions.1.transformer_blocks.9.attn1.processor',
1443
+ 'down_blocks.2.attentions.1.transformer_blocks.9.attn2.processor',#20
1444
+
1445
+ 'mid_block.attentions.0.transformer_blocks.0.attn1.processor',
1446
+ 'mid_block.attentions.0.transformer_blocks.0.attn2.processor',
1447
+ 'mid_block.attentions.0.transformer_blocks.1.attn1.processor',
1448
+ 'mid_block.attentions.0.transformer_blocks.1.attn2.processor',
1449
+ 'mid_block.attentions.0.transformer_blocks.2.attn1.processor',
1450
+ 'mid_block.attentions.0.transformer_blocks.2.attn2.processor',
1451
+ 'mid_block.attentions.0.transformer_blocks.3.attn1.processor',
1452
+ 'mid_block.attentions.0.transformer_blocks.3.attn2.processor',
1453
+ 'mid_block.attentions.0.transformer_blocks.4.attn1.processor',
1454
+ 'mid_block.attentions.0.transformer_blocks.4.attn2.processor',
1455
+ 'mid_block.attentions.0.transformer_blocks.5.attn1.processor',
1456
+ 'mid_block.attentions.0.transformer_blocks.5.attn2.processor',
1457
+ 'mid_block.attentions.0.transformer_blocks.6.attn1.processor',
1458
+ 'mid_block.attentions.0.transformer_blocks.6.attn2.processor',
1459
+ 'mid_block.attentions.0.transformer_blocks.7.attn1.processor',
1460
+ 'mid_block.attentions.0.transformer_blocks.7.attn2.processor',
1461
+ 'mid_block.attentions.0.transformer_blocks.8.attn1.processor',
1462
+ 'mid_block.attentions.0.transformer_blocks.8.attn2.processor',
1463
+ 'mid_block.attentions.0.transformer_blocks.9.attn1.processor',
1464
+ 'mid_block.attentions.0.transformer_blocks.9.attn2.processor', #20
1465
+
1466
+ 'up_blocks.0.attentions.0.transformer_blocks.0.attn1.processor',
1467
+ 'up_blocks.0.attentions.0.transformer_blocks.0.attn2.processor',
1468
+ 'up_blocks.0.attentions.0.transformer_blocks.1.attn1.processor',
1469
+ 'up_blocks.0.attentions.0.transformer_blocks.1.attn2.processor',
1470
+ 'up_blocks.0.attentions.0.transformer_blocks.2.attn1.processor',
1471
+ 'up_blocks.0.attentions.0.transformer_blocks.2.attn2.processor',
1472
+ 'up_blocks.0.attentions.0.transformer_blocks.3.attn1.processor',
1473
+ 'up_blocks.0.attentions.0.transformer_blocks.3.attn2.processor',
1474
+ 'up_blocks.0.attentions.0.transformer_blocks.4.attn1.processor',
1475
+ 'up_blocks.0.attentions.0.transformer_blocks.4.attn2.processor',
1476
+ 'up_blocks.0.attentions.0.transformer_blocks.5.attn1.processor',
1477
+ 'up_blocks.0.attentions.0.transformer_blocks.5.attn2.processor',
1478
+ 'up_blocks.0.attentions.0.transformer_blocks.6.attn1.processor',
1479
+ 'up_blocks.0.attentions.0.transformer_blocks.6.attn2.processor',
1480
+ 'up_blocks.0.attentions.0.transformer_blocks.7.attn1.processor',
1481
+ 'up_blocks.0.attentions.0.transformer_blocks.7.attn2.processor',
1482
+ 'up_blocks.0.attentions.0.transformer_blocks.8.attn1.processor',
1483
+ 'up_blocks.0.attentions.0.transformer_blocks.8.attn2.processor',
1484
+ 'up_blocks.0.attentions.0.transformer_blocks.9.attn1.processor',
1485
+ 'up_blocks.0.attentions.0.transformer_blocks.9.attn2.processor',#20
1486
+
1487
+ 'up_blocks.0.attentions.1.transformer_blocks.0.attn1.processor',
1488
+ 'up_blocks.0.attentions.1.transformer_blocks.0.attn2.processor',
1489
+ 'up_blocks.0.attentions.1.transformer_blocks.1.attn1.processor',
1490
+ 'up_blocks.0.attentions.1.transformer_blocks.1.attn2.processor',
1491
+ 'up_blocks.0.attentions.1.transformer_blocks.2.attn1.processor',
1492
+ 'up_blocks.0.attentions.1.transformer_blocks.2.attn2.processor',
1493
+ 'up_blocks.0.attentions.1.transformer_blocks.3.attn1.processor',
1494
+ 'up_blocks.0.attentions.1.transformer_blocks.3.attn2.processor',
1495
+ 'up_blocks.0.attentions.1.transformer_blocks.4.attn1.processor',
1496
+ 'up_blocks.0.attentions.1.transformer_blocks.4.attn2.processor',
1497
+ 'up_blocks.0.attentions.1.transformer_blocks.5.attn1.processor',
1498
+ 'up_blocks.0.attentions.1.transformer_blocks.5.attn2.processor',
1499
+ 'up_blocks.0.attentions.1.transformer_blocks.6.attn1.processor',
1500
+ 'up_blocks.0.attentions.1.transformer_blocks.6.attn2.processor',
1501
+ 'up_blocks.0.attentions.1.transformer_blocks.7.attn1.processor',
1502
+ 'up_blocks.0.attentions.1.transformer_blocks.7.attn2.processor',
1503
+ 'up_blocks.0.attentions.1.transformer_blocks.8.attn1.processor',
1504
+ 'up_blocks.0.attentions.1.transformer_blocks.8.attn2.processor',
1505
+ 'up_blocks.0.attentions.1.transformer_blocks.9.attn1.processor',
1506
+ 'up_blocks.0.attentions.1.transformer_blocks.9.attn2.processor',#20
1507
+
1508
+ 'up_blocks.0.attentions.2.transformer_blocks.0.attn1.processor',
1509
+ 'up_blocks.0.attentions.2.transformer_blocks.0.attn2.processor',
1510
+ 'up_blocks.0.attentions.2.transformer_blocks.1.attn1.processor',
1511
+ 'up_blocks.0.attentions.2.transformer_blocks.1.attn2.processor',
1512
+ 'up_blocks.0.attentions.2.transformer_blocks.2.attn1.processor',
1513
+ 'up_blocks.0.attentions.2.transformer_blocks.2.attn2.processor',
1514
+ 'up_blocks.0.attentions.2.transformer_blocks.3.attn1.processor',
1515
+ 'up_blocks.0.attentions.2.transformer_blocks.3.attn2.processor',
1516
+ 'up_blocks.0.attentions.2.transformer_blocks.4.attn1.processor',
1517
+ 'up_blocks.0.attentions.2.transformer_blocks.4.attn2.processor',
1518
+ 'up_blocks.0.attentions.2.transformer_blocks.5.attn1.processor',
1519
+ 'up_blocks.0.attentions.2.transformer_blocks.5.attn2.processor',
1520
+ 'up_blocks.0.attentions.2.transformer_blocks.6.attn1.processor',
1521
+ 'up_blocks.0.attentions.2.transformer_blocks.6.attn2.processor',
1522
+ 'up_blocks.0.attentions.2.transformer_blocks.7.attn1.processor',
1523
+ 'up_blocks.0.attentions.2.transformer_blocks.7.attn2.processor',
1524
+ 'up_blocks.0.attentions.2.transformer_blocks.8.attn1.processor',
1525
+ 'up_blocks.0.attentions.2.transformer_blocks.8.attn2.processor',
1526
+ 'up_blocks.0.attentions.2.transformer_blocks.9.attn1.processor',
1527
+ 'up_blocks.0.attentions.2.transformer_blocks.9.attn2.processor', #20
1528
+
1529
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor',
1530
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
1531
+ 'up_blocks.1.attentions.0.transformer_blocks.1.attn1.processor',
1532
+ 'up_blocks.1.attentions.0.transformer_blocks.1.attn2.processor',
1533
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor',
1534
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
1535
+ 'up_blocks.1.attentions.1.transformer_blocks.1.attn1.processor',
1536
+ 'up_blocks.1.attentions.1.transformer_blocks.1.attn2.processor',
1537
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor',
1538
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor',
1539
+ 'up_blocks.1.attentions.2.transformer_blocks.1.attn1.processor',
1540
+ 'up_blocks.1.attentions.2.transformer_blocks.1.attn2.processor',#12
1541
+
1542
+ ]
1543
+
1544
+ STYLE_DESCRIPTION_DICT = {
1545
+ "chinese-ink-paint":("{object} in colorful chinese ink paintings style",""),
1546
+ "cloud":("Photography of {object}, realistic",""),
1547
+ "digital-art":("{object} in digital glitch arts style",""),
1548
+ "fire":("{object} photography, realistic, black background'",""),
1549
+ "klimt":("{object} in style of Gustav Klimt",""),
1550
+ "line-art":("line art drawing of {object} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",""),
1551
+ "low-poly":("low-poly style of {object} . low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition",
1552
+ "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo"),
1553
+ "munch":("{object} in Edvard Munch style",""),
1554
+ "van-gogh":("{object}, Van Gogh",""),
1555
+ "totoro":("{object}, art by studio ghibli, cinematic, masterpiece,key visual, studio anime, highly detailed",
1556
+ "photo, deformed, black and white, realism, disfigured, low contrast"),
1557
+
1558
+ "realistic": ("A portrait of {object}, photorealistic, 35mm film, realistic",
1559
+ "gray, ugly, deformed, noisy, blurry"),
1560
+
1561
+ "line_art": ("line art drawing of {object} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
1562
+ "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic"
1563
+ ) ,
1564
+
1565
+ "anime": ("anime artwork of {object} . anime style, key visual, vibrant, studio anime, highly detailed",
1566
+ "photo, deformed, black and white, realism, disfigured, low contrast"
1567
+ ),
1568
+
1569
+ "Artstyle_Pop_Art" : ("pop Art style of {object} . bright colors, bold outlines, popular culture themes, ironic or kitsch",
1570
+ "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, minimalist"
1571
+ ),
1572
+
1573
+ "Artstyle_Pointillism": ("pointillism style of {object} . composed entirely of small, distinct dots of color, vibrant, highly detailed",
1574
+ "line drawing, smooth shading, large color fields, simplistic"
1575
+ ),
1576
+
1577
+ "origami": ("origami style of {object} . paper art, pleated paper, folded, origami art, pleats, cut and fold, centered composition",
1578
+ "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo"
1579
+ ),
1580
+
1581
+ "craft_clay": ("play-doh style of {object} . sculpture, clay art, centered composition, Claymation",
1582
+ "sloppy, messy, grainy, highly detailed, ultra textured, photo"
1583
+ ),
1584
+
1585
+ "low_poly" : ("low-poly style of {object} . low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition",
1586
+ "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo"
1587
+ ),
1588
+
1589
+ "Artstyle_watercolor": ("watercolor painting of {object} . vibrant, beautiful, painterly, detailed, textural, artistic",
1590
+ "anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy"
1591
+ ),
1592
+
1593
+ "Papercraft_Collage" : ("collage style of {object} . mixed media, layered, textural, detailed, artistic",
1594
+ "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic"
1595
+ ),
1596
+
1597
+ "Artstyle_Impressionist" : ("impressionist painting of {object} . loose brushwork, vibrant color, light and shadow play, captures feeling over form",
1598
+ "anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy"
1599
+ ),
1600
+ "realistic_bg_black":("{object} photography, realistic, black background",
1601
+ ""),
1602
+ "photography_realistic":("Photography of {object}, realistic",
1603
+ ""),
1604
+ "digital_art":("{object} in digital glitch arts style.",
1605
+ ""
1606
+ ),
1607
+ "chinese_painting":("{object} in traditional a chinese ink painting style.",
1608
+ ""
1609
+ ),
1610
+ "no_style":("{object}",
1611
+ ""),
1612
+ "kid_drawing":("{object} in kid crayon drawings style.",""),
1613
+ "onepiece":("{object}, wanostyle, angry looking, straw hat, looking at viewer, solo, upper body, masterpiece, best quality, (extremely detailed), watercolor, illustration, depth of field, sketch, dark intense shadows, sharp focus, soft lighting, hdr, colorful, good composition, fire all around, spectacular, closed shirt",
1614
+ " watermark, text, error, blurry, jpeg artifacts, many objects, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature")
1615
+ }
pipelines/pipeline_controlnet_sd_xl.py ADDED
The diff for this file is too large to render. See raw diff
 
pipelines/pipeline_stable_diffusion_xl.py ADDED
@@ -0,0 +1,1792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+ import PIL
18
+ import torch
19
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
20
+
21
+ from diffusers.image_processor import VaeImageProcessor
22
+ from diffusers.loaders import (
23
+ FromSingleFileMixin,
24
+ StableDiffusionXLLoraLoaderMixin,
25
+ TextualInversionLoaderMixin,
26
+ )
27
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
+ from diffusers.models.attention_processor import (
29
+ AttnProcessor2_0,
30
+ LoRAAttnProcessor2_0,
31
+ LoRAXFormersAttnProcessor,
32
+ XFormersAttnProcessor,
33
+ )
34
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
35
+ from diffusers.schedulers import KarrasDiffusionSchedulers
36
+ from diffusers.utils import (
37
+ USE_PEFT_BACKEND,
38
+ deprecate,
39
+ is_invisible_watermark_available,
40
+ is_torch_xla_available,
41
+ logging,
42
+ replace_example_docstring,
43
+ scale_lora_layers,
44
+ unscale_lora_layers,
45
+ )
46
+ from diffusers.utils.torch_utils import randn_tensor
47
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
49
+
50
+ from pipelines.inverted_ve_pipeline import CrossFrameAttnProcessor, ACTIVATE_LAYER_CANDIDATE, SharedAttentionProcessor, SharedAttentionProcessor_v2
51
+ from diffusers.models.attention_processor import AttnProcessor
52
+
53
+ import os
54
+
55
+ if is_invisible_watermark_available():
56
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
57
+
58
+ if is_torch_xla_available():
59
+ import torch_xla.core.xla_model as xm
60
+
61
+ XLA_AVAILABLE = True
62
+ else:
63
+ XLA_AVAILABLE = False
64
+
65
+
66
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
67
+
68
+ EXAMPLE_DOC_STRING = """
69
+ Examples:
70
+ ```py
71
+ >>> import torch
72
+ >>> from diffusers import StableDiffusionXLPipeline
73
+
74
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
75
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
76
+ ... )
77
+ >>> pipe = pipe.to("cuda")
78
+
79
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
80
+ >>> image = pipe(prompt).images[0]
81
+ ```
82
+ """
83
+
84
+
85
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
86
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
87
+ """
88
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
89
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
90
+ """
91
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
92
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
93
+ # rescale the results from guidance (fixes overexposure)
94
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
95
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
96
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
97
+ return noise_cfg
98
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
99
+ def retrieve_latents(encoder_output, generator):
100
+ if hasattr(encoder_output, "latent_dist"):
101
+ return encoder_output.latent_dist.sample(generator)
102
+ elif hasattr(encoder_output, "latents"):
103
+ return encoder_output.latents
104
+ else:
105
+ raise AttributeError("Could not access latents of provided encoder_output")
106
+
107
+
108
+
109
+ class StableDiffusionXLPipeline(
110
+ DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
111
+ ):
112
+ r"""
113
+ Pipeline for text-to-image generation using Stable Diffusion XL.
114
+
115
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
116
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
117
+
118
+ In addition the pipeline inherits the following loading methods:
119
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
120
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
121
+
122
+ as well as the following saving methods:
123
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
124
+
125
+ Args:
126
+ vae ([`AutoencoderKL`]):
127
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
128
+ text_encoder ([`CLIPTextModel`]):
129
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
130
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
131
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
132
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
133
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
134
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
135
+ specifically the
136
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
137
+ variant.
138
+ tokenizer (`CLIPTokenizer`):
139
+ Tokenizer of class
140
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
141
+ tokenizer_2 (`CLIPTokenizer`):
142
+ Second Tokenizer of class
143
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
144
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
145
+ scheduler ([`SchedulerMixin`]):
146
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
147
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
148
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
149
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
150
+ `stabilityai/stable-diffusion-xl-base-1-0`.
151
+ add_watermarker (`bool`, *optional*):
152
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
153
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
154
+ watermarker will be used.
155
+ """
156
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
157
+ _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
158
+ _callback_tensor_inputs = [
159
+ "latents",
160
+ "prompt_embeds",
161
+ "negative_prompt_embeds",
162
+ "add_text_embeds",
163
+ "add_time_ids",
164
+ "negative_pooled_prompt_embeds",
165
+ "negative_add_time_ids",
166
+ ]
167
+
168
+ def __init__(
169
+ self,
170
+ vae: AutoencoderKL,
171
+ text_encoder: CLIPTextModel,
172
+ text_encoder_2: CLIPTextModelWithProjection,
173
+ tokenizer: CLIPTokenizer,
174
+ tokenizer_2: CLIPTokenizer,
175
+ unet: UNet2DConditionModel,
176
+ scheduler: KarrasDiffusionSchedulers,
177
+ force_zeros_for_empty_prompt: bool = True,
178
+ add_watermarker: Optional[bool] = None,
179
+ ):
180
+ super().__init__()
181
+
182
+ self.register_modules(
183
+ vae=vae,
184
+ text_encoder=text_encoder,
185
+ text_encoder_2=text_encoder_2,
186
+ tokenizer=tokenizer,
187
+ tokenizer_2=tokenizer_2,
188
+ unet=unet,
189
+ scheduler=scheduler,
190
+ )
191
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
192
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
193
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
194
+
195
+ self.default_sample_size = self.unet.config.sample_size
196
+
197
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
198
+
199
+ if add_watermarker:
200
+ self.watermark = StableDiffusionXLWatermarker()
201
+ else:
202
+ self.watermark = None
203
+
204
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
205
+ def enable_vae_slicing(self):
206
+ r"""
207
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
208
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
209
+ """
210
+ self.vae.enable_slicing()
211
+
212
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
213
+ def disable_vae_slicing(self):
214
+ r"""
215
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
216
+ computing decoding in one step.
217
+ """
218
+ self.vae.disable_slicing()
219
+
220
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
221
+ def enable_vae_tiling(self):
222
+ r"""
223
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
224
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
225
+ processing larger images.
226
+ """
227
+ self.vae.enable_tiling()
228
+
229
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
230
+ def disable_vae_tiling(self):
231
+ r"""
232
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
233
+ computing decoding in one step.
234
+ """
235
+ self.vae.disable_tiling()
236
+
237
+ def encode_prompt(
238
+ self,
239
+ prompt: str,
240
+ prompt_2: Optional[str] = None,
241
+ device: Optional[torch.device] = None,
242
+ num_images_per_prompt: int = 1,
243
+ do_classifier_free_guidance: bool = True,
244
+ negative_prompt: Optional[str] = None,
245
+ negative_prompt_2: Optional[str] = None,
246
+ prompt_embeds: Optional[torch.FloatTensor] = None,
247
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
248
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
249
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
250
+ lora_scale: Optional[float] = None,
251
+ clip_skip: Optional[int] = None,
252
+ ):
253
+ r"""
254
+ Encodes the prompt into text encoder hidden states.
255
+
256
+ Args:
257
+ prompt (`str` or `List[str]`, *optional*):
258
+ prompt to be encoded
259
+ prompt_2 (`str` or `List[str]`, *optional*):
260
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
261
+ used in both text-encoders
262
+ device: (`torch.device`):
263
+ torch device
264
+ num_images_per_prompt (`int`):
265
+ number of images that should be generated per prompt
266
+ do_classifier_free_guidance (`bool`):
267
+ whether to use classifier free guidance or not
268
+ negative_prompt (`str` or `List[str]`, *optional*):
269
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
270
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
271
+ less than `1`).
272
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
273
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
274
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
275
+ prompt_embeds (`torch.FloatTensor`, *optional*):
276
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
277
+ provided, text embeddings will be generated from `prompt` input argument.
278
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
279
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
280
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
281
+ argument.
282
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
283
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
284
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
285
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
286
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
287
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
288
+ input argument.
289
+ lora_scale (`float`, *optional*):
290
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
291
+ clip_skip (`int`, *optional*):
292
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
293
+ the output of the pre-final layer will be used for computing the prompt embeddings.
294
+ """
295
+ device = device or self._execution_device
296
+
297
+ # set lora scale so that monkey patched LoRA
298
+ # function of text encoder can correctly access it
299
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
300
+ self._lora_scale = lora_scale
301
+
302
+ # dynamically adjust the LoRA scale
303
+ if self.text_encoder is not None:
304
+ if not USE_PEFT_BACKEND:
305
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
306
+ else:
307
+ scale_lora_layers(self.text_encoder, lora_scale)
308
+
309
+ if self.text_encoder_2 is not None:
310
+ if not USE_PEFT_BACKEND:
311
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
312
+ else:
313
+ scale_lora_layers(self.text_encoder_2, lora_scale)
314
+
315
+ prompt = [prompt] if isinstance(prompt, str) else prompt
316
+
317
+ if prompt is not None:
318
+ batch_size = len(prompt)
319
+ else:
320
+ batch_size = prompt_embeds.shape[0]
321
+
322
+ # Define tokenizers and text encoders
323
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
324
+ text_encoders = (
325
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
326
+ )
327
+
328
+ if prompt_embeds is None:
329
+ prompt_2 = prompt_2 or prompt
330
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
331
+
332
+ # textual inversion: procecss multi-vector tokens if necessary
333
+ prompt_embeds_list = []
334
+ prompts = [prompt, prompt_2]
335
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
336
+ if isinstance(self, TextualInversionLoaderMixin):
337
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
338
+
339
+ text_inputs = tokenizer(
340
+ prompt,
341
+ padding="max_length",
342
+ max_length=tokenizer.model_max_length,
343
+ truncation=True,
344
+ return_tensors="pt",
345
+ )
346
+
347
+ text_input_ids = text_inputs.input_ids
348
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
349
+
350
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
351
+ text_input_ids, untruncated_ids
352
+ ):
353
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
354
+ logger.warning(
355
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
356
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
357
+ )
358
+
359
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
360
+
361
+ # We are only ALWAYS interested in the pooled output of the final text encoder
362
+ pooled_prompt_embeds = prompt_embeds[0]
363
+ if clip_skip is None:
364
+ prompt_embeds = prompt_embeds.hidden_states[-2]
365
+ else:
366
+ # "2" because SDXL always indexes from the penultimate layer.
367
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
368
+
369
+ prompt_embeds_list.append(prompt_embeds)
370
+
371
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
372
+
373
+ # get unconditional embeddings for classifier free guidance
374
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
375
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
376
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
377
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
378
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
379
+ negative_prompt = negative_prompt or ""
380
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
381
+
382
+ # normalize str to list
383
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
384
+ negative_prompt_2 = (
385
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
386
+ )
387
+
388
+ uncond_tokens: List[str]
389
+ if prompt is not None and type(prompt) is not type(negative_prompt):
390
+ raise TypeError(
391
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
392
+ f" {type(prompt)}."
393
+ )
394
+ elif batch_size != len(negative_prompt):
395
+ raise ValueError(
396
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
397
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
398
+ " the batch size of `prompt`."
399
+ )
400
+ else:
401
+ uncond_tokens = [negative_prompt, negative_prompt_2]
402
+
403
+ negative_prompt_embeds_list = []
404
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
405
+ if isinstance(self, TextualInversionLoaderMixin):
406
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
407
+
408
+ max_length = prompt_embeds.shape[1]
409
+ uncond_input = tokenizer(
410
+ negative_prompt,
411
+ padding="max_length",
412
+ max_length=max_length,
413
+ truncation=True,
414
+ return_tensors="pt",
415
+ )
416
+
417
+ negative_prompt_embeds = text_encoder(
418
+ uncond_input.input_ids.to(device),
419
+ output_hidden_states=True,
420
+ )
421
+ # We are only ALWAYS interested in the pooled output of the final text encoder
422
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
423
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
424
+
425
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
426
+
427
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
428
+
429
+ if self.text_encoder_2 is not None:
430
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
431
+ else:
432
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
433
+
434
+ bs_embed, seq_len, _ = prompt_embeds.shape
435
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
436
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
437
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
438
+
439
+ if do_classifier_free_guidance:
440
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
441
+ seq_len = negative_prompt_embeds.shape[1]
442
+
443
+ if self.text_encoder_2 is not None:
444
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
445
+ else:
446
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
447
+
448
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
449
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
450
+
451
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
452
+ bs_embed * num_images_per_prompt, -1
453
+ )
454
+ if do_classifier_free_guidance:
455
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
456
+ bs_embed * num_images_per_prompt, -1
457
+ )
458
+
459
+
460
+ if self.text_encoder is not None:
461
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
462
+ # Retrieve the original scale by scaling back the LoRA layers
463
+ unscale_lora_layers(self.text_encoder, lora_scale)
464
+
465
+ if self.text_encoder_2 is not None:
466
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
467
+ # Retrieve the original scale by scaling back the LoRA layers
468
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
469
+
470
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
471
+
472
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
473
+ def prepare_extra_step_kwargs(self, generator, eta):
474
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
475
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
476
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
477
+ # and should be between [0, 1]
478
+
479
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
480
+ extra_step_kwargs = {}
481
+ if accepts_eta:
482
+ extra_step_kwargs["eta"] = eta
483
+
484
+ # check if the scheduler accepts generator
485
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
486
+ if accepts_generator:
487
+ extra_step_kwargs["generator"] = generator
488
+ return extra_step_kwargs
489
+
490
+ def check_inputs(
491
+ self,
492
+ prompt,
493
+ prompt_2,
494
+ height,
495
+ width,
496
+ callback_steps,
497
+ negative_prompt=None,
498
+ negative_prompt_2=None,
499
+ prompt_embeds=None,
500
+ negative_prompt_embeds=None,
501
+ pooled_prompt_embeds=None,
502
+ negative_pooled_prompt_embeds=None,
503
+ callback_on_step_end_tensor_inputs=None,
504
+ ):
505
+ if height % 8 != 0 or width % 8 != 0:
506
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
507
+
508
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
509
+ raise ValueError(
510
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
511
+ f" {type(callback_steps)}."
512
+ )
513
+
514
+ if callback_on_step_end_tensor_inputs is not None and not all(
515
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
516
+ ):
517
+ raise ValueError(
518
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
519
+ )
520
+
521
+ if prompt is not None and prompt_embeds is not None:
522
+ raise ValueError(
523
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
524
+ " only forward one of the two."
525
+ )
526
+ elif prompt_2 is not None and prompt_embeds is not None:
527
+ raise ValueError(
528
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
529
+ " only forward one of the two."
530
+ )
531
+ elif prompt is None and prompt_embeds is None:
532
+ raise ValueError(
533
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
534
+ )
535
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
536
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
537
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
538
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
539
+
540
+ if negative_prompt is not None and negative_prompt_embeds is not None:
541
+ raise ValueError(
542
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
543
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
544
+ )
545
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
546
+ raise ValueError(
547
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
548
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
549
+ )
550
+
551
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
552
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
553
+ raise ValueError(
554
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
555
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
556
+ f" {negative_prompt_embeds.shape}."
557
+ )
558
+
559
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
560
+ raise ValueError(
561
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
562
+ )
563
+
564
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
565
+ raise ValueError(
566
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
567
+ )
568
+ def prepare_img_latents(
569
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
570
+ ):
571
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
572
+ raise ValueError(
573
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
574
+ )
575
+
576
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
577
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
578
+ self.text_encoder_2.to("cpu")
579
+ torch.cuda.empty_cache()
580
+
581
+ image = image.to(device=device, dtype=dtype)
582
+
583
+ batch_size = batch_size * num_images_per_prompt
584
+
585
+ if image.shape[1] == 4:
586
+ init_latents = image
587
+
588
+ else:
589
+ # make sure the VAE is in float32 mode, as it overflows in float16
590
+ if self.vae.config.force_upcast:
591
+ image = image.float()
592
+ self.vae.to(dtype=torch.float32)
593
+
594
+ if isinstance(generator, list) and len(generator) != batch_size:
595
+ raise ValueError(
596
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
597
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
598
+ )
599
+
600
+ elif isinstance(generator, list):
601
+ init_latents = [
602
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
603
+ for i in range(batch_size)
604
+ ]
605
+ init_latents = torch.cat(init_latents, dim=0)
606
+ else:
607
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
608
+
609
+ if self.vae.config.force_upcast:
610
+ self.vae.to(dtype)
611
+
612
+ init_latents = init_latents.to(dtype)
613
+ init_latents = self.vae.config.scaling_factor * init_latents
614
+
615
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
616
+ # expand init_latents for batch_size
617
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
618
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
619
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
620
+ raise ValueError(
621
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
622
+ )
623
+ else:
624
+ init_latents = torch.cat([init_latents], dim=0)
625
+
626
+ if add_noise:
627
+ shape = init_latents.shape
628
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
629
+ # get latents
630
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
631
+
632
+ latents = init_latents
633
+
634
+ return latents
635
+
636
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
637
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
638
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
639
+ if isinstance(generator, list) and len(generator) != batch_size:
640
+ raise ValueError(
641
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
642
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
643
+ )
644
+
645
+ if latents is None:
646
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
647
+ else:
648
+ latents = latents.to(device)
649
+
650
+ # scale the initial noise by the standard deviation required by the scheduler
651
+ latents = latents * self.scheduler.init_noise_sigma
652
+ return latents
653
+
654
+ def _get_add_time_ids(
655
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
656
+ ):
657
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
658
+
659
+ passed_add_embed_dim = (
660
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
661
+ )
662
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
663
+
664
+ if expected_add_embed_dim != passed_add_embed_dim:
665
+ raise ValueError(
666
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
667
+ )
668
+
669
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
670
+ return add_time_ids
671
+
672
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
673
+ def upcast_vae(self):
674
+ dtype = self.vae.dtype
675
+ self.vae.to(dtype=torch.float32)
676
+ use_torch_2_0_or_xformers = isinstance(
677
+ self.vae.decoder.mid_block.attentions[0].processor,
678
+ (
679
+ AttnProcessor2_0,
680
+ XFormersAttnProcessor,
681
+ LoRAXFormersAttnProcessor,
682
+ LoRAAttnProcessor2_0,
683
+ ),
684
+ )
685
+ # if xformers or torch_2_0 is used attention block does not need
686
+ # to be in float32 which can save lots of memory
687
+ if use_torch_2_0_or_xformers:
688
+ self.vae.post_quant_conv.to(dtype)
689
+ self.vae.decoder.conv_in.to(dtype)
690
+ self.vae.decoder.mid_block.to(dtype)
691
+
692
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
693
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
694
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
695
+
696
+ The suffixes after the scaling factors represent the stages where they are being applied.
697
+
698
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
699
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
700
+
701
+ Args:
702
+ s1 (`float`):
703
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
704
+ mitigate "oversmoothing effect" in the enhanced denoising process.
705
+ s2 (`float`):
706
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
707
+ mitigate "oversmoothing effect" in the enhanced denoising process.
708
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
709
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
710
+ """
711
+ if not hasattr(self, "unet"):
712
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
713
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
714
+
715
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
716
+ def disable_freeu(self):
717
+ """Disables the FreeU mechanism if enabled."""
718
+ self.unet.disable_freeu()
719
+
720
+ @property
721
+ def guidance_scale(self):
722
+ return self._guidance_scale
723
+
724
+ @property
725
+ def guidance_rescale(self):
726
+ return self._guidance_rescale
727
+
728
+ @property
729
+ def clip_skip(self):
730
+ return self._clip_skip
731
+
732
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
733
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
734
+ # corresponds to doing no classifier free guidance.
735
+ @property
736
+ def do_classifier_free_guidance(self):
737
+ return self._guidance_scale > 1
738
+
739
+ @property
740
+ def cross_attention_kwargs(self):
741
+ return self._cross_attention_kwargs
742
+
743
+ @property
744
+ def denoising_end(self):
745
+ return self._denoising_end
746
+
747
+ @property
748
+ def num_timesteps(self):
749
+ return self._num_timesteps
750
+
751
+ @torch.no_grad()
752
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
753
+ def __call__(
754
+ self,
755
+ prompt: Union[str, List[str]] = None,
756
+ prompt_2: Optional[Union[str, List[str]]] = None,
757
+ height: Optional[int] = None,
758
+ width: Optional[int] = None,
759
+ num_inference_steps: int = 50,
760
+ denoising_end: Optional[float] = None,
761
+ guidance_scale: float = 5.0,
762
+ negative_prompt: Optional[Union[str, List[str]]] = None,
763
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
764
+ num_images_per_prompt: Optional[int] = 1,
765
+ eta: float = 0.0,
766
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
767
+ latents: Optional[torch.FloatTensor] = None,
768
+ prompt_embeds: Optional[torch.FloatTensor] = None,
769
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
770
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
771
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
772
+ output_type: Optional[str] = "pil",
773
+ return_dict: bool = True,
774
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
775
+ guidance_rescale: float = 0.0,
776
+ original_size: Optional[Tuple[int, int]] = None,
777
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
778
+ target_size: Optional[Tuple[int, int]] = None,
779
+ negative_original_size: Optional[Tuple[int, int]] = None,
780
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
781
+ negative_target_size: Optional[Tuple[int, int]] = None,
782
+ clip_skip: Optional[int] = None,
783
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
784
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
785
+ use_prompt_as_null = False,
786
+ image = None,
787
+ **kwargs,
788
+ ):
789
+ r"""
790
+ Function invoked when calling the pipeline for generation.
791
+
792
+ Args:
793
+ prompt (`str` or `List[str]`, *optional*):
794
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
795
+ instead.
796
+ prompt_2 (`str` or `List[str]`, *optional*):
797
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
798
+ used in both text-encoders
799
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
800
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
801
+ Anything below 512 pixels won't work well for
802
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
803
+ and checkpoints that are not specifically fine-tuned on low resolutions.
804
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
805
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
806
+ Anything below 512 pixels won't work well for
807
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
808
+ and checkpoints that are not specifically fine-tuned on low resolutions.
809
+ num_inference_steps (`int`, *optional*, defaults to 50):
810
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
811
+ expense of slower inference.
812
+ denoising_end (`float`, *optional*):
813
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
814
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
815
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
816
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
817
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
818
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
819
+ guidance_scale (`float`, *optional*, defaults to 5.0):
820
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
821
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
822
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
823
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
824
+ usually at the expense of lower image quality.
825
+ negative_prompt (`str` or `List[str]`, *optional*):
826
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
827
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
828
+ less than `1`).
829
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
830
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
831
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
832
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
833
+ The number of images to generate per prompt.
834
+ eta (`float`, *optional*, defaults to 0.0):
835
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
836
+ [`schedulers.DDIMScheduler`], will be ignored for others.
837
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
838
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
839
+ to make generation deterministic.
840
+ latents (`torch.FloatTensor`, *optional*):
841
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
842
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
843
+ tensor will ge generated by sampling using the supplied random `generator`.
844
+ prompt_embeds (`torch.FloatTensor`, *optional*):
845
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
846
+ provided, text embeddings will be generated from `prompt` input argument.
847
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
848
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
849
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
850
+ argument.
851
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
852
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
853
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
854
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
855
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
856
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
857
+ input argument.
858
+ output_type (`str`, *optional*, defaults to `"pil"`):
859
+ The output format of the generate image. Choose between
860
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
861
+ return_dict (`bool`, *optional*, defaults to `True`):
862
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
863
+ of a plain tuple.
864
+ cross_attention_kwargs (`dict`, *optional*):
865
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
866
+ `self.processor` in
867
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
868
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
869
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
870
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
871
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
872
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
873
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
874
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
875
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
876
+ explained in section 2.2 of
877
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
878
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
879
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
880
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
881
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
882
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
883
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
884
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
885
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
886
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
887
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
888
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
889
+ micro-conditioning as explained in section 2.2 of
890
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
891
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
892
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
893
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
894
+ micro-conditioning as explained in section 2.2 of
895
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
896
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
897
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
898
+ To negatively condition the generation process based on a target image resolution. It should be as same
899
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
900
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
901
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
902
+ callback_on_step_end (`Callable`, *optional*):
903
+ A function that calls at the end of each denoising steps during the inference. The function is called
904
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
905
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
906
+ `callback_on_step_end_tensor_inputs`.
907
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
908
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
909
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
910
+ `._callback_tensor_inputs` attribute of your pipeine class.
911
+
912
+ Examples:
913
+
914
+ Returns:
915
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
916
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
917
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
918
+ """
919
+
920
+
921
+
922
+
923
+ callback = kwargs.pop("callback", None)
924
+ callback_steps = kwargs.pop("callback_steps", None)
925
+
926
+ if callback is not None:
927
+ deprecate(
928
+ "callback",
929
+ "1.0.0",
930
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
931
+ )
932
+ if callback_steps is not None:
933
+ deprecate(
934
+ "callback_steps",
935
+ "1.0.0",
936
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
937
+ )
938
+
939
+
940
+ if image is not None:
941
+ z0 = self.image_processor.preprocess(image)
942
+
943
+
944
+ # 0. Default height and width to unet
945
+ height = height or self.default_sample_size * self.vae_scale_factor
946
+ width = width or self.default_sample_size * self.vae_scale_factor
947
+
948
+ original_size = original_size or (height, width)
949
+ target_size = target_size or (height, width)
950
+
951
+ # 1. Check inputs. Raise error if not correct
952
+ self.check_inputs(
953
+ prompt,
954
+ prompt_2,
955
+ height,
956
+ width,
957
+ callback_steps,
958
+ negative_prompt,
959
+ negative_prompt_2,
960
+ prompt_embeds,
961
+ negative_prompt_embeds,
962
+ pooled_prompt_embeds,
963
+ negative_pooled_prompt_embeds,
964
+ callback_on_step_end_tensor_inputs,
965
+ )
966
+
967
+ self._guidance_scale = guidance_scale
968
+ self._guidance_rescale = guidance_rescale
969
+ self._clip_skip = clip_skip
970
+ self._cross_attention_kwargs = cross_attention_kwargs
971
+ self._denoising_end = denoising_end
972
+
973
+ # 2. Define call parameters
974
+ if prompt is not None and isinstance(prompt, str):
975
+ batch_size = 1
976
+ elif prompt is not None and isinstance(prompt, list):
977
+ batch_size = len(prompt)
978
+ else:
979
+ batch_size = prompt_embeds.shape[0]
980
+
981
+ device = self._execution_device
982
+
983
+ # 3. Encode input prompt
984
+ lora_scale = (
985
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
986
+ )
987
+
988
+
989
+ (
990
+ prompt_embeds,
991
+ negative_prompt_embeds,
992
+ pooled_prompt_embeds,
993
+ negative_pooled_prompt_embeds,
994
+ ) = self.encode_prompt(
995
+ prompt=prompt,
996
+ prompt_2=prompt_2,
997
+ device=device,
998
+ num_images_per_prompt=num_images_per_prompt,
999
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1000
+ negative_prompt=negative_prompt,
1001
+ negative_prompt_2=negative_prompt_2,
1002
+ prompt_embeds=prompt_embeds,
1003
+ negative_prompt_embeds=negative_prompt_embeds,
1004
+ pooled_prompt_embeds=pooled_prompt_embeds,
1005
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1006
+ lora_scale=lora_scale,
1007
+ clip_skip=self.clip_skip,
1008
+ )
1009
+
1010
+ if kwargs['target_prompt'] is not None:
1011
+ (
1012
+ prompt_embeds_,
1013
+ negative_prompt_embeds_,
1014
+ pooled_prompt_embeds_,
1015
+ negative_pooled_prompt_embeds_,
1016
+ ) = self.encode_prompt(
1017
+ prompt=kwargs['target_prompt'],
1018
+ prompt_2=prompt_2,
1019
+ device=device,
1020
+ num_images_per_prompt=num_images_per_prompt,
1021
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1022
+ # negative_prompt=negative_prompt,
1023
+ negative_prompt=None, #if kwargs["target_neg"] is None else kwargs["target_neg"],
1024
+ # negative_prompt_2=negative_prompt_2,
1025
+ negative_prompt_2=None,
1026
+ prompt_embeds=None,
1027
+ negative_prompt_embeds=None,
1028
+ pooled_prompt_embeds=None,
1029
+ negative_pooled_prompt_embeds=None,
1030
+ lora_scale=lora_scale,
1031
+ clip_skip=self.clip_skip,
1032
+ )
1033
+
1034
+ prompt_embeds[1:] = prompt_embeds_[1:]
1035
+ pooled_prompt_embeds[1:] = pooled_prompt_embeds_[1:]
1036
+ if not kwargs['use_inf_negative_prompt']:
1037
+ negative_prompt_embeds[1:] = negative_prompt_embeds_[1:]
1038
+ negative_pooled_prompt_embeds[1:] = negative_pooled_prompt_embeds_[1:]
1039
+
1040
+
1041
+ # 4. Prepare timesteps
1042
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1043
+
1044
+ timesteps = self.scheduler.timesteps
1045
+
1046
+
1047
+ # 5. Prepare latent variables
1048
+ num_channels_latents = self.unet.config.in_channels
1049
+ latents = self.prepare_latents(
1050
+ batch_size * num_images_per_prompt,
1051
+ num_channels_latents,
1052
+ height,
1053
+ width,
1054
+ prompt_embeds.dtype,
1055
+ device,
1056
+ generator,
1057
+ latents,
1058
+ )
1059
+
1060
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1061
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1062
+
1063
+ # 7. Prepare added time ids & embeddings
1064
+ add_text_embeds = pooled_prompt_embeds
1065
+ if self.text_encoder_2 is None:
1066
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1067
+ else:
1068
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1069
+
1070
+ add_time_ids = self._get_add_time_ids(
1071
+ original_size,
1072
+ crops_coords_top_left,
1073
+ target_size,
1074
+ dtype=prompt_embeds.dtype,
1075
+ text_encoder_projection_dim=text_encoder_projection_dim,
1076
+ )
1077
+ if negative_original_size is not None and negative_target_size is not None:
1078
+ negative_add_time_ids = self._get_add_time_ids(
1079
+ negative_original_size,
1080
+ negative_crops_coords_top_left,
1081
+ negative_target_size,
1082
+ dtype=prompt_embeds.dtype,
1083
+ text_encoder_projection_dim=text_encoder_projection_dim,
1084
+ )
1085
+ else:
1086
+ negative_add_time_ids = add_time_ids
1087
+
1088
+ if self.do_classifier_free_guidance:
1089
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1090
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1091
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1092
+
1093
+ prompt_embeds = prompt_embeds.to(device)
1094
+ add_text_embeds = add_text_embeds.to(device)
1095
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1096
+
1097
+ # 8. Denoising loop
1098
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1099
+
1100
+ # 8.1 Apply denoising_end
1101
+ if (
1102
+ self.denoising_end is not None
1103
+ and isinstance(self.denoising_end, float)
1104
+ and self.denoising_end > 0
1105
+ and self.denoising_end < 1
1106
+ ):
1107
+ discrete_timestep_cutoff = int(
1108
+ round(
1109
+ self.scheduler.config.num_train_timesteps
1110
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1111
+ )
1112
+ )
1113
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1114
+ timesteps = timesteps[:num_inference_steps]
1115
+
1116
+ self._num_timesteps = len(timesteps)
1117
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1118
+ for i, t in enumerate(timesteps):
1119
+
1120
+
1121
+ if image is not None:
1122
+ zt = self.prepare_img_latents(z0,t.repeat(1),1, num_images_per_prompt,prompt_embeds.dtype,device,generator,True)# add_noise/
1123
+
1124
+ latents[0] = zt[0]
1125
+
1126
+ # expand the latents if we are doing classifier free guidance
1127
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1128
+
1129
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1130
+
1131
+ # predict the noise residual
1132
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1133
+ noise_pred = self.unet(
1134
+ latent_model_input,
1135
+ t,
1136
+ encoder_hidden_states=prompt_embeds,
1137
+ cross_attention_kwargs=self.cross_attention_kwargs,
1138
+ added_cond_kwargs=added_cond_kwargs,
1139
+ return_dict=False,
1140
+ )[0]
1141
+
1142
+
1143
+ # perform guidance
1144
+ if self.do_classifier_free_guidance:
1145
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1146
+
1147
+ # tmp_noise_pred_text = noise_pred_text[0]######### reconstruction only
1148
+ # import pdb; pdb.set_trace()
1149
+ # if 1 < i < 3 and kwargs["use_advanced_sampling"]:
1150
+ if i < 3 and kwargs["use_advanced_sampling"]:
1151
+ noise_pred = noise_pred_uncond + 20.0 * (noise_pred_text - noise_pred_uncond)
1152
+ # noise_pred[0] = noise_pred_uncond[0] + self.guidance_scale * (noise_pred_text[0] - noise_pred_uncond[0])
1153
+ else:
1154
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1155
+
1156
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1157
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1158
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1159
+
1160
+ if use_prompt_as_null:
1161
+ noise_pred[0] = noise_pred_text[0]
1162
+
1163
+
1164
+ # noise_pred[0] = tmp_noise_pred_text######## reconstruction only
1165
+
1166
+ # compute the previous noisy sample x_t -> x_t-1
1167
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1168
+
1169
+ if callback_on_step_end is not None:
1170
+ callback_kwargs = {}
1171
+ for k in callback_on_step_end_tensor_inputs:
1172
+ callback_kwargs[k] = locals()[k]
1173
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1174
+
1175
+ latents = callback_outputs.pop("latents", latents)
1176
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1177
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1178
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1179
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1180
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1181
+ )
1182
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1183
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1184
+
1185
+ # call the callback, if provided
1186
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1187
+ progress_bar.update()
1188
+ if callback is not None and i % callback_steps == 0:
1189
+ step_idx = i // getattr(self.scheduler, "order", 1)
1190
+ callback(step_idx, t, latents)
1191
+
1192
+ if XLA_AVAILABLE:
1193
+ xm.mark_step()
1194
+
1195
+ if not output_type == "latent":
1196
+ # make sure the VAE is in float32 mode, as it overflows in float16
1197
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1198
+
1199
+ if needs_upcasting:
1200
+ self.upcast_vae()
1201
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1202
+ self.enable_vae_slicing()
1203
+
1204
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1205
+
1206
+ # cast back to fp16 if needed
1207
+ if needs_upcasting:
1208
+ self.vae.to(dtype=torch.float16)
1209
+ else:
1210
+ image = latents
1211
+
1212
+ if not output_type == "latent":
1213
+ # apply watermark if available
1214
+ if self.watermark is not None:
1215
+ image = self.watermark.apply_watermark(image)
1216
+
1217
+ image = self.image_processor.postprocess(image, output_type=output_type)
1218
+
1219
+ # Offload all models
1220
+ self.maybe_free_model_hooks()
1221
+
1222
+ if not return_dict:
1223
+ return (image,)
1224
+
1225
+ return StableDiffusionXLPipelineOutput(images=image)
1226
+
1227
+ @torch.no_grad()
1228
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1229
+ def inverted_ve_cross_frame_attn(
1230
+ self,
1231
+ prompt: Union[str, List[str]] = None,
1232
+ prompt_2: Optional[Union[str, List[str]]] = None,
1233
+ height: Optional[int] = None,
1234
+ width: Optional[int] = None,
1235
+ num_inference_steps: int = 50,
1236
+ denoising_end: Optional[float] = None,
1237
+ guidance_scale: float = 5.0,
1238
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1239
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1240
+ num_images_per_prompt: Optional[int] = 1,
1241
+ eta: float = 0.0,
1242
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1243
+ latents: Optional[torch.FloatTensor] = None,
1244
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1245
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1246
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1247
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1248
+ output_type: Optional[str] = "pil",
1249
+ return_dict: bool = True,
1250
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1251
+ guidance_rescale: float = 0.0,
1252
+ original_size: Optional[Tuple[int, int]] = None,
1253
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1254
+ target_size: Optional[Tuple[int, int]] = None,
1255
+ negative_original_size: Optional[Tuple[int, int]] = None,
1256
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1257
+ negative_target_size: Optional[Tuple[int, int]] = None,
1258
+ clip_skip: Optional[int] = None,
1259
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1260
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1261
+ **kwargs,
1262
+ ):
1263
+ r"""
1264
+ Function invoked when calling the pipeline for generation.
1265
+
1266
+ Args:
1267
+ prompt (`str` or `List[str]`, *optional*):
1268
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1269
+ instead.
1270
+ prompt_2 (`str` or `List[str]`, *optional*):
1271
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1272
+ used in both text-encoders
1273
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1274
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1275
+ Anything below 512 pixels won't work well for
1276
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1277
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1278
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1279
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1280
+ Anything below 512 pixels won't work well for
1281
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1282
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1283
+ num_inference_steps (`int`, *optional*, defaults to 50):
1284
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1285
+ expense of slower inference.
1286
+ denoising_end (`float`, *optional*):
1287
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1288
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1289
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
1290
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
1291
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1292
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1293
+ guidance_scale (`float`, *optional*, defaults to 5.0):
1294
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1295
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1296
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1297
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1298
+ usually at the expense of lower image quality.
1299
+ negative_prompt (`str` or `List[str]`, *optional*):
1300
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1301
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1302
+ less than `1`).
1303
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1304
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1305
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1306
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1307
+ The number of images to generate per prompt.
1308
+ eta (`float`, *optional*, defaults to 0.0):
1309
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1310
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1311
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1312
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1313
+ to make generation deterministic.
1314
+ latents (`torch.FloatTensor`, *optional*):
1315
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1316
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1317
+ tensor will ge generated by sampling using the supplied random `generator`.
1318
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1319
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1320
+ provided, text embeddings will be generated from `prompt` input argument.
1321
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1322
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1323
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1324
+ argument.
1325
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1326
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1327
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1328
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1329
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1330
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1331
+ input argument.
1332
+ output_type (`str`, *optional*, defaults to `"pil"`):
1333
+ The output format of the generate image. Choose between
1334
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1335
+ return_dict (`bool`, *optional*, defaults to `True`):
1336
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
1337
+ of a plain tuple.
1338
+ cross_attention_kwargs (`dict`, *optional*):
1339
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1340
+ `self.processor` in
1341
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1342
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
1343
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
1344
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
1345
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
1346
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
1347
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1348
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1349
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1350
+ explained in section 2.2 of
1351
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1352
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1353
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1354
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1355
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1356
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1357
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1358
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1359
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1360
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1361
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1362
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1363
+ micro-conditioning as explained in section 2.2 of
1364
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1365
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1366
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1367
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1368
+ micro-conditioning as explained in section 2.2 of
1369
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1370
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1371
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1372
+ To negatively condition the generation process based on a target image resolution. It should be as same
1373
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1374
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1375
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1376
+ callback_on_step_end (`Callable`, *optional*):
1377
+ A function that calls at the end of each denoising steps during the inference. The function is called
1378
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1379
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1380
+ `callback_on_step_end_tensor_inputs`.
1381
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1382
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1383
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1384
+ `._callback_tensor_inputs` attribute of your pipeine class.
1385
+
1386
+ Examples:
1387
+
1388
+ Returns:
1389
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
1390
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1391
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1392
+ """
1393
+
1394
+
1395
+
1396
+ callback = kwargs.pop("callback", None)
1397
+ callback_steps = kwargs.pop("callback_steps", None)
1398
+
1399
+ if callback is not None:
1400
+ deprecate(
1401
+ "callback",
1402
+ "1.0.0",
1403
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1404
+ )
1405
+ if callback_steps is not None:
1406
+ deprecate(
1407
+ "callback_steps",
1408
+ "1.0.0",
1409
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1410
+ )
1411
+
1412
+ # 0. Default height and width to unet
1413
+ height = height or self.default_sample_size * self.vae_scale_factor
1414
+ width = width or self.default_sample_size * self.vae_scale_factor
1415
+
1416
+ original_size = original_size or (height, width)
1417
+ target_size = target_size or (height, width)
1418
+
1419
+ # 1. Check inputs. Raise error if not correct
1420
+ self.check_inputs(
1421
+ prompt,
1422
+ prompt_2,
1423
+ height,
1424
+ width,
1425
+ callback_steps,
1426
+ negative_prompt,
1427
+ negative_prompt_2,
1428
+ prompt_embeds,
1429
+ negative_prompt_embeds,
1430
+ pooled_prompt_embeds,
1431
+ negative_pooled_prompt_embeds,
1432
+ callback_on_step_end_tensor_inputs,
1433
+ )
1434
+
1435
+ self._guidance_scale = guidance_scale
1436
+ self._guidance_rescale = guidance_rescale
1437
+ self._clip_skip = clip_skip
1438
+ self._cross_attention_kwargs = cross_attention_kwargs
1439
+ self._denoising_end = denoising_end
1440
+
1441
+ # 2. Define call parameters
1442
+ if prompt is not None and isinstance(prompt, str):
1443
+ batch_size = 1
1444
+ elif prompt is not None and isinstance(prompt, list):
1445
+ batch_size = len(prompt)
1446
+ else:
1447
+ batch_size = prompt_embeds.shape[0]
1448
+
1449
+ device = self._execution_device
1450
+
1451
+ # 3. Encode input prompt
1452
+ lora_scale = (
1453
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1454
+ )
1455
+
1456
+ (
1457
+ prompt_embeds,
1458
+ negative_prompt_embeds,
1459
+ pooled_prompt_embeds,
1460
+ negative_pooled_prompt_embeds,
1461
+ ) = self.encode_prompt(
1462
+ prompt=prompt,
1463
+ prompt_2=prompt_2,
1464
+ device=device,
1465
+ num_images_per_prompt=num_images_per_prompt,
1466
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1467
+ negative_prompt=negative_prompt,
1468
+ negative_prompt_2=negative_prompt_2,
1469
+ prompt_embeds=prompt_embeds,
1470
+ negative_prompt_embeds=negative_prompt_embeds,
1471
+ pooled_prompt_embeds=pooled_prompt_embeds,
1472
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1473
+ lora_scale=lora_scale,
1474
+ clip_skip=self.clip_skip,
1475
+ )
1476
+
1477
+ if kwargs['target_prompt'] is not None:
1478
+ (
1479
+ prompt_embeds_,
1480
+ negative_prompt_embeds_,
1481
+ _,
1482
+ _,
1483
+ ) = self.encode_prompt(
1484
+ prompt=kwargs['target_prompt'],
1485
+ prompt_2=prompt_2,
1486
+ device=device,
1487
+ num_images_per_prompt=num_images_per_prompt,
1488
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1489
+ negative_prompt=kwargs['target_negative_prompt'] if kwargs['target_negative_prompt'] is not None else None,
1490
+ # negative_prompt=None,
1491
+ # negative_prompt_2=negative_prompt_2,
1492
+ negative_prompt_2=None,
1493
+ prompt_embeds=None,
1494
+ negative_prompt_embeds=None,
1495
+ pooled_prompt_embeds=None,
1496
+ negative_pooled_prompt_embeds=None,
1497
+ lora_scale=lora_scale,
1498
+ clip_skip=self.clip_skip,
1499
+ )
1500
+ prompt_embeds[1:] = prompt_embeds_[1:]
1501
+ if negative_prompt_embeds_ is not None:
1502
+ negative_prompt_embeds[1:] = negative_prompt_embeds_[1:]
1503
+
1504
+
1505
+ # 4. Prepare timesteps
1506
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1507
+
1508
+ timesteps = self.scheduler.timesteps
1509
+
1510
+ # 5. Prepare latent variables
1511
+ num_channels_latents = self.unet.config.in_channels
1512
+ latents = self.prepare_latents(
1513
+ batch_size * num_images_per_prompt,
1514
+ num_channels_latents,
1515
+ height,
1516
+ width,
1517
+ prompt_embeds.dtype,
1518
+ device,
1519
+ generator,
1520
+ latents,
1521
+ )
1522
+
1523
+
1524
+ latents_ = self.prepare_latents(
1525
+ batch_size * num_images_per_prompt,
1526
+ num_channels_latents,
1527
+ height,
1528
+ width,
1529
+ prompt_embeds.dtype,
1530
+ device,
1531
+ generator,
1532
+ # latents,
1533
+ )
1534
+
1535
+ # import pdb; pdb.set_trace()
1536
+
1537
+ # latents[1:] = latents_[1:]
1538
+ latents = torch.cat([latents.unsqueeze(0), latents_[1:]], dim=0)
1539
+
1540
+
1541
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1542
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1543
+
1544
+ # 7. Prepare added time ids & embeddings
1545
+ add_text_embeds = pooled_prompt_embeds
1546
+ if self.text_encoder_2 is None:
1547
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1548
+ else:
1549
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1550
+
1551
+ add_time_ids = self._get_add_time_ids(
1552
+ original_size,
1553
+ crops_coords_top_left,
1554
+ target_size,
1555
+ dtype=prompt_embeds.dtype,
1556
+ text_encoder_projection_dim=text_encoder_projection_dim,
1557
+ )
1558
+ if negative_original_size is not None and negative_target_size is not None:
1559
+ negative_add_time_ids = self._get_add_time_ids(
1560
+ negative_original_size,
1561
+ negative_crops_coords_top_left,
1562
+ negative_target_size,
1563
+ dtype=prompt_embeds.dtype,
1564
+ text_encoder_projection_dim=text_encoder_projection_dim,
1565
+ )
1566
+ else:
1567
+ negative_add_time_ids = add_time_ids
1568
+
1569
+ if self.do_classifier_free_guidance:
1570
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1571
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1572
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1573
+
1574
+ prompt_embeds = prompt_embeds.to(device)
1575
+ add_text_embeds = add_text_embeds.to(device)
1576
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1577
+
1578
+ # 8. Denoising loop
1579
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1580
+
1581
+ # 8.1 Apply denoising_end
1582
+ if (
1583
+ self.denoising_end is not None
1584
+ and isinstance(self.denoising_end, float)
1585
+ and self.denoising_end > 0
1586
+ and self.denoising_end < 1
1587
+ ):
1588
+ discrete_timestep_cutoff = int(
1589
+ round(
1590
+ self.scheduler.config.num_train_timesteps
1591
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1592
+ )
1593
+ )
1594
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1595
+ timesteps = timesteps[:num_inference_steps]
1596
+
1597
+ self._num_timesteps = len(timesteps)
1598
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1599
+ for i, t in enumerate(timesteps):
1600
+ # expand the latents if we are doing classifier free guidance
1601
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1602
+
1603
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1604
+
1605
+ # predict the noise residual
1606
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1607
+ noise_pred = self.unet(
1608
+ latent_model_input,
1609
+ t,
1610
+ encoder_hidden_states=prompt_embeds,
1611
+ cross_attention_kwargs=self.cross_attention_kwargs,
1612
+ added_cond_kwargs=added_cond_kwargs,
1613
+ return_dict=False,
1614
+ )[0]
1615
+
1616
+ # perform guidance
1617
+ if self.do_classifier_free_guidance:
1618
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1619
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1620
+ noise_pred[0] = noise_pred_uncond[0] #추가된것
1621
+
1622
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1623
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1624
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1625
+ noise_pred[0] = noise_pred_uncond[0] #추가된것
1626
+
1627
+ # compute the previous noisy sample x_t -> x_t-1
1628
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1629
+
1630
+ if callback_on_step_end is not None:
1631
+ callback_kwargs = {}
1632
+ for k in callback_on_step_end_tensor_inputs:
1633
+ callback_kwargs[k] = locals()[k]
1634
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1635
+
1636
+ latents = callback_outputs.pop("latents", latents)
1637
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1638
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1639
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1640
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1641
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1642
+ )
1643
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1644
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1645
+
1646
+ # call the callback, if provided
1647
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1648
+ progress_bar.update()
1649
+ if callback is not None and i % callback_steps == 0:
1650
+ step_idx = i // getattr(self.scheduler, "order", 1)
1651
+ callback(step_idx, t, latents)
1652
+
1653
+ if XLA_AVAILABLE:
1654
+ xm.mark_step()
1655
+
1656
+ if not output_type == "latent":
1657
+ # make sure the VAE is in float32 mode, as it overflows in float16
1658
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1659
+
1660
+ if needs_upcasting:
1661
+ self.upcast_vae()
1662
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1663
+
1664
+ self.enable_vae_slicing()
1665
+
1666
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1667
+
1668
+ # cast back to fp16 if needed
1669
+ if needs_upcasting:
1670
+ self.vae.to(dtype=torch.float16)
1671
+ else:
1672
+ image = latents
1673
+
1674
+ if not output_type == "latent":
1675
+ # apply watermark if available
1676
+ if self.watermark is not None:
1677
+ image = self.watermark.apply_watermark(image)
1678
+
1679
+ image = self.image_processor.postprocess(image, output_type=output_type)
1680
+
1681
+ # Offload all models
1682
+ self.maybe_free_model_hooks()
1683
+
1684
+ if not return_dict:
1685
+ return (image,)
1686
+
1687
+ return StableDiffusionXLPipelineOutput(images=image)
1688
+
1689
+
1690
+ @torch.no_grad()
1691
+ def activate_layer(self,
1692
+ activate_layer_indices,
1693
+ attn_map_save_steps=[],
1694
+ activate_step_indices = None,
1695
+ use_shared_attention = False,
1696
+ adain_queries=True,
1697
+ adain_keys=True,
1698
+ adain_values=False,
1699
+ ):
1700
+
1701
+
1702
+ attn_procs = {}
1703
+ activate_layer = []
1704
+ str_activate_layer = ""
1705
+ for activate_layer_index in activate_layer_indices:
1706
+ activate_layer += ACTIVATE_LAYER_CANDIDATE[activate_layer_index[0]:activate_layer_index[1]]
1707
+ str_activate_layer += str(activate_layer_index)
1708
+
1709
+ str_activate_step = ""
1710
+ for activate_step_index in activate_step_indices:
1711
+ str_activate_step += str(activate_step_index)
1712
+
1713
+ for name in self.unet.attn_processors.keys():
1714
+ if name in activate_layer:
1715
+ if not use_shared_attention:
1716
+ attn_procs[name] = CrossFrameAttnProcessor(unet_chunk_size=2,
1717
+ attn_map_save_steps=attn_map_save_steps,
1718
+ activate_step_indices=activate_step_indices)
1719
+ else:
1720
+
1721
+ activate_save_layer = [
1722
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor',
1723
+ 'up_blocks.0.attentions.2.transformer_blocks.0.attn1.processor',
1724
+ 'up_blocks.0.attentions.1.transformer_blocks.0.attn1.processor',
1725
+ 'up_blocks.0.attentions.0.transformer_blocks.0.attn1.processor',
1726
+ 'mid_block.attentions.0.transformer_blocks.0.attn1.processor'
1727
+ ]
1728
+ if name in activate_save_layer:
1729
+ attn_procs[name] = SharedAttentionProcessor_v2(
1730
+ adain_keys=adain_keys,
1731
+ adain_queries=adain_queries,
1732
+ adain_values=adain_values,
1733
+ attn_map_save_steps = attn_map_save_steps,
1734
+ keys_scale=1.0,
1735
+ )
1736
+ else:
1737
+ attn_procs[name] = SharedAttentionProcessor(
1738
+ # unet_chunk_size=2,
1739
+ # attn_map_save_steps=attn_map_save_steps,
1740
+ # activate_step_indices=activate_step_indices,
1741
+ adain_keys=adain_keys,
1742
+ adain_queries=adain_queries,
1743
+ adain_values=adain_values,
1744
+ keys_scale=1.0,
1745
+ )
1746
+ else :
1747
+ attn_procs[name] = AttnProcessor()
1748
+
1749
+ self.unet.set_attn_processor(attn_procs)
1750
+
1751
+ return str_activate_layer, str_activate_step
1752
+
1753
+ @torch.no_grad()
1754
+ def get_init_latent(self,
1755
+ precomputed_path,
1756
+ seed):
1757
+
1758
+
1759
+ if not os.path.exists(precomputed_path):
1760
+ os.makedirs(precomputed_path)
1761
+
1762
+ #search init latents in precomputed latents
1763
+ init_latent_name = f'init_latent_{seed}.pt'
1764
+ init_latent_path = os.path.join(precomputed_path, init_latent_name)
1765
+
1766
+
1767
+
1768
+ # 0. Default height and width to unet
1769
+ height = self.default_sample_size * self.vae_scale_factor
1770
+ width = self.default_sample_size * self.vae_scale_factor
1771
+
1772
+ num_channels_latents = self.unet.config.in_channels
1773
+
1774
+
1775
+
1776
+ if not os.path.exists(init_latent_path):
1777
+ print(f'init_latent_{seed}.pt is not exist')
1778
+ # device= self._execution_device
1779
+ device = torch.device("cpu")
1780
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
1781
+
1782
+
1783
+ shape = (1, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
1784
+
1785
+ init_latent = randn_tensor(shape, generator=generator, dtype = self.dtype, device=device)
1786
+
1787
+ torch.save(init_latent, init_latent_path)
1788
+ else:
1789
+ print(f'init_latent_{seed}.pt is exist')
1790
+ init_latent = torch.load(init_latent_path)
1791
+
1792
+ return init_latent
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ diffusers
3
+ transformers
4
+ accelerate
5
+ einops
6
+ kornia
7
+ gradio
8
+ torchvision
9
+ opencv-python
10
+ xformers
utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.utils.torch_utils import randn_tensor
3
+
4
+ import json, os, cv2
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ def parse_config(config):
9
+ with open(config, 'r') as f:
10
+ config = json.load(f)
11
+ return config
12
+
13
+ def load_config(config):
14
+ activate_layer_indices_list = config['inference_info']['activate_layer_indices_list']
15
+ activate_step_indices_list = config['inference_info']['activate_step_indices_list']
16
+ ref_seeds = config['reference_info']['ref_seeds']
17
+ inf_seeds = config['inference_info']['inf_seeds']
18
+
19
+ attn_map_save_steps = config['inference_info']['attn_map_save_steps']
20
+ precomputed_path = config['precomputed_path']
21
+ guidance_scale = config['guidance_scale']
22
+ use_inf_negative_prompt = config['inference_info']['use_negative_prompt']
23
+
24
+ style_name_list = config["style_name_list"]
25
+ ref_object_list = config["reference_info"]["ref_object_list"]
26
+ inf_object_list = config["inference_info"]["inf_object_list"]
27
+ ref_with_style_description = config['reference_info']['with_style_description']
28
+ inf_with_style_description = config['inference_info']['with_style_description']
29
+
30
+
31
+ use_shared_attention = config['inference_info']['use_shared_attention']
32
+ adain_queries = config['inference_info']['adain_queries']
33
+ adain_keys = config['inference_info']['adain_keys']
34
+ adain_values = config['inference_info']['adain_values']
35
+ use_advanced_sampling = config['inference_info']['use_advanced_sampling']
36
+
37
+ out = [
38
+ activate_layer_indices_list, activate_step_indices_list,
39
+ ref_seeds, inf_seeds,
40
+ attn_map_save_steps, precomputed_path, guidance_scale, use_inf_negative_prompt,
41
+ style_name_list, ref_object_list, inf_object_list, ref_with_style_description, inf_with_style_description,
42
+ use_shared_attention, adain_queries, adain_keys, adain_values, use_advanced_sampling
43
+
44
+ ]
45
+ return out
46
+
47
+ def memory_efficient(model, device):
48
+ try:
49
+ model.to(device)
50
+ except Exception as e:
51
+ print("Error moving model to device:", e)
52
+
53
+ try:
54
+ model.enable_model_cpu_offload()
55
+ except AttributeError:
56
+ print("enable_model_cpu_offload is not supported.")
57
+ try:
58
+ model.enable_vae_slicing()
59
+ except AttributeError:
60
+ print("enable_vae_slicing is not supported.")
61
+
62
+ try:
63
+ model.enable_vae_tiling()
64
+ except AttributeError:
65
+ print("enable_vae_tiling is not supported.")
66
+
67
+ try:
68
+ model.enable_xformers_memory_efficient_attention()
69
+ except AttributeError:
70
+ print("enable_xformers_memory_efficient_attention is not supported.")
71
+
72
+ def init_latent(model, device_name='cuda', dtype=torch.float16, seed=None):
73
+ scale_factor = model.vae_scale_factor
74
+ sample_size = model.default_sample_size
75
+ latent_dim = model.unet.config.in_channels
76
+
77
+ height = sample_size * scale_factor
78
+ width = sample_size * scale_factor
79
+
80
+ shape = (1, latent_dim, height // scale_factor, width // scale_factor)
81
+
82
+ device = torch.device(device_name)
83
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
84
+
85
+ latent = randn_tensor(shape, generator=generator, dtype=dtype, device=device)
86
+
87
+ return latent
88
+
89
+
90
+ def get_canny_edge_array(canny_img_path, threshold1=100,threshold2=200):
91
+ canny_image_list = []
92
+
93
+ # check if canny_img_path is a directory
94
+ if os.path.isdir(canny_img_path):
95
+ canny_img_list = os.listdir(canny_img_path)
96
+ for canny_img in canny_img_list:
97
+ canny_image_tmp = Image.open(os.path.join(canny_img_path, canny_img))
98
+ #resize image into1024x1024
99
+ canny_image_tmp = canny_image_tmp.resize((1024,1024))
100
+ canny_image_tmp = np.array(canny_image_tmp)
101
+ canny_image_tmp = cv2.Canny(canny_image_tmp, threshold1, threshold2)
102
+ canny_image_tmp = canny_image_tmp[:, :, None]
103
+ canny_image_tmp = np.concatenate([canny_image_tmp, canny_image_tmp, canny_image_tmp], axis=2)
104
+ canny_image = Image.fromarray(canny_image_tmp)
105
+ canny_image_list.append(canny_image)
106
+
107
+ return canny_image_list
108
+
109
+ def get_depth_map(image, feature_extractor, depth_estimator, device='cuda'):
110
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
111
+ with torch.no_grad(), torch.autocast(device):
112
+ depth_map = depth_estimator(image).predicted_depth
113
+
114
+ depth_map = torch.nn.functional.interpolate(
115
+ depth_map.unsqueeze(1),
116
+ size=(1024, 1024),
117
+ mode="bicubic",
118
+ align_corners=False,
119
+ )
120
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
121
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
122
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
123
+ image = torch.cat([depth_map] * 3, dim=1)
124
+
125
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
126
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
127
+
128
+ return image
129
+
130
+ def get_depth_edge_array(depth_img_path, feature_extractor, depth_estimator, device='cuda'):
131
+ depth_image_list = []
132
+
133
+ # check if canny_img_path is a directory
134
+ if os.path.isdir(depth_img_path):
135
+ depth_img_list = os.listdir(depth_img_path)
136
+ for depth_img in depth_img_list:
137
+ depth_image_tmp = Image.open(os.path.join(depth_img_path, depth_img)).convert('RGB')
138
+
139
+ # get depth map
140
+ depth_map = get_depth_map(depth_image_tmp, feature_extractor, depth_estimator, device)
141
+ depth_image_list.append(depth_map)
142
+
143
+ return depth_image_list
visualize_attention_src/__init__.py ADDED
File without changes
visualize_attention_src/pipeline_stable_diffusion_xl_attn.py ADDED
@@ -0,0 +1,1573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
20
+
21
+ from diffusers.image_processor import VaeImageProcessor
22
+ from diffusers.loaders import (
23
+ FromSingleFileMixin,
24
+ StableDiffusionXLLoraLoaderMixin,
25
+ TextualInversionLoaderMixin,
26
+ )
27
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
+ from diffusers.models.attention_processor import (
29
+ AttnProcessor2_0,
30
+ LoRAAttnProcessor2_0,
31
+ LoRAXFormersAttnProcessor,
32
+ XFormersAttnProcessor,
33
+ )
34
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
35
+ from diffusers.schedulers import KarrasDiffusionSchedulers
36
+ from diffusers.utils import (
37
+ USE_PEFT_BACKEND,
38
+ deprecate,
39
+ is_invisible_watermark_available,
40
+ is_torch_xla_available,
41
+ logging,
42
+ replace_example_docstring,
43
+ scale_lora_layers,
44
+ unscale_lora_layers,
45
+ )
46
+ from diffusers.utils.torch_utils import randn_tensor
47
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
49
+
50
+
51
+ if is_invisible_watermark_available():
52
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
53
+
54
+ if is_torch_xla_available():
55
+ import torch_xla.core.xla_model as xm
56
+
57
+ XLA_AVAILABLE = True
58
+ else:
59
+ XLA_AVAILABLE = False
60
+
61
+
62
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
63
+
64
+ EXAMPLE_DOC_STRING = """
65
+ Examples:
66
+ ```py
67
+ >>> import torch
68
+ >>> from diffusers import StableDiffusionXLPipeline
69
+
70
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
71
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
72
+ ... )
73
+ >>> pipe = pipe.to("cuda")
74
+
75
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
76
+ >>> image = pipe(prompt).images[0]
77
+ ```
78
+ """
79
+
80
+
81
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
82
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
83
+ """
84
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
85
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
86
+ """
87
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
88
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
89
+ # rescale the results from guidance (fixes overexposure)
90
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
91
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
92
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
93
+ return noise_cfg
94
+
95
+
96
+ class StableDiffusionXLPipeline(
97
+ DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
98
+ ):
99
+ r"""
100
+ Pipeline for text-to-image generation using Stable Diffusion XL.
101
+
102
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
103
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
104
+
105
+ In addition the pipeline inherits the following loading methods:
106
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
107
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
108
+
109
+ as well as the following saving methods:
110
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
111
+
112
+ Args:
113
+ vae ([`AutoencoderKL`]):
114
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
115
+ text_encoder ([`CLIPTextModel`]):
116
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
117
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
118
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
119
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
120
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
121
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
122
+ specifically the
123
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
124
+ variant.
125
+ tokenizer (`CLIPTokenizer`):
126
+ Tokenizer of class
127
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
128
+ tokenizer_2 (`CLIPTokenizer`):
129
+ Second Tokenizer of class
130
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
131
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
132
+ scheduler ([`SchedulerMixin`]):
133
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
134
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
135
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
136
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
137
+ `stabilityai/stable-diffusion-xl-base-1-0`.
138
+ add_watermarker (`bool`, *optional*):
139
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
140
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
141
+ watermarker will be used.
142
+ """
143
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
144
+ _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
145
+ _callback_tensor_inputs = [
146
+ "latents",
147
+ "prompt_embeds",
148
+ "negative_prompt_embeds",
149
+ "add_text_embeds",
150
+ "add_time_ids",
151
+ "negative_pooled_prompt_embeds",
152
+ "negative_add_time_ids",
153
+ ]
154
+
155
+ def __init__(
156
+ self,
157
+ vae: AutoencoderKL,
158
+ text_encoder: CLIPTextModel,
159
+ text_encoder_2: CLIPTextModelWithProjection,
160
+ tokenizer: CLIPTokenizer,
161
+ tokenizer_2: CLIPTokenizer,
162
+ unet: UNet2DConditionModel,
163
+ scheduler: KarrasDiffusionSchedulers,
164
+ force_zeros_for_empty_prompt: bool = True,
165
+ add_watermarker: Optional[bool] = None,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.register_modules(
170
+ vae=vae,
171
+ text_encoder=text_encoder,
172
+ text_encoder_2=text_encoder_2,
173
+ tokenizer=tokenizer,
174
+ tokenizer_2=tokenizer_2,
175
+ unet=unet,
176
+ scheduler=scheduler,
177
+ )
178
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
179
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
180
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
181
+
182
+ self.default_sample_size = self.unet.config.sample_size
183
+
184
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
185
+
186
+ if add_watermarker:
187
+ self.watermark = StableDiffusionXLWatermarker()
188
+ else:
189
+ self.watermark = None
190
+
191
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
192
+ def enable_vae_slicing(self):
193
+ r"""
194
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
195
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
196
+ """
197
+ self.vae.enable_slicing()
198
+
199
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
200
+ def disable_vae_slicing(self):
201
+ r"""
202
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
203
+ computing decoding in one step.
204
+ """
205
+ self.vae.disable_slicing()
206
+
207
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
208
+ def enable_vae_tiling(self):
209
+ r"""
210
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
211
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
212
+ processing larger images.
213
+ """
214
+ self.vae.enable_tiling()
215
+
216
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
217
+ def disable_vae_tiling(self):
218
+ r"""
219
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
220
+ computing decoding in one step.
221
+ """
222
+ self.vae.disable_tiling()
223
+
224
+ def encode_prompt(
225
+ self,
226
+ prompt: str,
227
+ prompt_2: Optional[str] = None,
228
+ device: Optional[torch.device] = None,
229
+ num_images_per_prompt: int = 1,
230
+ do_classifier_free_guidance: bool = True,
231
+ negative_prompt: Optional[str] = None,
232
+ negative_prompt_2: Optional[str] = None,
233
+ prompt_embeds: Optional[torch.FloatTensor] = None,
234
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
235
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
236
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
237
+ lora_scale: Optional[float] = None,
238
+ clip_skip: Optional[int] = None,
239
+ ):
240
+ r"""
241
+ Encodes the prompt into text encoder hidden states.
242
+
243
+ Args:
244
+ prompt (`str` or `List[str]`, *optional*):
245
+ prompt to be encoded
246
+ prompt_2 (`str` or `List[str]`, *optional*):
247
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
248
+ used in both text-encoders
249
+ device: (`torch.device`):
250
+ torch device
251
+ num_images_per_prompt (`int`):
252
+ number of images that should be generated per prompt
253
+ do_classifier_free_guidance (`bool`):
254
+ whether to use classifier free guidance or not
255
+ negative_prompt (`str` or `List[str]`, *optional*):
256
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
257
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
258
+ less than `1`).
259
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
260
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
261
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
262
+ prompt_embeds (`torch.FloatTensor`, *optional*):
263
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
264
+ provided, text embeddings will be generated from `prompt` input argument.
265
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
266
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
267
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
268
+ argument.
269
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
270
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
271
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
272
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
273
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
274
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
275
+ input argument.
276
+ lora_scale (`float`, *optional*):
277
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
278
+ clip_skip (`int`, *optional*):
279
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
280
+ the output of the pre-final layer will be used for computing the prompt embeddings.
281
+ """
282
+ device = device or self._execution_device
283
+
284
+ # set lora scale so that monkey patched LoRA
285
+ # function of text encoder can correctly access it
286
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
287
+ self._lora_scale = lora_scale
288
+
289
+ # dynamically adjust the LoRA scale
290
+ if self.text_encoder is not None:
291
+ if not USE_PEFT_BACKEND:
292
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
293
+ else:
294
+ scale_lora_layers(self.text_encoder, lora_scale)
295
+
296
+ if self.text_encoder_2 is not None:
297
+ if not USE_PEFT_BACKEND:
298
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
299
+ else:
300
+ scale_lora_layers(self.text_encoder_2, lora_scale)
301
+
302
+ prompt = [prompt] if isinstance(prompt, str) else prompt
303
+
304
+ if prompt is not None:
305
+ batch_size = len(prompt)
306
+ else:
307
+ batch_size = prompt_embeds.shape[0]
308
+
309
+ # Define tokenizers and text encoders
310
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
311
+ text_encoders = (
312
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
313
+ )
314
+
315
+ if prompt_embeds is None:
316
+ prompt_2 = prompt_2 or prompt
317
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
318
+
319
+ # textual inversion: procecss multi-vector tokens if necessary
320
+ prompt_embeds_list = []
321
+ prompts = [prompt, prompt_2]
322
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
323
+ if isinstance(self, TextualInversionLoaderMixin):
324
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
325
+
326
+ text_inputs = tokenizer(
327
+ prompt,
328
+ padding="max_length",
329
+ max_length=tokenizer.model_max_length,
330
+ truncation=True,
331
+ return_tensors="pt",
332
+ )
333
+
334
+ text_input_ids = text_inputs.input_ids
335
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
336
+
337
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
338
+ text_input_ids, untruncated_ids
339
+ ):
340
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
341
+ logger.warning(
342
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
343
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
344
+ )
345
+
346
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
347
+
348
+ # We are only ALWAYS interested in the pooled output of the final text encoder
349
+ pooled_prompt_embeds = prompt_embeds[0]
350
+ if clip_skip is None:
351
+ prompt_embeds = prompt_embeds.hidden_states[-2]
352
+ else:
353
+ # "2" because SDXL always indexes from the penultimate layer.
354
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
355
+
356
+ prompt_embeds_list.append(prompt_embeds)
357
+
358
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
359
+
360
+ # get unconditional embeddings for classifier free guidance
361
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
362
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
363
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
364
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
365
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
366
+ negative_prompt = negative_prompt or ""
367
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
368
+
369
+ # normalize str to list
370
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
371
+ negative_prompt_2 = (
372
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
373
+ )
374
+
375
+ uncond_tokens: List[str]
376
+ if prompt is not None and type(prompt) is not type(negative_prompt):
377
+ raise TypeError(
378
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
379
+ f" {type(prompt)}."
380
+ )
381
+ elif batch_size != len(negative_prompt):
382
+ raise ValueError(
383
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
384
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
385
+ " the batch size of `prompt`."
386
+ )
387
+ else:
388
+ uncond_tokens = [negative_prompt, negative_prompt_2]
389
+
390
+ negative_prompt_embeds_list = []
391
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
392
+ if isinstance(self, TextualInversionLoaderMixin):
393
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
394
+
395
+ max_length = prompt_embeds.shape[1]
396
+ uncond_input = tokenizer(
397
+ negative_prompt,
398
+ padding="max_length",
399
+ max_length=max_length,
400
+ truncation=True,
401
+ return_tensors="pt",
402
+ )
403
+
404
+ negative_prompt_embeds = text_encoder(
405
+ uncond_input.input_ids.to(device),
406
+ output_hidden_states=True,
407
+ )
408
+ # We are only ALWAYS interested in the pooled output of the final text encoder
409
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
410
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
411
+
412
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
413
+
414
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
415
+
416
+ if self.text_encoder_2 is not None:
417
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
418
+ else:
419
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
420
+
421
+ bs_embed, seq_len, _ = prompt_embeds.shape
422
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
423
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
424
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
425
+
426
+ if do_classifier_free_guidance:
427
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
428
+ seq_len = negative_prompt_embeds.shape[1]
429
+
430
+ if self.text_encoder_2 is not None:
431
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
432
+ else:
433
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
434
+
435
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
436
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
437
+
438
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
439
+ bs_embed * num_images_per_prompt, -1
440
+ )
441
+ if do_classifier_free_guidance:
442
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
443
+ bs_embed * num_images_per_prompt, -1
444
+ )
445
+
446
+ if self.text_encoder is not None:
447
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
448
+ # Retrieve the original scale by scaling back the LoRA layers
449
+ unscale_lora_layers(self.text_encoder, lora_scale)
450
+
451
+ if self.text_encoder_2 is not None:
452
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
453
+ # Retrieve the original scale by scaling back the LoRA layers
454
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
455
+
456
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
457
+
458
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
459
+ def prepare_extra_step_kwargs(self, generator, eta):
460
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
461
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
462
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
463
+ # and should be between [0, 1]
464
+
465
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
466
+ extra_step_kwargs = {}
467
+ if accepts_eta:
468
+ extra_step_kwargs["eta"] = eta
469
+
470
+ # check if the scheduler accepts generator
471
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
472
+ if accepts_generator:
473
+ extra_step_kwargs["generator"] = generator
474
+ return extra_step_kwargs
475
+
476
+ def check_inputs(
477
+ self,
478
+ prompt,
479
+ prompt_2,
480
+ height,
481
+ width,
482
+ callback_steps,
483
+ negative_prompt=None,
484
+ negative_prompt_2=None,
485
+ prompt_embeds=None,
486
+ negative_prompt_embeds=None,
487
+ pooled_prompt_embeds=None,
488
+ negative_pooled_prompt_embeds=None,
489
+ callback_on_step_end_tensor_inputs=None,
490
+ ):
491
+ if height % 8 != 0 or width % 8 != 0:
492
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
493
+
494
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
495
+ raise ValueError(
496
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
497
+ f" {type(callback_steps)}."
498
+ )
499
+
500
+ if callback_on_step_end_tensor_inputs is not None and not all(
501
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
502
+ ):
503
+ raise ValueError(
504
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
505
+ )
506
+
507
+ if prompt is not None and prompt_embeds is not None:
508
+ raise ValueError(
509
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
510
+ " only forward one of the two."
511
+ )
512
+ elif prompt_2 is not None and prompt_embeds is not None:
513
+ raise ValueError(
514
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
515
+ " only forward one of the two."
516
+ )
517
+ elif prompt is None and prompt_embeds is None:
518
+ raise ValueError(
519
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
520
+ )
521
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
522
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
523
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
524
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
525
+
526
+ if negative_prompt is not None and negative_prompt_embeds is not None:
527
+ raise ValueError(
528
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
529
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
530
+ )
531
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
532
+ raise ValueError(
533
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
534
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
535
+ )
536
+
537
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
538
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
539
+ raise ValueError(
540
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
541
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
542
+ f" {negative_prompt_embeds.shape}."
543
+ )
544
+
545
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
546
+ raise ValueError(
547
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
548
+ )
549
+
550
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
551
+ raise ValueError(
552
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
553
+ )
554
+
555
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
556
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
557
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
558
+ if isinstance(generator, list) and len(generator) != batch_size:
559
+ raise ValueError(
560
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
561
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
562
+ )
563
+
564
+ if latents is None:
565
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
566
+ else:
567
+ latents = latents.to(device)
568
+
569
+ # scale the initial noise by the standard deviation required by the scheduler
570
+ latents = latents * self.scheduler.init_noise_sigma
571
+ return latents
572
+
573
+ def _get_add_time_ids(
574
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
575
+ ):
576
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
577
+
578
+ passed_add_embed_dim = (
579
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
580
+ )
581
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
582
+
583
+ if expected_add_embed_dim != passed_add_embed_dim:
584
+ raise ValueError(
585
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
586
+ )
587
+
588
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
589
+ return add_time_ids
590
+
591
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
592
+ def upcast_vae(self):
593
+ dtype = self.vae.dtype
594
+ self.vae.to(dtype=torch.float32)
595
+ use_torch_2_0_or_xformers = isinstance(
596
+ self.vae.decoder.mid_block.attentions[0].processor,
597
+ (
598
+ AttnProcessor2_0,
599
+ XFormersAttnProcessor,
600
+ LoRAXFormersAttnProcessor,
601
+ LoRAAttnProcessor2_0,
602
+ ),
603
+ )
604
+ # if xformers or torch_2_0 is used attention block does not need
605
+ # to be in float32 which can save lots of memory
606
+ if use_torch_2_0_or_xformers:
607
+ self.vae.post_quant_conv.to(dtype)
608
+ self.vae.decoder.conv_in.to(dtype)
609
+ self.vae.decoder.mid_block.to(dtype)
610
+
611
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
612
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
613
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
614
+
615
+ The suffixes after the scaling factors represent the stages where they are being applied.
616
+
617
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
618
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
619
+
620
+ Args:
621
+ s1 (`float`):
622
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
623
+ mitigate "oversmoothing effect" in the enhanced denoising process.
624
+ s2 (`float`):
625
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
626
+ mitigate "oversmoothing effect" in the enhanced denoising process.
627
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
628
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
629
+ """
630
+ if not hasattr(self, "unet"):
631
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
632
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
633
+
634
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
635
+ def disable_freeu(self):
636
+ """Disables the FreeU mechanism if enabled."""
637
+ self.unet.disable_freeu()
638
+
639
+ @property
640
+ def guidance_scale(self):
641
+ return self._guidance_scale
642
+
643
+ @property
644
+ def guidance_rescale(self):
645
+ return self._guidance_rescale
646
+
647
+ @property
648
+ def clip_skip(self):
649
+ return self._clip_skip
650
+
651
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
652
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
653
+ # corresponds to doing no classifier free guidance.
654
+ @property
655
+ def do_classifier_free_guidance(self):
656
+ return self._guidance_scale > 1
657
+
658
+ @property
659
+ def cross_attention_kwargs(self):
660
+ return self._cross_attention_kwargs
661
+
662
+ @property
663
+ def denoising_end(self):
664
+ return self._denoising_end
665
+
666
+ @property
667
+ def num_timesteps(self):
668
+ return self._num_timesteps
669
+
670
+ @torch.no_grad()
671
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
672
+ def __call__(
673
+ self,
674
+ prompt: Union[str, List[str]] = None,
675
+ prompt_2: Optional[Union[str, List[str]]] = None,
676
+ height: Optional[int] = None,
677
+ width: Optional[int] = None,
678
+ num_inference_steps: int = 50,
679
+ denoising_end: Optional[float] = None,
680
+ guidance_scale: float = 5.0,
681
+ negative_prompt: Optional[Union[str, List[str]]] = None,
682
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
683
+ num_images_per_prompt: Optional[int] = 1,
684
+ eta: float = 0.0,
685
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
686
+ latents: Optional[torch.FloatTensor] = None,
687
+ prompt_embeds: Optional[torch.FloatTensor] = None,
688
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
689
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
690
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
691
+ output_type: Optional[str] = "pil",
692
+ return_dict: bool = True,
693
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
694
+ guidance_rescale: float = 0.0,
695
+ original_size: Optional[Tuple[int, int]] = None,
696
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
697
+ target_size: Optional[Tuple[int, int]] = None,
698
+ negative_original_size: Optional[Tuple[int, int]] = None,
699
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
700
+ negative_target_size: Optional[Tuple[int, int]] = None,
701
+ clip_skip: Optional[int] = None,
702
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
703
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
704
+ **kwargs,
705
+ ):
706
+ r"""
707
+ Function invoked when calling the pipeline for generation.
708
+
709
+ Args:
710
+ prompt (`str` or `List[str]`, *optional*):
711
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
712
+ instead.
713
+ prompt_2 (`str` or `List[str]`, *optional*):
714
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
715
+ used in both text-encoders
716
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
717
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
718
+ Anything below 512 pixels won't work well for
719
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
720
+ and checkpoints that are not specifically fine-tuned on low resolutions.
721
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
722
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
723
+ Anything below 512 pixels won't work well for
724
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
725
+ and checkpoints that are not specifically fine-tuned on low resolutions.
726
+ num_inference_steps (`int`, *optional*, defaults to 50):
727
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
728
+ expense of slower inference.
729
+ denoising_end (`float`, *optional*):
730
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
731
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
732
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
733
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
734
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
735
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
736
+ guidance_scale (`float`, *optional*, defaults to 5.0):
737
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
738
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
739
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
740
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
741
+ usually at the expense of lower image quality.
742
+ negative_prompt (`str` or `List[str]`, *optional*):
743
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
744
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
745
+ less than `1`).
746
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
747
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
748
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
749
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
750
+ The number of images to generate per prompt.
751
+ eta (`float`, *optional*, defaults to 0.0):
752
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
753
+ [`schedulers.DDIMScheduler`], will be ignored for others.
754
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
755
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
756
+ to make generation deterministic.
757
+ latents (`torch.FloatTensor`, *optional*):
758
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
759
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
760
+ tensor will ge generated by sampling using the supplied random `generator`.
761
+ prompt_embeds (`torch.FloatTensor`, *optional*):
762
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
763
+ provided, text embeddings will be generated from `prompt` input argument.
764
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
765
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
766
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
767
+ argument.
768
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
769
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
770
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
771
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
772
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
773
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
774
+ input argument.
775
+ output_type (`str`, *optional*, defaults to `"pil"`):
776
+ The output format of the generate image. Choose between
777
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
778
+ return_dict (`bool`, *optional*, defaults to `True`):
779
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
780
+ of a plain tuple.
781
+ cross_attention_kwargs (`dict`, *optional*):
782
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
783
+ `self.processor` in
784
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
785
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
786
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
787
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
788
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
789
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
790
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
791
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
792
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
793
+ explained in section 2.2 of
794
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
795
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
796
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
797
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
798
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
799
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
800
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
801
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
802
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
803
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
804
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
805
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
806
+ micro-conditioning as explained in section 2.2 of
807
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
808
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
809
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
810
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
811
+ micro-conditioning as explained in section 2.2 of
812
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
813
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
814
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
815
+ To negatively condition the generation process based on a target image resolution. It should be as same
816
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
817
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
818
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
819
+ callback_on_step_end (`Callable`, *optional*):
820
+ A function that calls at the end of each denoising steps during the inference. The function is called
821
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
822
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
823
+ `callback_on_step_end_tensor_inputs`.
824
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
825
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
826
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
827
+ `._callback_tensor_inputs` attribute of your pipeine class.
828
+
829
+ Examples:
830
+
831
+ Returns:
832
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
833
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
834
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
835
+ """
836
+
837
+ callback = kwargs.pop("callback", None)
838
+ callback_steps = kwargs.pop("callback_steps", None)
839
+
840
+ if callback is not None:
841
+ deprecate(
842
+ "callback",
843
+ "1.0.0",
844
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
845
+ )
846
+ if callback_steps is not None:
847
+ deprecate(
848
+ "callback_steps",
849
+ "1.0.0",
850
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
851
+ )
852
+
853
+ # 0. Default height and width to unet
854
+ height = height or self.default_sample_size * self.vae_scale_factor
855
+ width = width or self.default_sample_size * self.vae_scale_factor
856
+
857
+ original_size = original_size or (height, width)
858
+ target_size = target_size or (height, width)
859
+
860
+ # 1. Check inputs. Raise error if not correct
861
+ self.check_inputs(
862
+ prompt,
863
+ prompt_2,
864
+ height,
865
+ width,
866
+ callback_steps,
867
+ negative_prompt,
868
+ negative_prompt_2,
869
+ prompt_embeds,
870
+ negative_prompt_embeds,
871
+ pooled_prompt_embeds,
872
+ negative_pooled_prompt_embeds,
873
+ callback_on_step_end_tensor_inputs,
874
+ )
875
+
876
+ self._guidance_scale = guidance_scale
877
+ self._guidance_rescale = guidance_rescale
878
+ self._clip_skip = clip_skip
879
+ self._cross_attention_kwargs = cross_attention_kwargs
880
+ self._denoising_end = denoising_end
881
+
882
+ # 2. Define call parameters
883
+ if prompt is not None and isinstance(prompt, str):
884
+ batch_size = 1
885
+ elif prompt is not None and isinstance(prompt, list):
886
+ batch_size = len(prompt)
887
+ else:
888
+ batch_size = prompt_embeds.shape[0]
889
+
890
+ device = self._execution_device
891
+
892
+ # 3. Encode input prompt
893
+ lora_scale = (
894
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
895
+ )
896
+
897
+ (
898
+ prompt_embeds,
899
+ negative_prompt_embeds,
900
+ pooled_prompt_embeds,
901
+ negative_pooled_prompt_embeds,
902
+ ) = self.encode_prompt(
903
+ prompt=prompt,
904
+ prompt_2=prompt_2,
905
+ device=device,
906
+ num_images_per_prompt=num_images_per_prompt,
907
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
908
+ negative_prompt=negative_prompt,
909
+ negative_prompt_2=negative_prompt_2,
910
+ prompt_embeds=prompt_embeds,
911
+ negative_prompt_embeds=negative_prompt_embeds,
912
+ pooled_prompt_embeds=pooled_prompt_embeds,
913
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
914
+ lora_scale=lora_scale,
915
+ clip_skip=self.clip_skip,
916
+ )
917
+
918
+ if kwargs['target_prompt'] is not None:
919
+ (
920
+ prompt_embeds_,
921
+ negative_prompt_embeds_,
922
+ _,
923
+ _,
924
+ ) = self.encode_prompt(
925
+ prompt=kwargs['target_prompt'],
926
+ prompt_2=prompt_2,
927
+ device=device,
928
+ num_images_per_prompt=num_images_per_prompt,
929
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
930
+ # negative_prompt=negative_prompt,
931
+ negative_prompt=None,
932
+ # negative_prompt_2=negative_prompt_2,
933
+ negative_prompt_2=None,
934
+ prompt_embeds=None,
935
+ negative_prompt_embeds=None,
936
+ pooled_prompt_embeds=None,
937
+ negative_pooled_prompt_embeds=None,
938
+ lora_scale=lora_scale,
939
+ clip_skip=self.clip_skip,
940
+ )
941
+ prompt_embeds[1:] = prompt_embeds_[1:]
942
+ negative_prompt_embeds[1:] = negative_prompt_embeds_[1:]
943
+
944
+
945
+ # 4. Prepare timesteps
946
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
947
+
948
+ timesteps = self.scheduler.timesteps
949
+
950
+ # 5. Prepare latent variables
951
+ num_channels_latents = self.unet.config.in_channels
952
+ latents = self.prepare_latents(
953
+ batch_size * num_images_per_prompt,
954
+ num_channels_latents,
955
+ height,
956
+ width,
957
+ prompt_embeds.dtype,
958
+ device,
959
+ generator,
960
+ latents,
961
+ )
962
+
963
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
964
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
965
+
966
+ # 7. Prepare added time ids & embeddings
967
+ add_text_embeds = pooled_prompt_embeds
968
+ if self.text_encoder_2 is None:
969
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
970
+ else:
971
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
972
+
973
+ add_time_ids = self._get_add_time_ids(
974
+ original_size,
975
+ crops_coords_top_left,
976
+ target_size,
977
+ dtype=prompt_embeds.dtype,
978
+ text_encoder_projection_dim=text_encoder_projection_dim,
979
+ )
980
+ if negative_original_size is not None and negative_target_size is not None:
981
+ negative_add_time_ids = self._get_add_time_ids(
982
+ negative_original_size,
983
+ negative_crops_coords_top_left,
984
+ negative_target_size,
985
+ dtype=prompt_embeds.dtype,
986
+ text_encoder_projection_dim=text_encoder_projection_dim,
987
+ )
988
+ else:
989
+ negative_add_time_ids = add_time_ids
990
+
991
+ if self.do_classifier_free_guidance:
992
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
993
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
994
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
995
+
996
+ prompt_embeds = prompt_embeds.to(device)
997
+ add_text_embeds = add_text_embeds.to(device)
998
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
999
+
1000
+ # 8. Denoising loop
1001
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1002
+
1003
+ # 8.1 Apply denoising_end
1004
+ if (
1005
+ self.denoising_end is not None
1006
+ and isinstance(self.denoising_end, float)
1007
+ and self.denoising_end > 0
1008
+ and self.denoising_end < 1
1009
+ ):
1010
+ discrete_timestep_cutoff = int(
1011
+ round(
1012
+ self.scheduler.config.num_train_timesteps
1013
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1014
+ )
1015
+ )
1016
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1017
+ timesteps = timesteps[:num_inference_steps]
1018
+
1019
+ self._num_timesteps = len(timesteps)
1020
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1021
+ for i, t in enumerate(timesteps):
1022
+ # expand the latents if we are doing classifier free guidance
1023
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1024
+
1025
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1026
+
1027
+ # predict the noise residual
1028
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1029
+ noise_pred = self.unet(
1030
+ latent_model_input,
1031
+ t,
1032
+ encoder_hidden_states=prompt_embeds,
1033
+ cross_attention_kwargs=self.cross_attention_kwargs,
1034
+ added_cond_kwargs=added_cond_kwargs,
1035
+ return_dict=False,
1036
+ )[0]
1037
+
1038
+ # perform guidance
1039
+ if self.do_classifier_free_guidance:
1040
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1041
+
1042
+ if i < 3:
1043
+ noise_pred = noise_pred_uncond + 15.0 * (noise_pred_text - noise_pred_uncond)
1044
+ else:
1045
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1046
+
1047
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1048
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1049
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1050
+
1051
+ # compute the previous noisy sample x_t -> x_t-1
1052
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1053
+
1054
+ if callback_on_step_end is not None:
1055
+ callback_kwargs = {}
1056
+ for k in callback_on_step_end_tensor_inputs:
1057
+ callback_kwargs[k] = locals()[k]
1058
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1059
+
1060
+ latents = callback_outputs.pop("latents", latents)
1061
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1062
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1063
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1064
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1065
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1066
+ )
1067
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1068
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1069
+
1070
+ # call the callback, if provided
1071
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1072
+ progress_bar.update()
1073
+ if callback is not None and i % callback_steps == 0:
1074
+ step_idx = i // getattr(self.scheduler, "order", 1)
1075
+ callback(step_idx, t, latents)
1076
+
1077
+ if XLA_AVAILABLE:
1078
+ xm.mark_step()
1079
+
1080
+ if not output_type == "latent":
1081
+ # make sure the VAE is in float32 mode, as it overflows in float16
1082
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1083
+
1084
+ if needs_upcasting:
1085
+ self.upcast_vae()
1086
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1087
+
1088
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1089
+
1090
+ # cast back to fp16 if needed
1091
+ if needs_upcasting:
1092
+ self.vae.to(dtype=torch.float16)
1093
+ else:
1094
+ image = latents
1095
+
1096
+ if not output_type == "latent":
1097
+ # apply watermark if available
1098
+ if self.watermark is not None:
1099
+ image = self.watermark.apply_watermark(image)
1100
+
1101
+ image = self.image_processor.postprocess(image, output_type=output_type)
1102
+
1103
+ # Offload all models
1104
+ self.maybe_free_model_hooks()
1105
+
1106
+ if not return_dict:
1107
+ return (image,)
1108
+
1109
+ return StableDiffusionXLPipelineOutput(images=image)
1110
+
1111
+ @torch.no_grad()
1112
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1113
+ def inverted_ve_cross_frame_attn(
1114
+ self,
1115
+ prompt: Union[str, List[str]] = None,
1116
+ prompt_2: Optional[Union[str, List[str]]] = None,
1117
+ height: Optional[int] = None,
1118
+ width: Optional[int] = None,
1119
+ num_inference_steps: int = 50,
1120
+ denoising_end: Optional[float] = None,
1121
+ guidance_scale: float = 5.0,
1122
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1123
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1124
+ num_images_per_prompt: Optional[int] = 1,
1125
+ eta: float = 0.0,
1126
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1127
+ latents: Optional[torch.FloatTensor] = None,
1128
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1129
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1130
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1131
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1132
+ output_type: Optional[str] = "pil",
1133
+ return_dict: bool = True,
1134
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1135
+ guidance_rescale: float = 0.0,
1136
+ original_size: Optional[Tuple[int, int]] = None,
1137
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1138
+ target_size: Optional[Tuple[int, int]] = None,
1139
+ negative_original_size: Optional[Tuple[int, int]] = None,
1140
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1141
+ negative_target_size: Optional[Tuple[int, int]] = None,
1142
+ clip_skip: Optional[int] = None,
1143
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1144
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1145
+ **kwargs,
1146
+ ):
1147
+ r"""
1148
+ Function invoked when calling the pipeline for generation.
1149
+
1150
+ Args:
1151
+ prompt (`str` or `List[str]`, *optional*):
1152
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1153
+ instead.
1154
+ prompt_2 (`str` or `List[str]`, *optional*):
1155
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1156
+ used in both text-encoders
1157
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1158
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1159
+ Anything below 512 pixels won't work well for
1160
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1161
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1162
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1163
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1164
+ Anything below 512 pixels won't work well for
1165
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1166
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1167
+ num_inference_steps (`int`, *optional*, defaults to 50):
1168
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1169
+ expense of slower inference.
1170
+ denoising_end (`float`, *optional*):
1171
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1172
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1173
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
1174
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
1175
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1176
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1177
+ guidance_scale (`float`, *optional*, defaults to 5.0):
1178
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1179
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1180
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1181
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1182
+ usually at the expense of lower image quality.
1183
+ negative_prompt (`str` or `List[str]`, *optional*):
1184
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1185
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1186
+ less than `1`).
1187
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1188
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1189
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1190
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1191
+ The number of images to generate per prompt.
1192
+ eta (`float`, *optional*, defaults to 0.0):
1193
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1194
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1195
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1196
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1197
+ to make generation deterministic.
1198
+ latents (`torch.FloatTensor`, *optional*):
1199
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1200
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1201
+ tensor will ge generated by sampling using the supplied random `generator`.
1202
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1203
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1204
+ provided, text embeddings will be generated from `prompt` input argument.
1205
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1206
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1207
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1208
+ argument.
1209
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1210
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1211
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1212
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1213
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1214
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1215
+ input argument.
1216
+ output_type (`str`, *optional*, defaults to `"pil"`):
1217
+ The output format of the generate image. Choose between
1218
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1219
+ return_dict (`bool`, *optional*, defaults to `True`):
1220
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
1221
+ of a plain tuple.
1222
+ cross_attention_kwargs (`dict`, *optional*):
1223
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1224
+ `self.processor` in
1225
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1226
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
1227
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
1228
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
1229
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
1230
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
1231
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1232
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1233
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1234
+ explained in section 2.2 of
1235
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1236
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1237
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1238
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1239
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1240
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1241
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1242
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1243
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1244
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1245
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1246
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1247
+ micro-conditioning as explained in section 2.2 of
1248
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1249
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1250
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1251
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1252
+ micro-conditioning as explained in section 2.2 of
1253
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1254
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1255
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1256
+ To negatively condition the generation process based on a target image resolution. It should be as same
1257
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1258
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1259
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1260
+ callback_on_step_end (`Callable`, *optional*):
1261
+ A function that calls at the end of each denoising steps during the inference. The function is called
1262
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1263
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1264
+ `callback_on_step_end_tensor_inputs`.
1265
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1266
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1267
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1268
+ `._callback_tensor_inputs` attribute of your pipeine class.
1269
+
1270
+ Examples:
1271
+
1272
+ Returns:
1273
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
1274
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1275
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1276
+ """
1277
+
1278
+ callback = kwargs.pop("callback", None)
1279
+ callback_steps = kwargs.pop("callback_steps", None)
1280
+
1281
+ if callback is not None:
1282
+ deprecate(
1283
+ "callback",
1284
+ "1.0.0",
1285
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1286
+ )
1287
+ if callback_steps is not None:
1288
+ deprecate(
1289
+ "callback_steps",
1290
+ "1.0.0",
1291
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1292
+ )
1293
+
1294
+ # 0. Default height and width to unet
1295
+ height = height or self.default_sample_size * self.vae_scale_factor
1296
+ width = width or self.default_sample_size * self.vae_scale_factor
1297
+
1298
+ original_size = original_size or (height, width)
1299
+ target_size = target_size or (height, width)
1300
+
1301
+ # 1. Check inputs. Raise error if not correct
1302
+ self.check_inputs(
1303
+ prompt,
1304
+ prompt_2,
1305
+ height,
1306
+ width,
1307
+ callback_steps,
1308
+ negative_prompt,
1309
+ negative_prompt_2,
1310
+ prompt_embeds,
1311
+ negative_prompt_embeds,
1312
+ pooled_prompt_embeds,
1313
+ negative_pooled_prompt_embeds,
1314
+ callback_on_step_end_tensor_inputs,
1315
+ )
1316
+
1317
+ self._guidance_scale = guidance_scale
1318
+ self._guidance_rescale = guidance_rescale
1319
+ self._clip_skip = clip_skip
1320
+ self._cross_attention_kwargs = cross_attention_kwargs
1321
+ self._denoising_end = denoising_end
1322
+
1323
+ # 2. Define call parameters
1324
+ if prompt is not None and isinstance(prompt, str):
1325
+ batch_size = 1
1326
+ elif prompt is not None and isinstance(prompt, list):
1327
+ batch_size = len(prompt)
1328
+ else:
1329
+ batch_size = prompt_embeds.shape[0]
1330
+
1331
+ device = self._execution_device
1332
+
1333
+ # 3. Encode input prompt
1334
+ lora_scale = (
1335
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1336
+ )
1337
+
1338
+ (
1339
+ prompt_embeds,
1340
+ negative_prompt_embeds,
1341
+ pooled_prompt_embeds,
1342
+ negative_pooled_prompt_embeds,
1343
+ ) = self.encode_prompt(
1344
+ prompt=prompt,
1345
+ prompt_2=prompt_2,
1346
+ device=device,
1347
+ num_images_per_prompt=num_images_per_prompt,
1348
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1349
+ negative_prompt=negative_prompt,
1350
+ negative_prompt_2=negative_prompt_2,
1351
+ prompt_embeds=prompt_embeds,
1352
+ negative_prompt_embeds=negative_prompt_embeds,
1353
+ pooled_prompt_embeds=pooled_prompt_embeds,
1354
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1355
+ lora_scale=lora_scale,
1356
+ clip_skip=self.clip_skip,
1357
+ )
1358
+
1359
+ if kwargs['target_prompt'] is not None:
1360
+ (
1361
+ prompt_embeds_,
1362
+ negative_prompt_embeds_,
1363
+ _,
1364
+ _,
1365
+ ) = self.encode_prompt(
1366
+ prompt=kwargs['target_prompt'],
1367
+ prompt_2=prompt_2,
1368
+ device=device,
1369
+ num_images_per_prompt=num_images_per_prompt,
1370
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1371
+ negative_prompt=kwargs['target_negative_prompt'] if kwargs['target_negative_prompt'] is not None else None,
1372
+ # negative_prompt=None,
1373
+ # negative_prompt_2=negative_prompt_2,
1374
+ negative_prompt_2=None,
1375
+ prompt_embeds=None,
1376
+ negative_prompt_embeds=None,
1377
+ pooled_prompt_embeds=None,
1378
+ negative_pooled_prompt_embeds=None,
1379
+ lora_scale=lora_scale,
1380
+ clip_skip=self.clip_skip,
1381
+ )
1382
+ prompt_embeds[1:] = prompt_embeds_[1:]
1383
+ if negative_prompt_embeds_ is not None:
1384
+ negative_prompt_embeds[1:] = negative_prompt_embeds_[1:]
1385
+
1386
+
1387
+ # 4. Prepare timesteps
1388
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1389
+
1390
+ timesteps = self.scheduler.timesteps
1391
+
1392
+ # 5. Prepare latent variables
1393
+ num_channels_latents = self.unet.config.in_channels
1394
+ latents = self.prepare_latents(
1395
+ batch_size * num_images_per_prompt,
1396
+ num_channels_latents,
1397
+ height,
1398
+ width,
1399
+ prompt_embeds.dtype,
1400
+ device,
1401
+ generator,
1402
+ latents,
1403
+ )
1404
+
1405
+
1406
+ # import pdb; pdb.set_trace()
1407
+
1408
+ latents_ = self.prepare_latents(
1409
+ batch_size * num_images_per_prompt,
1410
+ num_channels_latents,
1411
+ height,
1412
+ width,
1413
+ prompt_embeds.dtype,
1414
+ device,
1415
+ generator,
1416
+ # latents,
1417
+ )
1418
+
1419
+ # import pdb; pdb.set_trace()
1420
+
1421
+ # latents[1:] = latents_[1:]
1422
+ latents = torch.cat([latents.unsqueeze(0), latents_[1:]], dim=0)
1423
+
1424
+
1425
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1426
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1427
+
1428
+ # 7. Prepare added time ids & embeddings
1429
+ add_text_embeds = pooled_prompt_embeds
1430
+ if self.text_encoder_2 is None:
1431
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1432
+ else:
1433
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1434
+
1435
+ add_time_ids = self._get_add_time_ids(
1436
+ original_size,
1437
+ crops_coords_top_left,
1438
+ target_size,
1439
+ dtype=prompt_embeds.dtype,
1440
+ text_encoder_projection_dim=text_encoder_projection_dim,
1441
+ )
1442
+ if negative_original_size is not None and negative_target_size is not None:
1443
+ negative_add_time_ids = self._get_add_time_ids(
1444
+ negative_original_size,
1445
+ negative_crops_coords_top_left,
1446
+ negative_target_size,
1447
+ dtype=prompt_embeds.dtype,
1448
+ text_encoder_projection_dim=text_encoder_projection_dim,
1449
+ )
1450
+ else:
1451
+ negative_add_time_ids = add_time_ids
1452
+
1453
+ if self.do_classifier_free_guidance:
1454
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1455
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1456
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1457
+
1458
+ prompt_embeds = prompt_embeds.to(device)
1459
+ add_text_embeds = add_text_embeds.to(device)
1460
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1461
+
1462
+ # 8. Denoising loop
1463
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1464
+
1465
+ # 8.1 Apply denoising_end
1466
+ if (
1467
+ self.denoising_end is not None
1468
+ and isinstance(self.denoising_end, float)
1469
+ and self.denoising_end > 0
1470
+ and self.denoising_end < 1
1471
+ ):
1472
+ discrete_timestep_cutoff = int(
1473
+ round(
1474
+ self.scheduler.config.num_train_timesteps
1475
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1476
+ )
1477
+ )
1478
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1479
+ timesteps = timesteps[:num_inference_steps]
1480
+
1481
+ self._num_timesteps = len(timesteps)
1482
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1483
+ for i, t in enumerate(timesteps):
1484
+ # expand the latents if we are doing classifier free guidance
1485
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1486
+
1487
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1488
+
1489
+ # predict the noise residual
1490
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1491
+ noise_pred = self.unet(
1492
+ latent_model_input,
1493
+ t,
1494
+ encoder_hidden_states=prompt_embeds,
1495
+ cross_attention_kwargs=self.cross_attention_kwargs,
1496
+ added_cond_kwargs=added_cond_kwargs,
1497
+ return_dict=False,
1498
+ )[0]
1499
+
1500
+ # perform guidance
1501
+ if self.do_classifier_free_guidance:
1502
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1503
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1504
+ noise_pred[0] = noise_pred_uncond[0] #추가된것
1505
+
1506
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1507
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1508
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1509
+ noise_pred[0] = noise_pred_uncond[0] #추가된것
1510
+
1511
+ # compute the previous noisy sample x_t -> x_t-1
1512
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1513
+
1514
+ if callback_on_step_end is not None:
1515
+ callback_kwargs = {}
1516
+ for k in callback_on_step_end_tensor_inputs:
1517
+ callback_kwargs[k] = locals()[k]
1518
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1519
+
1520
+ latents = callback_outputs.pop("latents", latents)
1521
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1522
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1523
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1524
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1525
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1526
+ )
1527
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1528
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1529
+
1530
+ # call the callback, if provided
1531
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1532
+ progress_bar.update()
1533
+ if callback is not None and i % callback_steps == 0:
1534
+ step_idx = i // getattr(self.scheduler, "order", 1)
1535
+ callback(step_idx, t, latents)
1536
+
1537
+ if XLA_AVAILABLE:
1538
+ xm.mark_step()
1539
+
1540
+ if not output_type == "latent":
1541
+ # make sure the VAE is in float32 mode, as it overflows in float16
1542
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1543
+
1544
+ if needs_upcasting:
1545
+ self.upcast_vae()
1546
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1547
+
1548
+ self.enable_vae_slicing()
1549
+
1550
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1551
+
1552
+ # cast back to fp16 if needed
1553
+ if needs_upcasting:
1554
+ self.vae.to(dtype=torch.float16)
1555
+ else:
1556
+ image = latents
1557
+
1558
+ if not output_type == "latent":
1559
+ # apply watermark if available
1560
+ if self.watermark is not None:
1561
+ image = self.watermark.apply_watermark(image)
1562
+
1563
+ image = self.image_processor.postprocess(image, output_type=output_type)
1564
+
1565
+ # Offload all models
1566
+ self.maybe_free_model_hooks()
1567
+
1568
+ if not return_dict:
1569
+ return (image,)
1570
+
1571
+ return StableDiffusionXLPipelineOutput(images=image)
1572
+
1573
+
visualize_attention_src/save_attn_map_script.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pipelines.inverted_ve_pipeline import CrossFrameAttnProcessor, CrossFrameAttnProcessor_store, ACTIVATE_LAYER_CANDIDATE
3
+ from diffusers import DDIMScheduler, AutoencoderKL
4
+ import os
5
+ from PIL import Image
6
+ from utils import memory_efficient
7
+ from diffusers.models.attention_processor import AttnProcessor
8
+ from pipeline_stable_diffusion_xl_attn import StableDiffusionXLPipeline
9
+
10
+
11
+ def create_image_grid(image_list, rows, cols, padding=10):
12
+ # Ensure the number of rows and columns doesn't exceed the number of images
13
+ rows = min(rows, len(image_list))
14
+ cols = min(cols, len(image_list))
15
+
16
+ # Get the dimensions of a single image
17
+ image_width, image_height = image_list[0].size
18
+
19
+ # Calculate the size of the output image
20
+ grid_width = cols * (image_width + padding) - padding
21
+ grid_height = rows * (image_height + padding) - padding
22
+
23
+ # Create an empty grid image
24
+ grid_image = Image.new('RGB', (grid_width, grid_height), (255, 255, 255))
25
+
26
+ # Paste images into the grid
27
+ for i, img in enumerate(image_list[:rows * cols]):
28
+ row = i // cols
29
+ col = i % cols
30
+ x = col * (image_width + padding)
31
+ y = row * (image_height + padding)
32
+ grid_image.paste(img, (x, y))
33
+
34
+ return grid_image
35
+
36
+ def transform_variable_name(input_str, attn_map_save_step):
37
+ # Split the input string into parts using the dot as a separator
38
+ parts = input_str.split('.')
39
+
40
+ # Extract numerical indices from the parts
41
+ indices = [int(part) if part.isdigit() else part for part in parts]
42
+
43
+ # Build the desired output string
44
+ output_str = f'pipe.unet.{indices[0]}[{indices[1]}].{indices[2]}[{indices[3]}].{indices[4]}[{indices[5]}].{indices[6]}.attn_map[{attn_map_save_step}]'
45
+
46
+ return output_str
47
+
48
+
49
+ num_images_per_prompt = 4
50
+ seeds=[1] #craft_clay
51
+
52
+
53
+ activate_layer_indices_list = [
54
+ # ((0,28),(108,140)),
55
+ # ((0,48), (68,140)),
56
+ # ((0,48), (88,140)),
57
+ # ((0,48), (108,140)),
58
+ # ((0,48), (128,140)),
59
+ # ((0,48), (140,140)),
60
+ # ((0,28), (68,140)),
61
+ # ((0,28), (88,140)),
62
+ # ((0,28), (108,140)),
63
+ # ((0,28), (128,140)),
64
+ # ((0,28), (140,140)),
65
+ # ((0,8), (68,140)),
66
+ # ((0,8), (88,140)),
67
+ # ((0,8), (108,140)),
68
+ # ((0,8), (128,140)),
69
+ # ((0,8), (140,140)),
70
+ # ((0,0), (68,140)),
71
+ # ((0,0), (88,140)),
72
+ ((0,0), (108,140)),
73
+ # ((0,0), (128,140)),
74
+ # ((0,0), (140,140))
75
+ ]
76
+
77
+ save_layer_list = [
78
+ # 'up_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', #68
79
+ # 'up_blocks.0.attentions.1.transformer_blocks.4.attn2.processor', #78
80
+ # 'up_blocks.0.attentions.2.transformer_blocks.0.attn1.processor', #88
81
+ # 'up_blocks.0.attentions.2.transformer_blocks.4.attn2.processor', #108
82
+ # 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', #128
83
+ # 'up_blocks.1.attentions.2.transformer_blocks.1.attn1.processor', #138
84
+
85
+ 'up_blocks.0.attentions.2.transformer_blocks.0.attn1.processor', #108
86
+ 'up_blocks.0.attentions.2.transformer_blocks.0.attn2.processor',
87
+ 'up_blocks.0.attentions.2.transformer_blocks.1.attn1.processor',
88
+ 'up_blocks.0.attentions.2.transformer_blocks.1.attn2.processor',
89
+ 'up_blocks.0.attentions.2.transformer_blocks.2.attn1.processor',
90
+ 'up_blocks.0.attentions.2.transformer_blocks.2.attn2.processor',
91
+ 'up_blocks.0.attentions.2.transformer_blocks.3.attn1.processor',
92
+ 'up_blocks.0.attentions.2.transformer_blocks.3.attn2.processor',
93
+ 'up_blocks.0.attentions.2.transformer_blocks.4.attn1.processor',
94
+ 'up_blocks.0.attentions.2.transformer_blocks.4.attn2.processor',
95
+ 'up_blocks.0.attentions.2.transformer_blocks.5.attn1.processor',
96
+ 'up_blocks.0.attentions.2.transformer_blocks.5.attn2.processor',
97
+ 'up_blocks.0.attentions.2.transformer_blocks.6.attn1.processor',
98
+ 'up_blocks.0.attentions.2.transformer_blocks.6.attn2.processor',
99
+ 'up_blocks.0.attentions.2.transformer_blocks.7.attn1.processor',
100
+ 'up_blocks.0.attentions.2.transformer_blocks.7.attn2.processor',
101
+ 'up_blocks.0.attentions.2.transformer_blocks.8.attn1.processor',
102
+ 'up_blocks.0.attentions.2.transformer_blocks.8.attn2.processor',
103
+ 'up_blocks.0.attentions.2.transformer_blocks.9.attn1.processor',
104
+ 'up_blocks.0.attentions.2.transformer_blocks.9.attn2.processor',
105
+
106
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', #128
107
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
108
+ 'up_blocks.1.attentions.0.transformer_blocks.1.attn1.processor',
109
+ 'up_blocks.1.attentions.0.transformer_blocks.1.attn2.processor',
110
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor',
111
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
112
+ 'up_blocks.1.attentions.1.transformer_blocks.1.attn1.processor',
113
+ 'up_blocks.1.attentions.1.transformer_blocks.1.attn2.processor',
114
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor',
115
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor',
116
+ 'up_blocks.1.attentions.2.transformer_blocks.1.attn1.processor',
117
+ 'up_blocks.1.attentions.2.transformer_blocks.1.attn2.processor',
118
+ ]
119
+
120
+ attn_map_save_steps = [20]
121
+ # attn_map_save_steps = [10,20,30,40]
122
+
123
+ results_dir = 'saved_attention_map_results'
124
+ if not os.path.exists(results_dir):
125
+ os.makedirs(results_dir)
126
+
127
+ base_model_path = "runwayml/stable-diffusion-v1-5"
128
+ vae_model_path = "stabilityai/sd-vae-ft-mse"
129
+ image_encoder_path = "models/image_encoder/"
130
+
131
+
132
+ object_list = [
133
+ "cat",
134
+ # "woman",
135
+ # "dog",
136
+ # "horse",
137
+ # "motorcycle"
138
+ ]
139
+
140
+ target_object_list = [
141
+ # "Null",
142
+ "dog",
143
+ # "clock",
144
+ # "car"
145
+ # "panda",
146
+ # "bridge",
147
+ # "flower"
148
+ ]
149
+
150
+ prompt_neg_prompt_pair_dicts = {
151
+
152
+ # "line_art": ("line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
153
+ # "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic"
154
+ # ) ,
155
+
156
+ # "anime": ("anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
157
+ # "photo, deformed, black and white, realism, disfigured, low contrast"
158
+ # ),
159
+
160
+ # "Artstyle_Pop_Art" : ("pop Art style {prompt} . bright colors, bold outlines, popular culture themes, ironic or kitsch",
161
+ # "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, minimalist"
162
+ # ),
163
+
164
+ # "Artstyle_Pointillism": ("pointillism style {prompt} . composed entirely of small, distinct dots of color, vibrant, highly detailed",
165
+ # "line drawing, smooth shading, large color fields, simplistic"
166
+ # ),
167
+
168
+ # "origami": ("origami style {prompt} . paper art, pleated paper, folded, origami art, pleats, cut and fold, centered composition",
169
+ # "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo"
170
+ # ),
171
+
172
+ "craft_clay": ("play-doh style {prompt} . sculpture, clay art, centered composition, Claymation",
173
+ "sloppy, messy, grainy, highly detailed, ultra textured, photo"
174
+ ),
175
+
176
+ # "low_poly" : ("low-poly style {prompt} . low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition",
177
+ # "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo"
178
+ # ),
179
+
180
+ # "Artstyle_watercolor": ("watercolor painting {prompt} . vibrant, beautiful, painterly, detailed, textural, artistic",
181
+ # "anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy"
182
+ # ),
183
+
184
+ # "Papercraft_Collage" : ("collage style {prompt} . mixed media, layered, textural, detailed, artistic",
185
+ # "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic"
186
+ # ),
187
+
188
+ # "Artstyle_Impressionist" : ("impressionist painting {prompt} . loose brushwork, vibrant color, light and shadow play, captures feeling over form",
189
+ # "anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy"
190
+ # )
191
+
192
+ }
193
+
194
+
195
+
196
+ noise_scheduler = DDIMScheduler(
197
+ num_train_timesteps=1000,
198
+ beta_start=0.00085,
199
+ beta_end=0.012,
200
+ beta_schedule="scaled_linear",
201
+ clip_sample=False,
202
+ set_alpha_to_one=False,
203
+ steps_offset=1,
204
+ )
205
+
206
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
207
+ if device == 'cpu':
208
+ torch_dtype = torch.float32
209
+ else:
210
+ torch_dtype = torch.float16
211
+
212
+ vae = AutoencoderKL.from_pretrained(vae_model_path, torch_dtype=torch_dtype)
213
+ pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch_dtype)
214
+
215
+
216
+ memory_efficient(vae, device)
217
+ memory_efficient(pipe, device)
218
+
219
+ for seed in seeds:
220
+ for activate_layer_indices in activate_layer_indices_list:
221
+ attn_procs = {}
222
+ activate_layers = []
223
+ str_activate_layer = ""
224
+ for activate_layer_index in activate_layer_indices:
225
+ activate_layers += ACTIVATE_LAYER_CANDIDATE[activate_layer_index[0]:activate_layer_index[1]]
226
+ str_activate_layer += str(activate_layer_index)
227
+
228
+
229
+ for name in pipe.unet.attn_processors.keys():
230
+ if name in activate_layers:
231
+ if name in save_layer_list:
232
+ print(f"layer:{name}")
233
+ attn_procs[name] = CrossFrameAttnProcessor_store(unet_chunk_size=2, attn_map_save_steps=attn_map_save_steps)
234
+ else:
235
+ print(f"layer:{name}")
236
+ attn_procs[name] = CrossFrameAttnProcessor(unet_chunk_size=2)
237
+ else :
238
+ attn_procs[name] = AttnProcessor()
239
+ pipe.unet.set_attn_processor(attn_procs)
240
+
241
+
242
+ for target_object in target_object_list:
243
+ target_prompt = f"A photo of a {target_object}"
244
+
245
+ for object in object_list:
246
+ for key in prompt_neg_prompt_pair_dicts.keys():
247
+ prompt, negative_prompt = prompt_neg_prompt_pair_dicts[key]
248
+
249
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
250
+
251
+ images = pipe(
252
+ prompt=prompt.replace("{prompt}", object),
253
+ guidance_scale = 7.0,
254
+ num_images_per_prompt = num_images_per_prompt,
255
+ target_prompt = target_prompt,
256
+ generator=generator,
257
+
258
+ )[0]
259
+
260
+
261
+ #make grid
262
+ grid = create_image_grid(images, 1, num_images_per_prompt)
263
+
264
+ save_name = f"{key}_src_{object}_tgt_{target_object}_activate_layer_{str_activate_layer}_seed_{seed}.png"
265
+ save_path = os.path.join(results_dir, save_name)
266
+
267
+ grid.save(save_path)
268
+
269
+ print("Saved image to: ", save_path)
270
+
271
+ #save attn map
272
+ for attn_map_save_step in attn_map_save_steps:
273
+ attn_map_save_name = f"attn_map_raw_{key}_src_{object}_tgt_{target_object}_activate_layer_{str_activate_layer}_attn_map_step_{attn_map_save_step}_seed_{seed}.pt"
274
+ attn_map_dic = {}
275
+ # for activate_layer in activate_layers:
276
+ for activate_layer in save_layer_list:
277
+ attn_map_var_name = transform_variable_name(activate_layer, attn_map_save_step)
278
+ exec(f"attn_map_dic[\"{activate_layer}\"] = {attn_map_var_name}")
279
+
280
+ torch.save(attn_map_dic, os.path.join(results_dir, attn_map_save_name))
281
+ print("Saved attn map to: ", os.path.join(results_dir, attn_map_save_name))
282
+
283
+
visualize_attention_src/utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+
4
+ def get_image(image_path, row, col, image_size=1024, grid_width=1):
5
+
6
+ left_point = (image_size + grid_width) * col
7
+ up_point = (image_size + grid_width) * row
8
+ right_point = left_point + image_size
9
+ down_point = up_point + image_size
10
+
11
+ if type(image_path) is str:
12
+ image = Image.open(image_path)
13
+ else:
14
+ image = image_path
15
+ croped_image = image.crop((left_point, up_point, right_point, down_point))
16
+ return croped_image
17
+
18
+ def get_image_v2(image_path, row, col, image_size=1024, grid_row_space=1, grid_col_space=1):
19
+
20
+ left_point = (image_size + grid_col_space) * col
21
+ up_point = (image_size + grid_row_space) * row
22
+ right_point = left_point + image_size
23
+ down_point = up_point + image_size
24
+
25
+ if type(image_path) is str:
26
+ image = Image.open(image_path)
27
+ else:
28
+ image = image_path
29
+ croped_image = image.crop((left_point, up_point, right_point, down_point))
30
+ return croped_image
31
+
32
+ def create_image(row, col, image_size=1024, grid_width=1, background_color=(255,255,255), top_padding = 0, bottom_padding = 0, left_padding = 0, right_padding = 0):
33
+
34
+ image = Image.new('RGB', (image_size * col + grid_width * (col - 1) + left_padding , image_size * row + grid_width * (row - 1)), background_color)
35
+ return image
36
+
37
+ def paste_image(grid, image, row, col, image_size=1024, grid_width=1, top_padding = 0, bottom_padding = 0, left_padding = 0, right_padding = 0):
38
+ left_point = (image_size + grid_width) * col + left_padding
39
+ up_point = (image_size + grid_width) * row + top_padding
40
+ right_point = left_point + image_size
41
+ down_point = up_point + image_size
42
+ grid.paste(image, (left_point, up_point, right_point, down_point))
43
+
44
+ return grid
45
+
46
+ def paste_image_v2(grid, image, row, col, grid_size=1024, grid_width=1, top_padding = 0, bottom_padding = 0, left_padding = 0, right_padding = 0):
47
+ left_point = (grid_size + grid_width) * col + left_padding
48
+ up_point = (grid_size + grid_width) * row + top_padding
49
+
50
+ image_width, image_height = image.size
51
+
52
+ right_point = left_point + image_width
53
+ down_point = up_point + image_height
54
+
55
+ grid.paste(image, (left_point, up_point, right_point, down_point))
56
+
57
+ return grid
58
+
59
+
60
+ def pivot_figure(file_path, image_size=1024, grid_width=1):
61
+ if type(file_path) is str:
62
+ image = Image.open(file_path)
63
+ else:
64
+ image = file_path
65
+ image_col = image.width // image_size
66
+ image_row = image.height // image_size
67
+
68
+
69
+ grid = create_image(image_col, image_row, image_size, grid_width)
70
+
71
+ for row in range(image_row):
72
+ for col in range(image_col):
73
+ croped_image = get_image(image, row, col, image_size, grid_width)
74
+ grid = paste_image(grid, croped_image, col, row, image_size, grid_width)
75
+
76
+ return grid
77
+
78
+ def horizontal_flip_figure(file_path, image_size=1024, grid_width=1):
79
+ if type(file_path) is str:
80
+ image = Image.open(file_path)
81
+ else:
82
+ image = file_path
83
+ image_col = image.width // image_size
84
+ image_row = image.height // image_size
85
+
86
+ grid = create_image(image_row, image_col, image_size, grid_width)
87
+
88
+ for row in range(image_row):
89
+ for col in range(image_col):
90
+ croped_image = get_image(image, row, image_col - col - 1, image_size, grid_width)
91
+ grid = paste_image(grid, croped_image, row, col, image_size, grid_width)
92
+
93
+ return grid
94
+
95
+ def vertical_flip_figure(file_path, image_size=1024, grid_width=1):
96
+ if type(file_path) is str:
97
+ image = Image.open(file_path)
98
+ else:
99
+ image = file_path
100
+
101
+ image_col = image.width // image_size
102
+ image_row = image.height // image_size
103
+
104
+ grid = create_image(image_row, image_col, image_size, grid_width)
105
+
106
+ for row in range(image_row):
107
+ for col in range(image_col):
108
+ croped_image = get_image(image, image_row - row - 1, col, image_size, grid_width)
109
+ grid = paste_image(grid, croped_image, row, col, image_size, grid_width)
110
+
111
+ return grid
visualize_attention_src/visualize_attn_map_script.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from PIL import Image
4
+ import numpy as np
5
+ from ipycanvas import Canvas
6
+ import cv2
7
+
8
+ from visualize_attention_src.utils import get_image
9
+
10
+ exp_dir = "saved_attention_map_results"
11
+
12
+ style_name = "line_art"
13
+ src_name = "cat"
14
+ tgt_name = "dog"
15
+
16
+ steps = ["20"]
17
+ seed = "4"
18
+ saved_dtype = "tensor"
19
+
20
+
21
+ attn_map_raws = []
22
+ for step in steps:
23
+ attn_map_name_wo_ext = f"attn_map_raw_{style_name}_src_{src_name}_tgt_{tgt_name}_activate_layer_(0, 0)(108, 140)_attn_map_step_{step}_seed_{seed}" # new
24
+
25
+ if saved_dtype == 'uint8':
26
+ attn_map_name = attn_map_name_wo_ext + '_uint8.npy'
27
+ attn_map_path = os.path.join(exp_dir, attn_map_name)
28
+ attn_map_raws.append(np.load(attn_map_path, allow_pickle=True))
29
+
30
+ else:
31
+ attn_map_name = attn_map_name_wo_ext + '.pt'
32
+ attn_map_path = os.path.join(exp_dir, attn_map_name)
33
+ attn_map_raws.append(torch.load(attn_map_path))
34
+ print(attn_map_path)
35
+
36
+ attn_map_path = os.path.join(exp_dir, attn_map_name)
37
+
38
+ print(f"{step} is on memory")
39
+
40
+ keys = [key for key in attn_map_raws[0].keys()]
41
+
42
+
43
+ print(len(keys))
44
+ key = keys[0]
45
+
46
+ ########################
47
+ tgt_idx = 3 # indicating the location of generated images.
48
+
49
+ attn_map_paired_rgb_grid_name = f"{style_name}_src_{src_name}_tgt_{tgt_name}_scale_1.0_activate_layer_(0, 0)(108, 140)_seed_{seed}.png"
50
+
51
+ attn_map_paired_rgb_grid_path = os.path.join(exp_dir, attn_map_paired_rgb_grid_name)
52
+ print(attn_map_paired_rgb_grid_path)
53
+ attn_map_paired_rgb_grid = Image.open(attn_map_paired_rgb_grid_path)
54
+
55
+ attn_map_src_img = get_image(attn_map_paired_rgb_grid, row = 0, col = 0, image_size = 1024, grid_width = 10)
56
+ attn_map_tgt_img = get_image(attn_map_paired_rgb_grid, row = 0, col = tgt_idx, image_size = 1024, grid_width = 10)
57
+
58
+
59
+ h, w = 256, 256
60
+ num_of_grid = 64
61
+
62
+ plus_50 = 0
63
+
64
+ # key_idx_list = [0,2,4,6,8,10]
65
+ key_idx_list = [6, 28]
66
+ # (108 -> 0, 109 -> 1, ... , 140 -> 32)
67
+ # if Swapping Attentio nin (108, 140) layer , use key_idx_list = [6, 28].
68
+ # 6==early upblock, 28==late upblock
69
+
70
+ saved_attention_map_idx = [0]
71
+
72
+ source_image = attn_map_src_img
73
+ target_image = attn_map_tgt_img
74
+
75
+ # resize
76
+ source_image = source_image.resize((h, w))
77
+ target_image = target_image.resize((h, w))
78
+
79
+ # convert to numpy array
80
+ source_image = np.array(source_image)
81
+ target_image = np.array(target_image)
82
+
83
+ canvas = Canvas(width=4 * w, height=h * len(key_idx_list), sync_image_data=True)
84
+ canvas.put_image_data(source_image, w * 3, 0)
85
+ canvas.put_image_data(target_image, 0, 0)
86
+
87
+ canvas.put_image_data(source_image, w * 3, h)
88
+ canvas.put_image_data(target_image, 0, h)
89
+
90
+ # Display the canvas
91
+ # display(canvas)
92
+
93
+
94
+ def save_to_file(*args, **kwargs):
95
+ canvas.to_file("my_file1.png")
96
+
97
+
98
+ # Listen to changes on the ``image_data`` trait and call ``save_to_file`` when it changes.
99
+ canvas.observe(save_to_file, "image_data")
100
+
101
+
102
+ def on_click(x, y):
103
+ cnt = 0
104
+ canvas.put_image_data(target_image, 0, 0)
105
+
106
+ print(x, y)
107
+ # draw a point
108
+ canvas.fill_style = 'red'
109
+ canvas.fill_circle(x, y, 4)
110
+
111
+ for step_i, step in enumerate(range(len(saved_attention_map_idx))):
112
+
113
+ attn_map_raw = attn_map_raws[step_i]
114
+
115
+ for key_i, key_idx in enumerate(key_idx_list):
116
+ key = keys[key_idx]
117
+
118
+ num_of_grid = int(attn_map_raw[key].shape[-1] ** (0.5))
119
+
120
+ # normalize x,y
121
+ grid_x_idx = int(x / (w / num_of_grid))
122
+ grid_y_idx = int(y / (h / num_of_grid))
123
+
124
+ print(grid_x_idx, grid_y_idx)
125
+
126
+ grid_idx = grid_x_idx + grid_y_idx * num_of_grid
127
+
128
+ attn_map = attn_map_raw[key][tgt_idx * 10:10 + tgt_idx * 10, grid_idx, :]
129
+
130
+ attn_map = attn_map.sum(dim=0)
131
+
132
+ attn_map = attn_map.reshape(num_of_grid, num_of_grid)
133
+
134
+ # process attn_map to pil
135
+ attn_map = attn_map.detach().cpu().numpy()
136
+ # attn_map = attn_map / attn_map.max()
137
+ # normalized_attn_map = attn_map
138
+ normalized_attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
139
+ normalized_attn_map = 1.0 - normalized_attn_map
140
+
141
+ heatmap = cv2.applyColorMap(np.uint8(255 * normalized_attn_map), cv2.COLORMAP_JET)
142
+ heatmap = cv2.resize(heatmap, (w, h))
143
+
144
+ attn_map = normalized_attn_map * 255
145
+
146
+ attn_map = attn_map.astype(np.uint8)
147
+
148
+ attn_map = cv2.cvtColor(attn_map, cv2.COLOR_GRAY2RGB)
149
+ # attn_map = cv2.cvtColor(attn_map, cv2.COLORMAP_JET)
150
+ attn_map = cv2.resize(attn_map, (w, h))
151
+
152
+ # draw attn_map
153
+ canvas.put_image_data(attn_map, w + step_i * 4 * w, h * key_i)
154
+ # canvas.put_image_data(attn_map, w , h*key_i)
155
+
156
+ # blend attn_map and target image
157
+ alpha = 0.85
158
+ blended_image = cv2.addWeighted(source_image, 1 - alpha, heatmap, alpha, 0)
159
+
160
+ # draw blended image
161
+ canvas.put_image_data(blended_image, w * 2 + step_i * 4 * w, h * key_i)
162
+
163
+ cnt += 1
164
+
165
+ # Attach the event handler to the canvas
166
+
167
+
168
+ canvas.on_mouse_down(on_click)