import os from typing import AbstractSet, List, Literal, Optional, Union import cv2 import numpy as np import torch from controlnet_aux import ( HEDdetector, LineartDetector, OpenposeDetector, PidiNetDetector, ) from diffusers import ( ControlNetModel, DiffusionPipeline, EulerAncestralDiscreteScheduler, StableDiffusionAdapterPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLAdapterPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetPipeline, T2IAdapter, UniPCMultistepScheduler, ) from diffusers.pipelines.controlnet import MultiControlNetModel from PIL import Image from pydash import has from torch.nn import Linear from tqdm import gui from transformers import pipeline import internals.util.image as ImageUtil from internals.data.result import Result from internals.pipelines.commons import AbstractPipeline from internals.util import get_generators from internals.util.cache import clear_cuda_and_gc from internals.util.commons import download_image from internals.util.config import ( get_hf_cache_dir, get_hf_token, get_is_sdxl, get_model_dir, get_num_return_sequences, ) CONTROLNET_TYPES = Literal[ "pose", "canny", "scribble", "linearart", "tile_upscaler", "canny_2x" ] __CN_MODELS = {} MAX_CN_MODELS = 3 def clear_networks(): global __CN_MODELS __CN_MODELS = {} def load_network_model_by_key(repo_id: str, pipeline_type: str): global __CN_MODELS if repo_id in __CN_MODELS: return __CN_MODELS[repo_id] if len(__CN_MODELS) >= MAX_CN_MODELS: __CN_MODELS = {} if pipeline_type == "controlnet": model = ControlNetModel.from_pretrained( repo_id, torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), token=get_hf_token(), ).to("cuda") elif pipeline_type == "t2i": model = T2IAdapter.from_pretrained( repo_id, torch_dtype=torch.float16, varient="fp16", token=get_hf_token(), ).to("cuda") else: raise Exception("Invalid pipeline type") __CN_MODELS[repo_id] = model return model class StableDiffusionNetworkModelPipelineLoader: """Loads the pipeline for network module, eg: controlnet or t2i. Does not throw error in case of unsupported configurations, instead it returns None. """ def __new__( cls, is_sdxl, is_img2img, network_model, pipeline_type, base_pipe: Optional[AbstractSet] = None, ): if base_pipe is None: pretrained = True kwargs = { "pretrained_model_name_or_path": get_model_dir(), "torch_dtype": torch.float16, "token": get_hf_token(), "cache_dir": get_hf_cache_dir(), } else: pretrained = False kwargs = { **base_pipe.pipe.components, # pyright: ignore } if get_is_sdxl(): kwargs.pop("image_encoder", None) kwargs.pop("feature_extractor", None) if is_sdxl and is_img2img and pipeline_type == "controlnet": model = ( StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained if pretrained else StableDiffusionXLControlNetImg2ImgPipeline ) return model(controlnet=network_model, **kwargs).to("cuda") if is_sdxl and pipeline_type == "controlnet": model = ( StableDiffusionXLControlNetPipeline.from_pretrained if pretrained else StableDiffusionXLControlNetPipeline ) return model(controlnet=network_model, **kwargs).to("cuda") if is_sdxl and pipeline_type == "t2i": model = ( StableDiffusionXLAdapterPipeline.from_pretrained if pretrained else StableDiffusionXLAdapterPipeline ) return model(adapter=network_model, **kwargs).to("cuda") if is_img2img and pipeline_type == "controlnet": model = ( StableDiffusionControlNetImg2ImgPipeline.from_pretrained if pretrained else StableDiffusionControlNetImg2ImgPipeline ) return model(controlnet=network_model, **kwargs).to("cuda") if pipeline_type == "controlnet": model = ( StableDiffusionControlNetPipeline.from_pretrained if pretrained else StableDiffusionControlNetPipeline ) return model(controlnet=network_model, **kwargs).to("cuda") if pipeline_type == "t2i": model = ( StableDiffusionAdapterPipeline.from_pretrained if pretrained else StableDiffusionAdapterPipeline ) return model(adapter=network_model, **kwargs).to("cuda") print( f"Warning: Unsupported configuration {is_sdxl=}, {is_img2img=}, {pipeline_type=}" ) return None class ControlNet(AbstractPipeline): __current_task_name = "" __loaded = False __pipe_type = None def init(self, pipeline: AbstractPipeline): setattr(self, "__pipeline", pipeline) def unload(self): "Unloads the network module, pipelines and clears the cache." if not self.__loaded: return self.__loaded = False self.__pipe_type = None self.__current_task_name = "" if hasattr(self, "pipe"): delattr(self, "pipe") if hasattr(self, "pipe2"): delattr(self, "pipe2") clear_cuda_and_gc() def load_model(self, task_name: CONTROLNET_TYPES): "Appropriately loads the network module, pipelines and cache it for reuse." if self.__current_task_name == task_name: return config = self.__model_sdxl if get_is_sdxl() else self.__model_normal model = config[task_name] if not model: raise Exception(f"ControlNet is not supported for {task_name}") while model in list(config.keys()): task_name = model # pyright: ignore model = config[task_name] pipeline_type = ( self.__model_sdxl_types[task_name] if get_is_sdxl() else self.__model_normal_types[task_name] ) if "," in model: model = [m.strip() for m in model.split(",")] model = self.__load_network_model(model, pipeline_type) self.__load_pipeline(model, pipeline_type) self.__current_task_name = task_name clear_cuda_and_gc() def __load_network_model(self, model_name, pipeline_type): "Loads the network module, eg: ControlNet or T2I Adapters" if type(model_name) == str: return load_network_model_by_key(model_name, pipeline_type) elif type(model_name) == list: if pipeline_type == "controlnet": cns = [] for model in model_name: cns.append(load_network_model_by_key(model, pipeline_type)) return MultiControlNetModel(cns).to("cuda") elif pipeline_type == "t2i": raise Exception("Multi T2I adapters are not supported") raise Exception("Invalid pipeline type") def __load_pipeline(self, network_model, pipeline_type): "Load the base pipeline(s) (if not loaded already) based on pipeline type and attaches the network module to the pipeline" def patch_pipe(pipe): if not pipe: # cases where the loader may return None return None if get_is_sdxl(): pipe.enable_vae_tiling() pipe.enable_vae_slicing() pipe.enable_xformers_memory_efficient_attention() # this scheduler produces good outputs for t2i adapters if pipeline_type == "t2i": pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe.scheduler.config ) else: pipe.enable_xformers_memory_efficient_attention() return pipe # If the pipeline type is changed we should reload all # the pipelines if not self.__loaded or self.__pipe_type != pipeline_type: # controlnet pipeline for tile upscaler or any pipeline with img2img + network support pipe = StableDiffusionNetworkModelPipelineLoader( is_sdxl=get_is_sdxl(), is_img2img=True, network_model=network_model, pipeline_type=pipeline_type, base_pipe=getattr(self, "__pipeline", None), ) pipe = patch_pipe(pipe) if pipe: self.pipe = pipe # controlnet pipeline for canny and pose pipe2 = StableDiffusionNetworkModelPipelineLoader( is_sdxl=get_is_sdxl(), is_img2img=False, network_model=network_model, pipeline_type=pipeline_type, base_pipe=getattr(self, "__pipeline", None), ) pipe2 = patch_pipe(pipe2) if pipe2: self.pipe2 = pipe2 self.__loaded = True self.__pipe_type = pipeline_type # Set the network module in the pipeline if pipeline_type == "controlnet": if hasattr(self, "pipe"): setattr(self.pipe, "controlnet", network_model) if hasattr(self, "pipe2"): setattr(self.pipe2, "controlnet", network_model) elif pipeline_type == "t2i": if hasattr(self, "pipe"): setattr(self.pipe, "adapter", network_model) if hasattr(self, "pipe2"): setattr(self.pipe2, "adapter", network_model) if hasattr(self, "pipe"): self.pipe = self.pipe.to("cuda") if hasattr(self, "pipe2"): self.pipe2 = self.pipe2.to("cuda") clear_cuda_and_gc() def process(self, **kwargs): if self.__current_task_name == "pose": return self.process_pose(**kwargs) if self.__current_task_name == "depth": return self.process_depth(**kwargs) if self.__current_task_name == "canny": return self.process_canny(**kwargs) if self.__current_task_name == "scribble": return self.process_scribble(**kwargs) if self.__current_task_name == "linearart": return self.process_linearart(**kwargs) if self.__current_task_name == "tile_upscaler": return self.process_tile_upscaler(**kwargs) if self.__current_task_name == "canny_2x": return self.process_canny_2x(**kwargs) raise Exception("ControlNet is not loaded with any model") @torch.inference_mode() def process_canny( self, prompt: List[str], imageUrl: str, seed: int, num_inference_steps: int, negative_prompt: List[str], height: int, width: int, guidance_scale: float = 7.5, apply_preprocess: bool = True, **kwargs, ): if self.__current_task_name != "canny": raise Exception("ControlNet is not loaded with canny model") generator = get_generators(seed, get_num_return_sequences()) init_image = self.preprocess_image(imageUrl, width, height) if apply_preprocess: init_image = ControlNet.canny_detect_edge(init_image) init_image = init_image.resize((width, height)) # if get_is_sdxl(): # kwargs["controlnet_conditioning_scale"] = 0.5 kwargs = { "prompt": prompt, "image": init_image, "guidance_scale": guidance_scale, "num_images_per_prompt": 1, "negative_prompt": negative_prompt, "num_inference_steps": num_inference_steps, "height": height, "width": width, "generator": generator, **kwargs, } print(kwargs) result = self.pipe2.__call__(**kwargs) return Result.from_result(result), init_image @torch.inference_mode() def process_canny_2x( self, prompt: List[str], imageUrl: str, seed: int, num_inference_steps: int, negative_prompt: List[str], height: int, width: int, guidance_scale: float = 8.5, **kwargs, ): if self.__current_task_name != "canny_2x": raise Exception("ControlNet is not loaded with canny model") generator = get_generators(seed, get_num_return_sequences()) init_image = self.preprocess_image(imageUrl, width, height) canny_image = ControlNet.canny_detect_edge(init_image).resize((width, height)) depth_image = ControlNet.depth_image(init_image).resize((width, height)) condition_scale = kwargs.get("controlnet_conditioning_scale", None) condition_factor = kwargs.get("control_guidance_end", None) print("condition_scale", condition_scale) if not get_is_sdxl(): kwargs["guidance_scale"] = 7.5 kwargs["strength"] = 0.8 kwargs["controlnet_conditioning_scale"] = [condition_scale or 1.0, 0.3] else: kwargs["controlnet_conditioning_scale"] = [condition_scale or 0.8, 0.3] kwargs["control_guidance_end"] = [condition_factor or 1.0, 1.0] kwargs = { "prompt": prompt[0], "image": [init_image] * get_num_return_sequences(), "control_image": [canny_image, depth_image], "guidance_scale": guidance_scale, "num_images_per_prompt": get_num_return_sequences(), "negative_prompt": negative_prompt[0], "num_inference_steps": num_inference_steps, "strength": 1.0, "height": height, "width": width, "generator": generator, **kwargs, } print(kwargs) result = self.pipe.__call__(**kwargs) return Result.from_result(result), canny_image @torch.inference_mode() def process_pose( self, prompt: List[str], image: List[Image.Image], seed: int, num_inference_steps: int, negative_prompt: List[str], height: int, width: int, guidance_scale: float = 7.5, **kwargs, ): if self.__current_task_name != "pose": raise Exception("ControlNet is not loaded with pose model") generator = get_generators(seed, get_num_return_sequences()) kwargs = { "prompt": prompt[0], "image": image, "num_images_per_prompt": get_num_return_sequences(), "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt[0], "guidance_scale": guidance_scale, "height": height, "width": width, "generator": generator, **kwargs, } print(kwargs) result = self.pipe2.__call__(**kwargs) return Result.from_result(result), image @torch.inference_mode() def process_tile_upscaler( self, imageUrl: str, prompt: str, negative_prompt: str, num_inference_steps: int, seed: int, height: int, width: int, resize_dimension: int, guidance_scale: float = 7.5, **kwargs, ): if self.__current_task_name != "tile_upscaler": raise Exception("ControlNet is not loaded with tile_upscaler model") init_image = None # find the correct seed and imageUrl from imageUrl try: p = os.path.splitext(imageUrl)[0] p = p.split("/")[-1] p = p.split("_")[-1] seed = seed + int(p) if "_canny_2x" or "_linearart" in imageUrl: imageUrl = imageUrl.replace("_canny_2x", "_canny_2x_highres").replace( "_linearart_highres", "" ) init_image = download_image(imageUrl) width, height = init_image.size print("Setting imageUrl with width and height", imageUrl, width, height) except Exception as e: print("Failed to extract seed from imageUrl", e) print("Setting seed", seed) generator = get_generators(seed) if not init_image: init_image = download_image(imageUrl).resize((width, height)) condition_image = ImageUtil.resize_image(init_image, 1024) if get_is_sdxl(): condition_image = condition_image.resize(init_image.size) else: condition_image = self.__resize_for_condition_image( init_image, resize_dimension ) if get_is_sdxl(): kwargs["strength"] = 1.0 kwargs["controlnet_conditioning_scale"] = 1.0 kwargs["image"] = init_image else: kwargs["image"] = condition_image kwargs["guidance_scale"] = guidance_scale kwargs = { "prompt": prompt, "control_image": condition_image, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": condition_image.size[1], "width": condition_image.size[0], "generator": generator, **kwargs, } result = self.pipe.__call__(**kwargs) return Result.from_result(result), condition_image @torch.inference_mode() def process_scribble( self, image: List[Image.Image], prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], num_inference_steps: int, seed: int, height: int, width: int, guidance_scale: float = 7.5, apply_preprocess: bool = True, **kwargs, ): if self.__current_task_name != "scribble": raise Exception("ControlNet is not loaded with scribble model") generator = get_generators(seed, get_num_return_sequences()) if apply_preprocess: if get_is_sdxl(): # We use sketch in SDXL image = [ ControlNet.pidinet_image(image[0]).resize((width, height)) ] * len(image) else: image = [ ControlNet.scribble_image(image[0]).resize((width, height)) ] * len(image) sdxl_args = ( { "guidance_scale": guidance_scale, "adapter_conditioning_scale": 1.0, "adapter_conditioning_factor": 1.0, } if get_is_sdxl() else {} ) kwargs = { "image": image, "prompt": prompt, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": height, "width": width, "guidance_scale": guidance_scale, "generator": generator, **sdxl_args, **kwargs, } result = self.pipe2.__call__(**kwargs) return Result.from_result(result), image[0] @torch.inference_mode() def process_linearart( self, imageUrl: str, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], num_inference_steps: int, seed: int, height: int, width: int, guidance_scale: float = 7.5, apply_preprocess: bool = True, **kwargs, ): if self.__current_task_name != "linearart": raise Exception("ControlNet is not loaded with linearart model") generator = get_generators(seed, get_num_return_sequences()) init_image = self.preprocess_image(imageUrl, width, height) if apply_preprocess: condition_image = ControlNet.linearart_condition_image(init_image) condition_image = condition_image.resize(init_image.size) else: condition_image = init_image # we use t2i adapter and the conditioning scale should always be 0.8 sdxl_args = ( { "guidance_scale": guidance_scale, "adapter_conditioning_scale": 1.0, "adapter_conditioning_factor": 1.0, } if get_is_sdxl() else {} ) kwargs = { "image": [condition_image] * get_num_return_sequences(), "prompt": prompt, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": height, "width": width, "guidance_scale": guidance_scale, "generator": generator, **sdxl_args, **kwargs, } result = self.pipe2.__call__(**kwargs) return Result.from_result(result), condition_image @torch.inference_mode() def process_depth( self, imageUrl: str, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], num_inference_steps: int, seed: int, height: int, width: int, guidance_scale: float = 7.5, apply_preprocess: bool = True, **kwargs, ): if self.__current_task_name != "depth": raise Exception("ControlNet is not loaded with depth model") generator = get_generators(seed, get_num_return_sequences()) init_image = self.preprocess_image(imageUrl, width, height) if apply_preprocess: condition_image = ControlNet.depth_image(init_image) condition_image = condition_image.resize(init_image.size) else: condition_image = init_image # for using the depth controlnet in this SDXL model, these hyperparamters are optimal sdxl_args = ( {"controlnet_conditioning_scale": 0.2, "control_guidance_end": 0.2} if get_is_sdxl() else {} ) kwargs = { "image": [condition_image] * get_num_return_sequences(), "prompt": prompt, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": height, "width": width, "guidance_scale": guidance_scale, "generator": generator, **sdxl_args, **kwargs, } result = self.pipe2.__call__(**kwargs) return Result.from_result(result), condition_image def cleanup(self): """Doesn't do anything considering new diffusers has itself a cleanup mechanism after controlnet generation""" pass def detect_pose(self, imageUrl: str) -> Image.Image: detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") image = download_image(imageUrl) image = detector.__call__(image) return image @staticmethod def scribble_image(image: Image.Image) -> Image.Image: processor = HEDdetector.from_pretrained("lllyasviel/Annotators") image = processor.__call__(input_image=image, scribble=True) return image @staticmethod def linearart_condition_image(image: Image.Image, **kwargs) -> Image.Image: processor = LineartDetector.from_pretrained("lllyasviel/Annotators") if get_is_sdxl(): kwargs = {"detect_resolution": 384, "image_resolution": 1024, **kwargs} else: kwargs = {} image = processor.__call__(input_image=image, **kwargs) return image @staticmethod @torch.inference_mode() def depth_image(image: Image.Image) -> Image.Image: global midas, midas_transforms if "midas" not in globals(): midas = torch.hub.load("intel-isl/MiDaS", "MiDaS").to("cuda") midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") transform = midas_transforms.default_transform cv_image = np.array(image) img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) input_batch = transform(img).to("cuda") with torch.no_grad(): prediction = midas(input_batch) prediction = torch.nn.functional.interpolate( prediction.unsqueeze(1), size=img.shape[:2], mode="bicubic", align_corners=False, ).squeeze() output = prediction.cpu().numpy() formatted = (output * 255 / np.max(output)).astype("uint8") img = Image.fromarray(formatted) return img @staticmethod def pidinet_image(image: Image.Image) -> Image.Image: pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda") image = pidinet.__call__(input_image=image, apply_filter=True) return image @staticmethod def canny_detect_edge(image: Image.Image) -> Image.Image: image_array = np.array(image) low_threshold = 100 high_threshold = 200 image_array = cv2.Canny(image_array, low_threshold, high_threshold) image_array = image_array[:, :, None] image_array = np.concatenate([image_array, image_array, image_array], axis=2) canny_image = Image.fromarray(image_array) return canny_image def preprocess_image(self, imageUrl, width, height) -> Image.Image: image = download_image(imageUrl, mode="RGBA").resize((width, height)) return ImageUtil.alpha_to_white(image) def __resize_for_condition_image(self, image: Image.Image, resolution: int): input_image = image.convert("RGB") W, H = input_image.size k = float(resolution) / max(W, H) H *= k W *= k H = int(round(H / 64.0)) * 64 W = int(round(W / 64.0)) * 64 img = input_image.resize((W, H), resample=Image.LANCZOS) return img __model_normal = { "pose": "lllyasviel/control_v11f1p_sd15_depth, lllyasviel/control_v11p_sd15_openpose", "canny": "lllyasviel/control_v11p_sd15_canny", "linearart": "lllyasviel/control_v11p_sd15_lineart", "scribble": "lllyasviel/control_v11p_sd15_scribble", "tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile", "canny_2x": "lllyasviel/control_v11p_sd15_canny, lllyasviel/control_v11f1p_sd15_depth", } __model_normal_types = { "pose": "controlnet", "canny": "controlnet", "linearart": "controlnet", "scribble": "controlnet", "tile_upscaler": "controlnet", "canny_2x": "controlnet", } __model_sdxl = { "pose": "thibaud/controlnet-openpose-sdxl-1.0", "canny": "Autodraft/controlnet-canny-sdxl-1.0", "depth": "Autodraft/controlnet-depth-sdxl-1.0", "canny_2x": "Autodraft/controlnet-canny-sdxl-1.0, Autodraft/controlnet-depth-sdxl-1.0", "linearart": "TencentARC/t2i-adapter-lineart-sdxl-1.0", "scribble": "TencentARC/t2i-adapter-sketch-sdxl-1.0", "tile_upscaler": "Autodraft/ControlNet_SDXL_tile_upscale", } __model_sdxl_types = { "pose": "controlnet", "canny": "controlnet", "canny_2x": "controlnet", "depth": "controlnet", "linearart": "t2i", "scribble": "t2i", "tile_upscaler": "controlnet", }