File size: 7,660 Bytes
0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 0c1540a 5cb0966 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import os
import torch
import torch.nn.functional as F
from torchvision.transforms import ToPILImage
from diffusers.models import Transformer2DModel
from diffusers.models.unets import UNet2DConditionModel
from diffusers.models.transformers import SD3Transformer2DModel, FluxTransformer2DModel
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock
from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock
from diffusers import FluxPipeline
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
JointAttnProcessor2_0,
FluxAttnProcessor2_0
)
from modules import *
def cross_attn_init():
AttnProcessor.__call__ = attn_call
AttnProcessor2_0.__call__ = attn_call2_0
LoRAAttnProcessor.__call__ = lora_attn_call
LoRAAttnProcessor2_0.__call__ = lora_attn_call2_0
JointAttnProcessor2_0.__call__ = joint_attn_call2_0
FluxAttnProcessor2_0.__call__ = flux_attn_call2_0
def hook_function(name, detach=True):
def forward_hook(module, input, output):
if hasattr(module.processor, "attn_map"):
timestep = module.processor.timestep
attn_maps[timestep] = attn_maps.get(timestep, dict())
attn_maps[timestep][name] = module.processor.attn_map.cpu() if detach \
else module.processor.attn_map
del module.processor.attn_map
return forward_hook
def register_cross_attention_hook(model, hook_function, target_name):
for name, module in model.named_modules():
if not name.endswith(target_name):
continue
if isinstance(module.processor, AttnProcessor):
module.processor.store_attn_map = True
elif isinstance(module.processor, AttnProcessor2_0):
module.processor.store_attn_map = True
elif isinstance(module.processor, LoRAAttnProcessor):
module.processor.store_attn_map = True
elif isinstance(module.processor, LoRAAttnProcessor2_0):
module.processor.store_attn_map = True
elif isinstance(module.processor, JointAttnProcessor2_0):
module.processor.store_attn_map = True
elif isinstance(module.processor, FluxAttnProcessor2_0):
module.processor.store_attn_map = True
hook = module.register_forward_hook(hook_function(name))
return model
def replace_call_method_for_unet(model):
if model.__class__.__name__ == 'UNet2DConditionModel':
model.forward = UNet2DConditionModelForward.__get__(model, UNet2DConditionModel)
for name, layer in model.named_children():
if layer.__class__.__name__ == 'Transformer2DModel':
layer.forward = Transformer2DModelForward.__get__(layer, Transformer2DModel)
elif layer.__class__.__name__ == 'BasicTransformerBlock':
layer.forward = BasicTransformerBlockForward.__get__(layer, BasicTransformerBlock)
replace_call_method_for_unet(layer)
return model
def replace_call_method_for_sd3(model):
if model.__class__.__name__ == 'SD3Transformer2DModel':
model.forward = SD3Transformer2DModelForward.__get__(model, SD3Transformer2DModel)
for name, layer in model.named_children():
if layer.__class__.__name__ == 'JointTransformerBlock':
layer.forward = JointTransformerBlockForward.__get__(layer, JointTransformerBlock)
replace_call_method_for_sd3(layer)
return model
def replace_call_method_for_flux(model):
if model.__class__.__name__ == 'FluxTransformer2DModel':
model.forward = FluxTransformer2DModelForward.__get__(model, FluxTransformer2DModel)
for name, layer in model.named_children():
if layer.__class__.__name__ == 'FluxTransformerBlock':
layer.forward = FluxTransformerBlockForward.__get__(layer, FluxTransformerBlock)
replace_call_method_for_flux(layer)
return model
def init_pipeline(pipeline):
if 'transformer' in vars(pipeline).keys():
if pipeline.transformer.__class__.__name__ == 'SD3Transformer2DModel':
pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn')
pipeline.transformer = replace_call_method_for_sd3(pipeline.transformer)
elif pipeline.transformer.__class__.__name__ == 'FluxTransformer2DModel':
FluxPipeline.__call__ = FluxPipeline_call
pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn')
pipeline.transformer = replace_call_method_for_flux(pipeline.transformer)
else:
if pipeline.unet.__class__.__name__ == 'UNet2DConditionModel':
pipeline.unet = register_cross_attention_hook(pipeline.unet, hook_function, 'attn2')
pipeline.unet = replace_call_method_for_unet(pipeline.unet)
return pipeline
def save_attention_maps(attn_maps, tokenizer, prompts, base_dir='attn_maps', unconditional=True):
to_pil = ToPILImage()
token_ids = tokenizer(prompts)['input_ids']
total_tokens = []
for token_id in token_ids:
total_tokens.append(tokenizer.convert_ids_to_tokens(token_id))
if not os.path.exists(base_dir):
os.mkdir(base_dir)
total_attn_map = list(list(attn_maps.values())[0].values())[0].sum(1)
if unconditional:
total_attn_map = total_attn_map.chunk(2)[1] # (batch, height, width, attn_dim)
total_attn_map = total_attn_map.permute(0, 3, 1, 2)
total_attn_map = torch.zeros_like(total_attn_map)
total_attn_map_shape = total_attn_map.shape[-2:]
total_attn_map_number = 0
for timestep, layers in attn_maps.items():
timestep_dir = os.path.join(base_dir, f'{timestep}')
if not os.path.exists(timestep_dir):
os.mkdir(timestep_dir)
for layer, attn_map in layers.items():
layer_dir = os.path.join(timestep_dir, f'{layer}')
if not os.path.exists(layer_dir):
os.mkdir(layer_dir)
attn_map = attn_map.sum(1).squeeze(1)
attn_map = attn_map.permute(0, 3, 1, 2)
if unconditional:
attn_map = attn_map.chunk(2)[1]
resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False)
total_attn_map += resized_attn_map
total_attn_map_number += 1
total_attn_map /= total_attn_map_number
final_attn_map = {}
for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)):
batch_dir = os.path.join(base_dir, f'batch-{batch}')
if not os.path.exists(batch_dir):
os.mkdir(batch_dir)
startofword = True
for i, (token, a) in enumerate(zip(tokens, attn_map[:len(tokens)])):
if '</w>' in token:
token = token.replace('</w>', '')
if startofword:
token = '<' + token + '>'
else:
token = '-' + token + '>'
startofword = True
elif token != '<|startoftext|>' and token != '<|endoftext|>':
if startofword:
token = '<' + token + '-'
startofword = False
else:
token = '-' + token + '-'
final_attn_map[f'{i}-{token}.png'] = to_pil(a.to(torch.float32))
return final_attn_map |