songweig commited on
Commit
99e3c03
1 Parent(s): ea48617

udpate to sdxl

Browse files
app.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  import numpy as np
9
  from torchvision import transforms
10
 
11
- from models.region_diffusion import RegionDiffusion
12
  from utils.attention_utils import get_token_maps
13
  from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
14
  get_attention_control_input, get_gradient_guidance_input
@@ -61,7 +61,7 @@ def load_url_params(url_params):
61
 
62
  def main():
63
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
- model = RegionDiffusion(device)
65
 
66
  def generate(
67
  text_input: str,
@@ -81,8 +81,8 @@ def main():
81
  run_dir = 'results/'
82
  os.makedirs(run_dir, exist_ok=True)
83
  # Load region diffusion model.
84
- height = int(height) if height else 512
85
- width = int(width) if width else 512
86
  steps = 41 if not steps else steps
87
  guidance_weight = 8.5 if not guidance_weight else guidance_weight
88
  text_input = rich_text_input if rich_text_input != '' and rich_text_input != None else text_input
@@ -117,19 +117,19 @@ def main():
117
  else:
118
  model.reset_attention_maps()
119
  model.remove_tokenmap_hooks()
120
- plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
121
- height=height, width=width, num_inference_steps=steps,
122
- guidance_scale=guidance_weight)
123
  print('time lapses to get attention maps: %.4f' %
124
  (time.time()-begin_time))
125
  seed_everything(seed)
126
  color_obj_masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
127
- 512//8, 512//8, color_target_token_ids[:-1], seed,
128
  base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
129
  return_vis=True)
130
  seed_everything(seed)
131
  model.masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
132
- 512//8, 512//8, region_target_token_ids[:-1], seed,
133
  base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
134
  return_vis=True)
135
  color_obj_atten_all = torch.zeros_like(color_obj_masks[-1])
@@ -146,14 +146,14 @@ def main():
146
  # generate image from rich text
147
  begin_time = time.time()
148
  seed_everything(seed)
149
- rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
150
- height=height, width=width, num_inference_steps=steps,
151
- guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
152
- text_format_dict=text_format_dict, inject_selfattn=inject_interval,
153
- inject_background=inject_background)
154
  print('time lapses to generate image from rich text: %.4f' %
155
  (time.time()-begin_time))
156
- return [plain_img[0], rich_img[0], segments_vis, token_maps]
157
 
158
  with gr.Blocks(css=css) as demo:
159
  url_params = gr.JSON({}, visible=False, label="URL Params")
@@ -226,12 +226,12 @@ def main():
226
  maximum=50,
227
  step=0.1,
228
  value=8.5)
229
- width = gr.Dropdown(choices=[512],
230
- value=512,
231
  label='Width',
232
  visible=True)
233
- height = gr.Dropdown(choices=[512],
234
- value=512,
235
  label='height',
236
  visible=True)
237
 
@@ -243,7 +243,7 @@ def main():
243
  with gr.Column():
244
  richtext_result = gr.Image(
245
  label='Rich-text', elem_id="rich-text-image")
246
- richtext_result.style(height=512)
247
  with gr.Row():
248
  plaintext_result = gr.Image(
249
  label='Plain-text', elem_id="plain-text-image")
@@ -265,22 +265,22 @@ def main():
265
  [
266
  '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
267
  '',
268
- 5,
 
269
  0.3,
270
- 0,
271
  0.5,
272
- 6,
273
  0,
274
  None,
275
  ],
276
  [
277
- '{"ops":[{"insert":"A "},{"attributes":{"link":"Thor Kitchen 30 Inch Wide Freestanding Gas Range with Automatic Re-Ignition System"},"insert":"kitchen island"},{"insert":" next to a "},{"attributes":{"link":"an open refrigerator stocked with fresh produce, dairy products, and beverages. "},"insert":"refrigerator"},{"insert":", by James McDonald and Joarc Architects, home, interior, octane render, deviantart, cinematic, key art, hyperrealism, sun light, sunrays, canon eos c 300, ƒ 1.8, 35 mm, 8k, medium - format print"}]}',
278
  '',
279
- 7,
280
- 0.5,
281
- 0,
282
  0.5,
283
- 6,
284
  0,
285
  None,
286
  ],
@@ -325,7 +325,7 @@ def main():
325
  '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#04a704"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
326
  'lowres, had anatomy, bad hands, cropped, worst quality',
327
  11,
328
- 0.3,
329
  0.3,
330
  0.3,
331
  6,
@@ -333,10 +333,10 @@ def main():
333
  None,
334
  ],
335
  [
336
- '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#999999"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
337
  'lowres, had anatomy, bad hands, cropped, worst quality',
338
  11,
339
- 0.3,
340
  0.3,
341
  0.3,
342
  6,
@@ -344,36 +344,25 @@ def main():
344
  None,
345
  ],
346
  [
347
- '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
348
- '',
349
- 10,
350
- 0.4,
351
  0.5,
352
  0.3,
 
353
  6,
354
  0.5,
355
  None,
356
  ],
357
  [
358
- '{"ops":[{"insert":"A mesmerizing sight that captures the beauty of a "},{"attributes":{"color":"#4775fc"},"insert":"rose"},{"insert":" blooming, close up"}]}',
359
- '',
360
- 3,
361
- 0.3,
362
- 0,
363
- 0,
364
- 9,
365
- 1,
366
- None,
367
- ],
368
- [
369
- '{"ops":[{"insert":"A "},{"attributes":{"color":"#FFD700"},"insert":"marble statue of a wolf\'s head and shoulder"},{"insert":", surrounded by colorful flowers michelangelo, detailed, intricate, full of color, led lighting, trending on artstation, 4 k, hyperrealistic, 3 5 mm, focused, extreme details, unreal engine 5, masterpiece "}]}',
370
  '',
371
- 5,
372
- 0.4,
373
- 0.3,
374
  0.3,
375
- 5,
376
- 0.6,
377
  None,
378
  ],
379
  ]
@@ -403,21 +392,21 @@ def main():
403
  with gr.Row():
404
  style_examples = [
405
  [
406
- '{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":""}]}',
407
  '',
408
  10,
409
- 0.4,
410
  0,
411
- 0.2,
412
- 3,
413
  0,
414
  None,
415
  ],
416
  [
417
  '{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
418
  'worst quality, dark, poor quality',
419
- 5,
420
- 0.3,
421
  0,
422
  0,
423
  9,
@@ -425,10 +414,10 @@ def main():
425
  None,
426
  ],
427
  [
428
- '{"ops":[{"insert":"a "},{"attributes":{"font":"slabo"},"insert":"night sky filled with stars"},{"insert":" above a "},{"attributes":{"font":"roboto"},"insert":"turbulent sea with giant waves"}]}',
429
  '',
430
  2,
431
- 0.35,
432
  0,
433
  0,
434
  6,
@@ -462,35 +451,35 @@ def main():
462
  with gr.Row():
463
  size_examples = [
464
  [
465
- '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": ", pepperoni, and mushroom on the top, 4k, photorealistic"}]}',
466
- 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
467
  5,
468
  0.3,
469
  0,
470
  0,
471
- 13,
472
  1,
473
  None,
474
  ],
475
  [
476
- '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "20px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top, 4k, photorealistic"}]}',
477
- 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
478
  5,
479
  0.3,
480
  0,
481
  0,
482
- 13,
483
  1,
484
  None,
485
  ],
486
  [
487
- '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "70px"}, "insert": "mushroom"}, {"insert": " on the top, 4k, photorealistic"}]}',
488
- 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
489
  5,
490
  0.3,
491
  0,
492
  0,
493
- 13,
494
  1,
495
  None,
496
  ],
 
8
  import numpy as np
9
  from torchvision import transforms
10
 
11
+ from models.region_diffusion_xl import RegionDiffusionXL
12
  from utils.attention_utils import get_token_maps
13
  from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
14
  get_attention_control_input, get_gradient_guidance_input
 
61
 
62
  def main():
63
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ model = RegionDiffusionXL()
65
 
66
  def generate(
67
  text_input: str,
 
81
  run_dir = 'results/'
82
  os.makedirs(run_dir, exist_ok=True)
83
  # Load region diffusion model.
84
+ height = int(height) if height else 1024
85
+ width = int(width) if width else 1024
86
  steps = 41 if not steps else steps
87
  guidance_weight = 8.5 if not guidance_weight else guidance_weight
88
  text_input = rich_text_input if rich_text_input != '' and rich_text_input != None else text_input
 
117
  else:
118
  model.reset_attention_maps()
119
  model.remove_tokenmap_hooks()
120
+ plain_img = model.sample([base_text_prompt], negative_prompt=[negative_text],
121
+ height=height, width=width, num_inference_steps=steps,
122
+ guidance_scale=guidance_weight, run_rich_text=False)
123
  print('time lapses to get attention maps: %.4f' %
124
  (time.time()-begin_time))
125
  seed_everything(seed)
126
  color_obj_masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
127
+ 1024//8, 1024//8, color_target_token_ids[:-1], seed,
128
  base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
129
  return_vis=True)
130
  seed_everything(seed)
131
  model.masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
132
+ 1024//8, 1024//8, region_target_token_ids[:-1], seed,
133
  base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
134
  return_vis=True)
135
  color_obj_atten_all = torch.zeros_like(color_obj_masks[-1])
 
146
  # generate image from rich text
147
  begin_time = time.time()
148
  seed_everything(seed)
149
+ rich_img = model.sample(region_text_prompts, negative_prompt=[negative_text],
150
+ height=height, width=width, num_inference_steps=steps,
151
+ guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
152
+ text_format_dict=text_format_dict, inject_selfattn=inject_interval,
153
+ inject_background=inject_background, run_rich_text=True)
154
  print('time lapses to generate image from rich text: %.4f' %
155
  (time.time()-begin_time))
156
+ return [plain_img.images[0], rich_img.images[0], segments_vis, token_maps]
157
 
158
  with gr.Blocks(css=css) as demo:
159
  url_params = gr.JSON({}, visible=False, label="URL Params")
 
226
  maximum=50,
227
  step=0.1,
228
  value=8.5)
229
+ width = gr.Dropdown(choices=[1024],
230
+ value=1024,
231
  label='Width',
232
  visible=True)
233
+ height = gr.Dropdown(choices=[1024],
234
+ value=1024,
235
  label='height',
236
  visible=True)
237
 
 
243
  with gr.Column():
244
  richtext_result = gr.Image(
245
  label='Rich-text', elem_id="rich-text-image")
246
+ richtext_result.style(height=1024)
247
  with gr.Row():
248
  plaintext_result = gr.Image(
249
  label='Plain-text', elem_id="plain-text-image")
 
265
  [
266
  '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
267
  '',
268
+ 9,
269
+ 0.3,
270
  0.3,
 
271
  0.5,
272
+ 3,
273
  0,
274
  None,
275
  ],
276
  [
277
+ '{"ops":[{"insert":"A cozy "},{"attributes":{"link":"A charming wooden cabin with Christmas decoration, warm light coming out from the windows."},"insert":"cabin"},{"insert":" nestled in a "},{"attributes":{"link":"Towering evergreen trees covered in a thick layer of pristine snow."},"insert":"snowy forest"},{"insert":", and a "},{"attributes":{"link":"A cute snowman wearing a carrot nose, coal eyes, and a colorful scarf, welcoming visitors with a cheerful vibe."},"insert":"snowman"},{"insert":" stands in the yard.\n"}]}',
278
  '',
279
+ 12,
280
+ 0.4,
281
+ 0.3,
282
  0.5,
283
+ 4,
284
  0,
285
  None,
286
  ],
 
325
  '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#04a704"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
326
  'lowres, had anatomy, bad hands, cropped, worst quality',
327
  11,
328
+ 0.5,
329
  0.3,
330
  0.3,
331
  6,
 
333
  None,
334
  ],
335
  [
336
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#ff5df1"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
337
  'lowres, had anatomy, bad hands, cropped, worst quality',
338
  11,
339
+ 0.5,
340
  0.3,
341
  0.3,
342
  6,
 
344
  None,
345
  ],
346
  [
347
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#999999"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
348
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
349
+ 11,
 
350
  0.5,
351
  0.3,
352
+ 0.3,
353
  6,
354
  0.5,
355
  None,
356
  ],
357
  [
358
+ '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
 
 
 
 
 
 
 
 
 
 
 
359
  '',
360
+ 10,
361
+ 0.5,
362
+ 0.5,
363
  0.3,
364
+ 7,
365
+ 0.5,
366
  None,
367
  ],
368
  ]
 
392
  with gr.Row():
