wooyeolbaek commited on
Commit
5cb0966
·
verified ·
1 Parent(s): 9fc47f2

Add save_attention_maps

Browse files
Files changed (1) hide show
  1. utils.py +158 -368
utils.py CHANGED
@@ -1,280 +1,53 @@
1
  import os
2
- import math
3
- import numpy as np
4
- from PIL import Image
5
 
6
  import torch
7
  import torch.nn.functional as F
8
-
9
- from diffusers.utils import deprecate
 
 
 
 
 
 
10
  from diffusers.models.attention_processor import (
11
- Attention,
12
  AttnProcessor,
13
  AttnProcessor2_0,
14
  LoRAAttnProcessor,
15
- LoRAAttnProcessor2_0
 
 
16
  )
17
 
18
-
19
- attn_maps = {}
20
-
21
-
22
- def attn_call(
23
- self,
24
- attn: Attention,
25
- hidden_states,
26
- encoder_hidden_states=None,
27
- attention_mask=None,
28
- temb=None,
29
- scale=1.0,
30
- ):
31
- residual = hidden_states
32
-
33
- if attn.spatial_norm is not None:
34
- hidden_states = attn.spatial_norm(hidden_states, temb)
35
-
36
- input_ndim = hidden_states.ndim
37
-
38
- if input_ndim == 4:
39
- batch_size, channel, height, width = hidden_states.shape
40
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
41
-
42
- batch_size, sequence_length, _ = (
43
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
44
- )
45
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
46
-
47
- if attn.group_norm is not None:
48
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
49
-
50
- query = attn.to_q(hidden_states, scale=scale)
51
-
52
- if encoder_hidden_states is None:
53
- encoder_hidden_states = hidden_states
54
- elif attn.norm_cross:
55
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
56
-
57
- key = attn.to_k(encoder_hidden_states, scale=scale)
58
- value = attn.to_v(encoder_hidden_states, scale=scale)
59
-
60
- query = attn.head_to_batch_dim(query)
61
- key = attn.head_to_batch_dim(key)
62
- value = attn.head_to_batch_dim(value)
63
-
64
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
65
- ####################################################################################################
66
- # (20,4096,77) or (40,1024,77)
67
- if hasattr(self, "store_attn_map"):
68
- self.attn_map = attention_probs
69
- ####################################################################################################
70
- hidden_states = torch.bmm(attention_probs, value)
71
- hidden_states = attn.batch_to_head_dim(hidden_states)
72
-
73
- # linear proj
74
- hidden_states = attn.to_out[0](hidden_states, scale=scale)
75
- # dropout
76
- hidden_states = attn.to_out[1](hidden_states)
77
-
78
- if input_ndim == 4:
79
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
80
-
81
- if attn.residual_connection:
82
- hidden_states = hidden_states + residual
83
-
84
- hidden_states = hidden_states / attn.rescale_output_factor
85
-
86
- return hidden_states
87
-
88
-
89
- def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
90
- # Efficient implementation equivalent to the following:
91
- L, S = query.size(-2), key.size(-2)
92
- scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
93
- attn_bias = torch.zeros(L, S, dtype=query.dtype)
94
- if is_causal:
95
- assert attn_mask is None
96
- temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
97
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
98
- attn_bias.to(query.dtype)
99
-
100
- if attn_mask is not None:
101
- if attn_mask.dtype == torch.bool:
102
- attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
103
- else:
104
- attn_bias += attn_mask
105
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
106
- attn_weight += attn_bias.to(attn_weight.device)
107
- attn_weight = torch.softmax(attn_weight, dim=-1)
108
-
109
- return torch.dropout(attn_weight, dropout_p, train=True) @ value, attn_weight
110
-
111
-
112
- def attn_call2_0(
113
- self,
114
- attn: Attention,
115
- hidden_states,
116
- encoder_hidden_states=None,
117
- attention_mask=None,
118
- temb=None,
119
- scale: float = 1.0,
120
- ):
121
- residual = hidden_states
122
-
123
- if attn.spatial_norm is not None:
124
- hidden_states = attn.spatial_norm(hidden_states, temb)
125
-
126
- input_ndim = hidden_states.ndim
127
-
128
- if input_ndim == 4:
129
- batch_size, channel, height, width = hidden_states.shape
130
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
131
-
132
- batch_size, sequence_length, _ = (
133
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
134
- )
135
-
136
- if attention_mask is not None:
137
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
138
- # scaled_dot_product_attention expects attention_mask shape to be
139
- # (batch, heads, source_length, target_length)
140
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
141
-
142
- if attn.group_norm is not None:
143
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
144
-
145
- query = attn.to_q(hidden_states, scale=scale)
146
-
147
- if encoder_hidden_states is None:
148
- encoder_hidden_states = hidden_states
149
- elif attn.norm_cross:
150
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
151
-
152
- key = attn.to_k(encoder_hidden_states, scale=scale)
153
- value = attn.to_v(encoder_hidden_states, scale=scale)
154
-
155
- inner_dim = key.shape[-1]
156
- head_dim = inner_dim // attn.heads
157
-
158
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
159
-
160
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
161
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
162
-
163
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
164
- # TODO: add support for attn.scale when we move to Torch 2.1
165
- ####################################################################################################
166
- # if self.store_attn_map:
167
- if hasattr(self, "store_attn_map"):
168
- hidden_states, attn_map = scaled_dot_product_attention(
169
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
170
- )
171
- # (2,10,4096,77) or (2,20,1024,77)
172
- self.attn_map = attn_map
173
- else:
174
- hidden_states = F.scaled_dot_product_attention(
175
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
176
- )
177
- ####################################################################################################
178
-
179
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
180
- hidden_states = hidden_states.to(query.dtype)
181
-
182
- # linear proj
183
- hidden_states = attn.to_out[0](hidden_states, scale=scale)
184
- # dropout
185
- hidden_states = attn.to_out[1](hidden_states)
186
-
187
- if input_ndim == 4:
188
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
189
-
190
- if attn.residual_connection:
191
- hidden_states = hidden_states + residual
192
-
193
- hidden_states = hidden_states / attn.rescale_output_factor
194
-
195
- return hidden_states
196
-
197
-
198
- def lora_attn_call(self, attn: Attention, hidden_states, *args, **kwargs):
199
- self_cls_name = self.__class__.__name__
200
- deprecate(
201
- self_cls_name,
202
- "0.26.0",
203
- (
204
- f"Make sure use {self_cls_name[4:]} instead by setting"
205
- "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
206
- " `LoraLoaderMixin.load_lora_weights`"
207
- ),
208
- )
209
- attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
210
- attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
211
- attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
212
- attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
213
-
214
- attn._modules.pop("processor")
215
- attn.processor = AttnProcessor()
216
-
217
- if hasattr(self, "store_attn_map"):
218
- attn.processor.store_attn_map = True
219
-
220
- return attn.processor(attn, hidden_states, *args, **kwargs)
221
-
222
-
223
- def lora_attn_call2_0(self, attn: Attention, hidden_states, *args, **kwargs):
224
- self_cls_name = self.__class__.__name__
225
- deprecate(
226
- self_cls_name,
227
- "0.26.0",
228
- (
229
- f"Make sure use {self_cls_name[4:]} instead by setting"
230
- "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
231
- " `LoraLoaderMixin.load_lora_weights`"
232
- ),
233
- )
234
- attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
235
- attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
236
- attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
237
- attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
238
-
239
- attn._modules.pop("processor")
240
- attn.processor = AttnProcessor2_0()
241
-
242
- if hasattr(self, "store_attn_map"):
243
- attn.processor.store_attn_map = True
244
-
245
- return attn.processor(attn, hidden_states, *args, **kwargs)
246
-
247
 
