abubakar123456's picture
Upload 750 files
71d94dd verified
import os
import json
import torch
from model.attn_processor import AttnProcessor2_0, SkipAttnProcessor
def init_adapter(unet,
cross_attn_cls=SkipAttnProcessor,
self_attn_cls=None,
cross_attn_dim=None,
**kwargs):
if cross_attn_dim is None:
cross_attn_dim = unet.config.cross_attention_dim
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else cross_attn_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
if self_attn_cls is not None:
attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
else:
# retain the original attn processor
attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
else:
attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
return adapter_modules
def init_diffusion_model(diffusion_model_name_or_path, unet_class=None):
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder="tokenizer")
try:
unet_folder = os.path.join(diffusion_model_name_or_path, "unet")
unet_configs = json.load(open(os.path.join(unet_folder, "config.json"), "r"))
unet = unet_class(**unet_configs)
unet.load_state_dict(torch.load(os.path.join(unet_folder, "diffusion_pytorch_model.bin"), map_location="cpu"), strict=True)
except:
unet = None
return text_encoder, vae, tokenizer, unet
def attn_of_unet(unet):
attn_blocks = torch.nn.ModuleList()
for name, param in unet.named_modules():
if "attn1" in name:
attn_blocks.append(param)
return attn_blocks
def get_trainable_module(unet, trainable_module_name):
if trainable_module_name == "unet":
return unet
elif trainable_module_name == "transformer":
trainable_modules = torch.nn.ModuleList()
for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:
if hasattr(blocks, "attentions"):
trainable_modules.append(blocks.attentions)
else:
for block in blocks:
if hasattr(block, "attentions"):
trainable_modules.append(block.attentions)
return trainable_modules
elif trainable_module_name == "attention":
attn_blocks = torch.nn.ModuleList()
for name, param in unet.named_modules():
if "attn1" in name:
attn_blocks.append(param)
return attn_blocks
else:
raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}")