393
  style_examples = [
394
  [
395
+ '{"ops":[{"insert":"a beautiful"},{"attributes":{"font":"mirza"},"insert":" garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain"},{"insert":" in the background"}]}',
396
  '',
397
  10,
398
+ 0.6,
399
  0,
400
+ 0.4,
401
+ 5,
402
  0,
403
  None,
404
  ],
405
  [
406
  '{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
407
  'worst quality, dark, poor quality',
408
+ 2,
409
+ 0.45,
410
  0,
411
  0,
412
  9,
 
414
  None,
415
  ],
416
  [
417
+ '{"ops":[{"insert":"a night"},{"attributes":{"font":"slabo"},"insert":" sky"},{"insert":" filled with stars above a turbulent"},{"attributes":{"font":"roboto"},"insert":" sea"},{"insert":" with giant waves"}]}',
418
  '',
419
  2,
420
+ 0.6,
421
  0,
422
  0,
423
  6,
 
451
  with gr.Row():
452
  size_examples = [
453
  [
454
+ '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": " pepperoni, and mushroom on the top"}]}',
455
+ '',
456
  5,
457
  0.3,
458
  0,
459
  0,
460
+ 3,
461
  1,
462
  None,
463
  ],
464
  [
465
+ '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "60px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top"}]}',
466
+ '',
467
  5,
468
  0.3,
469
  0,
470
  0,
471
+ 3,
472
  1,
473
  None,
474
  ],
475
  [
476
+ '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "60px"}, "insert": "mushroom"}, {"insert": " on the top"}]}',
477
+ '',
478
  5,
479
  0.3,
480
  0,
481
  0,
482
+ 3,
483
  1,
484
  None,
485
  ],
app_sd.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import os
4
+ import json
5
+ import time
6
+ import argparse
7
+ import torch
8
+ import numpy as np
9
+ from torchvision import transforms
10
+
11
+ from models.region_diffusion import RegionDiffusion
12
+ from utils.attention_utils import get_token_maps
13
+ from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
14
+ get_attention_control_input, get_gradient_guidance_input
15
+
16
+
17
+ import gradio as gr
18
+ from PIL import Image, ImageOps
19
+ from share_btn import community_icon_html, loading_icon_html, share_js, css
20
+
21
+
22
+ help_text = """
23
+ If you are encountering an error or not achieving your desired outcome, here are some potential reasons and recommendations to consider:
24
+ 1. If you format only a portion of a word rather than the complete word, an error may occur.
25
+ 2. If you use font color and get completely corrupted results, you may consider decrease the color weight lambda.
26
+ 3. Consider using a different seed.
27
+ """
28
+
29
+
30
+ canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
31
+ get_js_data = """
32
+ async (text_input, negative_prompt, height, width, seed, steps, num_segments, segment_threshold, inject_interval, guidance_weight, color_guidance_weight, rich_text_input, background_aug) => {
33
+ const richEl = document.getElementById("rich-text-root");
34
+ const data = richEl? richEl.contentDocument.body._data : {};
35
+ return [text_input, negative_prompt, height, width, seed, steps, num_segments, segment_threshold, inject_interval, guidance_weight, color_guidance_weight, JSON.stringify(data), background_aug];
36
+ }
37
+ """
38
+ set_js_data = """
39
+ async (text_input) => {
40
+ const richEl = document.getElementById("rich-text-root");
41
+ const data = text_input ? JSON.parse(text_input) : null;
42
+ if (richEl && data) richEl.contentDocument.body.setQuillContents(data);
43
+ }
44
+ """
45
+
46
+ get_window_url_params = """
47
+ async (url_params) => {
48
+ const params = new URLSearchParams(window.location.search);
49
+ url_params = Object.fromEntries(params);
50
+ return [url_params];
51
+ }
52
+ """
53
+
54
+
55
+ def load_url_params(url_params):
56
+ if 'prompt' in url_params:
57
+ return gr.update(visible=True), url_params
58
+ else:
59
+ return gr.update(visible=False), url_params
60
+
61
+
62
+ def main():
63
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ model = RegionDiffusion(device)
65
+
66
+ def generate(
67
+ text_input: str,
68
+ negative_text: str,
69
+ height: int,
70
+ width: int,
71
+ seed: int,
72
+ steps: int,
73
+ num_segments: int,
74
+ segment_threshold: float,
75
+ inject_interval: float,
76
+ guidance_weight: float,
77
+ color_guidance_weight: float,
78
+ rich_text_input: str,
79
+ background_aug: bool,
80
+ ):
81
+ run_dir = 'results/'
82
+ os.makedirs(run_dir, exist_ok=True)
83
+ # Load region diffusion model.
84
+ height = int(height)
85
+ width = int(width)
86
+ steps = 41 if not steps else steps
87
+ guidance_weight = 8.5 if not guidance_weight else guidance_weight
88
+ text_input = rich_text_input if rich_text_input != '' else text_input
89
+ print('text_input', text_input)
90
+ if (text_input == '' or rich_text_input == ''):
91
+ raise gr.Error("Please enter some text.")
92
+ # parse json to span attributes
93
+ base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
94
+ color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
95
+ json.loads(text_input))
96
+
97
+ # create control input for region diffusion
98
+ region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
99
+ model, base_text_prompt, style_text_prompts, footnote_text_prompts,
100
+ footnote_target_tokens, color_text_prompts, color_names)
101
+
102
+ # create control input for cross attention
103
+ text_format_dict = get_attention_control_input(
104
+ model, base_tokens, size_text_prompts_and_sizes)
105
+
106
+ # create control input for region guidance
107
+ text_format_dict, color_target_token_ids = get_gradient_guidance_input(
108
+ model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, color_guidance_weight=color_guidance_weight)
109
+
110
+ seed_everything(seed)
111
+
112
+ # get token maps from plain text to image generation.
113
+ begin_time = time.time()
114
+ if model.selfattn_maps is None and model.crossattn_maps is None:
115
+ model.remove_tokenmap_hooks()
116
+ model.register_tokenmap_hooks()
117
+ else:
118
+ model.reset_attention_maps()
119
+ model.remove_tokenmap_hooks()
120
+ plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
121
+ height=height, width=width, num_inference_steps=steps,
122
+ guidance_scale=guidance_weight)
123
+ print('time lapses to get attention maps: %.4f' %
124
+ (time.time()-begin_time))
125
+ seed_everything(seed)
126
+ color_obj_masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
127
+ 512//8, 512//8, color_target_token_ids[:-1], seed,
128
+ base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
129
+ return_vis=True)
130
+ seed_everything(seed)
131
+ model.masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
132
+ 512//8, 512//8, region_target_token_ids[:-1], seed,
133
+ base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
134
+ return_vis=True)
135
+ color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
136
+ interpolation=transforms.InterpolationMode.BICUBIC,
137
+ antialias=True)
138
+ for color_obj_mask in color_obj_masks]
139
+ text_format_dict['color_obj_atten'] = color_obj_masks
140
+ model.remove_tokenmap_hooks()
141
+
142
+ # generate image from rich text
143
+ begin_time = time.time()
144
+ seed_everything(seed)
145
+ if background_aug:
146
+ bg_aug_end = 500
147
+ else:
148
+ bg_aug_end = 1000
149
+ rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
150
+ height=height, width=width, num_inference_steps=steps,
151
+ guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
152
+ text_format_dict=text_format_dict, inject_selfattn=inject_interval,
153
+ bg_aug_end=bg_aug_end)
154
+ print('time lapses to generate image from rich text: %.4f' %
155
+ (time.time()-begin_time))
156
+ return [plain_img[0], rich_img[0], segments_vis, token_maps]
157
+
158
+ with gr.Blocks(css=css) as demo:
159
+ url_params = gr.JSON({}, visible=False, label="URL Params")
160
+ gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
161
+ <p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
162
+ <p> UMD, Adobe, CMU <p/>
163
+ <p> <a href="https://huggingface.co/spaces/songweig/rich-text-to-image?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="display:inline;"alt="Duplicate Space"></a> | <a href="https://rich-text-to-image.github.io">[Website]</a> | <a href="https://github.com/SongweiGe/rich-text-to-image">[Code]</a> | <a href="https://arxiv.org/abs/2304.06720">[Paper]</a><p/>
164
+ <p> For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.""")
165
+ with gr.Row():
166
+ with gr.Column():
167
+ rich_text_el = gr.HTML(canvas_html, elem_id="canvas_html")
168
+ rich_text_input = gr.Textbox(value="", visible=False)
169
+ text_input = gr.Textbox(
170
+ label='Rich-text JSON Input',
171
+ visible=False,
172
+ max_lines=1,
173
+ placeholder='Example: \'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#b26b00"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background.\n"}]}\'',
174
+ elem_id="text_input"
175
+ )
176
+ negative_prompt = gr.Textbox(
177
+ label='Negative Prompt',
178
+ max_lines=1,
179
+ placeholder='Example: poor quality, blurry, dark, low resolution, low quality, worst quality',
180
+ elem_id="negative_prompt"
181
+ )
182
+ segment_threshold = gr.Slider(label='Token map threshold',
183
+ info='(See less area in token maps? Decrease this. See too much area? Increase this.)',
184
+ minimum=0,
185
+ maximum=1,
186
+ step=0.01,
187
+ value=0.25)
188
+ inject_interval = gr.Slider(label='Detail preservation',
189
+ info='(To preserve more structure from plain-text generation, increase this. To see more rich-text attributes, decrease this.)',
190
+ minimum=0,
191
+ maximum=1,
192
+ step=0.01,
193
+ value=0.)
194
+ color_guidance_weight = gr.Slider(label='Color weight',
195
+ info='(To obtain more precise color, increase this, while too large value may cause artifacts.)',
196
+ minimum=0,
197
+ maximum=2,
198
+ step=0.1,
199
+ value=0.5)
200
+ num_segments = gr.Slider(label='Number of segments',
201
+ minimum=2,
202
+ maximum=20,
203
+ step=1,
204
+ value=9)
205
+ seed = gr.Slider(label='Seed',
206
+ minimum=0,
207
+ maximum=100000,
208
+ step=1,
209
+ value=6,
210
+ elem_id="seed"
211
+ )
212
+ background_aug = gr.Checkbox(
213
+ label='Precise region alignment',
214
+ info='(For strict region alignment, select this option, but beware of potential artifacts when using with style.)',
215
+ value=True)
216
+ with gr.Accordion('Other Parameters', open=False):
217
+ steps = gr.Slider(label='Number of Steps',
218
+ minimum=0,
219
+ maximum=500,
220
+ step=1,
221
+ value=41)
222
+ guidance_weight = gr.Slider(label='CFG weight',
223
+ minimum=0,
224
+ maximum=50,
225
+ step=0.1,
226
+ value=8.5)
227
+ width = gr.Dropdown(choices=[512],
228
+ value=512,
229
+ label='Width',
230
+ visible=True)
231
+ height = gr.Dropdown(choices=[512],
232
+ value=512,
233
+ label='height',
234
+ visible=True)
235
+
236
+ with gr.Row():
237
+ with gr.Column(scale=1, min_width=100):
238
+ generate_button = gr.Button("Generate")
239
+ load_params_button = gr.Button(
240
+ "Load from URL Params", visible=True)
241
+ with gr.Column():
242
+ richtext_result = gr.Image(
243
+ label='Rich-text', elem_id="rich-text-image")
244
+ richtext_result.style(height=512)
245
+ with gr.Row():
246
+ plaintext_result = gr.Image(
247
+ label='Plain-text', elem_id="plain-text-image")
248
+ segments = gr.Image(label='Segmentation')
249
+ with gr.Row():
250
+ token_map = gr.Image(label='Token Maps')
251
+ with gr.Row(visible=False) as share_row:
252
+ with gr.Group(elem_id="share-btn-container"):
253
+ community_icon = gr.HTML(community_icon_html)
254
+ loading_icon = gr.HTML(loading_icon_html)
255
+ share_button = gr.Button(
256
+ "Share to community", elem_id="share-btn")
257
+ share_button.click(None, [], [], _js=share_js)
258
+ with gr.Row():
259
+ gr.Markdown(help_text)
260
+
261
+ with gr.Row():
262
+ footnote_examples = [
263
+ [
264
+ '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
265
+ '',
266
+ 5,
267
+ 0.3,
268
+ 0,
269
+ 6,
270
+ 1,
271
+ None,
272
+ True
273
+ ],
274
+ [
275
+ '{"ops":[{"insert":"A "},{"attributes":{"link":"kitchen island with a stove with gas burners and a built-in oven "},"insert":"kitchen island"},{"insert":" next to a "},{"attributes":{"link":"an open refrigerator stocked with fresh produce, dairy products, and beverages. "},"insert":"refrigerator"},{"insert":", by James McDonald and Joarc Architects, home, interior, octane render, deviantart, cinematic, key art, hyperrealism, sun light, sunrays, canon eos c 300, ƒ 1.8, 35 mm, 8k, medium - format print"}]}',
276
+ '',
277
+ 6,
278
+ 0.5,
279
+ 0,
280
+ 6,
281
+ 1,
282
+ None,
283
+ True
284
+ ],
285
+ [
286
+ '{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
287
+ '',
288
+ 4,
289
+ 0.3,
290
+ 0,
291
+ 4,
292
+ 1,
293
+ None,
294
+ True
295
+ ],
296
+ ]
297
+
298
+ gr.Examples(examples=footnote_examples,
299
+ label='Footnote examples',
300
+ inputs=[
301
+ text_input,
302
+ negative_prompt,
303
+ num_segments,
304
+ segment_threshold,
305
+ inject_interval,
306
+ seed,
307
+ color_guidance_weight,
308
+ rich_text_input,
309
+ background_aug,
310
+ ],
311
+ outputs=[
312
+ plaintext_result,
313
+ richtext_result,
314
+ segments,
315
+ token_map,
316
+ ],
317
+ fn=generate,
318
+ # cache_examples=True,
319
+ examples_per_page=20)
320
+ with gr.Row():
321
+ color_examples = [
322
+ [
323
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#00ffff"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
324
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
325
+ 9,
326
+ 0.25,
327
+ 0.3,
328
+ 6,
329
+ 0.5,
330
+ None,
331
+ True
332
+ ],
333
+ [
334
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#eeeeee"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
335
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
336
+ 9,
337
+ 0.25,
338
+ 0.3,
339
+ 6,
340
+ 0.1,
341
+ None,
342
+ True
343
+ ],
344
+ [
345
+ '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
346
+ '',
347
+ 5,
348
+ 0.3,
349
+ 0.5,
350
+ 6,
351
+ 0.5,
352
+ None,
353
+ False
354
+ ],
355
+ [
356
+ '{"ops":[{"insert":"A mesmerizing sight that captures the beauty of a "},{"attributes":{"color":"#4775fc"},"insert":"rose"},{"insert":" blooming, close up"}]}',
357
+ '',
358
+ 3,
359
+ 0.3,
360
+ 0,
361
+ 9,
362
+ 1,
363
+ None,
364
+ False
365
+ ],
366
+ [
367
+ '{"ops":[{"insert":"A "},{"attributes":{"color":"#FFD700"},"insert":"marble statue of a wolf\'s head and shoulder"},{"insert":", surrounded by colorful flowers michelangelo, detailed, intricate, full of color, led lighting, trending on artstation, 4 k, hyperrealistic, 3 5 mm, focused, extreme details, unreal engine 5, masterpiece "}]}',
368
+ '',
369
+ 5,
370
+ 0.3,
371
+ 0,
372
+ 5,
373
+ 0.6,
374
+ None,
375
+ False
376
+ ],
377
+ ]
378
+ gr.Examples(examples=color_examples,
379
+ label='Font color examples',
380
+ inputs=[
381
+ text_input,
382
+ negative_prompt,
383
+ num_segments,
384
+ segment_threshold,
385
+ inject_interval,
386
+ seed,
387
+ color_guidance_weight,
388
+ rich_text_input,
389
+ background_aug,
390
+ ],
391
+ outputs=[
392
+ plaintext_result,
393
+ richtext_result,
394
+ segments,
395
+ token_map,
396
+ ],
397
+ fn=generate,
398
+ # cache_examples=True,
399
+ examples_per_page=20)
400
+
401
+ with gr.Row():
402
+ style_examples = [
403
+ [
404
+ '{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":""}]}',
405
+ '',
406
+ 10,
407
+ 0.45,
408
+ 0,
409
+ 0.2,
410
+ 3,
411
+ 0.5,
412
+ None,
413
+ False
414
+ ],
415
+ [
416
+ '{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
417
+ 'worst quality, dark, poor quality',
418
+ 2,
419
+ 0.45,
420
+ 0,
421
+ 9,
422
+ 0.5,
423
+ None,
424
+ False
425
+ ],
426
+ [
427
+ '{"ops":[{"insert":"a "},{"attributes":{"font":"slabo"},"insert":"night sky filled with stars"},{"insert":" above a "},{"attributes":{"font":"roboto"},"insert":"turbulent sea with giant waves"}]}',
428
+ '',
429
+ 2,
430
+ 0.45,
431
+ 0,
432
+ 0,
433
+ 6,
434
+ 0.5,
435
+ None,
436
+ False
437
+ ],
438
+ ]
439
+ gr.Examples(examples=style_examples,
440
+ label='Font style examples',
441
+ inputs=[
442
+ text_input,
443
+ negative_prompt,
444
+ num_segments,
445
+ segment_threshold,
446
+ inject_interval,
447
+ seed,
448
+ color_guidance_weight,
449
+ rich_text_input,
450
+ background_aug,
451
+ ],
452
+ outputs=[
453
+ plaintext_result,
454
+ richtext_result,
455
+ segments,
456
+ token_map,
457
+ ],
458
+ fn=generate,
459
+ # cache_examples=True,
460
+ examples_per_page=20)
461
+
462
+ with gr.Row():
463
+ size_examples = [
464
+ [
465
+ '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": ", pepperoni, and mushroom on the top, 4k, photorealistic"}]}',
466
+ 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
467
+ 5,
468
+ 0.3,
469
+ 0,
470
+ 13,
471
+ 1,
472
+ None,
473
+ False
474
+ ],
475
+ [
476
+ '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "20px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top, 4k, photorealistic"}]}',
477
+ 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
478
+ 5,
479
+ 0.3,
480
+ 0,
481
+ 13,
482
+ 1,
483
+ None,
484
+ False
485
+ ],
486
+ [
487
+ '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "70px"}, "insert": "mushroom"}, {"insert": " on the top, 4k, photorealistic"}]}',
488
+ 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
489
+ 5,
490
+ 0.3,
491
+ 0,
492
+ 13,
493
+ 1,
494
+ None,
495
+ False
496
+ ],
497
+ ]
498
+ gr.Examples(examples=size_examples,
499
+ label='Font size examples',
500
+ inputs=[
501
+ text_input,
502
+ negative_prompt,
503
+ num_segments,
504
+ segment_threshold,
505
+ inject_interval,
506
+ seed,
507
+ color_guidance_weight,
508
+ rich_text_input,
509
+ background_aug,
510
+ ],
511
+ outputs=[
512
+ plaintext_result,
513
+ richtext_result,
514
+ segments,
515
+ token_map,
516
+ ],
517
+ fn=generate,
518
+ # cache_examples=True,
519
+ examples_per_page=20)
520
+ generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
521
+ fn=generate,
522
+ inputs=[
523
+ text_input,
524
+ negative_prompt,
525
+ height,
526
+ width,
527
+ seed,
528
+ steps,
529
+ num_segments,
530
+ segment_threshold,
531
+ inject_interval,
532
+ guidance_weight,
533
+ color_guidance_weight,
534
+ rich_text_input,
535
+ background_aug
536
+ ],
537
+ outputs=[plaintext_result, richtext_result, segments, token_map],
538
+ _js=get_js_data
539
+ ).then(
540
+ fn=lambda: gr.update(visible=True), inputs=None, outputs=share_row, queue=False)
541
+ text_input.change(
542
+ fn=None, inputs=[text_input], outputs=None, _js=set_js_data, queue=False)
543
+ # load url param prompt to textinput
544
+ load_params_button.click(fn=lambda x: x['prompt'], inputs=[
545
+ url_params], outputs=[text_input], queue=False)
546
+ demo.load(
547
+ fn=load_url_params,
548
+ inputs=[url_params],
549
+ outputs=[load_params_button, url_params],
550
+ _js=get_window_url_params
551
+ )
552
+ demo.queue(concurrency_count=1)
553
+ demo.launch(share=False)
554
+
555
+
556
+ if __name__ == "__main__":
557
+ main()
models/attention.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -11,378 +11,19 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- import math
15
- import warnings
16
- from dataclasses import dataclass
17
- from typing import Optional
18
 
19
  import torch
20
  import torch.nn.functional as F
21
  from torch import nn
22
 
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.models.modeling_utils import ModelMixin
25
- from diffusers.models.embeddings import ImagePositionalEmbeddings
26
- from diffusers.utils import BaseOutput
27
- from diffusers.utils.import_utils import is_xformers_available
28
-
29
-
30
- @dataclass
31
- class Transformer2DModelOutput(BaseOutput):
32
- """
33
- Args:
34
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
35
- Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
36
- for the unnoised latent pixels.
37
- """
38
-
39
- sample: torch.FloatTensor
40
-
41
-
42
- if is_xformers_available():
43
- import xformers
44
- import xformers.ops
45
- else:
46
- xformers = None
47
-
48
-
49
- class Transformer2DModel(ModelMixin, ConfigMixin):
50
- """
51
- Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
52
- embeddings) inputs.
53
-
54
- When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
55
- transformer action. Finally, reshape to image.
56
-
57
- When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
58
- embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
59
- classes of unnoised image.
60
-
61
- Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
62
- image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
63
-
64
- Parameters:
65
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
66
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
67
- in_channels (`int`, *optional*):
68
- Pass if the input is continuous. The number of channels in the input and output.
69
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
70
- dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
71
- cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
72
- sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
73
- Note that this is fixed at training time as it is used for learning a number of position embeddings. See
74
- `ImagePositionalEmbeddings`.
75
- num_vector_embeds (`int`, *optional*):
76
- Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
77
- Includes the class for the masked latent pixel.
78
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
79
- num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
80
- The number of diffusion steps used during training. Note that this is fixed at training time as it is used
81
- to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
82
- up to but not more than steps than `num_embeds_ada_norm`.
83
- attention_bias (`bool`, *optional*):
84
- Configure if the TransformerBlocks' attention should contain a bias parameter.
85
- """
86
-
87
- @register_to_config
88
- def __init__(
89
- self,
90
- num_attention_heads: int = 16,
91
- attention_head_dim: int = 88,
92
- in_channels: Optional[int] = None,
93
- num_layers: int = 1,
94
- dropout: float = 0.0,
95
- norm_num_groups: int = 32,
96
- cross_attention_dim: Optional[int] = None,
97
- attention_bias: bool = False,
98
- sample_size: Optional[int] = None,
99
- num_vector_embeds: Optional[int] = None,
100
- activation_fn: str = "geglu",
101
- num_embeds_ada_norm: Optional[int] = None,
102
- use_linear_projection: bool = False,
103
- only_cross_attention: bool = False,
104
- ):
105
- super().__init__()
106
- self.use_linear_projection = use_linear_projection
107
- self.num_attention_heads = num_attention_heads
108
- self.attention_head_dim = attention_head_dim
109
- inner_dim = num_attention_heads * attention_head_dim
110
-
111
- # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
112
- # Define whether input is continuous or discrete depending on configuration
113
- self.is_input_continuous = in_channels is not None
114
- self.is_input_vectorized = num_vector_embeds is not None
115
-
116
- if self.is_input_continuous and self.is_input_vectorized:
117
- raise ValueError(
118
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
119
- " sure that either `in_channels` or `num_vector_embeds` is None."
120
- )
121
- elif not self.is_input_continuous and not self.is_input_vectorized:
122
- raise ValueError(
123
- f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
124
- " sure that either `in_channels` or `num_vector_embeds` is not None."
125
- )
126
-
127
- # 2. Define input layers
128
- if self.is_input_continuous:
129
- self.in_channels = in_channels
130
-
131
- self.norm = torch.nn.GroupNorm(
132
- num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
133
- if use_linear_projection:
134
- self.proj_in = nn.Linear(in_channels, inner_dim)
135
- else:
136
- self.proj_in = nn.Conv2d(
137
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
138
- elif self.is_input_vectorized:
139
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
140
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
141
-
142
- self.height = sample_size
143
- self.width = sample_size
144
- self.num_vector_embeds = num_vector_embeds
145
- self.num_latent_pixels = self.height * self.width
146
-
147
- self.latent_image_embedding = ImagePositionalEmbeddings(
148
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
149
- )
150
-
151
- # 3. Define transformers blocks
152
- self.transformer_blocks = nn.ModuleList(
153
- [
154
- BasicTransformerBlock(
155
- inner_dim,
156
- num_attention_heads,
157
- attention_head_dim,
158
- dropout=dropout,
159
- cross_attention_dim=cross_attention_dim,
160
- activation_fn=activation_fn,
161
- num_embeds_ada_norm=num_embeds_ada_norm,
162
- attention_bias=attention_bias,
163
- only_cross_attention=only_cross_attention,
164
- )
165
- for d in range(num_layers)
166
- ]
167
- )
168
-
169
- # 4. Define output layers
170
- if self.is_input_continuous:
171
- if use_linear_projection:
172
- self.proj_out = nn.Linear(in_channels, inner_dim)
173
- else:
174
- self.proj_out = nn.Conv2d(
175
- inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
176
- elif self.is_input_vectorized:
177
- self.norm_out = nn.LayerNorm(inner_dim)
178
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
179
-
180
- def _set_attention_slice(self, slice_size):
181
- for block in self.transformer_blocks:
182
- block._set_attention_slice(slice_size)
183
-
184
- def forward(self, hidden_states, encoder_hidden_states=None, timestep=None,
185
- text_format_dict={}, return_dict: bool = True):
186
- """
187
- Args:
188
- hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
189
- When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
190
- hidden_states
191
- encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
192
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
193
- self-attention.
194
- timestep ( `torch.long`, *optional*):
195
- Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
196
- return_dict (`bool`, *optional*, defaults to `True`):
197
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
198
-
199
- Returns:
200
- [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
201
- if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
202
- tensor.
203
- """
204
- # 1. Input
205
- if self.is_input_continuous:
206
- batch, channel, height, weight = hidden_states.shape
207
- residual = hidden_states
208
-
209
- hidden_states = self.norm(hidden_states)
210
- if not self.use_linear_projection:
211
- hidden_states = self.proj_in(hidden_states)
212
- inner_dim = hidden_states.shape[1]
213
- hidden_states = hidden_states.permute(
214
- 0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
215
- else:
216
- inner_dim = hidden_states.shape[1]
217
- hidden_states = hidden_states.permute(
218
- 0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
219
- hidden_states = self.proj_in(hidden_states)
220
- elif self.is_input_vectorized:
221
- hidden_states = self.latent_image_embedding(hidden_states)
222
-
223
- # 2. Blocks
224
- for block in self.transformer_blocks:
225
- hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep,
226
- text_format_dict=text_format_dict)
227
-
228
- # 3. Output
229
- if self.is_input_continuous:
230
- if not self.use_linear_projection:
231
- hidden_states = (
232
- hidden_states.reshape(batch, height, weight, inner_dim).permute(
233
- 0, 3, 1, 2).contiguous()
234
- )
235
- hidden_states = self.proj_out(hidden_states)
236
- else:
237
- hidden_states = self.proj_out(hidden_states)
238
- hidden_states = (
239
- hidden_states.reshape(batch, height, weight, inner_dim).permute(
240
- 0, 3, 1, 2).contiguous()
241
- )
242
-
243
- output = hidden_states + residual
244
- elif self.is_input_vectorized:
245
- hidden_states = self.norm_out(hidden_states)
246
- logits = self.out(hidden_states)
247
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
248
- logits = logits.permute(0, 2, 1)
249
-
250
- # log(p(x_0))
251
- output = F.log_softmax(logits.double(), dim=1).float()
252
-
253
- if not return_dict:
254
- return (output,)
255
-
256
- return Transformer2DModelOutput(sample=output)
257
-
258
- def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
259
- for block in self.transformer_blocks:
260
- block._set_use_memory_efficient_attention_xformers(
261
- use_memory_efficient_attention_xformers)
262
-
263
-
264
- class AttentionBlock(nn.Module):
265
- """
266
- An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
267
- to the N-d case.
268
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
269
- Uses three q, k, v linear layers to compute attention.
270
-
271
- Parameters:
272
- channels (`int`): The number of channels in the input and output.
273
- num_head_channels (`int`, *optional*):
274
- The number of channels in each head. If None, then `num_heads` = 1.
275
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
276
- rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
277
- eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
278
- """
279
-
280
- def __init__(
281
- self,
282
- channels: int,
283
- num_head_channels: Optional[int] = None,
284
- norm_num_groups: int = 32,
285
- rescale_output_factor: float = 1.0,
286
- eps: float = 1e-5,
287
- ):
288
- super().__init__()
289
- self.channels = channels
290
-
291
- self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
292
- self.num_head_size = num_head_channels
293
- self.group_norm = nn.GroupNorm(
294
- num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
295
-
296
- # define q,k,v as linear layers
297
- self.query = nn.Linear(channels, channels)
298
- self.key = nn.Linear(channels, channels)
299
- self.value = nn.Linear(channels, channels)
300
-
301
- self.rescale_output_factor = rescale_output_factor
302
- self.proj_attn = nn.Linear(channels, channels, 1)
303
-
304
- def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
305
- new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
306
- # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
307
- new_projection = projection.view(
308
- new_projection_shape).permute(0, 2, 1, 3)
309
- return new_projection
310
-
311
- def forward(self, hidden_states):
312
- residual = hidden_states
313
- batch, channel, height, width = hidden_states.shape
314
-
315
- # norm
316
- hidden_states = self.group_norm(hidden_states)
317
-
318
- hidden_states = hidden_states.view(
319
- batch, channel, height * width).transpose(1, 2)
320
-
321
- # proj to q, k, v
322
- query_proj = self.query(hidden_states)
323
- key_proj = self.key(hidden_states)
324
- value_proj = self.value(hidden_states)
325
-
326
- scale = 1 / math.sqrt(self.channels / self.num_heads)
327
-
328
- # get scores
329
- if self.num_heads > 1:
330
- query_states = self.transpose_for_scores(query_proj)
331
- key_states = self.transpose_for_scores(key_proj)
332
- value_states = self.transpose_for_scores(value_proj)
333
-
334
- # TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
335
- # or reformulate this into a 3D problem?
336
- # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
337
- # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
338
- # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
339
- attention_scores = torch.matmul(
340
- query_states, key_states.transpose(-1, -2)) * scale
341
- else:
342
- query_states, key_states, value_states = query_proj, key_proj, value_proj
343
-
344
- attention_scores = torch.baddbmm(
345
- torch.empty(
346
- query_states.shape[0],
347
- query_states.shape[1],
348
- key_states.shape[1],
349
- dtype=query_states.dtype,
350
- device=query_states.device,
351
- ),
352
- query_states,
353
- key_states.transpose(-1, -2),
354
- beta=0,
355
- alpha=scale,
356
- )
357
-
358
- attention_probs = torch.softmax(
359
- attention_scores.float(), dim=-1).type(attention_scores.dtype)
360
-
361
- # compute attention output
362
- if self.num_heads > 1:
363
- # TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
364
- # or reformulate this into a 3D problem?
365
- # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
366
- # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
367
- # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
368
- hidden_states = torch.matmul(attention_probs, value_states)
369
- hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
370
- new_hidden_states_shape = hidden_states.size()[
371
- :-2] + (self.channels,)
372
- hidden_states = hidden_states.view(new_hidden_states_shape)
373
- else:
374
- hidden_states = torch.bmm(attention_probs, value_states)
375
-
376
- # compute next hidden_states
377
- hidden_states = self.proj_attn(hidden_states)
378
- hidden_states = hidden_states.transpose(
379
- -1, -2).reshape(batch, channel, height, width)
380
-
381
- # res connect and rescale
382
- hidden_states = (hidden_states + residual) / self.rescale_output_factor
383
- return hidden_states
384
 
 
385
 
 
386
  class BasicTransformerBlock(nn.Module):
387
  r"""
388
  A basic Transformer block.
@@ -392,7 +33,11 @@ class BasicTransformerBlock(nn.Module):
392
  num_attention_heads (`int`): The number of heads to use for multi-head attention.
393
  attention_head_dim (`int`): The number of channels in each head.
394
  dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
395
- cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
 
 
 
 
396
  activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
397
  num_embeds_ada_norm (:
398
  obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
@@ -411,264 +56,153 @@ class BasicTransformerBlock(nn.Module):
411
  num_embeds_ada_norm: Optional[int] = None,
412
  attention_bias: bool = False,
413
  only_cross_attention: bool = False,
 
 
 
 
 
414
  ):
415
  super().__init__()
416
  self.only_cross_attention = only_cross_attention
417
- self.attn1 = CrossAttention(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  query_dim=dim,
419
  heads=num_attention_heads,
420
  dim_head=attention_head_dim,
421
  dropout=dropout,
422
  bias=attention_bias,
423
  cross_attention_dim=cross_attention_dim if only_cross_attention else None,
424
- ) # is a self-attention
425
- self.ff = FeedForward(dim, dropout=dropout,
426
- activation_fn=activation_fn)
427
- self.attn2 = CrossAttention(
428
- query_dim=dim,
429
- cross_attention_dim=cross_attention_dim,
430
- heads=num_attention_heads,
431
- dim_head=attention_head_dim,
432
- dropout=dropout,
433
- bias=attention_bias,
434
- ) # is self-attn if context is none
435
-
436
- # layer norms
437
- self.use_ada_layer_norm = num_embeds_ada_norm is not None
438
- if self.use_ada_layer_norm:
439
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
440
- self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
441
- else:
442
- self.norm1 = nn.LayerNorm(dim)
443
- self.norm2 = nn.LayerNorm(dim)
444
- self.norm3 = nn.LayerNorm(dim)
445
-
446
- # if xformers is installed try to use memory_efficient_attention by default
447
- if is_xformers_available():
448
- try:
449
- self._set_use_memory_efficient_attention_xformers(True)
450
- except Exception as e:
451
- warnings.warn(
452
- "Could not enable memory efficient attention. Make sure xformers is installed"
453
- f" correctly and a GPU is available: {e}"
454
- )
455
 
456
- def _set_attention_slice(self, slice_size):
457
- self.attn1._slice_size = slice_size
458
- self.attn2._slice_size = slice_size
459
-
460
- def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
461
- if not is_xformers_available():
462
- print("Here is how to install it")
463
- raise ModuleNotFoundError(
464
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
465
- " xformers",
466
- name="xformers",
467
- )
468
- elif not torch.cuda.is_available():
469
- raise ValueError(
470
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
471
- " available for GPU "
472
  )
 
 
 
 
 
 
 
 
 
473
  else:
474
- try:
475
- # Make sure we can run the memory efficient attention
476
- _ = xformers.ops.memory_efficient_attention(
477
- torch.randn((1, 2, 40), device="cuda"),
478
- torch.randn((1, 2, 40), device="cuda"),
479
- torch.randn((1, 2, 40), device="cuda"),
480
- )
481
- except Exception as e:
482
- raise e
483
- self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
484
- self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
485
 
486
- def forward(self, hidden_states, context=None, timestep=None, text_format_dict={}):
487
- # 1. Self-Attention
488
- norm_hidden_states = (
489
- self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(
490
- hidden_states)
491
- )
 
492
 
493
- if self.only_cross_attention:
494
- attn_out, _ = self.attn1(
495
- norm_hidden_states, context=context, text_format_dict=text_format_dict) + hidden_states
496
- hidden_states = attn_out + hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  else:
498
- attn_out, _ = self.attn1(norm_hidden_states)
499
- hidden_states = attn_out + hidden_states
500
 
501
- # 2. Cross-Attention
502
- norm_hidden_states = (
503
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(
504
- hidden_states)
 
 
 
 
505
  )
506
- attn_out, _ = self.attn2(
507
- norm_hidden_states, context=context, text_format_dict=text_format_dict)
508
- hidden_states = attn_out + hidden_states
509
 
510
- # 3. Feed-forward
511
- hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
 
 
 
512
 
513
- return hidden_states
 
 
 
 
 
 
 
514
 
 
 
515
 
516
- class CrossAttention(nn.Module):
517
- r"""
518
- A cross attention layer.
519
 
520
- Parameters:
521
- query_dim (`int`): The number of channels in the query.
522
- cross_attention_dim (`int`, *optional*):
523
- The number of channels in the context. If not given, defaults to `query_dim`.
524
- heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
525
- dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
526
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
527
- bias (`bool`, *optional*, defaults to False):
528
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
529
- """
530
 
531
- def __init__(
532
- self,
533
- query_dim: int,
534
- cross_attention_dim: Optional[int] = None,
535
- heads: int = 8,
536
- dim_head: int = 64,
537
- dropout: float = 0.0,
538
- bias=False,
539
- ):
540
- super().__init__()
541
- inner_dim = dim_head * heads
542
- self.is_cross_attn = cross_attention_dim is not None
543
- cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
544
-
545
- self.scale = dim_head**-0.5
546
- self.heads = heads
547
- # for slice_size > 0 the attention score computation
548
- # is split across the batch axis to save memory
549
- # You can set slice_size with `set_attention_slice`
550
- self._slice_size = None
551
- self._use_memory_efficient_attention_xformers = False
552
-
553
- self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
554
- self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
555
- self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
556
-
557
- self.to_out = nn.ModuleList([])
558
- self.to_out.append(nn.Linear(inner_dim, query_dim))
559
- self.to_out.append(nn.Dropout(dropout))
560
-
561
- def reshape_heads_to_batch_dim(self, tensor):
562
- batch_size, seq_len, dim = tensor.shape
563
- head_size = self.heads
564
- tensor = tensor.reshape(batch_size, seq_len,
565
- head_size, dim // head_size)
566
- tensor = tensor.permute(0, 2, 1, 3).reshape(
567
- batch_size * head_size, seq_len, dim // head_size)
568
- return tensor
569
-
570
- def reshape_batch_dim_to_heads(self, tensor):
571
- batch_size, seq_len, dim = tensor.shape
572
- head_size = self.heads
573
- tensor = tensor.reshape(batch_size // head_size,
574
- head_size, seq_len, dim)
575
- tensor = tensor.permute(0, 2, 1, 3).reshape(
576
- batch_size // head_size, seq_len, dim * head_size)
577
- return tensor
578
-
579
- def reshape_batch_dim_to_heads_and_average(self, tensor):
580
- batch_size, seq_len, seq_len2 = tensor.shape
581
- head_size = self.heads
582
- tensor = tensor.reshape(batch_size // head_size,
583
- head_size, seq_len, seq_len2)
584
- return tensor.mean(1)
585
-
586
- def forward(self, hidden_states, real_attn_probs=None, context=None, mask=None, text_format_dict={}):
587
- batch_size, sequence_length, _ = hidden_states.shape
588
-
589
- query = self.to_q(hidden_states)
590
- context = context if context is not None else hidden_states
591
- key = self.to_k(context)
592
- value = self.to_v(context)
593
-
594
- dim = query.shape[-1]
595
-
596
- query = self.reshape_heads_to_batch_dim(query)
597
- key = self.reshape_heads_to_batch_dim(key)
598
- value = self.reshape_heads_to_batch_dim(value)
599
-
600
- # attention, what we cannot get enough of
601
- if self._use_memory_efficient_attention_xformers:
602
- hidden_states = self._memory_efficient_attention_xformers(
603
- query, key, value)
604
- # Some versions of xformers return output in fp32, cast it back to the dtype of the input
605
- hidden_states = hidden_states.to(query.dtype)
606
  else:
607
- if self._slice_size is None or query.shape[0] // self._slice_size == 1:
608
- # only this attention function is used
609
- hidden_states, attn_probs = self._attention(
610
- query, key, value, real_attn_probs, **text_format_dict)
611
-
612
- # linear proj
613
- hidden_states = self.to_out[0](hidden_states)
614
- # dropout
615
- hidden_states = self.to_out[1](hidden_states)
616
- return hidden_states, attn_probs
617
-
618
- def _qk(self, query, key):
619
- return torch.baddbmm(
620
- torch.empty(query.shape[0], query.shape[1], key.shape[1],
621
- dtype=query.dtype, device=query.device),
622
- query,
623
- key.transpose(-1, -2),
624
- beta=0,
625
- alpha=self.scale,
626
- )
627
 
628
- def _attention(self, query, key, value, real_attn_probs=None, word_pos=None, font_size=None,
629
- **kwargs):
630
- attention_scores = self._qk(query, key)
631
-
632
- # Font size V2:
633
- if self.is_cross_attn and word_pos is not None and font_size is not None:
634
- assert key.shape[1] == 77
635
- attention_score_exp = attention_scores.exp()
636
- font_size_abs, font_size_sign = font_size.abs(), font_size.sign()
637
- attention_score_exp[:, :, word_pos] = attention_score_exp[:, :, word_pos].clone(
638
- )*font_size_abs
639
- attention_probs = attention_score_exp / \
640
- attention_score_exp.sum(-1, True)
641
- attention_probs[:, :, word_pos] *= font_size_sign
642
- else:
643
- attention_probs = attention_scores.softmax(dim=-1)
644
 
645
- # compute attention output
646
- if real_attn_probs is None:
647
- hidden_states = torch.bmm(attention_probs, value)
648
- else:
649
- if isinstance(real_attn_probs, dict):
650
- for pos1, pos2 in zip(real_attn_probs['inject_pos'][0], real_attn_probs['inject_pos'][1]):
651
- attention_probs[:, :,
652
- pos2] = real_attn_probs['reference'][:, :, pos1]
653
- hidden_states = torch.bmm(attention_probs, value)
654
- else:
655
- hidden_states = torch.bmm(real_attn_probs, value)
656
-
657
- # reshape hidden_states
658
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
659
-
660
- # we also return the map averaged over heads to save memory footprint
661
- attention_probs_avg = self.reshape_batch_dim_to_heads_and_average(
662
- attention_probs)
663
- return hidden_states, [attention_probs_avg, attention_probs]
664
-
665
- def _memory_efficient_attention_xformers(self, query, key, value):
666
- query = query.contiguous()
667
- key = key.contiguous()
668
- value = value.contiguous()
669
- hidden_states = xformers.ops.memory_efficient_attention(
670
- query, key, value, attn_bias=None)
671
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
672
  return hidden_states
673
 
674
 
@@ -682,6 +216,7 @@ class FeedForward(nn.Module):
682
  mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
683
  dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
684
  activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
 
685
  """
686
 
687
  def __init__(
@@ -691,23 +226,31 @@ class FeedForward(nn.Module):
691
  mult: int = 4,
692
  dropout: float = 0.0,
693
  activation_fn: str = "geglu",
 
694
  ):
695
  super().__init__()
696
  inner_dim = int(dim * mult)
697
  dim_out = dim_out if dim_out is not None else dim
698
 
699
- if activation_fn == "geglu":
700
- geglu = GEGLU(dim, inner_dim)
 
 
 
 
701
  elif activation_fn == "geglu-approximate":
702
- geglu = ApproximateGELU(dim, inner_dim)
703
 
704
  self.net = nn.ModuleList([])
705
  # project in
706
- self.net.append(geglu)
707
  # project dropout
708
  self.net.append(nn.Dropout(dropout))
709
  # project out
710
  self.net.append(nn.Linear(inner_dim, dim_out))
 
 
 
711
 
712
  def forward(self, hidden_states):
713
  for module in self.net:
@@ -715,7 +258,28 @@ class FeedForward(nn.Module):
715
  return hidden_states
716
 
717
 
718
- # feedforward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719
  class GEGLU(nn.Module):
720
  r"""
721
  A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
@@ -775,130 +339,53 @@ class AdaLayerNorm(nn.Module):
775
  return x
776
 
777
 
778
- class DualTransformer2DModel(nn.Module):
 
 
779
  """
780
- Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
781
 
782
- Parameters:
783
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
784
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
785
- in_channels (`int`, *optional*):
786
- Pass if the input is continuous. The number of channels in the input and output.
787
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
788
- dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
789
- cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
790
- sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
791
- Note that this is fixed at training time as it is used for learning a number of position embeddings. See
792
- `ImagePositionalEmbeddings`.
793
- num_vector_embeds (`int`, *optional*):
794
- Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
795
- Includes the class for the masked latent pixel.
796
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
797
- num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
798
- The number of diffusion steps used during training. Note that this is fixed at training time as it is used
799
- to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
800
- up to but not more than steps than `num_embeds_ada_norm`.
801
- attention_bias (`bool`, *optional*):
802
- Configure if the TransformerBlocks' attention should contain a bias parameter.
803
  """
804
 
805
  def __init__(
806
- self,
807
- num_attention_heads: int = 16,
808
- attention_head_dim: int = 88,
809
- in_channels: Optional[int] = None,
810
- num_layers: int = 1,
811
- dropout: float = 0.0,
812
- norm_num_groups: int = 32,
813
- cross_attention_dim: Optional[int] = None,
814
- attention_bias: bool = False,
815
- sample_size: Optional[int] = None,
816
- num_vector_embeds: Optional[int] = None,
817
- activation_fn: str = "geglu",
818
- num_embeds_ada_norm: Optional[int] = None,
819
  ):
820
  super().__init__()
821
- self.transformers = nn.ModuleList(
822
- [
823
- Transformer2DModel(
824
- num_attention_heads=num_attention_heads,
825
- attention_head_dim=attention_head_dim,
826
- in_channels=in_channels,
827
- num_layers=num_layers,
828
- dropout=dropout,
829
- norm_num_groups=norm_num_groups,
830
- cross_attention_dim=cross_attention_dim,
831
- attention_bias=attention_bias,
832
- sample_size=sample_size,
833
- num_vector_embeds=num_vector_embeds,
834
- activation_fn=activation_fn,
835
- num_embeds_ada_norm=num_embeds_ada_norm,
836
- )
837
- for _ in range(2)
838
- ]
839
- )
840
 
841
- # Variables that can be set by a pipeline:
842
-
843
- # The ratio of transformer1 to transformer2's output states to be combined during inference
844
- self.mix_ratio = 0.5
845
-
846
- # The shape of `encoder_hidden_states` is expected to be
847
- # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
848
- self.condition_lengths = [77, 257]
849
-
850
- # Which transformer to use to encode which condition.
851
- # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
852
- self.transformer_index_for_condition = [1, 0]
853
-
854
- def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
855
- """
856
- Args:
857
- hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
858
- When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
859
- hidden_states
860
- encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
861
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
862
- self-attention.
863
- timestep ( `torch.long`, *optional*):
864
- Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
865
- return_dict (`bool`, *optional*, defaults to `True`):
866
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
867
-
868
- Returns:
869
- [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
870
- if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
871
- tensor.
872
- """
873
- input_states = hidden_states
874
-
875
- encoded_states = []
876
- tokens_start = 0
877
- for i in range(2):
878
- # for each of the two transformers, pass the corresponding condition tokens
879
- condition_state = encoder_hidden_states[:,
880
- tokens_start: tokens_start + self.condition_lengths[i]]
881
- transformer_index = self.transformer_index_for_condition[i]
882
- encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
883
- 0
884
- ]
885
- encoded_states.append(encoded_state - input_states)
886
- tokens_start += self.condition_lengths[i]
887
-
888
- output_states = encoded_states[0] * self.mix_ratio + \
889
- encoded_states[1] * (1 - self.mix_ratio)
890
- output_states = output_states + input_states
891
-
892
- if not return_dict:
893
- return (output_states,)
894
-
895
- return Transformer2DModelOutput(sample=output_states)
896
-
897
- def _set_attention_slice(self, slice_size):
898
- for transformer in self.transformers:
899
- transformer._set_attention_slice(slice_size)
900
-
901
- def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
902
- for transformer in self.transformers:
903
- transformer._set_use_memory_efficient_attention_xformers(
904
- use_memory_efficient_attention_xformers)
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ from typing import Any, Dict, Optional
 
 
 
15
 
16
  import torch
17
  import torch.nn.functional as F
18
  from torch import nn
19
 
20
+ from diffusers.utils import maybe_allow_in_graph
21
+ from diffusers.models.activations import get_activation
22
+ from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ from models.attention_processor import Attention
25
 
26
+ @maybe_allow_in_graph
27
  class BasicTransformerBlock(nn.Module):
28
  r"""
29
  A basic Transformer block.
 
33
  num_attention_heads (`int`): The number of heads to use for multi-head attention.
34
  attention_head_dim (`int`): The number of channels in each head.
35
  dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
36
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
37
+ only_cross_attention (`bool`, *optional*):
38
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
39
+ double_self_attention (`bool`, *optional*):
40
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
41
  activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
42
  num_embeds_ada_norm (:
43
  obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
 
56
  num_embeds_ada_norm: Optional[int] = None,
57
  attention_bias: bool = False,
58
  only_cross_attention: bool = False,
59
+ double_self_attention: bool = False,
60
+ upcast_attention: bool = False,
61
+ norm_elementwise_affine: bool = True,
62
+ norm_type: str = "layer_norm",
63
+ final_dropout: bool = False,
64
  ):
65
  super().__init__()
66
  self.only_cross_attention = only_cross_attention
67
+
68
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
69
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
70
+
71
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
72
+ raise ValueError(
73
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
74
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
75
+ )
76
+
77
+ # Define 3 blocks. Each block has its own normalization layer.
78
+ # 1. Self-Attn
79
+ if self.use_ada_layer_norm:
80
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
81
+ elif self.use_ada_layer_norm_zero:
82
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
83
+ else:
84
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
85
+ self.attn1 = Attention(
86
  query_dim=dim,
87
  heads=num_attention_heads,
88
  dim_head=attention_head_dim,
89
  dropout=dropout,
90
  bias=attention_bias,
91
  cross_attention_dim=cross_attention_dim if only_cross_attention else None,
92
+ upcast_attention=upcast_attention,
93
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # 2. Cross-Attn
96
+ if cross_attention_dim is not None or double_self_attention:
97
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
98
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
99
+ # the second cross attention block.
100
+ self.norm2 = (
101
+ AdaLayerNorm(dim, num_embeds_ada_norm)
102
+ if self.use_ada_layer_norm
103
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
 
 
 
 
 
 
 
104
  )
105
+ self.attn2 = Attention(
106
+ query_dim=dim,
107
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
108
+ heads=num_attention_heads,
109
+ dim_head=attention_head_dim,
110
+ dropout=dropout,
111
+ bias=attention_bias,
112
+ upcast_attention=upcast_attention,
113
+ ) # is self-attn if encoder_hidden_states is none
114
  else:
115
+ self.norm2 = None
116
+ self.attn2 = None
 
 
 
 
 
 
 
 
 
117
 
118
+ # 3. Feed-forward
119
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
120
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
121
+
122
+ # let chunk size default to None
123
+ self._chunk_size = None
124
+ self._chunk_dim = 0
125
 
126
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
127
+ # Sets chunk feed-forward
128
+ self._chunk_size = chunk_size
129
+ self._chunk_dim = dim
130
+
131
+ def forward(
132
+ self,
133
+ hidden_states: torch.FloatTensor,
134
+ attention_mask: Optional[torch.FloatTensor] = None,
135
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
136
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
137
+ timestep: Optional[torch.LongTensor] = None,
138
+ cross_attention_kwargs: Dict[str, Any] = None,
139
+ class_labels: Optional[torch.LongTensor] = None,
140
+ ):
141
+ # Notice that normalization is always applied before the real computation in the following blocks.
142
+ # 1. Self-Attention
143
+ if self.use_ada_layer_norm:
144
+ norm_hidden_states = self.norm1(hidden_states, timestep)
145
+ elif self.use_ada_layer_norm_zero:
146
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
147
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
148
+ )
149
  else:
150
+ norm_hidden_states = self.norm1(hidden_states)
 
151
 
152
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
153
+
154
+ # Rich-Text: ignore the attention probs
155
+ attn_output, _ = self.attn1(
156
+ norm_hidden_states,
157
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
158
+ attention_mask=attention_mask,
159
+ **cross_attention_kwargs,
160
  )
161
+ if self.use_ada_layer_norm_zero:
162
+ attn_output = gate_msa.unsqueeze(1) * attn_output
163
+ hidden_states = attn_output + hidden_states
164
 
165
+ # 2. Cross-Attention
166
+ if self.attn2 is not None:
167
+ norm_hidden_states = (
168
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
169
+ )
170
 
171
+ # Rich-Text: ignore the attention probs
172
+ attn_output, _ = self.attn2(
173
+ norm_hidden_states,
174
+ encoder_hidden_states=encoder_hidden_states,
175
+ attention_mask=encoder_attention_mask,
176
+ **cross_attention_kwargs,
177
+ )
178
+ hidden_states = attn_output + hidden_states
179
 
180
+ # 3. Feed-forward
181
+ norm_hidden_states = self.norm3(hidden_states)
182
 
183
+ if self.use_ada_layer_norm_zero:
184
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
 
185
 
186
+ if self._chunk_size is not None:
187
+ # "feed_forward_chunk_size" can be used to save memory
188
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
189
+ raise ValueError(
190
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
191
+ )
 
 
 
 
192
 
193
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
194
+ ff_output = torch.cat(
195
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
196
+ dim=self._chunk_dim,
197
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  else:
199
+ ff_output = self.ff(norm_hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ if self.use_ada_layer_norm_zero:
202
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
203
+
204
+ hidden_states = ff_output + hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  return hidden_states
207
 
208
 
 
216
  mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
217
  dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
218
  activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
219
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
220
  """
221
 
222
  def __init__(
 
226
  mult: int = 4,
227
  dropout: float = 0.0,
228
  activation_fn: str = "geglu",
229
+ final_dropout: bool = False,
230
  ):
231
  super().__init__()
232
  inner_dim = int(dim * mult)
233
  dim_out = dim_out if dim_out is not None else dim
234
 
235
+ if activation_fn == "gelu":
236
+ act_fn = GELU(dim, inner_dim)
237
+ if activation_fn == "gelu-approximate":
238
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
239
+ elif activation_fn == "geglu":
240
+ act_fn = GEGLU(dim, inner_dim)
241
  elif activation_fn == "geglu-approximate":
242
+ act_fn = ApproximateGELU(dim, inner_dim)
243
 
244
  self.net = nn.ModuleList([])
245
  # project in
246
+ self.net.append(act_fn)
247
  # project dropout
248
  self.net.append(nn.Dropout(dropout))
249
  # project out
250
  self.net.append(nn.Linear(inner_dim, dim_out))
251
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
252
+ if final_dropout:
253
+ self.net.append(nn.Dropout(dropout))
254
 
255
  def forward(self, hidden_states):
256
  for module in self.net:
 
258
  return hidden_states
259
 
260
 
261
+ class GELU(nn.Module):
262
+ r"""
263
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
264
+ """
265
+
266
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
267
+ super().__init__()
268
+ self.proj = nn.Linear(dim_in, dim_out)
269
+ self.approximate = approximate
270
+
271
+ def gelu(self, gate):
272
+ if gate.device.type != "mps":
273
+ return F.gelu(gate, approximate=self.approximate)
274
+ # mps: gelu is not implemented for float16
275
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
276
+
277
+ def forward(self, hidden_states):
278
+ hidden_states = self.proj(hidden_states)
279
+ hidden_states = self.gelu(hidden_states)
280
+ return hidden_states
281
+
282
+
283
  class GEGLU(nn.Module):
284
  r"""
285
  A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
 
339
  return x
340
 
341
 
342
+ class AdaLayerNormZero(nn.Module):
343
+ """
344
+ Norm layer adaptive layer norm zero (adaLN-Zero).
345
  """
 
346
 
347
+ def __init__(self, embedding_dim, num_embeddings):
348
+ super().__init__()
349
+
350
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
351
+
352
+ self.silu = nn.SiLU()
353
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
354
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
355
+
356
+ def forward(self, x, timestep, class_labels, hidden_dtype=None):
357
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
358
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
359
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
360
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
361
+
362
+
363
+ class AdaGroupNorm(nn.Module):
364
+ """
365
+ GroupNorm layer modified to incorporate timestep embeddings.
 
 
366
  """
367
 
368
  def __init__(
369
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
 
 
 
 
 
 
 
 
 
 
 
 
370
  ):
371
  super().__init__()
372
+ self.num_groups = num_groups
373
+ self.eps = eps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ if act_fn is None:
376
+ self.act = None
377
+ else:
378
+ self.act = get_activation(act_fn)
379
+
380
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
381
+
382
+ def forward(self, x, emb):
383
+ if self.act:
384
+ emb = self.act(emb)
385
+ emb = self.linear(emb)
386
+ emb = emb[:, :, None, None]
387
+ scale, shift = emb.chunk(2, dim=1)
388
+
389
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
390
+ x = x * (1 + scale) + shift
391
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/attention_processor.py ADDED
@@ -0,0 +1,1687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Callable, Optional, Union
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import deprecate, logging, maybe_allow_in_graph
21
+ from diffusers.utils.import_utils import is_xformers_available
22
+
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ @maybe_allow_in_graph
35
+ class Attention(nn.Module):
36
+ r"""
37
+ A cross attention layer.
38
+
39
+ Parameters:
40
+ query_dim (`int`): The number of channels in the query.
41
+ cross_attention_dim (`int`, *optional*):
42
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
43
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
44
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
45
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
46
+ bias (`bool`, *optional*, defaults to False):
47
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ query_dim: int,
53
+ cross_attention_dim: Optional[int] = None,
54
+ heads: int = 8,
55
+ dim_head: int = 64,
56
+ dropout: float = 0.0,
57
+ bias=False,
58
+ upcast_attention: bool = False,
59
+ upcast_softmax: bool = False,
60
+ cross_attention_norm: Optional[str] = None,
61
+ cross_attention_norm_num_groups: int = 32,
62
+ added_kv_proj_dim: Optional[int] = None,
63
+ norm_num_groups: Optional[int] = None,
64
+ spatial_norm_dim: Optional[int] = None,
65
+ out_bias: bool = True,
66
+ scale_qk: bool = True,
67
+ only_cross_attention: bool = False,
68
+ eps: float = 1e-5,
69
+ rescale_output_factor: float = 1.0,
70
+ residual_connection: bool = False,
71
+ _from_deprecated_attn_block=False,
72
+ processor: Optional["AttnProcessor"] = None,
73
+ ):
74
+ super().__init__()
75
+ inner_dim = dim_head * heads
76
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
77
+ self.upcast_attention = upcast_attention
78
+ self.upcast_softmax = upcast_softmax
79
+ self.rescale_output_factor = rescale_output_factor
80
+ self.residual_connection = residual_connection
81
+ self.dropout = dropout
82
+
83
+ # we make use of this private variable to know whether this class is loaded
84
+ # with an deprecated state dict so that we can convert it on the fly
85
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
86
+
87
+ self.scale_qk = scale_qk
88
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
89
+
90
+ self.heads = heads
91
+ # for slice_size > 0 the attention score computation
92
+ # is split across the batch axis to save memory
93
+ # You can set slice_size with `set_attention_slice`
94
+ self.sliceable_head_dim = heads
95
+
96
+ self.added_kv_proj_dim = added_kv_proj_dim
97
+ self.only_cross_attention = only_cross_attention
98
+
99
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
100
+ raise ValueError(
101
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
102
+ )
103
+
104
+ if norm_num_groups is not None:
105
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
106
+ else:
107
+ self.group_norm = None
108
+
109
+ if spatial_norm_dim is not None:
110
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
111
+ else:
112
+ self.spatial_norm = None
113
+
114
+ if cross_attention_norm is None:
115
+ self.norm_cross = None
116
+ elif cross_attention_norm == "layer_norm":
117
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
118
+ elif cross_attention_norm == "group_norm":
119
+ if self.added_kv_proj_dim is not None:
120
+ # The given `encoder_hidden_states` are initially of shape
121
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
122
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
123
+ # before the projection, so we need to use `added_kv_proj_dim` as
124
+ # the number of channels for the group norm.
125
+ norm_cross_num_channels = added_kv_proj_dim
126
+ else:
127
+ norm_cross_num_channels = cross_attention_dim
128
+
129
+ self.norm_cross = nn.GroupNorm(
130
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
131
+ )
132
+ else:
133
+ raise ValueError(
134
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
135
+ )
136
+
137
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
138
+
139
+ if not self.only_cross_attention:
140
+ # only relevant for the `AddedKVProcessor` classes
141
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
142
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
143
+ else:
144
+ self.to_k = None
145
+ self.to_v = None
146
+
147
+ if self.added_kv_proj_dim is not None:
148
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
149
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
150
+
151
+ self.to_out = nn.ModuleList([])
152
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
153
+ self.to_out.append(nn.Dropout(dropout))
154
+
155
+ # set attention processor
156
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
157
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
158
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
159
+ if processor is None:
160
+ processor = (
161
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
162
+ )
163
+ self.set_processor(processor)
164
+
165
+ # Rich-Text: util function for averaging over attention heads
166
+ def reshape_batch_dim_to_heads_and_average(self, tensor):
167
+ batch_size, seq_len, seq_len2 = tensor.shape
168
+ head_size = self.heads
169
+ tensor = tensor.reshape(batch_size // head_size,
170
+ head_size, seq_len, seq_len2)
171
+ return tensor.mean(1)
172
+
173
+ def set_use_memory_efficient_attention_xformers(
174
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
175
+ ):
176
+ is_lora = hasattr(self, "processor") and isinstance(
177
+ self.processor,
178
+ (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
179
+ )
180
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
181
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
182
+ )
183
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
184
+ self.processor,
185
+ (
186
+ AttnAddedKVProcessor,
187
+ AttnAddedKVProcessor2_0,
188
+ SlicedAttnAddedKVProcessor,
189
+ XFormersAttnAddedKVProcessor,
190
+ LoRAAttnAddedKVProcessor,
191
+ ),
192
+ )
193
+
194
+ if use_memory_efficient_attention_xformers:
195
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
196
+ raise NotImplementedError(
197
+ f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
198
+ )
199
+ if not is_xformers_available():
200
+ raise ModuleNotFoundError(
201
+ (
202
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
203
+ " xformers"
204
+ ),
205
+ name="xformers",
206
+ )
207
+ elif not torch.cuda.is_available():
208
+ raise ValueError(
209
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
210
+ " only available for GPU "
211
+ )
212
+ else:
213
+ try:
214
+ # Make sure we can run the memory efficient attention
215
+ _ = xformers.ops.memory_efficient_attention(
216
+ torch.randn((1, 2, 40), device="cuda"),
217
+ torch.randn((1, 2, 40), device="cuda"),
218
+ torch.randn((1, 2, 40), device="cuda"),
219
+ )
220
+ except Exception as e:
221
+ raise e
222
+
223
+ if is_lora:
224
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
225
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
226
+ processor = LoRAXFormersAttnProcessor(
227
+ hidden_size=self.processor.hidden_size,
228
+ cross_attention_dim=self.processor.cross_attention_dim,
229
+ rank=self.processor.rank,
230
+ attention_op=attention_op,
231
+ )
232
+ processor.load_state_dict(self.processor.state_dict())
233
+ processor.to(self.processor.to_q_lora.up.weight.device)
234
+ elif is_custom_diffusion:
235
+ processor = CustomDiffusionXFormersAttnProcessor(
236
+ train_kv=self.processor.train_kv,
237
+ train_q_out=self.processor.train_q_out,
238
+ hidden_size=self.processor.hidden_size,
239
+ cross_attention_dim=self.processor.cross_attention_dim,
240
+ attention_op=attention_op,
241
+ )
242
+ processor.load_state_dict(self.processor.state_dict())
243
+ if hasattr(self.processor, "to_k_custom_diffusion"):
244
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
245
+ elif is_added_kv_processor:
246
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
247
+ # which uses this type of cross attention ONLY because the attention mask of format
248
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
249
+ # throw warning
250
+ logger.info(
251
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
252
+ )
253
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
254
+ else:
255
+ processor = XFormersAttnProcessor(attention_op=attention_op)
256
+ else:
257
+ if is_lora:
258
+ attn_processor_class = (
259
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
260
+ )
261
+ processor = attn_processor_class(
262
+ hidden_size=self.processor.hidden_size,
263
+ cross_attention_dim=self.processor.cross_attention_dim,
264
+ rank=self.processor.rank,
265
+ )
266
+ processor.load_state_dict(self.processor.state_dict())
267
+ processor.to(self.processor.to_q_lora.up.weight.device)
268
+ elif is_custom_diffusion:
269
+ processor = CustomDiffusionAttnProcessor(
270
+ train_kv=self.processor.train_kv,
271
+ train_q_out=self.processor.train_q_out,
272
+ hidden_size=self.processor.hidden_size,
273
+ cross_attention_dim=self.processor.cross_attention_dim,
274
+ )
275
+ processor.load_state_dict(self.processor.state_dict())
276
+ if hasattr(self.processor, "to_k_custom_diffusion"):
277
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
278
+ else:
279
+ # set attention processor
280
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
281
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
282
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
283
+ processor = (
284
+ AttnProcessor2_0()
285
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
286
+ else AttnProcessor()
287
+ )
288
+
289
+ self.set_processor(processor)
290
+
291
+ def set_attention_slice(self, slice_size):
292
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
293
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
294
+
295
+ if slice_size is not None and self.added_kv_proj_dim is not None:
296
+ processor = SlicedAttnAddedKVProcessor(slice_size)
297
+ elif slice_size is not None:
298
+ processor = SlicedAttnProcessor(slice_size)
299
+ elif self.added_kv_proj_dim is not None:
300
+ processor = AttnAddedKVProcessor()
301
+ else:
302
+ # set attention processor
303
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
304
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
305
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
306
+ processor = (
307
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
308
+ )
309
+
310
+ self.set_processor(processor)
311
+
312
+ def set_processor(self, processor: "AttnProcessor"):
313
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
314
+ # pop `processor` from `self._modules`
315
+ if (
316
+ hasattr(self, "processor")
317
+ and isinstance(self.processor, torch.nn.Module)
318
+ and not isinstance(processor, torch.nn.Module)
319
+ ):
320
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
321
+ self._modules.pop("processor")
322
+
323
+ self.processor = processor
324
+
325
+ # Rich-Text: inject self-attention maps
326
+ def forward(self, hidden_states, real_attn_probs=None, attn_weights=None, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
327
+ # The `Attention` class can call different attention processors / attention functions
328
+ # here we simply pass along all tensors to the selected processor class
329
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
330
+ return self.processor(
331
+ self,
332
+ hidden_states,
333
+ real_attn_probs=real_attn_probs,
334
+ attn_weights=attn_weights,
335
+ encoder_hidden_states=encoder_hidden_states,
336
+ attention_mask=attention_mask,
337
+ **cross_attention_kwargs,
338
+ )
339
+
340
+ def batch_to_head_dim(self, tensor):
341
+ head_size = self.heads
342
+ batch_size, seq_len, dim = tensor.shape
343
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
344
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
345
+ return tensor
346
+
347
+ def head_to_batch_dim(self, tensor, out_dim=3):
348
+ head_size = self.heads
349
+ batch_size, seq_len, dim = tensor.shape
350
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
351
+ tensor = tensor.permute(0, 2, 1, 3)
352
+
353
+ if out_dim == 3:
354
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
355
+
356
+ return tensor
357
+
358
+ # Rich-Text: return attention scores
359
+ def get_attention_scores(self, query, key, attention_mask=None, attn_weights=False):
360
+ dtype = query.dtype
361
+ if self.upcast_attention:
362
+ query = query.float()
363
+ key = key.float()
364
+
365
+ if attention_mask is None:
366
+ baddbmm_input = torch.empty(
367
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
368
+ )
369
+ beta = 0
370
+ else:
371
+ baddbmm_input = attention_mask
372
+ beta = 1
373
+
374
+ attention_scores = torch.baddbmm(
375
+ baddbmm_input,
376
+ query,
377
+ key.transpose(-1, -2),
378
+ beta=beta,
379
+ alpha=self.scale,
380
+ )
381
+ del baddbmm_input
382
+
383
+ if self.upcast_softmax:
384
+ attention_scores = attention_scores.float()
385
+
386
+ # Rich-Text: font size
387
+ if attn_weights is not None:
388
+ assert key.shape[1] == 77
389
+ attention_scores_stable = attention_scores - attention_scores.max(-1, True)[0]
390
+ attention_score_exp = attention_scores_stable.float().exp()
391
+ # attention_score_exp = attention_scores.float().exp()
392
+ font_size_abs, font_size_sign = attn_weights['font_size'].abs(), attn_weights['font_size'].sign()
393
+ attention_score_exp[:, :, attn_weights['word_pos']] = attention_score_exp[:, :, attn_weights['word_pos']].clone(
394
+ )*font_size_abs
395
+ attention_probs = attention_score_exp / attention_score_exp.sum(-1, True)
396
+ attention_probs[:, :, attn_weights['word_pos']] *= font_size_sign
397
+ # import ipdb; ipdb.set_trace()
398
+ if attention_probs.isnan().any():
399
+ import ipdb; ipdb.set_trace()
400
+ else:
401
+ attention_probs = attention_scores.softmax(dim=-1)
402
+
403
+ del attention_scores
404
+
405
+ attention_probs = attention_probs.to(dtype)
406
+
407
+ return attention_probs
408
+
409
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
410
+ if batch_size is None:
411
+ deprecate(
412
+ "batch_size=None",
413
+ "0.0.15",
414
+ (
415
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
416
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
417
+ " `prepare_attention_mask` when preparing the attention_mask."
418
+ ),
419
+ )
420
+ batch_size = 1
421
+
422
+ head_size = self.heads
423
+ if attention_mask is None:
424
+ return attention_mask
425
+
426
+ current_length: int = attention_mask.shape[-1]
427
+ if current_length != target_length:
428
+ if attention_mask.device.type == "mps":
429
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
430
+ # Instead, we can manually construct the padding tensor.
431
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
432
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
433
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
434
+ else:
435
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
436
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
437
+ # remaining_length: int = target_length - current_length
438
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
439
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
440
+
441
+ if out_dim == 3:
442
+ if attention_mask.shape[0] < batch_size * head_size:
443
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
444
+ elif out_dim == 4:
445
+ attention_mask = attention_mask.unsqueeze(1)
446
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
447
+
448
+ return attention_mask
449
+
450
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
451
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
452
+
453
+ if isinstance(self.norm_cross, nn.LayerNorm):
454
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
455
+ elif isinstance(self.norm_cross, nn.GroupNorm):
456
+ # Group norm norms along the channels dimension and expects
457
+ # input to be in the shape of (N, C, *). In this case, we want
458
+ # to norm along the hidden dimension, so we need to move
459
+ # (batch_size, sequence_length, hidden_size) ->
460
+ # (batch_size, hidden_size, sequence_length)
461
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
462
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
463
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
464
+ else:
465
+ assert False
466
+
467
+ return encoder_hidden_states
468
+
469
+
470
+ class AttnProcessor:
471
+ r"""
472
+ Default processor for performing attention-related computations.
473
+ """
474
+
475
+ # Rich-Text: inject self-attention maps
476
+ def __call__(
477
+ self,
478
+ attn: Attention,
479
+ hidden_states,
480
+ real_attn_probs=None,
481
+ attn_weights=None,
482
+ encoder_hidden_states=None,
483
+ attention_mask=None,
484
+ temb=None,
485
+ ):
486
+ residual = hidden_states
487
+
488
+ if attn.spatial_norm is not None:
489
+ hidden_states = attn.spatial_norm(hidden_states, temb)
490
+
491
+ input_ndim = hidden_states.ndim
492
+
493
+ if input_ndim == 4:
494
+ batch_size, channel, height, width = hidden_states.shape
495
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
496
+
497
+ batch_size, sequence_length, _ = (
498
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
499
+ )
500
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
501
+
502
+ if attn.group_norm is not None:
503
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
504
+
505
+ query = attn.to_q(hidden_states)
506
+
507
+ if encoder_hidden_states is None:
508
+ encoder_hidden_states = hidden_states
509
+ elif attn.norm_cross:
510
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
511
+
512
+ key = attn.to_k(encoder_hidden_states)
513
+ value = attn.to_v(encoder_hidden_states)
514
+
515
+ query = attn.head_to_batch_dim(query)
516
+ key = attn.head_to_batch_dim(key)
517
+ value = attn.head_to_batch_dim(value)
518
+
519
+ if real_attn_probs is None:
520
+ # Rich-Text: font size
521
+ attention_probs = attn.get_attention_scores(query, key, attention_mask, attn_weights=attn_weights)
522
+ else:
523
+ # Rich-Text: inject self-attention maps
524
+ attention_probs = real_attn_probs
525
+ hidden_states = torch.bmm(attention_probs, value)
526
+ hidden_states = attn.batch_to_head_dim(hidden_states)
527
+
528
+ # linear proj
529
+ hidden_states = attn.to_out[0](hidden_states)
530
+ # dropout
531
+ hidden_states = attn.to_out[1](hidden_states)
532
+
533
+ if input_ndim == 4:
534
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
535
+
536
+ if attn.residual_connection:
537
+ hidden_states = hidden_states + residual
538
+
539
+ hidden_states = hidden_states / attn.rescale_output_factor
540
+
541
+ # Rich-Text Modified: return attn probs
542
+ # We return the map averaged over heads to save memory footprint
543
+ attention_probs_avg = attn.reshape_batch_dim_to_heads_and_average(
544
+ attention_probs)
545
+ return hidden_states, [attention_probs_avg, attention_probs]
546
+
547
+
548
+ class LoRALinearLayer(nn.Module):
549
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None):
550
+ super().__init__()
551
+
552
+ if rank > min(in_features, out_features):
553
+ raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
554
+
555
+ self.down = nn.Linear(in_features, rank, bias=False)
556
+ self.up = nn.Linear(rank, out_features, bias=False)
557
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
558
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
559
+ self.network_alpha = network_alpha
560
+ self.rank = rank
561
+
562
+ nn.init.normal_(self.down.weight, std=1 / rank)
563
+ nn.init.zeros_(self.up.weight)
564
+
565
+ def forward(self, hidden_states):
566
+ orig_dtype = hidden_states.dtype
567
+ dtype = self.down.weight.dtype
568
+
569
+ down_hidden_states = self.down(hidden_states.to(dtype))
570
+ up_hidden_states = self.up(down_hidden_states)
571
+
572
+ if self.network_alpha is not None:
573
+ up_hidden_states *= self.network_alpha / self.rank
574
+
575
+ return up_hidden_states.to(orig_dtype)
576
+
577
+
578
+ class LoRAAttnProcessor(nn.Module):
579
+ r"""
580
+ Processor for implementing the LoRA attention mechanism.
581
+
582
+ Args:
583
+ hidden_size (`int`, *optional*):
584
+ The hidden size of the attention layer.
585
+ cross_attention_dim (`int`, *optional*):
586
+ The number of channels in the `encoder_hidden_states`.
587
+ rank (`int`, defaults to 4):
588
+ The dimension of the LoRA update matrices.
589
+ network_alpha (`int`, *optional*):
590
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
591
+ """
592
+
593
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
594
+ super().__init__()
595
+
596
+ self.hidden_size = hidden_size
597
+ self.cross_attention_dim = cross_attention_dim
598
+ self.rank = rank
599
+
600
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
601
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
602
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
603
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
604
+
605
+ def __call__(
606
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
607
+ ):
608
+ residual = hidden_states
609
+
610
+ if attn.spatial_norm is not None:
611
+ hidden_states = attn.spatial_norm(hidden_states, temb)
612
+
613
+ input_ndim = hidden_states.ndim
614
+
615
+ if input_ndim == 4:
616
+ batch_size, channel, height, width = hidden_states.shape
617
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
618
+
619
+ batch_size, sequence_length, _ = (
620
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
621
+ )
622
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
623
+
624
+ if attn.group_norm is not None:
625
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
626
+
627
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
628
+ query = attn.head_to_batch_dim(query)
629
+
630
+ if encoder_hidden_states is None:
631
+ encoder_hidden_states = hidden_states
632
+ elif attn.norm_cross:
633
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
634
+
635
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
636
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
637
+
638
+ key = attn.head_to_batch_dim(key)
639
+ value = attn.head_to_batch_dim(value)
640
+
641
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
642
+ hidden_states = torch.bmm(attention_probs, value)
643
+ hidden_states = attn.batch_to_head_dim(hidden_states)
644
+
645
+ # linear proj
646
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
647
+ # dropout
648
+ hidden_states = attn.to_out[1](hidden_states)
649
+
650
+ if input_ndim == 4:
651
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
652
+
653
+ if attn.residual_connection:
654
+ hidden_states = hidden_states + residual
655
+
656
+ hidden_states = hidden_states / attn.rescale_output_factor
657
+
658
+ return hidden_states
659
+
660
+
661
+ class CustomDiffusionAttnProcessor(nn.Module):
662
+ r"""
663
+ Processor for implementing attention for the Custom Diffusion method.
664
+
665
+ Args:
666
+ train_kv (`bool`, defaults to `True`):
667
+ Whether to newly train the key and value matrices corresponding to the text features.
668
+ train_q_out (`bool`, defaults to `True`):
669
+ Whether to newly train query matrices corresponding to the latent image features.
670
+ hidden_size (`int`, *optional*, defaults to `None`):
671
+ The hidden size of the attention layer.
672
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
673
+ The number of channels in the `encoder_hidden_states`.
674
+ out_bias (`bool`, defaults to `True`):
675
+ Whether to include the bias parameter in `train_q_out`.
676
+ dropout (`float`, *optional*, defaults to 0.0):
677
+ The dropout probability to use.
678
+ """
679
+
680
+ def __init__(
681
+ self,
682
+ train_kv=True,
683
+ train_q_out=True,
684
+ hidden_size=None,
685
+ cross_attention_dim=None,
686
+ out_bias=True,
687
+ dropout=0.0,
688
+ ):
689
+ super().__init__()
690
+ self.train_kv = train_kv
691
+ self.train_q_out = train_q_out
692
+
693
+ self.hidden_size = hidden_size
694
+ self.cross_attention_dim = cross_attention_dim
695
+
696
+ # `_custom_diffusion` id for easy serialization and loading.
697
+ if self.train_kv:
698
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
699
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
700
+ if self.train_q_out:
701
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
702
+ self.to_out_custom_diffusion = nn.ModuleList([])
703
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
704
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
705
+
706
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
707
+ batch_size, sequence_length, _ = hidden_states.shape
708
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
709
+ if self.train_q_out:
710
+ query = self.to_q_custom_diffusion(hidden_states)
711
+ else:
712
+ query = attn.to_q(hidden_states)
713
+
714
+ if encoder_hidden_states is None:
715
+ crossattn = False
716
+ encoder_hidden_states = hidden_states
717
+ else:
718
+ crossattn = True
719
+ if attn.norm_cross:
720
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
721
+
722
+ if self.train_kv:
723
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
724
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
725
+ else:
726
+ key = attn.to_k(encoder_hidden_states)
727
+ value = attn.to_v(encoder_hidden_states)
728
+
729
+ if crossattn:
730
+ detach = torch.ones_like(key)
731
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
732
+ key = detach * key + (1 - detach) * key.detach()
733
+ value = detach * value + (1 - detach) * value.detach()
734
+
735
+ query = attn.head_to_batch_dim(query)
736
+ key = attn.head_to_batch_dim(key)
737
+ value = attn.head_to_batch_dim(value)
738
+
739
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
740
+ hidden_states = torch.bmm(attention_probs, value)
741
+ hidden_states = attn.batch_to_head_dim(hidden_states)
742
+
743
+ if self.train_q_out:
744
+ # linear proj
745
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
746
+ # dropout
747
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
748
+ else:
749
+ # linear proj
750
+ hidden_states = attn.to_out[0](hidden_states)
751
+ # dropout
752
+ hidden_states = attn.to_out[1](hidden_states)
753
+
754
+ return hidden_states
755
+
756
+
757
+ class AttnAddedKVProcessor:
758
+ r"""
759
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
760
+ encoder.
761
+ """
762
+
763
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
764
+ residual = hidden_states
765
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
766
+ batch_size, sequence_length, _ = hidden_states.shape
767
+
768
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
769
+
770
+ if encoder_hidden_states is None:
771
+ encoder_hidden_states = hidden_states
772
+ elif attn.norm_cross:
773
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
774
+
775
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
776
+
777
+ query = attn.to_q(hidden_states)
778
+ query = attn.head_to_batch_dim(query)
779
+
780
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
781
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
782
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
783
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
784
+
785
+ if not attn.only_cross_attention:
786
+ key = attn.to_k(hidden_states)
787
+ value = attn.to_v(hidden_states)
788
+ key = attn.head_to_batch_dim(key)
789
+ value = attn.head_to_batch_dim(value)
790
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
791
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
792
+ else:
793
+ key = encoder_hidden_states_key_proj
794
+ value = encoder_hidden_states_value_proj
795
+
796
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
797
+ hidden_states = torch.bmm(attention_probs, value)
798
+ hidden_states = attn.batch_to_head_dim(hidden_states)
799
+
800
+ # linear proj
801
+ hidden_states = attn.to_out[0](hidden_states)
802
+ # dropout
803
+ hidden_states = attn.to_out[1](hidden_states)
804
+
805
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
806
+ hidden_states = hidden_states + residual
807
+
808
+ return hidden_states
809
+
810
+
811
+ class AttnAddedKVProcessor2_0:
812
+ r"""
813
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
814
+ learnable key and value matrices for the text encoder.
815
+ """
816
+
817
+ def __init__(self):
818
+ if not hasattr(F, "scaled_dot_product_attention"):
819
+ raise ImportError(
820
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
821
+ )
822
+
823
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
824
+ residual = hidden_states
825
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
826
+ batch_size, sequence_length, _ = hidden_states.shape
827
+
828
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
829
+
830
+ if encoder_hidden_states is None:
831
+ encoder_hidden_states = hidden_states
832
+ elif attn.norm_cross:
833
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
834
+
835
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
836
+
837
+ query = attn.to_q(hidden_states)
838
+ query = attn.head_to_batch_dim(query, out_dim=4)
839
+
840
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
841
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
842
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
843
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
844
+
845
+ if not attn.only_cross_attention:
846
+ key = attn.to_k(hidden_states)
847
+ value = attn.to_v(hidden_states)
848
+ key = attn.head_to_batch_dim(key, out_dim=4)
849
+ value = attn.head_to_batch_dim(value, out_dim=4)
850
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
851
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
852
+ else:
853
+ key = encoder_hidden_states_key_proj
854
+ value = encoder_hidden_states_value_proj
855
+
856
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
857
+ # TODO: add support for attn.scale when we move to Torch 2.1
858
+ hidden_states = F.scaled_dot_product_attention(
859
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
860
+ )
861
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
862
+
863
+ # linear proj
864
+ hidden_states = attn.to_out[0](hidden_states)
865
+ # dropout
866
+ hidden_states = attn.to_out[1](hidden_states)
867
+
868
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
869
+ hidden_states = hidden_states + residual
870
+
871
+ return hidden_states
872
+
873
+
874
+ class LoRAAttnAddedKVProcessor(nn.Module):
875
+ r"""
876
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
877
+ encoder.
878
+
879
+ Args:
880
+ hidden_size (`int`, *optional*):
881
+ The hidden size of the attention layer.
882
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
883
+ The number of channels in the `encoder_hidden_states`.
884
+ rank (`int`, defaults to 4):
885
+ The dimension of the LoRA update matrices.
886
+
887
+ """
888
+
889
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
890
+ super().__init__()
891
+
892
+ self.hidden_size = hidden_size
893
+ self.cross_attention_dim = cross_attention_dim
894
+ self.rank = rank
895
+
896
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
897
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
898
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
899
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
900
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
901
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
902
+
903
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
904
+ residual = hidden_states
905
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
906
+ batch_size, sequence_length, _ = hidden_states.shape
907
+
908
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
909
+
910
+ if encoder_hidden_states is None:
911
+ encoder_hidden_states = hidden_states
912
+ elif attn.norm_cross:
913
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
914
+
915
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
916
+
917
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
918
+ query = attn.head_to_batch_dim(query)
919
+
920
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
921
+ encoder_hidden_states
922
+ )
923
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
924
+ encoder_hidden_states
925
+ )
926
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
927
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
928
+
929
+ if not attn.only_cross_attention:
930
+ key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
931
+ value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
932
+ key = attn.head_to_batch_dim(key)
933
+ value = attn.head_to_batch_dim(value)
934
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
935
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
936
+ else:
937
+ key = encoder_hidden_states_key_proj
938
+ value = encoder_hidden_states_value_proj
939
+
940
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
941
+ hidden_states = torch.bmm(attention_probs, value)
942
+ hidden_states = attn.batch_to_head_dim(hidden_states)
943
+
944
+ # linear proj
945
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
946
+ # dropout
947
+ hidden_states = attn.to_out[1](hidden_states)
948
+
949
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
950
+ hidden_states = hidden_states + residual
951
+
952
+ return hidden_states
953
+
954
+
955
+ class XFormersAttnAddedKVProcessor:
956
+ r"""
957
+ Processor for implementing memory efficient attention using xFormers.
958
+
959
+ Args:
960
+ attention_op (`Callable`, *optional*, defaults to `None`):
961
+ The base
962
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
963
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
964
+ operator.
965
+ """
966
+
967
+ def __init__(self, attention_op: Optional[Callable] = None):
968
+ self.attention_op = attention_op
969
+
970
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
971
+ residual = hidden_states
972
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
973
+ batch_size, sequence_length, _ = hidden_states.shape
974
+
975
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
976
+
977
+ if encoder_hidden_states is None:
978
+ encoder_hidden_states = hidden_states
979
+ elif attn.norm_cross:
980
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
981
+
982
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
983
+
984
+ query = attn.to_q(hidden_states)
985
+ query = attn.head_to_batch_dim(query)
986
+
987
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
988
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
989
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
990
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
991
+
992
+ if not attn.only_cross_attention:
993
+ key = attn.to_k(hidden_states)
994
+ value = attn.to_v(hidden_states)
995
+ key = attn.head_to_batch_dim(key)
996
+ value = attn.head_to_batch_dim(value)
997
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
998
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
999
+ else:
1000
+ key = encoder_hidden_states_key_proj
1001
+ value = encoder_hidden_states_value_proj
1002
+
1003
+ hidden_states = xformers.ops.memory_efficient_attention(
1004
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1005
+ )
1006
+ hidden_states = hidden_states.to(query.dtype)
1007
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1008
+
1009
+ # linear proj
1010
+ hidden_states = attn.to_out[0](hidden_states)
1011
+ # dropout
1012
+ hidden_states = attn.to_out[1](hidden_states)
1013
+
1014
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1015
+ hidden_states = hidden_states + residual
1016
+
1017
+ return hidden_states
1018
+
1019
+
1020
+ class XFormersAttnProcessor:
1021
+ r"""
1022
+ Processor for implementing memory efficient attention using xFormers.
1023
+
1024
+ Args:
1025
+ attention_op (`Callable`, *optional*, defaults to `None`):
1026
+ The base
1027
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1028
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1029
+ operator.
1030
+ """
1031
+
1032
+ def __init__(self, attention_op: Optional[Callable] = None):
1033
+ self.attention_op = attention_op
1034
+
1035
+ def __call__(
1036
+ self,
1037
+ attn: Attention,
1038
+ hidden_states: torch.FloatTensor,
1039
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1040
+ attention_mask: Optional[torch.FloatTensor] = None,
1041
+ temb: Optional[torch.FloatTensor] = None,
1042
+ ):
1043
+ residual = hidden_states
1044
+
1045
+ if attn.spatial_norm is not None:
1046
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1047
+
1048
+ input_ndim = hidden_states.ndim
1049
+
1050
+ if input_ndim == 4:
1051
+ batch_size, channel, height, width = hidden_states.shape
1052
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1053
+
1054
+ batch_size, key_tokens, _ = (
1055
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1056
+ )
1057
+
1058
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1059
+ if attention_mask is not None:
1060
+ # expand our mask's singleton query_tokens dimension:
1061
+ # [batch*heads, 1, key_tokens] ->
1062
+ # [batch*heads, query_tokens, key_tokens]
1063
+ # so that it can be added as a bias onto the attention scores that xformers computes:
1064
+ # [batch*heads, query_tokens, key_tokens]
1065
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1066
+ _, query_tokens, _ = hidden_states.shape
1067
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
1068
+
1069
+ if attn.group_norm is not None:
1070
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1071
+
1072
+ query = attn.to_q(hidden_states)
1073
+
1074
+ if encoder_hidden_states is None:
1075
+ encoder_hidden_states = hidden_states
1076
+ elif attn.norm_cross:
1077
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1078
+
1079
+ key = attn.to_k(encoder_hidden_states)
1080
+ value = attn.to_v(encoder_hidden_states)
1081
+
1082
+ query = attn.head_to_batch_dim(query).contiguous()
1083
+ key = attn.head_to_batch_dim(key).contiguous()
1084
+ value = attn.head_to_batch_dim(value).contiguous()
1085
+
1086
+ hidden_states = xformers.ops.memory_efficient_attention(
1087
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1088
+ )
1089
+ hidden_states = hidden_states.to(query.dtype)
1090
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1091
+
1092
+ # linear proj
1093
+ hidden_states = attn.to_out[0](hidden_states)
1094
+ # dropout
1095
+ hidden_states = attn.to_out[1](hidden_states)
1096
+
1097
+ if input_ndim == 4:
1098
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1099
+
1100
+ if attn.residual_connection:
1101
+ hidden_states = hidden_states + residual
1102
+
1103
+ hidden_states = hidden_states / attn.rescale_output_factor
1104
+
1105
+ return hidden_states
1106
+
1107
+
1108
+ class AttnProcessor2_0:
1109
+ r"""
1110
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1111
+ """
1112
+
1113
+ def __init__(self):
1114
+ if not hasattr(F, "scaled_dot_product_attention"):
1115
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1116
+
1117
+ def __call__(
1118
+ self,
1119
+ attn: Attention,
1120
+ hidden_states,
1121
+ encoder_hidden_states=None,
1122
+ attention_mask=None,
1123
+ temb=None,
1124
+ ):
1125
+ residual = hidden_states
1126
+
1127
+ if attn.spatial_norm is not None:
1128
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1129
+
1130
+ input_ndim = hidden_states.ndim
1131
+
1132
+ if input_ndim == 4:
1133
+ batch_size, channel, height, width = hidden_states.shape
1134
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1135
+
1136
+ batch_size, sequence_length, _ = (
1137
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1138
+ )
1139
+ inner_dim = hidden_states.shape[-1]
1140
+
1141
+ if attention_mask is not None:
1142
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1143
+ # scaled_dot_product_attention expects attention_mask shape to be
1144
+ # (batch, heads, source_length, target_length)
1145
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1146
+
1147
+ if attn.group_norm is not None:
1148
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1149
+
1150
+ query = attn.to_q(hidden_states)
1151
+
1152
+ if encoder_hidden_states is None:
1153
+ encoder_hidden_states = hidden_states
1154
+ elif attn.norm_cross:
1155
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1156
+
1157
+ key = attn.to_k(encoder_hidden_states)
1158
+ value = attn.to_v(encoder_hidden_states)
1159
+
1160
+ head_dim = inner_dim // attn.heads
1161
+
1162
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1163
+
1164
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1165
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1166
+
1167
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1168
+ # TODO: add support for attn.scale when we move to Torch 2.1
1169
+ hidden_states = F.scaled_dot_product_attention(
1170
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1171
+ )
1172
+
1173
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1174
+ hidden_states = hidden_states.to(query.dtype)
1175
+
1176
+ # linear proj
1177
+ hidden_states = attn.to_out[0](hidden_states)
1178
+ # dropout
1179
+ hidden_states = attn.to_out[1](hidden_states)
1180
+
1181
+ if input_ndim == 4:
1182
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1183
+
1184
+ if attn.residual_connection:
1185
+ hidden_states = hidden_states + residual
1186
+
1187
+ hidden_states = hidden_states / attn.rescale_output_factor
1188
+
1189
+ return hidden_states
1190
+
1191
+
1192
+ class LoRAXFormersAttnProcessor(nn.Module):
1193
+ r"""
1194
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1195
+
1196
+ Args:
1197
+ hidden_size (`int`, *optional*):
1198
+ The hidden size of the attention layer.
1199
+ cross_attention_dim (`int`, *optional*):
1200
+ The number of channels in the `encoder_hidden_states`.
1201
+ rank (`int`, defaults to 4):
1202
+ The dimension of the LoRA update matrices.
1203
+ attention_op (`Callable`, *optional*, defaults to `None`):
1204
+ The base
1205
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1206
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1207
+ operator.
1208
+ network_alpha (`int`, *optional*):
1209
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1210
+
1211
+ """
1212
+
1213
+ def __init__(
1214
+ self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
1215
+ ):
1216
+ super().__init__()
1217
+
1218
+ self.hidden_size = hidden_size
1219
+ self.cross_attention_dim = cross_attention_dim
1220
+ self.rank = rank
1221
+ self.attention_op = attention_op
1222
+
1223
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1224
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1225
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1226
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1227
+
1228
+ def __call__(
1229
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
1230
+ ):
1231
+ residual = hidden_states
1232
+
1233
+ if attn.spatial_norm is not None:
1234
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1235
+
1236
+ input_ndim = hidden_states.ndim
1237
+
1238
+ if input_ndim == 4:
1239
+ batch_size, channel, height, width = hidden_states.shape
1240
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1241
+
1242
+ batch_size, sequence_length, _ = (
1243
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1244
+ )
1245
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1246
+
1247
+ if attn.group_norm is not None:
1248
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1249
+
1250
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1251
+ query = attn.head_to_batch_dim(query).contiguous()
1252
+
1253
+ if encoder_hidden_states is None:
1254
+ encoder_hidden_states = hidden_states
1255
+ elif attn.norm_cross:
1256
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1257
+
1258
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1259
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1260
+
1261
+ key = attn.head_to_batch_dim(key).contiguous()
1262
+ value = attn.head_to_batch_dim(value).contiguous()
1263
+
1264
+ hidden_states = xformers.ops.memory_efficient_attention(
1265
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1266
+ )
1267
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1268
+
1269
+ # linear proj
1270
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1271
+ # dropout
1272
+ hidden_states = attn.to_out[1](hidden_states)
1273
+
1274
+ if input_ndim == 4:
1275
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1276
+
1277
+ if attn.residual_connection:
1278
+ hidden_states = hidden_states + residual
1279
+
1280
+ hidden_states = hidden_states / attn.rescale_output_factor
1281
+
1282
+ return hidden_states
1283
+
1284
+
1285
+ class LoRAAttnProcessor2_0(nn.Module):
1286
+ r"""
1287
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1288
+ attention.
1289
+
1290
+ Args:
1291
+ hidden_size (`int`):
1292
+ The hidden size of the attention layer.
1293
+ cross_attention_dim (`int`, *optional*):
1294
+ The number of channels in the `encoder_hidden_states`.
1295
+ rank (`int`, defaults to 4):
1296
+ The dimension of the LoRA update matrices.
1297
+ network_alpha (`int`, *optional*):
1298
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1299
+ """
1300
+
1301
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
1302
+ super().__init__()
1303
+ if not hasattr(F, "scaled_dot_product_attention"):
1304
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1305
+
1306
+ self.hidden_size = hidden_size
1307
+ self.cross_attention_dim = cross_attention_dim
1308
+ self.rank = rank
1309
+
1310
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1311
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1312
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1313
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1314
+
1315
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1316
+ residual = hidden_states
1317
+
1318
+ input_ndim = hidden_states.ndim
1319
+
1320
+ if input_ndim == 4:
1321
+ batch_size, channel, height, width = hidden_states.shape
1322
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1323
+
1324
+ batch_size, sequence_length, _ = (
1325
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1326
+ )
1327
+ inner_dim = hidden_states.shape[-1]
1328
+
1329
+ if attention_mask is not None:
1330
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1331
+ # scaled_dot_product_attention expects attention_mask shape to be
1332
+ # (batch, heads, source_length, target_length)
1333
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1334
+
1335
+ if attn.group_norm is not None:
1336
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1337
+
1338
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1339
+
1340
+ if encoder_hidden_states is None:
1341
+ encoder_hidden_states = hidden_states
1342
+ elif attn.norm_cross:
1343
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1344
+
1345
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1346
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1347
+
1348
+ head_dim = inner_dim // attn.heads
1349
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1350
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1351
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1352
+
1353
+ # TODO: add support for attn.scale when we move to Torch 2.1
1354
+ hidden_states = F.scaled_dot_product_attention(
1355
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1356
+ )
1357
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1358
+ hidden_states = hidden_states.to(query.dtype)
1359
+
1360
+ # linear proj
1361
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1362
+ # dropout
1363
+ hidden_states = attn.to_out[1](hidden_states)
1364
+
1365
+ if input_ndim == 4:
1366
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1367
+
1368
+ if attn.residual_connection:
1369
+ hidden_states = hidden_states + residual
1370
+
1371
+ hidden_states = hidden_states / attn.rescale_output_factor
1372
+
1373
+ return hidden_states
1374
+
1375
+
1376
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1377
+ r"""
1378
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1379
+
1380
+ Args:
1381
+ train_kv (`bool`, defaults to `True`):
1382
+ Whether to newly train the key and value matrices corresponding to the text features.
1383
+ train_q_out (`bool`, defaults to `True`):
1384
+ Whether to newly train query matrices corresponding to the latent image features.
1385
+ hidden_size (`int`, *optional*, defaults to `None`):
1386
+ The hidden size of the attention layer.
1387
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1388
+ The number of channels in the `encoder_hidden_states`.
1389
+ out_bias (`bool`, defaults to `True`):
1390
+ Whether to include the bias parameter in `train_q_out`.
1391
+ dropout (`float`, *optional*, defaults to 0.0):
1392
+ The dropout probability to use.
1393
+ attention_op (`Callable`, *optional*, defaults to `None`):
1394
+ The base
1395
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1396
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1397
+ """
1398
+
1399
+ def __init__(
1400
+ self,
1401
+ train_kv=True,
1402
+ train_q_out=False,
1403
+ hidden_size=None,
1404
+ cross_attention_dim=None,
1405
+ out_bias=True,
1406
+ dropout=0.0,
1407
+ attention_op: Optional[Callable] = None,
1408
+ ):
1409
+ super().__init__()
1410
+ self.train_kv = train_kv
1411
+ self.train_q_out = train_q_out
1412
+
1413
+ self.hidden_size = hidden_size
1414
+ self.cross_attention_dim = cross_attention_dim
1415
+ self.attention_op = attention_op
1416
+
1417
+ # `_custom_diffusion` id for easy serialization and loading.
1418
+ if self.train_kv:
1419
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1420
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1421
+ if self.train_q_out:
1422
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1423
+ self.to_out_custom_diffusion = nn.ModuleList([])
1424
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1425
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1426
+
1427
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1428
+ batch_size, sequence_length, _ = (
1429
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1430
+ )
1431
+
1432
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1433
+
1434
+ if self.train_q_out:
1435
+ query = self.to_q_custom_diffusion(hidden_states)
1436
+ else:
1437
+ query = attn.to_q(hidden_states)
1438
+
1439
+ if encoder_hidden_states is None:
1440
+ crossattn = False
1441
+ encoder_hidden_states = hidden_states
1442
+ else:
1443
+ crossattn = True
1444
+ if attn.norm_cross:
1445
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1446
+
1447
+ if self.train_kv:
1448
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
1449
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
1450
+ else:
1451
+ key = attn.to_k(encoder_hidden_states)
1452
+ value = attn.to_v(encoder_hidden_states)
1453
+
1454
+ if crossattn:
1455
+ detach = torch.ones_like(key)
1456
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1457
+ key = detach * key + (1 - detach) * key.detach()
1458
+ value = detach * value + (1 - detach) * value.detach()
1459
+
1460
+ query = attn.head_to_batch_dim(query).contiguous()
1461
+ key = attn.head_to_batch_dim(key).contiguous()
1462
+ value = attn.head_to_batch_dim(value).contiguous()
1463
+
1464
+ hidden_states = xformers.ops.memory_efficient_attention(
1465
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1466
+ )
1467
+ hidden_states = hidden_states.to(query.dtype)
1468
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1469
+
1470
+ if self.train_q_out:
1471
+ # linear proj
1472
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1473
+ # dropout
1474
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1475
+ else:
1476
+ # linear proj
1477
+ hidden_states = attn.to_out[0](hidden_states)
1478
+ # dropout
1479
+ hidden_states = attn.to_out[1](hidden_states)
1480
+ return hidden_states
1481
+
1482
+
1483
+ class SlicedAttnProcessor:
1484
+ r"""
1485
+ Processor for implementing sliced attention.
1486
+
1487
+ Args:
1488
+ slice_size (`int`, *optional*):
1489
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1490
+ `attention_head_dim` must be a multiple of the `slice_size`.
1491
+ """
1492
+
1493
+ def __init__(self, slice_size):
1494
+ self.slice_size = slice_size
1495
+
1496
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1497
+ residual = hidden_states
1498
+
1499
+ input_ndim = hidden_states.ndim
1500
+
1501
+ if input_ndim == 4:
1502
+ batch_size, channel, height, width = hidden_states.shape
1503
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1504
+
1505
+ batch_size, sequence_length, _ = (
1506
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1507
+ )
1508
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1509
+
1510
+ if attn.group_norm is not None:
1511
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1512
+
1513
+ query = attn.to_q(hidden_states)
1514
+ dim = query.shape[-1]
1515
+ query = attn.head_to_batch_dim(query)
1516
+
1517
+ if encoder_hidden_states is None:
1518
+ encoder_hidden_states = hidden_states
1519
+ elif attn.norm_cross:
1520
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1521
+
1522
+ key = attn.to_k(encoder_hidden_states)
1523
+ value = attn.to_v(encoder_hidden_states)
1524
+ key = attn.head_to_batch_dim(key)
1525
+ value = attn.head_to_batch_dim(value)
1526
+
1527
+ batch_size_attention, query_tokens, _ = query.shape
1528
+ hidden_states = torch.zeros(
1529
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1530
+ )
1531
+
1532
+ for i in range(batch_size_attention // self.slice_size):
1533
+ start_idx = i * self.slice_size
1534
+ end_idx = (i + 1) * self.slice_size
1535
+
1536
+ query_slice = query[start_idx:end_idx]
1537
+ key_slice = key[start_idx:end_idx]
1538
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1539
+
1540
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1541
+
1542
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1543
+
1544
+ hidden_states[start_idx:end_idx] = attn_slice
1545
+
1546
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1547
+
1548
+ # linear proj
1549
+ hidden_states = attn.to_out[0](hidden_states)
1550
+ # dropout
1551
+ hidden_states = attn.to_out[1](hidden_states)
1552
+
1553
+ if input_ndim == 4:
1554
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1555
+
1556
+ if attn.residual_connection:
1557
+ hidden_states = hidden_states + residual
1558
+
1559
+ hidden_states = hidden_states / attn.rescale_output_factor
1560
+
1561
+ return hidden_states
1562
+
1563
+
1564
+ class SlicedAttnAddedKVProcessor:
1565
+ r"""
1566
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1567
+
1568
+ Args:
1569
+ slice_size (`int`, *optional*):
1570
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1571
+ `attention_head_dim` must be a multiple of the `slice_size`.
1572
+ """
1573
+
1574
+ def __init__(self, slice_size):
1575
+ self.slice_size = slice_size
1576
+
1577
+ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1578
+ residual = hidden_states
1579
+
1580
+ if attn.spatial_norm is not None:
1581
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1582
+
1583
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1584
+
1585
+ batch_size, sequence_length, _ = hidden_states.shape
1586
+
1587
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1588
+
1589
+ if encoder_hidden_states is None:
1590
+ encoder_hidden_states = hidden_states
1591
+ elif attn.norm_cross:
1592
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1593
+
1594
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1595
+
1596
+ query = attn.to_q(hidden_states)
1597
+ dim = query.shape[-1]
1598
+ query = attn.head_to_batch_dim(query)
1599
+
1600
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1601
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1602
+
1603
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1604
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1605
+
1606
+ if not attn.only_cross_attention:
1607
+ key = attn.to_k(hidden_states)
1608
+ value = attn.to_v(hidden_states)
1609
+ key = attn.head_to_batch_dim(key)
1610
+ value = attn.head_to_batch_dim(value)
1611
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1612
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1613
+ else:
1614
+ key = encoder_hidden_states_key_proj
1615
+ value = encoder_hidden_states_value_proj
1616
+
1617
+ batch_size_attention, query_tokens, _ = query.shape
1618
+ hidden_states = torch.zeros(
1619
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1620
+ )
1621
+
1622
+ for i in range(batch_size_attention // self.slice_size):
1623
+ start_idx = i * self.slice_size
1624
+ end_idx = (i + 1) * self.slice_size
1625
+
1626
+ query_slice = query[start_idx:end_idx]
1627
+ key_slice = key[start_idx:end_idx]
1628
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1629
+
1630
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1631
+
1632
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1633
+
1634
+ hidden_states[start_idx:end_idx] = attn_slice
1635
+
1636
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1637
+
1638
+ # linear proj
1639
+ hidden_states = attn.to_out[0](hidden_states)
1640
+ # dropout
1641
+ hidden_states = attn.to_out[1](hidden_states)
1642
+
1643
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1644
+ hidden_states = hidden_states + residual
1645
+
1646
+ return hidden_states
1647
+
1648
+
1649
+ AttentionProcessor = Union[
1650
+ AttnProcessor,
1651
+ AttnProcessor2_0,
1652
+ XFormersAttnProcessor,
1653
+ SlicedAttnProcessor,
1654
+ AttnAddedKVProcessor,
1655
+ SlicedAttnAddedKVProcessor,
1656
+ AttnAddedKVProcessor2_0,
1657
+ XFormersAttnAddedKVProcessor,
1658
+ LoRAAttnProcessor,
1659
+ LoRAXFormersAttnProcessor,
1660
+ LoRAAttnProcessor2_0,
1661
+ LoRAAttnAddedKVProcessor,
1662
+ CustomDiffusionAttnProcessor,
1663
+ CustomDiffusionXFormersAttnProcessor,
1664
+ ]
1665
+
1666
+
1667
+ class SpatialNorm(nn.Module):
1668
+ """
1669
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
1670
+ """
1671
+
1672
+ def __init__(
1673
+ self,
1674
+ f_channels,
1675
+ zq_channels,
1676
+ ):
1677
+ super().__init__()
1678
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1679
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1680
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1681
+
1682
+ def forward(self, f, zq):
1683
+ f_size = f.shape[-2:]
1684
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
1685
+ norm_f = self.norm_layer(f)
1686
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1687
+ return new_f
models/dual_transformer_2d.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ from torch import nn
17
+
18
+ from models.transformer_2d import Transformer2DModel, Transformer2DModelOutput
19
+
20
+
21
+ class DualTransformer2DModel(nn.Module):
22
+ """
23
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24
+
25
+ Parameters:
26
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28
+ in_channels (`int`, *optional*):
29
+ Pass if the input is continuous. The number of channels in the input and output.
30
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35
+ `ImagePositionalEmbeddings`.
36
+ num_vector_embeds (`int`, *optional*):
37
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38
+ Includes the class for the masked latent pixel.
39
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43
+ up to but not more than steps than `num_embeds_ada_norm`.
44
+ attention_bias (`bool`, *optional*):
45
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ num_attention_heads: int = 16,
51
+ attention_head_dim: int = 88,
52
+ in_channels: Optional[int] = None,
53
+ num_layers: int = 1,
54
+ dropout: float = 0.0,
55
+ norm_num_groups: int = 32,
56
+ cross_attention_dim: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ sample_size: Optional[int] = None,
59
+ num_vector_embeds: Optional[int] = None,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ ):
63
+ super().__init__()
64
+ self.transformers = nn.ModuleList(
65
+ [
66
+ Transformer2DModel(
67
+ num_attention_heads=num_attention_heads,
68
+ attention_head_dim=attention_head_dim,
69
+ in_channels=in_channels,
70
+ num_layers=num_layers,
71
+ dropout=dropout,
72
+ norm_num_groups=norm_num_groups,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attention_bias=attention_bias,
75
+ sample_size=sample_size,
76
+ num_vector_embeds=num_vector_embeds,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ )
80
+ for _ in range(2)
81
+ ]
82
+ )
83
+
84
+ # Variables that can be set by a pipeline:
85
+
86
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
87
+ self.mix_ratio = 0.5
88
+
89
+ # The shape of `encoder_hidden_states` is expected to be
90
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91
+ self.condition_lengths = [77, 257]
92
+
93
+ # Which transformer to use to encode which condition.
94
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95
+ self.transformer_index_for_condition = [1, 0]
96
+
97
+ def forward(
98
+ self,
99
+ hidden_states,
100
+ encoder_hidden_states,
101
+ timestep=None,
102
+ attention_mask=None,
103
+ cross_attention_kwargs=None,
104
+ return_dict: bool = True,
105
+ ):
106
+ """
107
+ Args:
108
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110
+ hidden_states
111
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
+ self-attention.
114
+ timestep ( `torch.long`, *optional*):
115
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
+ attention_mask (`torch.FloatTensor`, *optional*):
117
+ Optional attention mask to be applied in Attention
118
+ return_dict (`bool`, *optional*, defaults to `True`):
119
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120
+
121
+ Returns:
122
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124
+ returning a tuple, the first element is the sample tensor.
125
+ """
126
+ input_states = hidden_states
127
+
128
+ encoded_states = []
129
+ tokens_start = 0
130
+ # attention_mask is not used yet
131
+ for i in range(2):
132
+ # for each of the two transformers, pass the corresponding condition tokens
133
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134
+ transformer_index = self.transformer_index_for_condition[i]
135
+ encoded_state = self.transformers[transformer_index](
136
+ input_states,
137
+ encoder_hidden_states=condition_state,
138
+ timestep=timestep,
139
+ cross_attention_kwargs=cross_attention_kwargs,
140
+ return_dict=False,
141
+ )[0]
142
+ encoded_states.append(encoded_state - input_states)
143
+ tokens_start += self.condition_lengths[i]
144
+
145
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146
+ output_states = output_states + input_states
147
+
148
+ if not return_dict:
149
+ return (output_states,)
150
+
151
+ return Transformer2DModelOutput(sample=output_states)
models/region_diffusion.py CHANGED
@@ -84,17 +84,19 @@ class RegionDiffusion(nn.Module):
84
  return text_embeddings
85
 
86
  def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
87
- latents=None, use_guidance=False, text_format_dict={}, inject_selfattn=0, inject_background=0):
88
 
89
  if latents is None:
90
  latents = torch.randn(
91
  (1, self.unet.in_channels, height // 8, width // 8), device=self.device)
92
 
93
- if inject_selfattn > 0 or inject_background > 0:
94
  latents_reference = latents.clone().detach()
95
  self.scheduler.set_timesteps(num_inference_steps)
96
  n_styles = text_embeddings.shape[0]-1
 
97
  assert n_styles == len(self.masks)
 
98
  with torch.autocast('cuda'):
99
  for i, t in enumerate(self.scheduler.timesteps):
100
 
@@ -102,34 +104,56 @@ class RegionDiffusion(nn.Module):
102
  with torch.no_grad():
103
  # tokens without any attributes
104
  feat_inject_step = t > (1-inject_selfattn) * 1000
105
- background_inject_step = i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0
106
  noise_pred_uncond_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
107
- text_format_dict={})['sample']
 
 
 
108
  noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[-1:],
109
- text_format_dict=text_format_dict)['sample']
 
 
110
  if inject_selfattn > 0 or inject_background > 0:
111
  noise_pred_uncond_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[:1],
112
- text_format_dict={})['sample']
 
113
  self.register_selfattn_hooks(feat_inject_step)
114
  noise_pred_text_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[-1:],
115
- text_format_dict={})['sample']
 
116
  self.remove_selfattn_hooks()
117
  noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1]
118
  noise_pred_text = noise_pred_text_cur * self.masks[-1]
119
  # tokens with attributes
120
  for style_i, mask in enumerate(self.masks[:-1]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  self.register_replacement_hooks(feat_inject_step)
122
  noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
123
- text_format_dict={})['sample']
 
124
  self.remove_replacement_hooks()
125
  noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
126
  noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
127
-
128
- # perform classifier-free guidance
129
  noise_pred = noise_pred_uncond + guidance_scale * \
130
  (noise_pred_text - noise_pred_uncond)
131
 
132
- if inject_selfattn > 0 or inject_background > 0:
133
  noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
134
  (noise_pred_text_refer - noise_pred_uncond_refer)
135
 
@@ -154,21 +178,25 @@ class RegionDiffusion(nn.Module):
154
  latents_inp = 1 / 0.18215 * latents_0
155
  imgs = self.vae.decode(latents_inp).sample
156
  imgs = (imgs / 2 + 0.5).clamp(0, 1)
 
 
 
 
 
157
  loss_total = 0.
158
  for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
 
 
159
  avg_rgb = (
160
  imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
161
  loss = self.color_loss(
162
  avg_rgb, rgb_val[:, :, 0, 0])*100
 
163
  loss_total += loss
164
  loss_total.backward()
165
  latents = (
166
- latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone()
167
 
168
- # apply background injection
169
- if background_inject_step:
170
- latents = latents_reference * self.masks[-1] + latents * \
171
- (1-self.masks[-1])
172
  return latents
173
 
174
  def predict_x0(self, x_t, eps_t, t):
@@ -244,7 +272,7 @@ class RegionDiffusion(nn.Module):
244
  return latents
245
 
246
  def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
247
- guidance_scale=7.5, latents=None, text_format_dict={}, use_guidance=False, inject_selfattn=0, inject_background=0):
248
 
249
  if isinstance(prompts, str):
250
  prompts = [prompts]
@@ -260,7 +288,7 @@ class RegionDiffusion(nn.Module):
260
  latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
261
  num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
262
  use_guidance=use_guidance, text_format_dict=text_format_dict,
263
- inject_selfattn=inject_selfattn, inject_background=inject_background) # [1, 4, 64, 64]
264
  # Img latents -> imgs
265
  imgs = self.decode_latents(latents) # [1, 3, 512, 512]
266
 
@@ -334,6 +362,8 @@ class RegionDiffusion(nn.Module):
334
  """
335
  # out[0] - final output of residual layer
336
  # out[1] - residual hidden feature
 
 
337
  assert out[1].shape[-1] == 16
338
  activations[name] = out[1].detach()
339
  attention_dict = collections.defaultdict(list)
@@ -459,3 +489,33 @@ class RegionDiffusion(nn.Module):
459
  def remove_selfattn_hooks(self):
460
  for hook in self.selfattn_forward_hooks:
461
  hook.remove()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return text_embeddings
85
 
86
  def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
87
+ latents=None, use_guidance=False, text_format_dict={}, inject_selfattn=0, bg_aug_end=1000):
88
 
89
  if latents is None:
90
  latents = torch.randn(
91
  (1, self.unet.in_channels, height // 8, width // 8), device=self.device)
92
 
93
+ if inject_selfattn > 0:
94
  latents_reference = latents.clone().detach()
95
  self.scheduler.set_timesteps(num_inference_steps)
96
  n_styles = text_embeddings.shape[0]-1
97
+ print(n_styles, len(self.masks))
98
  assert n_styles == len(self.masks)
99
+
100
  with torch.autocast('cuda'):
101
  for i, t in enumerate(self.scheduler.timesteps):
102
 
 
104
  with torch.no_grad():
105
  # tokens without any attributes
106
  feat_inject_step = t > (1-inject_selfattn) * 1000
 
107
  noise_pred_uncond_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
108
+ # text_format_dict={})['sample']
109
+ )['sample']
110
+ # tokens without any style or footnote
111
+ self.register_fontsize_hooks(text_format_dict)
112
  noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[-1:],
113
+ # text_format_dict=text_format_dict)['sample']
114
+ )['sample']
115
+ self.remove_fontsize_hooks()
116
  if inject_selfattn > 0 or inject_background > 0:
117
  noise_pred_uncond_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[:1],
118
+ # text_format_dict={})['sample']
119
+ )['sample']
120
  self.register_selfattn_hooks(feat_inject_step)
