File size: 3,661 Bytes
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aeb9e5
 
 
 
 
 
 
 
a3d6c18
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from io import BytesIO
from typing import List, Union

import torch
from diffusers import (
    ControlNetModel,
    StableDiffusionControlNetInpaintPipeline,
    StableDiffusionInpaintPipeline,
    UniPCMultistepScheduler,
)
from PIL import Image, ImageFilter, ImageOps

import internals.util.image as ImageUtil
from internals.data.result import Result
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.remove_background import RemoveBackgroundV2
from internals.pipelines.upscaler import Upscaler
from internals.util.commons import download_image


class ReplaceBackground:
    def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2):
        controlnet = ControlNetModel.from_pretrained(
            "lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16
        ).to("cuda")
        pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting",
            controlnet=controlnet,
            torch_dtype=torch.float16,
        )
        pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
        pipe.to("cuda")

        upscaler.load()

        self.pipe = pipe
        self.upscaler = upscaler
        self.remove_background = remove_background

    def replace(
        self,
        image: Union[str, Image.Image],
        width: int,
        height: int,
        product_scale_width: float,
        prompt: Union[str, List[str]],
        negative_prompt: Union[str, List[str]],
        resize_dimension: int,
        seed: int,
        steps: int,
    ):
        if type(image) is str:
            image = download_image(image)

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        image = image.convert("RGB")
        image = self.remove_background.remove(image)

        width = int(width)
        height = int(height)

        n_width = int(width * product_scale_width)
        n_height = int(n_width * height // width)

        print(width, height, n_width, n_height)

        image = ImageUtil.padd_image(image, n_width, n_height)

        f_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
        f_image.paste(image, ((width - n_width) // 2, (height - n_height) // 2))
        image = f_image

        mask = image.copy()
        pixdata = mask.load()

        w, h = mask.size
        for y in range(h):
            for x in range(w):
                item = pixdata[x, y]
                if item[3] == 0:
                    pixdata[x, y] = (255, 255, 255, 255)
                else:
                    pixdata[x, y] = (0, 0, 0, 255)

        mask = mask.convert("RGB")

        condition_image = ControlNet.linearart_condition_image(image)

        result = self.pipe.__call__(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=image,
            mask_image=mask,
            control_image=condition_image,
            guidance_scale=9,
            strength=1,
            height=height,
            width=width,
        )
        result = Result.from_result(result)

        images, has_nsfw = result

        if not has_nsfw:
            for i in range(len(images)):
                images[i].paste(image, (0, 0), image)
                w, h = images[i].size
                out_bytes = self.upscaler.upscale(
                    image=images[i],
                    width=w,
                    height=h,
                    face_enhance=False,
                    resize_dimension=resize_dimension,
                )
                images[i] = Image.open(BytesIO(out_bytes)).convert("RGB")

        return (images, has_nsfw)