|
import os |
|
import json |
|
import torch |
|
from model.attn_processor import AttnProcessor2_0, SkipAttnProcessor |
|
|
|
|
|
def init_adapter(unet, |
|
cross_attn_cls=SkipAttnProcessor, |
|
self_attn_cls=None, |
|
cross_attn_dim=None, |
|
**kwargs): |
|
if cross_attn_dim is None: |
|
cross_attn_dim = unet.config.cross_attention_dim |
|
attn_procs = {} |
|
for name in unet.attn_processors.keys(): |
|
cross_attention_dim = None if name.endswith("attn1.processor") else cross_attn_dim |
|
if name.startswith("mid_block"): |
|
hidden_size = unet.config.block_out_channels[-1] |
|
elif name.startswith("up_blocks"): |
|
block_id = int(name[len("up_blocks.")]) |
|
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] |
|
elif name.startswith("down_blocks"): |
|
block_id = int(name[len("down_blocks.")]) |
|
hidden_size = unet.config.block_out_channels[block_id] |
|
if cross_attention_dim is None: |
|
if self_attn_cls is not None: |
|
attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) |
|
else: |
|
|
|
attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) |
|
else: |
|
attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) |
|
|
|
unet.set_attn_processor(attn_procs) |
|
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) |
|
return adapter_modules |
|
|
|
def init_diffusion_model(diffusion_model_name_or_path, unet_class=None): |
|
from diffusers import AutoencoderKL |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder="text_encoder") |
|
vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder="vae") |
|
tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder="tokenizer") |
|
try: |
|
unet_folder = os.path.join(diffusion_model_name_or_path, "unet") |
|
unet_configs = json.load(open(os.path.join(unet_folder, "config.json"), "r")) |
|
unet = unet_class(**unet_configs) |
|
unet.load_state_dict(torch.load(os.path.join(unet_folder, "diffusion_pytorch_model.bin"), map_location="cpu"), strict=True) |
|
except: |
|
unet = None |
|
return text_encoder, vae, tokenizer, unet |
|
|
|
def attn_of_unet(unet): |
|
attn_blocks = torch.nn.ModuleList() |
|
for name, param in unet.named_modules(): |
|
if "attn1" in name: |
|
attn_blocks.append(param) |
|
return attn_blocks |
|
|
|
def get_trainable_module(unet, trainable_module_name): |
|
if trainable_module_name == "unet": |
|
return unet |
|
elif trainable_module_name == "transformer": |
|
trainable_modules = torch.nn.ModuleList() |
|
for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]: |
|
if hasattr(blocks, "attentions"): |
|
trainable_modules.append(blocks.attentions) |
|
else: |
|
for block in blocks: |
|
if hasattr(block, "attentions"): |
|
trainable_modules.append(block.attentions) |
|
return trainable_modules |
|
elif trainable_module_name == "attention": |
|
attn_blocks = torch.nn.ModuleList() |
|
for name, param in unet.named_modules(): |
|
if "attn1" in name: |
|
attn_blocks.append(param) |
|
return attn_blocks |
|
else: |
|
raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}") |
|
|
|
|
|
|
|
|