121
  noise_pred_text_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[-1:],
122
+ # text_format_dict={})['sample']
123
+ )['sample']
124
  self.remove_selfattn_hooks()
125
  noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1]
126
  noise_pred_text = noise_pred_text_cur * self.masks[-1]
127
  # tokens with attributes
128
  for style_i, mask in enumerate(self.masks[:-1]):
129
+ if t > bg_aug_end:
130
+ rand_rgb = torch.rand([1, 3, 1, 1]).cuda()
131
+ black_background = torch.ones(
132
+ [1, 3, height, width]).cuda()*rand_rgb
133
+ black_latent = self.encode_imgs(
134
+ black_background)
135
+ noise = torch.randn_like(black_latent)
136
+ black_latent_noisy = self.scheduler.add_noise(
137
+ black_latent, noise, t)
138
+ masked_latent = (
139
+ mask > 0.001) * latents + (mask < 0.001) * black_latent_noisy
140
+ noise_pred_uncond_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[:1],
141
+ text_format_dict={})['sample']
142
+ else:
143
+ masked_latent = latents
144
  self.register_replacement_hooks(feat_inject_step)
145
  noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
146
+ # text_format_dict={})['sample']
147
+ )['sample']
148
  self.remove_replacement_hooks()
149
  noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
150
  noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