248
  def cross_attn_init():
249
  AttnProcessor.__call__ = attn_call
250
- AttnProcessor2_0.__call__ = attn_call # attn_call is faster
251
- # AttnProcessor2_0.__call__ = attn_call2_0
252
  LoRAAttnProcessor.__call__ = lora_attn_call
253
- # LoRAAttnProcessor2_0.__call__ = lora_attn_call2_0
254
- LoRAAttnProcessor2_0.__call__ = lora_attn_call
255
-
256
-
257
- def reshape_attn_map(attn_map):
258
- attn_map = torch.mean(attn_map,dim=0) # mean by head dim: (20,4096,77) -> (4096,77)
259
- attn_map = attn_map.permute(1,0) # (4096,77) -> (77,4096)
260
- latent_size = int(math.sqrt(attn_map.shape[1]))
261
- latent_shape = (attn_map.shape[0],latent_size,-1)
262
- attn_map = attn_map.reshape(latent_shape) # (77,4096) -> (77,64,64)
263
 
264
- return attn_map # torch.sum(attn_map,dim=0) = [1,1,...,1]
265
 
266
-
267
- def hook_fn(name):
268
  def forward_hook(module, input, output):
269
  if hasattr(module.processor, "attn_map"):
270
- attn_maps[name] = module.processor.attn_map
 
 
 
 
 
 
271
  del module.processor.attn_map
