File size: 4,472 Bytes
19b3da3 35575bb 19b3da3 35575bb 22df957 19b3da3 10230ea 35575bb 22df957 10230ea 22df957 10230ea 35575bb 10230ea 19b3da3 fd5252e 10230ea 19b3da3 fd5252e 10230ea 35575bb 10230ea c95142c 35575bb 22df957 35575bb 10230ea 35575bb 10230ea c95142c 10230ea fd5252e 19b3da3 10230ea fd5252e 0daeeb0 10230ea 0daeeb0 10230ea 22df957 19b3da3 f70725b 19b3da3 35575bb 19b3da3 35575bb f70725b 22df957 35575bb f70725b 35575bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from typing import List, Union
import torch
from diffusers import (
StableDiffusionInpaintPipeline,
StableDiffusionXLInpaintPipeline,
UNet2DConditionModel,
)
from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.high_res import HighRes
from internals.pipelines.inpaint_imageprocessor import VaeImageProcessor
from internals.util import get_generators
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import disable_safety_checker, download_image
from internals.util.config import (
get_base_inpaint_model_revision,
get_base_inpaint_model_variant,
get_hf_cache_dir,
get_hf_token,
get_inpaint_model_path,
get_is_sdxl,
get_model_dir,
get_num_return_sequences,
)
class InPainter(AbstractPipeline):
__loaded = False
def init(self, pipeline: AbstractPipeline):
self.__base = pipeline
def load(self):
if self.__loaded:
return
if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir():
self.create(self.__base)
self.__loaded = True
return
if get_is_sdxl():
# only take UNet from the repo
unet = UNet2DConditionModel.from_pretrained(
get_inpaint_model_path(),
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
token=get_hf_token(),
subfolder="unet",
variant=get_base_inpaint_model_variant(),
revision=get_base_inpaint_model_revision(),
).to("cuda")
kwargs = {**self.__base.pipe.components, "unet": unet}
self.pipe = StableDiffusionXLInpaintPipeline(**kwargs).to("cuda")
self.pipe.mask_processor = VaeImageProcessor(
vae_scale_factor=self.pipe.vae_scale_factor,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
)
self.pipe.image_processor = VaeImageProcessor(
vae_scale_factor=self.pipe.vae_scale_factor
)
else:
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
get_inpaint_model_path(),
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
token=get_hf_token(),
).to("cuda")
disable_safety_checker(self.pipe)
self.__patch()
self.__loaded = True
def create(self, pipeline: AbstractPipeline):
if get_is_sdxl():
self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to(
"cuda"
)
else:
self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
"cuda"
)
disable_safety_checker(self.pipe)
self.__patch()
def __patch(self):
if get_is_sdxl():
self.pipe.enable_vae_tiling()
self.pipe.enable_vae_slicing()
self.pipe.enable_xformers_memory_efficient_attention()
def unload(self):
self.__loaded = False
self.pipe = None
clear_cuda_and_gc()
@torch.inference_mode()
def process(
self,
image_url: str,
mask_image_url: str,
width: int,
height: int,
seed: int,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
num_inference_steps: int,
**kwargs,
):
generator = get_generators(seed, get_num_return_sequences())
input_img = download_image(image_url).resize((width, height))
mask_img = download_image(mask_image_url).resize((width, height))
if get_is_sdxl():
width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height)
mask_img = self.pipe.mask_processor.blur(mask_img, blur_factor=33)
kwargs["strength"] = 0.999
kwargs["padding_mask_crop"] = 1000
kwargs = {
"prompt": prompt,
"image": input_img,
"mask_image": mask_img,
"height": height,
"width": width,
"negative_prompt": negative_prompt,
"num_inference_steps": num_inference_steps,
"strength": 1.0,
"generator": generator,
**kwargs,
}
return self.pipe.__call__(**kwargs).images, mask_img
|