151
+
152
+ # perform guidance
153
  noise_pred = noise_pred_uncond + guidance_scale * \
154
  (noise_pred_text - noise_pred_uncond)
155
 
156
+ if inject_selfattn > 0:
157
  noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
158
  (noise_pred_text_refer - noise_pred_uncond_refer)
159
 
 
178
  latents_inp = 1 / 0.18215 * latents_0
179
  imgs = self.vae.decode(latents_inp).sample
180
  imgs = (imgs / 2 + 0.5).clamp(0, 1)
181
+ # save_path = 'results/font_color/20230425/church_process/orange/'
182
+ # os.makedirs(save_path, exist_ok=True)
183
+ # torchvision.utils.save_image(
184
+ # imgs, os.path.join(save_path, 'step%d.png' % t))
185
+ # loss = (((imgs - text_format_dict['target_RGB'])*text_format_dict['color_obj_atten'][:, 0])**2).mean()*100
186
  loss_total = 0.
187
  for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
188
+ # loss = self.color_loss(
189
+ # imgs*attn_map[:, 0], rgb_val*attn_map[:, 0])*100
190
  avg_rgb = (
191
  imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
192
  loss = self.color_loss(
193
  avg_rgb, rgb_val[:, :, 0, 0])*100
194
+ # print(loss)
195
  loss_total += loss
196
  loss_total.backward()
197
  latents = (
198
+ latents - latents.grad * text_format_dict['color_guidance_weight'] * self.masks[0]).detach().clone()
199
 
 
 
 
 
200
  return latents
201
 
202
  def predict_x0(self, x_t, eps_t, t):
 
272
  return latents
273
 
274
  def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
275
+ guidance_scale=7.5, latents=None, text_format_dict={}, use_guidance=False, inject_selfattn=0, bg_aug_end=1000):
276
 
277
  if isinstance(prompts, str):
278
  prompts = [prompts]
 
288
  latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
289
  num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
290
  use_guidance=use_guidance, text_format_dict=text_format_dict,
291
+ inject_selfattn=inject_selfattn, bg_aug_end=bg_aug_end) # [1, 4, 64, 64]
292
  # Img latents -> imgs
293
  imgs = self.decode_latents(latents) # [1, 3, 512, 512]
