from typing import List, Union import torch from diffusers import ( StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline, UNet2DConditionModel, ) from internals.pipelines.commons import AbstractPipeline from internals.pipelines.high_res import HighRes from internals.pipelines.inpaint_imageprocessor import VaeImageProcessor from internals.util import get_generators from internals.util.cache import clear_cuda_and_gc from internals.util.commons import disable_safety_checker, download_image from internals.util.config import ( get_base_inpaint_model_revision, get_base_inpaint_model_variant, get_hf_cache_dir, get_hf_token, get_inpaint_model_path, get_is_sdxl, get_model_dir, get_num_return_sequences, ) class InPainter(AbstractPipeline): __loaded = False def init(self, pipeline: AbstractPipeline): self.__base = pipeline def load(self): if self.__loaded: return if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir(): self.create(self.__base) self.__loaded = True return if get_is_sdxl(): # only take UNet from the repo unet = UNet2DConditionModel.from_pretrained( get_inpaint_model_path(), torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), token=get_hf_token(), subfolder="unet", variant=get_base_inpaint_model_variant(), revision=get_base_inpaint_model_revision(), ).to("cuda") kwargs = {**self.__base.pipe.components, "unet": unet} self.pipe = StableDiffusionXLInpaintPipeline(**kwargs).to("cuda") self.pipe.mask_processor = VaeImageProcessor( vae_scale_factor=self.pipe.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True, ) self.pipe.image_processor = VaeImageProcessor( vae_scale_factor=self.pipe.vae_scale_factor ) else: self.pipe = StableDiffusionInpaintPipeline.from_pretrained( get_inpaint_model_path(), torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), token=get_hf_token(), ).to("cuda") disable_safety_checker(self.pipe) self.__patch() self.__loaded = True def create(self, pipeline: AbstractPipeline): if get_is_sdxl(): self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to( "cuda" ) else: self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to( "cuda" ) disable_safety_checker(self.pipe) self.__patch() def __patch(self): if get_is_sdxl(): self.pipe.enable_vae_tiling() self.pipe.enable_vae_slicing() self.pipe.enable_xformers_memory_efficient_attention() def unload(self): self.__loaded = False self.pipe = None clear_cuda_and_gc() @torch.inference_mode() def process( self, image_url: str, mask_image_url: str, width: int, height: int, seed: int, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], num_inference_steps: int, **kwargs, ): generator = get_generators(seed, get_num_return_sequences()) input_img = download_image(image_url).resize((width, height)) mask_img = download_image(mask_image_url).resize((width, height)) if get_is_sdxl(): width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height) mask_img = self.pipe.mask_processor.blur(mask_img, blur_factor=33) kwargs["strength"] = 0.999 kwargs["padding_mask_crop"] = 1000 kwargs = { "prompt": prompt, "image": input_img, "mask_image": mask_img, "height": height, "width": width, "negative_prompt": negative_prompt, "num_inference_steps": num_inference_steps, "strength": 1.0, "generator": generator, **kwargs, } return self.pipe.__call__(**kwargs).images, mask_img