File size: 6,473 Bytes
2c6c92a
 
 
 
 
 
35575bb
99a0484
2c6c92a
35575bb
2c6c92a
 
 
99a0484
22df957
2c6c92a
 
 
 
 
 
 
22df957
35575bb
22df957
99a0484
22df957
 
1cd09a3
 
35575bb
 
 
2c6c92a
 
 
 
 
 
 
 
99a0484
 
2c6c92a
 
35575bb
22df957
 
 
 
 
 
 
 
 
35575bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c6c92a
 
 
 
 
 
 
 
1cd09a3
35575bb
 
2c6c92a
35575bb
 
 
 
 
 
1cd09a3
 
35575bb
1cd09a3
 
35575bb
2c6c92a
35575bb
1cd09a3
2c6c92a
 
 
35575bb
2c6c92a
35575bb
 
 
 
 
 
 
99a0484
 
 
 
 
 
 
35575bb
99a0484
 
 
 
 
 
 
35575bb
 
99a0484
 
 
35575bb
 
 
 
99a0484
 
35575bb
 
2c6c92a
 
 
 
 
 
 
 
 
 
 
 
 
 
35575bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c6c92a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import torch
from diffusers import ControlNetModel
from PIL import Image
from torchvision import transforms

import internals.util.image as ImageUtils
import internals.util.image as ImageUtil
from carvekit.api import high
from internals.data.result import Result
from internals.data.task import TaskType
from internals.pipelines.commons import AbstractPipeline, Text2Img
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
from internals.pipelines.high_res import HighRes
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image
from internals.util.config import get_base_dimension

controlnet = ControlNet()


class SDXLTileUpscaler(AbstractPipeline):
    __loaded = False
    __current_process_mode = None

    def create(self, high_res: HighRes, pipeline: Text2Img, model_id: int):
        if self.__loaded:
            return
        # temporal hack for upscale model till multicontrolnet support is added

        controlnet = ControlNetModel.from_pretrained(
            "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
        )
        pipe = DemoFusionSDXLControlNetPipeline(
            **pipeline.pipe.components, controlnet=controlnet
        )
        pipe = pipe.to("cuda")
        pipe.enable_vae_tiling()
        pipe.enable_vae_slicing()
        pipe.enable_xformers_memory_efficient_attention()

        self.high_res = high_res

        self.pipe = pipe

        self.__current_process_mode = TaskType.CANNY.name
        self.__loaded = True

    def unload(self):
        self.__loaded = False
        self.pipe = None
        self.high_res = None

        clear_cuda_and_gc()

    def __reload_controlnet(self, process_mode: str):
        if self.__current_process_mode == process_mode:
            return

        model = (
            "thibaud/controlnet-openpose-sdxl-1.0"
            if process_mode == TaskType.POSE.name
            else "diffusers/controlnet-canny-sdxl-1.0"
        )
        controlnet = ControlNetModel.from_pretrained(
            model, torch_dtype=torch.float16
        ).to("cuda")

        if hasattr(self, "pipe"):
            self.pipe.controlnet = controlnet

        self.__current_process_mode = process_mode

        clear_cuda_and_gc()

    def process(
        self,
        prompt: str,
        imageUrl: str,
        resize_dimension: int,
        negative_prompt: str,
        width: int,
        height: int,
        model_id: int,
        seed: int,
        process_mode: str,
    ):
        generator = torch.manual_seed(seed)

        self.__reload_controlnet(process_mode)

        if process_mode == TaskType.POSE.name:
            print("Running POSE")
            condition_image = controlnet.detect_pose(imageUrl)
        else:
            print("Running CANNY")
            condition_image = download_image(imageUrl)
            condition_image = ControlNet.canny_detect_edge(condition_image)
        width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height)

        img = download_image(imageUrl).resize((width, height))
        condition_image = condition_image.resize(img.size)

        img2 = self.__resize_for_condition_image(img, resize_dimension)

        img = self.pad_image(img)
        image_lr = self.load_and_process_image(img)

        out_img = self.pad_image(img2)
        condition_image = self.pad_image(condition_image)

        print("img", img.size)
        print("img2", img2.size)
        print("condition", condition_image.size)
        if int(model_id) == 2000173:
            kwargs = {
                "prompt": prompt,
                "negative_prompt": negative_prompt,
                "image": img2,
                "strength": 0.3,
                "num_inference_steps": 30,
                "generator": generator,
            }
            images = self.high_res.pipe.__call__(**kwargs).images
        else:
            images = self.pipe.__call__(
                image_lr=image_lr,
                prompt=prompt,
                condition_image=condition_image,
                negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic, "
                + negative_prompt,
                guidance_scale=11,
                sigma=0.8,
                num_inference_steps=24,
                controlnet_conditioning_scale=0.5,
                generator=generator,
                width=out_img.size[0],
                height=out_img.size[1],
            )
            images = images[::-1]
            iv = ImageUtil.resize_image(img2, images[0].size[0])
            images = [self.unpad_image(images[0], iv.size)]
        return images, False

    def load_and_process_image(self, pil_image):
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )
        image = transform(pil_image)
        image = image.unsqueeze(0).half()
        image = image.to("cuda")
        return image

    def pad_image(self, image):
        w, h = image.size
        if w == h:
            return image
        elif w > h:
            new_image = Image.new(image.mode, (w, w), (0, 0, 0))
            pad_w = 0
            pad_h = (w - h) // 2
            new_image.paste(image, (0, pad_h))
            return new_image
        else:
            new_image = Image.new(image.mode, (h, h), (0, 0, 0))
            pad_w = (h - w) // 2
            pad_h = 0
            new_image.paste(image, (pad_w, 0))
            return new_image

    def unpad_image(self, padded_image, original_size):
        w, h = original_size
        if w == h:
            return padded_image
        elif w > h:
            pad_h = (w - h) // 2
            unpadded_image = padded_image.crop((0, pad_h, w, h + pad_h))
            return unpadded_image
        else:
            pad_w = (h - w) // 2
            unpadded_image = padded_image.crop((pad_w, 0, w + pad_w, h))
            return unpadded_image

    def __resize_for_condition_image(self, image: Image.Image, resolution: int):
        input_image = image.convert("RGB")
        W, H = input_image.size
        k = float(resolution) / max(W, H)
        H *= k
        W *= k
        H = int(round(H / 64.0)) * 64
        W = int(round(W / 64.0)) * 64
        img = input_image.resize((W, H), resample=Image.LANCZOS)
        return img