import torch from diffusers import ControlNetModel from PIL import Image from torchvision import transforms import internals.util.image as ImageUtils import internals.util.image as ImageUtil from carvekit.api import high from internals.data.result import Result from internals.data.task import TaskType from internals.pipelines.commons import AbstractPipeline, Text2Img from internals.pipelines.controlnets import ControlNet from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline from internals.pipelines.high_res import HighRes from internals.util.cache import clear_cuda_and_gc from internals.util.commons import download_image from internals.util.config import get_base_dimension controlnet = ControlNet() class SDXLTileUpscaler(AbstractPipeline): __loaded = False __current_process_mode = None def create(self, high_res: HighRes, pipeline: Text2Img, model_id: int): if self.__loaded: return # temporal hack for upscale model till multicontrolnet support is added controlnet = ControlNetModel.from_pretrained( "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 ) pipe = DemoFusionSDXLControlNetPipeline( **pipeline.pipe.components, controlnet=controlnet ) pipe = pipe.to("cuda") pipe.enable_vae_tiling() pipe.enable_vae_slicing() pipe.enable_xformers_memory_efficient_attention() self.high_res = high_res self.pipe = pipe self.__current_process_mode = TaskType.CANNY.name self.__loaded = True def unload(self): self.__loaded = False self.pipe = None self.high_res = None clear_cuda_and_gc() def __reload_controlnet(self, process_mode: str): if self.__current_process_mode == process_mode: return model = ( "thibaud/controlnet-openpose-sdxl-1.0" if process_mode == TaskType.POSE.name else "diffusers/controlnet-canny-sdxl-1.0" ) controlnet = ControlNetModel.from_pretrained( model, torch_dtype=torch.float16 ).to("cuda") if hasattr(self, "pipe"): self.pipe.controlnet = controlnet self.__current_process_mode = process_mode clear_cuda_and_gc() def process( self, prompt: str, imageUrl: str, resize_dimension: int, negative_prompt: str, width: int, height: int, model_id: int, seed: int, process_mode: str, ): generator = torch.manual_seed(seed) self.__reload_controlnet(process_mode) if process_mode == TaskType.POSE.name: print("Running POSE") condition_image = controlnet.detect_pose(imageUrl) else: print("Running CANNY") condition_image = download_image(imageUrl) condition_image = ControlNet.canny_detect_edge(condition_image) width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height) img = download_image(imageUrl).resize((width, height)) condition_image = condition_image.resize(img.size) img2 = self.__resize_for_condition_image(img, resize_dimension) img = self.pad_image(img) image_lr = self.load_and_process_image(img) out_img = self.pad_image(img2) condition_image = self.pad_image(condition_image) print("img", img.size) print("img2", img2.size) print("condition", condition_image.size) if int(model_id) == 2000173: kwargs = { "prompt": prompt, "negative_prompt": negative_prompt, "image": img2, "strength": 0.3, "num_inference_steps": 30, "generator": generator, } images = self.high_res.pipe.__call__(**kwargs).images else: images = self.pipe.__call__( image_lr=image_lr, prompt=prompt, condition_image=condition_image, negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic, " + negative_prompt, guidance_scale=11, sigma=0.8, num_inference_steps=24, controlnet_conditioning_scale=0.5, generator=generator, width=out_img.size[0], height=out_img.size[1], ) images = images[::-1] iv = ImageUtil.resize_image(img2, images[0].size[0]) images = [self.unpad_image(images[0], iv.size)] return images, False def load_and_process_image(self, pil_image): transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] ) image = transform(pil_image) image = image.unsqueeze(0).half() image = image.to("cuda") return image def pad_image(self, image): w, h = image.size if w == h: return image elif w > h: new_image = Image.new(image.mode, (w, w), (0, 0, 0)) pad_w = 0 pad_h = (w - h) // 2 new_image.paste(image, (0, pad_h)) return new_image else: new_image = Image.new(image.mode, (h, h), (0, 0, 0)) pad_w = (h - w) // 2 pad_h = 0 new_image.paste(image, (pad_w, 0)) return new_image def unpad_image(self, padded_image, original_size): w, h = original_size if w == h: return padded_image elif w > h: pad_h = (w - h) // 2 unpadded_image = padded_image.crop((0, pad_h, w, h + pad_h)) return unpadded_image else: pad_w = (h - w) // 2 unpadded_image = padded_image.crop((pad_w, 0, w + pad_w, h)) return unpadded_image def __resize_for_condition_image(self, image: Image.Image, resolution: int): input_image = image.convert("RGB") W, H = input_image.size k = float(resolution) / max(W, H) H *= k W *= k H = int(round(H / 64.0)) * 64 W = int(round(W / 64.0)) * 64 img = input_image.resize((W, H), resample=Image.LANCZOS) return img