294
 
 
362
  """
363
  # out[0] - final output of residual layer
364
  # out[1] - residual hidden feature
365
+ # import ipdb
366
+ # ipdb.set_trace()
367
  assert out[1].shape[-1] == 16
368
  activations[name] = out[1].detach()
369
  attention_dict = collections.defaultdict(list)
 
489
  def remove_selfattn_hooks(self):
490
  for hook in self.selfattn_forward_hooks:
491
  hook.remove()
492
+
493
+ def register_fontsize_hooks(self, text_format_dict={}):
494
+ r"""Function for registering hooks to replace self attention.
495
+ """
496
+ self.forward_fontsize_hooks = []
497
+
498
+ def adjust_attn_weights(name, module, args):
499
+ r"""
500
+ PyTorch Forward hook to save outputs at each forward pass.
501
+ """
502
+ if 'attn2' in name:
503
+ modified_args = (args[0], None, attn_weights)
504
+ return modified_args
505
+
506
+ if text_format_dict['word_pos'] is not None and text_format_dict['font_size'] is not None:
507
+ attn_weights = {'word_pos': text_format_dict['word_pos'], 'font_size': text_format_dict['font_size']}
508
+ else:
509
+ attn_weights = None
510
+
511
+ for name, module in self.unet.named_modules():
512
+ leaf_name = name.split('.')[-1]
513
+ if 'attn' in leaf_name and attn_weights is not None:
514
+ # Register hook to obtain outputs at every attention layer.
515
+ self.forward_fontsize_hooks.append(module.register_forward_pre_hook(
516
+ partial(adjust_attn_weights, name)
517
+ ))
518
+
519
+ def remove_fontsize_hooks(self):
520
+ for hook in self.forward_fontsize_hooks:
521
+ hook.remove()
models/region_diffusion_xl.py ADDED
@@ -0,0 +1,1138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from diffusers.pipelines.stable_diffusion.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.py
2
+
3
+ import inspect
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
8
+
9
+ from diffusers.image_processor import VaeImageProcessor
10
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
11
+ # from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
+ from diffusers.models import AutoencoderKL
13
+
14
+ from diffusers.models.attention_processor import (
15
+ AttnProcessor2_0,
16
+ LoRAAttnProcessor2_0,
17
+ LoRAXFormersAttnProcessor,
18
+ XFormersAttnProcessor,
19
+ )
20
+ from diffusers.schedulers import EulerDiscreteScheduler
21
+ from diffusers.utils import (
22
+ is_accelerate_available,
23
+ is_accelerate_version,
24
+ logging,
25
+ randn_tensor,
26
+ replace_example_docstring,
27
+ )
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
30
+
31
+ ### cutomized modules
32
+ import collections
33
+ from functools import partial
34
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
35
+
36
+ from models.unet_2d_condition import UNet2DConditionModel
37
+ from utils.attention_utils import CrossAttentionLayers_XL
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
43
+ """
44
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
45
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
46
+ """
47
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
48
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
49
+ # rescale the results from guidance (fixes overexposure)
50
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
51
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
52
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
53
+ return noise_cfg
54
+
55
+
56
+ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
57
+ r"""
58
+ Pipeline for text-to-image generation using Stable Diffusion.
59
+
60
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
61
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
62
+
63
+ In addition the pipeline inherits the following loading methods:
64
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
65
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
66
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
67
+
68
+ as well as the following saving methods:
69
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
70
+
71
+ Args:
72
+ vae ([`AutoencoderKL`]):
73
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
74
+ text_encoder ([`CLIPTextModel`]):
75
+ Frozen text-encoder. Stable Diffusion uses the text portion of
76
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
77
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
78
+ tokenizer (`CLIPTokenizer`):
79
+ Tokenizer of class
80
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
81
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
82
+ scheduler ([`SchedulerMixin`]):
83
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
84
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ load_path: str = "stabilityai/stable-diffusion-xl-base-1.0",
90
+ device: str = "cuda",
91
+ force_zeros_for_empty_prompt: bool = True,
92
+ ):
93
+ super().__init__()
94
+
95
+ # self.register_modules(
96
+ # vae=vae,
97
+ # text_encoder=text_encoder,
98
+ # text_encoder_2=text_encoder_2,
99
+ # tokenizer=tokenizer,
100
+ # tokenizer_2=tokenizer_2,
101
+ # unet=unet,
102
+ # scheduler=scheduler,
103
+ # )
104
+
105
+ # 1. Load the autoencoder model which will be used to decode the latents into image space.
106
+ self.vae = AutoencoderKL.from_pretrained(load_path, subfolder="vae", torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
107
+
108
+ # 2. Load the tokenizer and text encoder to tokenize and encode the text.
109
+ self.tokenizer = CLIPTokenizer.from_pretrained(load_path, subfolder='tokenizer')
110
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(load_path, subfolder='tokenizer_2')
111
+ self.text_encoder = CLIPTextModel.from_pretrained(load_path, subfolder='text_encoder', torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
112
+ self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(load_path, subfolder='text_encoder_2', torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
113
+
114
+ # 3. The UNet model for generating the latents.
115
+ self.unet = UNet2DConditionModel.from_pretrained(load_path, subfolder="unet", torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
116
+
117
+ # 4. Scheduler.
118
+ self.scheduler = EulerDiscreteScheduler.from_pretrained(load_path, subfolder="scheduler")
119
+
120
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
121
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
122
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
123
+ self.default_sample_size = self.unet.config.sample_size
124
+
125
+ self.watermark = StableDiffusionXLWatermarker()
126
+
127
+ self.device_type = device
128
+
129
+ self.masks = []
130
+ self.attention_maps = None
131
+ self.selfattn_maps = None
132
+ self.crossattn_maps = None
133
+ self.color_loss = torch.nn.functional.mse_loss
134
+ self.forward_hooks = []
135
+ self.forward_replacement_hooks = []
136
+
137
+ # Overwriting the method from diffusers.pipelines.diffusion_pipeline.DiffusionPipeline
138
+ @property
139
+ def device(self) -> torch.device:
140
+ r"""
141
+ Returns:
142
+ `torch.device`: The torch device on which the pipeline is located.
143
+ """
144
+
145
+ return torch.device(self.device_type)
146
+
147
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
148
+ def enable_vae_slicing(self):
149
+ r"""
150
+ Enable sliced VAE decoding.
151
+
152
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
153
+ steps. This is useful to save some memory and allow larger batch sizes.
154
+ """
155
+ self.vae.enable_slicing()
156
+
157
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
158
+ def disable_vae_slicing(self):
159
+ r"""
160
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
161
+ computing decoding in one step.
162
+ """
163
+ self.vae.disable_slicing()
164
+
165
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
166
+ def enable_vae_tiling(self):
167
+ r"""
168
+ Enable tiled VAE decoding.
169
+
170
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
171
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
172
+ """
173
+ self.vae.enable_tiling()
174
+
175
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
176
+ def disable_vae_tiling(self):
177
+ r"""
178
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
179
+ computing decoding in one step.
180
+ """
181
+ self.vae.disable_tiling()
182
+
183
+ def enable_sequential_cpu_offload(self, gpu_id=0):
184
+ r"""
185
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
186
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
187
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
188
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
189
+ `enable_model_cpu_offload`, but performance is lower.
190
+ """
191
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
192
+ from accelerate import cpu_offload
193
+ else:
194
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
195
+
196
+ device = torch.device(f"cuda:{gpu_id}")
197
+
198
+ if self.device.type != "cpu":
199
+ self.to("cpu", silence_dtype_warnings=True)
200
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
201
+
202
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]:
203
+ cpu_offload(cpu_offloaded_model, device)
204
+
205
+ def enable_model_cpu_offload(self, gpu_id=0):
206
+ r"""
207
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
208
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
209
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
210
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
211
+ """
212
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
213
+ from accelerate import cpu_offload_with_hook
214
+ else:
215
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
216
+
217
+ device = torch.device(f"cuda:{gpu_id}")
218
+
219
+ if self.device.type != "cpu":
220
+ self.to("cpu", silence_dtype_warnings=True)
221
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
222
+
223
+ model_sequence = (
224
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
225
+ )
226
+ model_sequence.extend([self.unet, self.vae])
227
+
228
+ hook = None
229
+ for cpu_offloaded_model in model_sequence:
230
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
231
+
232
+ # We'll offload the last model manually.
233
+ self.final_offload_hook = hook
234
+
235
+ @property
236
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
237
+ def _execution_device(self):
238
+ r"""
239
+ Returns the device on which the pipeline's models will be executed. After calling
240
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
241
+ hooks.
242
+ """
243
+ if not hasattr(self.unet, "_hf_hook"):
244
+ return self.device
245
+ for module in self.unet.modules():
246
+ if (
247
+ hasattr(module, "_hf_hook")
248
+ and hasattr(module._hf_hook, "execution_device")
249
+ and module._hf_hook.execution_device is not None
250
+ ):
251
+ return torch.device(module._hf_hook.execution_device)
252
+ return self.device
253
+
254
+ def encode_prompt(
255
+ self,
256
+ prompt,
257
+ device: Optional[torch.device] = None,
258
+ num_images_per_prompt: int = 1,
259
+ do_classifier_free_guidance: bool = True,
260
+ negative_prompt=None,
261
+ prompt_embeds: Optional[torch.FloatTensor] = None,
262
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
263
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
264
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
265
+ lora_scale: Optional[float] = None,
266
+ ):
267
+ r"""
268
+ Encodes the prompt into text encoder hidden states.
269
+
270
+ Args:
271
+ prompt (`str` or `List[str]`, *optional*):
272
+ prompt to be encoded
273
+ device: (`torch.device`):
274
+ torch device
275
+ num_images_per_prompt (`int`):
276
+ number of images that should be generated per prompt
277
+ do_classifier_free_guidance (`bool`):
278
+ whether to use classifier free guidance or not
279
+ negative_prompt (`str` or `List[str]`, *optional*):
280
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
281
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
282
+ less than `1`).
283
+ prompt_embeds (`torch.FloatTensor`, *optional*):
284
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
285
+ provided, text embeddings will be generated from `prompt` input argument.
286
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
287
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
288
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
289
+ argument.
290
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
291
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
292
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
293
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
294
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
295
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
296
+ input argument.
297
+ lora_scale (`float`, *optional*):
298
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
299
+ """
300
+ device = device or self._execution_device
301
+
302
+ # set lora scale so that monkey patched LoRA
303
+ # function of text encoder can correctly access it
304
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
305
+ self._lora_scale = lora_scale
306
+
307
+ if prompt is not None and isinstance(prompt, str):
308
+ batch_size = 1
309
+ elif prompt is not None and isinstance(prompt, list):
310
+ batch_size = len(prompt)
311
+ batch_size_neg = len(negative_prompt)
312
+ else:
313
+ batch_size = prompt_embeds.shape[0]
314
+
315
+ # Define tokenizers and text encoders
316
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
317
+ text_encoders = (
318
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
319
+ )
320
+
321
+ if prompt_embeds is None:
322
+ # textual inversion: procecss multi-vector tokens if necessary
323
+ prompt_embeds_list = []
324
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
325
+ if isinstance(self, TextualInversionLoaderMixin):
326
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
327
+
328
+ text_inputs = tokenizer(
329
+ prompt,
330
+ padding="max_length",
331
+ max_length=tokenizer.model_max_length,
332
+ truncation=True,
333
+ return_tensors="pt",
334
+ )
335
+ text_input_ids = text_inputs.input_ids
336
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
337
+
338
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
339
+ text_input_ids, untruncated_ids
340
+ ):
341
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
342
+ logger.warning(
343
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
344
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
345
+ )
346
+
347
+ prompt_embeds = text_encoder(
348
+ text_input_ids.to(device),
349
+ output_hidden_states=True,
350
+ )
351
+
352
+ # We are only ALWAYS interested in the pooled output of the final text encoder
353
+ pooled_prompt_embeds = prompt_embeds[0]
354
+ prompt_embeds = prompt_embeds.hidden_states[-2]
355
+
356
+ bs_embed, seq_len, _ = prompt_embeds.shape
357
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
358
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
359
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
360
+
361
+ prompt_embeds_list.append(prompt_embeds)
362
+
363
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
364
+
365
+ # get unconditional embeddings for classifier free guidance
366
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
367
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
368
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
369
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
370
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
371
+ negative_prompt = negative_prompt or ""
372
+ uncond_tokens: List[str]
373
+ if prompt is not None and type(prompt) is not type(negative_prompt):
374
+ raise TypeError(
375
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
376
+ f" {type(prompt)}."
377
+ )
378
+ elif isinstance(negative_prompt, str):
379
+ uncond_tokens = [negative_prompt]
380
+ # elif batch_size != len(negative_prompt):
381
+ # raise ValueError(
382
+ # f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
383
+ # f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
384
+ # " the batch size of `prompt`."
385
+ # )
386
+ else:
387
+ uncond_tokens = negative_prompt
388
+
389
+ negative_prompt_embeds_list = []
390
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
391
+ # textual inversion: procecss multi-vector tokens if necessary
392
+ if isinstance(self, TextualInversionLoaderMixin):
393
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
394
+
395
+ max_length = prompt_embeds.shape[1]
396
+ uncond_input = tokenizer(
397
+ uncond_tokens,
398
+ padding="max_length",
399
+ max_length=max_length,
400
+ truncation=True,
401
+ return_tensors="pt",
402
+ )
403
+
404
+ negative_prompt_embeds = text_encoder(
405
+ uncond_input.input_ids.to(device),
406
+ output_hidden_states=True,
407
+ )
408
+ # We are only ALWAYS interested in the pooled output of the final text encoder
409
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
410
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
411
+
412
+ if do_classifier_free_guidance:
413
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
414
+ seq_len = negative_prompt_embeds.shape[1]
415
+
416
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
417
+
418
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
419
+ negative_prompt_embeds = negative_prompt_embeds.view(
420
+ batch_size_neg * num_images_per_prompt, seq_len, -1
421
+ )
422
+
423
+ # For classifier free guidance, we need to do two forward passes.
424
+ # Here we concatenate the unconditional and text embeddings into a single batch
425
+ # to avoid doing two forward passes
426
+
427
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
428
+
429
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
430
+
431
+ bs_embed = pooled_prompt_embeds.shape[0]
432
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
433
+ bs_embed * num_images_per_prompt, -1
434
+ )
435
+ bs_embed = negative_pooled_prompt_embeds.shape[0]
436
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
437
+ bs_embed * num_images_per_prompt, -1
438
+ )
439
+
440
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
441
+
442
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
443
+ def prepare_extra_step_kwargs(self, generator, eta):
444
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
445
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
446
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
447
+ # and should be between [0, 1]
448
+
449
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
450
+ extra_step_kwargs = {}
451
+ if accepts_eta:
452
+ extra_step_kwargs["eta"] = eta
453
+
454
+ # check if the scheduler accepts generator
455
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
456
+ if accepts_generator:
457
+ extra_step_kwargs["generator"] = generator
458
+ return extra_step_kwargs
459
+
460
+ def check_inputs(
461
+ self,
462
+ prompt,
463
+ height,
464
+ width,
465
+ callback_steps,
466
+ negative_prompt=None,
467
+ prompt_embeds=None,
468
+ negative_prompt_embeds=None,
469
+ pooled_prompt_embeds=None,
470
+ negative_pooled_prompt_embeds=None,
471
+ ):
472
+ if height % 8 != 0 or width % 8 != 0:
473
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
474
+
475
+ if (callback_steps is None) or (
476
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
477
+ ):
478
+ raise ValueError(
479
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
480
+ f" {type(callback_steps)}."
481
+ )
482
+
483
+ if prompt is not None and prompt_embeds is not None:
484
+ raise ValueError(
485
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
486
+ " only forward one of the two."
487
+ )
488
+ elif prompt is None and prompt_embeds is None:
489
+ raise ValueError(
490
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
491
+ )
492
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
493
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
494
+
495
+ if negative_prompt is not None and negative_prompt_embeds is not None:
496
+ raise ValueError(
497
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
498
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
499
+ )
500
+
501
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
502
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
503
+ raise ValueError(
504
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
505
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
506
+ f" {negative_prompt_embeds.shape}."
507
+ )
508
+
509
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
510
+ raise ValueError(
511
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
512
+ )
513
+
514
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
515
+ raise ValueError(
516
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
517
+ )
518
+
519
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
520
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
521
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
522
+ if isinstance(generator, list) and len(generator) != batch_size:
523
+ raise ValueError(
524
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
525
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
526
+ )
527
+
528
+ if latents is None:
529
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
530
+ else:
531
+ latents = latents.to(device)
532
+
533
+ # scale the initial noise by the standard deviation required by the scheduler
534
+ latents = latents * self.scheduler.init_noise_sigma
535
+ return latents
536
+
537
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
538
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
539
+
540
+ passed_add_embed_dim = (
541
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
542
+ )
543
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
544
+
545
+ if expected_add_embed_dim != passed_add_embed_dim:
546
+ raise ValueError(
547
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
548
+ )
549
+
550
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
551
+ return add_time_ids
552
+
553
+ @torch.no_grad()
554
+ def sample(
555
+ self,
556
+ prompt: Union[str, List[str]] = None,
557
+ height: Optional[int] = None,
558
+ width: Optional[int] = None,
559
+ num_inference_steps: int = 50,
560
+ guidance_scale: float = 5.0,
561
+ negative_prompt: Optional[Union[str, List[str]]] = None,
562
+ num_images_per_prompt: Optional[int] = 1,
563
+ eta: float = 0.0,
564
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
565
+ latents: Optional[torch.FloatTensor] = None,
566
+ prompt_embeds: Optional[torch.FloatTensor] = None,
567
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
568
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
569
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
570
+ output_type: Optional[str] = "pil",
571
+ return_dict: bool = True,
572
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
573
+ callback_steps: int = 1,
574
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
575
+ guidance_rescale: float = 0.0,
576
+ original_size: Optional[Tuple[int, int]] = None,
577
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
578
+ target_size: Optional[Tuple[int, int]] = None,
579
+ # Rich-Text args
580
+ use_guidance: bool = False,
581
+ inject_selfattn: float = 0.0,
582
+ inject_background: float = 0.0,
583
+ text_format_dict: Optional[dict] = None,
584
+ run_rich_text: bool = False,
585
+ ):
586
+ r"""
587
+ Function invoked when calling the pipeline for generation.
588
+
589
+ Args:
590
+ prompt (`str` or `List[str]`, *optional*):
591
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
592
+ instead.
593
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
594
+ The height in pixels of the generated image.
595
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
596
+ The width in pixels of the generated image.
597
+ num_inference_steps (`int`, *optional*, defaults to 50):
598
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
599
+ expense of slower inference.
600
+ guidance_scale (`float`, *optional*, defaults to 7.5):
601
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
602
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
603
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
604
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
605
+ usually at the expense of lower image quality.
606
+ negative_prompt (`str` or `List[str]`, *optional*):
607
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
608
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
609
+ less than `1`).
610
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
611
+ The number of images to generate per prompt.
612
+ eta (`float`, *optional*, defaults to 0.0):
613
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
614
+ [`schedulers.DDIMScheduler`], will be ignored for others.
615
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
616
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
617
+ to make generation deterministic.
618
+ latents (`torch.FloatTensor`, *optional*):
619
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
620
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
621
+ tensor will ge generated by sampling using the supplied random `generator`.
622
+ prompt_embeds (`torch.FloatTensor`, *optional*):
623
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
624
+ provided, text embeddings will be generated from `prompt` input argument.
625
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
626
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
627
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
628
+ argument.
629
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
630
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
631
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
632
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
633
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
634
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
635
+ input argument.
636
+ output_type (`str`, *optional*, defaults to `"pil"`):
637
+ The output format of the generate image. Choose between
638
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
639
+ return_dict (`bool`, *optional*, defaults to `True`):
640
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
641
+ plain tuple.
642
+ callback (`Callable`, *optional*):
643
+ A function that will be called every `callback_steps` steps during inference. The function will be
644
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
645
+ callback_steps (`int`, *optional*, defaults to 1):
646
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
647
+ called at every step.
648
+ cross_attention_kwargs (`dict`, *optional*):
649
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
650
+ `self.processor` in
651
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
652
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
653
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
654
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
655
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
656
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
657
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
658
+ TODO
659
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
660
+ TODO
661
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
662
+ TODO
663
+
664
+ Examples:
665
+
666
+ Returns:
667
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
668
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
669
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
670
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
671
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
672
+ """
673
+ # 0. Default height and width to unet
674
+ height = height or self.default_sample_size * self.vae_scale_factor
675
+ width = width or self.default_sample_size * self.vae_scale_factor
676
+
677
+ original_size = original_size or (height, width)
678
+ target_size = target_size or (height, width)
679
+
680
+ # 1. Check inputs. Raise error if not correct
681
+ self.check_inputs(
682
+ prompt,
683
+ height,
684
+ width,
685
+ callback_steps,
686
+ negative_prompt,
687
+ prompt_embeds,
688
+ negative_prompt_embeds,
689
+ pooled_prompt_embeds,
690
+ negative_pooled_prompt_embeds,
691
+ )
692
+
693
+ # 2. Define call parameters
694
+ if prompt is not None and isinstance(prompt, str):
695
+ batch_size = 1
696
+ elif prompt is not None and isinstance(prompt, list):
697
+ # TODO: support batched prompts
698
+ batch_size = 1
699
+ # batch_size = len(prompt)
700
+ else:
701
+ batch_size = prompt_embeds.shape[0]
702
+
703
+ device = self._execution_device
704
+
705
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
706
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
707
+ # corresponds to doing no classifier free guidance.
708
+ do_classifier_free_guidance = guidance_scale > 1.0
709
+
710
+ # 3. Encode input prompt
711
+ text_encoder_lora_scale = (
712
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
713
+ )
714
+ (
715
+ prompt_embeds,
716
+ negative_prompt_embeds,
717
+ pooled_prompt_embeds,
718
+ negative_pooled_prompt_embeds,
719
+ ) = self.encode_prompt(
720
+ prompt,
721
+ device,
722
+ num_images_per_prompt,
723
+ do_classifier_free_guidance,
724
+ negative_prompt,
725
+ prompt_embeds=prompt_embeds,
726
+ negative_prompt_embeds=negative_prompt_embeds,
727
+ pooled_prompt_embeds=pooled_prompt_embeds,
728
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
729
+ lora_scale=text_encoder_lora_scale,
730
+ )
731
+
732
+ # 4. Prepare timesteps
733
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
734
+
735
+ timesteps = self.scheduler.timesteps
736
+
737
+ # 5. Prepare latent variables
738
+ num_channels_latents = self.unet.config.in_channels
739
+ latents = self.prepare_latents(
740
+ batch_size * num_images_per_prompt,
741
+ num_channels_latents,
742
+ height,
743
+ width,
744
+ prompt_embeds.dtype,
745
+ device,
746
+ generator,
747
+ latents,
748
+ )
749
+
750
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
751
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
752
+
753
+ # 7. Prepare added time ids & embeddings
754
+ add_text_embeds = pooled_prompt_embeds
755
+ add_time_ids = self._get_add_time_ids(
756
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
757
+ )
758
+
759
+ if do_classifier_free_guidance:
760
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
761
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
762
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
763
+
764
+ prompt_embeds = prompt_embeds.to(device)
765
+ add_text_embeds = add_text_embeds.to(device)
766
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
767
+
768
+ # 8. Denoising loop
769
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
770
+ if run_rich_text:
771
+ if inject_selfattn > 0 or inject_background > 0:
772
+ latents_reference = latents.clone().detach()
773
+ n_styles = prompt_embeds.shape[0]-1
774
+ self.masks = [mask.to(dtype=prompt_embeds.dtype) for mask in self.masks]
775
+ print(n_styles, len(self.masks))
776
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
777
+ for i, t in enumerate(self.scheduler.timesteps):
778
+ # predict the noise residual
779
+ with torch.no_grad():
780
+ feat_inject_step = t > (1-inject_selfattn) * 1000
781
+ background_inject_step = i < inject_background * len(self.scheduler.timesteps)
782
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
783
+ # import ipdb;ipdb.set_trace()
784
+ # unconditional prediction
785
+ noise_pred_uncond_cur = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds[:1],
786
+ cross_attention_kwargs=cross_attention_kwargs,
787
+ added_cond_kwargs={"text_embeds": add_text_embeds[:1], "time_ids": add_time_ids[:1]}
788
+ )['sample']
789
+ # tokens without any style or footnote
790
+ self.register_fontsize_hooks(text_format_dict)
791
+ noise_pred_text_cur = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds[-1:],
792
+ cross_attention_kwargs=cross_attention_kwargs,
793
+ added_cond_kwargs={"text_embeds": add_text_embeds[-1:], "time_ids": add_time_ids[:1]}
794
+ )['sample']
795
+ self.remove_fontsize_hooks()
796
+ if inject_selfattn > 0 or inject_background > 0:
797
+ latent_reference_model_input = self.scheduler.scale_model_input(latents_reference, t)
798
+ noise_pred_uncond_refer = self.unet(latent_reference_model_input, t, encoder_hidden_states=prompt_embeds[:1],
799
+ cross_attention_kwargs=cross_attention_kwargs,
800
+ added_cond_kwargs={"text_embeds": add_text_embeds[:1], "time_ids": add_time_ids[:1]}
801
+ )['sample']
802
+ self.register_selfattn_hooks(feat_inject_step)
803
+ noise_pred_text_refer = self.unet(latent_reference_model_input, t, encoder_hidden_states=prompt_embeds[-1:],
804
+ cross_attention_kwargs=cross_attention_kwargs,
805
+ added_cond_kwargs={"text_embeds": add_text_embeds[-1:], "time_ids": add_time_ids[:1]}
806
+ )['sample']
807
+ self.remove_selfattn_hooks()
808
+ noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1]
809
+ noise_pred_text = noise_pred_text_cur * self.masks[-1]
810
+ # tokens with style or footnote
811
+ for style_i, mask in enumerate(self.masks[:-1]):
812
+ self.register_replacement_hooks(feat_inject_step)
813
+ noise_pred_text_cur = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds[style_i+1:style_i+2],
814
+ cross_attention_kwargs=cross_attention_kwargs,
815
+ added_cond_kwargs={"text_embeds": add_text_embeds[style_i+1:style_i+2], "time_ids": add_time_ids[:1]}
816
+ )['sample']
817
+ self.remove_replacement_hooks()
818
+ noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
819
+ noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
820
+
821
+ # perform guidance
822
+ noise_pred = noise_pred_uncond + guidance_scale * \
823
+ (noise_pred_text - noise_pred_uncond)
824
+
825
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
826
+ # TODO: Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
827
+ # noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
828
+ raise NotImplementedError
829
+
830
+ if inject_selfattn > 0 or background_inject_step > 0:
831
+ noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
832
+ (noise_pred_text_refer - noise_pred_uncond_refer)
833
+
834
+ # compute the previous noisy sample x_t -> x_t-1
835
+ latents_reference = self.scheduler.step(torch.cat([noise_pred, noise_pred_refer]), t,
836
+ torch.cat([latents, latents_reference]))[
837
+ 'prev_sample']
838
+ latents, latents_reference = torch.chunk(
839
+ latents_reference, 2, dim=0)
840
+
841
+ else:
842
+ # compute the previous noisy sample x_t -> x_t-1
843
+ latents = self.scheduler.step(noise_pred, t, latents)[
844
+ 'prev_sample']
845
+
846
+ # apply guidance
847
+ if use_guidance and t < text_format_dict['guidance_start_step']:
848
+ with torch.enable_grad():
849
+ if not latents.requires_grad:
850
+ latents.requires_grad = True
851
+ # import ipdb;ipdb.set_trace()
852
+ latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=latents.dtype)
853
+ latents_inp = latents_0 / self.vae.config.scaling_factor
854
+ imgs = self.vae.decode(latents_inp.to(dtype=torch.float32)).sample
855
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
856
+ loss_total = 0.
857
+ for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
858
+ avg_rgb = (
859
+ imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
860
+ loss = self.color_loss(
861
+ avg_rgb, rgb_val[:, :, 0, 0])*100
862
+ loss_total += loss
863
+ loss_total.backward()
864
+ latents = (
865
+ latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone().to(dtype=prompt_embeds.dtype)
866
+
867
+ # apply background injection
868
+ if i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0:
869
+ latents = latents_reference * self.masks[-1] + latents * \
870
+ (1-self.masks[-1])
871
+
872
+ # call the callback, if provided
873
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
874
+ progress_bar.update()
875
+ if callback is not None and i % callback_steps == 0:
876
+ callback(i, t, latents)
877
+ else:
878
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
879
+ for i, t in enumerate(timesteps):
880
+ # expand the latents if we are doing classifier free guidance
881
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
882
+
883
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
884
+
885
+ # predict the noise residual
886
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
887
+ noise_pred = self.unet(
888
+ latent_model_input,
889
+ t,
890
+ encoder_hidden_states=prompt_embeds,
891
+ cross_attention_kwargs=cross_attention_kwargs,
892
+ added_cond_kwargs=added_cond_kwargs,
893
+ return_dict=False,
894
+ )[0]
895
+
896
+ # perform guidance
897
+ if do_classifier_free_guidance:
898
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
899
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
900
+
901
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
902
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
903
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
904
+
905
+ # compute the previous noisy sample x_t -> x_t-1
906
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
907
+
908
+ # call the callback, if provided
909
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
910
+ progress_bar.update()
911
+ if callback is not None and i % callback_steps == 0:
912
+ callback(i, t, latents)
913
+
914
+ # make sure the VAE is in float32 mode, as it overflows in float16
915
+ self.vae.to(dtype=torch.float32)
916
+
917
+ use_torch_2_0_or_xformers = isinstance(
918
+ self.vae.decoder.mid_block.attentions[0].processor,
919
+ (
920
+ AttnProcessor2_0,
921
+ XFormersAttnProcessor,
922
+ LoRAXFormersAttnProcessor,
923
+ LoRAAttnProcessor2_0,
924
+ ),
925
+ )
926
+ # if xformers or torch_2_0 is used attention block does not need
927
+ # to be in float32 which can save lots of memory
928
+ if use_torch_2_0_or_xformers:
929
+ self.vae.post_quant_conv.to(latents.dtype)
930
+ self.vae.decoder.conv_in.to(latents.dtype)
931
+ self.vae.decoder.mid_block.to(latents.dtype)
932
+ else:
933
+ latents = latents.float()
934
+
935
+ if not output_type == "latent":
936
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
937
+ else:
938
+ image = latents
939
+ return StableDiffusionXLPipelineOutput(images=image)
940
+
941
+ image = self.watermark.apply_watermark(image)
942
+ image = self.image_processor.postprocess(image, output_type=output_type)
943
+
944
+ # Offload last model to CPU
945
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
946
+ self.final_offload_hook.offload()
947
+
948
+ if not return_dict:
949
+ return (image,)
950
+
951
+ return StableDiffusionXLPipelineOutput(images=image)
952
+
953
+ def predict_x0(self, x_t, eps_t, t):
954
+ alpha_t = self.scheduler.alphas_cumprod[t.cpu().long().item()]
955
+ return (x_t - eps_t * torch.sqrt(1-alpha_t)) / torch.sqrt(alpha_t)
956
+
957
+ def register_tokenmap_hooks(self):
958
+ r"""Function for registering hooks during evaluation.
959
+ We mainly store activation maps averaged over queries.
960
+ """
961
+ self.forward_hooks = []
962
+
963
+ def save_activations(selfattn_maps, crossattn_maps, n_maps, name, module, inp, out):
964
+ r"""
965
+ PyTorch Forward hook to save outputs at each forward pass.
966
+ """
967
+ # out[0] - final output of attention layer
968
+ # out[1] - attention probability matrices
969
+ if name in n_maps:
970
+ n_maps[name] += 1
971
+ else:
972
+ n_maps[name] = 1
973
+ if 'attn2' in name:
974
+ assert out[1][0].shape[-1] == 77
975
+ if name in CrossAttentionLayers_XL and n_maps[name] > 10:
976
+ # if n_maps[name] > 10:
977
+ if name in crossattn_maps:
978
+ crossattn_maps[name] += out[1][0].detach().cpu()[1:2]
979
+ else:
980
+ crossattn_maps[name] = out[1][0].detach().cpu()[1:2]
981
+ # For visualization
982
+ # crossattn_maps[name].append(out[1][0].detach().cpu()[1:2])
983
+ else:
984
+ assert out[1][0].shape[-1] != 77
985
+ # if name in SelfAttentionLayers and n_maps[name] > 10:
986
+ if n_maps[name] > 10:
987
+ if name in selfattn_maps:
988
+ selfattn_maps[name] += out[1][0].detach().cpu()[1:2]
989
+ else:
990
+ selfattn_maps[name] = out[1][0].detach().cpu()[1:2]
991
+
992
+ selfattn_maps = collections.defaultdict(list)
993
+ crossattn_maps = collections.defaultdict(list)
994
+ n_maps = collections.defaultdict(list)
995
+
996
+ for name, module in self.unet.named_modules():
997
+ leaf_name = name.split('.')[-1]
998
+ if 'attn' in leaf_name:
999
+ # Register hook to obtain outputs at every attention layer.
1000
+ self.forward_hooks.append(module.register_forward_hook(
1001
+ partial(save_activations, selfattn_maps,
1002
+ crossattn_maps, n_maps, name)
1003
+ ))
1004
+ # attention_dict is a dictionary containing attention maps for every attention layer
1005
+ self.selfattn_maps = selfattn_maps
1006
+ self.crossattn_maps = crossattn_maps
1007
+ self.n_maps = n_maps
1008
+
1009
+ def remove_tokenmap_hooks(self):
1010
+ for hook in self.forward_hooks:
1011
+ hook.remove()
1012
+ self.selfattn_maps = None
1013
+ self.crossattn_maps = None
1014
+ self.n_maps = None
1015
+
1016
+ def register_replacement_hooks(self, feat_inject_step=False):
1017
+ r"""Function for registering hooks to replace self attention.
1018
+ """
1019
+ self.forward_replacement_hooks = []
1020
+
1021
+ def replace_activations(name, module, args):
1022
+ r"""
1023
+ PyTorch Forward hook to save outputs at each forward pass.
1024
+ """
1025
+ if 'attn1' in name:
1026
+ modified_args = (args[0], self.self_attention_maps_cur[name])
1027
+ return modified_args
1028
+ # cross attention injection
1029
+ # elif 'attn2' in name:
1030
+ # modified_map = {
1031
+ # 'reference': self.self_attention_maps_cur[name],
1032
+ # 'inject_pos': self.inject_pos,
1033
+ # }
1034
+ # modified_args = (args[0], modified_map)
1035
+ # return modified_args
1036
+
1037
+ def replace_resnet_activations(name, module, args):
1038
+ r"""
1039
+ PyTorch Forward hook to save outputs at each forward pass.
1040
+ """
1041
+ modified_args = (args[0], args[1],
1042
+ self.self_attention_maps_cur[name])
1043
+ return modified_args
1044
+ for name, module in self.unet.named_modules():
1045
+ leaf_name = name.split('.')[-1]
1046
+ if 'attn' in leaf_name and feat_inject_step:
1047
+ # Register hook to obtain outputs at every attention layer.
1048
+ self.forward_replacement_hooks.append(module.register_forward_pre_hook(
1049
+ partial(replace_activations, name)
1050
+ ))
1051
+ if name == 'up_blocks.1.resnets.1' and feat_inject_step:
1052
+ # Register hook to obtain outputs at every attention layer.
1053
+ self.forward_replacement_hooks.append(module.register_forward_pre_hook(
1054
+ partial(replace_resnet_activations, name)
1055
+ ))
1056
+
1057
+ def remove_replacement_hooks(self):
1058
+ for hook in self.forward_replacement_hooks:
1059
+ hook.remove()
1060
+
1061
+
1062
+ def register_selfattn_hooks(self, feat_inject_step=False):
1063
+ r"""Function for registering hooks during evaluation.
1064
+ We mainly store activation maps averaged over queries.
1065
+ """
1066
+ self.selfattn_forward_hooks = []
1067
+
1068
+ def save_activations(activations, name, module, inp, out):
1069
+ r"""
1070
+ PyTorch Forward hook to save outputs at each forward pass.
1071
+ """
1072
+ # out[0] - final output of attention layer
1073
+ # out[1] - attention probability matrix
1074
+ if 'attn2' in name:
1075
+ assert out[1][1].shape[-1] == 77
1076
+ # cross attention injection
1077
+ # activations[name] = out[1][1].detach()
1078
+ else:
1079
+ assert out[1][1].shape[-1] != 77
1080
+ activations[name] = out[1][1].detach()
1081
+
1082
+ def save_resnet_activations(activations, name, module, inp, out):
1083
+ r"""
1084
+ PyTorch Forward hook to save outputs at each forward pass.
1085
+ """
1086
+ # out[0] - final output of residual layer
1087
+ # out[1] - residual hidden feature
1088
+ # import ipdb;ipdb.set_trace()
1089
+ assert out[1].shape[-1] == 64
1090
+ activations[name] = out[1].detach()
1091
+ attention_dict = collections.defaultdict(list)
1092
+ for name, module in self.unet.named_modules():
1093
+ leaf_name = name.split('.')[-1]
1094
+ if 'attn' in leaf_name and feat_inject_step:
1095
+ # Register hook to obtain outputs at every attention layer.
1096
+ self.selfattn_forward_hooks.append(module.register_forward_hook(
1097
+ partial(save_activations, attention_dict, name)
1098
+ ))
1099
+ if name == 'up_blocks.1.resnets.1' and feat_inject_step:
1100
+ self.selfattn_forward_hooks.append(module.register_forward_hook(
1101
+ partial(save_resnet_activations, attention_dict, name)
1102
+ ))
1103
+ # attention_dict is a dictionary containing attention maps for every attention layer
1104
+ self.self_attention_maps_cur = attention_dict
1105
+
1106
+ def remove_selfattn_hooks(self):
1107
+ for hook in self.selfattn_forward_hooks:
1108
+ hook.remove()
1109
+
1110
+ def register_fontsize_hooks(self, text_format_dict={}):
1111
+ r"""Function for registering hooks to replace self attention.
1112
+ """
1113
+ self.forward_fontsize_hooks = []
1114
+
1115
+ def adjust_attn_weights(name, module, args):
1116
+ r"""
1117
+ PyTorch Forward hook to save outputs at each forward pass.
1118
+ """
1119
+ if 'attn2' in name:
1120
+ modified_args = (args[0], None, attn_weights)
1121
+ return modified_args
1122
+
1123
+ if text_format_dict['word_pos'] is not None and text_format_dict['font_size'] is not None:
1124
+ attn_weights = {'word_pos': text_format_dict['word_pos'], 'font_size': text_format_dict['font_size']}
1125
+ else:
1126
+ attn_weights = None
1127
+
1128
+ for name, module in self.unet.named_modules():
1129
+ leaf_name = name.split('.')[-1]
1130
+ if 'attn' in leaf_name and attn_weights is not None:
1131
+ # Register hook to obtain outputs at every attention layer.
1132
+ self.forward_fontsize_hooks.append(module.register_forward_pre_hook(
1133
+ partial(adjust_attn_weights, name)
1134
+ ))
1135
+
1136
+ def remove_fontsize_hooks(self):
1137
+ for hook in self.forward_fontsize_hooks:
1138
+ hook.remove()
models/resnet.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention import AdaGroupNorm
25
+ from models.attention_processor import SpatialNorm
26
+
27
+
28
+ class Upsample1D(nn.Module):
29
+ """A 1D upsampling layer with an optional convolution.
30
+
31
+ Parameters:
32
+ channels (`int`):
33
+ number of channels in the inputs and outputs.
34
+ use_conv (`bool`, default `False`):
35
+ option to use a convolution.
36
+ use_conv_transpose (`bool`, default `False`):
37
+ option to use a convolution transpose.
38
+ out_channels (`int`, optional):
39
+ number of output channels. Defaults to `channels`.
40
+ """
41
+
42
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
43
+ super().__init__()
44
+ self.channels = channels
45
+ self.out_channels = out_channels or channels
46
+ self.use_conv = use_conv
47
+ self.use_conv_transpose = use_conv_transpose
48
+ self.name = name
49
+
50
+ self.conv = None
51
+ if use_conv_transpose:
52
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
53
+ elif use_conv:
54
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
55
+
56
+ def forward(self, inputs):
57
+ assert inputs.shape[1] == self.channels
58
+ if self.use_conv_transpose:
59
+ return self.conv(inputs)
60
+
61
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
62
+
63
+ if self.use_conv:
64
+ outputs = self.conv(outputs)
65
+
66
+ return outputs
67
+
68
+
69
+ class Downsample1D(nn.Module):
70
+ """A 1D downsampling layer with an optional convolution.
71
+
72
+ Parameters:
73
+ channels (`int`):
74
+ number of channels in the inputs and outputs.
75
+ use_conv (`bool`, default `False`):
76
+ option to use a convolution.
77
+ out_channels (`int`, optional):
78
+ number of output channels. Defaults to `channels`.
79
+ padding (`int`, default `1`):
80
+ padding for the convolution.
81
+ """
82
+
83
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
84
+ super().__init__()
85
+ self.channels = channels
86
+ self.out_channels = out_channels or channels
87
+ self.use_conv = use_conv
88
+ self.padding = padding
89
+ stride = 2
90
+ self.name = name
91
+
92
+ if use_conv:
93
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
94
+ else:
95
+ assert self.channels == self.out_channels
96
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
97
+
98
+ def forward(self, inputs):
99
+ assert inputs.shape[1] == self.channels
100
+ return self.conv(inputs)
101
+
102
+
103
+ class Upsample2D(nn.Module):
104
+ """A 2D upsampling layer with an optional convolution.
105
+
106
+ Parameters:
107
+ channels (`int`):
108
+ number of channels in the inputs and outputs.
109
+ use_conv (`bool`, default `False`):
110
+ option to use a convolution.
111
+ use_conv_transpose (`bool`, default `False`):
112
+ option to use a convolution transpose.
113
+ out_channels (`int`, optional):
114
+ number of output channels. Defaults to `channels`.
115
+ """
116
+
117
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
118
+ super().__init__()
119
+ self.channels = channels
120
+ self.out_channels = out_channels or channels
121
+ self.use_conv = use_conv
122
+ self.use_conv_transpose = use_conv_transpose
123
+ self.name = name
124
+
125
+ conv = None
126
+ if use_conv_transpose:
127
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
128
+ elif use_conv:
129
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
130
+
131
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
132
+ if name == "conv":
133
+ self.conv = conv
134
+ else:
135
+ self.Conv2d_0 = conv
136
+
137
+ def forward(self, hidden_states, output_size=None):
138
+ assert hidden_states.shape[1] == self.channels
139
+
140
+ if self.use_conv_transpose:
141
+ return self.conv(hidden_states)
142
+
143
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
144
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
145
+ # https://github.com/pytorch/pytorch/issues/86679
146
+ dtype = hidden_states.dtype
147
+ if dtype == torch.bfloat16:
148
+ hidden_states = hidden_states.to(torch.float32)
149
+
150
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
151
+ if hidden_states.shape[0] >= 64:
152
+ hidden_states = hidden_states.contiguous()
153
+
154
+ # if `output_size` is passed we force the interpolation output
155
+ # size and do not make use of `scale_factor=2`
156
+ if output_size is None:
157
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
158
+ else:
159
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
160
+
161
+ # If the input is bfloat16, we cast back to bfloat16
162
+ if dtype == torch.bfloat16:
163
+ hidden_states = hidden_states.to(dtype)
164
+
165
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
166
+ if self.use_conv:
167
+ if self.name == "conv":
168
+ hidden_states = self.conv(hidden_states)
169
+ else:
170
+ hidden_states = self.Conv2d_0(hidden_states)
171
+
172
+ return hidden_states
173
+
174
+
175
+ class Downsample2D(nn.Module):
176
+ """A 2D downsampling layer with an optional convolution.
177
+
178
+ Parameters:
179
+ channels (`int`):
180
+ number of channels in the inputs and outputs.
181
+ use_conv (`bool`, default `False`):
182
+ option to use a convolution.
183
+ out_channels (`int`, optional):
184
+ number of output channels. Defaults to `channels`.
185
+ padding (`int`, default `1`):
186
+ padding for the convolution.
187
+ """
188
+
189
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
190
+ super().__init__()
191
+ self.channels = channels
192
+ self.out_channels = out_channels or channels
193
+ self.use_conv = use_conv
194
+ self.padding = padding
195
+ stride = 2
196
+ self.name = name
197
+
198
+ if use_conv:
199
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
200
+ else:
201
+ assert self.channels == self.out_channels
202
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
203
+
204
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
205
+ if name == "conv":
206
+ self.Conv2d_0 = conv
207
+ self.conv = conv
208
+ elif name == "Conv2d_0":
209
+ self.conv = conv
210
+ else:
211
+ self.conv = conv
212
+
213
+ def forward(self, hidden_states):
214
+ assert hidden_states.shape[1] == self.channels
215
+ if self.use_conv and self.padding == 0:
216
+ pad = (0, 1, 0, 1)
217
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
218
+
219
+ assert hidden_states.shape[1] == self.channels
220
+ hidden_states = self.conv(hidden_states)
221
+
222
+ return hidden_states
223
+
224
+
225
+ class FirUpsample2D(nn.Module):
226
+ """A 2D FIR upsampling layer with an optional convolution.
227
+
228
+ Parameters:
229
+ channels (`int`):
230
+ number of channels in the inputs and outputs.
231
+ use_conv (`bool`, default `False`):
232
+ option to use a convolution.
233
+ out_channels (`int`, optional):
234
+ number of output channels. Defaults to `channels`.
235
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
236
+ kernel for the FIR filter.
237
+ """
238
+
239
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
240
+ super().__init__()
241
+ out_channels = out_channels if out_channels else channels
242
+ if use_conv:
243
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
244
+ self.use_conv = use_conv
245
+ self.fir_kernel = fir_kernel
246
+ self.out_channels = out_channels
247
+
248
+ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
249
+ """Fused `upsample_2d()` followed by `Conv2d()`.
250
+
251
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
252
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
253
+ arbitrary order.
254
+
255
+ Args:
256
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
257
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
258
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
259
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
260
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
261
+ factor: Integer upsampling factor (default: 2).
262
+ gain: Scaling factor for signal magnitude (default: 1.0).
263
+
264
+ Returns:
265
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
266
+ datatype as `hidden_states`.
267
+ """
268
+
269
+ assert isinstance(factor, int) and factor >= 1
270
+
271
+ # Setup filter kernel.
272
+ if kernel is None:
273
+ kernel = [1] * factor
274
+
275
+ # setup kernel
276
+ kernel = torch.tensor(kernel, dtype=torch.float32)
277
+ if kernel.ndim == 1:
278
+ kernel = torch.outer(kernel, kernel)
279
+ kernel /= torch.sum(kernel)
280
+
281
+ kernel = kernel * (gain * (factor**2))
282
+
283
+ if self.use_conv:
284
+ convH = weight.shape[2]
285
+ convW = weight.shape[3]
286
+ inC = weight.shape[1]
287
+
288
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
289
+
290
+ stride = (factor, factor)
291
+ # Determine data dimensions.
292
+ output_shape = (
293
+ (hidden_states.shape[2] - 1) * factor + convH,
294
+ (hidden_states.shape[3] - 1) * factor + convW,
295
+ )
296
+ output_padding = (
297
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
298
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
299
+ )
300
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
301
+ num_groups = hidden_states.shape[1] // inC
302
+
303
+ # Transpose weights.
304
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
305
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
306
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
307
+
308
+ inverse_conv = F.conv_transpose2d(
309
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
310
+ )
311
+
312
+ output = upfirdn2d_native(
313
+ inverse_conv,
314
+ torch.tensor(kernel, device=inverse_conv.device),
315
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
316
+ )
317
+ else:
318
+ pad_value = kernel.shape[0] - factor
319
+ output = upfirdn2d_native(
320
+ hidden_states,
321
+ torch.tensor(kernel, device=hidden_states.device),
322
+ up=factor,
323
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
324
+ )
325
+
326
+ return output
327
+
328
+ def forward(self, hidden_states):
329
+ if self.use_conv:
330
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
331
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
332
+ else:
333
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
334
+
335
+ return height
336
+
337
+
338
+ class FirDownsample2D(nn.Module):
339
+ """A 2D FIR downsampling layer with an optional convolution.
340
+
341
+ Parameters:
342
+ channels (`int`):
343
+ number of channels in the inputs and outputs.
344
+ use_conv (`bool`, default `False`):
345
+ option to use a convolution.
346
+ out_channels (`int`, optional):
347
+ number of output channels. Defaults to `channels`.
348
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
349
+ kernel for the FIR filter.
350
+ """
351
+
352
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
353
+ super().__init__()
354
+ out_channels = out_channels if out_channels else channels
355
+ if use_conv:
356
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
357
+ self.fir_kernel = fir_kernel
358
+ self.use_conv = use_conv
359
+ self.out_channels = out_channels
360
+
361
+ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
362
+ """Fused `Conv2d()` followed by `downsample_2d()`.
363
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
364
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
365
+ arbitrary order.
366
+
367
+ Args:
368
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
369
+ weight:
370
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
371
+ performed by `inChannels = x.shape[0] // numGroups`.
372
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
373
+ factor`, which corresponds to average pooling.
374
+ factor: Integer downsampling factor (default: 2).
375
+ gain: Scaling factor for signal magnitude (default: 1.0).
376
+
377
+ Returns:
378
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
379
+ same datatype as `x`.
380
+ """
381
+
382
+ assert isinstance(factor, int) and factor >= 1
383
+ if kernel is None:
384
+ kernel = [1] * factor
385
+
386
+ # setup kernel
387
+ kernel = torch.tensor(kernel, dtype=torch.float32)
388
+ if kernel.ndim == 1:
389
+ kernel = torch.outer(kernel, kernel)
390
+ kernel /= torch.sum(kernel)
391
+
392
+ kernel = kernel * gain
393
+
394
+ if self.use_conv:
395
+ _, _, convH, convW = weight.shape
396
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
397
+ stride_value = [factor, factor]
398
+ upfirdn_input = upfirdn2d_native(
399
+ hidden_states,
400
+ torch.tensor(kernel, device=hidden_states.device),
401
+ pad=((pad_value + 1) // 2, pad_value // 2),
402
+ )
403
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
404
+ else:
405
+ pad_value = kernel.shape[0] - factor
406
+ output = upfirdn2d_native(
407
+ hidden_states,
408
+ torch.tensor(kernel, device=hidden_states.device),
409
+ down=factor,
410
+ pad=((pad_value + 1) // 2, pad_value // 2),
411
+ )
412
+
413
+ return output
414
+
415
+ def forward(self, hidden_states):
416
+ if self.use_conv:
417
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
418
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
419
+ else:
420
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
421
+
422
+ return hidden_states
423
+
424
+
425
+ # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
426
+ class KDownsample2D(nn.Module):
427
+ def __init__(self, pad_mode="reflect"):
428
+ super().__init__()
429
+ self.pad_mode = pad_mode
430
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
431
+ self.pad = kernel_1d.shape[1] // 2 - 1
432
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
433
+
434
+ def forward(self, inputs):
435
+ inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
436
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
437
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
438
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
439
+ weight[indices, indices] = kernel
440
+ return F.conv2d(inputs, weight, stride=2)
441
+
442
+
443
+ class KUpsample2D(nn.Module):
444
+ def __init__(self, pad_mode="reflect"):
445
+ super().__init__()
446
+ self.pad_mode = pad_mode
447
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
448
+ self.pad = kernel_1d.shape[1] // 2 - 1
449
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
450
+
451
+ def forward(self, inputs):
452
+ inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
453
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
454
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
455
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
456
+ weight[indices, indices] = kernel
457
+ return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
458
+
459
+
460
+ class ResnetBlock2D(nn.Module):
461
+ r"""
462
+ A Resnet block.
463
+
464
+ Parameters:
465
+ in_channels (`int`): The number of channels in the input.
466
+ out_channels (`int`, *optional*, default to be `None`):
467
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
468
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
469
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
470
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
471
+ groups_out (`int`, *optional*, default to None):
472
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
473
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
474
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
475
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
476
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
477
+ "ada_group" for a stronger conditioning with scale and shift.
478
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
479
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
480
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
481
+ use_in_shortcut (`bool`, *optional*, default to `True`):
482
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
483
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
484
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
485
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
486
+ `conv_shortcut` output.
487
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
488
+ If None, same as `out_channels`.
489
+ """
490
+
491
+ def __init__(
492
+ self,
493
+ *,
494
+ in_channels,
495
+ out_channels=None,
496
+ conv_shortcut=False,
497
+ dropout=0.0,
498
+ temb_channels=512,
499
+ groups=32,
500
+ groups_out=None,
501
+ pre_norm=True,
502
+ eps=1e-6,
503
+ non_linearity="swish",
504
+ skip_time_act=False,
505
+ time_embedding_norm="default", # default, scale_shift, ada_group, spatial
506
+ kernel=None,
507
+ output_scale_factor=1.0,
508
+ use_in_shortcut=None,
509
+ up=False,
510
+ down=False,
511
+ conv_shortcut_bias: bool = True,
512
+ conv_2d_out_channels: Optional[int] = None,
513
+ ):
514
+ super().__init__()
515
+ self.pre_norm = pre_norm
516
+ self.pre_norm = True
517
+ self.in_channels = in_channels
518
+ out_channels = in_channels if out_channels is None else out_channels
519
+ self.out_channels = out_channels
520
+ self.use_conv_shortcut = conv_shortcut
521
+ self.up = up
522
+ self.down = down
523
+ self.output_scale_factor = output_scale_factor
524
+ self.time_embedding_norm = time_embedding_norm
525
+ self.skip_time_act = skip_time_act
526
+
527
+ if groups_out is None:
528
+ groups_out = groups
529
+
530
+ if self.time_embedding_norm == "ada_group":
531
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
532
+ elif self.time_embedding_norm == "spatial":
533
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
534
+ else:
535
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
536
+
537
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
538
+
539
+ if temb_channels is not None:
540
+ if self.time_embedding_norm == "default":
541
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
542
+ elif self.time_embedding_norm == "scale_shift":
543
+ self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
544
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
545
+ self.time_emb_proj = None
546
+ else:
547
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
548
+ else:
549
+ self.time_emb_proj = None
550
+
551
+ if self.time_embedding_norm == "ada_group":
552
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
553
+ elif self.time_embedding_norm == "spatial":
554
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
555
+ else:
556
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
557
+
558
+ self.dropout = torch.nn.Dropout(dropout)
559
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
560
+ self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
561
+
562
+ self.nonlinearity = get_activation(non_linearity)
563
+
564
+ self.upsample = self.downsample = None
565
+ if self.up:
566
+ if kernel == "fir":
567
+ fir_kernel = (1, 3, 3, 1)
568
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
569
+ elif kernel == "sde_vp":
570
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
571
+ else:
572
+ self.upsample = Upsample2D(in_channels, use_conv=False)
573
+ elif self.down:
574
+ if kernel == "fir":
575
+ fir_kernel = (1, 3, 3, 1)
576
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
577
+ elif kernel == "sde_vp":
578
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
579
+ else:
580
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
581
+
582
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
583
+
584
+ self.conv_shortcut = None
585
+ if self.use_in_shortcut:
586
+ self.conv_shortcut = torch.nn.Conv2d(
587
+ in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
588
+ )
589
+
590
+ # Rich-Text: feature injection
591
+ def forward(self, input_tensor, temb, inject_states=None):
592
+ hidden_states = input_tensor
593
+
594
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
595
+ hidden_states = self.norm1(hidden_states, temb)
596
+ else:
597
+ hidden_states = self.norm1(hidden_states)
598
+
599
+ hidden_states = self.nonlinearity(hidden_states)
600
+
601
+ if self.upsample is not None:
602
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
603
+ if hidden_states.shape[0] >= 64:
604
+ input_tensor = input_tensor.contiguous()
605
+ hidden_states = hidden_states.contiguous()
606
+ input_tensor = self.upsample(input_tensor)
607
+ hidden_states = self.upsample(hidden_states)
608
+ elif self.downsample is not None:
609
+ input_tensor = self.downsample(input_tensor)
610
+ hidden_states = self.downsample(hidden_states)
611
+
612
+ hidden_states = self.conv1(hidden_states)
613
+
614
+ if self.time_emb_proj is not None:
615
+ if not self.skip_time_act:
616
+ temb = self.nonlinearity(temb)
617
+ temb = self.time_emb_proj(temb)[:, :, None, None]
618
+
619
+ if temb is not None and self.time_embedding_norm == "default":
620
+ hidden_states = hidden_states + temb
621
+
622
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
623
+ hidden_states = self.norm2(hidden_states, temb)
624
+ else:
625
+ hidden_states = self.norm2(hidden_states)
626
+
627
+ if temb is not None and self.time_embedding_norm == "scale_shift":
628
+ scale, shift = torch.chunk(temb, 2, dim=1)
629
+ hidden_states = hidden_states * (1 + scale) + shift
630
+
631
+ hidden_states = self.nonlinearity(hidden_states)
632
+
633
+ hidden_states = self.dropout(hidden_states)
634
+ hidden_states = self.conv2(hidden_states)
635
+
636
+ if self.conv_shortcut is not None:
637
+ input_tensor = self.conv_shortcut(input_tensor)
638
+
639
+ # Rich-Text: feature injection
640
+ if inject_states is not None:
641
+ output_tensor = (input_tensor + inject_states) / self.output_scale_factor
642
+ else:
643
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
644
+
645
+ return output_tensor, hidden_states
646
+
647
+
648
+ # unet_rl.py
649
+ def rearrange_dims(tensor):
650
+ if len(tensor.shape) == 2:
651
+ return tensor[:, :, None]
652
+ if len(tensor.shape) == 3:
653
+ return tensor[:, :, None, :]
654
+ elif len(tensor.shape) == 4:
655
+ return tensor[:, :, 0, :]
656
+ else:
657
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
658
+
659
+
660
+ class Conv1dBlock(nn.Module):
661
+ """
662
+ Conv1d --> GroupNorm --> Mish
663
+ """
664
+
665
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
666
+ super().__init__()
667
+
668
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
669
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
670
+ self.mish = nn.Mish()
671
+
672
+ def forward(self, inputs):
673
+ intermediate_repr = self.conv1d(inputs)
674
+ intermediate_repr = rearrange_dims(intermediate_repr)
675
+ intermediate_repr = self.group_norm(intermediate_repr)
676
+ intermediate_repr = rearrange_dims(intermediate_repr)
677
+ output = self.mish(intermediate_repr)
678
+ return output
679
+
680
+
681
+ # unet_rl.py
682
+ class ResidualTemporalBlock1D(nn.Module):
683
+ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
684
+ super().__init__()
685
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
686
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
687
+
688
+ self.time_emb_act = nn.Mish()
689
+ self.time_emb = nn.Linear(embed_dim, out_channels)
690
+
691
+ self.residual_conv = (
692
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
693
+ )
694
+
695
+ def forward(self, inputs, t):
696
+ """
697
+ Args:
698
+ inputs : [ batch_size x inp_channels x horizon ]
699
+ t : [ batch_size x embed_dim ]
700
+
701
+ returns:
702
+ out : [ batch_size x out_channels x horizon ]
703
+ """
704
+ t = self.time_emb_act(t)
705
+ t = self.time_emb(t)
706
+ out = self.conv_in(inputs) + rearrange_dims(t)
707
+ out = self.conv_out(out)
708
+ return out + self.residual_conv(inputs)
709
+
710
+
711
+ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
712
+ r"""Upsample2D a batch of 2D images with the given filter.
713
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
714
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
715
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
716
+ a: multiple of the upsampling factor.
717
+
718
+ Args:
719
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
720
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
721
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
722
+ factor: Integer upsampling factor (default: 2).
723
+ gain: Scaling factor for signal magnitude (default: 1.0).
724
+
725
+ Returns:
726
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
727
+ """
728
+ assert isinstance(factor, int) and factor >= 1
729
+ if kernel is None:
730
+ kernel = [1] * factor
731
+
732
+ kernel = torch.tensor(kernel, dtype=torch.float32)
733
+ if kernel.ndim == 1:
734
+ kernel = torch.outer(kernel, kernel)
735
+ kernel /= torch.sum(kernel)
736
+
737
+ kernel = kernel * (gain * (factor**2))
738
+ pad_value = kernel.shape[0] - factor
739
+ output = upfirdn2d_native(
740
+ hidden_states,
741
+ kernel.to(device=hidden_states.device),
742
+ up=factor,
743
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
744
+ )
745
+ return output
746
+
747
+
748
+ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
749
+ r"""Downsample2D a batch of 2D images with the given filter.
750
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
751
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
752
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
753
+ shape is a multiple of the downsampling factor.
754
+
755
+ Args:
756
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
757
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
758
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
759
+ factor: Integer downsampling factor (default: 2).
760
+ gain: Scaling factor for signal magnitude (default: 1.0).
761
+
762
+ Returns:
763
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
764
+ """
765
+
766
+ assert isinstance(factor, int) and factor >= 1
767
+ if kernel is None:
768
+ kernel = [1] * factor
769
+
770
+ kernel = torch.tensor(kernel, dtype=torch.float32)
771
+ if kernel.ndim == 1:
772
+ kernel = torch.outer(kernel, kernel)
773
+ kernel /= torch.sum(kernel)
774
+
775
+ kernel = kernel * gain
776
+ pad_value = kernel.shape[0] - factor
777
+ output = upfirdn2d_native(
778
+ hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
779
+ )
780
+ return output
781
+
782
+
783
+ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
784
+ up_x = up_y = up
785
+ down_x = down_y = down
786
+ pad_x0 = pad_y0 = pad[0]
787
+ pad_x1 = pad_y1 = pad[1]
788
+
789
+ _, channel, in_h, in_w = tensor.shape
790
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
791
+
792
+ _, in_h, in_w, minor = tensor.shape
793
+ kernel_h, kernel_w = kernel.shape
794
+
795
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
796
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
797
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
798
+
799
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
800
+ out = out.to(tensor.device) # Move back to mps if necessary
801
+ out = out[
802
+ :,
803
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
804
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
805
+ :,
806
+ ]
807
+
808
+ out = out.permute(0, 3, 1, 2)
809
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
810
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
811
+ out = F.conv2d(out, w)
812
+ out = out.reshape(
813
+ -1,
814
+ minor,
815
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
816
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
817
+ )
818
+ out = out.permute(0, 2, 3, 1)
819
+ out = out[:, ::down_y, ::down_x, :]
820
+
821
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
822
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
823
+
824
+ return out.view(-1, channel, out_h, out_w)
825
+
826
+
827
+ class TemporalConvLayer(nn.Module):
828
+ """
829
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
830
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
831
+ """
832
+
833
+ def __init__(self, in_dim, out_dim=None, dropout=0.0):
834
+ super().__init__()
835
+ out_dim = out_dim or in_dim
836
+ self.in_dim = in_dim
837
+ self.out_dim = out_dim
838
+
839
+ # conv layers
840
+ self.conv1 = nn.Sequential(
841
+ nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
842
+ )
843
+ self.conv2 = nn.Sequential(
844
+ nn.GroupNorm(32, out_dim),
845
+ nn.SiLU(),
846
+ nn.Dropout(dropout),
847
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
848
+ )
849
+ self.conv3 = nn.Sequential(
850
+ nn.GroupNorm(32, out_dim),
851
+ nn.SiLU(),
852
+ nn.Dropout(dropout),
853
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
854
+ )
855
+ self.conv4 = nn.Sequential(
856
+ nn.GroupNorm(32, out_dim),
857
+ nn.SiLU(),
858
+ nn.Dropout(dropout),
859
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
860
+ )
861
+
862
+ # zero out the last layer params,so the conv block is identity
863
+ nn.init.zeros_(self.conv4[-1].weight)
864
+ nn.init.zeros_(self.conv4[-1].bias)
865
+
866
+ def forward(self, hidden_states, num_frames=1):
867
+ hidden_states = (
868
+ hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
869
+ )
870
+
871
+ identity = hidden_states
872
+ hidden_states = self.conv1(hidden_states)
873
+ hidden_states = self.conv2(hidden_states)
874
+ hidden_states = self.conv3(hidden_states)
875
+ hidden_states = self.conv4(hidden_states)
876
+
877
+ hidden_states = identity + hidden_states
878
+
879
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
880
+ (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
881
+ )
882
+ return hidden_states
models/transformer_2d.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.models.embeddings import PatchEmbed
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+
27
+ from models.attention import BasicTransformerBlock
28
+
29
+ @dataclass
30
+ class Transformer2DModelOutput(BaseOutput):
31
+ """
32
+ The output of [`Transformer2DModel`].
33
+
34
+ Args:
35
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
36
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
37
+ distributions for the unnoised latent pixels.
38
+ """
39
+
40
+ sample: torch.FloatTensor
41
+
42
+
43
+ class Transformer2DModel(ModelMixin, ConfigMixin):
44
+ """
45
+ A 2D Transformer model for image-like data.
46
+
47
+ Parameters:
48
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
49
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
50
+ in_channels (`int`, *optional*):
51
+ The number of channels in the input and output (specify if the input is **continuous**).
52
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
53
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
54
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
55
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
56
+ This is fixed during training since it is used to learn a number of position embeddings.
57
+ num_vector_embeds (`int`, *optional*):
58
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
59
+ Includes the class for the masked latent pixel.
60
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
61
+ num_embeds_ada_norm ( `int`, *optional*):
62
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
63
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
64
+ added to the hidden states.
65
+
66
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
67
+ attention_bias (`bool`, *optional*):
68
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
69
+ """
70
+
71
+ @register_to_config
72
+ def __init__(
73
+ self,
74
+ num_attention_heads: int = 16,
75
+ attention_head_dim: int = 88,
76
+ in_channels: Optional[int] = None,
77
+ out_channels: Optional[int] = None,
78
+ num_layers: int = 1,
79
+ dropout: float = 0.0,
80
+ norm_num_groups: int = 32,
81
+ cross_attention_dim: Optional[int] = None,
82
+ attention_bias: bool = False,
83
+ sample_size: Optional[int] = None,
84
+ num_vector_embeds: Optional[int] = None,
85
+ patch_size: Optional[int] = None,
86
+ activation_fn: str = "geglu",
87
+ num_embeds_ada_norm: Optional[int] = None,
88
+ use_linear_projection: bool = False,
89
+ only_cross_attention: bool = False,
90
+ upcast_attention: bool = False,
91
+ norm_type: str = "layer_norm",
92
+ norm_elementwise_affine: bool = True,
93
+ ):
94
+ super().__init__()
95
+ self.use_linear_projection = use_linear_projection
96
+ self.num_attention_heads = num_attention_heads
97
+ self.attention_head_dim = attention_head_dim
98
+ inner_dim = num_attention_heads * attention_head_dim
99
+
100
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
101
+ # Define whether input is continuous or discrete depending on configuration
102
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
103
+ self.is_input_vectorized = num_vector_embeds is not None
104
+ self.is_input_patches = in_channels is not None and patch_size is not None
105
+
106
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
107
+ deprecation_message = (
108
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
109
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
110
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
111
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
112
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
113
+ )
114
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
115
+ norm_type = "ada_norm"
116
+
117
+ if self.is_input_continuous and self.is_input_vectorized:
118
+ raise ValueError(
119
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
120
+ " sure that either `in_channels` or `num_vector_embeds` is None."
121
+ )
122
+ elif self.is_input_vectorized and self.is_input_patches:
123
+ raise ValueError(
124
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
125
+ " sure that either `num_vector_embeds` or `num_patches` is None."
126
+ )
127
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
128
+ raise ValueError(
129
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
130
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
131
+ )
132
+
133
+ # 2. Define input layers
134
+ if self.is_input_continuous:
135
+ self.in_channels = in_channels
136
+
137
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
138
+ if use_linear_projection:
139
+ self.proj_in = nn.Linear(in_channels, inner_dim)
140
+ else:
141
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
142
+ elif self.is_input_vectorized:
143
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
144
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
145
+
146
+ self.height = sample_size
147
+ self.width = sample_size
148
+ self.num_vector_embeds = num_vector_embeds
149
+ self.num_latent_pixels = self.height * self.width
150
+
151
+ self.latent_image_embedding = ImagePositionalEmbeddings(
152
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
153
+ )
154
+ elif self.is_input_patches:
155
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
156
+
157
+ self.height = sample_size
158
+ self.width = sample_size
159
+
160
+ self.patch_size = patch_size
161
+ self.pos_embed = PatchEmbed(
162
+ height=sample_size,
163
+ width=sample_size,
164
+ patch_size=patch_size,
165
+ in_channels=in_channels,
166
+ embed_dim=inner_dim,
167
+ )
168
+
169
+ # 3. Define transformers blocks
170
+ self.transformer_blocks = nn.ModuleList(
171
+ [
172
+ BasicTransformerBlock(
173
+ inner_dim,
174
+ num_attention_heads,
175
+ attention_head_dim,
176
+ dropout=dropout,
177
+ cross_attention_dim=cross_attention_dim,
178
+ activation_fn=activation_fn,
179
+ num_embeds_ada_norm=num_embeds_ada_norm,
180
+ attention_bias=attention_bias,
181
+ only_cross_attention=only_cross_attention,
182
+ upcast_attention=upcast_attention,
183
+ norm_type=norm_type,
184
+ norm_elementwise_affine=norm_elementwise_affine,
185
+ )
186
+ for d in range(num_layers)
187
+ ]
188
+ )
189
+
190
+ # 4. Define output layers
191
+ self.out_channels = in_channels if out_channels is None else out_channels
192
+ if self.is_input_continuous:
193
+ # TODO: should use out_channels for continuous projections
194
+ if use_linear_projection:
195
+ self.proj_out = nn.Linear(inner_dim, in_channels)
196
+ else:
197
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
198
+ elif self.is_input_vectorized:
199
+ self.norm_out = nn.LayerNorm(inner_dim)
200
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
201
+ elif self.is_input_patches:
202
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
203
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
204
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
205
+
206
+ def forward(
207
+ self,
208
+ hidden_states: torch.Tensor,
209
+ encoder_hidden_states: Optional[torch.Tensor] = None,
210
+ timestep: Optional[torch.LongTensor] = None,
211
+ class_labels: Optional[torch.LongTensor] = None,
212
+ cross_attention_kwargs: Dict[str, Any] = None,
213
+ attention_mask: Optional[torch.Tensor] = None,
214
+ encoder_attention_mask: Optional[torch.Tensor] = None,
215
+ return_dict: bool = True,
216
+ ):
217
+ """
218
+ The [`Transformer2DModel`] forward method.
219
+
220
+ Args:
221
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
222
+ Input `hidden_states`.
223
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
224
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
225
+ self-attention.
226
+ timestep ( `torch.LongTensor`, *optional*):
227
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
228
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
229
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
230
+ `AdaLayerZeroNorm`.
231
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
232
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
233
+
234
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
235
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
236
+
237
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
238
+ above. This bias will be added to the cross-attention scores.
239
+ return_dict (`bool`, *optional*, defaults to `True`):
240
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
241
+ tuple.
242
+
243
+ Returns:
244
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
245
+ `tuple` where the first element is the sample tensor.
246
+ """
247
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
248
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
249
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
250
+ # expects mask of shape:
251
+ # [batch, key_tokens]
252
+ # adds singleton query_tokens dimension:
253
+ # [batch, 1, key_tokens]
254
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
255
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
256
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
257
+ if attention_mask is not None and attention_mask.ndim == 2:
258
+ # assume that mask is expressed as:
259
+ # (1 = keep, 0 = discard)
260
+ # convert mask into a bias that can be added to attention scores:
261
+ # (keep = +0, discard = -10000.0)
262
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
263
+ attention_mask = attention_mask.unsqueeze(1)
264
+
265
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
266
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
267
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
268
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
269
+
270
+ # 1. Input
271
+ if self.is_input_continuous:
272
+ batch, _, height, width = hidden_states.shape
273
+ residual = hidden_states
274
+
275
+ hidden_states = self.norm(hidden_states)
276
+ if not self.use_linear_projection:
277
+ hidden_states = self.proj_in(hidden_states)
278
+ inner_dim = hidden_states.shape[1]
279
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
280
+ else:
281
+ inner_dim = hidden_states.shape[1]
282
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
283
+ hidden_states = self.proj_in(hidden_states)
284
+ elif self.is_input_vectorized:
285
+ hidden_states = self.latent_image_embedding(hidden_states)
286
+ elif self.is_input_patches:
287
+ hidden_states = self.pos_embed(hidden_states)
288
+
289
+ # 2. Blocks
290
+ for block in self.transformer_blocks:
291
+ hidden_states = block(
292
+ hidden_states,
293
+ attention_mask=attention_mask,
294
+ encoder_hidden_states=encoder_hidden_states,
295
+ encoder_attention_mask=encoder_attention_mask,
296
+ timestep=timestep,
297
+ cross_attention_kwargs=cross_attention_kwargs,
298
+ class_labels=class_labels,
299
+ )
300
+
301
+ # 3. Output
302
+ if self.is_input_continuous:
303
+ if not self.use_linear_projection:
304
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
305
+ hidden_states = self.proj_out(hidden_states)
306
+ else:
307
+ hidden_states = self.proj_out(hidden_states)
308
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
309
+
310
+ output = hidden_states + residual
311
+ elif self.is_input_vectorized:
312
+ hidden_states = self.norm_out(hidden_states)
313
+ logits = self.out(hidden_states)
314
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
315
+ logits = logits.permute(0, 2, 1)
316
+
317
+ # log(p(x_0))
318
+ output = F.log_softmax(logits.double(), dim=1).float()
319
+ elif self.is_input_patches:
320
+ # TODO: cleanup!
321
+ conditioning = self.transformer_blocks[0].norm1.emb(
322
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
323
+ )
324
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
325
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
326
+ hidden_states = self.proj_out_2(hidden_states)
327
+
328
+ # unpatchify
329
+ height = width = int(hidden_states.shape[1] ** 0.5)
330
+ hidden_states = hidden_states.reshape(
331
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
332
+ )
333
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
334
+ output = hidden_states.reshape(
335
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
336
+ )
337
+
338
+ if not return_dict:
339
+ return (output,)
340
+
341
+ return Transformer2DModelOutput(sample=output)
models/unet_2d_blocks.py CHANGED
The diff for this file is too large to render. See raw diff
 
models/unet_2d_condition.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -12,21 +12,38 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  from dataclasses import dataclass
15
- from typing import Optional, Tuple, Union
16
 
17
  import torch
18
  import torch.nn as nn
19
  import torch.utils.checkpoint
20
 
21
  from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.models.modeling_utils import ModelMixin
23
  from diffusers.utils import BaseOutput, logging
24
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
- from .unet_2d_blocks import (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  CrossAttnDownBlock2D,
27
  CrossAttnUpBlock2D,
28
  DownBlock2D,
29
  UNetMidBlock2DCrossAttn,
 
30
  UpBlock2D,
31
  get_down_block,
32
  get_up_block,
@@ -39,35 +56,43 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
  @dataclass
40
  class UNet2DConditionOutput(BaseOutput):
41
  """
 
 
42
  Args:
43
  sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
44
- Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
45
  """
46
 
47
- sample: torch.FloatTensor
48
 
49
 
50
- class UNet2DConditionModel(ModelMixin, ConfigMixin):
51
  r"""
52
- UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
53
- and returns sample shaped output.
54
 
55
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
56
- implements for all the models (such as downloading or saving, etc.)
57
 
58
  Parameters:
59
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
60
  Height and width of input/output sample.
61
- in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
62
- out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
63
  center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
64
  flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
65
  Whether to flip the sin to cos in the time embedding.
66
  freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
67
  down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
68
  The tuple of downsample blocks to use.
69
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
 
 
 
70
  The tuple of upsample blocks to use.
 
 
 
71
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
72
  The tuple of output channels for each block.
73
  layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
@@ -75,9 +100,58 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
75
  mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
76
  act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
77
  norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
 
78
  norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
79
- cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
 
 
 
 
 
 
 
 
 
 
 
80
  attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  """
82
 
83
  _supports_gradient_checkpointing = True
@@ -97,50 +171,262 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
97
  "CrossAttnDownBlock2D",
98
  "DownBlock2D",
99
  ),
 
100
  up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
101
  only_cross_attention: Union[bool, Tuple[bool]] = False,
102
  block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
103
- layers_per_block: int = 2,
104
  downsample_padding: int = 1,
105
  mid_block_scale_factor: float = 1,
106
  act_fn: str = "silu",
107
- norm_num_groups: int = 32,
108
  norm_eps: float = 1e-5,
109
- cross_attention_dim: int = 1280,
 
 
 
110
  attention_head_dim: Union[int, Tuple[int]] = 8,
 
111
  dual_cross_attention: bool = False,
112
  use_linear_projection: bool = False,
 
 
 
113
  num_class_embeds: Optional[int] = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  ):
115
  super().__init__()
116
 
117
  self.sample_size = sample_size
118
- time_embed_dim = block_out_channels[0] * 4
119
- # import ipdb;ipdb.set_trace()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # input
122
- self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
 
 
 
123
 
124
  # time
125
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
126
- timestep_input_dim = block_out_channels[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  # class embedding
131
- if num_class_embeds is not None:
132
  self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  self.down_blocks = nn.ModuleList([])
135
- self.mid_block = None
136
  self.up_blocks = nn.ModuleList([])
137
 
138
  if isinstance(only_cross_attention, bool):
 
 
 
139
  only_cross_attention = [only_cross_attention] * len(down_block_types)
140
 
 
 
 
 
 
 
141
  if isinstance(attention_head_dim, int):
142
  attention_head_dim = (attention_head_dim,) * len(down_block_types)
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  # down
145
  output_channel = block_out_channels[0]
146
  for i, down_block_type in enumerate(down_block_types):
@@ -150,45 +436,78 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
150
 
151
  down_block = get_down_block(
152
  down_block_type,
153
- num_layers=layers_per_block,
 
154
  in_channels=input_channel,
155
  out_channels=output_channel,
156
- temb_channels=time_embed_dim,
157
  add_downsample=not is_final_block,
158
  resnet_eps=norm_eps,
159
  resnet_act_fn=act_fn,
160
  resnet_groups=norm_num_groups,
161
- cross_attention_dim=cross_attention_dim,
162
- attn_num_head_channels=attention_head_dim[i],
163
  downsample_padding=downsample_padding,
164
  dual_cross_attention=dual_cross_attention,
165
  use_linear_projection=use_linear_projection,
166
  only_cross_attention=only_cross_attention[i],
 
 
 
 
 
 
167
  )
168
  self.down_blocks.append(down_block)
169
 
170
  # mid
171
- self.mid_block = UNetMidBlock2DCrossAttn(
172
- in_channels=block_out_channels[-1],
173
- temb_channels=time_embed_dim,
174
- resnet_eps=norm_eps,
175
- resnet_act_fn=act_fn,
176
- output_scale_factor=mid_block_scale_factor,
177
- resnet_time_scale_shift="default",
178
- cross_attention_dim=cross_attention_dim,
179
- attn_num_head_channels=attention_head_dim[-1],
180
- resnet_groups=norm_num_groups,
181
- dual_cross_attention=dual_cross_attention,
182
- use_linear_projection=use_linear_projection,
183
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  # count how many layers upsample the images
186
  self.num_upsamplers = 0
187
 
188
  # up
189
  reversed_block_out_channels = list(reversed(block_out_channels))
190
- reversed_attention_head_dim = list(reversed(attention_head_dim))
 
 
 
191
  only_cross_attention = list(reversed(only_cross_attention))
 
192
  output_channel = reversed_block_out_channels[0]
193
  for i, up_block_type in enumerate(up_block_types):
194
  is_final_block = i == len(block_out_channels) - 1
@@ -206,63 +525,176 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
206
 
207
  up_block = get_up_block(
208
  up_block_type,
209
- num_layers=layers_per_block + 1,
 
210
  in_channels=input_channel,
211
  out_channels=output_channel,
212
  prev_output_channel=prev_output_channel,
213
- temb_channels=time_embed_dim,
214
  add_upsample=add_upsample,
215
  resnet_eps=norm_eps,
216
  resnet_act_fn=act_fn,
217
  resnet_groups=norm_num_groups,
218
- cross_attention_dim=cross_attention_dim,
219
- attn_num_head_channels=reversed_attention_head_dim[i],
220
  dual_cross_attention=dual_cross_attention,
221
  use_linear_projection=use_linear_projection,
222
  only_cross_attention=only_cross_attention[i],
 
 
 
 
 
 
223
  )
224
  self.up_blocks.append(up_block)
225
  prev_output_channel = output_channel
226
 
227
  # out
228
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
229
- self.conv_act = nn.SiLU()
230
- self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
231
-
232
- def set_attention_slice(self, slice_size):
233
- head_dims = self.config.attention_head_dim
234
- head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
235
- if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
236
- raise ValueError(
237
- f"Make sure slice_size {slice_size} is a common divisor of "
238
- f"the number of heads used in cross_attention: {head_dims}"
239
  )
240
- if slice_size is not None and slice_size > min(head_dims):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  raise ValueError(
242
- f"slice_size {slice_size} has to be smaller or equal to "
243
- f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
244
  )
245
 
246
- for block in self.down_blocks:
247
- if hasattr(block, "attentions") and block.attentions is not None:
248
- block.set_attention_slice(slice_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- self.mid_block.set_attention_slice(slice_size)
251
 
252
- for block in self.up_blocks:
253
- if hasattr(block, "attentions") and block.attentions is not None:
254
- block.set_attention_slice(slice_size)
 
 
 
 
 
 
 
 
255
 
256
- def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
257
- for block in self.down_blocks:
258
- if hasattr(block, "attentions") and block.attentions is not None:
259
- block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
 
 
260
 
261
- self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
 
262
 
263
- for block in self.up_blocks:
264
- if hasattr(block, "attentions") and block.attentions is not None:
265
- block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
266
 
267
  def _set_gradient_checkpointing(self, module, value=False):
268
  if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
@@ -274,24 +706,44 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
274
  timestep: Union[torch.Tensor, float, int],
275
  encoder_hidden_states: torch.Tensor,
276
  class_labels: Optional[torch.Tensor] = None,
277
- text_format_dict = {},
 
 
 
 
 
 
278
  return_dict: bool = True,
279
  ) -> Union[UNet2DConditionOutput, Tuple]:
280
  r"""
 
 
281
  Args:
282
- sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
283
- timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
284
- encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
 
 
 
 
 
 
285
  return_dict (`bool`, *optional*, defaults to `True`):
286
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
 
 
 
 
 
 
287
 
288
  Returns:
289
  [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
290
- [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
291
- returning a tuple, the first element is the sample tensor.
292
  """
293
  # By default samples have to be AT least a multiple of the overall upsampling factor.
294
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
295
  # However, the upsampling interpolation output size can be forced to fit any upsampling size
296
  # on the fly if necessary.
297
  default_overall_up_factor = 2**self.num_upsamplers
@@ -304,6 +756,27 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
304
  logger.info("Forward upsample size to force interpolation output size.")
305
  forward_upsample_size = True
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  # 0. center input if necessary
308
  if self.config.center_input_sample:
309
  sample = 2 * sample - 1.0
@@ -312,8 +785,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
312
  timesteps = timestep
313
  if not torch.is_tensor(timesteps):
314
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
315
- timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
316
- elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
 
 
 
 
 
 
317
  timesteps = timesteps[None].to(sample.device)
318
 
319
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -321,47 +800,148 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
321
 
322
  t_emb = self.time_proj(timesteps)
323
 
324
- # timesteps does not contain any weights and will always return f32 tensors
325
  # but time_embedding might actually be running in fp16. so we need to cast here.
326
  # there might be better ways to encapsulate this.
327
- t_emb = t_emb.to(dtype=self.dtype)
328
- emb = self.time_embedding(t_emb)
 
 
329
 
330
- if self.config.num_class_embeds is not None:
331
  if class_labels is None:
332
  raise ValueError("class_labels should be provided when num_class_embeds > 0")
333
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
334
- emb = emb + class_emb
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  # 2. pre-process
337
  sample = self.conv_in(sample)
338
 
339
  # 3. down
340
  down_block_res_samples = (sample,)
341
  for downsample_block in self.down_blocks:
342
- if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
343
- if isinstance(downsample_block, CrossAttnDownBlock2D):
344
- sample, res_samples = downsample_block(
345
- hidden_states=sample,
346
- temb=emb,
347
- encoder_hidden_states=encoder_hidden_states,
348
- text_format_dict=text_format_dict
349
- )
350
- else:
351
- sample, res_samples = downsample_block(
352
- hidden_states=sample,
353
- temb=emb,
354
- encoder_hidden_states=encoder_hidden_states,
355
- )
356
  else:
357
- if isinstance(downsample_block, CrossAttnDownBlock2D):
358
- import ipdb;ipdb.set_trace()
359
  sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
 
360
  down_block_res_samples += res_samples
361
 
 
 
 
 
 
 
 
 
 
 
 
362
  # 4. mid
363
- sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states,
364
- text_format_dict=text_format_dict)
 
 
 
 
 
 
 
 
 
 
365
 
366
  # 5. up
367
  for i, upsample_block in enumerate(self.up_blocks):
@@ -375,34 +955,26 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
375
  if not is_final_block and forward_upsample_size:
376
  upsample_size = down_block_res_samples[-1].shape[2:]
377
 
378
- if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
379
- if isinstance(upsample_block, CrossAttnUpBlock2D):
380
- sample = upsample_block(
381
- hidden_states=sample,
382
- temb=emb,
383
- res_hidden_states_tuple=res_samples,
384
- encoder_hidden_states=encoder_hidden_states,
385
- upsample_size=upsample_size,
386
- text_format_dict=text_format_dict
387
- )
388
- else:
389
- sample = upsample_block(
390
- hidden_states=sample,
391
- temb=emb,
392
- res_hidden_states_tuple=res_samples,
393
- encoder_hidden_states=encoder_hidden_states,
394
- upsample_size=upsample_size,
395
- )
396
  else:
397
- if isinstance(upsample_block, CrossAttnUpBlock2D):
398
- upsample_block.attentions
399
- import ipdb;ipdb.set_trace()
400
  sample = upsample_block(
401
  hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
402
  )
 
403
  # 6. post-process
404
- sample = self.conv_norm_out(sample)
405
- sample = self.conv_act(sample)
 
406
  sample = self.conv_out(sample)
407
 
408
  if not return_dict:
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
 
17
  import torch
18
  import torch.nn as nn
19
  import torch.utils.checkpoint
20
 
21
  from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import UNet2DConditionLoadersMixin
23
  from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.activations import get_activation
25
+
26
+ from diffusers.models.embeddings import (
27
+ GaussianFourierProjection,
28
+ ImageHintTimeEmbedding,
29
+ ImageProjection,
30
+ ImageTimeEmbedding,
31
+ TextImageProjection,
32
+ TextImageTimeEmbedding,
33
+ TextTimeEmbedding,
34
+ TimestepEmbedding,
35
+ Timesteps,
36
+ )
37
+ from diffusers.models.modeling_utils import ModelMixin
38
+
39
+ from models.attention_processor import AttentionProcessor, AttnProcessor
40
+
41
+ from models.unet_2d_blocks import (
42
  CrossAttnDownBlock2D,
43
  CrossAttnUpBlock2D,
44
  DownBlock2D,
45
  UNetMidBlock2DCrossAttn,
46
+ UNetMidBlock2DSimpleCrossAttn,
47
  UpBlock2D,
48
  get_down_block,
49
  get_up_block,
 
56
  @dataclass
57
  class UNet2DConditionOutput(BaseOutput):
58
  """
59
+ The output of [`UNet2DConditionModel`].
60
+
61
  Args:
62
  sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
63
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
64
  """
65
 
66
+ sample: torch.FloatTensor = None
67
 
68
 
69
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
70
  r"""
71
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
72
+ shaped output.
73
 
74
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
75
+ for all models (such as downloading or saving).
76
 
77
  Parameters:
78
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
79
  Height and width of input/output sample.
80
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
81
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
82
  center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
83
  flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
84
  Whether to flip the sin to cos in the time embedding.
85
  freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
86
  down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
87
  The tuple of downsample blocks to use.
88
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
89
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
90
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
91
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
92
  The tuple of upsample blocks to use.
93
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
94
+ Whether to include self-attention in the basic transformer blocks, see
95
+ [`~models.attention.BasicTransformerBlock`].
96
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
97
  The tuple of output channels for each block.
98
  layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
 
100
  mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
101
  act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
102
  norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
103
+ If `None`, normalization and activation layers is skipped in post-processing.
104
  norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
105
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
106
+ The dimension of the cross attention features.
107
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
108
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
109
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
110
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
111
+ encoder_hid_dim (`int`, *optional*, defaults to None):
112
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
113
+ dimension to `cross_attention_dim`.
114
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
115
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
116
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
117
  attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
118
+ num_attention_heads (`int`, *optional*):
119
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
120
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
121
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
122
+ class_embed_type (`str`, *optional*, defaults to `None`):
123
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
124
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
125
+ addition_embed_type (`str`, *optional*, defaults to `None`):
126
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
127
+ "text". "text" will use the `TextTimeEmbedding` layer.
128
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
129
+ Dimension for the timestep embeddings.
130
+ num_class_embeds (`int`, *optional*, defaults to `None`):
131
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
132
+ class conditioning with `class_embed_type` equal to `None`.
133
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
134
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
135
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
136
+ An optional override for the dimension of the projected time embedding.
137
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
138
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
139
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
140
+ timestep_post_act (`str`, *optional*, defaults to `None`):
141
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
142
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
143
+ The dimension of `cond_proj` layer in the timestep embedding.
144
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
145
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
146
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
147
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
148
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
149
+ embeddings with the class embeddings.
150
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
151
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
152
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
153
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
154
+ otherwise.
155
  """
156
 
157
  _supports_gradient_checkpointing = True
 
171
  "CrossAttnDownBlock2D",
172
  "DownBlock2D",
173
  ),
174
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
175
  up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
176
  only_cross_attention: Union[bool, Tuple[bool]] = False,
177
  block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
178
+ layers_per_block: Union[int, Tuple[int]] = 2,
179
  downsample_padding: int = 1,
180
  mid_block_scale_factor: float = 1,
181
  act_fn: str = "silu",
182
+ norm_num_groups: Optional[int] = 32,
183
  norm_eps: float = 1e-5,
184
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
185
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
186
+ encoder_hid_dim: Optional[int] = None,
187
+ encoder_hid_dim_type: Optional[str] = None,
188
  attention_head_dim: Union[int, Tuple[int]] = 8,
189
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
190
  dual_cross_attention: bool = False,
191
  use_linear_projection: bool = False,
192
+ class_embed_type: Optional[str] = None,
193
+ addition_embed_type: Optional[str] = None,
194
+ addition_time_embed_dim: Optional[int] = None,
195
  num_class_embeds: Optional[int] = None,
196
+ upcast_attention: bool = False,
197
+ resnet_time_scale_shift: str = "default",
198
+ resnet_skip_time_act: bool = False,
199
+ resnet_out_scale_factor: int = 1.0,
200
+ time_embedding_type: str = "positional",
201
+ time_embedding_dim: Optional[int] = None,
202
+ time_embedding_act_fn: Optional[str] = None,
203
+ timestep_post_act: Optional[str] = None,
204
+ time_cond_proj_dim: Optional[int] = None,
205
+ conv_in_kernel: int = 3,
206
+ conv_out_kernel: int = 3,
207
+ projection_class_embeddings_input_dim: Optional[int] = None,
208
+ class_embeddings_concat: bool = False,
209
+ mid_block_only_cross_attention: Optional[bool] = None,
210
+ cross_attention_norm: Optional[str] = None,
211
+ addition_embed_type_num_heads=64,
212
  ):
213
  super().__init__()
214
 
215
  self.sample_size = sample_size
216
+
217
+ if num_attention_heads is not None:
218
+ raise ValueError(
219
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
220
+ )
221
+
222
+ # If `num_attention_heads` is not defined (which is the case for most models)
223
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
224
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
225
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
226
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
227
+ # which is why we correct for the naming here.
228
+ num_attention_heads = num_attention_heads or attention_head_dim
229
+
230
+ # Check inputs
231
+ if len(down_block_types) != len(up_block_types):
232
+ raise ValueError(
233
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
234
+ )
235
+
236
+ if len(block_out_channels) != len(down_block_types):
237
+ raise ValueError(
238
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
239
+ )
240
+
241
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
242
+ raise ValueError(
243
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
244
+ )
245
+
246
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
247
+ raise ValueError(
248
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
249
+ )
250
+
251
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
262
+ raise ValueError(
263
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
264
+ )
265
 
266
  # input
267
+ conv_in_padding = (conv_in_kernel - 1) // 2
268
+ self.conv_in = nn.Conv2d(
269
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
270
+ )
271
 
272
  # time
273
+ if time_embedding_type == "fourier":
274
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
275
+ if time_embed_dim % 2 != 0:
276
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
277
+ self.time_proj = GaussianFourierProjection(
278
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
279
+ )
280
+ timestep_input_dim = time_embed_dim
281
+ elif time_embedding_type == "positional":
282
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
283
+
284
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
285
+ timestep_input_dim = block_out_channels[0]
286
+ else:
287
+ raise ValueError(
288
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
289
+ )
290
+
291
+ self.time_embedding = TimestepEmbedding(
292
+ timestep_input_dim,
293
+ time_embed_dim,
294
+ act_fn=act_fn,
295
+ post_act_fn=timestep_post_act,
296
+ cond_proj_dim=time_cond_proj_dim,
297
+ )
298
 
299
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
300
+ encoder_hid_dim_type = "text_proj"
301
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
302
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
303
+
304
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
305
+ raise ValueError(
306
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
307
+ )
308
+
309
+ if encoder_hid_dim_type == "text_proj":
310
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
311
+ elif encoder_hid_dim_type == "text_image_proj":
312
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
313
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
314
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
315
+ self.encoder_hid_proj = TextImageProjection(
316
+ text_embed_dim=encoder_hid_dim,
317
+ image_embed_dim=cross_attention_dim,
318
+ cross_attention_dim=cross_attention_dim,
319
+ )
320
+ elif encoder_hid_dim_type == "image_proj":
321
+ # Kandinsky 2.2
322
+ self.encoder_hid_proj = ImageProjection(
323
+ image_embed_dim=encoder_hid_dim,
324
+ cross_attention_dim=cross_attention_dim,
325
+ )
326
+ elif encoder_hid_dim_type is not None:
327
+ raise ValueError(
328
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
329
+ )
330
+ else:
331
+ self.encoder_hid_proj = None
332
 
333
  # class embedding
334
+ if class_embed_type is None and num_class_embeds is not None:
335
  self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
336
+ elif class_embed_type == "timestep":
337
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
338
+ elif class_embed_type == "identity":
339
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
340
+ elif class_embed_type == "projection":
341
+ if projection_class_embeddings_input_dim is None:
342
+ raise ValueError(
343
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
344
+ )
345
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
346
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
347
+ # 2. it projects from an arbitrary input dimension.
348
+ #
349
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
350
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
351
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
352
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
353
+ elif class_embed_type == "simple_projection":
354
+ if projection_class_embeddings_input_dim is None:
355
+ raise ValueError(
356
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
357
+ )
358
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
359
+ else:
360
+ self.class_embedding = None
361
+
362
+ if addition_embed_type == "text":
363
+ if encoder_hid_dim is not None:
364
+ text_time_embedding_from_dim = encoder_hid_dim
365
+ else:
366
+ text_time_embedding_from_dim = cross_attention_dim
367
+
368
+ self.add_embedding = TextTimeEmbedding(
369
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
370
+ )
371
+ elif addition_embed_type == "text_image":
372
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
373
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
374
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
375
+ self.add_embedding = TextImageTimeEmbedding(
376
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
377
+ )
378
+ elif addition_embed_type == "text_time":
379
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
380
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
381
+ elif addition_embed_type == "image":
382
+ # Kandinsky 2.2
383
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
384
+ elif addition_embed_type == "image_hint":
385
+ # Kandinsky 2.2 ControlNet
386
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
387
+ elif addition_embed_type is not None:
388
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
389
+
390
+ if time_embedding_act_fn is None:
391
+ self.time_embed_act = None
392
+ else:
393
+ self.time_embed_act = get_activation(time_embedding_act_fn)
394
 
395
  self.down_blocks = nn.ModuleList([])
 
396
  self.up_blocks = nn.ModuleList([])
397
 
398
  if isinstance(only_cross_attention, bool):
399
+ if mid_block_only_cross_attention is None:
400
+ mid_block_only_cross_attention = only_cross_attention
401
+
402
  only_cross_attention = [only_cross_attention] * len(down_block_types)
403
 
404
+ if mid_block_only_cross_attention is None:
405
+ mid_block_only_cross_attention = False
406
+
407
+ if isinstance(num_attention_heads, int):
408
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
409
+
410
  if isinstance(attention_head_dim, int):
411
  attention_head_dim = (attention_head_dim,) * len(down_block_types)
412
 
413
+ if isinstance(cross_attention_dim, int):
414
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
415
+
416
+ if isinstance(layers_per_block, int):
417
+ layers_per_block = [layers_per_block] * len(down_block_types)
418
+
419
+ if isinstance(transformer_layers_per_block, int):
420
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
421
+
422
+ if class_embeddings_concat:
423
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
424
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
425
+ # regular time embeddings
426
+ blocks_time_embed_dim = time_embed_dim * 2
427
+ else:
428
+ blocks_time_embed_dim = time_embed_dim
429
+
430
  # down
431
  output_channel = block_out_channels[0]
432
  for i, down_block_type in enumerate(down_block_types):
 
436
 
437
  down_block = get_down_block(
438
  down_block_type,
439
+ num_layers=layers_per_block[i],
440
+ transformer_layers_per_block=transformer_layers_per_block[i],
441
  in_channels=input_channel,
442
  out_channels=output_channel,
443
+ temb_channels=blocks_time_embed_dim,
444
  add_downsample=not is_final_block,
445
  resnet_eps=norm_eps,
446
  resnet_act_fn=act_fn,
447
  resnet_groups=norm_num_groups,
448
+ cross_attention_dim=cross_attention_dim[i],
449
+ num_attention_heads=num_attention_heads[i],
450
  downsample_padding=downsample_padding,
451
  dual_cross_attention=dual_cross_attention,
452
  use_linear_projection=use_linear_projection,
453
  only_cross_attention=only_cross_attention[i],
454
+ upcast_attention=upcast_attention,
455
+ resnet_time_scale_shift=resnet_time_scale_shift,
456
+ resnet_skip_time_act=resnet_skip_time_act,
457
+ resnet_out_scale_factor=resnet_out_scale_factor,
458
+ cross_attention_norm=cross_attention_norm,
459
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
460
  )
461
  self.down_blocks.append(down_block)
462
 
463
  # mid
464
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
465
+ self.mid_block = UNetMidBlock2DCrossAttn(
466
+ transformer_layers_per_block=transformer_layers_per_block[-1],
467
+ in_channels=block_out_channels[-1],
468
+ temb_channels=blocks_time_embed_dim,
469
+ resnet_eps=norm_eps,
470
+ resnet_act_fn=act_fn,
471
+ output_scale_factor=mid_block_scale_factor,
472
+ resnet_time_scale_shift=resnet_time_scale_shift,
473
+ cross_attention_dim=cross_attention_dim[-1],
474
+ num_attention_heads=num_attention_heads[-1],
475
+ resnet_groups=norm_num_groups,
476
+ dual_cross_attention=dual_cross_attention,
477
+ use_linear_projection=use_linear_projection,
478
+ upcast_attention=upcast_attention,
479
+ )
480
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
481
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
482
+ in_channels=block_out_channels[-1],
483
+ temb_channels=blocks_time_embed_dim,
484
+ resnet_eps=norm_eps,
485
+ resnet_act_fn=act_fn,
486
+ output_scale_factor=mid_block_scale_factor,
487
+ cross_attention_dim=cross_attention_dim[-1],
488
+ attention_head_dim=attention_head_dim[-1],
489
+ resnet_groups=norm_num_groups,
490
+ resnet_time_scale_shift=resnet_time_scale_shift,
491
+ skip_time_act=resnet_skip_time_act,
492
+ only_cross_attention=mid_block_only_cross_attention,
493
+ cross_attention_norm=cross_attention_norm,
494
+ )
495
+ elif mid_block_type is None:
496
+ self.mid_block = None
497
+ else:
498
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
499
 
500
  # count how many layers upsample the images
501
  self.num_upsamplers = 0
502
 
503
  # up
504
  reversed_block_out_channels = list(reversed(block_out_channels))
505
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
506
+ reversed_layers_per_block = list(reversed(layers_per_block))
507
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
508
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
509
  only_cross_attention = list(reversed(only_cross_attention))
510
+
511
  output_channel = reversed_block_out_channels[0]
512
  for i, up_block_type in enumerate(up_block_types):
513
  is_final_block = i == len(block_out_channels) - 1
 
525
 
526
  up_block = get_up_block(
527
  up_block_type,
528
+ num_layers=reversed_layers_per_block[i] + 1,
529
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
530
  in_channels=input_channel,
531
  out_channels=output_channel,
532
  prev_output_channel=prev_output_channel,
533
+ temb_channels=blocks_time_embed_dim,
534
  add_upsample=add_upsample,
535
  resnet_eps=norm_eps,
536
  resnet_act_fn=act_fn,
537
  resnet_groups=norm_num_groups,
538
+ cross_attention_dim=reversed_cross_attention_dim[i],
539
+ num_attention_heads=reversed_num_attention_heads[i],
540
  dual_cross_attention=dual_cross_attention,
541
  use_linear_projection=use_linear_projection,
542
  only_cross_attention=only_cross_attention[i],
543
+ upcast_attention=upcast_attention,
544
+ resnet_time_scale_shift=resnet_time_scale_shift,
545
+ resnet_skip_time_act=resnet_skip_time_act,
546
+ resnet_out_scale_factor=resnet_out_scale_factor,
547
+ cross_attention_norm=cross_attention_norm,
548
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
549
  )
550
  self.up_blocks.append(up_block)
551
  prev_output_channel = output_channel
552
 
553
  # out
554
+ if norm_num_groups is not None:
555
+ self.conv_norm_out = nn.GroupNorm(
556
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
 
 
 
 
 
 
 
 
557
  )
558
+
559
+ self.conv_act = get_activation(act_fn)
560
+
561
+ else:
562
+ self.conv_norm_out = None
563
+ self.conv_act = None
564
+
565
+ conv_out_padding = (conv_out_kernel - 1) // 2
566
+ self.conv_out = nn.Conv2d(
567
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
568
+ )
569
+
570
+ @property
571
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
572
+ r"""
573
+ Returns:
574
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
575
+ indexed by its weight name.
576
+ """
577
+ # set recursively
578
+ processors = {}
579
+
580
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
581
+ if hasattr(module, "set_processor"):
582
+ processors[f"{name}.processor"] = module.processor
583
+
584
+ for sub_name, child in module.named_children():
585
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
586
+
587
+ return processors
588
+
589
+ for name, module in self.named_children():
590
+ fn_recursive_add_processors(name, module, processors)
591
+
592
+ return processors
593
+
594
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
595
+ r"""
596
+ Sets the attention processor to use to compute attention.
597
+
598
+ Parameters:
599
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
600
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
601
+ for **all** `Attention` layers.
602
+
603
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
604
+ processor. This is strongly recommended when setting trainable attention processors.
605
+
606
+ """
607
+ count = len(self.attn_processors.keys())
608
+
609
+ if isinstance(processor, dict) and len(processor) != count:
610
  raise ValueError(
611
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
612
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
613
  )
614
 
615
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
616
+ if hasattr(module, "set_processor"):
617
+ if not isinstance(processor, dict):
618
+ module.set_processor(processor)
619
+ else:
620
+ module.set_processor(processor.pop(f"{name}.processor"))
621
+
622
+ for sub_name, child in module.named_children():
623
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
624
+
625
+ for name, module in self.named_children():
626
+ fn_recursive_attn_processor(name, module, processor)
627
+
628
+ def set_default_attn_processor(self):
629
+ """
630
+ Disables custom attention processors and sets the default attention implementation.
631
+ """
632
+ self.set_attn_processor(AttnProcessor())
633
+
634
+ def set_attention_slice(self, slice_size):
635
+ r"""
636
+ Enable sliced attention computation.
637
+
638
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
639
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
640
+
641
+ Args:
642
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
643
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
644
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
645
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
646
+ must be a multiple of `slice_size`.
647
+ """
648
+ sliceable_head_dims = []
649
+
650
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
651
+ if hasattr(module, "set_attention_slice"):
652
+ sliceable_head_dims.append(module.sliceable_head_dim)
653
+
654
+ for child in module.children():
655
+ fn_recursive_retrieve_sliceable_dims(child)
656
+
657
+ # retrieve number of attention layers
658
+ for module in self.children():
659
+ fn_recursive_retrieve_sliceable_dims(module)
660
+
661
+ num_sliceable_layers = len(sliceable_head_dims)
662
+
663
+ if slice_size == "auto":
664
+ # half the attention head size is usually a good trade-off between
665
+ # speed and memory
666
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
667
+ elif slice_size == "max":
668
+ # make smallest slice possible
669
+ slice_size = num_sliceable_layers * [1]
670
 
671
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
672
 
673
+ if len(slice_size) != len(sliceable_head_dims):
674
+ raise ValueError(
675
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
676
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
677
+ )
678
+
679
+ for i in range(len(slice_size)):
680
+ size = slice_size[i]
681
+ dim = sliceable_head_dims[i]
682
+ if size is not None and size > dim:
683
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
684
 
685
+ # Recursively walk through all the children.
686
+ # Any children which exposes the set_attention_slice method
687
+ # gets the message
688
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
689
+ if hasattr(module, "set_attention_slice"):
690
+ module.set_attention_slice(slice_size.pop())
691
 
692
+ for child in module.children():
693
+ fn_recursive_set_attention_slice(child, slice_size)
694
 
695
+ reversed_slice_size = list(reversed(slice_size))
696
+ for module in self.children():
697
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
698
 
699
  def _set_gradient_checkpointing(self, module, value=False):
700
  if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
 
706
  timestep: Union[torch.Tensor, float, int],
707
  encoder_hidden_states: torch.Tensor,
708
  class_labels: Optional[torch.Tensor] = None,
709
+ timestep_cond: Optional[torch.Tensor] = None,
710
+ attention_mask: Optional[torch.Tensor] = None,
711
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
712
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
713
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
714
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
715
+ encoder_attention_mask: Optional[torch.Tensor] = None,
716
  return_dict: bool = True,
717
  ) -> Union[UNet2DConditionOutput, Tuple]:
718
  r"""
719
+ The [`UNet2DConditionModel`] forward method.
720
+
721
  Args:
722
+ sample (`torch.FloatTensor`):
723
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
724
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
725
+ encoder_hidden_states (`torch.FloatTensor`):
726
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
727
+ encoder_attention_mask (`torch.Tensor`):
728
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
729
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
730
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
731
  return_dict (`bool`, *optional*, defaults to `True`):
732
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
733
+ tuple.
734
+ cross_attention_kwargs (`dict`, *optional*):
735
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
736
+ added_cond_kwargs: (`dict`, *optional*):
737
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
738
+ are passed along to the UNet blocks.
739
 
740
  Returns:
741
  [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
742
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
743
+ a `tuple` is returned where the first element is the sample tensor.
744
  """
745
  # By default samples have to be AT least a multiple of the overall upsampling factor.
746
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
747
  # However, the upsampling interpolation output size can be forced to fit any upsampling size
748
  # on the fly if necessary.
749
  default_overall_up_factor = 2**self.num_upsamplers
 
756
  logger.info("Forward upsample size to force interpolation output size.")
757
  forward_upsample_size = True
758
 
759
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
760
+ # expects mask of shape:
761
+ # [batch, key_tokens]
762
+ # adds singleton query_tokens dimension:
763
+ # [batch, 1, key_tokens]
764
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
765
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
766
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
767
+ if attention_mask is not None:
768
+ # assume that mask is expressed as:
769
+ # (1 = keep, 0 = discard)
770
+ # convert mask into a bias that can be added to attention scores:
771
+ # (keep = +0, discard = -10000.0)
772
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
773
+ attention_mask = attention_mask.unsqueeze(1)
774
+
775
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
776
+ if encoder_attention_mask is not None:
777
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
778
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
779
+
780
  # 0. center input if necessary
781
  if self.config.center_input_sample:
782
  sample = 2 * sample - 1.0
 
785
  timesteps = timestep
786
  if not torch.is_tensor(timesteps):
787
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
788
+ # This would be a good case for the `match` statement (Python 3.10+)
789
+ is_mps = sample.device.type == "mps"
790
+ if isinstance(timestep, float):
791
+ dtype = torch.float32 if is_mps else torch.float64
792
+ else:
793
+ dtype = torch.int32 if is_mps else torch.int64
794
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
795
+ elif len(timesteps.shape) == 0:
796
  timesteps = timesteps[None].to(sample.device)
797
 
798
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
 
800
 
801
  t_emb = self.time_proj(timesteps)
802
 
803
+ # `Timesteps` does not contain any weights and will always return f32 tensors
804
  # but time_embedding might actually be running in fp16. so we need to cast here.
805
  # there might be better ways to encapsulate this.
806
+ t_emb = t_emb.to(dtype=sample.dtype)
807
+
808
+ emb = self.time_embedding(t_emb, timestep_cond)
809
+ aug_emb = None
810
 
811
+ if self.class_embedding is not None:
812
  if class_labels is None:
813
  raise ValueError("class_labels should be provided when num_class_embeds > 0")
 
 
814
 
815
+ if self.config.class_embed_type == "timestep":
816
+ class_labels = self.time_proj(class_labels)
817
+
818
+ # `Timesteps` does not contain any weights and will always return f32 tensors
819
+ # there might be better ways to encapsulate this.
820
+ class_labels = class_labels.to(dtype=sample.dtype)
821
+
822
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
823
+
824
+ if self.config.class_embeddings_concat:
825
+ emb = torch.cat([emb, class_emb], dim=-1)
826
+ else:
827
+ emb = emb + class_emb
828
+
829
+ if self.config.addition_embed_type == "text":
830
+ aug_emb = self.add_embedding(encoder_hidden_states)
831
+ elif self.config.addition_embed_type == "text_image":
832
+ # Kandinsky 2.1 - style
833
+ if "image_embeds" not in added_cond_kwargs:
834
+ raise ValueError(
835
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
836
+ )
837
+
838
+ image_embs = added_cond_kwargs.get("image_embeds")
839
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
840
+ aug_emb = self.add_embedding(text_embs, image_embs)
841
+ elif self.config.addition_embed_type == "text_time":
842
+ if "text_embeds" not in added_cond_kwargs:
843
+ raise ValueError(
844
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
845
+ )
846
+ text_embeds = added_cond_kwargs.get("text_embeds")
847
+ if "time_ids" not in added_cond_kwargs:
848
+ raise ValueError(
849
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
850
+ )
851
+ time_ids = added_cond_kwargs.get("time_ids")
852
+ time_embeds = self.add_time_proj(time_ids.flatten())
853
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
854
+
855
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
856
+ add_embeds = add_embeds.to(emb.dtype)
857
+ aug_emb = self.add_embedding(add_embeds)
858
+ elif self.config.addition_embed_type == "image":
859
+ # Kandinsky 2.2 - style
860
+ if "image_embeds" not in added_cond_kwargs:
861
+ raise ValueError(
862
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
863
+ )
864
+ image_embs = added_cond_kwargs.get("image_embeds")
865
+ aug_emb = self.add_embedding(image_embs)
866
+ elif self.config.addition_embed_type == "image_hint":
867
+ # Kandinsky 2.2 - style
868
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
869
+ raise ValueError(
870
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
871
+ )
872
+ image_embs = added_cond_kwargs.get("image_embeds")
873
+ hint = added_cond_kwargs.get("hint")
874
+ aug_emb, hint = self.add_embedding(image_embs, hint)
875
+ sample = torch.cat([sample, hint], dim=1)
876
+
877
+ emb = emb + aug_emb if aug_emb is not None else emb
878
+
879
+ if self.time_embed_act is not None:
880
+ emb = self.time_embed_act(emb)
881
+
882
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
883
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
884
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
885
+ # Kadinsky 2.1 - style
886
+ if "image_embeds" not in added_cond_kwargs:
887
+ raise ValueError(
888
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
889
+ )
890
+
891
+ image_embeds = added_cond_kwargs.get("image_embeds")
892
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
893
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
894
+ # Kandinsky 2.2 - style
895
+ if "image_embeds" not in added_cond_kwargs:
896
+ raise ValueError(
897
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
898
+ )
899
+ image_embeds = added_cond_kwargs.get("image_embeds")
900
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
901
  # 2. pre-process
902
  sample = self.conv_in(sample)
903
 
904
  # 3. down
905
  down_block_res_samples = (sample,)
906
  for downsample_block in self.down_blocks:
907
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
908
+ sample, res_samples = downsample_block(
909
+ hidden_states=sample,
910
+ temb=emb,
911
+ encoder_hidden_states=encoder_hidden_states,
912
+ attention_mask=attention_mask,
913
+ cross_attention_kwargs=cross_attention_kwargs,
914
+ encoder_attention_mask=encoder_attention_mask,
915
+ )
 
 
 
 
 
916
  else:
 
 
917
  sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
918
+
919
  down_block_res_samples += res_samples
920
 
921
+ if down_block_additional_residuals is not None:
922
+ new_down_block_res_samples = ()
923
+
924
+ for down_block_res_sample, down_block_additional_residual in zip(
925
+ down_block_res_samples, down_block_additional_residuals
926
+ ):
927
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
928
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
929
+
930
+ down_block_res_samples = new_down_block_res_samples
931
+
932
  # 4. mid
933
+ if self.mid_block is not None:
934
+ sample = self.mid_block(
935
+ sample,
936
+ emb,
937
+ encoder_hidden_states=encoder_hidden_states,
938
+ attention_mask=attention_mask,
939
+ cross_attention_kwargs=cross_attention_kwargs,
940
+ encoder_attention_mask=encoder_attention_mask,
941
+ )
942
+
943
+ if mid_block_additional_residual is not None:
944
+ sample = sample + mid_block_additional_residual
945
 
946
  # 5. up
947
  for i, upsample_block in enumerate(self.up_blocks):
 
955
  if not is_final_block and forward_upsample_size:
956
  upsample_size = down_block_res_samples[-1].shape[2:]
957
 
958
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
959
+ sample = upsample_block(
960
+ hidden_states=sample,
961
+ temb=emb,
962
+ res_hidden_states_tuple=res_samples,
963
+ encoder_hidden_states=encoder_hidden_states,
964
+ cross_attention_kwargs=cross_attention_kwargs,
965
+ upsample_size=upsample_size,
966
+ attention_mask=attention_mask,
967
+ encoder_attention_mask=encoder_attention_mask,
968
+ )
 
 
 
 
 
 
 
969
  else:
 
 
 
970
  sample = upsample_block(
971
  hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
972
  )
973
+
974
  # 6. post-process
975
+ if self.conv_norm_out:
976
+ sample = self.conv_norm_out(sample)
977
+ sample = self.conv_act(sample)
978
  sample = self.conv_out(sample)
979
 
980
  if not return_dict:
requirements.txt CHANGED
@@ -6,4 +6,4 @@ transformers==4.26.0
6
  numpy==1.24.2
7
  seaborn==0.12.2
8
  accelerate==0.16.0
9
- scikit-learn==0.24.1
 
6
  numpy==1.24.2
7
  seaborn==0.12.2
8
  accelerate==0.16.0
9
+ scikit-learn==1.1.3
utils/attention_utils.py CHANGED
@@ -7,25 +7,44 @@ import torch
7
  import torchvision
8
 
9
  from utils.richtext_utils import seed_everything
10
- from sklearn.cluster import SpectralClustering
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  SelfAttentionLayers = [
13
- 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
14
- 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
15
  'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
16
- 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
17
  'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
18
  'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
19
  'mid_block.attentions.0.transformer_blocks.0.attn1',
20
  'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
21
  'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
22
  'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
23
- 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
24
  'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
25
- 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
26
- 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
27
- 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
28
- 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
29
  ]
30
 
31
 
@@ -48,6 +67,50 @@ CrossAttentionLayers = [
48
  # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
49
  ]
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def split_attention_maps_over_steps(attention_maps):
53
  r"""Function for splitting attention maps over steps.
@@ -75,8 +138,233 @@ def split_attention_maps_over_steps(attention_maps):
75
  return attention_maps_cond, attention_maps_uncond
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
79
- atten_names = ['presoftmax', 'postsoftmax', 'postsoftmax_erosion']
80
  for i, attn_map in enumerate(atten_map_list):
81
  n_obj = len(attn_map)
82
  plt.figure()
@@ -88,7 +376,7 @@ def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=N
88
  fig.set_figheight(3)
89
  fig.set_figwidth(3*n_obj+0.1)
90
 
91
- cmap = plt.get_cmap('OrRd')
92
 
93
  vmax = 0
94
  vmin = 1
@@ -117,18 +405,22 @@ def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=N
117
  norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
118
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
119
  fig.colorbar(sm, cax=axs[-1])
 
 
 
120
  canvas = fig.canvas
121
  canvas.draw()
122
  width, height = canvas.get_width_height()
123
  img = np.frombuffer(canvas.tostring_rgb(),
124
  dtype='uint8').reshape((height, width, 3))
125
-
126
- fig.tight_layout()
127
- plt.close()
128
  return img
129
 
130
 
131
- def get_token_maps_deprecated(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None):
 
132
  r"""Function to visualize attention maps.
133
  Args:
134
  save_dir (str): Path to save attention maps
@@ -175,11 +467,11 @@ def get_token_maps_deprecated(attention_maps, save_dir, width, height, obj_token
175
  else:
176
  obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[
177
  0].permute([3, 0, 1, 2])
 
178
  obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
179
  interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
180
  attention_maps[obj_id].append(obj_attention_map)
181
 
182
- # average attention maps over steps
183
  attention_maps_averaged = []
184
  for obj_id, obj_token in enumerate(obj_tokens):
185
  if obj_id == len(obj_tokens) - 1:
@@ -189,27 +481,114 @@ def get_token_maps_deprecated(attention_maps, save_dir, width, height, obj_token
189
  attention_maps_averaged.append(
190
  torch.cat(attention_maps[obj_id]).mean(0))
191
 
192
- # normalize attention maps into [0, 1]
193
  attention_maps_averaged_normalized = []
194
  attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0)
195
  for obj_id, obj_token in enumerate(obj_tokens):
196
  attention_maps_averaged_normalized.append(
197
  attention_maps_averaged[obj_id]/attention_maps_averaged_sum)
198
 
199
- # softmax
200
- attention_maps_averaged_normalized = (
201
- torch.cat(attention_maps_averaged)/0.001).softmax(0)
202
- attention_maps_averaged_normalized = [
203
- attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- token_maps_vis = plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
206
- obj_tokens, save_dir, seed, tokens_vis)
207
  attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
208
  [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
209
- return attention_maps_averaged_normalized, token_maps_vis
 
210
 
211
 
212
- def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
213
  preprocess=False, segment_threshold=0.3, num_segments=5, return_vis=False, save_attn=False):
214
  r"""Function to visualize attention maps.
215
  Args:
@@ -218,15 +597,20 @@ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, heigh
218
  sampler_order (int): Sampler order
219
  """
220
 
221
- # create the segmentation mask using self-attention maps
222
  resolution = 32
 
 
 
223
  attn_maps_1024 = {8: [], 16: [], 32: [], 64: []}
224
  for attn_map in selfattn_maps.values():
225
  resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
226
  if resolution_map != resolution:
227
  continue
 
 
 
228
  attn_map = attn_map.reshape(
229
- 1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2])
230
  attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
231
  mode='bicubic', antialias=True)
