from typing import Optional import torch from diffusers import ControlNetModel, StableDiffusionControlNetImg2ImgPipeline from PIL import Image import internals.util.image as ImageUtil from internals.pipelines.commons import AbstractPipeline from internals.pipelines.controlnets import ControlNet from internals.pipelines.high_res import HighRes from internals.pipelines.sdxl_llite_pipeline import SDXLLLiteImg2ImgPipeline from internals.util.config import get_base_dimension, get_hf_cache_dir, get_is_sdxl class RealtimeDraw(AbstractPipeline): def load(self, pipeline: AbstractPipeline): if hasattr(self, "pipe"): return if get_is_sdxl(): lite_pipe = SDXLLLiteImg2ImgPipeline() lite_pipe.load( pipeline, [ "https://s3.ap-south-1.amazonaws.com/autodraft.model.assets/models/replicate-xl-llite.safetensors" ], ) self.pipe = lite_pipe else: self.__controlnet_scribble = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_scribble", torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), ) self.__controlnet_seg = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_seg", torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), ) kwargs = {**pipeline.pipe.components} # pyright: ignore kwargs.pop("image_encoder", None) self.pipe = StableDiffusionControlNetImg2ImgPipeline( **kwargs, controlnet=self.__controlnet_seg ).to("cuda") self.pipe.safety_checker = None self.pipe2 = StableDiffusionControlNetImg2ImgPipeline( **kwargs, controlnet=[self.__controlnet_scribble, self.__controlnet_seg] ).to("cuda") self.pipe2.safety_checker = None def process_seg( self, image: Image.Image, prompt: str, negative_prompt: str, seed: int, ): if get_is_sdxl(): raise Exception("SDXL is not supported for this method") torch.manual_seed(seed) image = ImageUtil.resize_image(image, 512) img = self.pipe.__call__( image=image, control_image=image, prompt=prompt, num_inference_steps=15, negative_prompt=negative_prompt, guidance_scale=10, strength=0.8, ).images[0] return img def process_img( self, prompt: str, negative_prompt: str, seed: int, image: Optional[Image.Image] = None, image2: Optional[Image.Image] = None, ): torch.manual_seed(seed) b_dimen = get_base_dimension() if not image: size = (b_dimen, b_dimen) if image2: size = image2.size image = Image.new("RGB", size, color=0) if not image2: size = (b_dimen, b_dimen) if image: size = image.size image2 = Image.new("RGB", size, color=0) if get_is_sdxl(): size = HighRes.find_closest_sdxl_aspect_ratio(image.size[0], image.size[1]) image = image.resize(size) images = self.pipe.__call__( image=image, condition_image=image, negative_prompt=negative_prompt, prompt=prompt, seed=seed, num_inference_steps=10, width=image.size[0], height=image.size[1], ) img = images[0] else: image = ImageUtil.resize_image(image, b_dimen) scribble = ControlNet.scribble_image(image) image2 = ImageUtil.resize_image(image2, b_dimen) img = self.pipe2.__call__( image=image, control_image=[scribble, image2], prompt=prompt, num_inference_steps=15, negative_prompt=negative_prompt, guidance_scale=10, strength=0.9, width=image.size[0], height=image.size[1], controlnet_conditioning_scale=[1.0, 0.8], ).images[0] img = ImageUtil.resize_image(img, 512) return img