File size: 5,373 Bytes
9235b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math

import torch
import torchvision.transforms as T
import numpy as np
from scepter.modules.annotator.registry import ANNOTATORS
from scepter.modules.utils.config import Config
from PIL import Image


def edit_preprocess(processor, device, edit_image, edit_mask):
    if edit_image is None or processor is None:
        return edit_image
    processor = Config(cfg_dict=processor, load=False)
    processor = ANNOTATORS.build(processor).to(device)
    new_edit_image = processor(np.asarray(edit_image))
    processor = processor.to("cpu")
    del processor
    new_edit_image = Image.fromarray(new_edit_image)
    return Image.composite(new_edit_image, edit_image, edit_mask)

class ACEPlusImageProcessor():
    def __init__(self, max_aspect_ratio=4, d=16, max_seq_len=1024):
        self.max_aspect_ratio = max_aspect_ratio
        self.d = d
        self.max_seq_len = max_seq_len
        self.transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def image_check(self, image):
        if image is None:
            return image
        # preprocess
        W, H = image.size
        if H / W > self.max_aspect_ratio:
            image = T.CenterCrop([int(self.max_aspect_ratio * W), W])(image)
        elif W / H > self.max_aspect_ratio:
            image = T.CenterCrop([H, int(self.max_aspect_ratio * H)])(image)
        return self.transforms(image)


    def preprocess(self,
                   reference_image=None,
                   edit_image=None,
                   edit_mask=None,
                   height=1024,
                   width=1024,
                   repainting_scale = 1.0,
                   keep_pixels = False,
                   keep_pixels_rate = 0.8,
                   use_change = False):
        reference_image = self.image_check(reference_image)
        edit_image = self.image_check(edit_image)
        # for reference generation
        if edit_image is None:
            edit_image = torch.zeros([3, height, width])
            edit_mask = torch.ones([1, height, width])
        else:
            if edit_mask is None:
                _, eH, eW = edit_image.shape
                edit_mask = np.ones((eH, eW))
            else:
                edit_mask = np.asarray(edit_mask)
                edit_mask = np.where(edit_mask > 128, 1, 0)
            edit_mask = edit_mask.astype(
                np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
                np.float32)
            edit_mask = torch.tensor(edit_mask).unsqueeze(0)

        edit_image = edit_image * (1 - edit_mask * repainting_scale)


        out_h, out_w = edit_image.shape[-2:]

        assert edit_mask is not None
        if reference_image is not None:
            _, H, W = reference_image.shape
            _, eH, eW = edit_image.shape
            if not keep_pixels:
                # align height with edit_image
                scale = eH / H
                tH, tW = eH, int(W * scale)
                reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
                    reference_image)
            else:
                # padding
                if H >= keep_pixels_rate * eH:
                    tH = int(eH * keep_pixels_rate)
                    scale = tH/H
                    tW = int(W * scale)
                    reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
                        reference_image)
                rH, rW = reference_image.shape[-2:]
                delta_w = 0
                delta_h = eH - rH
                padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
                reference_image = T.Pad(padding, fill=0, padding_mode="constant")(reference_image)
            edit_image = torch.cat([reference_image, edit_image], dim=-1)
            edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
            slice_w = reference_image.shape[-1]
        else:
            slice_w = 0

        H, W = edit_image.shape[-2:]
        scale = min(1.0, math.sqrt(self.max_seq_len * 2 / ((H / self.d) * (W / self.d))))
        rH = int(H * scale) // self.d * self.d  # ensure divisible by self.d
        rW = int(W * scale) // self.d * self.d
        slice_w = int(slice_w * scale) // self.d * self.d

        edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_image)
        edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
        content_image = edit_image
        if use_change:
            change_image = edit_image * edit_mask
            edit_image = edit_image * (1 - edit_mask)
        else:
            change_image = None
        return edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w


    def postprocess(self, image, slice_w, out_w, out_h):
        w, h = image.size
        if slice_w > 0:
            output_image = image.crop((slice_w + 30, 0, w, h))
            output_image = output_image.resize((out_w, out_h))
        else:
            output_image = image
        return output_image