Spaces:
No application file
No application file
from __future__ import annotations | |
import logging | |
from typing import Any, Dict, Tuple, Union, Optional | |
from einops import rearrange, repeat | |
from torch import nn | |
import torch | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.models.modeling_utils import ModelMixin, load_state_dict | |
from ..data.data_util import align_repeat_tensor_single_dim | |
from .unet_3d_condition import UNet3DConditionModel | |
from .referencenet import ReferenceNet2D | |
from ip_adapter.ip_adapter import ImageProjModel | |
logger = logging.getLogger(__name__) | |
class SuperUNet3DConditionModel(nn.Module): | |
"""封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。 | |
主要作用 | |
1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些; | |
2. 便于 accelerator 的分布式训练; | |
wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj | |
1. support controlnet, referencenet, etc. | |
2. support accelerator distributed training | |
""" | |
_supports_gradient_checkpointing = True | |
print_idx = 0 | |
# @register_to_config | |
def __init__( | |
self, | |
unet: nn.Module, | |
referencenet: nn.Module = None, | |
controlnet: nn.Module = None, | |
vae: nn.Module = None, | |
text_encoder: nn.Module = None, | |
tokenizer: nn.Module = None, | |
text_emb_extractor: nn.Module = None, | |
clip_vision_extractor: nn.Module = None, | |
ip_adapter_image_proj: nn.Module = None, | |
) -> None: | |
"""_summary_ | |
Args: | |
unet (nn.Module): _description_ | |
referencenet (nn.Module, optional): _description_. Defaults to None. | |
controlnet (nn.Module, optional): _description_. Defaults to None. | |
vae (nn.Module, optional): _description_. Defaults to None. | |
text_encoder (nn.Module, optional): _description_. Defaults to None. | |
tokenizer (nn.Module, optional): _description_. Defaults to None. | |
text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None. | |
clip_vision_extractor (nn.Module, optional): _description_. Defaults to None. | |
""" | |
super().__init__() | |
self.unet = unet | |
self.referencenet = referencenet | |
self.controlnet = controlnet | |
self.vae = vae | |
self.text_encoder = text_encoder | |
self.tokenizer = tokenizer | |
self.text_emb_extractor = text_emb_extractor | |
self.clip_vision_extractor = clip_vision_extractor | |
self.ip_adapter_image_proj = ip_adapter_image_proj | |
def forward( | |
self, | |
unet_params: Dict, | |
encoder_hidden_states: torch.Tensor, | |
referencenet_params: Dict = None, | |
controlnet_params: Dict = None, | |
controlnet_scale: float = 1.0, | |
vision_clip_emb: Union[torch.Tensor, None] = None, | |
prompt_only_use_image_prompt: bool = False, | |
): | |
"""_summary_ | |
Args: | |
unet_params (Dict): _description_ | |
encoder_hidden_states (torch.Tensor): b t n d | |
referencenet_params (Dict, optional): _description_. Defaults to None. | |
controlnet_params (Dict, optional): _description_. Defaults to None. | |
controlnet_scale (float, optional): _description_. Defaults to 1.0. | |
vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None. | |
prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False. | |
Returns: | |
_type_: _description_ | |
""" | |
batch_size = unet_params["sample"].shape[0] | |
time_size = unet_params["sample"].shape[2] | |
# ip_adapter_cross_attn, prepare image prompt | |
if vision_clip_emb is not None: | |
# b t n d -> b t n d | |
if self.print_idx == 0: | |
logger.debug( | |
f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" | |
) | |
if vision_clip_emb.ndim == 3: | |
vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d") | |
if self.ip_adapter_image_proj is not None: | |
vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d") | |
vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb) | |
if self.print_idx == 0: | |
logger.debug( | |
f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" | |
) | |
if vision_clip_emb.ndim == 2: | |
vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d") | |
vision_clip_emb = rearrange( | |
vision_clip_emb, "(b t) n d -> b t n d", b=batch_size | |
) | |
vision_clip_emb = align_repeat_tensor_single_dim( | |
vision_clip_emb, target_length=time_size, dim=1 | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" | |
) | |
if vision_clip_emb is None and encoder_hidden_states is not None: | |
vision_clip_emb = encoder_hidden_states | |
if vision_clip_emb is not None and encoder_hidden_states is None: | |
encoder_hidden_states = vision_clip_emb | |
# 当 prompt_only_use_image_prompt 为True时, | |
# 1. referencenet 都使用 vision_clip_emb | |
# 2. unet 如果没有dual_cross_attn,使用vision_clip_emb,有时不更新 | |
# 3. controlnet 当前使用 text_prompt | |
# when prompt_only_use_image_prompt True, | |
# 1. referencenet use vision_clip_emb | |
# 2. unet use vision_clip_emb if no dual_cross_attn, sometimes not update | |
# 3. controlnet use text_prompt | |
# extract referencenet emb | |
if self.referencenet is not None and referencenet_params is not None: | |
referencenet_encoder_hidden_states = align_repeat_tensor_single_dim( | |
vision_clip_emb, | |
target_length=referencenet_params["num_frames"], | |
dim=1, | |
) | |
referencenet_params["encoder_hidden_states"] = rearrange( | |
referencenet_encoder_hidden_states, "b t n d->(b t) n d" | |
) | |
referencenet_out = self.referencenet(**referencenet_params) | |
( | |
down_block_refer_embs, | |
mid_block_refer_emb, | |
refer_self_attn_emb, | |
) = referencenet_out | |
if down_block_refer_embs is not None: | |
if self.print_idx == 0: | |
logger.debug( | |
f"len(down_block_refer_embs)={len(down_block_refer_embs)}" | |
) | |
for i, down_emb in enumerate(down_block_refer_embs): | |
if self.print_idx == 0: | |
logger.debug( | |
f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}" | |
) | |
else: | |
if self.print_idx == 0: | |
logger.debug(f"down_block_refer_embs is None") | |
if mid_block_refer_emb is not None: | |
if self.print_idx == 0: | |
logger.debug( | |
f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}" | |
) | |
else: | |
if self.print_idx == 0: | |
logger.debug(f"mid_block_refer_emb is None") | |
if refer_self_attn_emb is not None: | |
if self.print_idx == 0: | |
logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}") | |
for i, self_attn_emb in enumerate(refer_self_attn_emb): | |
if self.print_idx == 0: | |
logger.debug( | |
f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}" | |
) | |
else: | |
if self.print_idx == 0: | |
logger.debug(f"refer_self_attn_emb is None") | |
else: | |
down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = ( | |
None, | |
None, | |
None, | |
) | |
# extract controlnet emb | |
if self.controlnet is not None and controlnet_params is not None: | |
controlnet_encoder_hidden_states = align_repeat_tensor_single_dim( | |
encoder_hidden_states, | |
target_length=unet_params["sample"].shape[2], | |
dim=1, | |
) | |
controlnet_params["encoder_hidden_states"] = rearrange( | |
controlnet_encoder_hidden_states, " b t n d -> (b t) n d" | |
) | |
( | |
down_block_additional_residuals, | |
mid_block_additional_residual, | |
) = self.controlnet(**controlnet_params) | |
if controlnet_scale != 1.0: | |
down_block_additional_residuals = [ | |
x * controlnet_scale for x in down_block_additional_residuals | |
] | |
mid_block_additional_residual = ( | |
mid_block_additional_residual * controlnet_scale | |
) | |
for i, down_block_additional_residual in enumerate( | |
down_block_additional_residuals | |
): | |
if self.print_idx == 0: | |
logger.debug( | |
f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}" | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}" | |
) | |
else: | |
down_block_additional_residuals = None | |
mid_block_additional_residual = None | |
if prompt_only_use_image_prompt and vision_clip_emb is not None: | |
encoder_hidden_states = vision_clip_emb | |
# run unet | |
out = self.unet( | |
**unet_params, | |
down_block_refer_embs=down_block_refer_embs, | |
mid_block_refer_emb=mid_block_refer_emb, | |
refer_self_attn_emb=refer_self_attn_emb, | |
down_block_additional_residuals=down_block_additional_residuals, | |
mid_block_additional_residual=mid_block_additional_residual, | |
encoder_hidden_states=encoder_hidden_states, | |
vision_clip_emb=vision_clip_emb, | |
) | |
self.print_idx += 1 | |
return out | |
def _set_gradient_checkpointing(self, module, value=False): | |
if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)): | |
module.gradient_checkpointing = value | |