Spaces:
No application file
No application file
import copy | |
from typing import Any, Callable, Dict, Iterable, Union | |
import PIL | |
import cv2 | |
import torch | |
import argparse | |
import datetime | |
import logging | |
import inspect | |
import math | |
import os | |
import shutil | |
from typing import Dict, List, Optional, Tuple | |
from pprint import pprint | |
from collections import OrderedDict | |
from dataclasses import dataclass | |
import gc | |
import time | |
import numpy as np | |
from omegaconf import OmegaConf | |
from omegaconf import SCMode | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from einops import rearrange, repeat | |
import pandas as pd | |
import h5py | |
from diffusers.models.modeling_utils import load_state_dict | |
from diffusers.utils import ( | |
logging, | |
) | |
from diffusers.utils.import_utils import is_xformers_available | |
from mmcm.vision.feature_extractor.clip_vision_extractor import ( | |
ImageClipVisionFeatureExtractor, | |
ImageClipVisionFeatureExtractorV2, | |
) | |
from mmcm.vision.feature_extractor.insight_face_extractor import InsightFaceExtractor | |
from ip_adapter.resampler import Resampler | |
from ip_adapter.ip_adapter import ImageProjModel | |
from .unet_loader import update_unet_with_sd | |
from .unet_3d_condition import UNet3DConditionModel | |
from .ip_adapter_loader import ip_adapter_keys_list | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
# refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651 | |
unet_keys_list = [ | |
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", | |
"mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", | |
] | |
UNET2IPAadapter_Keys_MAPIING = { | |
k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list) | |
} | |
def load_facein_extractor_and_proj_by_name( | |
model_name: str, | |
ip_ckpt: Tuple[str, nn.Module], | |
ip_image_encoder: Tuple[str, nn.Module] = None, | |
cross_attention_dim: int = 768, | |
clip_embeddings_dim: int = 512, | |
clip_extra_context_tokens: int = 1, | |
ip_scale: float = 0.0, | |
dtype: torch.dtype = torch.float16, | |
device: str = "cuda", | |
unet: nn.Module = None, | |
) -> nn.Module: | |
pass | |
def update_unet_facein_cross_attn_param( | |
unet: UNet3DConditionModel, ip_adapter_state_dict: Dict | |
) -> None: | |
"""use independent ip_adapter attn 中的 to_k, to_v in unet | |
ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']的字典 | |
Args: | |
unet (UNet3DConditionModel): _description_ | |
ip_adapter_state_dict (Dict): _description_ | |
""" | |
pass | |