from typing import List, Union import torch from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline from internals.pipelines.commons import AbstractPipeline from internals.util.commons import disable_safety_checker, download_image from internals.util.config import ( get_hf_cache_dir, get_hf_token, get_is_sdxl, get_inpaint_model_path, get_model_dir, ) class InPainter(AbstractPipeline): __loaded = False def init(self, pipeline: AbstractPipeline): self.__base = pipeline def load(self): if self.__loaded: return if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir(): self.create(self.__base) self.__loaded = True return if get_is_sdxl(): self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained( get_inpaint_model_path(), torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), use_auth_token=get_hf_token(), ).to("cuda") else: self.pipe = StableDiffusionInpaintPipeline.from_pretrained( get_inpaint_model_path(), torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), use_auth_token=get_hf_token(), ).to("cuda") disable_safety_checker(self.pipe) self.__patch() self.__loaded = True def create(self, pipeline: AbstractPipeline): if get_is_sdxl(): self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to( "cuda" ) else: self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to( "cuda" ) disable_safety_checker(self.pipe) self.__patch() def __patch(self): if get_is_sdxl(): self.pipe.enable_vae_tiling() self.pipe.enable_vae_slicing() self.pipe.enable_xformers_memory_efficient_attention() @torch.inference_mode() def process( self, image_url: str, mask_image_url: str, width: int, height: int, seed: int, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], num_inference_steps: int, **kwargs, ): torch.manual_seed(seed) input_img = download_image(image_url).resize((width, height)) mask_img = download_image(mask_image_url).resize((width, height)) kwargs = { "prompt": prompt, "image": input_img, "mask_image": mask_img, "height": height, "width": width, "negative_prompt": negative_prompt, "num_inference_steps": num_inference_steps, **kwargs, } return self.pipe.__call__(**kwargs).images