|
from typing import AbstractSet, List, Literal, Optional, Union |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from controlnet_aux import ( |
|
HEDdetector, |
|
LineartDetector, |
|
OpenposeDetector, |
|
PidiNetDetector, |
|
) |
|
from diffusers import ( |
|
ControlNetModel, |
|
DiffusionPipeline, |
|
EulerAncestralDiscreteScheduler, |
|
StableDiffusionAdapterPipeline, |
|
StableDiffusionControlNetImg2ImgPipeline, |
|
StableDiffusionControlNetPipeline, |
|
StableDiffusionXLAdapterPipeline, |
|
StableDiffusionXLControlNetPipeline, |
|
T2IAdapter, |
|
UniPCMultistepScheduler, |
|
) |
|
from diffusers.pipelines.controlnet import MultiControlNetModel |
|
from PIL import Image |
|
from pydash import has |
|
from torch.nn import Linear |
|
from tqdm import gui |
|
from transformers import pipeline |
|
|
|
import internals.util.image as ImageUtil |
|
from external.midas import apply_midas |
|
from internals.data.result import Result |
|
from internals.pipelines.commons import AbstractPipeline |
|
from internals.util.cache import clear_cuda_and_gc |
|
from internals.util.commons import download_image |
|
from internals.util.config import ( |
|
get_hf_cache_dir, |
|
get_hf_token, |
|
get_is_sdxl, |
|
get_model_dir, |
|
) |
|
|
|
CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"] |
|
|
|
|
|
class StableDiffusionNetworkModelPipelineLoader: |
|
"""Loads the pipeline for network module, eg: controlnet or t2i. |
|
Does not throw error in case of unsupported configurations, instead it returns None. |
|
""" |
|
|
|
def __new__( |
|
cls, |
|
is_sdxl, |
|
is_img2img, |
|
network_model, |
|
pipeline_type, |
|
base_pipe: Optional[AbstractSet] = None, |
|
): |
|
if is_sdxl and is_img2img: |
|
|
|
print("Warning: Tile upscale is not supported on SDXL") |
|
return None |
|
|
|
if base_pipe is None: |
|
pretrained = True |
|
kwargs = { |
|
"pretrained_model_name_or_path": get_model_dir(), |
|
"torch_dtype": torch.float16, |
|
"use_auth_token": get_hf_token(), |
|
"cache_dir": get_hf_cache_dir(), |
|
} |
|
else: |
|
pretrained = False |
|
kwargs = { |
|
**base_pipe.pipe.components, |
|
} |
|
|
|
if is_sdxl and pipeline_type == "controlnet": |
|
model = ( |
|
StableDiffusionXLControlNetPipeline.from_pretrained |
|
if pretrained |
|
else StableDiffusionXLControlNetPipeline |
|
) |
|
return model(controlnet=network_model, **kwargs).to("cuda") |
|
if is_sdxl and pipeline_type == "t2i": |
|
model = ( |
|
StableDiffusionXLAdapterPipeline.from_pretrained |
|
if pretrained |
|
else StableDiffusionXLAdapterPipeline |
|
) |
|
return model(adapter=network_model, **kwargs).to("cuda") |
|
if is_img2img and pipeline_type == "controlnet": |
|
model = ( |
|
StableDiffusionControlNetImg2ImgPipeline.from_pretrained |
|
if pretrained |
|
else StableDiffusionControlNetImg2ImgPipeline |
|
) |
|
return model(controlnet=network_model, **kwargs).to("cuda") |
|
if pipeline_type == "controlnet": |
|
model = ( |
|
StableDiffusionControlNetPipeline.from_pretrained |
|
if pretrained |
|
else StableDiffusionControlNetPipeline |
|
) |
|
return model(controlnet=network_model, **kwargs).to("cuda") |
|
if pipeline_type == "t2i": |
|
model = ( |
|
StableDiffusionAdapterPipeline.from_pretrained |
|
if pretrained |
|
else StableDiffusionAdapterPipeline |
|
) |
|
return model(adapter=network_model, **kwargs).to("cuda") |
|
|
|
print( |
|
f"Warning: Unsupported configuration {is_sdxl=}, {is_img2img=}, {pipeline_type=}" |
|
) |
|
return None |
|
|
|
|
|
class ControlNet(AbstractPipeline): |
|
__current_task_name = "" |
|
__loaded = False |
|
__pipe_type = None |
|
|
|
def init(self, pipeline: AbstractPipeline): |
|
setattr(self, "__pipeline", pipeline) |
|
|
|
def load_model(self, task_name: CONTROLNET_TYPES): |
|
"Appropriately loads the network module, pipelines and cache it for reuse." |
|
|
|
config = self.__model_sdxl if get_is_sdxl() else self.__model_normal |
|
if self.__current_task_name == task_name: |
|
return |
|
model = config[task_name] |
|
if not model: |
|
raise Exception(f"ControlNet is not supported for {task_name}") |
|
while model in list(config.keys()): |
|
task_name = model |
|
model = config[task_name] |
|
|
|
pipeline_type = ( |
|
self.__model_sdxl_types[task_name] |
|
if get_is_sdxl() |
|
else self.__model_normal_types[task_name] |
|
) |
|
|
|
if "," in model: |
|
model = [m.strip() for m in model.split(",")] |
|
|
|
model = self.__load_network_model(model, pipeline_type) |
|
|
|
self.__load_pipeline(model, pipeline_type) |
|
|
|
self.__current_task_name = task_name |
|
|
|
clear_cuda_and_gc() |
|
|
|
def __load_network_model(self, model_name, pipeline_type): |
|
"Loads the network module, eg: ControlNet or T2I Adapters" |
|
|
|
def load_controlnet(model): |
|
return ControlNetModel.from_pretrained( |
|
model, |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
).to("cuda") |
|
|
|
def load_t2i(model): |
|
return T2IAdapter.from_pretrained( |
|
model, |
|
torch_dtype=torch.float16, |
|
varient="fp16", |
|
).to("cuda") |
|
|
|
if type(model_name) == str: |
|
if pipeline_type == "controlnet": |
|
return load_controlnet(model_name) |
|
if pipeline_type == "t2i": |
|
return load_t2i(model_name) |
|
raise Exception("Invalid pipeline type") |
|
elif type(model_name) == list: |
|
if pipeline_type == "controlnet": |
|
cns = [] |
|
for model in model_name: |
|
cns.append(load_controlnet(model)) |
|
return MultiControlNetModel(cns).to("cuda") |
|
elif pipeline_type == "t2i": |
|
raise Exception("Multi T2I adapters are not supported") |
|
raise Exception("Invalid pipeline type") |
|
|
|
def __load_pipeline(self, network_model, pipeline_type): |
|
"Load the base pipeline(s) (if not loaded already) based on pipeline type and attaches the network module to the pipeline" |
|
|
|
def patch_pipe(pipe): |
|
if not pipe: |
|
|
|
return None |
|
|
|
if get_is_sdxl(): |
|
pipe.enable_vae_tiling() |
|
pipe.enable_vae_slicing() |
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|
pipe.scheduler.config |
|
) |
|
else: |
|
pipe.enable_xformers_memory_efficient_attention() |
|
return pipe |
|
|
|
|
|
|
|
if not self.__loaded or self.__pipe_type != pipeline_type: |
|
|
|
pipe = StableDiffusionNetworkModelPipelineLoader( |
|
is_sdxl=get_is_sdxl(), |
|
is_img2img=True, |
|
network_model=network_model, |
|
pipeline_type=pipeline_type, |
|
base_pipe=getattr(self, "__pipeline", None), |
|
) |
|
pipe = patch_pipe(pipe) |
|
if pipe: |
|
self.pipe = pipe |
|
|
|
|
|
pipe2 = StableDiffusionNetworkModelPipelineLoader( |
|
is_sdxl=get_is_sdxl(), |
|
is_img2img=False, |
|
network_model=network_model, |
|
pipeline_type=pipeline_type, |
|
base_pipe=getattr(self, "__pipeline", None), |
|
) |
|
pipe2 = patch_pipe(pipe2) |
|
if pipe2: |
|
self.pipe2 = pipe2 |
|
|
|
self.__loaded = True |
|
self.__pipe_type = pipeline_type |
|
|
|
|
|
if pipeline_type == "controlnet": |
|
if hasattr(self, "pipe"): |
|
setattr(self.pipe, "controlnet", network_model) |
|
if hasattr(self, "pipe2"): |
|
setattr(self.pipe2, "controlnet", network_model) |
|
elif pipeline_type == "t2i": |
|
if hasattr(self, "pipe"): |
|
setattr(self.pipe, "adapter", network_model) |
|
if hasattr(self, "pipe2"): |
|
setattr(self.pipe2, "adapter", network_model) |
|
|
|
if hasattr(self, "pipe"): |
|
self.pipe = self.pipe.to("cuda") |
|
if hasattr(self, "pipe2"): |
|
self.pipe2 = self.pipe2.to("cuda") |
|
|
|
clear_cuda_and_gc() |
|
|
|
def process(self, **kwargs): |
|
if self.__current_task_name == "pose": |
|
return self.process_pose(**kwargs) |
|
if self.__current_task_name == "canny": |
|
return self.process_canny(**kwargs) |
|
if self.__current_task_name == "scribble": |
|
return self.process_scribble(**kwargs) |
|
if self.__current_task_name == "linearart": |
|
return self.process_linearart(**kwargs) |
|
if self.__current_task_name == "tile_upscaler": |
|
return self.process_tile_upscaler(**kwargs) |
|
raise Exception("ControlNet is not loaded with any model") |
|
|
|
@torch.inference_mode() |
|
def process_canny( |
|
self, |
|
prompt: List[str], |
|
imageUrl: str, |
|
seed: int, |
|
num_inference_steps: int, |
|
negative_prompt: List[str], |
|
height: int, |
|
width: int, |
|
guidance_scale: float = 9, |
|
**kwargs, |
|
): |
|
if self.__current_task_name != "canny": |
|
raise Exception("ControlNet is not loaded with canny model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
init_image = download_image(imageUrl).resize((width, height)) |
|
init_image = ControlNet.canny_detect_edge(init_image) |
|
|
|
kwargs = { |
|
"prompt": prompt, |
|
"image": init_image, |
|
"guidance_scale": guidance_scale, |
|
"num_images_per_prompt": 1, |
|
"negative_prompt": negative_prompt, |
|
"num_inference_steps": num_inference_steps, |
|
"height": height, |
|
"width": width, |
|
**kwargs, |
|
} |
|
|
|
result = self.pipe2.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_pose( |
|
self, |
|
prompt: List[str], |
|
image: List[Image.Image], |
|
seed: int, |
|
num_inference_steps: int, |
|
negative_prompt: List[str], |
|
height: int, |
|
width: int, |
|
guidance_scale: float = 7.5, |
|
**kwargs, |
|
): |
|
if self.__current_task_name != "pose": |
|
raise Exception("ControlNet is not loaded with pose model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
kwargs = { |
|
"prompt": prompt[0], |
|
"image": image, |
|
"num_images_per_prompt": 4, |
|
"num_inference_steps": num_inference_steps, |
|
"negative_prompt": negative_prompt[0], |
|
"guidance_scale": guidance_scale, |
|
"height": height, |
|
"width": width, |
|
**kwargs, |
|
} |
|
print(kwargs) |
|
result = self.pipe2.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_tile_upscaler( |
|
self, |
|
imageUrl: str, |
|
prompt: str, |
|
negative_prompt: str, |
|
num_inference_steps: int, |
|
seed: int, |
|
height: int, |
|
width: int, |
|
resize_dimension: int, |
|
guidance_scale: float = 7.5, |
|
**kwargs, |
|
): |
|
if self.__current_task_name != "tile_upscaler": |
|
raise Exception("ControlNet is not loaded with tile_upscaler model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
init_image = download_image(imageUrl).resize((width, height)) |
|
condition_image = self.__resize_for_condition_image( |
|
init_image, resize_dimension |
|
) |
|
|
|
kwargs = { |
|
"image": condition_image, |
|
"prompt": prompt, |
|
"control_image": condition_image, |
|
"num_inference_steps": num_inference_steps, |
|
"negative_prompt": negative_prompt, |
|
"height": condition_image.size[1], |
|
"width": condition_image.size[0], |
|
"guidance_scale": guidance_scale, |
|
**kwargs, |
|
} |
|
result = self.pipe.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_scribble( |
|
self, |
|
image: List[Image.Image], |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Union[str, List[str]], |
|
num_inference_steps: int, |
|
seed: int, |
|
height: int, |
|
width: int, |
|
guidance_scale: float = 7.5, |
|
**kwargs, |
|
): |
|
if self.__current_task_name != "scribble": |
|
raise Exception("ControlNet is not loaded with scribble model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
sdxl_args = ( |
|
{ |
|
"guidance_scale": 6, |
|
"adapter_conditioning_scale": 1.0, |
|
"adapter_conditioning_factor": 1.0, |
|
} |
|
if get_is_sdxl() |
|
else {} |
|
) |
|
|
|
kwargs = { |
|
"image": image, |
|
"prompt": prompt, |
|
"num_inference_steps": num_inference_steps, |
|
"negative_prompt": negative_prompt, |
|
"height": height, |
|
"width": width, |
|
"guidance_scale": guidance_scale, |
|
**sdxl_args, |
|
**kwargs, |
|
} |
|
result = self.pipe2.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_linearart( |
|
self, |
|
imageUrl: str, |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Union[str, List[str]], |
|
num_inference_steps: int, |
|
seed: int, |
|
height: int, |
|
width: int, |
|
guidance_scale: float = 7.5, |
|
**kwargs, |
|
): |
|
if self.__current_task_name != "linearart": |
|
raise Exception("ControlNet is not loaded with linearart model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
init_image = download_image(imageUrl).resize((width, height)) |
|
condition_image = ControlNet.linearart_condition_image(init_image) |
|
|
|
|
|
sdxl_args = ( |
|
{ |
|
"guidance_scale": 6, |
|
"adapter_conditioning_scale": 1.0, |
|
"adapter_conditioning_factor": 1.0, |
|
} |
|
if get_is_sdxl() |
|
else {} |
|
) |
|
|
|
kwargs = { |
|
"image": [condition_image] * 4, |
|
"prompt": prompt, |
|
"num_inference_steps": num_inference_steps, |
|
"negative_prompt": negative_prompt, |
|
"height": height, |
|
"width": width, |
|
"guidance_scale": guidance_scale, |
|
**sdxl_args, |
|
**kwargs, |
|
} |
|
result = self.pipe2.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|
|
def cleanup(self): |
|
"""Doesn't do anything considering new diffusers has itself a cleanup mechanism |
|
after controlnet generation""" |
|
pass |
|
|
|
def detect_pose(self, imageUrl: str) -> Image.Image: |
|
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") |
|
image = download_image(imageUrl) |
|
image = detector.__call__(image) |
|
return image |
|
|
|
@staticmethod |
|
def scribble_image(image: Image.Image) -> Image.Image: |
|
processor = HEDdetector.from_pretrained("lllyasviel/Annotators") |
|
image = processor.__call__(input_image=image, scribble=True) |
|
return image |
|
|
|
@staticmethod |
|
def linearart_condition_image(image: Image.Image, **kwargs) -> Image.Image: |
|
processor = LineartDetector.from_pretrained("lllyasviel/Annotators") |
|
if get_is_sdxl(): |
|
kwargs = {"detect_resolution": 384, **kwargs} |
|
|
|
image = processor.__call__(input_image=image, **kwargs) |
|
return image |
|
|
|
@staticmethod |
|
def depth_image(image: Image.Image) -> Image.Image: |
|
global midas, midas_transforms |
|
if "midas" not in globals(): |
|
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS").to("cuda") |
|
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") |
|
transform = midas_transforms.default_transform |
|
|
|
cv_image = np.array(image) |
|
img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) |
|
|
|
input_batch = transform(img).to("cuda") |
|
with torch.no_grad(): |
|
prediction = midas(input_batch) |
|
|
|
prediction = torch.nn.functional.interpolate( |
|
prediction.unsqueeze(1), |
|
size=img.shape[:2], |
|
mode="bicubic", |
|
align_corners=False, |
|
).squeeze() |
|
|
|
output = prediction.cpu().numpy() |
|
formatted = (output * 255 / np.max(output)).astype("uint8") |
|
img = Image.fromarray(formatted) |
|
return img |
|
|
|
@staticmethod |
|
def pidinet_image(image: Image.Image) -> Image.Image: |
|
pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda") |
|
image = pidinet.__call__(input_image=image, apply_filter=True) |
|
return image |
|
|
|
@staticmethod |
|
def canny_detect_edge(image: Image.Image) -> Image.Image: |
|
image_array = np.array(image) |
|
|
|
low_threshold = 100 |
|
high_threshold = 200 |
|
|
|
image_array = cv2.Canny(image_array, low_threshold, high_threshold) |
|
image_array = image_array[:, :, None] |
|
image_array = np.concatenate([image_array, image_array, image_array], axis=2) |
|
canny_image = Image.fromarray(image_array) |
|
return canny_image |
|
|
|
def __resize_for_condition_image(self, image: Image.Image, resolution: int): |
|
input_image = image.convert("RGB") |
|
W, H = input_image.size |
|
k = float(resolution) / max(W, H) |
|
H *= k |
|
W *= k |
|
H = int(round(H / 64.0)) * 64 |
|
W = int(round(W / 64.0)) * 64 |
|
img = input_image.resize((W, H), resample=Image.LANCZOS) |
|
return img |
|
|
|
__model_normal = { |
|
"pose": "lllyasviel/control_v11f1p_sd15_depth, lllyasviel/control_v11p_sd15_openpose", |
|
"canny": "lllyasviel/control_v11p_sd15_canny", |
|
"linearart": "lllyasviel/control_v11p_sd15_lineart", |
|
"scribble": "lllyasviel/control_v11p_sd15_scribble", |
|
"tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile", |
|
} |
|
__model_normal_types = { |
|
"pose": "controlnet", |
|
"canny": "controlnet", |
|
"linearart": "controlnet", |
|
"scribble": "controlnet", |
|
"tile_upscaler": "controlnet", |
|
} |
|
|
|
__model_sdxl = { |
|
"pose": "thibaud/controlnet-openpose-sdxl-1.0", |
|
"canny": "diffusers/controlnet-canny-sdxl-1.0", |
|
"linearart": "TencentARC/t2i-adapter-lineart-sdxl-1.0", |
|
"scribble": "TencentARC/t2i-adapter-sketch-sdxl-1.0", |
|
"tile_upscaler": None, |
|
} |
|
__model_sdxl_types = { |
|
"pose": "controlnet", |
|
"canny": "controlnet", |
|
"linearart": "t2i", |
|
"scribble": "t2i", |
|
"tile_upscaler": None, |
|
} |
|
|