Spaces:
No application file
No application file
from copy import deepcopy | |
import inspect | |
from typing import Any, Callable, Dict, List, Literal, Tuple, Union | |
import warnings | |
import os | |
import random | |
import h5py | |
from diffusers.image_processor import VaeImageProcessor | |
import cv2 | |
from einops import rearrange, repeat | |
import numpy as np | |
import torch | |
from torch import nn | |
from PIL import Image | |
import controlnet_aux | |
from diffusers.models.controlnet import ControlNetModel | |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel | |
from controlnet_aux.dwpose import draw_pose | |
from ..process.image_process import dynamic_crop_resize_image | |
from ..utils.data_type_util import convert_images | |
from ...data.emb.h5py_emb import save_value_with_h5py | |
from ...data.extract_feature.base_extract_feature import BaseFeatureExtractor | |
import json | |
def json_serializer(obj): | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
return obj | |
import time | |
def controlnet_tile_processor(img, **kwargs): | |
return img | |
def size_calculate(H, W, resolution): | |
H = float(H) | |
W = float(W) | |
k = float(resolution) / min(H, W) | |
H *= k | |
W *= k | |
H = int(np.round(H / 64.0)) * 64 | |
W = int(np.round(W / 64.0)) * 64 | |
return H, W | |
def HWC3(x): | |
assert x.dtype == np.uint8 | |
if x.ndim == 2: | |
x = x[:, :, None] | |
assert x.ndim == 3 | |
H, W, C = x.shape | |
assert C == 1 or C == 3 or C == 4 | |
if C == 3: | |
return x | |
if C == 1: | |
return np.concatenate([x, x, x], axis=2) | |
if C == 4: | |
color = x[:, :, 0:3].astype(np.float32) | |
alpha = x[:, :, 3:4].astype(np.float32) / 255.0 | |
y = color * alpha + 255.0 * (1.0 - alpha) | |
y = y.clip(0, 255).astype(np.uint8) | |
return y | |
def pose2map(pose, H_in, W_in, detect_resolution, image_resolution): | |
H, W = size_calculate(H_in, W_in, detect_resolution) | |
detected_map = draw_pose(pose, H, W) | |
detected_map = HWC3(detected_map) | |
H, W = size_calculate(H, W, image_resolution) | |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) | |
return detected_map | |
def candidate2pose( | |
candidate, | |
subset, | |
include_body: bool = True, | |
include_face: bool = False, | |
hand_and_face: bool = None, | |
include_hand: bool = True, | |
): | |
if hand_and_face is not None: | |
include_face = True | |
include_hand = True | |
nums, keys, locs = candidate.shape | |
body = candidate[:, :18].copy() | |
body = body.reshape(nums * 18, locs) | |
score = subset[:, :18] | |
for i in range(len(score)): | |
for j in range(len(score[i])): | |
if score[i][j] > 0.3: | |
score[i][j] = int(18 * i + j) | |
else: | |
score[i][j] = -1 | |
un_visible = subset < 0.3 | |
candidate[un_visible] = -1 | |
foot = candidate[:, 18:24] | |
faces = candidate[:, 24:92] | |
hands = candidate[:, 92:113] | |
hands = np.vstack([hands, candidate[:, 113:]]) | |
bodies = dict(candidate=body, subset=score) | |
if not include_body: | |
bodies = [] | |
if not include_face: | |
faces = [] | |
if not include_hand: | |
hands = [] | |
pose = dict(bodies=bodies, hands=hands, faces=faces) | |
return pose | |
class ControlnetProcessor(object): | |
def __init__( | |
self, | |
detector_name: str, | |
detector_id: str = None, | |
filename: str = None, | |
cache_dir: str = None, | |
device: str = "cpu", | |
dtype: torch.dtype = torch.float32, | |
processor_params: Dict = None, | |
processor_name: str = None, | |
) -> None: | |
self.detector_name = detector_name | |
self.detector_id = detector_id | |
self.processor_name = processor_name | |
if detector_name is None: | |
self.processor = None | |
self.processor_params = {} | |
if isinstance(processor_name, str) and "tile" in processor_name: | |
self.processor = controlnet_tile_processor | |
else: | |
processor_cls = controlnet_aux.__dict__[detector_name] | |
processor_cls_argspec = inspect.getfullargspec(processor_cls.__init__) | |
self.processor_params = ( | |
processor_params if processor_params is not None else {} | |
) | |
if not hasattr(processor_cls, "from_pretrained"): | |
self.processor = processor_cls() | |
else: | |
self.processor = processor_cls.from_pretrained( | |
detector_id, | |
cache_dir=cache_dir, | |
filename=filename, | |
**self.processor_params, | |
) | |
if hasattr(self.processor, "to"): | |
self.processor = self.processor.to(device=device) | |
self.device = device | |
self.dtype = dtype | |
def __call__( | |
self, | |
data: Union[ | |
Image.Image, List[Image.Image], str, List[str], np.ndarray, torch.Tensor | |
], | |
data_channel_order: str, | |
target_width: int = None, | |
target_height: int = None, | |
return_type: Literal["pil", "np", "torch"] = "np", | |
return_data_channel_order: str = "b h w c", | |
processor_params: Dict = None, | |
input_rgb_order: str = "rgb", | |
return_rgb_order: str = "rgb", | |
) -> Union[np.ndarray, torch.Tensor]: | |
# TODO: 目前采用二选一的方式,后续可以改进为增量更新 | |
processor_params = processor_params if processor_params is not None else {} | |
data = convert_images( | |
data, | |
return_type="pil", | |
input_rgb_order=input_rgb_order, | |
return_rgb_order=return_rgb_order, | |
data_channel_order=data_channel_order, | |
) | |
height, width = data[0].height, data[0].width | |
if target_width is None: | |
target_width = width | |
if target_height is None: | |
target_height = height | |
data = [ | |
dynamic_crop_resize_image( | |
image, target_height=target_height, target_width=target_width | |
) | |
for image in data | |
] | |
if self.processor is not None: | |
data = [self.processor(image, **processor_params) for image in data] | |
# return_pose_only (bool): if true, only return pose keypoints in array format | |
if "return_pose_only" in processor_params.keys(): | |
if ( | |
self.detector_name == "DWposeDetector" | |
and processor_params["return_pose_only"] | |
): | |
# (18, 2) | |
# (1, 18) | |
# (2, 21, 2) | |
# (1, 68, 2) | |
# j=json.dumps(data) | |
# json_str = json.dumps(data, default=json_serializer) | |
# return json_str | |
# print(len(data)) | |
item_lsit = [] | |
for candidate, subset in data: | |
# candidate shape (1, 134, 2) | |
# subset (1, 134) | |
# print(candidate.shape) | |
# print(subset.shape) | |
subset = np.expand_dims(subset, -1) | |
item = np.concatenate([candidate, subset], -1) | |
# print(item.shape) | |
max_num = 20 | |
if item.shape[0] > max_num: | |
item = item[:max_num] | |
if item.shape[0] < max_num: | |
pad_num = max_num - item.shape[0] | |
item = np.pad(item, ((0, pad_num), (0, 0), (0, 0))) | |
# print(item.shape) | |
# print() | |
item_lsit.append(item) | |
return np.stack(item_lsit, axis=0) # b, num_candidates, 134, 3 | |
if return_type == "pil": | |
return data | |
data = np.stack([np.asarray(image) for image in data], axis=0) | |
if return_data_channel_order != "b h w c": | |
data = rearrange(data, "b h w c -> {}".format(return_data_channel_order)) | |
if return_type == "np": | |
return data | |
if return_type == "torch": | |
data = torch.from_numpy(data) | |
return data | |
class MultiControlnetProcessor(object): | |
def __init__(self, processors: List[ControlnetProcessor]) -> None: | |
self.processors = processors | |
def __call__( | |
self, | |
data: Union[ | |
Image.Image, List[Image.Image], str, List[str], np.ndarray, torch.Tensor | |
], | |
data_channel_order: str, | |
target_width: int = None, | |
target_height: int = None, | |
return_type: Literal["pil", "np", "torch"] = "np", | |
return_data_channel_order: str = "b h w c", | |
processor_params: List[Dict] = None, | |
input_rgb_order: str = "rgb", | |
return_rgb_order: str = "rgb", | |
) -> Union[np.ndarray, torch.Tensor]: | |
if processor_params is not None: | |
assert isinstance( | |
processor_params, list | |
), f"type of datas should be list, but given {type(datas)}" | |
assert len(processor_params) == len( | |
self.processors | |
), f"length of datas({len(processor_params)}) be same as of {len(self.processors)}" | |
datas = [ | |
processor( | |
data=data, | |
data_channel_order=data_channel_order, | |
target_height=target_height, | |
target_width=target_width, | |
return_type=return_type, | |
return_data_channel_order=return_data_channel_order, | |
input_rgb_order=input_rgb_order, | |
processor_params=processor_params[i], | |
) | |
for i, processor in enumerate(self.processors) | |
] | |
return datas | |
class ControlnetFeatureExtractor(BaseFeatureExtractor): | |
def __init__( | |
self, | |
model_path: str, | |
detector_name: str, | |
detector_id: str, | |
device: str = "cpu", | |
dtype=torch.float32, | |
name: str = None, | |
# /group/30065/users/public/muse/models/stable-diffusion-v1-5/vae/config.json | |
vae_config_block_out_channels: int = 4, | |
processor_params: Dict = None, | |
filename=None, | |
cache_dir: str = None, | |
): | |
super().__init__(device, dtype, name) | |
self.model_path = model_path | |
self.processor = ControlnetProcessor( | |
detector_name=detector_name, | |
detector_id=detector_id, | |
filename=filename, | |
cache_dir=cache_dir, | |
device=device, | |
dtype=dtype, | |
) | |
self.vae_scale_factor = 2 ** (vae_config_block_out_channels - 1) | |
self.control_image_processor = VaeImageProcessor( | |
vae_scale_factor=self.vae_scale_factor, | |
do_convert_rgb=True, | |
do_normalize=False, | |
) | |
self.controlnet = ControlNetModel.from_pretrained( | |
model_path, | |
).to(device=device, dtype=dtype) | |
self.detector_name = detector_name | |
def emb_name(self, width, height): | |
return "{}_w={}_h={}_emb".format(self.name, width, height) | |
def prepare_image( | |
self, | |
image, # b c t h w | |
width, | |
height, | |
): | |
if isinstance(image, np.ndarray): | |
image = torch.from_numpy(image) | |
if image.ndim == 5: | |
image = rearrange(image, "b c t h w-> (b t) c h w") | |
if height is None: | |
height = image.shape[-2] | |
if width is None: | |
width = image.shape[-1] | |
width, height = ( | |
x - x % self.control_image_processor.vae_scale_factor | |
for x in (width, height) | |
) | |
image = image / 255.0 | |
# image = torch.nn.functional.interpolate(image, size=(height, width)) | |
do_normalize = self.control_image_processor.config.do_normalize | |
if image.min() < 0: | |
warnings.warn( | |
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " | |
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", | |
FutureWarning, | |
) | |
do_normalize = False | |
if do_normalize: | |
image = self.control_image_processor.normalize(image) | |
return image | |
def extract_images( | |
self, | |
data: Union[str, List[str], Image.Image, List[Image.Image], np.ndarray], | |
target_width: int = None, | |
target_height: int = None, | |
return_type: str = "numpy", | |
data_channel_order: str = "b h w c", | |
processor_params: Dict = None, | |
input_rgb_order: str = "rgb", | |
return_rgb_order: str = "rgb", | |
) -> Union[np.ndarray, torch.Tensor]: | |
data = self.processor( | |
data, | |
data_channel_order=data_channel_order, | |
target_height=target_height, | |
target_width=target_width, | |
return_type="torch", | |
processor_params=processor_params, | |
return_data_channel_order="b c h w", | |
input_rgb_order=input_rgb_order, | |
return_rgb_order=return_rgb_order, | |
) | |
# return_pose_only (bool): if true, only return pose keypoints in array format | |
if "return_pose_only" in processor_params.keys(): | |
if ( | |
self.detector_name == "DWposeDetector" | |
and processor_params["return_pose_only"] | |
): | |
return data | |
batch = self.prepare_image(image=data, width=target_width, height=target_height) | |
with torch.no_grad(): | |
batch = batch.to(self.device, dtype=self.dtype) | |
emb = self.controlnet.controlnet_cond_embedding(batch) | |
if return_type == "numpy": | |
emb = emb.cpu().numpy() | |
return emb | |
def extract_video( | |
self, | |
video_dataset, | |
target_width: int = None, | |
target_height: int = None, | |
return_type: str = "numpy", | |
processor_params: Dict = None, | |
input_rgb_order: str = "rgb", | |
return_rgb_order: str = "rgb", | |
) -> Union[np.ndarray, torch.Tensor]: | |
embs = [] | |
sample_indexs = [] | |
with torch.no_grad(): | |
for i, (batch, batch_index) in enumerate(video_dataset): | |
# print(f"============== extract img begin") | |
# print(batch.shape) | |
t0 = time.time() | |
emb = self.extract_images( | |
data=batch, | |
target_width=target_width, | |
target_height=target_height, | |
return_type=return_type, | |
processor_params=processor_params, | |
input_rgb_order=input_rgb_order, | |
return_rgb_order=return_rgb_order, | |
) | |
torch.cuda.synchronize() | |
t1 = time.time() | |
# print(f"============== extract img end TIME COST:{t1-t0}\n") | |
embs.append(emb) | |
sample_indexs.extend(batch_index) | |
sample_indexs = np.array(sample_indexs) | |
# return_pose_only (bool): if true, only return pose keypoints in array format | |
if "return_pose_only" in processor_params.keys(): | |
if ( | |
self.detector_name == "DWposeDetector" | |
and processor_params["return_pose_only"] | |
): | |
embs = np.concatenate(embs, axis=0) | |
return sample_indexs, embs | |
if return_type == "numpy": | |
embs = np.concatenate(embs, axis=0) | |
elif return_type == "torch": | |
embs = torch.concat(embs, dim=0) | |
sample_indexs = torch.from_numpy(sample_indexs) | |
return sample_indexs, embs | |
def extract( | |
self, | |
data: Union[str, List[str]], | |
data_type: Literal["image", "video"], | |
return_type: str = "numpy", | |
save_emb_path: str = None, | |
save_type: str = "h5py", | |
emb_key: str = "emb", | |
sample_index_key: str = "sample_indexs", | |
insert_name_to_key: bool = False, | |
overwrite: bool = False, | |
target_width: int = None, | |
target_height: int = None, | |
save_sample_index: bool = True, | |
processor_params: Dict = None, | |
input_rgb_order: str = "rgb", | |
return_rgb_order: str = "rgb", | |
**kwargs, | |
) -> Union[np.ndarray, torch.tensor]: | |
if self.name is not None and insert_name_to_key: | |
emb_key = f"{self.name}_{emb_key}" | |
sample_index_key = f"{self.name}_{sample_index_key}" | |
if save_emb_path is not None and os.path.exists(save_emb_path): | |
with h5py.File(save_emb_path, "r") as f: | |
if not overwrite and emb_key in f and sample_index_key in f: | |
return None | |
if data_type == "image": | |
emb = self.extract_images( | |
data=data, | |
return_type=return_type, | |
target_height=target_height, | |
target_width=target_width, | |
processor_params=processor_params, | |
input_rgb_order=input_rgb_order, | |
return_rgb_order=return_rgb_order, | |
) | |
if save_emb_path is None: | |
return emb | |
else: | |
raise NotImplementedError("save images emb") | |
elif data_type == "video": | |
sample_indexs, emb = self.extract_video( | |
video_dataset=data, | |
return_type=return_type, | |
processor_params=processor_params, | |
input_rgb_order=input_rgb_order, | |
return_rgb_order=return_rgb_order, | |
target_height=target_height, | |
target_width=target_width, | |
**kwargs, | |
) | |
if save_emb_path is None: | |
return sample_indexs, emb | |
else: | |
if save_type == "h5py": | |
self.save_video_emb_with_h5py( | |
save_emb_path=save_emb_path, | |
emb=emb, | |
emb_key=emb_key, | |
sample_indexs=sample_indexs, | |
sample_index_key=sample_index_key, | |
save_sample_index=save_sample_index, | |
overwrite=overwrite, | |
) | |
return sample_indexs, emb | |
else: | |
raise ValueError(f"only support save_type={save_type}") | |
def save_video_emb_with_h5py( | |
save_emb_path: str, | |
emb: np.ndarray = None, | |
emb_key: str = "emb", | |
sample_indexs: np.ndarray = None, | |
sample_index_key: str = "sample_indexs", | |
overwrite: bool = False, | |
save_sample_index: bool = True, | |
) -> h5py.File: | |
save_value_with_h5py(save_emb_path, value=emb, key=emb_key, overwrite=overwrite) | |
if save_sample_index: | |
save_value_with_h5py( | |
save_emb_path, | |
value=sample_indexs, | |
key=sample_index_key, | |
overwrite=overwrite, | |
dtype=np.uint32, | |
) | |
def get_controlnet_params( | |
controlnet_names: Union[ | |
Literal[ | |
"pose", | |
"pose_body", | |
"pose_hand", | |
"pose_face", | |
"pose_hand_body", | |
"pose_hand_face", | |
"pose_all", | |
"dwpose", | |
"canny", | |
"hed", | |
"hed_scribble", | |
"depth", | |
"pidi", | |
"normal_bae", | |
"lineart", | |
"lineart_anime", | |
"zoe", | |
"sam", | |
"mobile_sam", | |
"leres", | |
"content", | |
"face_detector", | |
], | |
List[str], | |
], | |
detect_resolution: int = None, | |
image_resolution: int = None, | |
include_body: bool = False, | |
include_hand: bool = False, | |
include_face: bool = False, | |
hand_and_face: bool = None, | |
) -> Dict: | |
"""通过简单 字符串参数就选择配置好的完整controlnet参数 | |
Args: | |
controlnet_conds (Union[ Literal[ "pose", "canny", "hed", "hed_scribble", "depth", "pidi", "normal_bae", "lineart", "lineart_anime", "zoe", "sam", "mobile_sam", "leres", "content", "face_detector", ], List[str], ]): _description_ | |
detect_resolution (int, optional): controlnet_aux图像处理需要的参数,尽量是64的整倍数. Defaults to None. | |
image_resolution (int, optional): controlnet_aux图像处理需要的参数,尽量是64的整倍数. Defaults to None. | |
include_body (bool, optional): controlnet 是否包含身体. Defaults to False. | |
hand_and_face (bool, optional): pose controlnet 是否包含头和身体. Defaults to False. | |
Returns: | |
Dict: ControlnetProcessor需要的字典参数 | |
""" | |
controlnet_cond_maps = { | |
"pose": { | |
"middle": "pose", | |
"detector_name": "OpenposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_body": include_body, | |
"include_hand": include_hand, | |
"include_face": include_face, | |
"hand_and_face": hand_and_face, | |
}, | |
}, | |
"pose_body": { | |
"middle": "pose", | |
"detector_name": "OpenposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_body": True, | |
"include_hand": False, | |
"include_face": False, | |
"hand_and_face": False, | |
}, | |
}, | |
"pose_hand": { | |
"middle": "pose", | |
"detector_name": "OpenposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_body": False, | |
"include_hand": True, | |
"include_face": False, | |
"hand_and_face": False, | |
}, | |
}, | |
"pose_face": { | |
"middle": "pose", | |
"detector_name": "OpenposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_body": False, | |
"include_hand": False, | |
"include_face": True, | |
"hand_and_face": False, | |
}, | |
}, | |
"pose_hand_body": { | |
"middle": "pose", | |
"detector_name": "OpenposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_body": True, | |
"include_hand": True, | |
"include_face": False, | |
"hand_and_face": False, | |
}, | |
}, | |
"pose_hand_face": { | |
"middle": "pose", | |
"detector_name": "OpenposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_body": False, | |
"include_hand": True, | |
"include_face": True, | |
"hand_and_face": True, | |
}, | |
}, | |
"dwpose": { | |
"middle": "dwpose", | |
"detector_name": "DWposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"dwpose_face": { | |
"middle": "dwpose", | |
"detector_name": "DWposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_hand": False, | |
"include_body": False, | |
}, | |
}, | |
"dwpose_hand": { | |
"middle": "dwpose", | |
"detector_name": "DWposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_face": False, | |
"include_body": False, | |
}, | |
}, | |
"dwpose_body": { | |
"middle": "dwpose", | |
"detector_name": "DWposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_face": False, | |
"include_hand": False, | |
}, | |
}, | |
"dwpose_body_hand": { | |
"middle": "dwpose", | |
"detector_name": "DWposeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_face": False, | |
"include_hand": True, | |
"include_body": True, | |
}, | |
}, | |
"canny": { | |
"middle": "canny", | |
"detector_name": "CannyDetector", | |
# "detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_canny", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"tile": { | |
"middle": "tile", | |
"detector_name": None, | |
"detector_id": None, | |
"controlnet_model_path": "lllyasviel/control_v11f1e_sd15_tile", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"include_body": include_body, | |
"hand_and_face": hand_and_face, | |
}, | |
}, | |
# 隶属线条检测 | |
"hed": { | |
"middle": "hed", | |
"detector_name": "HEDdetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/sd-controlnet-hed", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"hed_scribble": { | |
"middle": "hed", | |
"detector_name": "HEDdetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_scribble", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"depth": { | |
"middle": "depth", | |
"detector_name": "MidasDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11f1p_sd15_depth", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"pidi": { | |
"middle": "pidi", | |
"detector_name": "PidiNetDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11f1p_sd15_depth", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"normal_bae": { | |
"middle": "normal_bae", | |
"detector_name": "NormalBaeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_normalbae", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"lineart": { | |
"middle": "lineart", | |
"detector_name": "LineartDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_lineart", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
"coarse": True, | |
}, | |
}, | |
"lineart_anime": { | |
"middle": "lineart_anime", | |
"detector_name": "LineartAnimeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15s2_lineart_anime", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"zoe": { | |
"middle": "zoe", | |
"detector_name": "ZoeDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11f1p_sd15_depth", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"sam": { | |
"middle": "sam", | |
"detector_name": "SamDetector", | |
"detector_id": "ybelkada/segment-anything", | |
"processor_cls_params": {"subfolder": "checkpoints"}, | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_seg", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"mobile_sam": { | |
"middle": "mobile_sam", | |
"detector_name": "SamDetector", | |
"detector_id": "dhkim2810/MobileSAM", | |
"processor_cls_params": { | |
"subfolder": "checkpoints", | |
"model_type": "vit_t", | |
"filename": "mobile_sam.pt", | |
}, | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_seg", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"leres": { | |
"middle": "leres", | |
"detector_name": "LeresDetector", | |
"detector_id": "lllyasviel/Annotators", | |
"controlnet_model_path": "lllyasviel/control_v11f1p_sd15_depth", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
# error | |
"content": { | |
"middle": "content", | |
"detector_name": "ContentShuffleDetector", | |
"controlnet_model_path": "lllyasviel/control_v11e_sd15_shuffle", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
}, | |
"face_detector": { | |
"middle": "face_detector", | |
"detector_name": "MediapipeFaceDetector", | |
"processor_params": { | |
"detect_resolution": detect_resolution, | |
"image_resolution": image_resolution, | |
}, | |
"controlnet_model_path": "lllyasviel/control_v11p_sd15_openpose", | |
}, | |
} | |
def complete(dct): | |
if "detector_id" not in dct: | |
dct["detector_id"] = None | |
if "processor_cls_params" not in dct: | |
dct["processor_cls_params"] = None | |
return dct | |
if isinstance(controlnet_names, str): | |
return complete(controlnet_cond_maps[controlnet_names]) | |
else: | |
params = [complete(controlnet_cond_maps[name]) for name in controlnet_names] | |
return params | |
def load_controlnet_model( | |
controlnet_names: Union[str, List[str]], | |
device: str, | |
dtype=torch.dtype, | |
need_controlnet_processor: bool = True, | |
need_controlnet=True, | |
detect_resolution: int = None, | |
image_resolution: int = None, | |
include_body: bool = False, | |
include_face: bool = False, | |
include_hand: bool = False, | |
hand_and_face: bool = None, | |
) -> Tuple[nn.Module, Callable, Dict]: | |
controlnet_params = get_controlnet_params( | |
controlnet_names, | |
detect_resolution=detect_resolution, | |
image_resolution=image_resolution, | |
include_body=include_body, | |
include_face=include_face, | |
hand_and_face=hand_and_face, | |
include_hand=include_hand, | |
) | |
if need_controlnet_processor: | |
if not isinstance(controlnet_params, list): | |
controlnet_processor = ControlnetProcessor( | |
detector_name=controlnet_params["detector_name"], | |
detector_id=controlnet_params["detector_id"], | |
processor_params=controlnet_params["processor_cls_params"], | |
device=device, | |
dtype=dtype, | |
processor_name=controlnet_params["middle"], | |
) | |
processor_params = controlnet_params["processor_params"] | |
else: | |
controlnet_processor = MultiControlnetProcessor( | |
[ | |
ControlnetProcessor( | |
detector_name=controlnet_param["detector_name"], | |
detector_id=controlnet_param["detector_id"], | |
processor_params=controlnet_param["processor_cls_params"], | |
device=device, | |
dtype=dtype, | |
processor_name=controlnet_param["middle"], | |
) | |
for controlnet_param in controlnet_params | |
] | |
) | |
processor_params = [ | |
controlnet_param["processor_params"] | |
for controlnet_param in controlnet_params | |
] | |
else: | |
controlnet_processor = None | |
processor_params = None | |
if need_controlnet: | |
if isinstance(controlnet_params, List): | |
# TODO: support MultiControlNetModel.save_pretrained str path | |
controlnet = MultiControlNetModel( | |
[ | |
ControlNetModel.from_pretrained(d["controlnet_model_path"]) | |
for d in controlnet_params | |
] | |
) | |
else: | |
controlnet_model_path = controlnet_params["controlnet_model_path"] | |
controlnet = ControlNetModel.from_pretrained(controlnet_model_path) | |
controlnet = controlnet.to(device=device, dtype=dtype) | |
else: | |
controlnet = None | |
return controlnet, controlnet_processor, processor_params | |
def prepare_image( | |
image, # b c t h w | |
image_processor: Callable, | |
width=None, | |
height=None, | |
return_type: Literal["numpy", "torch"] = "numpy", | |
): | |
if isinstance(image, List) and isinstance(image[0], str): | |
raise NotImplementedError | |
if isinstance(image, List) and isinstance(image[0], np.ndarray): | |
image = np.concatenate(image, axis=0) | |
if isinstance(image, np.ndarray): | |
image = torch.from_numpy(image) | |
if image.ndim == 5: | |
image = rearrange(image, "b c t h w-> (b t) c h w") | |
if height is None: | |
height = image.shape[-2] | |
if width is None: | |
width = image.shape[-1] | |
width, height = (x - x % image_processor.vae_scale_factor for x in (width, height)) | |
if height != image.shape[-2] or width != image.shape[-1]: | |
image = torch.nn.functional.interpolate( | |
image, size=(height, width), mode="bilinear" | |
) | |
image = image.to(dtype=torch.float32) / 255.0 | |
do_normalize = image_processor.config.do_normalize | |
if image.min() < 0: | |
warnings.warn( | |
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " | |
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", | |
FutureWarning, | |
) | |
do_normalize = False | |
if do_normalize: | |
image = image_processor.normalize(image) | |
if return_type == "numpy": | |
image = image.numpy() | |
return image | |
class PoseKPs2ImgConverter(object): | |
def __init__( | |
self, | |
target_width: int, | |
target_height: int, | |
num_candidates: int = 10, | |
image_processor: Callable = None, | |
include_body: bool = True, | |
include_face: bool = False, | |
hand_and_face: bool = None, | |
include_hand: bool = True, | |
) -> None: | |
self.target_width = target_width | |
self.target_height = target_height | |
self.num_candidates = num_candidates | |
self.image_processor = image_processor | |
self.include_body = include_body | |
self.include_face = include_face | |
self.hand_and_face = hand_and_face | |
self.include_hand = include_hand | |
def __call__(self, kps: np.array) -> Any: | |
# draw pose | |
# (b, max_num=10, 134, 3) last dim, x,y,score | |
num_candidates = 0 | |
for idx_t in range(self.num_candidates): | |
if np.sum(kps[:, idx_t, :, :]) == 0: | |
num_candidates = idx_t | |
break | |
if num_candidates > 0: | |
kps = kps[:, 0:num_candidates, :, :] | |
candidate = kps[..., :2] | |
subset = kps[..., 2] | |
poses = [ | |
candidate2pose( | |
candidate[i], | |
subset[i], | |
include_body=self.include_body, | |
include_face=self.include_face, | |
hand_and_face=self.hand_and_face, | |
include_hand=self.include_hand, | |
) | |
for i in range(candidate.shape[0]) | |
] | |
pose_imgs = [ | |
pose2map( | |
pose, | |
self.target_height, | |
self.target_width, | |
min(self.target_height, self.target_width), | |
min(self.target_height, self.target_width), | |
) | |
for pose in poses | |
] | |
pose_imgs = np.stack(pose_imgs, axis=0) # b h w c | |
else: | |
pose_imgs = np.zeros( | |
shape=(kps.shape[0], self.target_height, self.target_width, 3), | |
dtype=np.uint8, | |
) | |
pose_imgs = rearrange(pose_imgs, "b h w c -> b c h w") | |
if self.image_processor is not None: | |
pose_imgs = prepare_image( | |
image=pose_imgs, | |
width=self.target_width, | |
height=self.target_height, | |
image_processor=self.image_processor, | |
return_type="numpy", | |
) | |
return pose_imgs | |