272
 
273
  return forward_hook
274
 
275
- def register_cross_attention_hook(unet):
276
- for name, module in unet.named_modules():
277
- if not name.split('.')[-1].startswith('attn2'):
 
278
  continue
279
 
280
  if isinstance(module.processor, AttnProcessor):
@@ -285,129 +58,146 @@ def register_cross_attention_hook(unet):
285
  module.processor.store_attn_map = True
286
  elif isinstance(module.processor, LoRAAttnProcessor2_0):
287
  module.processor.store_attn_map = True
 
 
 
 
288
 
289
- hook = module.register_forward_hook(hook_fn(name))
290
 
291
- return unet
292
-
293
-
294
- def prompt2tokens(tokenizer, prompt):
295
- text_inputs = tokenizer(
296
- prompt,
297
- padding="max_length",
298
- max_length=tokenizer.model_max_length,
299
- truncation=True,
300
- return_tensors="pt",
301
- )
302
- text_input_ids = text_inputs.input_ids
303
- tokens = []
304
- for text_input_id in text_input_ids[0]:
305
- token = tokenizer.decoder[text_input_id.item()]
306
- tokens.append(token)
307
- return tokens
308
-
309
-
310
- # TODO: generalize for rectangle images
311
- def upscale(attn_map, target_size):
312
- attn_map = torch.mean(attn_map, dim=0) # (10, 32*32, 77) -> (32*32, 77)
313
- attn_map = attn_map.permute(1,0) # (32*32, 77) -> (77, 32*32)
314
-
315
- if target_size[0]*target_size[1] != attn_map.shape[1]:
316
- temp_size = (target_size[0]//2, target_size[1]//2)
317
- attn_map = attn_map.view(attn_map.shape[0], *temp_size) # (77, 32,32)
318
- attn_map = attn_map.unsqueeze(0) # (77,32,32) -> (1,77,32,32)
319
-
320
- attn_map = F.interpolate(
321
- attn_map.to(dtype=torch.float32),
322
- size=target_size,
323
- mode='bilinear',
324
- align_corners=False
325
- ).squeeze() # (77,64,64)
326
- else:
327
- attn_map = attn_map.to(dtype=torch.float32) # (77,64,64)
328
 
329
- attn_map = torch.softmax(attn_map, dim=0)
330
- attn_map = attn_map.reshape(attn_map.shape[0],-1) # (77,64*64)
331
- return attn_map
332
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
- def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
335
- target_size = (image_size[0]//16, image_size[1]//16)
336
- idx = 0 if instance_or_negative else 1
337
- net_attn_maps = []
338
 
339
- for name, attn_map in attn_maps.items():
340
- attn_map = attn_map.cpu() if detach else attn_map
341
- attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG
342
- if len(attn_map.shape) == 4:
343
- attn_map = attn_map.squeeze()
 
 
 
 
 
344
 
345
- attn_map = upscale(attn_map, target_size) # (10,32*32,77) -> (77,64*64)
346
- net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64)
 
 
347
 
348
- net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
349
- net_attn_maps = net_attn_maps.reshape(net_attn_maps.shape[0], 64,64) # (77,64*64) -> (77,64,64)
350
 
351
- return net_attn_maps
352
 
353
 
354
- def save_net_attn_map(net_attn_maps, dir_name, tokenizer, prompt):
355
- if not os.path.exists(dir_name):
356
- os.makedirs(dir_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
- tokens = prompt2tokens(tokenizer, prompt)
359
- total_attn_scores = 0
360
- for i, (token, attn_map) in enumerate(zip(tokens, net_attn_maps)):
361
- attn_map_score = torch.sum(attn_map)
362
- attn_map = attn_map.cpu().numpy()
363
- h,w = attn_map.shape
364
- attn_map_total = h*w
365
- attn_map_score = attn_map_score / attn_map_total
366
- total_attn_scores += attn_map_score
367
- token = token.replace('</w>','')
368
- save_attn_map(
369
- attn_map,
370
- f'{token}:{attn_map_score:.2f}',
371
- f"{dir_name}/{i}_<{token}>:{int(attn_map_score*100)}.png"
372
- )
373
- print(f'total_attn_scores: {total_attn_scores}')
374
-
375
-
376
- def resize_net_attn_map(net_attn_maps, target_size):
377
- net_attn_maps = F.interpolate(
378
- net_attn_maps.to(dtype=torch.float32).unsqueeze(0),
379
- size=target_size,
380
- mode='bilinear',
381
- align_corners=False
382
- ).squeeze() # (77,64,64)
383
- return net_attn_maps
384
-
385
-
386
- def save_attn_map(attn_map, title, save_path):
387
- normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
388
- normalized_attn_map = normalized_attn_map.astype(np.uint8)
389
- image = Image.fromarray(normalized_attn_map)
390
- image.save(save_path, format='PNG', compression=0)
391
-
392
-
393
- def return_net_attn_map(net_attn_maps, tokenizer, prompt):
394
 
395
- tokens = prompt2tokens(tokenizer, prompt)
396
- total_attn_scores = 0
397
- images = []
398
- for i, (token, attn_map) in enumerate(zip(tokens, net_attn_maps)):
399
- attn_map_score = torch.sum(attn_map)
400
- h,w = attn_map.shape
401
- attn_map_total = h*w
402
- attn_map_score = attn_map_score / attn_map_total
403
- total_attn_scores += attn_map_score
404
-
405
- attn_map = attn_map.cpu().numpy()
406
- normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
407
- normalized_attn_map = normalized_attn_map.astype(np.uint8)
408
- image = Image.fromarray(normalized_attn_map)
409
-
410
- token = token.replace('</w>','')
411
- images.append((image,f"{i}_<{token}>"))
412
- print(f'total_attn_scores: {total_attn_scores}')
413
- return images
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
2
 
3
  import torch
4
  import torch.nn.functional as F
5
+ from torchvision.transforms import ToPILImage
6
+
7
+ from diffusers.models import Transformer2DModel
8
+ from diffusers.models.unets import UNet2DConditionModel
9
+ from diffusers.models.transformers import SD3Transformer2DModel, FluxTransformer2DModel
10
+ from diffusers.models.transformers.transformer_flux import FluxTransformerBlock
11
+ from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock
12
+ from diffusers import FluxPipeline
13
  from diffusers.models.attention_processor import (
 
14
  AttnProcessor,
15
  AttnProcessor2_0,
16
  LoRAAttnProcessor,
17
+ LoRAAttnProcessor2_0,
18
+ JointAttnProcessor2_0,
19
+ FluxAttnProcessor2_0
20
  )
21
 
22
+ from modules import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def cross_attn_init():
25
  AttnProcessor.__call__ = attn_call
26
+ AttnProcessor2_0.__call__ = attn_call2_0
 
27
  LoRAAttnProcessor.__call__ = lora_attn_call
28
+ LoRAAttnProcessor2_0.__call__ = lora_attn_call2_0
29
+ JointAttnProcessor2_0.__call__ = joint_attn_call2_0
30
+ FluxAttnProcessor2_0.__call__ = flux_attn_call2_0
 
 
 
 
 
 
 
31
 
 
32
 
33
+ def hook_function(name, detach=True):
 
34
  def forward_hook(module, input, output):
35
  if hasattr(module.processor, "attn_map"):
36
+
37
+ timestep = module.processor.timestep
38
+
39
+ attn_maps[timestep] = attn_maps.get(timestep, dict())
40
+ attn_maps[timestep][name] = module.processor.attn_map.cpu() if detach \
41
+ else module.processor.attn_map
42
+
43
  del module.processor.attn_map
44
 
45
  return forward_hook
46
 
47
+
48
+ def register_cross_attention_hook(model, hook_function, target_name):
49
+ for name, module in model.named_modules():
50
+ if not name.endswith(target_name):
51
  continue
52
 
53
  if isinstance(module.processor, AttnProcessor):
 
58
  module.processor.store_attn_map = True
59
  elif isinstance(module.processor, LoRAAttnProcessor2_0):
60
  module.processor.store_attn_map = True
61
+ elif isinstance(module.processor, JointAttnProcessor2_0):
62
+ module.processor.store_attn_map = True
63
+ elif isinstance(module.processor, FluxAttnProcessor2_0):
64
+ module.processor.store_attn_map = True
65
 
66
+ hook = module.register_forward_hook(hook_function(name))
67
 
68
+ return model
69
+
70
+
71
+ def replace_call_method_for_unet(model):
72
+ if model.__class__.__name__ == 'UNet2DConditionModel':
73
+ model.forward = UNet2DConditionModelForward.__get__(model, UNet2DConditionModel)
74
+
75
+ for name, layer in model.named_children():
76
+
77
+ if layer.__class__.__name__ == 'Transformer2DModel':
78
+ layer.forward = Transformer2DModelForward.__get__(layer, Transformer2DModel)
79
+
80
+ elif layer.__class__.__name__ == 'BasicTransformerBlock':
81
+ layer.forward = BasicTransformerBlockForward.__get__(layer, BasicTransformerBlock)
82
+
83
+ replace_call_method_for_unet(layer)
84
+
85
+ return model
86
+
87
+
88
+ def replace_call_method_for_sd3(model):
89
+ if model.__class__.__name__ == 'SD3Transformer2DModel':
90
+ model.forward = SD3Transformer2DModelForward.__get__(model, SD3Transformer2DModel)
91
+
92
+ for name, layer in model.named_children():
93
+
94
+ if layer.__class__.__name__ == 'JointTransformerBlock':
95
+ layer.forward = JointTransformerBlockForward.__get__(layer, JointTransformerBlock)
96
+
97
+ replace_call_method_for_sd3(layer)
98
+
99
+ return model
 
 
 
 
 
100
 
 
 
 
101
 
102
+ def replace_call_method_for_flux(model):
103
+ if model.__class__.__name__ == 'FluxTransformer2DModel':
104
+ model.forward = FluxTransformer2DModelForward.__get__(model, FluxTransformer2DModel)
105
+
106
+ for name, layer in model.named_children():
107
+
108
+ if layer.__class__.__name__ == 'FluxTransformerBlock':
109
+ layer.forward = FluxTransformerBlockForward.__get__(layer, FluxTransformerBlock)
110
+
111
+ replace_call_method_for_flux(layer)
112
+
113
+ return model
114
 
 
 
 
 
115
 
116
+ def init_pipeline(pipeline):
117
+ if 'transformer' in vars(pipeline).keys():
118
+ if pipeline.transformer.__class__.__name__ == 'SD3Transformer2DModel':
119
+ pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn')
120
+ pipeline.transformer = replace_call_method_for_sd3(pipeline.transformer)
121
+
122
+ elif pipeline.transformer.__class__.__name__ == 'FluxTransformer2DModel':
123
+ FluxPipeline.__call__ = FluxPipeline_call
124
+ pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn')
125
+ pipeline.transformer = replace_call_method_for_flux(pipeline.transformer)
126
 
127
+ else:
128
+ if pipeline.unet.__class__.__name__ == 'UNet2DConditionModel':
129
+ pipeline.unet = register_cross_attention_hook(pipeline.unet, hook_function, 'attn2')
130
+ pipeline.unet = replace_call_method_for_unet(pipeline.unet)
131
 
 
 
132
 
133
+ return pipeline
134
 
135
 
136
+ def save_attention_maps(attn_maps, tokenizer, prompts, base_dir='attn_maps', unconditional=True):
137
+ to_pil = ToPILImage()
138
+
139
+ token_ids = tokenizer(prompts)['input_ids']
140
+ total_tokens = []
141
+ for token_id in token_ids:
142
+ total_tokens.append(tokenizer.convert_ids_to_tokens(token_id))
143
+
144
+ if not os.path.exists(base_dir):
145
+ os.mkdir(base_dir)
146
+
147
+ total_attn_map = list(list(attn_maps.values())[0].values())[0].sum(1)
148
+ if unconditional:
149
+ total_attn_map = total_attn_map.chunk(2)[1] # (batch, height, width, attn_dim)
150
+ total_attn_map = total_attn_map.permute(0, 3, 1, 2)
151
+ total_attn_map = torch.zeros_like(total_attn_map)
152
+ total_attn_map_shape = total_attn_map.shape[-2:]
153
+ total_attn_map_number = 0
154
 
155
+ for timestep, layers in attn_maps.items():
156
+ timestep_dir = os.path.join(base_dir, f'{timestep}')
157
+ if not os.path.exists(timestep_dir):
158
+ os.mkdir(timestep_dir)
159
+
160
+ for layer, attn_map in layers.items():
161
+ layer_dir = os.path.join(timestep_dir, f'{layer}')
162
+ if not os.path.exists(layer_dir):
163
+ os.mkdir(layer_dir)
164
+
165
+ attn_map = attn_map.sum(1).squeeze(1)
166
+ attn_map = attn_map.permute(0, 3, 1, 2)
167
+
168
+ if unconditional:
169
+ attn_map = attn_map.chunk(2)[1]
170
+
171
+ resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False)
172
+ total_attn_map += resized_attn_map
173
+ total_attn_map_number += 1
174
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ total_attn_map /= total_attn_map_number
177
+ final_attn_map = {}
178
+ for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)):
179
+ batch_dir = os.path.join(base_dir, f'batch-{batch}')
180
+ if not os.path.exists(batch_dir):
181
+ os.mkdir(batch_dir)
182
+
183
+ startofword = True
184
+ for i, (token, a) in enumerate(zip(tokens, attn_map[:len(tokens)])):
185
+ if '</w>' in token:
186
+ token = token.replace('</w>', '')
187
+ if startofword:
188
+ token = '<' + token + '>'
189
+ else:
190
+ token = '-' + token + '>'
191
+ startofword = True
192
+
193
+ elif token != '<|startoftext|>' and token != '<|endoftext|>':
194
+ if startofword:
195
+ token = '<' + token + '-'
196
+ startofword = False
197
+ else:
198
+ token = '-' + token + '-'
199
+
200
+
201
+ final_attn_map[f'{i}-{token}.png'] = to_pil(a.to(torch.float32))
202
+
203
+ return final_attn_map