232
  attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
@@ -237,7 +621,16 @@ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, heigh
237
  print('saving self-attention maps...', attn_maps_1024.shape)
238
  torch.save(torch.from_numpy(attn_maps_1024),
239
  'results/maps/selfattn_maps.pth')
240
- seed_everything(seed)
 
 
 
 
 
 
 
 
 
241
  sc = SpectralClustering(num_segments, affinity='precomputed', n_init=100,
242
  assign_labels='kmeans')
243
  clusters = sc.fit_predict(attn_maps_1024)
@@ -245,6 +638,8 @@ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, heigh
245
  fig = plt.figure()
246
  plt.imshow(clusters)
247
  plt.axis('off')
 
 
248
  if return_vis:
249
  canvas = fig.canvas
250
  canvas.draw()
@@ -258,18 +653,16 @@ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, heigh
258
  cross_attn_maps_1024 = []
259
  for attn_map in crossattn_maps.values():
260
  resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
 
 
261
  attn_map = attn_map.reshape(
262
- 1, resolution_map, resolution_map, -1).permute([0, 3, 1, 2])
263
  attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
264
  mode='bicubic', antialias=True)
265
  cross_attn_maps_1024.append(attn_map.permute([0, 2, 3, 1]))
266
 
267
  cross_attn_maps_1024 = torch.cat(
268
  cross_attn_maps_1024).mean(0).cpu().numpy()
