File size: 3,442 Bytes
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from io import BytesIO
from typing import List, Union

import torch
from diffusers import (
    ControlNetModel,
    StableDiffusionControlNetInpaintPipeline,
    StableDiffusionInpaintPipeline,
    UniPCMultistepScheduler,
)
from PIL import Image, ImageFilter, ImageOps

import internals.util.image as ImageUtil
from internals.data.result import Result
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.remove_background import RemoveBackgroundV2
from internals.pipelines.upscaler import Upscaler
from internals.util.commons import download_image


class ReplaceBackground:
    def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2):
        controlnet = ControlNetModel.from_pretrained(
            "lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16
        ).to("cuda")
        pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting",
            controlnet=controlnet,
            torch_dtype=torch.float16,
        )
        pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
        pipe.to("cuda")

        upscaler.load()

        self.pipe = pipe
        self.upscaler = upscaler
        self.remove_background = remove_background

    def replace(
        self,
        image: Union[str, Image.Image],
        width: int,
        height: int,
        product_scale_width: float,
        prompt: Union[str, List[str]],
        negative_prompt: Union[str, List[str]],
        resize_dimension: int,
        seed: int,
        steps: int,
    ):
        if type(image) is str:
            image = download_image(image)

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        image = image.convert("RGB")
        image = self.remove_background.remove(image)

        width = int(width)
        height = int(height)

        n_width = int(width * product_scale_width)
        n_height = int(n_width * height // width)

        print(width, height, n_width, n_height)

        image = ImageUtil.padd_image(image, n_width, n_height)

        f_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
        f_image.paste(image, ((width - n_width) // 2, (height - n_height) // 2))
        image = f_image

        mask = image.copy()
        pixdata = mask.load()

        w, h = mask.size
        for y in range(h):
            for x in range(w):
                item = pixdata[x, y]
                if item[3] == 0:
                    pixdata[x, y] = (255, 255, 255, 255)
                else:
                    pixdata[x, y] = (0, 0, 0, 255)

        mask = mask.convert("RGB")

        condition_image = ControlNet.linearart_condition_image(image)

        result = self.pipe.__call__(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=image,
            mask_image=mask,
            control_image=condition_image,
            guidance_scale=9,
            strength=1,
            height=height,
            width=width,
        )
        result = Result.from_result(result)

        images, has_nsfw = result

        if not has_nsfw:
            for i in range(len(images)):
                images[i].paste(image, (0, 0), image)
                out_bytes = self.upscaler.upscale(images[i], resize_dimension)
                images[i] = Image.open(BytesIO(out_bytes)).convert("RGB")

        return (images, has_nsfw)