Spaces:
Paused
Paused
File size: 5,373 Bytes
9235b7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
import torchvision.transforms as T
import numpy as np
from scepter.modules.annotator.registry import ANNOTATORS
from scepter.modules.utils.config import Config
from PIL import Image
def edit_preprocess(processor, device, edit_image, edit_mask):
if edit_image is None or processor is None:
return edit_image
processor = Config(cfg_dict=processor, load=False)
processor = ANNOTATORS.build(processor).to(device)
new_edit_image = processor(np.asarray(edit_image))
processor = processor.to("cpu")
del processor
new_edit_image = Image.fromarray(new_edit_image)
return Image.composite(new_edit_image, edit_image, edit_mask)
class ACEPlusImageProcessor():
def __init__(self, max_aspect_ratio=4, d=16, max_seq_len=1024):
self.max_aspect_ratio = max_aspect_ratio
self.d = d
self.max_seq_len = max_seq_len
self.transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def image_check(self, image):
if image is None:
return image
# preprocess
W, H = image.size
if H / W > self.max_aspect_ratio:
image = T.CenterCrop([int(self.max_aspect_ratio * W), W])(image)
elif W / H > self.max_aspect_ratio:
image = T.CenterCrop([H, int(self.max_aspect_ratio * H)])(image)
return self.transforms(image)
def preprocess(self,
reference_image=None,
edit_image=None,
edit_mask=None,
height=1024,
width=1024,
repainting_scale = 1.0,
keep_pixels = False,
keep_pixels_rate = 0.8,
use_change = False):
reference_image = self.image_check(reference_image)
edit_image = self.image_check(edit_image)
# for reference generation
if edit_image is None:
edit_image = torch.zeros([3, height, width])
edit_mask = torch.ones([1, height, width])
else:
if edit_mask is None:
_, eH, eW = edit_image.shape
edit_mask = np.ones((eH, eW))
else:
edit_mask = np.asarray(edit_mask)
edit_mask = np.where(edit_mask > 128, 1, 0)
edit_mask = edit_mask.astype(
np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
np.float32)
edit_mask = torch.tensor(edit_mask).unsqueeze(0)
edit_image = edit_image * (1 - edit_mask * repainting_scale)
out_h, out_w = edit_image.shape[-2:]
assert edit_mask is not None
if reference_image is not None:
_, H, W = reference_image.shape
_, eH, eW = edit_image.shape
if not keep_pixels:
# align height with edit_image
scale = eH / H
tH, tW = eH, int(W * scale)
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
reference_image)
else:
# padding
if H >= keep_pixels_rate * eH:
tH = int(eH * keep_pixels_rate)
scale = tH/H
tW = int(W * scale)
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
reference_image)
rH, rW = reference_image.shape[-2:]
delta_w = 0
delta_h = eH - rH
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
reference_image = T.Pad(padding, fill=0, padding_mode="constant")(reference_image)
edit_image = torch.cat([reference_image, edit_image], dim=-1)
edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
slice_w = reference_image.shape[-1]
else:
slice_w = 0
H, W = edit_image.shape[-2:]
scale = min(1.0, math.sqrt(self.max_seq_len * 2 / ((H / self.d) * (W / self.d))))
rH = int(H * scale) // self.d * self.d # ensure divisible by self.d
rW = int(W * scale) // self.d * self.d
slice_w = int(slice_w * scale) // self.d * self.d
edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_image)
edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
content_image = edit_image
if use_change:
change_image = edit_image * edit_mask
edit_image = edit_image * (1 - edit_mask)
else:
change_image = None
return edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
def postprocess(self, image, slice_w, out_w, out_h):
w, h = image.size
if slice_w > 0:
output_image = image.crop((slice_w + 30, 0, w, h))
output_image = output_image.resize((out_w, out_h))
else:
output_image = image
return output_image |