import math from typing import Optional, Callable import xformers from omegaconf import OmegaConf import yaml from .util import classify_blocks def identify_blocks(block_list, name): block_name = None for block in block_list: if block in name: block_name = block break return block_name class MySelfAttnProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn, hidden_states, query, key, value, attention_mask): # self.attn = attn self.key = key self.query = query # self.value = value # self.attention_mask = attention_mask # self.hidden_state = hidden_states.detach() # return hidden_states def record_qkv(self, attn, hidden_states, query, key, value, attention_mask): # self.attn = attn self.key = key self.query = query # self.value = value # # self.attention_mask = attention_mask # self.hidden_state = hidden_states.detach() # # import pdb; pdb.set_trace() def record_attn_mask(self, attn, hidden_states, query, key, value, attention_mask): self.attn = attn self.attention_mask = attention_mask def prep_unet_attention(unet,motion_gudiance_blocks): # replace the fwd function for name, module in unet.named_modules(): module_name = type(module).__name__ if "VersatileAttention" in module_name and classify_blocks(motion_gudiance_blocks, name): # the temporary attention in guidance blocks module.set_processor(MySelfAttnProcessor()) # print(module_name) return unet def get_self_attn_feat(unet, injection_config, config): hidden_state_dict = dict() query_dict = dict() key_dict = dict() value_dict = dict() for name, module in unet.named_modules(): module_name = type(module).__name__ if "CrossAttention" in module_name and 'attn1' in name and classify_blocks(injection_config.blocks, name=name): res = int(math.sqrt(module.processor.hidden_state.shape[1])) # import pdb; pdb.set_trace() bs = module.processor.hidden_state.shape[0] # 20 * 16 = 320 # block_name = identify_blocks(injection_config.blocks, name=name) # block_id = int(block_name.split('.')[-1]) # h = config.H // (32 * block_id) # w = config.W // (32 * block_id) hidden_state_dict[name] = module.processor.hidden_state.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) res = int(math.sqrt(module.processor.query.shape[1])) query_dict[name] = module.processor.query.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) key_dict[name] = module.processor.key.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) value_dict[name] = module.processor.value.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) # import pdb; pdb.set_trace() # import pdb; pdb.set_trace() return hidden_state_dict, query_dict, key_dict, value_dict def clean_attn_buffer(unet): for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "Attention" and 'attn' in name: if 'injection_config' in module.processor.__dict__.keys(): module.processor.injection_config = None if 'injection_mask' in module.processor.__dict__.keys(): module.processor.injection_mask = None if 'obj_index' in module.processor.__dict__.keys(): module.processor.obj_index = None if 'pca_weight' in module.processor.__dict__.keys(): module.processor.pca_weight = None if 'pca_weight_changed' in module.processor.__dict__.keys(): module.processor.pca_weight_changed = None if 'pca_info' in module.processor.__dict__.keys(): module.processor.pca_info = None if 'step' in module.processor.__dict__.keys(): module.processor.step = None