Spaces:
No application file
No application file
import copy | |
from typing import Any, Callable, Dict, Iterable, Union | |
import PIL | |
import cv2 | |
import torch | |
import argparse | |
import datetime | |
import logging | |
import inspect | |
import math | |
import os | |
import shutil | |
from typing import Dict, List, Optional, Tuple | |
from pprint import pprint | |
from collections import OrderedDict | |
from dataclasses import dataclass | |
import gc | |
import time | |
import numpy as np | |
from omegaconf import OmegaConf | |
from omegaconf import SCMode | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from einops import rearrange, repeat | |
import pandas as pd | |
import h5py | |
from diffusers.models.modeling_utils import load_state_dict | |
from diffusers.utils import ( | |
logging, | |
) | |
from diffusers.utils.import_utils import is_xformers_available | |
from .referencenet import ReferenceNet2D | |
from .unet_loader import update_unet_with_sd | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
def load_referencenet( | |
sd_referencenet_model: Tuple[str, nn.Module], | |
sd_model: nn.Module = None, | |
need_self_attn_block_embs: bool = False, | |
need_block_embs: bool = False, | |
dtype: torch.dtype = torch.float16, | |
cross_attention_dim: int = 768, | |
subfolder: str = "unet", | |
): | |
""" | |
Loads the ReferenceNet model. | |
Args: | |
sd_referencenet_model (Tuple[str, nn.Module] or str): The pretrained ReferenceNet model or the path to the model. | |
sd_model (nn.Module, optional): The sd_model to update the ReferenceNet with. Defaults to None. | |
need_self_attn_block_embs (bool, optional): Whether to compute self-attention block embeddings. Defaults to False. | |
need_block_embs (bool, optional): Whether to compute block embeddings. Defaults to False. | |
dtype (torch.dtype, optional): The data type of the tensors. Defaults to torch.float16. | |
cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 768. | |
subfolder (str, optional): The subfolder of the model. Defaults to "unet". | |
Returns: | |
nn.Module: The loaded ReferenceNet model. | |
""" | |
if isinstance(sd_referencenet_model, str): | |
referencenet = ReferenceNet2D.from_pretrained( | |
sd_referencenet_model, | |
subfolder=subfolder, | |
need_self_attn_block_embs=need_self_attn_block_embs, | |
need_block_embs=need_block_embs, | |
torch_dtype=dtype, | |
cross_attention_dim=cross_attention_dim, | |
) | |
elif isinstance(sd_referencenet_model, nn.Module): | |
referencenet = sd_referencenet_model | |
if sd_model is not None: | |
referencenet = update_unet_with_sd(referencenet, sd_model) | |
return referencenet | |
def load_referencenet_by_name( | |
model_name: str, | |
sd_referencenet_model: Tuple[str, nn.Module], | |
sd_model: nn.Module = None, | |
cross_attention_dim: int = 768, | |
dtype: torch.dtype = torch.float16, | |
) -> nn.Module: | |
"""通过模型名字 初始化 referencenet,载入预训练参数, | |
如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义 | |
init referencenet with model_name. | |
if you want to use pretrained model with simple name, you need to define it here. | |
Args: | |
model_name (str): _description_ | |
sd_unet_model (Tuple[str, nn.Module]): _description_ | |
sd_model (Tuple[str, nn.Module]): _description_ | |
cross_attention_dim (int, optional): _description_. Defaults to 768. | |
dtype (torch.dtype, optional): _description_. Defaults to torch.float16. | |
Raises: | |
ValueError: _description_ | |
Returns: | |
nn.Module: _description_ | |
""" | |
if model_name in [ | |
"musev_referencenet", | |
]: | |
unet = load_referencenet( | |
sd_referencenet_model=sd_referencenet_model, | |
sd_model=sd_model, | |
cross_attention_dim=cross_attention_dim, | |
dtype=dtype, | |
need_self_attn_block_embs=False, | |
need_block_embs=True, | |
subfolder="referencenet", | |
) | |
else: | |
raise ValueError( | |
f"unsupport model_name={model_name}, only support ReferenceNet_V0_block13, ReferenceNet_V1_block13, ReferenceNet_V2_block13, ReferenceNet_V0_sefattn16" | |
) | |
return unet | |