File size: 4,472 Bytes
19b3da3
 
 
35575bb
 
 
 
 
19b3da3
 
35575bb
 
 
22df957
19b3da3
10230ea
35575bb
22df957
10230ea
 
 
22df957
10230ea
35575bb
10230ea
19b3da3
 
 
fd5252e
 
10230ea
 
 
19b3da3
fd5252e
 
 
10230ea
 
 
 
 
 
35575bb
 
10230ea
 
 
c95142c
35575bb
22df957
35575bb
10230ea
35575bb
 
 
 
 
 
 
 
 
 
 
10230ea
 
 
 
 
c95142c
10230ea
fd5252e
19b3da3
 
10230ea
 
fd5252e
 
0daeeb0
10230ea
 
 
 
 
 
 
 
0daeeb0
 
10230ea
 
 
 
 
 
 
 
22df957
 
 
 
 
19b3da3
 
 
 
 
 
 
 
 
 
f70725b
 
19b3da3
35575bb
19b3da3
 
 
 
35575bb
 
 
 
 
 
 
f70725b
 
 
 
 
 
 
 
22df957
35575bb
f70725b
 
35575bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from typing import List, Union

import torch
from diffusers import (
    StableDiffusionInpaintPipeline,
    StableDiffusionXLInpaintPipeline,
    UNet2DConditionModel,
)

from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.high_res import HighRes
from internals.pipelines.inpaint_imageprocessor import VaeImageProcessor
from internals.util import get_generators
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import disable_safety_checker, download_image
from internals.util.config import (
    get_base_inpaint_model_revision,
    get_base_inpaint_model_variant,
    get_hf_cache_dir,
    get_hf_token,
    get_inpaint_model_path,
    get_is_sdxl,
    get_model_dir,
    get_num_return_sequences,
)


class InPainter(AbstractPipeline):
    __loaded = False

    def init(self, pipeline: AbstractPipeline):
        self.__base = pipeline

    def load(self):
        if self.__loaded:
            return

        if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir():
            self.create(self.__base)
            self.__loaded = True
            return

        if get_is_sdxl():
            # only take UNet from the repo
            unet = UNet2DConditionModel.from_pretrained(
                get_inpaint_model_path(),
                torch_dtype=torch.float16,
                cache_dir=get_hf_cache_dir(),
                token=get_hf_token(),
                subfolder="unet",
                variant=get_base_inpaint_model_variant(),
                revision=get_base_inpaint_model_revision(),
            ).to("cuda")
            kwargs = {**self.__base.pipe.components, "unet": unet}
            self.pipe = StableDiffusionXLInpaintPipeline(**kwargs).to("cuda")
            self.pipe.mask_processor = VaeImageProcessor(
                vae_scale_factor=self.pipe.vae_scale_factor,
                do_normalize=False,
                do_binarize=True,
                do_convert_grayscale=True,
            )
            self.pipe.image_processor = VaeImageProcessor(
                vae_scale_factor=self.pipe.vae_scale_factor
            )
        else:
            self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
                get_inpaint_model_path(),
                torch_dtype=torch.float16,
                cache_dir=get_hf_cache_dir(),
                token=get_hf_token(),
            ).to("cuda")

        disable_safety_checker(self.pipe)

        self.__patch()

        self.__loaded = True

    def create(self, pipeline: AbstractPipeline):
        if get_is_sdxl():
            self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to(
                "cuda"
            )
        else:
            self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
                "cuda"
            )
        disable_safety_checker(self.pipe)

        self.__patch()

    def __patch(self):
        if get_is_sdxl():
            self.pipe.enable_vae_tiling()
            self.pipe.enable_vae_slicing()
        self.pipe.enable_xformers_memory_efficient_attention()

    def unload(self):
        self.__loaded = False
        self.pipe = None
        clear_cuda_and_gc()

    @torch.inference_mode()
    def process(
        self,
        image_url: str,
        mask_image_url: str,
        width: int,
        height: int,
        seed: int,
        prompt: Union[str, List[str]],
        negative_prompt: Union[str, List[str]],
        num_inference_steps: int,
        **kwargs,
    ):
        generator = get_generators(seed, get_num_return_sequences())

        input_img = download_image(image_url).resize((width, height))
        mask_img = download_image(mask_image_url).resize((width, height))

        if get_is_sdxl():
            width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height)
            mask_img = self.pipe.mask_processor.blur(mask_img, blur_factor=33)

            kwargs["strength"] = 0.999
            kwargs["padding_mask_crop"] = 1000

        kwargs = {
            "prompt": prompt,
            "image": input_img,
            "mask_image": mask_img,
            "height": height,
            "width": width,
            "negative_prompt": negative_prompt,
            "num_inference_steps": num_inference_steps,
            "strength": 1.0,
            "generator": generator,
            **kwargs,
        }
        return self.pipe.__call__(**kwargs).images, mask_img