Spaces:
Running
on
T4
Running
on
T4
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
# Power by Zongsheng Yue 2022-06-09 14:59:55 | |
import torch | |
import random | |
import numpy as np | |
from einops import rearrange | |
def batch_inpainging_from_grad(im_in, mask, gradx, grady): | |
''' | |
Recovering from gradient for batch data (torch tensro). | |
Input: | |
im_in: N x c x h x w, torch tensor, masked image | |
mask: N x 1 x h x w, torch tensor | |
gradx, grady: N x c x h x w, torch tensor, image gradient | |
''' | |
im_out = torch.zeros_like(im_in.data) | |
for ii in range(im_in.shape[0]): | |
im_current, gradx_current, grady_current = [rearrange(x[ii,].cpu().numpy(), 'c h w -> h w c') | |
for x in [im_in, gradx, grady]] | |
mask_current = mask[ii, 0,].cpu().numpy() | |
out_current = inpainting_from_grad(im_current, mask_current, gradx_current, grady_current) | |
im_out[ii,] = torch.from_numpy(rearrange(out_current, 'h w c -> c h w')).to( | |
device=im_in.device, | |
dtype=im_in.dtype | |
) | |
return im_out | |
def inpainting_from_grad(im_in, mask, gradx, grady): | |
''' | |
Input: | |
im_in: h x w x c, masked image, numpy array | |
mask: h x w, image mask, 1 represents missing value | |
gradx: h x w x c, gradient along x-axis, numpy array | |
grady: h x w x c, gradient along y-axis, numpy array | |
Output: | |
im_out: recoverd image | |
''' | |
h, w = im_in.shape[:2] | |
counts_h = np.sum(1-mask, axis=0, keepdims=False) | |
counts_w = np.sum(1-mask, axis=1, keepdims=False) | |
if np.any(counts_h[1:-1,] == h): | |
idx = find_first_index(counts_h[1:-1,], h) + 1 | |
im_out = fill_image_from_gradx(im_in, mask, gradx, idx) | |
elif np.any(counts_w[1:-1,] == w): | |
idx = find_first_index(counts_w[1:-1,], w) + 1 | |
im_out = inpainting_from_grad(im_in.T, mask.T, gradx.T, idx) | |
else: | |
idx = random.choices(list(range(1,w-1)), k=1, weights=counts_h[1:-1])[0] | |
line = fill_line(im_in[:, idx, ], mask[:, idx,], grady[:, idx,]) | |
im_in[:, idx,] = line | |
im_out = fill_image_from_gradx(im_in, mask, gradx, idx) | |
if im_in.ndim > mask.ndim: | |
mask = mask[:, :, None] | |
im_out = im_in + im_out * mask | |
return im_out | |
def fill_image_from_gradx(im_in, mask, gradx, idx): | |
init = np.zeros_like(im_in) | |
init[:, idx,] = im_in[:, idx,] | |
right = np.cumsum(init[:, idx:-1, ] + gradx[:, idx+1:, ], axis=1) | |
left = np.cumsum( | |
init[:, idx:0:-1, ] - gradx[:, idx:0:-1, ], | |
axis=1 | |
)[:, ::-1] | |
center = im_in[:, idx, ][:, None] # h x 1 x 3 | |
im_out = np.concatenate((left, center, right), axis=1) | |
return im_out | |
def fill_line(xx, mm, grad): | |
''' | |
Fill one line from grad. | |
Input: | |
xx: n x c array, masked vector | |
mm: (n,) array, mask, 1 represent missing value | |
grad: (n,) array | |
''' | |
n = xx.shape[0] | |
assert mm.sum() < n | |
if mm.sum() == 0: | |
return xx | |
else: | |
idx1 = find_first_index(mm, 1) | |
if idx1 == 0: | |
idx2 = find_first_index(mm, 0) | |
subx = xx[idx2::-1,].copy() | |
subgrad = grad[idx2::-1, ].copy() | |
subx -= subgrad | |
xx[:idx2,] = np.cumsum(subx, axis=0)[idx2-1::-1,] | |
mm[idx1:idx2,] = 0 | |
else: | |
idx2 = find_first_index(mm[idx1:,], 0) + idx1 | |
subx = xx[idx1-1:idx2-1,].copy() | |
subgrad = grad[idx1:idx2,].copy() | |
subx += subgrad | |
xx[idx1:idx2,] = np.cumsum(subx, axis=0) | |
mm[idx1:idx2,] = 0 | |
return fill_line(xx, mm, grad) | |
def find_first_index(mm, value): | |
''' | |
Input: | |
mm: (n, ) array | |
value: scalar | |
''' | |
try: | |
out = next((idx for idx, val in np.ndenumerate(mm) if val == value))[0] | |
except StopIteration: | |
out = mm.shape[0] | |
return out | |
if __name__ == '__main__': | |
import sys | |
from pathlib import Path | |
sys.path.append(str(Path(__file__).resolve().parents[1])) | |
from utils import util_image | |
from datapipe.masks.train import process_mask | |
# mask_file_names = [x for x in Path('../lama/LaMa_test_images').glob('*mask*.png')] | |
mask_file_names = [x for x in Path('./testdata/inpainting/val/places/').glob('*mask*.png')] | |
file_names = [x.parents[0]/(x.stem.rsplit('_mask',1)[0]+'.png') for x in mask_file_names] | |
for im_path, mask_path in zip(file_names, mask_file_names): | |
im = util_image.imread(im_path, chn='rgb', dtype='float32') | |
mask = process_mask(util_image.imread(mask_path, chn='rgb', dtype='float32')[:, :, 0]) | |
grad_dict = util_image.imgrad(im) | |
im_masked = im * (1 - mask[:, :, None]) | |
im_recover = inpainting_from_grad(im_masked, mask, grad_dict['gradx'], grad_dict['grady']) | |
error_max = np.abs(im_recover -im).max() | |
print('Error Max: {:.2e}'.format(error_max)) | |