from io import BytesIO from typing import List, Optional, Union import torch from cv2 import inpaint from diffusers import ( ControlNetModel, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, StableDiffusionInpaintPipeline, UniPCMultistepScheduler, ) from PIL import Image, ImageFilter, ImageOps import internals.util.image as ImageUtil from internals.data.result import Result from internals.data.task import ModelType from internals.pipelines.commons import AbstractPipeline from internals.pipelines.controlnets import ControlNet, load_network_model_by_key from internals.pipelines.high_res import HighRes from internals.pipelines.inpainter import InPainter from internals.pipelines.remove_background import RemoveBackgroundV2 from internals.pipelines.upscaler import Upscaler from internals.util import get_generators from internals.util.cache import clear_cuda_and_gc from internals.util.commons import download_image from internals.util.config import ( get_hf_cache_dir, get_hf_token, get_inpaint_model_path, get_model_dir, get_num_return_sequences, ) class ReplaceBackground(AbstractPipeline): __loaded = False def load( self, upscaler: Optional[Upscaler] = None, remove_background: Optional[RemoveBackgroundV2] = None, base: Optional[AbstractPipeline] = None, high_res: Optional[HighRes] = None, ): if self.__loaded: return controlnet_model = load_network_model_by_key( "lllyasviel/control_v11p_sd15_canny", "controlnet" ) if base: pipe = StableDiffusionControlNetPipeline( **base.pipe.components, controlnet=controlnet_model, ) else: pipe = StableDiffusionControlNetPipeline.from_pretrained( get_model_dir(), controlnet=controlnet_model, torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), token=get_hf_token(), ) pipe.enable_xformers_memory_efficient_attention() pipe.enable_vae_slicing() pipe.to("cuda") self.pipe = pipe if not high_res: high_res = HighRes() high_res.load() self.high_res = high_res if not upscaler: upscaler = Upscaler() upscaler.load() self.upscaler = upscaler if not remove_background: remove_background = RemoveBackgroundV2() self.remove_background = remove_background self.__loaded = True def unload(self): self.__loaded = False self.pipe = None self.high_res = None self.upscaler = None self.remove_background = None clear_cuda_and_gc() @torch.inference_mode() def replace( self, image: Union[str, Image.Image], width: int, height: int, prompt: List[str], negative_prompt: List[str], conditioning_scale: float, seed: int, steps: int, apply_high_res: bool = False, model_type: ModelType = ModelType.REAL, ): if type(image) is str: image = download_image(image) generator = get_generators(seed, get_num_return_sequences()) image = image.convert("RGB") if max(image.size) > 1024: image = ImageUtil.resize_image(image, dimension=1024) image = self.remove_background.remove(image, model_type=model_type) width = int(width) height = int(height) resolution = max(width, height) image = ImageUtil.resize_image(image, resolution) image = ImageUtil.padd_image(image, width, height) mask = image.copy() pixdata = mask.load() w, h = mask.size for y in range(h): for x in range(w): item = pixdata[x, y] if item[3] == 0: pixdata[x, y] = (255, 255, 255, 255) else: pixdata[x, y] = (0, 0, 0, 255) condition_image = ControlNet.canny_detect_edge(image) mask = mask.convert("RGB") result = self.pipe.__call__( prompt=prompt, negative_prompt=negative_prompt, image=condition_image, controlnet_conditioning_scale=conditioning_scale, guidance_scale=9, height=height, num_inference_steps=steps, generator=generator, width=width, ) result = Result.from_result(result) images, has_nsfw = result if not has_nsfw: for i in range(len(images)): images[i].paste(image, (0, 0), image) return (images, has_nsfw)