wooyeolbaek commited on
Commit
0c1540a
·
1 Parent(s): c620069

Add app.py, utils.py

Browse files
Files changed (2) hide show
  1. app.py +51 -0
  2. utils.py +413 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from diffusers import StableDiffusionXLPipeline
4
+ from utils import (
5
+ cross_attn_init,
6
+ register_cross_attention_hook,
7
+ attn_maps,
8
+ get_net_attn_map,
9
+ resize_net_attn_map,
10
+ return_net_attn_map,
11
+ )
12
+
13
+ cross_attn_init()
14
+ pipe = StableDiffusionXLPipeline.from_pretrained(
15
+ "stabilityai/stable-diffusion-xl-base-1.0",
16
+ torch_dtype=torch.float16,
17
+ )
18
+ pipe.unet = register_cross_attention_hook(pipe.unet)
19
+ pipe = pipe.to("cuda")
20
+
21
+
22
+ def inference(prompt):
23
+ image = pipe(prompt).images[0]
24
+ net_attn_maps = get_net_attn_map(image.size)
25
+ net_attn_maps = resize_net_attn_map(net_attn_maps, image.size)
26
+ net_attn_maps = return_net_attn_map(net_attn_maps, pipe.tokenizer, prompt)
27
+
28
+ return image, net_attn_maps
29
+
30
+
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown(
33
+ """
34
+ 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡
35
+ """)
36
+ prompt = gr.Textbox(value="A photo of a black puppy, christmas atmosphere", label="Prompt", lines=2)
37
+ btn = gr.Button("Generate images", scale=0)
38
+
39
+ with gr.Row():
40
+ image = gr.Image(height=512,width=512,type="pil")
41
+ gallery = gr.Gallery(
42
+ value=None,
43
+ label="Generated images", show_label=False, elem_id="gallery",
44
+ object_fit="contain", height="auto")
45
+
46
+
47
+ btn.click(inference, prompt, [image, gallery])
48
+
49
+ if __name__ == "__main__":
50
+ demo.launch(share=True)
51
+
utils.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from diffusers.utils import deprecate
10
+ from diffusers.models.attention_processor import (
11
+ Attention,
12
+ AttnProcessor,
13
+ AttnProcessor2_0,
14
+ LoRAAttnProcessor,
15
+ LoRAAttnProcessor2_0
16
+ )
17
+
18
+
19
+ attn_maps = {}
20
+
21
+
22
+ def attn_call(
23
+ self,
24
+ attn: Attention,
25
+ hidden_states,
26
+ encoder_hidden_states=None,
27
+ attention_mask=None,
28
+ temb=None,
29
+ scale=1.0,
30
+ ):
31
+ residual = hidden_states
32
+
33
+ if attn.spatial_norm is not None:
34
+ hidden_states = attn.spatial_norm(hidden_states, temb)
35
+
36
+ input_ndim = hidden_states.ndim
37
+
38
+ if input_ndim == 4:
39
+ batch_size, channel, height, width = hidden_states.shape
40
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
41
+
42
+ batch_size, sequence_length, _ = (
43
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
44
+ )
45
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
46
+
47
+ if attn.group_norm is not None:
48
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
49
+
50
+ query = attn.to_q(hidden_states, scale=scale)
51
+
52
+ if encoder_hidden_states is None:
53
+ encoder_hidden_states = hidden_states
54
+ elif attn.norm_cross:
55
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
56
+
57
+ key = attn.to_k(encoder_hidden_states, scale=scale)
58
+ value = attn.to_v(encoder_hidden_states, scale=scale)
59
+
60
+ query = attn.head_to_batch_dim(query)
61
+ key = attn.head_to_batch_dim(key)
62
+ value = attn.head_to_batch_dim(value)
63
+
64
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
65
+ ####################################################################################################
66
+ # (20,4096,77) or (40,1024,77)
67
+ if hasattr(self, "store_attn_map"):
68
+ self.attn_map = attention_probs
69
+ ####################################################################################################
70
+ hidden_states = torch.bmm(attention_probs, value)
71
+ hidden_states = attn.batch_to_head_dim(hidden_states)
72
+
73
+ # linear proj
74
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
75
+ # dropout
76
+ hidden_states = attn.to_out[1](hidden_states)
77
+
78
+ if input_ndim == 4:
79
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
80
+
81
+ if attn.residual_connection:
82
+ hidden_states = hidden_states + residual
83
+
84
+ hidden_states = hidden_states / attn.rescale_output_factor
85
+
86
+ return hidden_states
87
+
88
+
89
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
90
+ # Efficient implementation equivalent to the following:
91
+ L, S = query.size(-2), key.size(-2)
92
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
93
+ attn_bias = torch.zeros(L, S, dtype=query.dtype)
94
+ if is_causal:
95
+ assert attn_mask is None
96
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
97
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
98
+ attn_bias.to(query.dtype)
99
+
100
+ if attn_mask is not None:
101
+ if attn_mask.dtype == torch.bool:
102
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
103
+ else:
104
+ attn_bias += attn_mask
105
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
106
+ attn_weight += attn_bias.to(attn_weight.device)
107
+ attn_weight = torch.softmax(attn_weight, dim=-1)
108
+
109
+ return torch.dropout(attn_weight, dropout_p, train=True) @ value, attn_weight
110
+
111
+
112
+ def attn_call2_0(
113
+ self,
114
+ attn: Attention,
115
+ hidden_states,
116
+ encoder_hidden_states=None,
117
+ attention_mask=None,
118
+ temb=None,
119
+ scale: float = 1.0,
120
+ ):
121
+ residual = hidden_states
122
+
123
+ if attn.spatial_norm is not None:
124
+ hidden_states = attn.spatial_norm(hidden_states, temb)
125
+
126
+ input_ndim = hidden_states.ndim
127
+
128
+ if input_ndim == 4:
129
+ batch_size, channel, height, width = hidden_states.shape
130
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
131
+
132
+ batch_size, sequence_length, _ = (
133
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
134
+ )
135
+
136
+ if attention_mask is not None:
137
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
138
+ # scaled_dot_product_attention expects attention_mask shape to be
139
+ # (batch, heads, source_length, target_length)
140
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
141
+
142
+ if attn.group_norm is not None:
143
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
144
+
145
+ query = attn.to_q(hidden_states, scale=scale)
146
+
147
+ if encoder_hidden_states is None:
148
+ encoder_hidden_states = hidden_states
149
+ elif attn.norm_cross:
150
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
151
+
152
+ key = attn.to_k(encoder_hidden_states, scale=scale)
153
+ value = attn.to_v(encoder_hidden_states, scale=scale)
154
+
155
+ inner_dim = key.shape[-1]
156
+ head_dim = inner_dim // attn.heads
157
+
158
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
159
+
160
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
161
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
162
+
163
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
164
+ # TODO: add support for attn.scale when we move to Torch 2.1
165
+ ####################################################################################################
166
+ # if self.store_attn_map:
167
+ if hasattr(self, "store_attn_map"):
168
+ hidden_states, attn_map = scaled_dot_product_attention(
169
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
170
+ )
171
+ # (2,10,4096,77) or (2,20,1024,77)
172
+ self.attn_map = attn_map
173
+ else:
174
+ hidden_states = F.scaled_dot_product_attention(
175
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
176
+ )
177
+ ####################################################################################################
178
+
179
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
180
+ hidden_states = hidden_states.to(query.dtype)
181
+
182
+ # linear proj
183
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
184
+ # dropout
185
+ hidden_states = attn.to_out[1](hidden_states)
186
+
187
+ if input_ndim == 4:
188
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
189
+
190
+ if attn.residual_connection:
191
+ hidden_states = hidden_states + residual
192
+
193
+ hidden_states = hidden_states / attn.rescale_output_factor
194
+
195
+ return hidden_states
196
+
197
+
198
+ def lora_attn_call(self, attn: Attention, hidden_states, *args, **kwargs):
199
+ self_cls_name = self.__class__.__name__
200
+ deprecate(
201
+ self_cls_name,
202
+ "0.26.0",
203
+ (
204
+ f"Make sure use {self_cls_name[4:]} instead by setting"
205
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
206
+ " `LoraLoaderMixin.load_lora_weights`"
207
+ ),
208
+ )
209
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
210
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
211
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
212
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
213
+
214
+ attn._modules.pop("processor")
215
+ attn.processor = AttnProcessor()
216
+
217
+ if hasattr(self, "store_attn_map"):
218
+ attn.processor.store_attn_map = True
219
+
220
+ return attn.processor(attn, hidden_states, *args, **kwargs)
221
+
222
+
223
+ def lora_attn_call2_0(self, attn: Attention, hidden_states, *args, **kwargs):
224
+ self_cls_name = self.__class__.__name__
225
+ deprecate(
226
+ self_cls_name,
227
+ "0.26.0",
228
+ (
229
+ f"Make sure use {self_cls_name[4:]} instead by setting"
230
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
231
+ " `LoraLoaderMixin.load_lora_weights`"
232
+ ),
233
+ )
234
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
235
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
236
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
237
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
238
+
239
+ attn._modules.pop("processor")
240
+ attn.processor = AttnProcessor2_0()
241
+
242
+ if hasattr(self, "store_attn_map"):
243
+ attn.processor.store_attn_map = True
244
+
245
+ return attn.processor(attn, hidden_states, *args, **kwargs)
246
+
247
+
248
+ def cross_attn_init():
249
+ AttnProcessor.__call__ = attn_call
250
+ AttnProcessor2_0.__call__ = attn_call # attn_call is faster
251
+ # AttnProcessor2_0.__call__ = attn_call2_0
252
+ LoRAAttnProcessor.__call__ = lora_attn_call
253
+ # LoRAAttnProcessor2_0.__call__ = lora_attn_call2_0
254
+ LoRAAttnProcessor2_0.__call__ = lora_attn_call
255
+
256
+
257
+ def reshape_attn_map(attn_map):
258
+ attn_map = torch.mean(attn_map,dim=0) # mean by head dim: (20,4096,77) -> (4096,77)
259
+ attn_map = attn_map.permute(1,0) # (4096,77) -> (77,4096)
260
+ latent_size = int(math.sqrt(attn_map.shape[1]))
261
+ latent_shape = (attn_map.shape[0],latent_size,-1)
262
+ attn_map = attn_map.reshape(latent_shape) # (77,4096) -> (77,64,64)
263
+
264
+ return attn_map # torch.sum(attn_map,dim=0) = [1,1,...,1]
265
+
266
+
267
+ def hook_fn(name):
268
+ def forward_hook(module, input, output):
269
+ if hasattr(module.processor, "attn_map"):
270
+ attn_maps[name] = module.processor.attn_map
271
+ del module.processor.attn_map
272
+
273
+ return forward_hook
274
+
275
+ def register_cross_attention_hook(unet):
276
+ for name, module in unet.named_modules():
277
+ if not name.split('.')[-1].startswith('attn2'):
278
+ continue
279
+
280
+ if isinstance(module.processor, AttnProcessor):
281
+ module.processor.store_attn_map = True
282
+ elif isinstance(module.processor, AttnProcessor2_0):
283
+ module.processor.store_attn_map = True
284
+ elif isinstance(module.processor, LoRAAttnProcessor):
285
+ module.processor.store_attn_map = True
286
+ elif isinstance(module.processor, LoRAAttnProcessor2_0):
287
+ module.processor.store_attn_map = True
288
+
289
+ hook = module.register_forward_hook(hook_fn(name))
290
+
291
+ return unet
292
+
293
+
294
+ def prompt2tokens(tokenizer, prompt):
295
+ text_inputs = tokenizer(
296
+ prompt,
297
+ padding="max_length",
298
+ max_length=tokenizer.model_max_length,
299
+ truncation=True,
300
+ return_tensors="pt",
301
+ )
302
+ text_input_ids = text_inputs.input_ids
303
+ tokens = []
304
+ for text_input_id in text_input_ids[0]:
305
+ token = tokenizer.decoder[text_input_id.item()]
306
+ tokens.append(token)
307
+ return tokens
308
+
309
+
310
+ # TODO: generalize for rectangle images
311
+ def upscale(attn_map, target_size):
312
+ attn_map = torch.mean(attn_map, dim=0) # (10, 32*32, 77) -> (32*32, 77)
313
+ attn_map = attn_map.permute(1,0) # (32*32, 77) -> (77, 32*32)
314
+
315
+ if target_size[0]*target_size[1] != attn_map.shape[1]:
316
+ temp_size = (target_size[0]//2, target_size[1]//2)
317
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size) # (77, 32,32)
318
+ attn_map = attn_map.unsqueeze(0) # (77,32,32) -> (1,77,32,32)
319
+
320
+ attn_map = F.interpolate(
321
+ attn_map.to(dtype=torch.float32),
322
+ size=target_size,
323
+ mode='bilinear',
324
+ align_corners=False
325
+ ).squeeze() # (77,64,64)
326
+ else:
327
+ attn_map = attn_map.to(dtype=torch.float32) # (77,64,64)
328
+
329
+ attn_map = torch.softmax(attn_map, dim=0)
330
+ attn_map = attn_map.reshape(attn_map.shape[0],-1) # (77,64*64)
331
+ return attn_map
332
+
333
+
334
+ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
335
+ target_size = (image_size[0]//16, image_size[1]//16)
336
+ idx = 0 if instance_or_negative else 1
337
+ net_attn_maps = []
338
+
339
+ for name, attn_map in attn_maps.items():
340
+ attn_map = attn_map.cpu() if detach else attn_map
341
+ attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG
342
+ if len(attn_map.shape) == 4:
343
+ attn_map = attn_map.squeeze()
344
+
345
+ attn_map = upscale(attn_map, target_size) # (10,32*32,77) -> (77,64*64)
346
+ net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64)
347
+
348
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
349
+ net_attn_maps = net_attn_maps.reshape(net_attn_maps.shape[0], 64,64) # (77,64*64) -> (77,64,64)
350
+
351
+ return net_attn_maps
352
+
353
+
354
+ def save_net_attn_map(net_attn_maps, dir_name, tokenizer, prompt):
355
+ if not os.path.exists(dir_name):
356
+ os.makedirs(dir_name)
357
+
358
+ tokens = prompt2tokens(tokenizer, prompt)
359
+ total_attn_scores = 0
360
+ for i, (token, attn_map) in enumerate(zip(tokens, net_attn_maps)):
361
+ attn_map_score = torch.sum(attn_map)
362
+ attn_map = attn_map.cpu().numpy()
363
+ h,w = attn_map.shape
364
+ attn_map_total = h*w
365
+ attn_map_score = attn_map_score / attn_map_total
366
+ total_attn_scores += attn_map_score
367
+ token = token.replace('</w>','')
368
+ save_attn_map(
369
+ attn_map,
370
+ f'{token}:{attn_map_score:.2f}',
371
+ f"{dir_name}/{i}_<{token}>:{int(attn_map_score*100)}.png"
372
+ )
373
+ print(f'total_attn_scores: {total_attn_scores}')
374
+
375
+
376
+ def resize_net_attn_map(net_attn_maps, target_size):
377
+ net_attn_maps = F.interpolate(
378
+ net_attn_maps.to(dtype=torch.float32).unsqueeze(0),
379
+ size=target_size,
380
+ mode='bilinear',
381
+ align_corners=False
382
+ ).squeeze() # (77,64,64)
383
+ return net_attn_maps
384
+
385
+
386
+ def save_attn_map(attn_map, title, save_path):
387
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
388
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
389
+ image = Image.fromarray(normalized_attn_map)
390
+ image.save(save_path, format='PNG', compression=0)
391
+
392
+
393
+ def return_net_attn_map(net_attn_maps, tokenizer, prompt):
394
+
395
+ tokens = prompt2tokens(tokenizer, prompt)
396
+ total_attn_scores = 0
397
+ images = []
398
+ for i, (token, attn_map) in enumerate(zip(tokens, net_attn_maps)):
399
+ attn_map_score = torch.sum(attn_map)
400
+ h,w = attn_map.shape
401
+ attn_map_total = h*w
402
+ attn_map_score = attn_map_score / attn_map_total
403
+ total_attn_scores += attn_map_score
404
+
405
+ attn_map = attn_map.cpu().numpy()
406
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
407
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
408
+ image = Image.fromarray(normalized_attn_map)
409
+
410
+ token = token.replace('</w>','')
411
+ images.append((image,f"{i}_<{token}>"))
412
+ print(f'total_attn_scores: {total_attn_scores}')
413
+ return images