from typing import List, Literal, Union import cv2 import numpy as np from pydash import has import torch from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector from diffusers import ( ControlNetModel, DiffusionPipeline, StableDiffusionControlNetPipeline, UniPCMultistepScheduler, StableDiffusionXLControlNetPipeline, ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import ( MultiControlNetModel, ) from PIL import Image from torch.nn import Linear from tqdm import gui from transformers import pipeline import internals.util.image as ImageUtil from external.midas import apply_midas from internals.data.result import Result from internals.pipelines.commons import AbstractPipeline from internals.pipelines.tileUpscalePipeline import ( StableDiffusionControlNetImg2ImgPipeline, ) from internals.util.cache import clear_cuda_and_gc from internals.util.commons import download_image from internals.util.config import ( get_hf_cache_dir, get_hf_token, get_model_dir, get_is_sdxl, ) CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"] class ControlNet(AbstractPipeline): __current_task_name = "" __loaded = False __pipeline: AbstractPipeline def init(self, pipeline: AbstractPipeline): self.__pipeline = pipeline def load_model(self, task_name: CONTROLNET_TYPES): config = self.__model_sdxl if get_is_sdxl() else self.__model_normal if self.__current_task_name == task_name: return model = config[task_name] if not model: raise Exception(f"ControlNet is not supported for {task_name}") while model in list(config.keys()): task_name = model # pyright: ignore model = config[task_name] controlnet = ControlNetModel.from_pretrained( model, torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), ).to("cuda") self.__current_task_name = task_name self.controlnet = controlnet self.__load() if hasattr(self, "pipe"): self.pipe.controlnet = controlnet if hasattr(self, "pipe2"): self.pipe2.controlnet = controlnet clear_cuda_and_gc() def __load(self): "Should not be called externally" if self.__loaded: return if not hasattr(self, "controlnet"): self.load_model("pose") # controlnet pipeline for tile upscaler if get_is_sdxl(): print("Warning: Tile upscale is not supported on SDXL") if self.__pipeline: pipe = StableDiffusionXLControlNetPipeline( controlnet=self.controlnet, **self.__pipeline.pipe.components ).to("cuda") else: pipe = StableDiffusionXLControlNetPipeline.from_pretrained( get_model_dir(), controlnet=self.controlnet, torch_dtype=torch.float16, use_auth_token=get_hf_token(), cache_dir=get_hf_cache_dir(), use_safetensors=True, ).to("cuda") pipe.enable_vae_tiling() pipe.enable_vae_slicing() pipe.enable_xformers_memory_efficient_attention() self.pipe2 = pipe else: if hasattr(self, "__pipeline"): pipe = StableDiffusionControlNetImg2ImgPipeline( controlnet=self.controlnet, **self.__pipeline.pipe.components ).to("cuda") else: pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( get_model_dir(), controlnet=self.controlnet, torch_dtype=torch.float16, use_auth_token=get_hf_token(), cache_dir=get_hf_cache_dir(), ).to("cuda") # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() pipe.enable_xformers_memory_efficient_attention() self.pipe = pipe # controlnet pipeline for canny and pose pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda") pipe2.scheduler = UniPCMultistepScheduler.from_config( pipe2.scheduler.config ) pipe2.enable_xformers_memory_efficient_attention() self.pipe2 = pipe2 self.__loaded = True def process(self, **kwargs): if self.__current_task_name == "pose": return self.process_pose(**kwargs) if self.__current_task_name == "canny": return self.process_canny(**kwargs) if self.__current_task_name == "scribble": return self.process_scribble(**kwargs) if self.__current_task_name == "linearart": return self.process_linearart(**kwargs) if self.__current_task_name == "tile_upscaler": return self.process_tile_upscaler(**kwargs) raise Exception("ControlNet is not loaded with any model") @torch.inference_mode() def process_canny( self, prompt: List[str], imageUrl: str, seed: int, num_inference_steps: int, negative_prompt: List[str], height: int, width: int, guidance_scale: float = 9, **kwargs, ): if self.__current_task_name != "canny": raise Exception("ControlNet is not loaded with canny model") torch.manual_seed(seed) init_image = download_image(imageUrl).resize((width, height)) init_image = ControlNet.canny_detect_edge(init_image) kwargs = { "prompt": prompt, "image": init_image, "guidance_scale": guidance_scale, "num_images_per_prompt": 1, "negative_prompt": negative_prompt, "num_inference_steps": num_inference_steps, "height": height, "width": width, **kwargs, } result = self.pipe2.__call__(**kwargs) return Result.from_result(result) @torch.inference_mode() def process_pose( self, prompt: List[str], image: List[Image.Image], seed: int, num_inference_steps: int, negative_prompt: List[str], height: int, width: int, guidance_scale: float = 7.5, **kwargs, ): if self.__current_task_name != "pose": raise Exception("ControlNet is not loaded with pose model") torch.manual_seed(seed) kwargs = { "prompt": prompt[0], "image": [image[0]], "num_images_per_prompt": 4, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt[0], "guidance_scale": guidance_scale, "height": height, "width": width, **kwargs, } result = self.pipe2.__call__(**kwargs) return Result.from_result(result) @torch.inference_mode() def process_tile_upscaler( self, imageUrl: str, prompt: str, negative_prompt: str, num_inference_steps: int, seed: int, height: int, width: int, resize_dimension: int, guidance_scale: float = 7.5, **kwargs, ): if self.__current_task_name != "tile_upscaler": raise Exception("ControlNet is not loaded with tile_upscaler model") torch.manual_seed(seed) init_image = download_image(imageUrl).resize((width, height)) condition_image = self.__resize_for_condition_image( init_image, resize_dimension ) kwargs = { "image": condition_image, "prompt": prompt, "controlnet_conditioning_image": condition_image, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": condition_image.size[1], "width": condition_image.size[0], "guidance_scale": guidance_scale, **kwargs, } result = self.pipe.__call__(**kwargs) return Result.from_result(result) @torch.inference_mode() def process_scribble( self, imageUrl: Union[str, Image.Image], prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], num_inference_steps: int, seed: int, height: int, width: int, guidance_scale: float = 7.5, **kwargs, ): if self.__current_task_name != "scribble": raise Exception("ControlNet is not loaded with scribble model") torch.manual_seed(seed) if isinstance(imageUrl, Image.Image): init_image = imageUrl.resize((width, height)) else: init_image = download_image(imageUrl).resize((width, height)) condition_image = self.__scribble_condition_image(init_image) kwargs = { "image": condition_image, "prompt": prompt, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": height, "width": width, "guidance_scale": guidance_scale, **kwargs, } result = self.pipe2.__call__(**kwargs) return Result.from_result(result) @torch.inference_mode() def process_linearart( self, imageUrl: str, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], num_inference_steps: int, seed: int, height: int, width: int, guidance_scale: float = 7.5, **kwargs, ): if self.__current_task_name != "linearart": raise Exception("ControlNet is not loaded with linearart model") torch.manual_seed(seed) init_image = download_image(imageUrl).resize((width, height)) condition_image = ControlNet.linearart_condition_image(init_image) kwargs = { "image": condition_image, "prompt": prompt, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": height, "width": width, "guidance_scale": guidance_scale, **kwargs, } result = self.pipe2.__call__(**kwargs) return Result.from_result(result) def cleanup(self): if hasattr(self, "pipe") and hasattr(self.pipe, "controlnet"): del self.pipe.controlnet if hasattr(self, "pipe2") and hasattr(self.pipe2, "controlnet"): del self.pipe2.controlnet if hasattr(self, "controlnet"): del self.controlnet self.__current_task_name = "" clear_cuda_and_gc() def detect_pose(self, imageUrl: str) -> Image.Image: detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") image = download_image(imageUrl) image = detector.__call__(image) return image def __scribble_condition_image(self, image: Image.Image) -> Image.Image: processor = HEDdetector.from_pretrained("lllyasviel/Annotators") image = processor.__call__(input_image=image, scribble=True) return image @staticmethod def linearart_condition_image(image: Image.Image) -> Image.Image: processor = LineartDetector.from_pretrained("lllyasviel/Annotators") image = processor.__call__(input_image=image) return image @staticmethod def depth_image(image: Image.Image) -> Image.Image: depth = np.array(image) depth = ImageUtil.HWC3(depth) depth, _ = apply_midas(depth) depth = ImageUtil.HWC3(depth) depth = Image.fromarray(depth) return depth @staticmethod def canny_detect_edge(image: Image.Image) -> Image.Image: image_array = np.array(image) low_threshold = 100 high_threshold = 200 image_array = cv2.Canny(image_array, low_threshold, high_threshold) image_array = image_array[:, :, None] image_array = np.concatenate([image_array, image_array, image_array], axis=2) canny_image = Image.fromarray(image_array) return canny_image def __resize_for_condition_image(self, image: Image.Image, resolution: int): input_image = image.convert("RGB") W, H = input_image.size k = float(resolution) / max(W, H) H *= k W *= k H = int(round(H / 64.0)) * 64 W = int(round(W / 64.0)) * 64 img = input_image.resize((W, H), resample=Image.LANCZOS) return img __model_normal = { "pose": "lllyasviel/control_v11p_sd15_openpose", "canny": "lllyasviel/control_v11p_sd15_canny", "linearart": "lllyasviel/control_v11p_sd15_lineart", "scribble": "lllyasviel/control_v11p_sd15_scribble", "tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile", } __model_sdxl = { "pose": "thibaud/controlnet-openpose-sdxl-1.0", "canny": "diffusers/controlnet-canny-sdxl-1.0", "linearart": "canny", "scribble": "canny", "tile_upscaler": None, }