269
- if save_attn:
270
- print('saving cross-attention maps...', cross_attn_maps_1024.shape)
271
- torch.save(torch.from_numpy(cross_attn_maps_1024),
272
- 'results/maps/crossattn_maps.pth')
273
  normalized_span_maps = []
274
  for token_ids in obj_tokens:
275
  span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
@@ -277,7 +670,8 @@ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, heigh
277
  for i in range(span_token_maps.shape[-1]):
278
  curr_noun_map = span_token_maps[:, :, i]
279
  normalized_span_map[:, :, i] = (
280
- curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
 
281
  normalized_span_maps.append(normalized_span_map)
282
  foreground_token_maps = [np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze(
283
  ) for normalized_span_map in normalized_span_maps]
@@ -308,8 +702,19 @@ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, heigh
308
  0) for token_map in resized_token_maps]
309
  foreground_token_maps = [token_map[None, :, :]
310
  for token_map in foreground_token_maps]
 
 
 
 
 
 
 
 
 
 
 
311
  token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
312
- save_dir, seed, tokens_vis)
313
  resized_token_maps = [token_map.unsqueeze(1).repeat(
314
  [1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
315
  if return_vis:
 
7
  import torchvision
8
 
9
  from utils.richtext_utils import seed_everything
10
+ from sklearn.cluster import KMeans, SpectralClustering
11
+
12
+ # SelfAttentionLayers = [
13
+ # # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
14
+ # # 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
15
+ # 'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
16
+ # # 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
17
+ # 'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
18
+ # 'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
19
+ # 'mid_block.attentions.0.transformer_blocks.0.attn1',
20
+ # 'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
21
+ # 'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
22
+ # 'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
23
+ # # 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
24
+ # 'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
25
+ # # 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
26
+ # # 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
27
+ # # 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
28
+ # # 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
29
+ # ]
30
 
31
  SelfAttentionLayers = [
32
+ # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
33
+ # 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
34
  'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
35
+ # 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
36
  'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
37
  'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
38
  'mid_block.attentions.0.transformer_blocks.0.attn1',
39
  'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
40
  'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
41
  'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
42
+ # 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
43
  'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
44
+ # 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
45
+ # 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
46
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
47
+ # 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
48
  ]
49
 
50
 
 
67
  # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
68
  ]
69
 
70
+ # CrossAttentionLayers = [
71
+ # 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
72
+ # 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
73
+ # 'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
74
+ # 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
75
+ # 'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
76
+ # 'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
77
+ # 'mid_block.attentions.0.transformer_blocks.0.attn2',
78
+ # 'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
79
+ # 'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
80
+ # 'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
81
+ # 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
82
+ # 'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
83
+ # 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
84
+ # 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
85
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
86
+ # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
87
+ # ]
88
+
89
+ # CrossAttentionLayers_XL = [
90
+ # 'up_blocks.0.attentions.0.transformer_blocks.1.attn2',
91
+ # 'up_blocks.0.attentions.0.transformer_blocks.2.attn2',
92
+ # 'up_blocks.0.attentions.0.transformer_blocks.3.attn2',
93
+ # 'up_blocks.0.attentions.0.transformer_blocks.4.attn2',
94
+ # 'up_blocks.0.attentions.0.transformer_blocks.5.attn2',
95
+ # 'up_blocks.0.attentions.0.transformer_blocks.6.attn2',
96
+ # 'up_blocks.0.attentions.0.transformer_blocks.7.attn2',
97
+ # ]
98
+ CrossAttentionLayers_XL = [
99
+ 'down_blocks.2.attentions.1.transformer_blocks.3.attn2',
100
+ 'down_blocks.2.attentions.1.transformer_blocks.4.attn2',
101
+ 'mid_block.attentions.0.transformer_blocks.0.attn2',
102
+ 'mid_block.attentions.0.transformer_blocks.1.attn2',
103
+ 'mid_block.attentions.0.transformer_blocks.2.attn2',
104
+ 'mid_block.attentions.0.transformer_blocks.3.attn2',
105
+ 'up_blocks.0.attentions.0.transformer_blocks.1.attn2',
106
+ 'up_blocks.0.attentions.0.transformer_blocks.2.attn2',
107
+ 'up_blocks.0.attentions.0.transformer_blocks.3.attn2',
108
+ 'up_blocks.0.attentions.0.transformer_blocks.4.attn2',
109
+ 'up_blocks.0.attentions.0.transformer_blocks.5.attn2',
110
+ 'up_blocks.0.attentions.0.transformer_blocks.6.attn2',
111
+ 'up_blocks.0.attentions.0.transformer_blocks.7.attn2',
112
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2'
113
+ ]
114
 
115
  def split_attention_maps_over_steps(attention_maps):
116
  r"""Function for splitting attention maps over steps.
 
138
  return attention_maps_cond, attention_maps_uncond
139
 
140
 
141
+ def save_attention_heatmaps(attention_maps, tokens_vis, save_dir, prefix):
142
+ r"""Function to plot heatmaps for attention maps.
143
+
144
+ Args:
145
+ attention_maps (dict): Dictionary of attention maps per layer
146
+ save_dir (str): Directory to save attention maps
147
+ prefix (str): Filename prefix for html files
148
+
149
+ Returns:
150
+ Heatmaps, one per sample.
151
+ """
152
+
153
+ html_names = []
154
+
155
+ idx = 0
156
+ html_list = []
157
+
158
+ for layer in attention_maps.keys():
159
+ if idx == 0:
160
+ # import ipdb;ipdb.set_trace()
161
+ # create a set of html files.
162
+
163
+ batch_size = attention_maps[layer].shape[0]
164
+
165
+ for sample_num in range(batch_size):
166
+ # html path
167
+ html_rel_path = os.path.join('sample_{}'.format(
168
+ sample_num), '{}.html'.format(prefix))
169
+ html_names.append(html_rel_path)
170
+ html_path = os.path.join(save_dir, html_rel_path)
171
+ os.makedirs(os.path.dirname(html_path), exist_ok=True)
172
+ html_list.append(open(html_path, 'wt'))
173
+ html_list[sample_num].write(
174
+ '<html><head></head><body><table>\n')
175
+
176
+ for sample_num in range(batch_size):
177
+
178
+ save_path = os.path.join(save_dir, 'sample_{}'.format(sample_num),
179
+ prefix, 'layer_{}'.format(layer)) + '.jpg'
180
+ Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
181
+
182
+ layer_name = 'layer_{}'.format(layer)
183
+ html_list[sample_num].write(
184
+ f'<tr><td><h1>{layer_name}</h1></td></tr>\n')
185
+
186
+ prefix_stem = prefix.split('/')[-1]
187
+ relative_image_path = os.path.join(
188
+ prefix_stem, 'layer_{}'.format(layer)) + '.jpg'
189
+ html_list[sample_num].write(
190
+ f'<tr><td><img src=\"{relative_image_path}\"></td></tr>\n')
191
+
192
+ plt.figure()
193
+ plt.clf()
194
+ nrows = 2
195
+ ncols = 7
196
+ fig, axs = plt.subplots(nrows=nrows, ncols=ncols)
197
+
198
+ fig.set_figheight(8)
199
+ fig.set_figwidth(28.5)
200
+
201
+ # axs[0].set_aspect('equal')
202
+ # axs[1].set_aspect('equal')
203
+ # axs[2].set_aspect('equal')
204
+ # axs[3].set_aspect('equal')
205
+ # axs[4].set_aspect('equal')
206
+ # axs[5].set_aspect('equal')
207
+
208
+ cmap = plt.get_cmap('YlOrRd')
209
+
210
+ for rid in range(nrows):
211
+ for cid in range(ncols):
212
+ tid = rid*ncols + cid
213
+ # import ipdb;ipdb.set_trace()
214
+ attention_map_cur = attention_maps[layer][sample_num, :, :, tid].numpy(
215
+ )
216
+ vmax = float(attention_map_cur.max())
217
+ vmin = float(attention_map_cur.min())
218
+ sns.heatmap(
219
+ attention_map_cur, annot=False, cbar=False, ax=axs[rid, cid],
220
+ cmap=cmap, vmin=vmin, vmax=vmax
221
+ )
222
+ axs[rid, cid].set_xlabel(tokens_vis[tid])
223
+
224
+ # axs[0].set_xlabel('Self attention')
225
+ # axs[1].set_xlabel('Temporal attention')
226
+ # axs[2].set_xlabel('T5 text attention')
227
+ # axs[3].set_xlabel('CLIP text attention')
228
+ # axs[4].set_xlabel('CLIP image attention')
229
+ # axs[5].set_xlabel('Null text token')
230
+
231
+ norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
232
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
233
+ # fig.colorbar(sm, cax=axs[6])
234
+
235
+ fig.tight_layout()
236
+ plt.savefig(save_path, dpi=64)
237
+ plt.close('all')
238
+
239
+ if idx == (len(attention_maps.keys()) - 1):
240
+ for sample_num in range(batch_size):
241
+ html_list[sample_num].write('</table></body></html>')
242
+ html_list[sample_num].close()
243
+
244
+ idx += 1
245
+
246
+ return html_names
247
+
248
+
249
+ def create_recursive_html_link(html_path, save_dir):
250
+ r"""Function for creating recursive html links.
251
+ If the path is dir1/dir2/dir3/*.html,
252
+ we create chained directories
253
+ -dir1
254
+ dir1.html (has links to all children)
255
+ -dir2
256
+ dir2.html (has links to all children)
257
+ -dir3
258
+ dir3.html
259
+
260
+ Args:
261
+ html_path (str): Path to html file.
262
+ save_dir (str): Save directory.
263
+ """
264
+
265
+ html_path_split = os.path.splitext(html_path)[0].split('/')
266
+ if len(html_path_split) == 1:
267
+ return
268
+
269
+ # First create the root directory
270
+ root_dir = html_path_split[0]
271
+ child_dir = html_path_split[1]
272
+
273
+ cur_html_path = os.path.join(save_dir, '{}.html'.format(root_dir))
274
+ if os.path.exists(cur_html_path):
275
+
276
+ fp = open(cur_html_path, 'r')
277
+ lines_written = fp.readlines()
278
+ fp.close()
279
+
280
+ fp = open(cur_html_path, 'a+')
281
+ child_path = os.path.join(root_dir, f'{child_dir}.html')
282
+ line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n'
283
+
284
+ if line_to_write not in lines_written:
285
+ fp.write('<html><head></head><body><table>\n')
286
+ fp.write(line_to_write)
287
+ fp.write('</table></body></html>')
288
+ fp.close()
289
+
290
+ else:
291
+
292
+ fp = open(cur_html_path, 'w')
293
+
294
+ child_path = os.path.join(root_dir, f'{child_dir}.html')
295
+ line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n'
296
+
297
+ fp.write('<html><head></head><body><table>\n')
298
+ fp.write(line_to_write)
299
+ fp.write('</table></body></html>')
300
+
301
+ fp.close()
302
+
303
+ child_path = '/'.join(html_path.split('/')[1:])
304
+ save_dir = os.path.join(save_dir, root_dir)
305
+ create_recursive_html_link(child_path, save_dir)
306
+
307
+
308
+ def visualize_attention_maps(attention_maps_all, save_dir, width, height, tokens_vis):
309
+ r"""Function to visualize attention maps.
310
+ Args:
311
+ save_dir (str): Path to save attention maps
312
+ batch_size (int): Batch size
313
+ sampler_order (int): Sampler order
314
+ """
315
+
316
+ rand_name = list(attention_maps_all.keys())[0]
317
+ nsteps = len(attention_maps_all[rand_name])
318
+ hw_ori = width * height
319
+
320
+ # html_path = save_dir + '.html'
321
+ text_input = save_dir.split('/')[-1]
322
+ # f = open(html_path, 'wt')
323
+
324
+ all_html_paths = []
325
+
326
+ for step_num in range(0, nsteps, 5):
327
+
328
+ # if cond_id == 'cond':
329
+ # attention_maps_cur = attention_maps_cond[step_num]
330
+ # else:
331
+ # attention_maps_cur = attention_maps_uncond[step_num]
332
+
333
+ attention_maps = dict()
334
+
335
+ for layer in attention_maps_all.keys():
336
+
337
+ attention_ind = attention_maps_all[layer][step_num].cpu()
338
+
339
+ # Attention maps are of shape [batch_size, nkeys, 77]
340
+ # since they are averaged out while collecting from hooks to save memory.
341
+ # Now split the heads from batch dimension
342
+ bs, hw, nclip = attention_ind.shape
343
+ down_ratio = np.sqrt(hw_ori // hw)
344
+ width_cur = int(width // down_ratio)
345
+ height_cur = int(height // down_ratio)
346
+ attention_ind = attention_ind.reshape(
347
+ bs, height_cur, width_cur, nclip)
348
+
349
+ attention_maps[layer] = attention_ind
350
+
351
+ # Obtain heatmaps corresponding to random heads and individual heads
352
+
353
+ html_names = save_attention_heatmaps(
354
+ attention_maps, tokens_vis, save_dir=save_dir, prefix='step_{}/attention_maps_cond'.format(
355
+ step_num)
356
+ )
357
+
358
+ # Write the logic for recursively creating pages
359
+ for html_name_cur in html_names:
360
+ all_html_paths.append(os.path.join(text_input, html_name_cur))
361
+
362
+ save_dir_root = '/'.join(save_dir.split('/')[0:-1])
363
+ for html_pth in all_html_paths:
364
+ create_recursive_html_link(html_pth, save_dir_root)
365
+
366
+
367
  def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
 
368
  for i, attn_map in enumerate(atten_map_list):
369
  n_obj = len(attn_map)
370
  plt.figure()
 
376
  fig.set_figheight(3)
377
  fig.set_figwidth(3*n_obj+0.1)
378
 
379
+ cmap = plt.get_cmap('YlOrRd')
380
 
381
  vmax = 0
382
  vmin = 1
 
405
  norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
406
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
407
  fig.colorbar(sm, cax=axs[-1])
408
+
409
+ fig.tight_layout()
410
+
411
  canvas = fig.canvas
412
  canvas.draw()
413
  width, height = canvas.get_width_height()
414
  img = np.frombuffer(canvas.tostring_rgb(),
415
  dtype='uint8').reshape((height, width, 3))
416
+ plt.savefig(os.path.join(
417
+ save_dir, 'average_seed%d_attn%d.jpg' % (seed, i)), dpi=100)
418
+ plt.close('all')
419
  return img
420
 
421
 
422
+ def get_average_attention_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
423
+ preprocess=False):
424
  r"""Function to visualize attention maps.
425
  Args:
426
  save_dir (str): Path to save attention maps
 
467
  else:
468
  obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[
469
  0].permute([3, 0, 1, 2])
470
+ # obj_attention_map = attention_ind[:, :, :, obj_token].mean(-1, True).permute([3, 0, 1, 2])
471
  obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
472
  interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
473
  attention_maps[obj_id].append(obj_attention_map)
474
 
 
475
  attention_maps_averaged = []
476
  for obj_id, obj_token in enumerate(obj_tokens):
477
  if obj_id == len(obj_tokens) - 1:
 
481
  attention_maps_averaged.append(
482
  torch.cat(attention_maps[obj_id]).mean(0))
483
 
 
484
  attention_maps_averaged_normalized = []
485
  attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0)
486
  for obj_id, obj_token in enumerate(obj_tokens):
487
  attention_maps_averaged_normalized.append(
488
  attention_maps_averaged[obj_id]/attention_maps_averaged_sum)
489
 
490
+ if obj_tokens[-1][0] != -1:
491
+ attention_maps_averaged_normalized = (
492
+ torch.cat(attention_maps_averaged)/0.001).softmax(0)
493
+ attention_maps_averaged_normalized = [
494
+ attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
495
+
496
+ if preprocess:
497
+ selem = square(5)
498
+ selem = square(3)
499
+ selem = square(1)
500
+ attention_maps_averaged_eroded = [erosion(skimage.img_as_float(
501
+ map[0].numpy()*255), selem) for map in attention_maps_averaged_normalized[:2]]
502
+ attention_maps_averaged_eroded = [(torch.from_numpy(map).unsqueeze(
503
+ 0)/255. > 0.8).float() for map in attention_maps_averaged_eroded]
504
+ attention_maps_averaged_eroded.append(
505
+ 1 - torch.cat(attention_maps_averaged_eroded).sum(0, True))
506
+ plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized,
507
+ attention_maps_averaged_eroded], obj_tokens, save_dir, seed, tokens_vis)
508
+ attention_maps_averaged_eroded = [attn_mask.unsqueeze(1).repeat(
509
+ [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_eroded]
510
+ return attention_maps_averaged_eroded
511
+ else:
512
+ plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
513
+ obj_tokens, save_dir, seed, tokens_vis)
514
+ attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
515
+ [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
516
+ return attention_maps_averaged_normalized
517
+
518
+
519
+ def get_average_attention_maps_threshold(attention_maps, save_dir, width, height, obj_tokens, seed=0, threshold=0.02):
520
+ r"""Function to visualize attention maps.
521
+ Args:
522
+ save_dir (str): Path to save attention maps
523
+ batch_size (int): Batch size
524
+ sampler_order (int): Sampler order
525
+ """
526
+
527
+ _EPS = 1e-8
528
+ # Split attention maps over steps
529
+ attention_maps_cond, _ = split_attention_maps_over_steps(
530
+ attention_maps
531
+ )
532
+
533
+ nsteps = len(attention_maps_cond)
534
+ hw_ori = width * height
535
+
536
+ attention_maps = []
537
+ for obj_token in obj_tokens:
538
+ attention_maps.append([])
539
+
540
+ # for each side prompt, get attention maps for all steps and all layers
541
+ for step_num in range(nsteps):
542
+ attention_maps_cur = attention_maps_cond[step_num]
543
+ for layer in attention_maps_cur.keys():
544
+ attention_ind = attention_maps_cur[layer].cpu()
545
+ bs, hw, nclip = attention_ind.shape
546
+ down_ratio = np.sqrt(hw_ori // hw)
547
+ width_cur = int(width // down_ratio)
548
+ height_cur = int(height // down_ratio)
549
+ attention_ind = attention_ind.reshape(
550
+ bs, height_cur, width_cur, nclip)
551
+ for obj_id, obj_token in enumerate(obj_tokens):
552
+ if attention_ind.shape[1] > width//2:
553
+ continue
554
+ if obj_token[0] != -1:
555
+ obj_attention_map = attention_ind[:, :, :,
556
+ obj_token].mean(-1, True).permute([3, 0, 1, 2])
557
+ obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
558
+ interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
559
+ attention_maps[obj_id].append(obj_attention_map)
560
+
561
+ # average of all steps and layers, thresholding
562
+ attention_maps_thres = []
563
+ attention_maps_averaged = []
564
+ for obj_id, obj_token in enumerate(obj_tokens):
565
+ if obj_token[0] != -1:
566
+ average_map = torch.cat(attention_maps[obj_id]).mean(0)
567
+ attention_maps_averaged.append(average_map)
568
+ attention_maps_thres.append((average_map > threshold).float())
569
+
570
+ # get the remaining region except for the original prompt
571
+ attention_maps_averaged_normalized = []
572
+ attention_maps_averaged_sum = torch.cat(attention_maps_thres).sum(0) + _EPS
573
+ for obj_id, obj_token in enumerate(obj_tokens):
574
+ if obj_token[0] != -1:
575
+ attention_maps_averaged_normalized.append(
576
+ attention_maps_thres[obj_id]/attention_maps_averaged_sum)
577
+ else:
578
+ attention_map_prev = torch.stack(
579
+ attention_maps_averaged_normalized).sum(0)
580
+ attention_maps_averaged_normalized.append(1.-attention_map_prev)
581
+
582
+ plot_attention_maps(
583
+ [attention_maps_averaged, attention_maps_averaged_normalized], save_dir, seed)
584
 
 
 
585
  attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
586
  [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
587
+ # attention_maps_averaged_normalized = attention_maps_averaged_normalized.unsqueeze(1).repeat([1, 4, 1, 1]).cuda()
588
+ return attention_maps_averaged_normalized
589
 
590
 
591
+ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, kmeans_seed=0, tokens_vis=None,
592
  preprocess=False, segment_threshold=0.3, num_segments=5, return_vis=False, save_attn=False):
593
  r"""Function to visualize attention maps.
594
  Args:
 
597
  sampler_order (int): Sampler order
598
  """
599
 
 
600
  resolution = 32
601
+ # attn_maps_1024 = [attn_map for attn_map in selfattn_maps.values(
602
+ # ) if attn_map.shape[1] == resolution**2]
603
+ # attn_maps_1024 = torch.cat(attn_maps_1024).mean(0).cpu().numpy()
604
  attn_maps_1024 = {8: [], 16: [], 32: [], 64: []}
605
  for attn_map in selfattn_maps.values():
606
  resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
607
  if resolution_map != resolution:
608
  continue
609
+ # attn_map = torch.nn.functional.interpolate(rearrange(attn_map, '1 c (h w) -> 1 c h w', h=resolution_map), (resolution, resolution),
610
+ # mode='bicubic', antialias=True)
611
+ # attn_map = rearrange(attn_map, '1 (h w) a b -> 1 (a b) h w', h=resolution_map)
612
  attn_map = attn_map.reshape(
613
+ 1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2]).float()
614
  attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
615
  mode='bicubic', antialias=True)
616
  attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
 
621
  print('saving self-attention maps...', attn_maps_1024.shape)
622
  torch.save(torch.from_numpy(attn_maps_1024),
623
  'results/maps/selfattn_maps.pth')
624
+ seed_everything(kmeans_seed)
625
+ # import ipdb;ipdb.set_trace()
626
+ # kmeans = KMeans(n_clusters=num_segments,
627
+ # n_init=10).fit(attn_maps_1024)
628
+ # clusters = kmeans.labels_
629
+ # clusters = clusters.reshape(resolution, resolution)
630
+ # mesh = np.array(np.meshgrid(range(resolution), range(resolution), indexing='ij'), dtype=np.float32)/resolution
631
+ # dists = mesh.reshape(2, -1).T
632
+ # delta = 0.01
633
+ # spatial_sim = rbf_kernel(dists, dists)*delta
634
  sc = SpectralClustering(num_segments, affinity='precomputed', n_init=100,
635
  assign_labels='kmeans')
636
  clusters = sc.fit_predict(attn_maps_1024)
 
638
  fig = plt.figure()
639
  plt.imshow(clusters)
640
  plt.axis('off')
641
+ plt.savefig(os.path.join(save_dir, 'segmentation_k%d_seed%d.jpg' % (num_segments, kmeans_seed)),
642
+ bbox_inches='tight', pad_inches=0)
643
  if return_vis:
644
  canvas = fig.canvas
645
  canvas.draw()
 
653
  cross_attn_maps_1024 = []
654
  for attn_map in crossattn_maps.values():
655
  resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
656
+ # if resolution_map != 16:
657
+ # continue
658
  attn_map = attn_map.reshape(
659
+ 1, resolution_map, resolution_map, -1).permute([0, 3, 1, 2]).float()
660
  attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
661
  mode='bicubic', antialias=True)
662
  cross_attn_maps_1024.append(attn_map.permute([0, 2, 3, 1]))
663
 
664
  cross_attn_maps_1024 = torch.cat(
665
  cross_attn_maps_1024).mean(0).cpu().numpy()
 
 
 
 
666
  normalized_span_maps = []
667
  for token_ids in obj_tokens:
668
  span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
 
670
  for i in range(span_token_maps.shape[-1]):
671
  curr_noun_map = span_token_maps[:, :, i]
672
  normalized_span_map[:, :, i] = (
673
+ # curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
674
+ curr_noun_map - np.abs(curr_noun_map.min())) / (curr_noun_map.max()-curr_noun_map.min())
675
  normalized_span_maps.append(normalized_span_map)
676
  foreground_token_maps = [np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze(
677
  ) for normalized_span_map in normalized_span_maps]
 
702
  0) for token_map in resized_token_maps]
703
  foreground_token_maps = [token_map[None, :, :]
704
  for token_map in foreground_token_maps]
705
+ if preprocess:
706
+ selem = square(5)
707
+ eroded_token_maps = torch.stack([torch.from_numpy(erosion(skimage.img_as_float(
708
+ map[0].numpy()*255), selem))/255. for map in resized_token_maps[:-1]]).clamp(0, 1)
709
+ # import ipdb; ipdb.set_trace()
710
+ eroded_background_maps = (1-eroded_token_maps.sum(0, True)).clamp(0, 1)
711
+ eroded_token_maps = torch.cat([eroded_token_maps, eroded_background_maps])
712
+ eroded_token_maps = eroded_token_maps / (eroded_token_maps.sum(0, True)+1e-8)
713
+ resized_token_maps = [token_map.unsqueeze(
714
+ 0) for token_map in eroded_token_maps]
715
+
716
  token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
717
+ save_dir, kmeans_seed, tokens_vis)
718
  resized_token_maps = [token_map.unsqueeze(1).repeat(
719
  [1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
720
  if return_vis: