Spaces:
No application file
No application file
import copy | |
from typing import Any, Callable, Dict, Iterable, Union | |
import PIL | |
import cv2 | |
import torch | |
import argparse | |
import datetime | |
import logging | |
import inspect | |
import math | |
import os | |
import shutil | |
from typing import Dict, List, Optional, Tuple | |
from pprint import pprint | |
from collections import OrderedDict | |
from dataclasses import dataclass | |
import gc | |
import time | |
import numpy as np | |
from omegaconf import OmegaConf | |
from omegaconf import SCMode | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from einops import rearrange, repeat | |
import pandas as pd | |
import h5py | |
from diffusers.models.modeling_utils import load_state_dict | |
from diffusers.utils import ( | |
logging, | |
) | |
from diffusers.utils.import_utils import is_xformers_available | |
from ..models.unet_3d_condition import UNet3DConditionModel | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
def update_unet_with_sd( | |
unet: nn.Module, sd_model: Tuple[str, nn.Module], subfolder: str = "unet" | |
): | |
"""更新T2V模型中的T2I参数. update t2i parameters in t2v model | |
Args: | |
unet (nn.Module): _description_ | |
sd_model (Tuple[str, nn.Module]): _description_ | |
Returns: | |
_type_: _description_ | |
""" | |
# dtype = unet.dtype | |
# TODO: in this way, sd_model_path must be absolute path, to be more dynamic | |
if isinstance(sd_model, str): | |
if os.path.isdir(sd_model): | |
unet_state_dict = load_state_dict( | |
os.path.join(sd_model, subfolder, "diffusion_pytorch_model.bin"), | |
) | |
elif os.path.isfile(sd_model): | |
if sd_model.endswith("pth"): | |
unet_state_dict = torch.load(sd_model, map_location="cpu") | |
print(f"referencenet successful load ={sd_model} with torch.load") | |
else: | |
try: | |
unet_state_dict = load_state_dict(sd_model) | |
print( | |
f"referencenet successful load with {sd_model} with load_state_dict" | |
) | |
except Exception as e: | |
print(e) | |
elif isinstance(sd_model, nn.Module): | |
unet_state_dict = sd_model.state_dict() | |
else: | |
raise ValueError(f"given {type(sd_model)}, but only support nn.Module or str") | |
missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False) | |
assert len(unexpected) == 0, f"unet load_state_dict error, unexpected={unexpected}" | |
# unet.to(dtype=dtype) | |
return unet | |
def load_unet( | |
sd_unet_model: Tuple[str, nn.Module], | |
sd_model: Tuple[str, nn.Module] = None, | |
cross_attention_dim: int = 768, | |
temporal_transformer: str = "TransformerTemporalModel", | |
temporal_conv_block: str = "TemporalConvLayer", | |
need_spatial_position_emb: bool = False, | |
need_transformer_in: bool = True, | |
need_t2i_ip_adapter: bool = False, | |
need_adain_temporal_cond: bool = False, | |
t2i_ip_adapter_attn_processor: str = "IPXFormersAttnProcessor", | |
keep_vision_condtion: bool = False, | |
use_anivv1_cfg: bool = False, | |
resnet_2d_skip_time_act: bool = False, | |
dtype: torch.dtype = torch.float16, | |
need_zero_vis_cond_temb: bool = True, | |
norm_spatial_length: bool = True, | |
spatial_max_length: int = 2048, | |
need_refer_emb: bool = False, | |
ip_adapter_cross_attn=False, | |
t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor", | |
need_t2i_facein: bool = False, | |
need_t2i_ip_adapter_face: bool = False, | |
strict: bool = True, | |
): | |
"""通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. | |
该部分都是通过 models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型 | |
model is defined and trained in models.unet_3d_condition.py:UNet3DConditionModel | |
Args: | |
sd_unet_model (Tuple[str, nn.Module]): _description_ | |
sd_model (Tuple[str, nn.Module]): _description_ | |
cross_attention_dim (int, optional): _description_. Defaults to 768. | |
temporal_transformer (str, optional): _description_. Defaults to "TransformerTemporalModel". | |
temporal_conv_block (str, optional): _description_. Defaults to "TemporalConvLayer". | |
need_spatial_position_emb (bool, optional): _description_. Defaults to False. | |
need_transformer_in (bool, optional): _description_. Defaults to True. | |
need_t2i_ip_adapter (bool, optional): _description_. Defaults to False. | |
need_adain_temporal_cond (bool, optional): _description_. Defaults to False. | |
t2i_ip_adapter_attn_processor (str, optional): _description_. Defaults to "IPXFormersAttnProcessor". | |
keep_vision_condtion (bool, optional): _description_. Defaults to False. | |
use_anivv1_cfg (bool, optional): _description_. Defaults to False. | |
resnet_2d_skip_time_act (bool, optional): _description_. Defaults to False. | |
dtype (torch.dtype, optional): _description_. Defaults to torch.float16. | |
need_zero_vis_cond_temb (bool, optional): _description_. Defaults to True. | |
norm_spatial_length (bool, optional): _description_. Defaults to True. | |
spatial_max_length (int, optional): _description_. Defaults to 2048. | |
Returns: | |
_type_: _description_ | |
""" | |
if isinstance(sd_unet_model, str): | |
unet = UNet3DConditionModel.from_pretrained_2d( | |
sd_unet_model, | |
subfolder="unet", | |
temporal_transformer=temporal_transformer, | |
temporal_conv_block=temporal_conv_block, | |
cross_attention_dim=cross_attention_dim, | |
need_spatial_position_emb=need_spatial_position_emb, | |
need_transformer_in=need_transformer_in, | |
need_t2i_ip_adapter=need_t2i_ip_adapter, | |
need_adain_temporal_cond=need_adain_temporal_cond, | |
t2i_ip_adapter_attn_processor=t2i_ip_adapter_attn_processor, | |
keep_vision_condtion=keep_vision_condtion, | |
use_anivv1_cfg=use_anivv1_cfg, | |
resnet_2d_skip_time_act=resnet_2d_skip_time_act, | |
torch_dtype=dtype, | |
need_zero_vis_cond_temb=need_zero_vis_cond_temb, | |
norm_spatial_length=norm_spatial_length, | |
spatial_max_length=spatial_max_length, | |
need_refer_emb=need_refer_emb, | |
ip_adapter_cross_attn=ip_adapter_cross_attn, | |
t2i_crossattn_ip_adapter_attn_processor=t2i_crossattn_ip_adapter_attn_processor, | |
need_t2i_facein=need_t2i_facein, | |
strict=strict, | |
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, | |
) | |
elif isinstance(sd_unet_model, nn.Module): | |
unet = sd_unet_model | |
if sd_model is not None: | |
unet = update_unet_with_sd(unet, sd_model) | |
return unet | |
def load_unet_custom_unet( | |
sd_unet_model: Tuple[str, nn.Module], | |
sd_model: Tuple[str, nn.Module], | |
unet_class: nn.Module, | |
): | |
""" | |
通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. | |
该部分都是通过 不通过models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型 | |
model is not defined in models.unet_3d_condition.py:UNet3DConditionModel | |
Args: | |
sd_unet_model (Tuple[str, nn.Module]): _description_ | |
sd_model (Tuple[str, nn.Module]): _description_ | |
unet_class (nn.Module): _description_ | |
Returns: | |
_type_: _description_ | |
""" | |
if isinstance(sd_unet_model, str): | |
unet = unet_class.from_pretrained( | |
sd_unet_model, | |
subfolder="unet", | |
) | |
elif isinstance(sd_unet_model, nn.Module): | |
unet = sd_unet_model | |
# TODO: in this way, sd_model_path must be absolute path, to be more dynamic | |
if isinstance(sd_model, str): | |
unet_state_dict = load_state_dict( | |
os.path.join(sd_model, "unet/diffusion_pytorch_model.bin"), | |
) | |
elif isinstance(sd_model, nn.Module): | |
unet_state_dict = sd_model.state_dict() | |
missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False) | |
assert ( | |
len(unexpected) == 0 | |
), "unet load_state_dict error" # Load scheduler, tokenizer and models. | |
return unet | |
def load_unet_by_name( | |
model_name: str, | |
sd_unet_model: Tuple[str, nn.Module], | |
sd_model: Tuple[str, nn.Module] = None, | |
cross_attention_dim: int = 768, | |
dtype: torch.dtype = torch.float16, | |
need_t2i_facein: bool = False, | |
need_t2i_ip_adapter_face: bool = False, | |
strict: bool = True, | |
) -> nn.Module: | |
"""通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. | |
如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义 | |
if you want to use pretrained model with simple name, you need to define it here. | |
Args: | |
model_name (str): _description_ | |
sd_unet_model (Tuple[str, nn.Module]): _description_ | |
sd_model (Tuple[str, nn.Module]): _description_ | |
cross_attention_dim (int, optional): _description_. Defaults to 768. | |
dtype (torch.dtype, optional): _description_. Defaults to torch.float16. | |
Raises: | |
ValueError: _description_ | |
Returns: | |
nn.Module: _description_ | |
""" | |
if model_name in ["musev"]: | |
unet = load_unet( | |
sd_unet_model=sd_unet_model, | |
sd_model=sd_model, | |
need_spatial_position_emb=False, | |
cross_attention_dim=cross_attention_dim, | |
need_t2i_ip_adapter=True, | |
need_adain_temporal_cond=True, | |
t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor", | |
dtype=dtype, | |
) | |
elif model_name in [ | |
"musev_referencenet", | |
"musev_referencenet_pose", | |
]: | |
unet = load_unet( | |
sd_unet_model=sd_unet_model, | |
sd_model=sd_model, | |
cross_attention_dim=cross_attention_dim, | |
temporal_conv_block="TemporalConvLayer", | |
need_transformer_in=False, | |
temporal_transformer="TransformerTemporalModel", | |
use_anivv1_cfg=True, | |
resnet_2d_skip_time_act=True, | |
need_t2i_ip_adapter=True, | |
need_adain_temporal_cond=True, | |
keep_vision_condtion=True, | |
t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor", | |
dtype=dtype, | |
need_refer_emb=True, | |
need_zero_vis_cond_temb=True, | |
ip_adapter_cross_attn=True, | |
t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor", | |
need_t2i_facein=need_t2i_facein, | |
strict=strict, | |
need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, | |
) | |
else: | |
raise ValueError( | |
f"unsupport model_name={model_name}, only support musev, musev_referencenet, musev_referencenet_pose" | |
) | |
return unet | |