File size: 3,877 Bytes
2c6c92a
 
 
 
 
 
99a0484
2c6c92a
 
 
 
99a0484
2c6c92a
 
 
 
 
 
 
99a0484
1cd09a3
 
 
 
 
2c6c92a
1cd09a3
 
2c6c92a
 
 
 
 
 
 
 
99a0484
 
2c6c92a
 
 
 
 
 
 
 
 
 
1cd09a3
2c6c92a
1cd09a3
 
 
 
 
2c6c92a
 
 
1cd09a3
2c6c92a
 
 
 
 
99a0484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c6c92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from diffusers import ControlNetModel
from PIL import Image
from torchvision import transforms

import internals.util.image as ImageUtils
from carvekit.api import high
from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline, Text2Img
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
from internals.pipelines.high_res import HighRes
from internals.util.commons import download_image
from internals.util.config import get_base_dimension

controlnet = ControlNet()


class SDXLTileUpscaler(AbstractPipeline):
    def create(self, high_res: HighRes, pipeline: Text2Img, model_id: int):
        # temporal hack for upscale model till multicontrolnet support is added
        model = (
            "thibaud/controlnet-openpose-sdxl-1.0"
            if int(model_id) == 2000293
            else "diffusers/controlnet-canny-sdxl-1.0"
        )

        controlnet = ControlNetModel.from_pretrained(model, torch_dtype=torch.float16)
        pipe = DemoFusionSDXLControlNetPipeline(
            **pipeline.pipe.components, controlnet=controlnet
        )
        pipe = pipe.to("cuda")
        pipe.enable_vae_tiling()
        pipe.enable_vae_slicing()
        pipe.enable_xformers_memory_efficient_attention()

        self.high_res = high_res

        self.pipe = pipe

    def process(
        self,
        prompt: str,
        imageUrl: str,
        resize_dimension: int,
        negative_prompt: str,
        width: int,
        height: int,
        model_id: int,
    ):
        if int(model_id) == 2000293:
            condition_image = controlnet.detect_pose(imageUrl)
        else:
            condition_image = download_image(imageUrl)
            condition_image = ControlNet.canny_detect_edge(condition_image)
        img = download_image(imageUrl).resize((width, height))

        img = ImageUtils.resize_image(img, get_base_dimension())
        condition_image = condition_image.resize(img.size)

        img2 = self.__resize_for_condition_image(img, resize_dimension)

        image_lr = self.load_and_process_image(img)
        print("img", img2.size, img.size)
        if int(model_id) == 2000173:
            kwargs = {
                "prompt": prompt,
                "negative_prompt": negative_prompt,
                "image": img2,
                "strength": 0.3,
                "num_inference_steps": 30,
            }
            images = self.high_res.pipe.__call__(**kwargs).images
        else:
            images = self.pipe.__call__(
                image_lr=image_lr,
                prompt=prompt,
                condition_image=condition_image,
                negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
                guidance_scale=11,
                sigma=0.8,
                num_inference_steps=24,
                width=img2.size[0],
                height=img2.size[1],
            )
            images = images[::-1]
        return images, False

    def load_and_process_image(self, pil_image):
        transform = transforms.Compose(
            [
                transforms.Resize((1024, 1024)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )
        image = transform(pil_image)
        image = image.unsqueeze(0).half()
        image = image.to("cuda")
        return image

    def __resize_for_condition_image(self, image: Image.Image, resolution: int):
        input_image = image.convert("RGB")
        W, H = input_image.size
        k = float(resolution) / max(W, H)
        H *= k
        W *= k
        H = int(round(H / 64.0)) * 64
        W = int(round(W / 64.0)) * 64
        img = input_image.resize((W, H), resample=Image.LANCZOS)
        return img