add saicinpainting
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- saicinpainting/__init__.py +0 -0
- saicinpainting/__pycache__/__init__.cpython-38.pyc +0 -0
- saicinpainting/__pycache__/__init__.cpython-39.pyc +0 -0
- saicinpainting/__pycache__/utils.cpython-39.pyc +0 -0
- saicinpainting/evaluation/__init__.py +33 -0
- saicinpainting/evaluation/__pycache__/__init__.cpython-38.pyc +0 -0
- saicinpainting/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
- saicinpainting/evaluation/__pycache__/data.cpython-39.pyc +0 -0
- saicinpainting/evaluation/__pycache__/evaluator.cpython-39.pyc +0 -0
- saicinpainting/evaluation/__pycache__/refinement.cpython-39.pyc +0 -0
- saicinpainting/evaluation/__pycache__/utils.cpython-39.pyc +0 -0
- saicinpainting/evaluation/data.py +168 -0
- saicinpainting/evaluation/evaluator.py +220 -0
- saicinpainting/evaluation/losses/__init__.py +0 -0
- saicinpainting/evaluation/losses/__pycache__/__init__.cpython-39.pyc +0 -0
- saicinpainting/evaluation/losses/__pycache__/base_loss.cpython-39.pyc +0 -0
- saicinpainting/evaluation/losses/__pycache__/lpips.cpython-39.pyc +0 -0
- saicinpainting/evaluation/losses/__pycache__/ssim.cpython-39.pyc +0 -0
- saicinpainting/evaluation/losses/base_loss.py +528 -0
- saicinpainting/evaluation/losses/fid/__init__.py +0 -0
- saicinpainting/evaluation/losses/fid/__pycache__/__init__.cpython-39.pyc +0 -0
- saicinpainting/evaluation/losses/fid/__pycache__/inception.cpython-39.pyc +0 -0
- saicinpainting/evaluation/losses/fid/fid_score.py +328 -0
- saicinpainting/evaluation/losses/fid/inception.py +323 -0
- saicinpainting/evaluation/losses/lpips.py +891 -0
- saicinpainting/evaluation/losses/ssim.py +74 -0
- saicinpainting/evaluation/masks/README.md +27 -0
- saicinpainting/evaluation/masks/__init__.py +0 -0
- saicinpainting/evaluation/masks/__pycache__/__init__.cpython-39.pyc +0 -0
- saicinpainting/evaluation/masks/__pycache__/mask.cpython-39.pyc +0 -0
- saicinpainting/evaluation/masks/countless/.gitignore +1 -0
- saicinpainting/evaluation/masks/countless/README.md +25 -0
- saicinpainting/evaluation/masks/countless/__init__.py +0 -0
- saicinpainting/evaluation/masks/countless/__pycache__/__init__.cpython-39.pyc +0 -0
- saicinpainting/evaluation/masks/countless/__pycache__/countless2d.cpython-39.pyc +0 -0
- saicinpainting/evaluation/masks/countless/countless2d.py +529 -0
- saicinpainting/evaluation/masks/countless/countless3d.py +356 -0
- saicinpainting/evaluation/masks/countless/requirements.txt +7 -0
- saicinpainting/evaluation/masks/countless/test.py +195 -0
- saicinpainting/evaluation/masks/mask.py +429 -0
- saicinpainting/evaluation/refinement.py +314 -0
- saicinpainting/evaluation/utils.py +28 -0
- saicinpainting/evaluation/vis.py +37 -0
- saicinpainting/training/__init__.py +0 -0
- saicinpainting/training/__pycache__/__init__.cpython-39.pyc +0 -0
- saicinpainting/training/data/__init__.py +0 -0
- saicinpainting/training/data/__pycache__/__init__.cpython-39.pyc +0 -0
- saicinpainting/training/data/__pycache__/aug.cpython-39.pyc +0 -0
- saicinpainting/training/data/__pycache__/datasets.cpython-39.pyc +0 -0
- saicinpainting/training/data/__pycache__/masks.cpython-39.pyc +0 -0
saicinpainting/__init__.py
ADDED
File without changes
|
saicinpainting/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (143 Bytes). View file
|
|
saicinpainting/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (143 Bytes). View file
|
|
saicinpainting/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (6.07 kB). View file
|
|
saicinpainting/evaluation/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1
|
6 |
+
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
|
7 |
+
|
8 |
+
|
9 |
+
def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs):
|
10 |
+
logging.info(f'Make evaluator {kind}')
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
metrics = {}
|
13 |
+
if ssim:
|
14 |
+
metrics['ssim'] = SSIMScore()
|
15 |
+
if lpips:
|
16 |
+
metrics['lpips'] = LPIPSScore()
|
17 |
+
if fid:
|
18 |
+
metrics['fid'] = FIDScore().to(device)
|
19 |
+
|
20 |
+
if integral_kind is None:
|
21 |
+
integral_func = None
|
22 |
+
elif integral_kind == 'ssim_fid100_f1':
|
23 |
+
integral_func = ssim_fid100_f1
|
24 |
+
elif integral_kind == 'lpips_fid100_f1':
|
25 |
+
integral_func = lpips_fid100_f1
|
26 |
+
else:
|
27 |
+
raise ValueError(f'Unexpected integral_kind={integral_kind}')
|
28 |
+
|
29 |
+
if kind == 'default':
|
30 |
+
return InpaintingEvaluatorOnline(scores=metrics,
|
31 |
+
integral_func=integral_func,
|
32 |
+
integral_title=integral_kind,
|
33 |
+
**kwargs)
|
saicinpainting/evaluation/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.06 kB). View file
|
|
saicinpainting/evaluation/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (1.06 kB). View file
|
|
saicinpainting/evaluation/__pycache__/data.cpython-39.pyc
ADDED
Binary file (7.26 kB). View file
|
|
saicinpainting/evaluation/__pycache__/evaluator.cpython-39.pyc
ADDED
Binary file (7.95 kB). View file
|
|
saicinpainting/evaluation/__pycache__/refinement.cpython-39.pyc
ADDED
Binary file (9.64 kB). View file
|
|
saicinpainting/evaluation/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (1.36 kB). View file
|
|
saicinpainting/evaluation/data.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import PIL.Image as Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def load_image(fname, mode='RGB', return_orig=False):
|
13 |
+
img = np.array(Image.open(fname).convert(mode))
|
14 |
+
if img.ndim == 3:
|
15 |
+
img = np.transpose(img, (2, 0, 1))
|
16 |
+
out_img = img.astype('float32') / 255
|
17 |
+
if return_orig:
|
18 |
+
return out_img, img
|
19 |
+
else:
|
20 |
+
return out_img
|
21 |
+
|
22 |
+
|
23 |
+
def ceil_modulo(x, mod):
|
24 |
+
if x % mod == 0:
|
25 |
+
return x
|
26 |
+
return (x // mod + 1) * mod
|
27 |
+
|
28 |
+
|
29 |
+
def pad_img_to_modulo(img, mod):
|
30 |
+
channels, height, width = img.shape
|
31 |
+
out_height = ceil_modulo(height, mod)
|
32 |
+
out_width = ceil_modulo(width, mod)
|
33 |
+
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
|
34 |
+
|
35 |
+
|
36 |
+
def pad_tensor_to_modulo(img, mod):
|
37 |
+
batch_size, channels, height, width = img.shape
|
38 |
+
out_height = ceil_modulo(height, mod)
|
39 |
+
out_width = ceil_modulo(width, mod)
|
40 |
+
return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect')
|
41 |
+
|
42 |
+
|
43 |
+
def scale_image(img, factor, interpolation=cv2.INTER_AREA):
|
44 |
+
if img.shape[0] == 1:
|
45 |
+
img = img[0]
|
46 |
+
else:
|
47 |
+
img = np.transpose(img, (1, 2, 0))
|
48 |
+
|
49 |
+
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
|
50 |
+
|
51 |
+
if img.ndim == 2:
|
52 |
+
img = img[None, ...]
|
53 |
+
else:
|
54 |
+
img = np.transpose(img, (2, 0, 1))
|
55 |
+
return img
|
56 |
+
|
57 |
+
|
58 |
+
class InpaintingDataset(Dataset):
|
59 |
+
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
|
60 |
+
self.datadir = datadir
|
61 |
+
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, '**', '*mask*.png'), recursive=True)))
|
62 |
+
self.img_filenames = [fname.rsplit('_mask', 1)[0] + img_suffix for fname in self.mask_filenames]
|
63 |
+
self.pad_out_to_modulo = pad_out_to_modulo
|
64 |
+
self.scale_factor = scale_factor
|
65 |
+
|
66 |
+
def __len__(self):
|
67 |
+
return len(self.mask_filenames)
|
68 |
+
|
69 |
+
def __getitem__(self, i):
|
70 |
+
image = load_image(self.img_filenames[i], mode='RGB')
|
71 |
+
mask = load_image(self.mask_filenames[i], mode='L')
|
72 |
+
result = dict(image=image, mask=mask[None, ...])
|
73 |
+
|
74 |
+
if self.scale_factor is not None:
|
75 |
+
result['image'] = scale_image(result['image'], self.scale_factor)
|
76 |
+
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
|
77 |
+
|
78 |
+
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
|
79 |
+
result['unpad_to_size'] = result['image'].shape[1:]
|
80 |
+
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
|
81 |
+
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
|
82 |
+
|
83 |
+
return result
|
84 |
+
|
85 |
+
class OurInpaintingDataset(Dataset):
|
86 |
+
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
|
87 |
+
self.datadir = datadir
|
88 |
+
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, 'mask', '**', '*mask*.png'), recursive=True)))
|
89 |
+
self.img_filenames = [os.path.join(self.datadir, 'img', os.path.basename(fname.rsplit('-', 1)[0].rsplit('_', 1)[0]) + '.png') for fname in self.mask_filenames]
|
90 |
+
self.pad_out_to_modulo = pad_out_to_modulo
|
91 |
+
self.scale_factor = scale_factor
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.mask_filenames)
|
95 |
+
|
96 |
+
def __getitem__(self, i):
|
97 |
+
result = dict(image=load_image(self.img_filenames[i], mode='RGB'),
|
98 |
+
mask=load_image(self.mask_filenames[i], mode='L')[None, ...])
|
99 |
+
|
100 |
+
if self.scale_factor is not None:
|
101 |
+
result['image'] = scale_image(result['image'], self.scale_factor)
|
102 |
+
result['mask'] = scale_image(result['mask'], self.scale_factor)
|
103 |
+
|
104 |
+
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
|
105 |
+
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
|
106 |
+
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
|
107 |
+
|
108 |
+
return result
|
109 |
+
|
110 |
+
class PrecomputedInpaintingResultsDataset(InpaintingDataset):
|
111 |
+
def __init__(self, datadir, predictdir, inpainted_suffix='_inpainted.jpg', **kwargs):
|
112 |
+
super().__init__(datadir, **kwargs)
|
113 |
+
if not datadir.endswith('/'):
|
114 |
+
datadir += '/'
|
115 |
+
self.predictdir = predictdir
|
116 |
+
self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
|
117 |
+
for fname in self.mask_filenames]
|
118 |
+
|
119 |
+
def __getitem__(self, i):
|
120 |
+
result = super().__getitem__(i)
|
121 |
+
result['inpainted'] = load_image(self.pred_filenames[i])
|
122 |
+
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
|
123 |
+
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
|
124 |
+
return result
|
125 |
+
|
126 |
+
class OurPrecomputedInpaintingResultsDataset(OurInpaintingDataset):
|
127 |
+
def __init__(self, datadir, predictdir, inpainted_suffix="png", **kwargs):
|
128 |
+
super().__init__(datadir, **kwargs)
|
129 |
+
if not datadir.endswith('/'):
|
130 |
+
datadir += '/'
|
131 |
+
self.predictdir = predictdir
|
132 |
+
self.pred_filenames = [os.path.join(predictdir, os.path.basename(os.path.splitext(fname)[0]) + f'_inpainted.{inpainted_suffix}')
|
133 |
+
for fname in self.mask_filenames]
|
134 |
+
# self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
|
135 |
+
# for fname in self.mask_filenames]
|
136 |
+
|
137 |
+
def __getitem__(self, i):
|
138 |
+
result = super().__getitem__(i)
|
139 |
+
result['inpainted'] = self.file_loader(self.pred_filenames[i])
|
140 |
+
|
141 |
+
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
|
142 |
+
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
|
143 |
+
return result
|
144 |
+
|
145 |
+
class InpaintingEvalOnlineDataset(Dataset):
|
146 |
+
def __init__(self, indir, mask_generator, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None, **kwargs):
|
147 |
+
self.indir = indir
|
148 |
+
self.mask_generator = mask_generator
|
149 |
+
self.img_filenames = sorted(list(glob.glob(os.path.join(self.indir, '**', f'*{img_suffix}' ), recursive=True)))
|
150 |
+
self.pad_out_to_modulo = pad_out_to_modulo
|
151 |
+
self.scale_factor = scale_factor
|
152 |
+
|
153 |
+
def __len__(self):
|
154 |
+
return len(self.img_filenames)
|
155 |
+
|
156 |
+
def __getitem__(self, i):
|
157 |
+
img, raw_image = load_image(self.img_filenames[i], mode='RGB', return_orig=True)
|
158 |
+
mask = self.mask_generator(img, raw_image=raw_image)
|
159 |
+
result = dict(image=img, mask=mask)
|
160 |
+
|
161 |
+
if self.scale_factor is not None:
|
162 |
+
result['image'] = scale_image(result['image'], self.scale_factor)
|
163 |
+
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
|
164 |
+
|
165 |
+
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
|
166 |
+
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
|
167 |
+
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
|
168 |
+
return result
|
saicinpainting/evaluation/evaluator.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import tqdm
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
|
11 |
+
from saicinpainting.evaluation.utils import move_to_device
|
12 |
+
|
13 |
+
LOGGER = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class InpaintingEvaluator():
|
17 |
+
def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda',
|
18 |
+
integral_func=None, integral_title=None, clamp_image_range=None):
|
19 |
+
"""
|
20 |
+
:param dataset: torch.utils.data.Dataset which contains images and masks
|
21 |
+
:param scores: dict {score_name: EvaluatorScore object}
|
22 |
+
:param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples
|
23 |
+
which are defined by share of area occluded by mask
|
24 |
+
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
|
25 |
+
:param batch_size: batch_size for the dataloader
|
26 |
+
:param device: device to use
|
27 |
+
"""
|
28 |
+
self.scores = scores
|
29 |
+
self.dataset = dataset
|
30 |
+
|
31 |
+
self.area_grouping = area_grouping
|
32 |
+
self.bins = bins
|
33 |
+
|
34 |
+
self.device = torch.device(device)
|
35 |
+
|
36 |
+
self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size)
|
37 |
+
|
38 |
+
self.integral_func = integral_func
|
39 |
+
self.integral_title = integral_title
|
40 |
+
self.clamp_image_range = clamp_image_range
|
41 |
+
|
42 |
+
def _get_bin_edges(self):
|
43 |
+
bin_edges = np.linspace(0, 1, self.bins + 1)
|
44 |
+
|
45 |
+
num_digits = max(0, math.ceil(math.log10(self.bins)) - 1)
|
46 |
+
interval_names = []
|
47 |
+
for idx_bin in range(self.bins):
|
48 |
+
start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \
|
49 |
+
round(100 * bin_edges[idx_bin + 1], num_digits)
|
50 |
+
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
|
51 |
+
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
|
52 |
+
interval_names.append("{0}-{1}%".format(start_percent, end_percent))
|
53 |
+
|
54 |
+
groups = []
|
55 |
+
for batch in self.dataloader:
|
56 |
+
mask = batch['mask']
|
57 |
+
batch_size = mask.shape[0]
|
58 |
+
area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1)
|
59 |
+
bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1
|
60 |
+
# corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element
|
61 |
+
bin_indices[bin_indices == self.bins] = self.bins - 1
|
62 |
+
groups.append(bin_indices)
|
63 |
+
groups = np.hstack(groups)
|
64 |
+
|
65 |
+
return groups, interval_names
|
66 |
+
|
67 |
+
def evaluate(self, model=None):
|
68 |
+
"""
|
69 |
+
:param model: callable with signature (image_batch, mask_batch); should return inpainted_batch
|
70 |
+
:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
|
71 |
+
name of the particular group arranged by area of mask (e.g. '10-20%')
|
72 |
+
and score statistics for the group as values.
|
73 |
+
"""
|
74 |
+
results = dict()
|
75 |
+
if self.area_grouping:
|
76 |
+
groups, interval_names = self._get_bin_edges()
|
77 |
+
else:
|
78 |
+
groups = None
|
79 |
+
|
80 |
+
for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'):
|
81 |
+
score.to(self.device)
|
82 |
+
with torch.no_grad():
|
83 |
+
score.reset()
|
84 |
+
for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False):
|
85 |
+
batch = move_to_device(batch, self.device)
|
86 |
+
image_batch, mask_batch = batch['image'], batch['mask']
|
87 |
+
if self.clamp_image_range is not None:
|
88 |
+
image_batch = torch.clamp(image_batch,
|
89 |
+
min=self.clamp_image_range[0],
|
90 |
+
max=self.clamp_image_range[1])
|
91 |
+
if model is None:
|
92 |
+
assert 'inpainted' in batch, \
|
93 |
+
'Model is None, so we expected precomputed inpainting results at key "inpainted"'
|
94 |
+
inpainted_batch = batch['inpainted']
|
95 |
+
else:
|
96 |
+
inpainted_batch = model(image_batch, mask_batch)
|
97 |
+
score(inpainted_batch, image_batch, mask_batch)
|
98 |
+
total_results, group_results = score.get_value(groups=groups)
|
99 |
+
|
100 |
+
results[(score_name, 'total')] = total_results
|
101 |
+
if groups is not None:
|
102 |
+
for group_index, group_values in group_results.items():
|
103 |
+
group_name = interval_names[group_index]
|
104 |
+
results[(score_name, group_name)] = group_values
|
105 |
+
|
106 |
+
if self.integral_func is not None:
|
107 |
+
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
|
108 |
+
|
109 |
+
return results
|
110 |
+
|
111 |
+
|
112 |
+
def ssim_fid100_f1(metrics, fid_scale=100):
|
113 |
+
ssim = metrics[('ssim', 'total')]['mean']
|
114 |
+
fid = metrics[('fid', 'total')]['mean']
|
115 |
+
fid_rel = max(0, fid_scale - fid) / fid_scale
|
116 |
+
f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
|
117 |
+
return f1
|
118 |
+
|
119 |
+
|
120 |
+
def lpips_fid100_f1(metrics, fid_scale=100):
|
121 |
+
neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better
|
122 |
+
fid = metrics[('fid', 'total')]['mean']
|
123 |
+
fid_rel = max(0, fid_scale - fid) / fid_scale
|
124 |
+
f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3)
|
125 |
+
return f1
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
class InpaintingEvaluatorOnline(nn.Module):
|
130 |
+
def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted',
|
131 |
+
integral_func=None, integral_title=None, clamp_image_range=None):
|
132 |
+
"""
|
133 |
+
:param scores: dict {score_name: EvaluatorScore object}
|
134 |
+
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
|
135 |
+
:param device: device to use
|
136 |
+
"""
|
137 |
+
super().__init__()
|
138 |
+
LOGGER.info(f'{type(self)} init called')
|
139 |
+
self.scores = nn.ModuleDict(scores)
|
140 |
+
self.image_key = image_key
|
141 |
+
self.inpainted_key = inpainted_key
|
142 |
+
self.bins_num = bins
|
143 |
+
self.bin_edges = np.linspace(0, 1, self.bins_num + 1)
|
144 |
+
|
145 |
+
num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1)
|
146 |
+
self.interval_names = []
|
147 |
+
for idx_bin in range(self.bins_num):
|
148 |
+
start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \
|
149 |
+
round(100 * self.bin_edges[idx_bin + 1], num_digits)
|
150 |
+
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
|
151 |
+
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
|
152 |
+
self.interval_names.append("{0}-{1}%".format(start_percent, end_percent))
|
153 |
+
|
154 |
+
self.groups = []
|
155 |
+
|
156 |
+
self.integral_func = integral_func
|
157 |
+
self.integral_title = integral_title
|
158 |
+
self.clamp_image_range = clamp_image_range
|
159 |
+
|
160 |
+
LOGGER.info(f'{type(self)} init done')
|
161 |
+
|
162 |
+
def _get_bins(self, mask_batch):
|
163 |
+
batch_size = mask_batch.shape[0]
|
164 |
+
area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy()
|
165 |
+
bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1)
|
166 |
+
return bin_indices
|
167 |
+
|
168 |
+
def forward(self, batch: Dict[str, torch.Tensor]):
|
169 |
+
"""
|
170 |
+
Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end
|
171 |
+
:param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key)
|
172 |
+
"""
|
173 |
+
result = {}
|
174 |
+
with torch.no_grad():
|
175 |
+
image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key]
|
176 |
+
if self.clamp_image_range is not None:
|
177 |
+
image_batch = torch.clamp(image_batch,
|
178 |
+
min=self.clamp_image_range[0],
|
179 |
+
max=self.clamp_image_range[1])
|
180 |
+
self.groups.extend(self._get_bins(mask_batch))
|
181 |
+
|
182 |
+
for score_name, score in self.scores.items():
|
183 |
+
result[score_name] = score(inpainted_batch, image_batch, mask_batch)
|
184 |
+
return result
|
185 |
+
|
186 |
+
def process_batch(self, batch: Dict[str, torch.Tensor]):
|
187 |
+
return self(batch)
|
188 |
+
|
189 |
+
def evaluation_end(self, states=None):
|
190 |
+
""":return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
|
191 |
+
name of the particular group arranged by area of mask (e.g. '10-20%')
|
192 |
+
and score statistics for the group as values.
|
193 |
+
"""
|
194 |
+
LOGGER.info(f'{type(self)}: evaluation_end called')
|
195 |
+
|
196 |
+
self.groups = np.array(self.groups)
|
197 |
+
|
198 |
+
results = {}
|
199 |
+
for score_name, score in self.scores.items():
|
200 |
+
LOGGER.info(f'Getting value of {score_name}')
|
201 |
+
cur_states = [s[score_name] for s in states] if states is not None else None
|
202 |
+
total_results, group_results = score.get_value(groups=self.groups, states=cur_states)
|
203 |
+
LOGGER.info(f'Getting value of {score_name} done')
|
204 |
+
results[(score_name, 'total')] = total_results
|
205 |
+
|
206 |
+
for group_index, group_values in group_results.items():
|
207 |
+
group_name = self.interval_names[group_index]
|
208 |
+
results[(score_name, group_name)] = group_values
|
209 |
+
|
210 |
+
if self.integral_func is not None:
|
211 |
+
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
|
212 |
+
|
213 |
+
LOGGER.info(f'{type(self)}: reset scores')
|
214 |
+
self.groups = []
|
215 |
+
for sc in self.scores.values():
|
216 |
+
sc.reset()
|
217 |
+
LOGGER.info(f'{type(self)}: reset scores done')
|
218 |
+
|
219 |
+
LOGGER.info(f'{type(self)}: evaluation_end done')
|
220 |
+
return results
|
saicinpainting/evaluation/losses/__init__.py
ADDED
File without changes
|
saicinpainting/evaluation/losses/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (161 Bytes). View file
|
|
saicinpainting/evaluation/losses/__pycache__/base_loss.cpython-39.pyc
ADDED
Binary file (17.6 kB). View file
|
|
saicinpainting/evaluation/losses/__pycache__/lpips.cpython-39.pyc
ADDED
Binary file (29.5 kB). View file
|
|
saicinpainting/evaluation/losses/__pycache__/ssim.cpython-39.pyc
ADDED
Binary file (2.73 kB). View file
|
|
saicinpainting/evaluation/losses/base_loss.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from abc import abstractmethod, ABC
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import sklearn
|
6 |
+
import sklearn.svm
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
from scipy import linalg
|
12 |
+
|
13 |
+
from models.ade20k import SegmentationModule, NUM_CLASS, segm_options
|
14 |
+
from .fid.inception import InceptionV3
|
15 |
+
from .lpips import PerceptualLoss
|
16 |
+
from .ssim import SSIM
|
17 |
+
|
18 |
+
LOGGER = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
def get_groupings(groups):
|
22 |
+
"""
|
23 |
+
:param groups: group numbers for respective elements
|
24 |
+
:return: dict of kind {group_idx: indices of the corresponding group elements}
|
25 |
+
"""
|
26 |
+
label_groups, count_groups = np.unique(groups, return_counts=True)
|
27 |
+
|
28 |
+
indices = np.argsort(groups)
|
29 |
+
|
30 |
+
grouping = dict()
|
31 |
+
cur_start = 0
|
32 |
+
for label, count in zip(label_groups, count_groups):
|
33 |
+
cur_end = cur_start + count
|
34 |
+
cur_indices = indices[cur_start:cur_end]
|
35 |
+
grouping[label] = cur_indices
|
36 |
+
cur_start = cur_end
|
37 |
+
return grouping
|
38 |
+
|
39 |
+
|
40 |
+
class EvaluatorScore(nn.Module):
|
41 |
+
@abstractmethod
|
42 |
+
def forward(self, pred_batch, target_batch, mask):
|
43 |
+
pass
|
44 |
+
|
45 |
+
@abstractmethod
|
46 |
+
def get_value(self, groups=None, states=None):
|
47 |
+
pass
|
48 |
+
|
49 |
+
@abstractmethod
|
50 |
+
def reset(self):
|
51 |
+
pass
|
52 |
+
|
53 |
+
|
54 |
+
class PairwiseScore(EvaluatorScore, ABC):
|
55 |
+
def __init__(self):
|
56 |
+
super().__init__()
|
57 |
+
self.individual_values = None
|
58 |
+
|
59 |
+
def get_value(self, groups=None, states=None):
|
60 |
+
"""
|
61 |
+
:param groups:
|
62 |
+
:return:
|
63 |
+
total_results: dict of kind {'mean': score mean, 'std': score std}
|
64 |
+
group_results: None, if groups is None;
|
65 |
+
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
|
66 |
+
"""
|
67 |
+
individual_values = torch.cat(states, dim=-1).reshape(-1).cpu().numpy() if states is not None \
|
68 |
+
else self.individual_values
|
69 |
+
|
70 |
+
total_results = {
|
71 |
+
'mean': individual_values.mean(),
|
72 |
+
'std': individual_values.std()
|
73 |
+
}
|
74 |
+
|
75 |
+
if groups is None:
|
76 |
+
return total_results, None
|
77 |
+
|
78 |
+
group_results = dict()
|
79 |
+
grouping = get_groupings(groups)
|
80 |
+
for label, index in grouping.items():
|
81 |
+
group_scores = individual_values[index]
|
82 |
+
group_results[label] = {
|
83 |
+
'mean': group_scores.mean(),
|
84 |
+
'std': group_scores.std()
|
85 |
+
}
|
86 |
+
return total_results, group_results
|
87 |
+
|
88 |
+
def reset(self):
|
89 |
+
self.individual_values = []
|
90 |
+
|
91 |
+
|
92 |
+
class SSIMScore(PairwiseScore):
|
93 |
+
def __init__(self, window_size=11):
|
94 |
+
super().__init__()
|
95 |
+
self.score = SSIM(window_size=window_size, size_average=False).eval()
|
96 |
+
self.reset()
|
97 |
+
|
98 |
+
def forward(self, pred_batch, target_batch, mask=None):
|
99 |
+
batch_values = self.score(pred_batch, target_batch)
|
100 |
+
self.individual_values = np.hstack([
|
101 |
+
self.individual_values, batch_values.detach().cpu().numpy()
|
102 |
+
])
|
103 |
+
return batch_values
|
104 |
+
|
105 |
+
|
106 |
+
class LPIPSScore(PairwiseScore):
|
107 |
+
def __init__(self, model='net-lin', net='vgg', model_path=None, use_gpu=True):
|
108 |
+
super().__init__()
|
109 |
+
self.score = PerceptualLoss(model=model, net=net, model_path=model_path,
|
110 |
+
use_gpu=use_gpu, spatial=False).eval()
|
111 |
+
self.reset()
|
112 |
+
|
113 |
+
def forward(self, pred_batch, target_batch, mask=None):
|
114 |
+
batch_values = self.score(pred_batch, target_batch).flatten()
|
115 |
+
self.individual_values = np.hstack([
|
116 |
+
self.individual_values, batch_values.detach().cpu().numpy()
|
117 |
+
])
|
118 |
+
return batch_values
|
119 |
+
|
120 |
+
|
121 |
+
def fid_calculate_activation_statistics(act):
|
122 |
+
mu = np.mean(act, axis=0)
|
123 |
+
sigma = np.cov(act, rowvar=False)
|
124 |
+
return mu, sigma
|
125 |
+
|
126 |
+
|
127 |
+
def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6):
|
128 |
+
mu1, sigma1 = fid_calculate_activation_statistics(activations_pred)
|
129 |
+
mu2, sigma2 = fid_calculate_activation_statistics(activations_target)
|
130 |
+
|
131 |
+
diff = mu1 - mu2
|
132 |
+
|
133 |
+
# Product might be almost singular
|
134 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
135 |
+
if not np.isfinite(covmean).all():
|
136 |
+
msg = ('fid calculation produces singular product; '
|
137 |
+
'adding %s to diagonal of cov estimates') % eps
|
138 |
+
LOGGER.warning(msg)
|
139 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
140 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
141 |
+
|
142 |
+
# Numerical error might give slight imaginary component
|
143 |
+
if np.iscomplexobj(covmean):
|
144 |
+
# if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
145 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2):
|
146 |
+
m = np.max(np.abs(covmean.imag))
|
147 |
+
raise ValueError('Imaginary component {}'.format(m))
|
148 |
+
covmean = covmean.real
|
149 |
+
|
150 |
+
tr_covmean = np.trace(covmean)
|
151 |
+
|
152 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
153 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
154 |
+
|
155 |
+
|
156 |
+
class FIDScore(EvaluatorScore):
|
157 |
+
def __init__(self, dims=2048, eps=1e-6):
|
158 |
+
LOGGER.info("FIDscore init called")
|
159 |
+
super().__init__()
|
160 |
+
if getattr(FIDScore, '_MODEL', None) is None:
|
161 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
162 |
+
FIDScore._MODEL = InceptionV3([block_idx]).eval()
|
163 |
+
self.model = FIDScore._MODEL
|
164 |
+
self.eps = eps
|
165 |
+
self.reset()
|
166 |
+
LOGGER.info("FIDscore init done")
|
167 |
+
|
168 |
+
def forward(self, pred_batch, target_batch, mask=None):
|
169 |
+
activations_pred = self._get_activations(pred_batch)
|
170 |
+
activations_target = self._get_activations(target_batch)
|
171 |
+
|
172 |
+
self.activations_pred.append(activations_pred.detach().cpu())
|
173 |
+
self.activations_target.append(activations_target.detach().cpu())
|
174 |
+
|
175 |
+
return activations_pred, activations_target
|
176 |
+
|
177 |
+
def get_value(self, groups=None, states=None):
|
178 |
+
LOGGER.info("FIDscore get_value called")
|
179 |
+
activations_pred, activations_target = zip(*states) if states is not None \
|
180 |
+
else (self.activations_pred, self.activations_target)
|
181 |
+
activations_pred = torch.cat(activations_pred).cpu().numpy()
|
182 |
+
activations_target = torch.cat(activations_target).cpu().numpy()
|
183 |
+
|
184 |
+
total_distance = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps)
|
185 |
+
total_results = dict(mean=total_distance)
|
186 |
+
|
187 |
+
if groups is None:
|
188 |
+
group_results = None
|
189 |
+
else:
|
190 |
+
group_results = dict()
|
191 |
+
grouping = get_groupings(groups)
|
192 |
+
for label, index in grouping.items():
|
193 |
+
if len(index) > 1:
|
194 |
+
group_distance = calculate_frechet_distance(activations_pred[index], activations_target[index],
|
195 |
+
eps=self.eps)
|
196 |
+
group_results[label] = dict(mean=group_distance)
|
197 |
+
|
198 |
+
else:
|
199 |
+
group_results[label] = dict(mean=float('nan'))
|
200 |
+
|
201 |
+
self.reset()
|
202 |
+
|
203 |
+
LOGGER.info("FIDscore get_value done")
|
204 |
+
|
205 |
+
return total_results, group_results
|
206 |
+
|
207 |
+
def reset(self):
|
208 |
+
self.activations_pred = []
|
209 |
+
self.activations_target = []
|
210 |
+
|
211 |
+
def _get_activations(self, batch):
|
212 |
+
activations = self.model(batch)[0]
|
213 |
+
if activations.shape[2] != 1 or activations.shape[3] != 1:
|
214 |
+
assert False, \
|
215 |
+
'We should not have got here, because Inception always scales inputs to 299x299'
|
216 |
+
# activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1))
|
217 |
+
activations = activations.squeeze(-1).squeeze(-1)
|
218 |
+
return activations
|
219 |
+
|
220 |
+
|
221 |
+
class SegmentationAwareScore(EvaluatorScore):
|
222 |
+
def __init__(self, weights_path):
|
223 |
+
super().__init__()
|
224 |
+
self.segm_network = SegmentationModule(weights_path=weights_path, use_default_normalization=True).eval()
|
225 |
+
self.target_class_freq_by_image_total = []
|
226 |
+
self.target_class_freq_by_image_mask = []
|
227 |
+
self.pred_class_freq_by_image_mask = []
|
228 |
+
|
229 |
+
def forward(self, pred_batch, target_batch, mask):
|
230 |
+
pred_segm_flat = self.segm_network.predict(pred_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy()
|
231 |
+
target_segm_flat = self.segm_network.predict(target_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy()
|
232 |
+
mask_flat = (mask.view(mask.shape[0], -1) > 0.5).detach().cpu().numpy()
|
233 |
+
|
234 |
+
batch_target_class_freq_total = []
|
235 |
+
batch_target_class_freq_mask = []
|
236 |
+
batch_pred_class_freq_mask = []
|
237 |
+
|
238 |
+
for cur_pred_segm, cur_target_segm, cur_mask in zip(pred_segm_flat, target_segm_flat, mask_flat):
|
239 |
+
cur_target_class_freq_total = np.bincount(cur_target_segm, minlength=NUM_CLASS)[None, ...]
|
240 |
+
cur_target_class_freq_mask = np.bincount(cur_target_segm[cur_mask], minlength=NUM_CLASS)[None, ...]
|
241 |
+
cur_pred_class_freq_mask = np.bincount(cur_pred_segm[cur_mask], minlength=NUM_CLASS)[None, ...]
|
242 |
+
|
243 |
+
self.target_class_freq_by_image_total.append(cur_target_class_freq_total)
|
244 |
+
self.target_class_freq_by_image_mask.append(cur_target_class_freq_mask)
|
245 |
+
self.pred_class_freq_by_image_mask.append(cur_pred_class_freq_mask)
|
246 |
+
|
247 |
+
batch_target_class_freq_total.append(cur_target_class_freq_total)
|
248 |
+
batch_target_class_freq_mask.append(cur_target_class_freq_mask)
|
249 |
+
batch_pred_class_freq_mask.append(cur_pred_class_freq_mask)
|
250 |
+
|
251 |
+
batch_target_class_freq_total = np.concatenate(batch_target_class_freq_total, axis=0)
|
252 |
+
batch_target_class_freq_mask = np.concatenate(batch_target_class_freq_mask, axis=0)
|
253 |
+
batch_pred_class_freq_mask = np.concatenate(batch_pred_class_freq_mask, axis=0)
|
254 |
+
return batch_target_class_freq_total, batch_target_class_freq_mask, batch_pred_class_freq_mask
|
255 |
+
|
256 |
+
def reset(self):
|
257 |
+
super().reset()
|
258 |
+
self.target_class_freq_by_image_total = []
|
259 |
+
self.target_class_freq_by_image_mask = []
|
260 |
+
self.pred_class_freq_by_image_mask = []
|
261 |
+
|
262 |
+
|
263 |
+
def distribute_values_to_classes(target_class_freq_by_image_mask, values, idx2name):
|
264 |
+
assert target_class_freq_by_image_mask.ndim == 2 and target_class_freq_by_image_mask.shape[0] == values.shape[0]
|
265 |
+
total_class_freq = target_class_freq_by_image_mask.sum(0)
|
266 |
+
distr_values = (target_class_freq_by_image_mask * values[..., None]).sum(0)
|
267 |
+
result = distr_values / (total_class_freq + 1e-3)
|
268 |
+
return {idx2name[i]: val for i, val in enumerate(result) if total_class_freq[i] > 0}
|
269 |
+
|
270 |
+
|
271 |
+
def get_segmentation_idx2name():
|
272 |
+
return {i - 1: name for i, name in segm_options['classes'].set_index('Idx', drop=True)['Name'].to_dict().items()}
|
273 |
+
|
274 |
+
|
275 |
+
class SegmentationAwarePairwiseScore(SegmentationAwareScore):
|
276 |
+
def __init__(self, *args, **kwargs):
|
277 |
+
super().__init__(*args, **kwargs)
|
278 |
+
self.individual_values = []
|
279 |
+
self.segm_idx2name = get_segmentation_idx2name()
|
280 |
+
|
281 |
+
def forward(self, pred_batch, target_batch, mask):
|
282 |
+
cur_class_stats = super().forward(pred_batch, target_batch, mask)
|
283 |
+
score_values = self.calc_score(pred_batch, target_batch, mask)
|
284 |
+
self.individual_values.append(score_values)
|
285 |
+
return cur_class_stats + (score_values,)
|
286 |
+
|
287 |
+
@abstractmethod
|
288 |
+
def calc_score(self, pred_batch, target_batch, mask):
|
289 |
+
raise NotImplementedError()
|
290 |
+
|
291 |
+
def get_value(self, groups=None, states=None):
|
292 |
+
"""
|
293 |
+
:param groups:
|
294 |
+
:return:
|
295 |
+
total_results: dict of kind {'mean': score mean, 'std': score std}
|
296 |
+
group_results: None, if groups is None;
|
297 |
+
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
|
298 |
+
"""
|
299 |
+
if states is not None:
|
300 |
+
(target_class_freq_by_image_total,
|
301 |
+
target_class_freq_by_image_mask,
|
302 |
+
pred_class_freq_by_image_mask,
|
303 |
+
individual_values) = states
|
304 |
+
else:
|
305 |
+
target_class_freq_by_image_total = self.target_class_freq_by_image_total
|
306 |
+
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
|
307 |
+
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
|
308 |
+
individual_values = self.individual_values
|
309 |
+
|
310 |
+
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
|
311 |
+
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
|
312 |
+
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
|
313 |
+
individual_values = np.concatenate(individual_values, axis=0)
|
314 |
+
|
315 |
+
total_results = {
|
316 |
+
'mean': individual_values.mean(),
|
317 |
+
'std': individual_values.std(),
|
318 |
+
**distribute_values_to_classes(target_class_freq_by_image_mask, individual_values, self.segm_idx2name)
|
319 |
+
}
|
320 |
+
|
321 |
+
if groups is None:
|
322 |
+
return total_results, None
|
323 |
+
|
324 |
+
group_results = dict()
|
325 |
+
grouping = get_groupings(groups)
|
326 |
+
for label, index in grouping.items():
|
327 |
+
group_class_freq = target_class_freq_by_image_mask[index]
|
328 |
+
group_scores = individual_values[index]
|
329 |
+
group_results[label] = {
|
330 |
+
'mean': group_scores.mean(),
|
331 |
+
'std': group_scores.std(),
|
332 |
+
** distribute_values_to_classes(group_class_freq, group_scores, self.segm_idx2name)
|
333 |
+
}
|
334 |
+
return total_results, group_results
|
335 |
+
|
336 |
+
def reset(self):
|
337 |
+
super().reset()
|
338 |
+
self.individual_values = []
|
339 |
+
|
340 |
+
|
341 |
+
class SegmentationClassStats(SegmentationAwarePairwiseScore):
|
342 |
+
def calc_score(self, pred_batch, target_batch, mask):
|
343 |
+
return 0
|
344 |
+
|
345 |
+
def get_value(self, groups=None, states=None):
|
346 |
+
"""
|
347 |
+
:param groups:
|
348 |
+
:return:
|
349 |
+
total_results: dict of kind {'mean': score mean, 'std': score std}
|
350 |
+
group_results: None, if groups is None;
|
351 |
+
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
|
352 |
+
"""
|
353 |
+
if states is not None:
|
354 |
+
(target_class_freq_by_image_total,
|
355 |
+
target_class_freq_by_image_mask,
|
356 |
+
pred_class_freq_by_image_mask,
|
357 |
+
_) = states
|
358 |
+
else:
|
359 |
+
target_class_freq_by_image_total = self.target_class_freq_by_image_total
|
360 |
+
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
|
361 |
+
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
|
362 |
+
|
363 |
+
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
|
364 |
+
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
|
365 |
+
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
|
366 |
+
|
367 |
+
target_class_freq_by_image_total_marginal = target_class_freq_by_image_total.sum(0).astype('float32')
|
368 |
+
target_class_freq_by_image_total_marginal /= target_class_freq_by_image_total_marginal.sum()
|
369 |
+
|
370 |
+
target_class_freq_by_image_mask_marginal = target_class_freq_by_image_mask.sum(0).astype('float32')
|
371 |
+
target_class_freq_by_image_mask_marginal /= target_class_freq_by_image_mask_marginal.sum()
|
372 |
+
|
373 |
+
pred_class_freq_diff = (pred_class_freq_by_image_mask - target_class_freq_by_image_mask).sum(0) / (target_class_freq_by_image_mask.sum(0) + 1e-3)
|
374 |
+
|
375 |
+
total_results = dict()
|
376 |
+
total_results.update({f'total_freq/{self.segm_idx2name[i]}': v
|
377 |
+
for i, v in enumerate(target_class_freq_by_image_total_marginal)
|
378 |
+
if v > 0})
|
379 |
+
total_results.update({f'mask_freq/{self.segm_idx2name[i]}': v
|
380 |
+
for i, v in enumerate(target_class_freq_by_image_mask_marginal)
|
381 |
+
if v > 0})
|
382 |
+
total_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v
|
383 |
+
for i, v in enumerate(pred_class_freq_diff)
|
384 |
+
if target_class_freq_by_image_total_marginal[i] > 0})
|
385 |
+
|
386 |
+
if groups is None:
|
387 |
+
return total_results, None
|
388 |
+
|
389 |
+
group_results = dict()
|
390 |
+
grouping = get_groupings(groups)
|
391 |
+
for label, index in grouping.items():
|
392 |
+
group_target_class_freq_by_image_total = target_class_freq_by_image_total[index]
|
393 |
+
group_target_class_freq_by_image_mask = target_class_freq_by_image_mask[index]
|
394 |
+
group_pred_class_freq_by_image_mask = pred_class_freq_by_image_mask[index]
|
395 |
+
|
396 |
+
group_target_class_freq_by_image_total_marginal = group_target_class_freq_by_image_total.sum(0).astype('float32')
|
397 |
+
group_target_class_freq_by_image_total_marginal /= group_target_class_freq_by_image_total_marginal.sum()
|
398 |
+
|
399 |
+
group_target_class_freq_by_image_mask_marginal = group_target_class_freq_by_image_mask.sum(0).astype('float32')
|
400 |
+
group_target_class_freq_by_image_mask_marginal /= group_target_class_freq_by_image_mask_marginal.sum()
|
401 |
+
|
402 |
+
group_pred_class_freq_diff = (group_pred_class_freq_by_image_mask - group_target_class_freq_by_image_mask).sum(0) / (
|
403 |
+
group_target_class_freq_by_image_mask.sum(0) + 1e-3)
|
404 |
+
|
405 |
+
cur_group_results = dict()
|
406 |
+
cur_group_results.update({f'total_freq/{self.segm_idx2name[i]}': v
|
407 |
+
for i, v in enumerate(group_target_class_freq_by_image_total_marginal)
|
408 |
+
if v > 0})
|
409 |
+
cur_group_results.update({f'mask_freq/{self.segm_idx2name[i]}': v
|
410 |
+
for i, v in enumerate(group_target_class_freq_by_image_mask_marginal)
|
411 |
+
if v > 0})
|
412 |
+
cur_group_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v
|
413 |
+
for i, v in enumerate(group_pred_class_freq_diff)
|
414 |
+
if group_target_class_freq_by_image_total_marginal[i] > 0})
|
415 |
+
|
416 |
+
group_results[label] = cur_group_results
|
417 |
+
return total_results, group_results
|
418 |
+
|
419 |
+
|
420 |
+
class SegmentationAwareSSIM(SegmentationAwarePairwiseScore):
|
421 |
+
def __init__(self, *args, window_size=11, **kwargs):
|
422 |
+
super().__init__(*args, **kwargs)
|
423 |
+
self.score_impl = SSIM(window_size=window_size, size_average=False).eval()
|
424 |
+
|
425 |
+
def calc_score(self, pred_batch, target_batch, mask):
|
426 |
+
return self.score_impl(pred_batch, target_batch).detach().cpu().numpy()
|
427 |
+
|
428 |
+
|
429 |
+
class SegmentationAwareLPIPS(SegmentationAwarePairwiseScore):
|
430 |
+
def __init__(self, *args, model='net-lin', net='vgg', model_path=None, use_gpu=True, **kwargs):
|
431 |
+
super().__init__(*args, **kwargs)
|
432 |
+
self.score_impl = PerceptualLoss(model=model, net=net, model_path=model_path,
|
433 |
+
use_gpu=use_gpu, spatial=False).eval()
|
434 |
+
|
435 |
+
def calc_score(self, pred_batch, target_batch, mask):
|
436 |
+
return self.score_impl(pred_batch, target_batch).flatten().detach().cpu().numpy()
|
437 |
+
|
438 |
+
|
439 |
+
def calculade_fid_no_img(img_i, activations_pred, activations_target, eps=1e-6):
|
440 |
+
activations_pred = activations_pred.copy()
|
441 |
+
activations_pred[img_i] = activations_target[img_i]
|
442 |
+
return calculate_frechet_distance(activations_pred, activations_target, eps=eps)
|
443 |
+
|
444 |
+
|
445 |
+
class SegmentationAwareFID(SegmentationAwarePairwiseScore):
|
446 |
+
def __init__(self, *args, dims=2048, eps=1e-6, n_jobs=-1, **kwargs):
|
447 |
+
super().__init__(*args, **kwargs)
|
448 |
+
if getattr(FIDScore, '_MODEL', None) is None:
|
449 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
450 |
+
FIDScore._MODEL = InceptionV3([block_idx]).eval()
|
451 |
+
self.model = FIDScore._MODEL
|
452 |
+
self.eps = eps
|
453 |
+
self.n_jobs = n_jobs
|
454 |
+
|
455 |
+
def calc_score(self, pred_batch, target_batch, mask):
|
456 |
+
activations_pred = self._get_activations(pred_batch)
|
457 |
+
activations_target = self._get_activations(target_batch)
|
458 |
+
return activations_pred, activations_target
|
459 |
+
|
460 |
+
def get_value(self, groups=None, states=None):
|
461 |
+
"""
|
462 |
+
:param groups:
|
463 |
+
:return:
|
464 |
+
total_results: dict of kind {'mean': score mean, 'std': score std}
|
465 |
+
group_results: None, if groups is None;
|
466 |
+
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
|
467 |
+
"""
|
468 |
+
if states is not None:
|
469 |
+
(target_class_freq_by_image_total,
|
470 |
+
target_class_freq_by_image_mask,
|
471 |
+
pred_class_freq_by_image_mask,
|
472 |
+
activation_pairs) = states
|
473 |
+
else:
|
474 |
+
target_class_freq_by_image_total = self.target_class_freq_by_image_total
|
475 |
+
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
|
476 |
+
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
|
477 |
+
activation_pairs = self.individual_values
|
478 |
+
|
479 |
+
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
|
480 |
+
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
|
481 |
+
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
|
482 |
+
activations_pred, activations_target = zip(*activation_pairs)
|
483 |
+
activations_pred = np.concatenate(activations_pred, axis=0)
|
484 |
+
activations_target = np.concatenate(activations_target, axis=0)
|
485 |
+
|
486 |
+
total_results = {
|
487 |
+
'mean': calculate_frechet_distance(activations_pred, activations_target, eps=self.eps),
|
488 |
+
'std': 0,
|
489 |
+
**self.distribute_fid_to_classes(target_class_freq_by_image_mask, activations_pred, activations_target)
|
490 |
+
}
|
491 |
+
|
492 |
+
if groups is None:
|
493 |
+
return total_results, None
|
494 |
+
|
495 |
+
group_results = dict()
|
496 |
+
grouping = get_groupings(groups)
|
497 |
+
for label, index in grouping.items():
|
498 |
+
if len(index) > 1:
|
499 |
+
group_activations_pred = activations_pred[index]
|
500 |
+
group_activations_target = activations_target[index]
|
501 |
+
group_class_freq = target_class_freq_by_image_mask[index]
|
502 |
+
group_results[label] = {
|
503 |
+
'mean': calculate_frechet_distance(group_activations_pred, group_activations_target, eps=self.eps),
|
504 |
+
'std': 0,
|
505 |
+
**self.distribute_fid_to_classes(group_class_freq,
|
506 |
+
group_activations_pred,
|
507 |
+
group_activations_target)
|
508 |
+
}
|
509 |
+
else:
|
510 |
+
group_results[label] = dict(mean=float('nan'), std=0)
|
511 |
+
return total_results, group_results
|
512 |
+
|
513 |
+
def distribute_fid_to_classes(self, class_freq, activations_pred, activations_target):
|
514 |
+
real_fid = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps)
|
515 |
+
|
516 |
+
fid_no_images = Parallel(n_jobs=self.n_jobs)(
|
517 |
+
delayed(calculade_fid_no_img)(img_i, activations_pred, activations_target, eps=self.eps)
|
518 |
+
for img_i in range(activations_pred.shape[0])
|
519 |
+
)
|
520 |
+
errors = real_fid - fid_no_images
|
521 |
+
return distribute_values_to_classes(class_freq, errors, self.segm_idx2name)
|
522 |
+
|
523 |
+
def _get_activations(self, batch):
|
524 |
+
activations = self.model(batch)[0]
|
525 |
+
if activations.shape[2] != 1 or activations.shape[3] != 1:
|
526 |
+
activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1))
|
527 |
+
activations = activations.squeeze(-1).squeeze(-1).detach().cpu().numpy()
|
528 |
+
return activations
|
saicinpainting/evaluation/losses/fid/__init__.py
ADDED
File without changes
|
saicinpainting/evaluation/losses/fid/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (165 Bytes). View file
|
|
saicinpainting/evaluation/losses/fid/__pycache__/inception.cpython-39.pyc
ADDED
Binary file (9.02 kB). View file
|
|
saicinpainting/evaluation/losses/fid/fid_score.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
3 |
+
|
4 |
+
The FID metric calculates the distance between two distributions of images.
|
5 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
6 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
7 |
+
|
8 |
+
When run as a stand-alone program, it compares the distribution of
|
9 |
+
images that are stored as PNG/JPEG at a specified location with a
|
10 |
+
distribution given by summary statistics (in pickle format).
|
11 |
+
|
12 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
13 |
+
the pool_3 layer of the inception net for generated samples and real world
|
14 |
+
samples respectively.
|
15 |
+
|
16 |
+
See --help to see further details.
|
17 |
+
|
18 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
19 |
+
of Tensorflow
|
20 |
+
|
21 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
22 |
+
|
23 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
24 |
+
you may not use this file except in compliance with the License.
|
25 |
+
You may obtain a copy of the License at
|
26 |
+
|
27 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
28 |
+
|
29 |
+
Unless required by applicable law or agreed to in writing, software
|
30 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
31 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
32 |
+
See the License for the specific language governing permissions and
|
33 |
+
limitations under the License.
|
34 |
+
"""
|
35 |
+
import os
|
36 |
+
import pathlib
|
37 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
38 |
+
|
39 |
+
import numpy as np
|
40 |
+
import torch
|
41 |
+
# from scipy.misc import imread
|
42 |
+
from imageio import imread
|
43 |
+
from PIL import Image, JpegImagePlugin
|
44 |
+
from scipy import linalg
|
45 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
46 |
+
from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor
|
47 |
+
|
48 |
+
try:
|
49 |
+
from tqdm import tqdm
|
50 |
+
except ImportError:
|
51 |
+
# If not tqdm is not available, provide a mock version of it
|
52 |
+
def tqdm(x): return x
|
53 |
+
|
54 |
+
try:
|
55 |
+
from .inception import InceptionV3
|
56 |
+
except ModuleNotFoundError:
|
57 |
+
from inception import InceptionV3
|
58 |
+
|
59 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
60 |
+
parser.add_argument('path', type=str, nargs=2,
|
61 |
+
help=('Path to the generated images or '
|
62 |
+
'to .npz statistic files'))
|
63 |
+
parser.add_argument('--batch-size', type=int, default=50,
|
64 |
+
help='Batch size to use')
|
65 |
+
parser.add_argument('--dims', type=int, default=2048,
|
66 |
+
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
|
67 |
+
help=('Dimensionality of Inception features to use. '
|
68 |
+
'By default, uses pool3 features'))
|
69 |
+
parser.add_argument('-c', '--gpu', default='', type=str,
|
70 |
+
help='GPU to use (leave blank for CPU only)')
|
71 |
+
parser.add_argument('--resize', default=256)
|
72 |
+
|
73 |
+
transform = Compose([Resize(256), CenterCrop(256), ToTensor()])
|
74 |
+
|
75 |
+
|
76 |
+
def get_activations(files, model, batch_size=50, dims=2048,
|
77 |
+
cuda=False, verbose=False, keep_size=False):
|
78 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
79 |
+
|
80 |
+
Params:
|
81 |
+
-- files : List of image files paths
|
82 |
+
-- model : Instance of inception model
|
83 |
+
-- batch_size : Batch size of images for the model to process at once.
|
84 |
+
Make sure that the number of samples is a multiple of
|
85 |
+
the batch size, otherwise some samples are ignored. This
|
86 |
+
behavior is retained to match the original FID score
|
87 |
+
implementation.
|
88 |
+
-- dims : Dimensionality of features returned by Inception
|
89 |
+
-- cuda : If set to True, use GPU
|
90 |
+
-- verbose : If set to True and parameter out_step is given, the number
|
91 |
+
of calculated batches is reported.
|
92 |
+
Returns:
|
93 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
94 |
+
activations of the given tensor when feeding inception with the
|
95 |
+
query tensor.
|
96 |
+
"""
|
97 |
+
model.eval()
|
98 |
+
|
99 |
+
if len(files) % batch_size != 0:
|
100 |
+
print(('Warning: number of images is not a multiple of the '
|
101 |
+
'batch size. Some samples are going to be ignored.'))
|
102 |
+
if batch_size > len(files):
|
103 |
+
print(('Warning: batch size is bigger than the data size. '
|
104 |
+
'Setting batch size to data size'))
|
105 |
+
batch_size = len(files)
|
106 |
+
|
107 |
+
n_batches = len(files) // batch_size
|
108 |
+
n_used_imgs = n_batches * batch_size
|
109 |
+
|
110 |
+
pred_arr = np.empty((n_used_imgs, dims))
|
111 |
+
|
112 |
+
for i in tqdm(range(n_batches)):
|
113 |
+
if verbose:
|
114 |
+
print('\rPropagating batch %d/%d' % (i + 1, n_batches),
|
115 |
+
end='', flush=True)
|
116 |
+
start = i * batch_size
|
117 |
+
end = start + batch_size
|
118 |
+
|
119 |
+
# # Official code goes below
|
120 |
+
# images = np.array([imread(str(f)).astype(np.float32)
|
121 |
+
# for f in files[start:end]])
|
122 |
+
|
123 |
+
# # Reshape to (n_images, 3, height, width)
|
124 |
+
# images = images.transpose((0, 3, 1, 2))
|
125 |
+
# images /= 255
|
126 |
+
# batch = torch.from_numpy(images).type(torch.FloatTensor)
|
127 |
+
# #
|
128 |
+
|
129 |
+
t = transform if not keep_size else ToTensor()
|
130 |
+
|
131 |
+
if isinstance(files[0], pathlib.PosixPath):
|
132 |
+
images = [t(Image.open(str(f))) for f in files[start:end]]
|
133 |
+
|
134 |
+
elif isinstance(files[0], Image.Image):
|
135 |
+
images = [t(f) for f in files[start:end]]
|
136 |
+
|
137 |
+
else:
|
138 |
+
raise ValueError(f"Unknown data type for image: {type(files[0])}")
|
139 |
+
|
140 |
+
batch = torch.stack(images)
|
141 |
+
|
142 |
+
if cuda:
|
143 |
+
batch = batch.cuda()
|
144 |
+
|
145 |
+
pred = model(batch)[0]
|
146 |
+
|
147 |
+
# If model output is not scalar, apply global spatial average pooling.
|
148 |
+
# This happens if you choose a dimensionality not equal 2048.
|
149 |
+
if pred.shape[2] != 1 or pred.shape[3] != 1:
|
150 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
151 |
+
|
152 |
+
pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
|
153 |
+
|
154 |
+
if verbose:
|
155 |
+
print(' done')
|
156 |
+
|
157 |
+
return pred_arr
|
158 |
+
|
159 |
+
|
160 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
161 |
+
"""Numpy implementation of the Frechet Distance.
|
162 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
163 |
+
and X_2 ~ N(mu_2, C_2) is
|
164 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
165 |
+
|
166 |
+
Stable version by Dougal J. Sutherland.
|
167 |
+
|
168 |
+
Params:
|
169 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
170 |
+
inception net (like returned by the function 'get_predictions')
|
171 |
+
for generated samples.
|
172 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
173 |
+
representative data set.
|
174 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
175 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
176 |
+
representative data set.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
-- : The Frechet Distance.
|
180 |
+
"""
|
181 |
+
|
182 |
+
mu1 = np.atleast_1d(mu1)
|
183 |
+
mu2 = np.atleast_1d(mu2)
|
184 |
+
|
185 |
+
sigma1 = np.atleast_2d(sigma1)
|
186 |
+
sigma2 = np.atleast_2d(sigma2)
|
187 |
+
|
188 |
+
assert mu1.shape == mu2.shape, \
|
189 |
+
'Training and test mean vectors have different lengths'
|
190 |
+
assert sigma1.shape == sigma2.shape, \
|
191 |
+
'Training and test covariances have different dimensions'
|
192 |
+
|
193 |
+
diff = mu1 - mu2
|
194 |
+
|
195 |
+
# Product might be almost singular
|
196 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
197 |
+
if not np.isfinite(covmean).all():
|
198 |
+
msg = ('fid calculation produces singular product; '
|
199 |
+
'adding %s to diagonal of cov estimates') % eps
|
200 |
+
print(msg)
|
201 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
202 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
203 |
+
|
204 |
+
# Numerical error might give slight imaginary component
|
205 |
+
if np.iscomplexobj(covmean):
|
206 |
+
# if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
207 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2):
|
208 |
+
m = np.max(np.abs(covmean.imag))
|
209 |
+
raise ValueError('Imaginary component {}'.format(m))
|
210 |
+
covmean = covmean.real
|
211 |
+
|
212 |
+
tr_covmean = np.trace(covmean)
|
213 |
+
|
214 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
215 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
216 |
+
|
217 |
+
|
218 |
+
def calculate_activation_statistics(files, model, batch_size=50,
|
219 |
+
dims=2048, cuda=False, verbose=False, keep_size=False):
|
220 |
+
"""Calculation of the statistics used by the FID.
|
221 |
+
Params:
|
222 |
+
-- files : List of image files paths
|
223 |
+
-- model : Instance of inception model
|
224 |
+
-- batch_size : The images numpy array is split into batches with
|
225 |
+
batch size batch_size. A reasonable batch size
|
226 |
+
depends on the hardware.
|
227 |
+
-- dims : Dimensionality of features returned by Inception
|
228 |
+
-- cuda : If set to True, use GPU
|
229 |
+
-- verbose : If set to True and parameter out_step is given, the
|
230 |
+
number of calculated batches is reported.
|
231 |
+
Returns:
|
232 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
233 |
+
the inception model.
|
234 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
235 |
+
the inception model.
|
236 |
+
"""
|
237 |
+
act = get_activations(files, model, batch_size, dims, cuda, verbose, keep_size=keep_size)
|
238 |
+
mu = np.mean(act, axis=0)
|
239 |
+
sigma = np.cov(act, rowvar=False)
|
240 |
+
return mu, sigma
|
241 |
+
|
242 |
+
|
243 |
+
def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
|
244 |
+
if path.endswith('.npz'):
|
245 |
+
f = np.load(path)
|
246 |
+
m, s = f['mu'][:], f['sigma'][:]
|
247 |
+
f.close()
|
248 |
+
else:
|
249 |
+
path = pathlib.Path(path)
|
250 |
+
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
|
251 |
+
m, s = calculate_activation_statistics(files, model, batch_size,
|
252 |
+
dims, cuda)
|
253 |
+
|
254 |
+
return m, s
|
255 |
+
|
256 |
+
|
257 |
+
def _compute_statistics_of_images(images, model, batch_size, dims, cuda, keep_size=False):
|
258 |
+
if isinstance(images, list): # exact paths to files are provided
|
259 |
+
m, s = calculate_activation_statistics(images, model, batch_size,
|
260 |
+
dims, cuda, keep_size=keep_size)
|
261 |
+
|
262 |
+
return m, s
|
263 |
+
|
264 |
+
else:
|
265 |
+
raise ValueError
|
266 |
+
|
267 |
+
|
268 |
+
def calculate_fid_given_paths(paths, batch_size, cuda, dims):
|
269 |
+
"""Calculates the FID of two paths"""
|
270 |
+
for p in paths:
|
271 |
+
if not os.path.exists(p):
|
272 |
+
raise RuntimeError('Invalid path: %s' % p)
|
273 |
+
|
274 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
275 |
+
|
276 |
+
model = InceptionV3([block_idx])
|
277 |
+
if cuda:
|
278 |
+
model.cuda()
|
279 |
+
|
280 |
+
m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
|
281 |
+
dims, cuda)
|
282 |
+
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
|
283 |
+
dims, cuda)
|
284 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
285 |
+
|
286 |
+
return fid_value
|
287 |
+
|
288 |
+
|
289 |
+
def calculate_fid_given_images(images, batch_size, cuda, dims, use_globals=False, keep_size=False):
|
290 |
+
if use_globals:
|
291 |
+
global FID_MODEL # for multiprocessing
|
292 |
+
|
293 |
+
for imgs in images:
|
294 |
+
if isinstance(imgs, list) and isinstance(imgs[0], (Image.Image, JpegImagePlugin.JpegImageFile)):
|
295 |
+
pass
|
296 |
+
else:
|
297 |
+
raise RuntimeError('Invalid images')
|
298 |
+
|
299 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
300 |
+
|
301 |
+
if 'FID_MODEL' not in globals() or not use_globals:
|
302 |
+
model = InceptionV3([block_idx])
|
303 |
+
if cuda:
|
304 |
+
model.cuda()
|
305 |
+
|
306 |
+
if use_globals:
|
307 |
+
FID_MODEL = model
|
308 |
+
|
309 |
+
else:
|
310 |
+
model = FID_MODEL
|
311 |
+
|
312 |
+
m1, s1 = _compute_statistics_of_images(images[0], model, batch_size,
|
313 |
+
dims, cuda, keep_size=False)
|
314 |
+
m2, s2 = _compute_statistics_of_images(images[1], model, batch_size,
|
315 |
+
dims, cuda, keep_size=False)
|
316 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
317 |
+
return fid_value
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == '__main__':
|
321 |
+
args = parser.parse_args()
|
322 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
323 |
+
|
324 |
+
fid_value = calculate_fid_given_paths(args.path,
|
325 |
+
args.batch_size,
|
326 |
+
args.gpu != '',
|
327 |
+
args.dims)
|
328 |
+
print('FID: ', fid_value)
|
saicinpainting/evaluation/losses/fid/inception.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torchvision import models
|
7 |
+
|
8 |
+
try:
|
9 |
+
from torchvision.models.utils import load_state_dict_from_url
|
10 |
+
except ImportError:
|
11 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
12 |
+
|
13 |
+
# Inception weights ported to Pytorch from
|
14 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
15 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
16 |
+
|
17 |
+
|
18 |
+
LOGGER = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
class InceptionV3(nn.Module):
|
22 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
23 |
+
|
24 |
+
# Index of default block of inception to return,
|
25 |
+
# corresponds to output of final average pooling
|
26 |
+
DEFAULT_BLOCK_INDEX = 3
|
27 |
+
|
28 |
+
# Maps feature dimensionality to their output blocks indices
|
29 |
+
BLOCK_INDEX_BY_DIM = {
|
30 |
+
64: 0, # First max pooling features
|
31 |
+
192: 1, # Second max pooling featurs
|
32 |
+
768: 2, # Pre-aux classifier features
|
33 |
+
2048: 3 # Final average pooling features
|
34 |
+
}
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
38 |
+
resize_input=True,
|
39 |
+
normalize_input=True,
|
40 |
+
requires_grad=False,
|
41 |
+
use_fid_inception=True):
|
42 |
+
"""Build pretrained InceptionV3
|
43 |
+
|
44 |
+
Parameters
|
45 |
+
----------
|
46 |
+
output_blocks : list of int
|
47 |
+
Indices of blocks to return features of. Possible values are:
|
48 |
+
- 0: corresponds to output of first max pooling
|
49 |
+
- 1: corresponds to output of second max pooling
|
50 |
+
- 2: corresponds to output which is fed to aux classifier
|
51 |
+
- 3: corresponds to output of final average pooling
|
52 |
+
resize_input : bool
|
53 |
+
If true, bilinearly resizes input to width and height 299 before
|
54 |
+
feeding input to model. As the network without fully connected
|
55 |
+
layers is fully convolutional, it should be able to handle inputs
|
56 |
+
of arbitrary size, so resizing might not be strictly needed
|
57 |
+
normalize_input : bool
|
58 |
+
If true, scales the input from range (0, 1) to the range the
|
59 |
+
pretrained Inception network expects, namely (-1, 1)
|
60 |
+
requires_grad : bool
|
61 |
+
If true, parameters of the model require gradients. Possibly useful
|
62 |
+
for finetuning the network
|
63 |
+
use_fid_inception : bool
|
64 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
65 |
+
FID implementation. If false, uses the pretrained Inception model
|
66 |
+
available in torchvision. The FID Inception model has different
|
67 |
+
weights and a slightly different structure from torchvision's
|
68 |
+
Inception model. If you want to compute FID scores, you are
|
69 |
+
strongly advised to set this parameter to true to get comparable
|
70 |
+
results.
|
71 |
+
"""
|
72 |
+
super(InceptionV3, self).__init__()
|
73 |
+
|
74 |
+
self.resize_input = resize_input
|
75 |
+
self.normalize_input = normalize_input
|
76 |
+
self.output_blocks = sorted(output_blocks)
|
77 |
+
self.last_needed_block = max(output_blocks)
|
78 |
+
|
79 |
+
assert self.last_needed_block <= 3, \
|
80 |
+
'Last possible output block index is 3'
|
81 |
+
|
82 |
+
self.blocks = nn.ModuleList()
|
83 |
+
|
84 |
+
if use_fid_inception:
|
85 |
+
inception = fid_inception_v3()
|
86 |
+
else:
|
87 |
+
inception = models.inception_v3(pretrained=True)
|
88 |
+
|
89 |
+
# Block 0: input to maxpool1
|
90 |
+
block0 = [
|
91 |
+
inception.Conv2d_1a_3x3,
|
92 |
+
inception.Conv2d_2a_3x3,
|
93 |
+
inception.Conv2d_2b_3x3,
|
94 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
95 |
+
]
|
96 |
+
self.blocks.append(nn.Sequential(*block0))
|
97 |
+
|
98 |
+
# Block 1: maxpool1 to maxpool2
|
99 |
+
if self.last_needed_block >= 1:
|
100 |
+
block1 = [
|
101 |
+
inception.Conv2d_3b_1x1,
|
102 |
+
inception.Conv2d_4a_3x3,
|
103 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
104 |
+
]
|
105 |
+
self.blocks.append(nn.Sequential(*block1))
|
106 |
+
|
107 |
+
# Block 2: maxpool2 to aux classifier
|
108 |
+
if self.last_needed_block >= 2:
|
109 |
+
block2 = [
|
110 |
+
inception.Mixed_5b,
|
111 |
+
inception.Mixed_5c,
|
112 |
+
inception.Mixed_5d,
|
113 |
+
inception.Mixed_6a,
|
114 |
+
inception.Mixed_6b,
|
115 |
+
inception.Mixed_6c,
|
116 |
+
inception.Mixed_6d,
|
117 |
+
inception.Mixed_6e,
|
118 |
+
]
|
119 |
+
self.blocks.append(nn.Sequential(*block2))
|
120 |
+
|
121 |
+
# Block 3: aux classifier to final avgpool
|
122 |
+
if self.last_needed_block >= 3:
|
123 |
+
block3 = [
|
124 |
+
inception.Mixed_7a,
|
125 |
+
inception.Mixed_7b,
|
126 |
+
inception.Mixed_7c,
|
127 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
128 |
+
]
|
129 |
+
self.blocks.append(nn.Sequential(*block3))
|
130 |
+
|
131 |
+
for param in self.parameters():
|
132 |
+
param.requires_grad = requires_grad
|
133 |
+
|
134 |
+
def forward(self, inp):
|
135 |
+
"""Get Inception feature maps
|
136 |
+
|
137 |
+
Parameters
|
138 |
+
----------
|
139 |
+
inp : torch.autograd.Variable
|
140 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
141 |
+
range (0, 1)
|
142 |
+
|
143 |
+
Returns
|
144 |
+
-------
|
145 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
146 |
+
block, sorted ascending by index
|
147 |
+
"""
|
148 |
+
outp = []
|
149 |
+
x = inp
|
150 |
+
|
151 |
+
if self.resize_input:
|
152 |
+
x = F.interpolate(x,
|
153 |
+
size=(299, 299),
|
154 |
+
mode='bilinear',
|
155 |
+
align_corners=False)
|
156 |
+
|
157 |
+
if self.normalize_input:
|
158 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
159 |
+
|
160 |
+
for idx, block in enumerate(self.blocks):
|
161 |
+
x = block(x)
|
162 |
+
if idx in self.output_blocks:
|
163 |
+
outp.append(x)
|
164 |
+
|
165 |
+
if idx == self.last_needed_block:
|
166 |
+
break
|
167 |
+
|
168 |
+
return outp
|
169 |
+
|
170 |
+
|
171 |
+
def fid_inception_v3():
|
172 |
+
"""Build pretrained Inception model for FID computation
|
173 |
+
|
174 |
+
The Inception model for FID computation uses a different set of weights
|
175 |
+
and has a slightly different structure than torchvision's Inception.
|
176 |
+
|
177 |
+
This method first constructs torchvision's Inception and then patches the
|
178 |
+
necessary parts that are different in the FID Inception model.
|
179 |
+
"""
|
180 |
+
LOGGER.info('fid_inception_v3 called')
|
181 |
+
inception = models.inception_v3(num_classes=1008,
|
182 |
+
aux_logits=False,
|
183 |
+
pretrained=False)
|
184 |
+
LOGGER.info('models.inception_v3 done')
|
185 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
186 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
187 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
188 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
189 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
190 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
191 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
192 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
193 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
194 |
+
|
195 |
+
LOGGER.info('fid_inception_v3 patching done')
|
196 |
+
|
197 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
198 |
+
LOGGER.info('fid_inception_v3 weights downloaded')
|
199 |
+
|
200 |
+
inception.load_state_dict(state_dict)
|
201 |
+
LOGGER.info('fid_inception_v3 weights loaded into model')
|
202 |
+
|
203 |
+
return inception
|
204 |
+
|
205 |
+
|
206 |
+
class FIDInceptionA(models.inception.InceptionA):
|
207 |
+
"""InceptionA block patched for FID computation"""
|
208 |
+
def __init__(self, in_channels, pool_features):
|
209 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
210 |
+
|
211 |
+
def forward(self, x):
|
212 |
+
branch1x1 = self.branch1x1(x)
|
213 |
+
|
214 |
+
branch5x5 = self.branch5x5_1(x)
|
215 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
216 |
+
|
217 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
218 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
219 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
220 |
+
|
221 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
222 |
+
# its average calculation
|
223 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
224 |
+
count_include_pad=False)
|
225 |
+
branch_pool = self.branch_pool(branch_pool)
|
226 |
+
|
227 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
228 |
+
return torch.cat(outputs, 1)
|
229 |
+
|
230 |
+
|
231 |
+
class FIDInceptionC(models.inception.InceptionC):
|
232 |
+
"""InceptionC block patched for FID computation"""
|
233 |
+
def __init__(self, in_channels, channels_7x7):
|
234 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
branch1x1 = self.branch1x1(x)
|
238 |
+
|
239 |
+
branch7x7 = self.branch7x7_1(x)
|
240 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
241 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
242 |
+
|
243 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
244 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
245 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
246 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
247 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
248 |
+
|
249 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
250 |
+
# its average calculation
|
251 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
252 |
+
count_include_pad=False)
|
253 |
+
branch_pool = self.branch_pool(branch_pool)
|
254 |
+
|
255 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
256 |
+
return torch.cat(outputs, 1)
|
257 |
+
|
258 |
+
|
259 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
260 |
+
"""First InceptionE block patched for FID computation"""
|
261 |
+
def __init__(self, in_channels):
|
262 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
branch1x1 = self.branch1x1(x)
|
266 |
+
|
267 |
+
branch3x3 = self.branch3x3_1(x)
|
268 |
+
branch3x3 = [
|
269 |
+
self.branch3x3_2a(branch3x3),
|
270 |
+
self.branch3x3_2b(branch3x3),
|
271 |
+
]
|
272 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
273 |
+
|
274 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
275 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
276 |
+
branch3x3dbl = [
|
277 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
278 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
279 |
+
]
|
280 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
281 |
+
|
282 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
283 |
+
# its average calculation
|
284 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
285 |
+
count_include_pad=False)
|
286 |
+
branch_pool = self.branch_pool(branch_pool)
|
287 |
+
|
288 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
289 |
+
return torch.cat(outputs, 1)
|
290 |
+
|
291 |
+
|
292 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
293 |
+
"""Second InceptionE block patched for FID computation"""
|
294 |
+
def __init__(self, in_channels):
|
295 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
296 |
+
|
297 |
+
def forward(self, x):
|
298 |
+
branch1x1 = self.branch1x1(x)
|
299 |
+
|
300 |
+
branch3x3 = self.branch3x3_1(x)
|
301 |
+
branch3x3 = [
|
302 |
+
self.branch3x3_2a(branch3x3),
|
303 |
+
self.branch3x3_2b(branch3x3),
|
304 |
+
]
|
305 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
306 |
+
|
307 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
308 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
309 |
+
branch3x3dbl = [
|
310 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
311 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
312 |
+
]
|
313 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
314 |
+
|
315 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
316 |
+
# pooling. This is likely an error in this specific Inception
|
317 |
+
# implementation, as other Inception models use average pooling here
|
318 |
+
# (which matches the description in the paper).
|
319 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
320 |
+
branch_pool = self.branch_pool(branch_pool)
|
321 |
+
|
322 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
323 |
+
return torch.cat(outputs, 1)
|
saicinpainting/evaluation/losses/lpips.py
ADDED
@@ -0,0 +1,891 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
############################################################
|
2 |
+
# The contents below have been combined using files in the #
|
3 |
+
# following repository: #
|
4 |
+
# https://github.com/richzhang/PerceptualSimilarity #
|
5 |
+
############################################################
|
6 |
+
|
7 |
+
############################################################
|
8 |
+
# __init__.py #
|
9 |
+
############################################################
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from skimage.metrics import structural_similarity
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from saicinpainting.utils import get_shape
|
16 |
+
|
17 |
+
|
18 |
+
class PerceptualLoss(torch.nn.Module):
|
19 |
+
def __init__(self, model='net-lin', net='alex', colorspace='rgb', model_path=None, spatial=False, use_gpu=True):
|
20 |
+
# VGG using our perceptually-learned weights (LPIPS metric)
|
21 |
+
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
|
22 |
+
super(PerceptualLoss, self).__init__()
|
23 |
+
self.use_gpu = use_gpu
|
24 |
+
self.spatial = spatial
|
25 |
+
self.model = DistModel()
|
26 |
+
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace,
|
27 |
+
model_path=model_path, spatial=self.spatial)
|
28 |
+
|
29 |
+
def forward(self, pred, target, normalize=True):
|
30 |
+
"""
|
31 |
+
Pred and target are Variables.
|
32 |
+
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
|
33 |
+
If normalize is False, assumes the images are already between [-1,+1]
|
34 |
+
Inputs pred and target are Nx3xHxW
|
35 |
+
Output pytorch Variable N long
|
36 |
+
"""
|
37 |
+
|
38 |
+
if normalize:
|
39 |
+
target = 2 * target - 1
|
40 |
+
pred = 2 * pred - 1
|
41 |
+
|
42 |
+
return self.model(target, pred)
|
43 |
+
|
44 |
+
|
45 |
+
def normalize_tensor(in_feat, eps=1e-10):
|
46 |
+
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
|
47 |
+
return in_feat / (norm_factor + eps)
|
48 |
+
|
49 |
+
|
50 |
+
def l2(p0, p1, range=255.):
|
51 |
+
return .5 * np.mean((p0 / range - p1 / range) ** 2)
|
52 |
+
|
53 |
+
|
54 |
+
def psnr(p0, p1, peak=255.):
|
55 |
+
return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2))
|
56 |
+
|
57 |
+
|
58 |
+
def dssim(p0, p1, range=255.):
|
59 |
+
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
|
60 |
+
|
61 |
+
|
62 |
+
def rgb2lab(in_img, mean_cent=False):
|
63 |
+
from skimage import color
|
64 |
+
img_lab = color.rgb2lab(in_img)
|
65 |
+
if (mean_cent):
|
66 |
+
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
|
67 |
+
return img_lab
|
68 |
+
|
69 |
+
|
70 |
+
def tensor2np(tensor_obj):
|
71 |
+
# change dimension of a tensor object into a numpy array
|
72 |
+
return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
|
73 |
+
|
74 |
+
|
75 |
+
def np2tensor(np_obj):
|
76 |
+
# change dimenion of np array into tensor array
|
77 |
+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
78 |
+
|
79 |
+
|
80 |
+
def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
|
81 |
+
# image tensor to lab tensor
|
82 |
+
from skimage import color
|
83 |
+
|
84 |
+
img = tensor2im(image_tensor)
|
85 |
+
img_lab = color.rgb2lab(img)
|
86 |
+
if (mc_only):
|
87 |
+
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
|
88 |
+
if (to_norm and not mc_only):
|
89 |
+
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
|
90 |
+
img_lab = img_lab / 100.
|
91 |
+
|
92 |
+
return np2tensor(img_lab)
|
93 |
+
|
94 |
+
|
95 |
+
def tensorlab2tensor(lab_tensor, return_inbnd=False):
|
96 |
+
from skimage import color
|
97 |
+
import warnings
|
98 |
+
warnings.filterwarnings("ignore")
|
99 |
+
|
100 |
+
lab = tensor2np(lab_tensor) * 100.
|
101 |
+
lab[:, :, 0] = lab[:, :, 0] + 50
|
102 |
+
|
103 |
+
rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1)
|
104 |
+
if (return_inbnd):
|
105 |
+
# convert back to lab, see if we match
|
106 |
+
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
|
107 |
+
mask = 1. * np.isclose(lab_back, lab, atol=2.)
|
108 |
+
mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
|
109 |
+
return (im2tensor(rgb_back), mask)
|
110 |
+
else:
|
111 |
+
return im2tensor(rgb_back)
|
112 |
+
|
113 |
+
|
114 |
+
def rgb2lab(input):
|
115 |
+
from skimage import color
|
116 |
+
return color.rgb2lab(input / 255.)
|
117 |
+
|
118 |
+
|
119 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
|
120 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
121 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
122 |
+
return image_numpy.astype(imtype)
|
123 |
+
|
124 |
+
|
125 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
|
126 |
+
return torch.Tensor((image / factor - cent)
|
127 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
128 |
+
|
129 |
+
|
130 |
+
def tensor2vec(vector_tensor):
|
131 |
+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
|
132 |
+
|
133 |
+
|
134 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
135 |
+
""" ap = voc_ap(rec, prec, [use_07_metric])
|
136 |
+
Compute VOC AP given precision and recall.
|
137 |
+
If use_07_metric is true, uses the
|
138 |
+
VOC 07 11 point method (default:False).
|
139 |
+
"""
|
140 |
+
if use_07_metric:
|
141 |
+
# 11 point metric
|
142 |
+
ap = 0.
|
143 |
+
for t in np.arange(0., 1.1, 0.1):
|
144 |
+
if np.sum(rec >= t) == 0:
|
145 |
+
p = 0
|
146 |
+
else:
|
147 |
+
p = np.max(prec[rec >= t])
|
148 |
+
ap = ap + p / 11.
|
149 |
+
else:
|
150 |
+
# correct AP calculation
|
151 |
+
# first append sentinel values at the end
|
152 |
+
mrec = np.concatenate(([0.], rec, [1.]))
|
153 |
+
mpre = np.concatenate(([0.], prec, [0.]))
|
154 |
+
|
155 |
+
# compute the precision envelope
|
156 |
+
for i in range(mpre.size - 1, 0, -1):
|
157 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
158 |
+
|
159 |
+
# to calculate area under PR curve, look for points
|
160 |
+
# where X axis (recall) changes value
|
161 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
162 |
+
|
163 |
+
# and sum (\Delta recall) * prec
|
164 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
165 |
+
return ap
|
166 |
+
|
167 |
+
|
168 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
|
169 |
+
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
170 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
171 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
172 |
+
return image_numpy.astype(imtype)
|
173 |
+
|
174 |
+
|
175 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
|
176 |
+
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
177 |
+
return torch.Tensor((image / factor - cent)
|
178 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
179 |
+
|
180 |
+
|
181 |
+
############################################################
|
182 |
+
# base_model.py #
|
183 |
+
############################################################
|
184 |
+
|
185 |
+
|
186 |
+
class BaseModel(torch.nn.Module):
|
187 |
+
def __init__(self):
|
188 |
+
super().__init__()
|
189 |
+
|
190 |
+
def name(self):
|
191 |
+
return 'BaseModel'
|
192 |
+
|
193 |
+
def initialize(self, use_gpu=True):
|
194 |
+
self.use_gpu = use_gpu
|
195 |
+
|
196 |
+
def forward(self):
|
197 |
+
pass
|
198 |
+
|
199 |
+
def get_image_paths(self):
|
200 |
+
pass
|
201 |
+
|
202 |
+
def optimize_parameters(self):
|
203 |
+
pass
|
204 |
+
|
205 |
+
def get_current_visuals(self):
|
206 |
+
return self.input
|
207 |
+
|
208 |
+
def get_current_errors(self):
|
209 |
+
return {}
|
210 |
+
|
211 |
+
def save(self, label):
|
212 |
+
pass
|
213 |
+
|
214 |
+
# helper saving function that can be used by subclasses
|
215 |
+
def save_network(self, network, path, network_label, epoch_label):
|
216 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
217 |
+
save_path = os.path.join(path, save_filename)
|
218 |
+
torch.save(network.state_dict(), save_path)
|
219 |
+
|
220 |
+
# helper loading function that can be used by subclasses
|
221 |
+
def load_network(self, network, network_label, epoch_label):
|
222 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
223 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
224 |
+
print('Loading network from %s' % save_path)
|
225 |
+
network.load_state_dict(torch.load(save_path, map_location='cpu'))
|
226 |
+
|
227 |
+
def update_learning_rate():
|
228 |
+
pass
|
229 |
+
|
230 |
+
def get_image_paths(self):
|
231 |
+
return self.image_paths
|
232 |
+
|
233 |
+
def save_done(self, flag=False):
|
234 |
+
np.save(os.path.join(self.save_dir, 'done_flag'), flag)
|
235 |
+
np.savetxt(os.path.join(self.save_dir, 'done_flag'), [flag, ], fmt='%i')
|
236 |
+
|
237 |
+
|
238 |
+
############################################################
|
239 |
+
# dist_model.py #
|
240 |
+
############################################################
|
241 |
+
|
242 |
+
import os
|
243 |
+
from collections import OrderedDict
|
244 |
+
from scipy.ndimage import zoom
|
245 |
+
from tqdm import tqdm
|
246 |
+
|
247 |
+
|
248 |
+
class DistModel(BaseModel):
|
249 |
+
def name(self):
|
250 |
+
return self.model_name
|
251 |
+
|
252 |
+
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False,
|
253 |
+
model_path=None,
|
254 |
+
use_gpu=True, printNet=False, spatial=False,
|
255 |
+
is_train=False, lr=.0001, beta1=0.5, version='0.1'):
|
256 |
+
'''
|
257 |
+
INPUTS
|
258 |
+
model - ['net-lin'] for linearly calibrated network
|
259 |
+
['net'] for off-the-shelf network
|
260 |
+
['L2'] for L2 distance in Lab colorspace
|
261 |
+
['SSIM'] for ssim in RGB colorspace
|
262 |
+
net - ['squeeze','alex','vgg']
|
263 |
+
model_path - if None, will look in weights/[NET_NAME].pth
|
264 |
+
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
265 |
+
use_gpu - bool - whether or not to use a GPU
|
266 |
+
printNet - bool - whether or not to print network architecture out
|
267 |
+
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
268 |
+
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
|
269 |
+
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
|
270 |
+
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
|
271 |
+
is_train - bool - [True] for training mode
|
272 |
+
lr - float - initial learning rate
|
273 |
+
beta1 - float - initial momentum term for adam
|
274 |
+
version - 0.1 for latest, 0.0 was original (with a bug)
|
275 |
+
'''
|
276 |
+
BaseModel.initialize(self, use_gpu=use_gpu)
|
277 |
+
|
278 |
+
self.model = model
|
279 |
+
self.net = net
|
280 |
+
self.is_train = is_train
|
281 |
+
self.spatial = spatial
|
282 |
+
self.model_name = '%s [%s]' % (model, net)
|
283 |
+
|
284 |
+
if (self.model == 'net-lin'): # pretrained net + linear layer
|
285 |
+
self.net = PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
|
286 |
+
use_dropout=True, spatial=spatial, version=version, lpips=True)
|
287 |
+
kw = dict(map_location='cpu')
|
288 |
+
if (model_path is None):
|
289 |
+
import inspect
|
290 |
+
model_path = os.path.abspath(
|
291 |
+
os.path.join(os.path.dirname(__file__), '..', '..', '..', 'models', 'lpips_models', f'{net}.pth'))
|
292 |
+
|
293 |
+
if (not is_train):
|
294 |
+
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
|
295 |
+
|
296 |
+
elif (self.model == 'net'): # pretrained network
|
297 |
+
self.net = PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
|
298 |
+
elif (self.model in ['L2', 'l2']):
|
299 |
+
self.net = L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing
|
300 |
+
self.model_name = 'L2'
|
301 |
+
elif (self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
|
302 |
+
self.net = DSSIM(use_gpu=use_gpu, colorspace=colorspace)
|
303 |
+
self.model_name = 'SSIM'
|
304 |
+
else:
|
305 |
+
raise ValueError("Model [%s] not recognized." % self.model)
|
306 |
+
|
307 |
+
self.trainable_parameters = list(self.net.parameters())
|
308 |
+
|
309 |
+
if self.is_train: # training mode
|
310 |
+
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
311 |
+
self.rankLoss = BCERankingLoss()
|
312 |
+
self.trainable_parameters += list(self.rankLoss.net.parameters())
|
313 |
+
self.lr = lr
|
314 |
+
self.old_lr = lr
|
315 |
+
self.optimizer_net = torch.optim.Adam(self.trainable_parameters, lr=lr, betas=(beta1, 0.999))
|
316 |
+
else: # test mode
|
317 |
+
self.net.eval()
|
318 |
+
|
319 |
+
# if (use_gpu):
|
320 |
+
# self.net.to(gpu_ids[0])
|
321 |
+
# self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
322 |
+
# if (self.is_train):
|
323 |
+
# self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
|
324 |
+
|
325 |
+
if (printNet):
|
326 |
+
print('---------- Networks initialized -------------')
|
327 |
+
print_network(self.net)
|
328 |
+
print('-----------------------------------------------')
|
329 |
+
|
330 |
+
def forward(self, in0, in1, retPerLayer=False):
|
331 |
+
''' Function computes the distance between image patches in0 and in1
|
332 |
+
INPUTS
|
333 |
+
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
334 |
+
OUTPUT
|
335 |
+
computed distances between in0 and in1
|
336 |
+
'''
|
337 |
+
|
338 |
+
return self.net(in0, in1, retPerLayer=retPerLayer)
|
339 |
+
|
340 |
+
# ***** TRAINING FUNCTIONS *****
|
341 |
+
def optimize_parameters(self):
|
342 |
+
self.forward_train()
|
343 |
+
self.optimizer_net.zero_grad()
|
344 |
+
self.backward_train()
|
345 |
+
self.optimizer_net.step()
|
346 |
+
self.clamp_weights()
|
347 |
+
|
348 |
+
def clamp_weights(self):
|
349 |
+
for module in self.net.modules():
|
350 |
+
if (hasattr(module, 'weight') and module.kernel_size == (1, 1)):
|
351 |
+
module.weight.data = torch.clamp(module.weight.data, min=0)
|
352 |
+
|
353 |
+
def set_input(self, data):
|
354 |
+
self.input_ref = data['ref']
|
355 |
+
self.input_p0 = data['p0']
|
356 |
+
self.input_p1 = data['p1']
|
357 |
+
self.input_judge = data['judge']
|
358 |
+
|
359 |
+
# if (self.use_gpu):
|
360 |
+
# self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
361 |
+
# self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
362 |
+
# self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
363 |
+
# self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
364 |
+
|
365 |
+
# self.var_ref = Variable(self.input_ref, requires_grad=True)
|
366 |
+
# self.var_p0 = Variable(self.input_p0, requires_grad=True)
|
367 |
+
# self.var_p1 = Variable(self.input_p1, requires_grad=True)
|
368 |
+
|
369 |
+
def forward_train(self): # run forward pass
|
370 |
+
# print(self.net.module.scaling_layer.shift)
|
371 |
+
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
|
372 |
+
|
373 |
+
assert False, "We shoud've not get here when using LPIPS as a metric"
|
374 |
+
|
375 |
+
self.d0 = self(self.var_ref, self.var_p0)
|
376 |
+
self.d1 = self(self.var_ref, self.var_p1)
|
377 |
+
self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
|
378 |
+
|
379 |
+
self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
|
380 |
+
|
381 |
+
self.loss_total = self.rankLoss(self.d0, self.d1, self.var_judge * 2. - 1.)
|
382 |
+
|
383 |
+
return self.loss_total
|
384 |
+
|
385 |
+
def backward_train(self):
|
386 |
+
torch.mean(self.loss_total).backward()
|
387 |
+
|
388 |
+
def compute_accuracy(self, d0, d1, judge):
|
389 |
+
''' d0, d1 are Variables, judge is a Tensor '''
|
390 |
+
d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
|
391 |
+
judge_per = judge.cpu().numpy().flatten()
|
392 |
+
return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
|
393 |
+
|
394 |
+
def get_current_errors(self):
|
395 |
+
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
|
396 |
+
('acc_r', self.acc_r)])
|
397 |
+
|
398 |
+
for key in retDict.keys():
|
399 |
+
retDict[key] = np.mean(retDict[key])
|
400 |
+
|
401 |
+
return retDict
|
402 |
+
|
403 |
+
def get_current_visuals(self):
|
404 |
+
zoom_factor = 256 / self.var_ref.data.size()[2]
|
405 |
+
|
406 |
+
ref_img = tensor2im(self.var_ref.data)
|
407 |
+
p0_img = tensor2im(self.var_p0.data)
|
408 |
+
p1_img = tensor2im(self.var_p1.data)
|
409 |
+
|
410 |
+
ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
|
411 |
+
p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
|
412 |
+
p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
|
413 |
+
|
414 |
+
return OrderedDict([('ref', ref_img_vis),
|
415 |
+
('p0', p0_img_vis),
|
416 |
+
('p1', p1_img_vis)])
|
417 |
+
|
418 |
+
def save(self, path, label):
|
419 |
+
if (self.use_gpu):
|
420 |
+
self.save_network(self.net.module, path, '', label)
|
421 |
+
else:
|
422 |
+
self.save_network(self.net, path, '', label)
|
423 |
+
self.save_network(self.rankLoss.net, path, 'rank', label)
|
424 |
+
|
425 |
+
def update_learning_rate(self, nepoch_decay):
|
426 |
+
lrd = self.lr / nepoch_decay
|
427 |
+
lr = self.old_lr - lrd
|
428 |
+
|
429 |
+
for param_group in self.optimizer_net.param_groups:
|
430 |
+
param_group['lr'] = lr
|
431 |
+
|
432 |
+
print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
|
433 |
+
self.old_lr = lr
|
434 |
+
|
435 |
+
|
436 |
+
def score_2afc_dataset(data_loader, func, name=''):
|
437 |
+
''' Function computes Two Alternative Forced Choice (2AFC) score using
|
438 |
+
distance function 'func' in dataset 'data_loader'
|
439 |
+
INPUTS
|
440 |
+
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
441 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
442 |
+
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
443 |
+
OUTPUTS
|
444 |
+
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
445 |
+
[1] - dictionary with following elements
|
446 |
+
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
447 |
+
gts - N array in [0,1], preferred patch selected by human evaluators
|
448 |
+
(closer to "0" for left patch p0, "1" for right patch p1,
|
449 |
+
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
450 |
+
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
451 |
+
CONSTS
|
452 |
+
N - number of test triplets in data_loader
|
453 |
+
'''
|
454 |
+
|
455 |
+
d0s = []
|
456 |
+
d1s = []
|
457 |
+
gts = []
|
458 |
+
|
459 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
460 |
+
d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
|
461 |
+
d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
|
462 |
+
gts += data['judge'].cpu().numpy().flatten().tolist()
|
463 |
+
|
464 |
+
d0s = np.array(d0s)
|
465 |
+
d1s = np.array(d1s)
|
466 |
+
gts = np.array(gts)
|
467 |
+
scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
|
468 |
+
|
469 |
+
return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
|
470 |
+
|
471 |
+
|
472 |
+
def score_jnd_dataset(data_loader, func, name=''):
|
473 |
+
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
|
474 |
+
INPUTS
|
475 |
+
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
476 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
477 |
+
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
478 |
+
OUTPUTS
|
479 |
+
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
480 |
+
[1] - dictionary with following elements
|
481 |
+
ds - N array containing distances between two patches shown to human evaluator
|
482 |
+
sames - N array containing fraction of people who thought the two patches were identical
|
483 |
+
CONSTS
|
484 |
+
N - number of test triplets in data_loader
|
485 |
+
'''
|
486 |
+
|
487 |
+
ds = []
|
488 |
+
gts = []
|
489 |
+
|
490 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
491 |
+
ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
|
492 |
+
gts += data['same'].cpu().numpy().flatten().tolist()
|
493 |
+
|
494 |
+
sames = np.array(gts)
|
495 |
+
ds = np.array(ds)
|
496 |
+
|
497 |
+
sorted_inds = np.argsort(ds)
|
498 |
+
ds_sorted = ds[sorted_inds]
|
499 |
+
sames_sorted = sames[sorted_inds]
|
500 |
+
|
501 |
+
TPs = np.cumsum(sames_sorted)
|
502 |
+
FPs = np.cumsum(1 - sames_sorted)
|
503 |
+
FNs = np.sum(sames_sorted) - TPs
|
504 |
+
|
505 |
+
precs = TPs / (TPs + FPs)
|
506 |
+
recs = TPs / (TPs + FNs)
|
507 |
+
score = voc_ap(recs, precs)
|
508 |
+
|
509 |
+
return (score, dict(ds=ds, sames=sames))
|
510 |
+
|
511 |
+
|
512 |
+
############################################################
|
513 |
+
# networks_basic.py #
|
514 |
+
############################################################
|
515 |
+
|
516 |
+
import torch.nn as nn
|
517 |
+
from torch.autograd import Variable
|
518 |
+
import numpy as np
|
519 |
+
|
520 |
+
|
521 |
+
def spatial_average(in_tens, keepdim=True):
|
522 |
+
return in_tens.mean([2, 3], keepdim=keepdim)
|
523 |
+
|
524 |
+
|
525 |
+
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
|
526 |
+
in_H = in_tens.shape[2]
|
527 |
+
scale_factor = 1. * out_H / in_H
|
528 |
+
|
529 |
+
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
|
530 |
+
|
531 |
+
|
532 |
+
# Learned perceptual metric
|
533 |
+
class PNetLin(nn.Module):
|
534 |
+
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False,
|
535 |
+
version='0.1', lpips=True):
|
536 |
+
super(PNetLin, self).__init__()
|
537 |
+
|
538 |
+
self.pnet_type = pnet_type
|
539 |
+
self.pnet_tune = pnet_tune
|
540 |
+
self.pnet_rand = pnet_rand
|
541 |
+
self.spatial = spatial
|
542 |
+
self.lpips = lpips
|
543 |
+
self.version = version
|
544 |
+
self.scaling_layer = ScalingLayer()
|
545 |
+
|
546 |
+
if (self.pnet_type in ['vgg', 'vgg16']):
|
547 |
+
net_type = vgg16
|
548 |
+
self.chns = [64, 128, 256, 512, 512]
|
549 |
+
elif (self.pnet_type == 'alex'):
|
550 |
+
net_type = alexnet
|
551 |
+
self.chns = [64, 192, 384, 256, 256]
|
552 |
+
elif (self.pnet_type == 'squeeze'):
|
553 |
+
net_type = squeezenet
|
554 |
+
self.chns = [64, 128, 256, 384, 384, 512, 512]
|
555 |
+
self.L = len(self.chns)
|
556 |
+
|
557 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
558 |
+
|
559 |
+
if (lpips):
|
560 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
561 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
562 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
563 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
564 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
565 |
+
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
566 |
+
if (self.pnet_type == 'squeeze'): # 7 layers for squeezenet
|
567 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
568 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
569 |
+
self.lins += [self.lin5, self.lin6]
|
570 |
+
|
571 |
+
def forward(self, in0, in1, retPerLayer=False):
|
572 |
+
# v0.0 - original release had a bug, where input was not scaled
|
573 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else (
|
574 |
+
in0, in1)
|
575 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
576 |
+
feats0, feats1, diffs = {}, {}, {}
|
577 |
+
|
578 |
+
for kk in range(self.L):
|
579 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
580 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
581 |
+
|
582 |
+
if (self.lpips):
|
583 |
+
if (self.spatial):
|
584 |
+
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
|
585 |
+
else:
|
586 |
+
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
|
587 |
+
else:
|
588 |
+
if (self.spatial):
|
589 |
+
res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
|
590 |
+
else:
|
591 |
+
res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
|
592 |
+
|
593 |
+
val = res[0]
|
594 |
+
for l in range(1, self.L):
|
595 |
+
val += res[l]
|
596 |
+
|
597 |
+
if (retPerLayer):
|
598 |
+
return (val, res)
|
599 |
+
else:
|
600 |
+
return val
|
601 |
+
|
602 |
+
|
603 |
+
class ScalingLayer(nn.Module):
|
604 |
+
def __init__(self):
|
605 |
+
super(ScalingLayer, self).__init__()
|
606 |
+
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
607 |
+
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
608 |
+
|
609 |
+
def forward(self, inp):
|
610 |
+
return (inp - self.shift) / self.scale
|
611 |
+
|
612 |
+
|
613 |
+
class NetLinLayer(nn.Module):
|
614 |
+
''' A single linear layer which does a 1x1 conv '''
|
615 |
+
|
616 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
617 |
+
super(NetLinLayer, self).__init__()
|
618 |
+
|
619 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
620 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
621 |
+
self.model = nn.Sequential(*layers)
|
622 |
+
|
623 |
+
|
624 |
+
class Dist2LogitLayer(nn.Module):
|
625 |
+
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
626 |
+
|
627 |
+
def __init__(self, chn_mid=32, use_sigmoid=True):
|
628 |
+
super(Dist2LogitLayer, self).__init__()
|
629 |
+
|
630 |
+
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ]
|
631 |
+
layers += [nn.LeakyReLU(0.2, True), ]
|
632 |
+
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ]
|
633 |
+
layers += [nn.LeakyReLU(0.2, True), ]
|
634 |
+
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ]
|
635 |
+
if (use_sigmoid):
|
636 |
+
layers += [nn.Sigmoid(), ]
|
637 |
+
self.model = nn.Sequential(*layers)
|
638 |
+
|
639 |
+
def forward(self, d0, d1, eps=0.1):
|
640 |
+
return self.model(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1))
|
641 |
+
|
642 |
+
|
643 |
+
class BCERankingLoss(nn.Module):
|
644 |
+
def __init__(self, chn_mid=32):
|
645 |
+
super(BCERankingLoss, self).__init__()
|
646 |
+
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
647 |
+
# self.parameters = list(self.net.parameters())
|
648 |
+
self.loss = torch.nn.BCELoss()
|
649 |
+
|
650 |
+
def forward(self, d0, d1, judge):
|
651 |
+
per = (judge + 1.) / 2.
|
652 |
+
self.logit = self.net(d0, d1)
|
653 |
+
return self.loss(self.logit, per)
|
654 |
+
|
655 |
+
|
656 |
+
# L2, DSSIM metrics
|
657 |
+
class FakeNet(nn.Module):
|
658 |
+
def __init__(self, use_gpu=True, colorspace='Lab'):
|
659 |
+
super(FakeNet, self).__init__()
|
660 |
+
self.use_gpu = use_gpu
|
661 |
+
self.colorspace = colorspace
|
662 |
+
|
663 |
+
|
664 |
+
class L2(FakeNet):
|
665 |
+
|
666 |
+
def forward(self, in0, in1, retPerLayer=None):
|
667 |
+
assert (in0.size()[0] == 1) # currently only supports batchSize 1
|
668 |
+
|
669 |
+
if (self.colorspace == 'RGB'):
|
670 |
+
(N, C, X, Y) = in0.size()
|
671 |
+
value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y),
|
672 |
+
dim=3).view(N)
|
673 |
+
return value
|
674 |
+
elif (self.colorspace == 'Lab'):
|
675 |
+
value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
|
676 |
+
tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float')
|
677 |
+
ret_var = Variable(torch.Tensor((value,)))
|
678 |
+
# if (self.use_gpu):
|
679 |
+
# ret_var = ret_var.cuda()
|
680 |
+
return ret_var
|
681 |
+
|
682 |
+
|
683 |
+
class DSSIM(FakeNet):
|
684 |
+
|
685 |
+
def forward(self, in0, in1, retPerLayer=None):
|
686 |
+
assert (in0.size()[0] == 1) # currently only supports batchSize 1
|
687 |
+
|
688 |
+
if (self.colorspace == 'RGB'):
|
689 |
+
value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype('float')
|
690 |
+
elif (self.colorspace == 'Lab'):
|
691 |
+
value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
|
692 |
+
tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float')
|
693 |
+
ret_var = Variable(torch.Tensor((value,)))
|
694 |
+
# if (self.use_gpu):
|
695 |
+
# ret_var = ret_var.cuda()
|
696 |
+
return ret_var
|
697 |
+
|
698 |
+
|
699 |
+
def print_network(net):
|
700 |
+
num_params = 0
|
701 |
+
for param in net.parameters():
|
702 |
+
num_params += param.numel()
|
703 |
+
print('Network', net)
|
704 |
+
print('Total number of parameters: %d' % num_params)
|
705 |
+
|
706 |
+
|
707 |
+
############################################################
|
708 |
+
# pretrained_networks.py #
|
709 |
+
############################################################
|
710 |
+
|
711 |
+
from collections import namedtuple
|
712 |
+
import torch
|
713 |
+
from torchvision import models as tv
|
714 |
+
|
715 |
+
|
716 |
+
class squeezenet(torch.nn.Module):
|
717 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
718 |
+
super(squeezenet, self).__init__()
|
719 |
+
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
720 |
+
self.slice1 = torch.nn.Sequential()
|
721 |
+
self.slice2 = torch.nn.Sequential()
|
722 |
+
self.slice3 = torch.nn.Sequential()
|
723 |
+
self.slice4 = torch.nn.Sequential()
|
724 |
+
self.slice5 = torch.nn.Sequential()
|
725 |
+
self.slice6 = torch.nn.Sequential()
|
726 |
+
self.slice7 = torch.nn.Sequential()
|
727 |
+
self.N_slices = 7
|
728 |
+
for x in range(2):
|
729 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
730 |
+
for x in range(2, 5):
|
731 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
732 |
+
for x in range(5, 8):
|
733 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
734 |
+
for x in range(8, 10):
|
735 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
736 |
+
for x in range(10, 11):
|
737 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
738 |
+
for x in range(11, 12):
|
739 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
740 |
+
for x in range(12, 13):
|
741 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
742 |
+
if not requires_grad:
|
743 |
+
for param in self.parameters():
|
744 |
+
param.requires_grad = False
|
745 |
+
|
746 |
+
def forward(self, X):
|
747 |
+
h = self.slice1(X)
|
748 |
+
h_relu1 = h
|
749 |
+
h = self.slice2(h)
|
750 |
+
h_relu2 = h
|
751 |
+
h = self.slice3(h)
|
752 |
+
h_relu3 = h
|
753 |
+
h = self.slice4(h)
|
754 |
+
h_relu4 = h
|
755 |
+
h = self.slice5(h)
|
756 |
+
h_relu5 = h
|
757 |
+
h = self.slice6(h)
|
758 |
+
h_relu6 = h
|
759 |
+
h = self.slice7(h)
|
760 |
+
h_relu7 = h
|
761 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])
|
762 |
+
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
|
763 |
+
|
764 |
+
return out
|
765 |
+
|
766 |
+
|
767 |
+
class alexnet(torch.nn.Module):
|
768 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
769 |
+
super(alexnet, self).__init__()
|
770 |
+
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
771 |
+
self.slice1 = torch.nn.Sequential()
|
772 |
+
self.slice2 = torch.nn.Sequential()
|
773 |
+
self.slice3 = torch.nn.Sequential()
|
774 |
+
self.slice4 = torch.nn.Sequential()
|
775 |
+
self.slice5 = torch.nn.Sequential()
|
776 |
+
self.N_slices = 5
|
777 |
+
for x in range(2):
|
778 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
779 |
+
for x in range(2, 5):
|
780 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
781 |
+
for x in range(5, 8):
|
782 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
783 |
+
for x in range(8, 10):
|
784 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
785 |
+
for x in range(10, 12):
|
786 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
787 |
+
if not requires_grad:
|
788 |
+
for param in self.parameters():
|
789 |
+
param.requires_grad = False
|
790 |
+
|
791 |
+
def forward(self, X):
|
792 |
+
h = self.slice1(X)
|
793 |
+
h_relu1 = h
|
794 |
+
h = self.slice2(h)
|
795 |
+
h_relu2 = h
|
796 |
+
h = self.slice3(h)
|
797 |
+
h_relu3 = h
|
798 |
+
h = self.slice4(h)
|
799 |
+
h_relu4 = h
|
800 |
+
h = self.slice5(h)
|
801 |
+
h_relu5 = h
|
802 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
803 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
804 |
+
|
805 |
+
return out
|
806 |
+
|
807 |
+
|
808 |
+
class vgg16(torch.nn.Module):
|
809 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
810 |
+
super(vgg16, self).__init__()
|
811 |
+
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
812 |
+
self.slice1 = torch.nn.Sequential()
|
813 |
+
self.slice2 = torch.nn.Sequential()
|
814 |
+
self.slice3 = torch.nn.Sequential()
|
815 |
+
self.slice4 = torch.nn.Sequential()
|
816 |
+
self.slice5 = torch.nn.Sequential()
|
817 |
+
self.N_slices = 5
|
818 |
+
for x in range(4):
|
819 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
820 |
+
for x in range(4, 9):
|
821 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
822 |
+
for x in range(9, 16):
|
823 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
824 |
+
for x in range(16, 23):
|
825 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
826 |
+
for x in range(23, 30):
|
827 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
828 |
+
if not requires_grad:
|
829 |
+
for param in self.parameters():
|
830 |
+
param.requires_grad = False
|
831 |
+
|
832 |
+
def forward(self, X):
|
833 |
+
h = self.slice1(X)
|
834 |
+
h_relu1_2 = h
|
835 |
+
h = self.slice2(h)
|
836 |
+
h_relu2_2 = h
|
837 |
+
h = self.slice3(h)
|
838 |
+
h_relu3_3 = h
|
839 |
+
h = self.slice4(h)
|
840 |
+
h_relu4_3 = h
|
841 |
+
h = self.slice5(h)
|
842 |
+
h_relu5_3 = h
|
843 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
844 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
845 |
+
|
846 |
+
return out
|
847 |
+
|
848 |
+
|
849 |
+
class resnet(torch.nn.Module):
|
850 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
851 |
+
super(resnet, self).__init__()
|
852 |
+
if (num == 18):
|
853 |
+
self.net = tv.resnet18(pretrained=pretrained)
|
854 |
+
elif (num == 34):
|
855 |
+
self.net = tv.resnet34(pretrained=pretrained)
|
856 |
+
elif (num == 50):
|
857 |
+
self.net = tv.resnet50(pretrained=pretrained)
|
858 |
+
elif (num == 101):
|
859 |
+
self.net = tv.resnet101(pretrained=pretrained)
|
860 |
+
elif (num == 152):
|
861 |
+
self.net = tv.resnet152(pretrained=pretrained)
|
862 |
+
self.N_slices = 5
|
863 |
+
|
864 |
+
self.conv1 = self.net.conv1
|
865 |
+
self.bn1 = self.net.bn1
|
866 |
+
self.relu = self.net.relu
|
867 |
+
self.maxpool = self.net.maxpool
|
868 |
+
self.layer1 = self.net.layer1
|
869 |
+
self.layer2 = self.net.layer2
|
870 |
+
self.layer3 = self.net.layer3
|
871 |
+
self.layer4 = self.net.layer4
|
872 |
+
|
873 |
+
def forward(self, X):
|
874 |
+
h = self.conv1(X)
|
875 |
+
h = self.bn1(h)
|
876 |
+
h = self.relu(h)
|
877 |
+
h_relu1 = h
|
878 |
+
h = self.maxpool(h)
|
879 |
+
h = self.layer1(h)
|
880 |
+
h_conv2 = h
|
881 |
+
h = self.layer2(h)
|
882 |
+
h_conv3 = h
|
883 |
+
h = self.layer3(h)
|
884 |
+
h_conv4 = h
|
885 |
+
h = self.layer4(h)
|
886 |
+
h_conv5 = h
|
887 |
+
|
888 |
+
outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
|
889 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
890 |
+
|
891 |
+
return out
|
saicinpainting/evaluation/losses/ssim.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class SSIM(torch.nn.Module):
|
7 |
+
"""SSIM. Modified from:
|
8 |
+
https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, window_size=11, size_average=True):
|
12 |
+
super().__init__()
|
13 |
+
self.window_size = window_size
|
14 |
+
self.size_average = size_average
|
15 |
+
self.channel = 1
|
16 |
+
self.register_buffer('window', self._create_window(window_size, self.channel))
|
17 |
+
|
18 |
+
def forward(self, img1, img2):
|
19 |
+
assert len(img1.shape) == 4
|
20 |
+
|
21 |
+
channel = img1.size()[1]
|
22 |
+
|
23 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
24 |
+
window = self.window
|
25 |
+
else:
|
26 |
+
window = self._create_window(self.window_size, channel)
|
27 |
+
|
28 |
+
# window = window.to(img1.get_device())
|
29 |
+
window = window.type_as(img1)
|
30 |
+
|
31 |
+
self.window = window
|
32 |
+
self.channel = channel
|
33 |
+
|
34 |
+
return self._ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
35 |
+
|
36 |
+
def _gaussian(self, window_size, sigma):
|
37 |
+
gauss = torch.Tensor([
|
38 |
+
np.exp(-(x - (window_size // 2)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)
|
39 |
+
])
|
40 |
+
return gauss / gauss.sum()
|
41 |
+
|
42 |
+
def _create_window(self, window_size, channel):
|
43 |
+
_1D_window = self._gaussian(window_size, 1.5).unsqueeze(1)
|
44 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
45 |
+
return _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
46 |
+
|
47 |
+
def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
|
48 |
+
mu1 = F.conv2d(img1, window, padding=(window_size // 2), groups=channel)
|
49 |
+
mu2 = F.conv2d(img2, window, padding=(window_size // 2), groups=channel)
|
50 |
+
|
51 |
+
mu1_sq = mu1.pow(2)
|
52 |
+
mu2_sq = mu2.pow(2)
|
53 |
+
mu1_mu2 = mu1 * mu2
|
54 |
+
|
55 |
+
sigma1_sq = F.conv2d(
|
56 |
+
img1 * img1, window, padding=(window_size // 2), groups=channel) - mu1_sq
|
57 |
+
sigma2_sq = F.conv2d(
|
58 |
+
img2 * img2, window, padding=(window_size // 2), groups=channel) - mu2_sq
|
59 |
+
sigma12 = F.conv2d(
|
60 |
+
img1 * img2, window, padding=(window_size // 2), groups=channel) - mu1_mu2
|
61 |
+
|
62 |
+
C1 = 0.01 ** 2
|
63 |
+
C2 = 0.03 ** 2
|
64 |
+
|
65 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
|
66 |
+
((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
67 |
+
|
68 |
+
if size_average:
|
69 |
+
return ssim_map.mean()
|
70 |
+
|
71 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
72 |
+
|
73 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
74 |
+
return
|
saicinpainting/evaluation/masks/README.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Current algorithm
|
2 |
+
|
3 |
+
## Choice of mask objects
|
4 |
+
|
5 |
+
For identification of the objects which are suitable for mask obtaining, panoptic segmentation model
|
6 |
+
from [detectron2](https://github.com/facebookresearch/detectron2) trained on COCO. Categories of the detected instances
|
7 |
+
belong either to "stuff" or "things" types. We consider that instances of objects should have category belong
|
8 |
+
to "things". Besides, we set upper bound on area which is taken by the object — we consider that too big
|
9 |
+
area indicates either of the instance being a background or a main object which should not be removed.
|
10 |
+
|
11 |
+
## Choice of position for mask
|
12 |
+
|
13 |
+
We consider that input image has size 2^n x 2^m. We downsample it using
|
14 |
+
[COUNTLESS](https://github.com/william-silversmith/countless) algorithm so the width is equal to
|
15 |
+
64 = 2^8 = 2^{downsample_levels}.
|
16 |
+
|
17 |
+
### Augmentation
|
18 |
+
|
19 |
+
There are several parameters for augmentation:
|
20 |
+
- Scaling factor. We limit scaling to the case when a mask after scaling with pivot point in its center fits inside the
|
21 |
+
image completely.
|
22 |
+
-
|
23 |
+
|
24 |
+
### Shift
|
25 |
+
|
26 |
+
|
27 |
+
## Select
|
saicinpainting/evaluation/masks/__init__.py
ADDED
File without changes
|
saicinpainting/evaluation/masks/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (160 Bytes). View file
|
|
saicinpainting/evaluation/masks/__pycache__/mask.cpython-39.pyc
ADDED
Binary file (13.8 kB). View file
|
|
saicinpainting/evaluation/masks/countless/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
results
|
saicinpainting/evaluation/masks/countless/README.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[![Build Status](https://travis-ci.org/william-silversmith/countless.svg?branch=master)](https://travis-ci.org/william-silversmith/countless)
|
2 |
+
|
3 |
+
Python COUNTLESS Downsampling
|
4 |
+
=============================
|
5 |
+
|
6 |
+
To install:
|
7 |
+
|
8 |
+
`pip install -r requirements.txt`
|
9 |
+
|
10 |
+
To test:
|
11 |
+
|
12 |
+
`python test.py`
|
13 |
+
|
14 |
+
To benchmark countless2d:
|
15 |
+
|
16 |
+
`python python/countless2d.py python/images/gray_segmentation.png`
|
17 |
+
|
18 |
+
To benchmark countless3d:
|
19 |
+
|
20 |
+
`python python/countless3d.py`
|
21 |
+
|
22 |
+
Adjust N and the list of algorithms inside each script to modify the run parameters.
|
23 |
+
|
24 |
+
|
25 |
+
Python3 is slightly faster than Python2.
|
saicinpainting/evaluation/masks/countless/__init__.py
ADDED
File without changes
|
saicinpainting/evaluation/masks/countless/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (170 Bytes). View file
|
|
saicinpainting/evaluation/masks/countless/__pycache__/countless2d.cpython-39.pyc
ADDED
Binary file (11.3 kB). View file
|
|
saicinpainting/evaluation/masks/countless/countless2d.py
ADDED
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division
|
2 |
+
|
3 |
+
"""
|
4 |
+
COUNTLESS performance test in Python.
|
5 |
+
|
6 |
+
python countless2d.py ./images/NAMEOFIMAGE
|
7 |
+
"""
|
8 |
+
|
9 |
+
import six
|
10 |
+
from six.moves import range
|
11 |
+
from collections import defaultdict
|
12 |
+
from functools import reduce
|
13 |
+
import operator
|
14 |
+
import io
|
15 |
+
import os
|
16 |
+
from PIL import Image
|
17 |
+
import math
|
18 |
+
import numpy as np
|
19 |
+
import random
|
20 |
+
import sys
|
21 |
+
import time
|
22 |
+
from tqdm import tqdm
|
23 |
+
from scipy import ndimage
|
24 |
+
|
25 |
+
def simplest_countless(data):
|
26 |
+
"""
|
27 |
+
Vectorized implementation of downsampling a 2D
|
28 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
29 |
+
|
30 |
+
data is a 2D numpy array with even dimensions.
|
31 |
+
"""
|
32 |
+
sections = []
|
33 |
+
|
34 |
+
# This loop splits the 2D array apart into four arrays that are
|
35 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
36 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
37 |
+
factor = (2,2)
|
38 |
+
for offset in np.ndindex(factor):
|
39 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
40 |
+
sections.append(part)
|
41 |
+
|
42 |
+
a, b, c, d = sections
|
43 |
+
|
44 |
+
ab = a * (a == b) # PICK(A,B)
|
45 |
+
ac = a * (a == c) # PICK(A,C)
|
46 |
+
bc = b * (b == c) # PICK(B,C)
|
47 |
+
|
48 |
+
a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed
|
49 |
+
|
50 |
+
return a + (a == 0) * d # AB || AC || BC || D
|
51 |
+
|
52 |
+
def quick_countless(data):
|
53 |
+
"""
|
54 |
+
Vectorized implementation of downsampling a 2D
|
55 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
56 |
+
|
57 |
+
data is a 2D numpy array with even dimensions.
|
58 |
+
"""
|
59 |
+
sections = []
|
60 |
+
|
61 |
+
# This loop splits the 2D array apart into four arrays that are
|
62 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
63 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
64 |
+
factor = (2,2)
|
65 |
+
for offset in np.ndindex(factor):
|
66 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
67 |
+
sections.append(part)
|
68 |
+
|
69 |
+
a, b, c, d = sections
|
70 |
+
|
71 |
+
ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
|
72 |
+
bc = b * (b == c) # PICK(B,C)
|
73 |
+
|
74 |
+
a = ab_ac | bc # (PICK(A,B) || PICK(A,C)) or PICK(B,C)
|
75 |
+
return a + (a == 0) * d # AB || AC || BC || D
|
76 |
+
|
77 |
+
def quickest_countless(data):
|
78 |
+
"""
|
79 |
+
Vectorized implementation of downsampling a 2D
|
80 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
81 |
+
|
82 |
+
data is a 2D numpy array with even dimensions.
|
83 |
+
"""
|
84 |
+
sections = []
|
85 |
+
|
86 |
+
# This loop splits the 2D array apart into four arrays that are
|
87 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
88 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
89 |
+
factor = (2,2)
|
90 |
+
for offset in np.ndindex(factor):
|
91 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
92 |
+
sections.append(part)
|
93 |
+
|
94 |
+
a, b, c, d = sections
|
95 |
+
|
96 |
+
ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
|
97 |
+
ab_ac |= b * (b == c) # PICK(B,C)
|
98 |
+
return ab_ac + (ab_ac == 0) * d # AB || AC || BC || D
|
99 |
+
|
100 |
+
def quick_countless_xor(data):
|
101 |
+
"""
|
102 |
+
Vectorized implementation of downsampling a 2D
|
103 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
104 |
+
|
105 |
+
data is a 2D numpy array with even dimensions.
|
106 |
+
"""
|
107 |
+
sections = []
|
108 |
+
|
109 |
+
# This loop splits the 2D array apart into four arrays that are
|
110 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
111 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
112 |
+
factor = (2,2)
|
113 |
+
for offset in np.ndindex(factor):
|
114 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
115 |
+
sections.append(part)
|
116 |
+
|
117 |
+
a, b, c, d = sections
|
118 |
+
|
119 |
+
ab = a ^ (a ^ b) # a or b
|
120 |
+
ab += (ab != a) * ((ab ^ (ab ^ c)) - b) # b or c
|
121 |
+
ab += (ab == c) * ((ab ^ (ab ^ d)) - c) # c or d
|
122 |
+
return ab
|
123 |
+
|
124 |
+
def stippled_countless(data):
|
125 |
+
"""
|
126 |
+
Vectorized implementation of downsampling a 2D
|
127 |
+
image by 2 on each side using the COUNTLESS algorithm
|
128 |
+
that treats zero as "background" and inflates lone
|
129 |
+
pixels.
|
130 |
+
|
131 |
+
data is a 2D numpy array with even dimensions.
|
132 |
+
"""
|
133 |
+
sections = []
|
134 |
+
|
135 |
+
# This loop splits the 2D array apart into four arrays that are
|
136 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
137 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
138 |
+
factor = (2,2)
|
139 |
+
for offset in np.ndindex(factor):
|
140 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
141 |
+
sections.append(part)
|
142 |
+
|
143 |
+
a, b, c, d = sections
|
144 |
+
|
145 |
+
ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
|
146 |
+
ab_ac |= b * (b == c) # PICK(B,C)
|
147 |
+
|
148 |
+
nonzero = a + (a == 0) * (b + (b == 0) * c)
|
149 |
+
return ab_ac + (ab_ac == 0) * (d + (d == 0) * nonzero) # AB || AC || BC || D
|
150 |
+
|
151 |
+
def zero_corrected_countless(data):
|
152 |
+
"""
|
153 |
+
Vectorized implementation of downsampling a 2D
|
154 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
155 |
+
|
156 |
+
data is a 2D numpy array with even dimensions.
|
157 |
+
"""
|
158 |
+
# allows us to prevent losing 1/2 a bit of information
|
159 |
+
# at the top end by using a bigger type. Without this 255 is handled incorrectly.
|
160 |
+
data, upgraded = upgrade_type(data)
|
161 |
+
|
162 |
+
# offset from zero, raw countless doesn't handle 0 correctly
|
163 |
+
# we'll remove the extra 1 at the end.
|
164 |
+
data += 1
|
165 |
+
|
166 |
+
sections = []
|
167 |
+
|
168 |
+
# This loop splits the 2D array apart into four arrays that are
|
169 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
170 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
171 |
+
factor = (2,2)
|
172 |
+
for offset in np.ndindex(factor):
|
173 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
174 |
+
sections.append(part)
|
175 |
+
|
176 |
+
a, b, c, d = sections
|
177 |
+
|
178 |
+
ab = a * (a == b) # PICK(A,B)
|
179 |
+
ac = a * (a == c) # PICK(A,C)
|
180 |
+
bc = b * (b == c) # PICK(B,C)
|
181 |
+
|
182 |
+
a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed
|
183 |
+
|
184 |
+
result = a + (a == 0) * d - 1 # a or d - 1
|
185 |
+
|
186 |
+
if upgraded:
|
187 |
+
return downgrade_type(result)
|
188 |
+
|
189 |
+
# only need to reset data if we weren't upgraded
|
190 |
+
# b/c no copy was made in that case
|
191 |
+
data -= 1
|
192 |
+
|
193 |
+
return result
|
194 |
+
|
195 |
+
def countless_extreme(data):
|
196 |
+
nonzeros = np.count_nonzero(data)
|
197 |
+
# print("nonzeros", nonzeros)
|
198 |
+
|
199 |
+
N = reduce(operator.mul, data.shape)
|
200 |
+
|
201 |
+
if nonzeros == N:
|
202 |
+
print("quick")
|
203 |
+
return quick_countless(data)
|
204 |
+
elif np.count_nonzero(data + 1) == N:
|
205 |
+
print("quick")
|
206 |
+
# print("upper", nonzeros)
|
207 |
+
return quick_countless(data)
|
208 |
+
else:
|
209 |
+
return countless(data)
|
210 |
+
|
211 |
+
|
212 |
+
def countless(data):
|
213 |
+
"""
|
214 |
+
Vectorized implementation of downsampling a 2D
|
215 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
216 |
+
|
217 |
+
data is a 2D numpy array with even dimensions.
|
218 |
+
"""
|
219 |
+
# allows us to prevent losing 1/2 a bit of information
|
220 |
+
# at the top end by using a bigger type. Without this 255 is handled incorrectly.
|
221 |
+
data, upgraded = upgrade_type(data)
|
222 |
+
|
223 |
+
# offset from zero, raw countless doesn't handle 0 correctly
|
224 |
+
# we'll remove the extra 1 at the end.
|
225 |
+
data += 1
|
226 |
+
|
227 |
+
sections = []
|
228 |
+
|
229 |
+
# This loop splits the 2D array apart into four arrays that are
|
230 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
231 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
232 |
+
factor = (2,2)
|
233 |
+
for offset in np.ndindex(factor):
|
234 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
235 |
+
sections.append(part)
|
236 |
+
|
237 |
+
a, b, c, d = sections
|
238 |
+
|
239 |
+
ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
|
240 |
+
ab_ac |= b * (b == c) # PICK(B,C)
|
241 |
+
result = ab_ac + (ab_ac == 0) * d - 1 # (matches or d) - 1
|
242 |
+
|
243 |
+
if upgraded:
|
244 |
+
return downgrade_type(result)
|
245 |
+
|
246 |
+
# only need to reset data if we weren't upgraded
|
247 |
+
# b/c no copy was made in that case
|
248 |
+
data -= 1
|
249 |
+
|
250 |
+
return result
|
251 |
+
|
252 |
+
def upgrade_type(arr):
|
253 |
+
dtype = arr.dtype
|
254 |
+
|
255 |
+
if dtype == np.uint8:
|
256 |
+
return arr.astype(np.uint16), True
|
257 |
+
elif dtype == np.uint16:
|
258 |
+
return arr.astype(np.uint32), True
|
259 |
+
elif dtype == np.uint32:
|
260 |
+
return arr.astype(np.uint64), True
|
261 |
+
|
262 |
+
return arr, False
|
263 |
+
|
264 |
+
def downgrade_type(arr):
|
265 |
+
dtype = arr.dtype
|
266 |
+
|
267 |
+
if dtype == np.uint64:
|
268 |
+
return arr.astype(np.uint32)
|
269 |
+
elif dtype == np.uint32:
|
270 |
+
return arr.astype(np.uint16)
|
271 |
+
elif dtype == np.uint16:
|
272 |
+
return arr.astype(np.uint8)
|
273 |
+
|
274 |
+
return arr
|
275 |
+
|
276 |
+
def odd_to_even(image):
|
277 |
+
"""
|
278 |
+
To facilitate 2x2 downsampling segmentation, change an odd sized image into an even sized one.
|
279 |
+
Works by mirroring the starting 1 pixel edge of the image on odd shaped sides.
|
280 |
+
|
281 |
+
e.g. turn a 3x3x5 image into a 4x4x5 (the x and y are what are getting downsampled)
|
282 |
+
|
283 |
+
For example: [ 3, 2, 4 ] => [ 3, 3, 2, 4 ] which is now easy to downsample.
|
284 |
+
|
285 |
+
"""
|
286 |
+
shape = np.array(image.shape)
|
287 |
+
|
288 |
+
offset = (shape % 2)[:2] # x,y offset
|
289 |
+
|
290 |
+
# detect if we're dealing with an even
|
291 |
+
# image. if so it's fine, just return.
|
292 |
+
if not np.any(offset):
|
293 |
+
return image
|
294 |
+
|
295 |
+
oddshape = image.shape[:2] + offset
|
296 |
+
oddshape = np.append(oddshape, shape[2:])
|
297 |
+
oddshape = oddshape.astype(int)
|
298 |
+
|
299 |
+
newimg = np.empty(shape=oddshape, dtype=image.dtype)
|
300 |
+
|
301 |
+
ox,oy = offset
|
302 |
+
sx,sy = oddshape
|
303 |
+
|
304 |
+
newimg[0,0] = image[0,0] # corner
|
305 |
+
newimg[ox:sx,0] = image[:,0] # x axis line
|
306 |
+
newimg[0,oy:sy] = image[0,:] # y axis line
|
307 |
+
|
308 |
+
return newimg
|
309 |
+
|
310 |
+
def counting(array):
|
311 |
+
factor = (2, 2, 1)
|
312 |
+
shape = array.shape
|
313 |
+
|
314 |
+
while len(shape) < 4:
|
315 |
+
array = np.expand_dims(array, axis=-1)
|
316 |
+
shape = array.shape
|
317 |
+
|
318 |
+
output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor))
|
319 |
+
output = np.zeros(output_shape, dtype=array.dtype)
|
320 |
+
|
321 |
+
for chan in range(0, shape[3]):
|
322 |
+
for z in range(0, shape[2]):
|
323 |
+
for x in range(0, shape[0], 2):
|
324 |
+
for y in range(0, shape[1], 2):
|
325 |
+
block = array[ x:x+2, y:y+2, z, chan ] # 2x2 block
|
326 |
+
|
327 |
+
hashtable = defaultdict(int)
|
328 |
+
for subx, suby in np.ndindex(block.shape[0], block.shape[1]):
|
329 |
+
hashtable[block[subx, suby]] += 1
|
330 |
+
|
331 |
+
best = (0, 0)
|
332 |
+
for segid, val in six.iteritems(hashtable):
|
333 |
+
if best[1] < val:
|
334 |
+
best = (segid, val)
|
335 |
+
|
336 |
+
output[ x // 2, y // 2, chan ] = best[0]
|
337 |
+
|
338 |
+
return output
|
339 |
+
|
340 |
+
def ndzoom(array):
|
341 |
+
if len(array.shape) == 3:
|
342 |
+
ratio = ( 1 / 2.0, 1 / 2.0, 1.0 )
|
343 |
+
else:
|
344 |
+
ratio = ( 1 / 2.0, 1 / 2.0)
|
345 |
+
return ndimage.interpolation.zoom(array, ratio, order=1)
|
346 |
+
|
347 |
+
def countless_if(array):
|
348 |
+
factor = (2, 2, 1)
|
349 |
+
shape = array.shape
|
350 |
+
|
351 |
+
if len(shape) < 3:
|
352 |
+
array = array[ :,:, np.newaxis ]
|
353 |
+
shape = array.shape
|
354 |
+
|
355 |
+
output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor))
|
356 |
+
output = np.zeros(output_shape, dtype=array.dtype)
|
357 |
+
|
358 |
+
for chan in range(0, shape[2]):
|
359 |
+
for x in range(0, shape[0], 2):
|
360 |
+
for y in range(0, shape[1], 2):
|
361 |
+
block = array[ x:x+2, y:y+2, chan ] # 2x2 block
|
362 |
+
|
363 |
+
if block[0,0] == block[1,0]:
|
364 |
+
pick = block[0,0]
|
365 |
+
elif block[0,0] == block[0,1]:
|
366 |
+
pick = block[0,0]
|
367 |
+
elif block[1,0] == block[0,1]:
|
368 |
+
pick = block[1,0]
|
369 |
+
else:
|
370 |
+
pick = block[1,1]
|
371 |
+
|
372 |
+
output[ x // 2, y // 2, chan ] = pick
|
373 |
+
|
374 |
+
return np.squeeze(output)
|
375 |
+
|
376 |
+
def downsample_with_averaging(array):
|
377 |
+
"""
|
378 |
+
Downsample x by factor using averaging.
|
379 |
+
|
380 |
+
@return: The downsampled array, of the same type as x.
|
381 |
+
"""
|
382 |
+
|
383 |
+
if len(array.shape) == 3:
|
384 |
+
factor = (2,2,1)
|
385 |
+
else:
|
386 |
+
factor = (2,2)
|
387 |
+
|
388 |
+
if np.array_equal(factor[:3], np.array([1,1,1])):
|
389 |
+
return array
|
390 |
+
|
391 |
+
output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor))
|
392 |
+
temp = np.zeros(output_shape, float)
|
393 |
+
counts = np.zeros(output_shape, np.int)
|
394 |
+
for offset in np.ndindex(factor):
|
395 |
+
part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
396 |
+
indexing_expr = tuple(np.s_[:s] for s in part.shape)
|
397 |
+
temp[indexing_expr] += part
|
398 |
+
counts[indexing_expr] += 1
|
399 |
+
return np.cast[array.dtype](temp / counts)
|
400 |
+
|
401 |
+
def downsample_with_max_pooling(array):
|
402 |
+
|
403 |
+
factor = (2,2)
|
404 |
+
|
405 |
+
if np.all(np.array(factor, int) == 1):
|
406 |
+
return array
|
407 |
+
|
408 |
+
sections = []
|
409 |
+
|
410 |
+
for offset in np.ndindex(factor):
|
411 |
+
part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
412 |
+
sections.append(part)
|
413 |
+
|
414 |
+
output = sections[0].copy()
|
415 |
+
|
416 |
+
for section in sections[1:]:
|
417 |
+
np.maximum(output, section, output)
|
418 |
+
|
419 |
+
return output
|
420 |
+
|
421 |
+
def striding(array):
|
422 |
+
"""Downsample x by factor using striding.
|
423 |
+
|
424 |
+
@return: The downsampled array, of the same type as x.
|
425 |
+
"""
|
426 |
+
factor = (2,2)
|
427 |
+
if np.all(np.array(factor, int) == 1):
|
428 |
+
return array
|
429 |
+
return array[tuple(np.s_[::f] for f in factor)]
|
430 |
+
|
431 |
+
def benchmark():
|
432 |
+
filename = sys.argv[1]
|
433 |
+
img = Image.open(filename)
|
434 |
+
data = np.array(img.getdata(), dtype=np.uint8)
|
435 |
+
|
436 |
+
if len(data.shape) == 1:
|
437 |
+
n_channels = 1
|
438 |
+
reshape = (img.height, img.width)
|
439 |
+
else:
|
440 |
+
n_channels = min(data.shape[1], 3)
|
441 |
+
data = data[:, :n_channels]
|
442 |
+
reshape = (img.height, img.width, n_channels)
|
443 |
+
|
444 |
+
data = data.reshape(reshape).astype(np.uint8)
|
445 |
+
|
446 |
+
methods = [
|
447 |
+
simplest_countless,
|
448 |
+
quick_countless,
|
449 |
+
quick_countless_xor,
|
450 |
+
quickest_countless,
|
451 |
+
stippled_countless,
|
452 |
+
zero_corrected_countless,
|
453 |
+
countless,
|
454 |
+
downsample_with_averaging,
|
455 |
+
downsample_with_max_pooling,
|
456 |
+
ndzoom,
|
457 |
+
striding,
|
458 |
+
# countless_if,
|
459 |
+
# counting,
|
460 |
+
]
|
461 |
+
|
462 |
+
formats = {
|
463 |
+
1: 'L',
|
464 |
+
3: 'RGB',
|
465 |
+
4: 'RGBA'
|
466 |
+
}
|
467 |
+
|
468 |
+
if not os.path.exists('./results'):
|
469 |
+
os.mkdir('./results')
|
470 |
+
|
471 |
+
N = 500
|
472 |
+
img_size = float(img.width * img.height) / 1024.0 / 1024.0
|
473 |
+
print("N = %d, %dx%d (%.2f MPx) %d chan, %s" % (N, img.width, img.height, img_size, n_channels, filename))
|
474 |
+
print("Algorithm\tMPx/sec\tMB/sec\tSec")
|
475 |
+
for fn in methods:
|
476 |
+
print(fn.__name__, end='')
|
477 |
+
sys.stdout.flush()
|
478 |
+
|
479 |
+
start = time.time()
|
480 |
+
# tqdm is here to show you what's going on the first time you run it.
|
481 |
+
# Feel free to remove it to get slightly more accurate timing results.
|
482 |
+
for _ in tqdm(range(N), desc=fn.__name__, disable=True):
|
483 |
+
result = fn(data)
|
484 |
+
end = time.time()
|
485 |
+
print("\r", end='')
|
486 |
+
|
487 |
+
total_time = (end - start)
|
488 |
+
mpx = N * img_size / total_time
|
489 |
+
mbytes = N * img_size * n_channels / total_time
|
490 |
+
# Output in tab separated format to enable copy-paste into excel/numbers
|
491 |
+
print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time))
|
492 |
+
outimg = Image.fromarray(np.squeeze(result), formats[n_channels])
|
493 |
+
outimg.save('./results/{}.png'.format(fn.__name__, "PNG"))
|
494 |
+
|
495 |
+
if __name__ == '__main__':
|
496 |
+
benchmark()
|
497 |
+
|
498 |
+
|
499 |
+
# Example results:
|
500 |
+
# N = 5, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png
|
501 |
+
# Function MPx/sec MB/sec Sec
|
502 |
+
# simplest_countless 752.855 752.855 0.01
|
503 |
+
# quick_countless 920.328 920.328 0.01
|
504 |
+
# zero_corrected_countless 534.143 534.143 0.01
|
505 |
+
# countless 644.247 644.247 0.01
|
506 |
+
# downsample_with_averaging 372.575 372.575 0.01
|
507 |
+
# downsample_with_max_pooling 974.060 974.060 0.01
|
508 |
+
# ndzoom 137.517 137.517 0.04
|
509 |
+
# striding 38550.588 38550.588 0.00
|
510 |
+
# countless_if 4.377 4.377 1.14
|
511 |
+
# counting 0.117 0.117 42.85
|
512 |
+
|
513 |
+
# Run without non-numpy implementations:
|
514 |
+
# N = 2000, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png
|
515 |
+
# Algorithm MPx/sec MB/sec Sec
|
516 |
+
# simplest_countless 800.522 800.522 2.50
|
517 |
+
# quick_countless 945.420 945.420 2.12
|
518 |
+
# quickest_countless 947.256 947.256 2.11
|
519 |
+
# stippled_countless 544.049 544.049 3.68
|
520 |
+
# zero_corrected_countless 575.310 575.310 3.48
|
521 |
+
# countless 646.684 646.684 3.09
|
522 |
+
# downsample_with_averaging 385.132 385.132 5.19
|
523 |
+
# downsample_with_max_poolin 988.361 988.361 2.02
|
524 |
+
# ndzoom 163.104 163.104 12.26
|
525 |
+
# striding 81589.340 81589.340 0.02
|
526 |
+
|
527 |
+
|
528 |
+
|
529 |
+
|
saicinpainting/evaluation/masks/countless/countless3d.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from six.moves import range
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import io
|
5 |
+
import time
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
from collections import defaultdict
|
10 |
+
from copy import deepcopy
|
11 |
+
from itertools import combinations
|
12 |
+
from functools import reduce
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from memory_profiler import profile
|
16 |
+
|
17 |
+
def countless5(a,b,c,d,e):
|
18 |
+
"""First stage of generalizing from countless2d.
|
19 |
+
|
20 |
+
You have five slots: A, B, C, D, E
|
21 |
+
|
22 |
+
You can decide if something is the winner by first checking for
|
23 |
+
matches of three, then matches of two, then picking just one if
|
24 |
+
the other two tries fail. In countless2d, you just check for matches
|
25 |
+
of two and then pick one of them otherwise.
|
26 |
+
|
27 |
+
Unfortunately, you need to check ABC, ABD, ABE, BCD, BDE, & CDE.
|
28 |
+
Then you need to check AB, AC, AD, BC, BD
|
29 |
+
We skip checking E because if none of these match, we pick E. We can
|
30 |
+
skip checking AE, BE, CE, DE since if any of those match, E is our boy
|
31 |
+
so it's redundant.
|
32 |
+
|
33 |
+
So countless grows cominatorially in complexity.
|
34 |
+
"""
|
35 |
+
sections = [ a,b,c,d,e ]
|
36 |
+
|
37 |
+
p2 = lambda q,r: q * (q == r) # q if p == q else 0
|
38 |
+
p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) # q if q == r == s else 0
|
39 |
+
|
40 |
+
lor = lambda x,y: x + (x == 0) * y
|
41 |
+
|
42 |
+
results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) )
|
43 |
+
results3 = reduce(lor, results3)
|
44 |
+
|
45 |
+
results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) )
|
46 |
+
results2 = reduce(lor, results2)
|
47 |
+
|
48 |
+
return reduce(lor, (results3, results2, e))
|
49 |
+
|
50 |
+
def countless8(a,b,c,d,e,f,g,h):
|
51 |
+
"""Extend countless5 to countless8. Same deal, except we also
|
52 |
+
need to check for matches of length 4."""
|
53 |
+
sections = [ a, b, c, d, e, f, g, h ]
|
54 |
+
|
55 |
+
p2 = lambda q,r: q * (q == r)
|
56 |
+
p3 = lambda q,r,s: q * ( (q == r) & (r == s) )
|
57 |
+
p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) )
|
58 |
+
|
59 |
+
lor = lambda x,y: x + (x == 0) * y
|
60 |
+
|
61 |
+
results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) )
|
62 |
+
results4 = reduce(lor, results4)
|
63 |
+
|
64 |
+
results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) )
|
65 |
+
results3 = reduce(lor, results3)
|
66 |
+
|
67 |
+
# We can always use our shortcut of omitting the last element
|
68 |
+
# for N choose 2
|
69 |
+
results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) )
|
70 |
+
results2 = reduce(lor, results2)
|
71 |
+
|
72 |
+
return reduce(lor, [ results4, results3, results2, h ])
|
73 |
+
|
74 |
+
def dynamic_countless3d(data):
|
75 |
+
"""countless8 + dynamic programming. ~2x faster"""
|
76 |
+
sections = []
|
77 |
+
|
78 |
+
# shift zeros up one so they don't interfere with bitwise operators
|
79 |
+
# we'll shift down at the end
|
80 |
+
data += 1
|
81 |
+
|
82 |
+
# This loop splits the 2D array apart into four arrays that are
|
83 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
84 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
85 |
+
factor = (2,2,2)
|
86 |
+
for offset in np.ndindex(factor):
|
87 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
88 |
+
sections.append(part)
|
89 |
+
|
90 |
+
pick = lambda a,b: a * (a == b)
|
91 |
+
lor = lambda x,y: x + (x == 0) * y
|
92 |
+
|
93 |
+
subproblems2 = {}
|
94 |
+
|
95 |
+
results2 = None
|
96 |
+
for x,y in combinations(range(7), 2):
|
97 |
+
res = pick(sections[x], sections[y])
|
98 |
+
subproblems2[(x,y)] = res
|
99 |
+
if results2 is not None:
|
100 |
+
results2 += (results2 == 0) * res
|
101 |
+
else:
|
102 |
+
results2 = res
|
103 |
+
|
104 |
+
subproblems3 = {}
|
105 |
+
|
106 |
+
results3 = None
|
107 |
+
for x,y,z in combinations(range(8), 3):
|
108 |
+
res = pick(subproblems2[(x,y)], sections[z])
|
109 |
+
|
110 |
+
if z != 7:
|
111 |
+
subproblems3[(x,y,z)] = res
|
112 |
+
|
113 |
+
if results3 is not None:
|
114 |
+
results3 += (results3 == 0) * res
|
115 |
+
else:
|
116 |
+
results3 = res
|
117 |
+
|
118 |
+
results3 = reduce(lor, (results3, results2, sections[-1]))
|
119 |
+
|
120 |
+
# free memory
|
121 |
+
results2 = None
|
122 |
+
subproblems2 = None
|
123 |
+
res = None
|
124 |
+
|
125 |
+
results4 = ( pick(subproblems3[(x,y,z)], sections[w]) for x,y,z,w in combinations(range(8), 4) )
|
126 |
+
results4 = reduce(lor, results4)
|
127 |
+
subproblems3 = None # free memory
|
128 |
+
|
129 |
+
final_result = lor(results4, results3) - 1
|
130 |
+
data -= 1
|
131 |
+
return final_result
|
132 |
+
|
133 |
+
def countless3d(data):
|
134 |
+
"""Now write countless8 in such a way that it could be used
|
135 |
+
to process an image."""
|
136 |
+
sections = []
|
137 |
+
|
138 |
+
# shift zeros up one so they don't interfere with bitwise operators
|
139 |
+
# we'll shift down at the end
|
140 |
+
data += 1
|
141 |
+
|
142 |
+
# This loop splits the 2D array apart into four arrays that are
|
143 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
144 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
145 |
+
factor = (2,2,2)
|
146 |
+
for offset in np.ndindex(factor):
|
147 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
148 |
+
sections.append(part)
|
149 |
+
|
150 |
+
p2 = lambda q,r: q * (q == r)
|
151 |
+
p3 = lambda q,r,s: q * ( (q == r) & (r == s) )
|
152 |
+
p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) )
|
153 |
+
|
154 |
+
lor = lambda x,y: x + (x == 0) * y
|
155 |
+
|
156 |
+
results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) )
|
157 |
+
results4 = reduce(lor, results4)
|
158 |
+
|
159 |
+
results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) )
|
160 |
+
results3 = reduce(lor, results3)
|
161 |
+
|
162 |
+
results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) )
|
163 |
+
results2 = reduce(lor, results2)
|
164 |
+
|
165 |
+
final_result = reduce(lor, (results4, results3, results2, sections[-1])) - 1
|
166 |
+
data -= 1
|
167 |
+
return final_result
|
168 |
+
|
169 |
+
def countless_generalized(data, factor):
|
170 |
+
assert len(data.shape) == len(factor)
|
171 |
+
|
172 |
+
sections = []
|
173 |
+
|
174 |
+
mode_of = reduce(lambda x,y: x * y, factor)
|
175 |
+
majority = int(math.ceil(float(mode_of) / 2))
|
176 |
+
|
177 |
+
data += 1
|
178 |
+
|
179 |
+
# This loop splits the 2D array apart into four arrays that are
|
180 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
181 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
182 |
+
for offset in np.ndindex(factor):
|
183 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
184 |
+
sections.append(part)
|
185 |
+
|
186 |
+
def pick(elements):
|
187 |
+
eq = ( elements[i] == elements[i+1] for i in range(len(elements) - 1) )
|
188 |
+
anded = reduce(lambda p,q: p & q, eq)
|
189 |
+
return elements[0] * anded
|
190 |
+
|
191 |
+
def logical_or(x,y):
|
192 |
+
return x + (x == 0) * y
|
193 |
+
|
194 |
+
result = ( pick(combo) for combo in combinations(sections, majority) )
|
195 |
+
result = reduce(logical_or, result)
|
196 |
+
for i in range(majority - 1, 3-1, -1): # 3-1 b/c of exclusive bounds
|
197 |
+
partial_result = ( pick(combo) for combo in combinations(sections, i) )
|
198 |
+
partial_result = reduce(logical_or, partial_result)
|
199 |
+
result = logical_or(result, partial_result)
|
200 |
+
|
201 |
+
partial_result = ( pick(combo) for combo in combinations(sections[:-1], 2) )
|
202 |
+
partial_result = reduce(logical_or, partial_result)
|
203 |
+
result = logical_or(result, partial_result)
|
204 |
+
|
205 |
+
result = logical_or(result, sections[-1]) - 1
|
206 |
+
data -= 1
|
207 |
+
return result
|
208 |
+
|
209 |
+
def dynamic_countless_generalized(data, factor):
|
210 |
+
assert len(data.shape) == len(factor)
|
211 |
+
|
212 |
+
sections = []
|
213 |
+
|
214 |
+
mode_of = reduce(lambda x,y: x * y, factor)
|
215 |
+
majority = int(math.ceil(float(mode_of) / 2))
|
216 |
+
|
217 |
+
data += 1 # offset from zero
|
218 |
+
|
219 |
+
# This loop splits the 2D array apart into four arrays that are
|
220 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
221 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
222 |
+
for offset in np.ndindex(factor):
|
223 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
224 |
+
sections.append(part)
|
225 |
+
|
226 |
+
pick = lambda a,b: a * (a == b)
|
227 |
+
lor = lambda x,y: x + (x == 0) * y # logical or
|
228 |
+
|
229 |
+
subproblems = [ {}, {} ]
|
230 |
+
results2 = None
|
231 |
+
for x,y in combinations(range(len(sections) - 1), 2):
|
232 |
+
res = pick(sections[x], sections[y])
|
233 |
+
subproblems[0][(x,y)] = res
|
234 |
+
if results2 is not None:
|
235 |
+
results2 = lor(results2, res)
|
236 |
+
else:
|
237 |
+
results2 = res
|
238 |
+
|
239 |
+
results = [ results2 ]
|
240 |
+
for r in range(3, majority+1):
|
241 |
+
r_results = None
|
242 |
+
for combo in combinations(range(len(sections)), r):
|
243 |
+
res = pick(subproblems[0][combo[:-1]], sections[combo[-1]])
|
244 |
+
|
245 |
+
if combo[-1] != len(sections) - 1:
|
246 |
+
subproblems[1][combo] = res
|
247 |
+
|
248 |
+
if r_results is not None:
|
249 |
+
r_results = lor(r_results, res)
|
250 |
+
else:
|
251 |
+
r_results = res
|
252 |
+
results.append(r_results)
|
253 |
+
subproblems[0] = subproblems[1]
|
254 |
+
subproblems[1] = {}
|
255 |
+
|
256 |
+
results.reverse()
|
257 |
+
final_result = lor(reduce(lor, results), sections[-1]) - 1
|
258 |
+
data -= 1
|
259 |
+
return final_result
|
260 |
+
|
261 |
+
def downsample_with_averaging(array):
|
262 |
+
"""
|
263 |
+
Downsample x by factor using averaging.
|
264 |
+
|
265 |
+
@return: The downsampled array, of the same type as x.
|
266 |
+
"""
|
267 |
+
factor = (2,2,2)
|
268 |
+
|
269 |
+
if np.array_equal(factor[:3], np.array([1,1,1])):
|
270 |
+
return array
|
271 |
+
|
272 |
+
output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor))
|
273 |
+
temp = np.zeros(output_shape, float)
|
274 |
+
counts = np.zeros(output_shape, np.int)
|
275 |
+
for offset in np.ndindex(factor):
|
276 |
+
part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
277 |
+
indexing_expr = tuple(np.s_[:s] for s in part.shape)
|
278 |
+
temp[indexing_expr] += part
|
279 |
+
counts[indexing_expr] += 1
|
280 |
+
return np.cast[array.dtype](temp / counts)
|
281 |
+
|
282 |
+
def downsample_with_max_pooling(array):
|
283 |
+
|
284 |
+
factor = (2,2,2)
|
285 |
+
|
286 |
+
sections = []
|
287 |
+
|
288 |
+
for offset in np.ndindex(factor):
|
289 |
+
part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
290 |
+
sections.append(part)
|
291 |
+
|
292 |
+
output = sections[0].copy()
|
293 |
+
|
294 |
+
for section in sections[1:]:
|
295 |
+
np.maximum(output, section, output)
|
296 |
+
|
297 |
+
return output
|
298 |
+
|
299 |
+
def striding(array):
|
300 |
+
"""Downsample x by factor using striding.
|
301 |
+
|
302 |
+
@return: The downsampled array, of the same type as x.
|
303 |
+
"""
|
304 |
+
factor = (2,2,2)
|
305 |
+
if np.all(np.array(factor, int) == 1):
|
306 |
+
return array
|
307 |
+
return array[tuple(np.s_[::f] for f in factor)]
|
308 |
+
|
309 |
+
def benchmark():
|
310 |
+
def countless3d_generalized(img):
|
311 |
+
return countless_generalized(img, (2,8,1))
|
312 |
+
def countless3d_dynamic_generalized(img):
|
313 |
+
return dynamic_countless_generalized(img, (8,8,1))
|
314 |
+
|
315 |
+
methods = [
|
316 |
+
# countless3d,
|
317 |
+
# dynamic_countless3d,
|
318 |
+
countless3d_generalized,
|
319 |
+
# countless3d_dynamic_generalized,
|
320 |
+
# striding,
|
321 |
+
# downsample_with_averaging,
|
322 |
+
# downsample_with_max_pooling
|
323 |
+
]
|
324 |
+
|
325 |
+
data = np.zeros(shape=(16**2, 16**2, 16**2), dtype=np.uint8) + 1
|
326 |
+
|
327 |
+
N = 5
|
328 |
+
|
329 |
+
print('Algorithm\tMPx\tMB/sec\tSec\tN=%d' % N)
|
330 |
+
|
331 |
+
for fn in methods:
|
332 |
+
start = time.time()
|
333 |
+
for _ in range(N):
|
334 |
+
result = fn(data)
|
335 |
+
end = time.time()
|
336 |
+
|
337 |
+
total_time = (end - start)
|
338 |
+
mpx = N * float(data.shape[0] * data.shape[1] * data.shape[2]) / total_time / 1024.0 / 1024.0
|
339 |
+
mbytes = mpx * np.dtype(data.dtype).itemsize
|
340 |
+
# Output in tab separated format to enable copy-paste into excel/numbers
|
341 |
+
print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time))
|
342 |
+
|
343 |
+
if __name__ == '__main__':
|
344 |
+
benchmark()
|
345 |
+
|
346 |
+
# Algorithm MPx MB/sec Sec N=5
|
347 |
+
# countless3d 10.564 10.564 60.58
|
348 |
+
# dynamic_countless3d 22.717 22.717 28.17
|
349 |
+
# countless3d_generalized 9.702 9.702 65.96
|
350 |
+
# countless3d_dynamic_generalized 22.720 22.720 28.17
|
351 |
+
# striding 253360.506 253360.506 0.00
|
352 |
+
# downsample_with_averaging 224.098 224.098 2.86
|
353 |
+
# downsample_with_max_pooling 690.474 690.474 0.93
|
354 |
+
|
355 |
+
|
356 |
+
|
saicinpainting/evaluation/masks/countless/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pillow>=6.2.0
|
2 |
+
numpy>=1.16
|
3 |
+
scipy
|
4 |
+
tqdm
|
5 |
+
memory_profiler
|
6 |
+
six
|
7 |
+
pytest
|
saicinpainting/evaluation/masks/countless/test.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import countless2d
|
6 |
+
import countless3d
|
7 |
+
|
8 |
+
def test_countless2d():
|
9 |
+
def test_all_cases(fn, test_zero):
|
10 |
+
case1 = np.array([ [ 1, 2 ], [ 3, 4 ] ]).reshape((2,2,1,1)) # all different
|
11 |
+
case2 = np.array([ [ 1, 1 ], [ 2, 3 ] ]).reshape((2,2,1,1)) # two are same
|
12 |
+
case1z = np.array([ [ 0, 1 ], [ 2, 3 ] ]).reshape((2,2,1,1)) # all different
|
13 |
+
case2z = np.array([ [ 0, 0 ], [ 2, 3 ] ]).reshape((2,2,1,1)) # two are same
|
14 |
+
case3 = np.array([ [ 1, 1 ], [ 2, 2 ] ]).reshape((2,2,1,1)) # two groups are same
|
15 |
+
case4 = np.array([ [ 1, 2 ], [ 2, 2 ] ]).reshape((2,2,1,1)) # 3 are the same
|
16 |
+
case5 = np.array([ [ 5, 5 ], [ 5, 5 ] ]).reshape((2,2,1,1)) # all are the same
|
17 |
+
|
18 |
+
is_255_handled = np.array([ [ 255, 255 ], [ 1, 2 ] ], dtype=np.uint8).reshape((2,2,1,1))
|
19 |
+
|
20 |
+
test = lambda case: fn(case)
|
21 |
+
|
22 |
+
if test_zero:
|
23 |
+
assert test(case1z) == [[[[3]]]] # d
|
24 |
+
assert test(case2z) == [[[[0]]]] # a==b
|
25 |
+
else:
|
26 |
+
assert test(case1) == [[[[4]]]] # d
|
27 |
+
assert test(case2) == [[[[1]]]] # a==b
|
28 |
+
|
29 |
+
assert test(case3) == [[[[1]]]] # a==b
|
30 |
+
assert test(case4) == [[[[2]]]] # b==c
|
31 |
+
assert test(case5) == [[[[5]]]] # a==b
|
32 |
+
|
33 |
+
assert test(is_255_handled) == [[[[255]]]]
|
34 |
+
|
35 |
+
assert fn(case1).dtype == case1.dtype
|
36 |
+
|
37 |
+
test_all_cases(countless2d.simplest_countless, False)
|
38 |
+
test_all_cases(countless2d.quick_countless, False)
|
39 |
+
test_all_cases(countless2d.quickest_countless, False)
|
40 |
+
test_all_cases(countless2d.stippled_countless, False)
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
methods = [
|
45 |
+
countless2d.zero_corrected_countless,
|
46 |
+
countless2d.countless,
|
47 |
+
countless2d.countless_if,
|
48 |
+
# countless2d.counting, # counting doesn't respect order so harder to write a test
|
49 |
+
]
|
50 |
+
|
51 |
+
for fn in methods:
|
52 |
+
print(fn.__name__)
|
53 |
+
test_all_cases(fn, True)
|
54 |
+
|
55 |
+
def test_stippled_countless2d():
|
56 |
+
a = np.array([ [ 1, 2 ], [ 3, 4 ] ]).reshape((2,2,1,1))
|
57 |
+
b = np.array([ [ 0, 2 ], [ 3, 4 ] ]).reshape((2,2,1,1))
|
58 |
+
c = np.array([ [ 1, 0 ], [ 3, 4 ] ]).reshape((2,2,1,1))
|
59 |
+
d = np.array([ [ 1, 2 ], [ 0, 4 ] ]).reshape((2,2,1,1))
|
60 |
+
e = np.array([ [ 1, 2 ], [ 3, 0 ] ]).reshape((2,2,1,1))
|
61 |
+
f = np.array([ [ 0, 0 ], [ 3, 4 ] ]).reshape((2,2,1,1))
|
62 |
+
g = np.array([ [ 0, 2 ], [ 0, 4 ] ]).reshape((2,2,1,1))
|
63 |
+
h = np.array([ [ 0, 2 ], [ 3, 0 ] ]).reshape((2,2,1,1))
|
64 |
+
i = np.array([ [ 1, 0 ], [ 0, 4 ] ]).reshape((2,2,1,1))
|
65 |
+
j = np.array([ [ 1, 2 ], [ 0, 0 ] ]).reshape((2,2,1,1))
|
66 |
+
k = np.array([ [ 1, 0 ], [ 3, 0 ] ]).reshape((2,2,1,1))
|
67 |
+
l = np.array([ [ 1, 0 ], [ 0, 0 ] ]).reshape((2,2,1,1))
|
68 |
+
m = np.array([ [ 0, 2 ], [ 0, 0 ] ]).reshape((2,2,1,1))
|
69 |
+
n = np.array([ [ 0, 0 ], [ 3, 0 ] ]).reshape((2,2,1,1))
|
70 |
+
o = np.array([ [ 0, 0 ], [ 0, 4 ] ]).reshape((2,2,1,1))
|
71 |
+
z = np.array([ [ 0, 0 ], [ 0, 0 ] ]).reshape((2,2,1,1))
|
72 |
+
|
73 |
+
test = countless2d.stippled_countless
|
74 |
+
|
75 |
+
# Note: We only tested non-matching cases above,
|
76 |
+
# cases f,g,h,i,j,k prove their duals work as well
|
77 |
+
# b/c if two pixels are black, either one can be chosen
|
78 |
+
# if they are different or the same.
|
79 |
+
|
80 |
+
assert test(a) == [[[[4]]]]
|
81 |
+
assert test(b) == [[[[4]]]]
|
82 |
+
assert test(c) == [[[[4]]]]
|
83 |
+
assert test(d) == [[[[4]]]]
|
84 |
+
assert test(e) == [[[[1]]]]
|
85 |
+
assert test(f) == [[[[4]]]]
|
86 |
+
assert test(g) == [[[[4]]]]
|
87 |
+
assert test(h) == [[[[2]]]]
|
88 |
+
assert test(i) == [[[[4]]]]
|
89 |
+
assert test(j) == [[[[1]]]]
|
90 |
+
assert test(k) == [[[[1]]]]
|
91 |
+
assert test(l) == [[[[1]]]]
|
92 |
+
assert test(m) == [[[[2]]]]
|
93 |
+
assert test(n) == [[[[3]]]]
|
94 |
+
assert test(o) == [[[[4]]]]
|
95 |
+
assert test(z) == [[[[0]]]]
|
96 |
+
|
97 |
+
bc = np.array([ [ 0, 2 ], [ 2, 4 ] ]).reshape((2,2,1,1))
|
98 |
+
bd = np.array([ [ 0, 2 ], [ 3, 2 ] ]).reshape((2,2,1,1))
|
99 |
+
cd = np.array([ [ 0, 2 ], [ 3, 3 ] ]).reshape((2,2,1,1))
|
100 |
+
|
101 |
+
assert test(bc) == [[[[2]]]]
|
102 |
+
assert test(bd) == [[[[2]]]]
|
103 |
+
assert test(cd) == [[[[3]]]]
|
104 |
+
|
105 |
+
ab = np.array([ [ 1, 1 ], [ 0, 4 ] ]).reshape((2,2,1,1))
|
106 |
+
ac = np.array([ [ 1, 2 ], [ 1, 0 ] ]).reshape((2,2,1,1))
|
107 |
+
ad = np.array([ [ 1, 0 ], [ 3, 1 ] ]).reshape((2,2,1,1))
|
108 |
+
|
109 |
+
assert test(ab) == [[[[1]]]]
|
110 |
+
assert test(ac) == [[[[1]]]]
|
111 |
+
assert test(ad) == [[[[1]]]]
|
112 |
+
|
113 |
+
def test_countless3d():
|
114 |
+
def test_all_cases(fn):
|
115 |
+
alldifferent = [
|
116 |
+
[
|
117 |
+
[1,2],
|
118 |
+
[3,4],
|
119 |
+
],
|
120 |
+
[
|
121 |
+
[5,6],
|
122 |
+
[7,8]
|
123 |
+
]
|
124 |
+
]
|
125 |
+
allsame = [
|
126 |
+
[
|
127 |
+
[1,1],
|
128 |
+
[1,1],
|
129 |
+
],
|
130 |
+
[
|
131 |
+
[1,1],
|
132 |
+
[1,1]
|
133 |
+
]
|
134 |
+
]
|
135 |
+
|
136 |
+
assert fn(np.array(alldifferent)) == [[[8]]]
|
137 |
+
assert fn(np.array(allsame)) == [[[1]]]
|
138 |
+
|
139 |
+
twosame = deepcopy(alldifferent)
|
140 |
+
twosame[1][1][0] = 2
|
141 |
+
|
142 |
+
assert fn(np.array(twosame)) == [[[2]]]
|
143 |
+
|
144 |
+
threemixed = [
|
145 |
+
[
|
146 |
+
[3,3],
|
147 |
+
[1,2],
|
148 |
+
],
|
149 |
+
[
|
150 |
+
[2,4],
|
151 |
+
[4,3]
|
152 |
+
]
|
153 |
+
]
|
154 |
+
assert fn(np.array(threemixed)) == [[[3]]]
|
155 |
+
|
156 |
+
foursame = [
|
157 |
+
[
|
158 |
+
[4,4],
|
159 |
+
[1,2],
|
160 |
+
],
|
161 |
+
[
|
162 |
+
[2,4],
|
163 |
+
[4,3]
|
164 |
+
]
|
165 |
+
]
|
166 |
+
|
167 |
+
assert fn(np.array(foursame)) == [[[4]]]
|
168 |
+
|
169 |
+
fivesame = [
|
170 |
+
[
|
171 |
+
[5,4],
|
172 |
+
[5,5],
|
173 |
+
],
|
174 |
+
[
|
175 |
+
[2,4],
|
176 |
+
[5,5]
|
177 |
+
]
|
178 |
+
]
|
179 |
+
|
180 |
+
assert fn(np.array(fivesame)) == [[[5]]]
|
181 |
+
|
182 |
+
def countless3d_generalized(img):
|
183 |
+
return countless3d.countless_generalized(img, (2,2,2))
|
184 |
+
def countless3d_dynamic_generalized(img):
|
185 |
+
return countless3d.dynamic_countless_generalized(img, (2,2,2))
|
186 |
+
|
187 |
+
methods = [
|
188 |
+
countless3d.countless3d,
|
189 |
+
countless3d.dynamic_countless3d,
|
190 |
+
countless3d_generalized,
|
191 |
+
countless3d_dynamic_generalized,
|
192 |
+
]
|
193 |
+
|
194 |
+
for fn in methods:
|
195 |
+
test_all_cases(fn)
|
saicinpainting/evaluation/masks/mask.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
from copy import deepcopy
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from skimage import img_as_ubyte
|
6 |
+
from skimage.transform import rescale, resize
|
7 |
+
try:
|
8 |
+
from detectron2 import model_zoo
|
9 |
+
from detectron2.config import get_cfg
|
10 |
+
from detectron2.engine import DefaultPredictor
|
11 |
+
DETECTRON_INSTALLED = True
|
12 |
+
except:
|
13 |
+
print("Detectron v2 is not installed")
|
14 |
+
DETECTRON_INSTALLED = False
|
15 |
+
|
16 |
+
from .countless.countless2d import zero_corrected_countless
|
17 |
+
|
18 |
+
|
19 |
+
class ObjectMask():
|
20 |
+
def __init__(self, mask):
|
21 |
+
self.height, self.width = mask.shape
|
22 |
+
(self.up, self.down), (self.left, self.right) = self._get_limits(mask)
|
23 |
+
self.mask = mask[self.up:self.down, self.left:self.right].copy()
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def _get_limits(mask):
|
27 |
+
def indicator_limits(indicator):
|
28 |
+
lower = indicator.argmax()
|
29 |
+
upper = len(indicator) - indicator[::-1].argmax()
|
30 |
+
return lower, upper
|
31 |
+
|
32 |
+
vertical_indicator = mask.any(axis=1)
|
33 |
+
vertical_limits = indicator_limits(vertical_indicator)
|
34 |
+
|
35 |
+
horizontal_indicator = mask.any(axis=0)
|
36 |
+
horizontal_limits = indicator_limits(horizontal_indicator)
|
37 |
+
|
38 |
+
return vertical_limits, horizontal_limits
|
39 |
+
|
40 |
+
def _clean(self):
|
41 |
+
self.up, self.down, self.left, self.right = 0, 0, 0, 0
|
42 |
+
self.mask = np.empty((0, 0))
|
43 |
+
|
44 |
+
def horizontal_flip(self, inplace=False):
|
45 |
+
if not inplace:
|
46 |
+
flipped = deepcopy(self)
|
47 |
+
return flipped.horizontal_flip(inplace=True)
|
48 |
+
|
49 |
+
self.mask = self.mask[:, ::-1]
|
50 |
+
return self
|
51 |
+
|
52 |
+
def vertical_flip(self, inplace=False):
|
53 |
+
if not inplace:
|
54 |
+
flipped = deepcopy(self)
|
55 |
+
return flipped.vertical_flip(inplace=True)
|
56 |
+
|
57 |
+
self.mask = self.mask[::-1, :]
|
58 |
+
return self
|
59 |
+
|
60 |
+
def image_center(self):
|
61 |
+
y_center = self.up + (self.down - self.up) / 2
|
62 |
+
x_center = self.left + (self.right - self.left) / 2
|
63 |
+
return y_center, x_center
|
64 |
+
|
65 |
+
def rescale(self, scaling_factor, inplace=False):
|
66 |
+
if not inplace:
|
67 |
+
scaled = deepcopy(self)
|
68 |
+
return scaled.rescale(scaling_factor, inplace=True)
|
69 |
+
|
70 |
+
scaled_mask = rescale(self.mask.astype(float), scaling_factor, order=0) > 0.5
|
71 |
+
(up, down), (left, right) = self._get_limits(scaled_mask)
|
72 |
+
self.mask = scaled_mask[up:down, left:right]
|
73 |
+
|
74 |
+
y_center, x_center = self.image_center()
|
75 |
+
mask_height, mask_width = self.mask.shape
|
76 |
+
self.up = int(round(y_center - mask_height / 2))
|
77 |
+
self.down = self.up + mask_height
|
78 |
+
self.left = int(round(x_center - mask_width / 2))
|
79 |
+
self.right = self.left + mask_width
|
80 |
+
return self
|
81 |
+
|
82 |
+
def crop_to_canvas(self, vertical=True, horizontal=True, inplace=False):
|
83 |
+
if not inplace:
|
84 |
+
cropped = deepcopy(self)
|
85 |
+
cropped.crop_to_canvas(vertical=vertical, horizontal=horizontal, inplace=True)
|
86 |
+
return cropped
|
87 |
+
|
88 |
+
if vertical:
|
89 |
+
if self.up >= self.height or self.down <= 0:
|
90 |
+
self._clean()
|
91 |
+
else:
|
92 |
+
cut_up, cut_down = max(-self.up, 0), max(self.down - self.height, 0)
|
93 |
+
if cut_up != 0:
|
94 |
+
self.mask = self.mask[cut_up:]
|
95 |
+
self.up = 0
|
96 |
+
if cut_down != 0:
|
97 |
+
self.mask = self.mask[:-cut_down]
|
98 |
+
self.down = self.height
|
99 |
+
|
100 |
+
if horizontal:
|
101 |
+
if self.left >= self.width or self.right <= 0:
|
102 |
+
self._clean()
|
103 |
+
else:
|
104 |
+
cut_left, cut_right = max(-self.left, 0), max(self.right - self.width, 0)
|
105 |
+
if cut_left != 0:
|
106 |
+
self.mask = self.mask[:, cut_left:]
|
107 |
+
self.left = 0
|
108 |
+
if cut_right != 0:
|
109 |
+
self.mask = self.mask[:, :-cut_right]
|
110 |
+
self.right = self.width
|
111 |
+
|
112 |
+
return self
|
113 |
+
|
114 |
+
def restore_full_mask(self, allow_crop=False):
|
115 |
+
cropped = self.crop_to_canvas(inplace=allow_crop)
|
116 |
+
mask = np.zeros((cropped.height, cropped.width), dtype=bool)
|
117 |
+
mask[cropped.up:cropped.down, cropped.left:cropped.right] = cropped.mask
|
118 |
+
return mask
|
119 |
+
|
120 |
+
def shift(self, vertical=0, horizontal=0, inplace=False):
|
121 |
+
if not inplace:
|
122 |
+
shifted = deepcopy(self)
|
123 |
+
return shifted.shift(vertical=vertical, horizontal=horizontal, inplace=True)
|
124 |
+
|
125 |
+
self.up += vertical
|
126 |
+
self.down += vertical
|
127 |
+
self.left += horizontal
|
128 |
+
self.right += horizontal
|
129 |
+
return self
|
130 |
+
|
131 |
+
def area(self):
|
132 |
+
return self.mask.sum()
|
133 |
+
|
134 |
+
|
135 |
+
class RigidnessMode(enum.Enum):
|
136 |
+
soft = 0
|
137 |
+
rigid = 1
|
138 |
+
|
139 |
+
|
140 |
+
class SegmentationMask:
|
141 |
+
def __init__(self, confidence_threshold=0.5, rigidness_mode=RigidnessMode.rigid,
|
142 |
+
max_object_area=0.3, min_mask_area=0.02, downsample_levels=6, num_variants_per_mask=4,
|
143 |
+
max_mask_intersection=0.5, max_foreground_coverage=0.5, max_foreground_intersection=0.5,
|
144 |
+
max_hidden_area=0.2, max_scale_change=0.25, horizontal_flip=True,
|
145 |
+
max_vertical_shift=0.1, position_shuffle=True):
|
146 |
+
"""
|
147 |
+
:param confidence_threshold: float; threshold for confidence of the panoptic segmentator to allow for
|
148 |
+
the instance.
|
149 |
+
:param rigidness_mode: RigidnessMode object
|
150 |
+
when soft, checks intersection only with the object from which the mask_object was produced
|
151 |
+
when rigid, checks intersection with any foreground class object
|
152 |
+
:param max_object_area: float; allowed upper bound for to be considered as mask_object.
|
153 |
+
:param min_mask_area: float; lower bound for mask to be considered valid
|
154 |
+
:param downsample_levels: int; defines width of the resized segmentation to obtain shifted masks;
|
155 |
+
:param num_variants_per_mask: int; maximal number of the masks for the same object;
|
156 |
+
:param max_mask_intersection: float; maximum allowed area fraction of intersection for 2 masks
|
157 |
+
produced by horizontal shift of the same mask_object; higher value -> more diversity
|
158 |
+
:param max_foreground_coverage: float; maximum allowed area fraction of intersection for foreground object to be
|
159 |
+
covered by mask; lower value -> less the objects are covered
|
160 |
+
:param max_foreground_intersection: float; maximum allowed area of intersection for the mask with foreground
|
161 |
+
object; lower value -> mask is more on the background than on the objects
|
162 |
+
:param max_hidden_area: upper bound on part of the object hidden by shifting object outside the screen area;
|
163 |
+
:param max_scale_change: allowed scale change for the mask_object;
|
164 |
+
:param horizontal_flip: if horizontal flips are allowed;
|
165 |
+
:param max_vertical_shift: amount of vertical movement allowed;
|
166 |
+
:param position_shuffle: shuffle
|
167 |
+
"""
|
168 |
+
|
169 |
+
assert DETECTRON_INSTALLED, 'Cannot use SegmentationMask without detectron2'
|
170 |
+
self.cfg = get_cfg()
|
171 |
+
self.cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
|
172 |
+
self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
|
173 |
+
self.cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = confidence_threshold
|
174 |
+
self.predictor = DefaultPredictor(self.cfg)
|
175 |
+
|
176 |
+
self.rigidness_mode = RigidnessMode(rigidness_mode)
|
177 |
+
self.max_object_area = max_object_area
|
178 |
+
self.min_mask_area = min_mask_area
|
179 |
+
self.downsample_levels = downsample_levels
|
180 |
+
self.num_variants_per_mask = num_variants_per_mask
|
181 |
+
self.max_mask_intersection = max_mask_intersection
|
182 |
+
self.max_foreground_coverage = max_foreground_coverage
|
183 |
+
self.max_foreground_intersection = max_foreground_intersection
|
184 |
+
self.max_hidden_area = max_hidden_area
|
185 |
+
self.position_shuffle = position_shuffle
|
186 |
+
|
187 |
+
self.max_scale_change = max_scale_change
|
188 |
+
self.horizontal_flip = horizontal_flip
|
189 |
+
self.max_vertical_shift = max_vertical_shift
|
190 |
+
|
191 |
+
def get_segmentation(self, img):
|
192 |
+
im = img_as_ubyte(img)
|
193 |
+
panoptic_seg, segment_info = self.predictor(im)["panoptic_seg"]
|
194 |
+
return panoptic_seg, segment_info
|
195 |
+
|
196 |
+
@staticmethod
|
197 |
+
def _is_power_of_two(n):
|
198 |
+
return (n != 0) and (n & (n-1) == 0)
|
199 |
+
|
200 |
+
def identify_candidates(self, panoptic_seg, segments_info):
|
201 |
+
potential_mask_ids = []
|
202 |
+
for segment in segments_info:
|
203 |
+
if not segment["isthing"]:
|
204 |
+
continue
|
205 |
+
mask = (panoptic_seg == segment["id"]).int().detach().cpu().numpy()
|
206 |
+
area = mask.sum().item() / np.prod(panoptic_seg.shape)
|
207 |
+
if area >= self.max_object_area:
|
208 |
+
continue
|
209 |
+
potential_mask_ids.append(segment["id"])
|
210 |
+
return potential_mask_ids
|
211 |
+
|
212 |
+
def downsample_mask(self, mask):
|
213 |
+
height, width = mask.shape
|
214 |
+
if not (self._is_power_of_two(height) and self._is_power_of_two(width)):
|
215 |
+
raise ValueError("Image sides are not power of 2.")
|
216 |
+
|
217 |
+
num_iterations = width.bit_length() - 1 - self.downsample_levels
|
218 |
+
if num_iterations < 0:
|
219 |
+
raise ValueError(f"Width is lower than 2^{self.downsample_levels}.")
|
220 |
+
|
221 |
+
if height.bit_length() - 1 < num_iterations:
|
222 |
+
raise ValueError("Height is too low to perform downsampling")
|
223 |
+
|
224 |
+
downsampled = mask
|
225 |
+
for _ in range(num_iterations):
|
226 |
+
downsampled = zero_corrected_countless(downsampled)
|
227 |
+
|
228 |
+
return downsampled
|
229 |
+
|
230 |
+
def _augmentation_params(self):
|
231 |
+
scaling_factor = np.random.uniform(1 - self.max_scale_change, 1 + self.max_scale_change)
|
232 |
+
if self.horizontal_flip:
|
233 |
+
horizontal_flip = bool(np.random.choice(2))
|
234 |
+
else:
|
235 |
+
horizontal_flip = False
|
236 |
+
vertical_shift = np.random.uniform(-self.max_vertical_shift, self.max_vertical_shift)
|
237 |
+
|
238 |
+
return {
|
239 |
+
"scaling_factor": scaling_factor,
|
240 |
+
"horizontal_flip": horizontal_flip,
|
241 |
+
"vertical_shift": vertical_shift
|
242 |
+
}
|
243 |
+
|
244 |
+
def _get_intersection(self, mask_array, mask_object):
|
245 |
+
intersection = mask_array[
|
246 |
+
mask_object.up:mask_object.down, mask_object.left:mask_object.right
|
247 |
+
] & mask_object.mask
|
248 |
+
return intersection
|
249 |
+
|
250 |
+
def _check_masks_intersection(self, aug_mask, total_mask_area, prev_masks):
|
251 |
+
for existing_mask in prev_masks:
|
252 |
+
intersection_area = self._get_intersection(existing_mask, aug_mask).sum()
|
253 |
+
intersection_existing = intersection_area / existing_mask.sum()
|
254 |
+
intersection_current = 1 - (aug_mask.area() - intersection_area) / total_mask_area
|
255 |
+
if (intersection_existing > self.max_mask_intersection) or \
|
256 |
+
(intersection_current > self.max_mask_intersection):
|
257 |
+
return False
|
258 |
+
return True
|
259 |
+
|
260 |
+
def _check_foreground_intersection(self, aug_mask, foreground):
|
261 |
+
for existing_mask in foreground:
|
262 |
+
intersection_area = self._get_intersection(existing_mask, aug_mask).sum()
|
263 |
+
intersection_existing = intersection_area / existing_mask.sum()
|
264 |
+
if intersection_existing > self.max_foreground_coverage:
|
265 |
+
return False
|
266 |
+
intersection_mask = intersection_area / aug_mask.area()
|
267 |
+
if intersection_mask > self.max_foreground_intersection:
|
268 |
+
return False
|
269 |
+
return True
|
270 |
+
|
271 |
+
def _move_mask(self, mask, foreground):
|
272 |
+
# Obtaining properties of the original mask_object:
|
273 |
+
orig_mask = ObjectMask(mask)
|
274 |
+
|
275 |
+
chosen_masks = []
|
276 |
+
chosen_parameters = []
|
277 |
+
# to fix the case when resizing gives mask_object consisting only of False
|
278 |
+
scaling_factor_lower_bound = 0.
|
279 |
+
|
280 |
+
for var_idx in range(self.num_variants_per_mask):
|
281 |
+
# Obtaining augmentation parameters and applying them to the downscaled mask_object
|
282 |
+
augmentation_params = self._augmentation_params()
|
283 |
+
augmentation_params["scaling_factor"] = min([
|
284 |
+
augmentation_params["scaling_factor"],
|
285 |
+
2 * min(orig_mask.up, orig_mask.height - orig_mask.down) / orig_mask.height + 1.,
|
286 |
+
2 * min(orig_mask.left, orig_mask.width - orig_mask.right) / orig_mask.width + 1.
|
287 |
+
])
|
288 |
+
augmentation_params["scaling_factor"] = max([
|
289 |
+
augmentation_params["scaling_factor"], scaling_factor_lower_bound
|
290 |
+
])
|
291 |
+
|
292 |
+
aug_mask = deepcopy(orig_mask)
|
293 |
+
aug_mask.rescale(augmentation_params["scaling_factor"], inplace=True)
|
294 |
+
if augmentation_params["horizontal_flip"]:
|
295 |
+
aug_mask.horizontal_flip(inplace=True)
|
296 |
+
total_aug_area = aug_mask.area()
|
297 |
+
if total_aug_area == 0:
|
298 |
+
scaling_factor_lower_bound = 1.
|
299 |
+
continue
|
300 |
+
|
301 |
+
# Fix if the element vertical shift is too strong and shown area is too small:
|
302 |
+
vertical_area = aug_mask.mask.sum(axis=1) / total_aug_area # share of area taken by rows
|
303 |
+
# number of rows which are allowed to be hidden from upper and lower parts of image respectively
|
304 |
+
max_hidden_up = np.searchsorted(vertical_area.cumsum(), self.max_hidden_area)
|
305 |
+
max_hidden_down = np.searchsorted(vertical_area[::-1].cumsum(), self.max_hidden_area)
|
306 |
+
# correcting vertical shift, so not too much area will be hidden
|
307 |
+
augmentation_params["vertical_shift"] = np.clip(
|
308 |
+
augmentation_params["vertical_shift"],
|
309 |
+
-(aug_mask.up + max_hidden_up) / aug_mask.height,
|
310 |
+
(aug_mask.height - aug_mask.down + max_hidden_down) / aug_mask.height
|
311 |
+
)
|
312 |
+
# Applying vertical shift:
|
313 |
+
vertical_shift = int(round(aug_mask.height * augmentation_params["vertical_shift"]))
|
314 |
+
aug_mask.shift(vertical=vertical_shift, inplace=True)
|
315 |
+
aug_mask.crop_to_canvas(vertical=True, horizontal=False, inplace=True)
|
316 |
+
|
317 |
+
# Choosing horizontal shift:
|
318 |
+
max_hidden_area = self.max_hidden_area - (1 - aug_mask.area() / total_aug_area)
|
319 |
+
horizontal_area = aug_mask.mask.sum(axis=0) / total_aug_area
|
320 |
+
max_hidden_left = np.searchsorted(horizontal_area.cumsum(), max_hidden_area)
|
321 |
+
max_hidden_right = np.searchsorted(horizontal_area[::-1].cumsum(), max_hidden_area)
|
322 |
+
allowed_shifts = np.arange(-max_hidden_left, aug_mask.width -
|
323 |
+
(aug_mask.right - aug_mask.left) + max_hidden_right + 1)
|
324 |
+
allowed_shifts = - (aug_mask.left - allowed_shifts)
|
325 |
+
|
326 |
+
if self.position_shuffle:
|
327 |
+
np.random.shuffle(allowed_shifts)
|
328 |
+
|
329 |
+
mask_is_found = False
|
330 |
+
for horizontal_shift in allowed_shifts:
|
331 |
+
aug_mask_left = deepcopy(aug_mask)
|
332 |
+
aug_mask_left.shift(horizontal=horizontal_shift, inplace=True)
|
333 |
+
aug_mask_left.crop_to_canvas(inplace=True)
|
334 |
+
|
335 |
+
prev_masks = [mask] + chosen_masks
|
336 |
+
is_mask_suitable = self._check_masks_intersection(aug_mask_left, total_aug_area, prev_masks) & \
|
337 |
+
self._check_foreground_intersection(aug_mask_left, foreground)
|
338 |
+
if is_mask_suitable:
|
339 |
+
aug_draw = aug_mask_left.restore_full_mask()
|
340 |
+
chosen_masks.append(aug_draw)
|
341 |
+
augmentation_params["horizontal_shift"] = horizontal_shift / aug_mask_left.width
|
342 |
+
chosen_parameters.append(augmentation_params)
|
343 |
+
mask_is_found = True
|
344 |
+
break
|
345 |
+
|
346 |
+
if not mask_is_found:
|
347 |
+
break
|
348 |
+
|
349 |
+
return chosen_parameters
|
350 |
+
|
351 |
+
def _prepare_mask(self, mask):
|
352 |
+
height, width = mask.shape
|
353 |
+
target_width = width if self._is_power_of_two(width) else (1 << width.bit_length())
|
354 |
+
target_height = height if self._is_power_of_two(height) else (1 << height.bit_length())
|
355 |
+
|
356 |
+
return resize(mask.astype('float32'), (target_height, target_width), order=0, mode='edge').round().astype('int32')
|
357 |
+
|
358 |
+
def get_masks(self, im, return_panoptic=False):
|
359 |
+
panoptic_seg, segments_info = self.get_segmentation(im)
|
360 |
+
potential_mask_ids = self.identify_candidates(panoptic_seg, segments_info)
|
361 |
+
|
362 |
+
panoptic_seg_scaled = self._prepare_mask(panoptic_seg.detach().cpu().numpy())
|
363 |
+
downsampled = self.downsample_mask(panoptic_seg_scaled)
|
364 |
+
scene_objects = []
|
365 |
+
for segment in segments_info:
|
366 |
+
if not segment["isthing"]:
|
367 |
+
continue
|
368 |
+
mask = downsampled == segment["id"]
|
369 |
+
if not np.any(mask):
|
370 |
+
continue
|
371 |
+
scene_objects.append(mask)
|
372 |
+
|
373 |
+
mask_set = []
|
374 |
+
for mask_id in potential_mask_ids:
|
375 |
+
mask = downsampled == mask_id
|
376 |
+
if not np.any(mask):
|
377 |
+
continue
|
378 |
+
|
379 |
+
if self.rigidness_mode is RigidnessMode.soft:
|
380 |
+
foreground = [mask]
|
381 |
+
elif self.rigidness_mode is RigidnessMode.rigid:
|
382 |
+
foreground = scene_objects
|
383 |
+
else:
|
384 |
+
raise ValueError(f'Unexpected rigidness_mode: {rigidness_mode}')
|
385 |
+
|
386 |
+
masks_params = self._move_mask(mask, foreground)
|
387 |
+
|
388 |
+
full_mask = ObjectMask((panoptic_seg == mask_id).detach().cpu().numpy())
|
389 |
+
|
390 |
+
for params in masks_params:
|
391 |
+
aug_mask = deepcopy(full_mask)
|
392 |
+
aug_mask.rescale(params["scaling_factor"], inplace=True)
|
393 |
+
if params["horizontal_flip"]:
|
394 |
+
aug_mask.horizontal_flip(inplace=True)
|
395 |
+
|
396 |
+
vertical_shift = int(round(aug_mask.height * params["vertical_shift"]))
|
397 |
+
horizontal_shift = int(round(aug_mask.width * params["horizontal_shift"]))
|
398 |
+
aug_mask.shift(vertical=vertical_shift, horizontal=horizontal_shift, inplace=True)
|
399 |
+
aug_mask = aug_mask.restore_full_mask().astype('uint8')
|
400 |
+
if aug_mask.mean() <= self.min_mask_area:
|
401 |
+
continue
|
402 |
+
mask_set.append(aug_mask)
|
403 |
+
|
404 |
+
if return_panoptic:
|
405 |
+
return mask_set, panoptic_seg.detach().cpu().numpy()
|
406 |
+
else:
|
407 |
+
return mask_set
|
408 |
+
|
409 |
+
|
410 |
+
def propose_random_square_crop(mask, min_overlap=0.5):
|
411 |
+
height, width = mask.shape
|
412 |
+
mask_ys, mask_xs = np.where(mask > 0.5) # mask==0 is known fragment and mask==1 is missing
|
413 |
+
|
414 |
+
if height < width:
|
415 |
+
crop_size = height
|
416 |
+
obj_left, obj_right = mask_xs.min(), mask_xs.max()
|
417 |
+
obj_width = obj_right - obj_left
|
418 |
+
left_border = max(0, min(width - crop_size - 1, obj_left + obj_width * min_overlap - crop_size))
|
419 |
+
right_border = max(left_border + 1, min(width - crop_size, obj_left + obj_width * min_overlap))
|
420 |
+
start_x = np.random.randint(left_border, right_border)
|
421 |
+
return start_x, 0, start_x + crop_size, height
|
422 |
+
else:
|
423 |
+
crop_size = width
|
424 |
+
obj_top, obj_bottom = mask_ys.min(), mask_ys.max()
|
425 |
+
obj_height = obj_bottom - obj_top
|
426 |
+
top_border = max(0, min(height - crop_size - 1, obj_top + obj_height * min_overlap - crop_size))
|
427 |
+
bottom_border = max(top_border + 1, min(height - crop_size, obj_top + obj_height * min_overlap))
|
428 |
+
start_y = np.random.randint(top_border, bottom_border)
|
429 |
+
return 0, start_y, width, start_y + crop_size
|
saicinpainting/evaluation/refinement.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.optim import Adam, SGD
|
4 |
+
from kornia.filters import gaussian_blur2d
|
5 |
+
from kornia.geometry.transform import resize
|
6 |
+
from kornia.morphology import erosion
|
7 |
+
from torch.nn import functional as F
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
from saicinpainting.evaluation.data import pad_tensor_to_modulo
|
12 |
+
from saicinpainting.evaluation.utils import move_to_device
|
13 |
+
from saicinpainting.training.modules.ffc import FFCResnetBlock
|
14 |
+
from saicinpainting.training.modules.pix2pixhd import ResnetBlock
|
15 |
+
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
|
19 |
+
def _pyrdown(im : torch.Tensor, downsize : tuple=None):
|
20 |
+
"""downscale the image"""
|
21 |
+
if downsize is None:
|
22 |
+
downsize = (im.shape[2]//2, im.shape[3]//2)
|
23 |
+
assert im.shape[1] == 3, "Expected shape for the input to be (n,3,height,width)"
|
24 |
+
im = gaussian_blur2d(im, kernel_size=(5,5), sigma=(1.0,1.0))
|
25 |
+
im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False)
|
26 |
+
return im
|
27 |
+
|
28 |
+
def _pyrdown_mask(mask : torch.Tensor, downsize : tuple=None, eps : float=1e-8, blur_mask : bool=True, round_up : bool=True):
|
29 |
+
"""downscale the mask tensor
|
30 |
+
|
31 |
+
Parameters
|
32 |
+
----------
|
33 |
+
mask : torch.Tensor
|
34 |
+
mask of size (B, 1, H, W)
|
35 |
+
downsize : tuple, optional
|
36 |
+
size to downscale to. If None, image is downscaled to half, by default None
|
37 |
+
eps : float, optional
|
38 |
+
threshold value for binarizing the mask, by default 1e-8
|
39 |
+
blur_mask : bool, optional
|
40 |
+
if True, apply gaussian filter before downscaling, by default True
|
41 |
+
round_up : bool, optional
|
42 |
+
if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True
|
43 |
+
|
44 |
+
Returns
|
45 |
+
-------
|
46 |
+
torch.Tensor
|
47 |
+
downscaled mask
|
48 |
+
"""
|
49 |
+
|
50 |
+
if downsize is None:
|
51 |
+
downsize = (mask.shape[2]//2, mask.shape[3]//2)
|
52 |
+
assert mask.shape[1] == 1, "Expected shape for the input to be (n,1,height,width)"
|
53 |
+
if blur_mask is True:
|
54 |
+
mask = gaussian_blur2d(mask, kernel_size=(5,5), sigma=(1.0,1.0))
|
55 |
+
mask = F.interpolate(mask, size=downsize, mode='bilinear', align_corners=False)
|
56 |
+
else:
|
57 |
+
mask = F.interpolate(mask, size=downsize, mode='bilinear', align_corners=False)
|
58 |
+
if round_up:
|
59 |
+
mask[mask>=eps] = 1
|
60 |
+
mask[mask<eps] = 0
|
61 |
+
else:
|
62 |
+
mask[mask>=1.0-eps] = 1
|
63 |
+
mask[mask<1.0-eps] = 0
|
64 |
+
return mask
|
65 |
+
|
66 |
+
def _erode_mask(mask : torch.Tensor, ekernel : torch.Tensor=None, eps : float=1e-8):
|
67 |
+
"""erode the mask, and set gray pixels to 0"""
|
68 |
+
if ekernel is not None:
|
69 |
+
mask = erosion(mask, ekernel)
|
70 |
+
mask[mask>=1.0-eps] = 1
|
71 |
+
mask[mask<1.0-eps] = 0
|
72 |
+
return mask
|
73 |
+
|
74 |
+
|
75 |
+
def _l1_loss(
|
76 |
+
pred : torch.Tensor, pred_downscaled : torch.Tensor, ref : torch.Tensor,
|
77 |
+
mask : torch.Tensor, mask_downscaled : torch.Tensor,
|
78 |
+
image : torch.Tensor, on_pred : bool=True
|
79 |
+
):
|
80 |
+
"""l1 loss on src pixels, and downscaled predictions if on_pred=True"""
|
81 |
+
loss = torch.mean(torch.abs(pred[mask<1e-8] - image[mask<1e-8]))
|
82 |
+
if on_pred:
|
83 |
+
loss += torch.mean(torch.abs(pred_downscaled[mask_downscaled>=1e-8] - ref[mask_downscaled>=1e-8]))
|
84 |
+
return loss
|
85 |
+
|
86 |
+
def _infer(
|
87 |
+
image : torch.Tensor, mask : torch.Tensor,
|
88 |
+
forward_front : nn.Module, forward_rears : nn.Module,
|
89 |
+
ref_lower_res : torch.Tensor, orig_shape : tuple, devices : list,
|
90 |
+
scale_ind : int, n_iters : int=15, lr : float=0.002):
|
91 |
+
"""Performs inference with refinement at a given scale.
|
92 |
+
|
93 |
+
Parameters
|
94 |
+
----------
|
95 |
+
image : torch.Tensor
|
96 |
+
input image to be inpainted, of size (1,3,H,W)
|
97 |
+
mask : torch.Tensor
|
98 |
+
input inpainting mask, of size (1,1,H,W)
|
99 |
+
forward_front : nn.Module
|
100 |
+
the front part of the inpainting network
|
101 |
+
forward_rears : nn.Module
|
102 |
+
the rear part of the inpainting network
|
103 |
+
ref_lower_res : torch.Tensor
|
104 |
+
the inpainting at previous scale, used as reference image
|
105 |
+
orig_shape : tuple
|
106 |
+
shape of the original input image before padding
|
107 |
+
devices : list
|
108 |
+
list of available devices
|
109 |
+
scale_ind : int
|
110 |
+
the scale index
|
111 |
+
n_iters : int, optional
|
112 |
+
number of iterations of refinement, by default 15
|
113 |
+
lr : float, optional
|
114 |
+
learning rate, by default 0.002
|
115 |
+
|
116 |
+
Returns
|
117 |
+
-------
|
118 |
+
torch.Tensor
|
119 |
+
inpainted image
|
120 |
+
"""
|
121 |
+
masked_image = image * (1 - mask)
|
122 |
+
masked_image = torch.cat([masked_image, mask], dim=1)
|
123 |
+
|
124 |
+
mask = mask.repeat(1,3,1,1)
|
125 |
+
if ref_lower_res is not None:
|
126 |
+
ref_lower_res = ref_lower_res.detach()
|
127 |
+
with torch.no_grad():
|
128 |
+
z1,z2 = forward_front(masked_image)
|
129 |
+
# Inference
|
130 |
+
mask = mask.to(devices[-1])
|
131 |
+
ekernel = torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)).astype(bool)).float()
|
132 |
+
ekernel = ekernel.to(devices[-1])
|
133 |
+
image = image.to(devices[-1])
|
134 |
+
z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0])
|
135 |
+
z1.requires_grad, z2.requires_grad = True, True
|
136 |
+
|
137 |
+
optimizer = Adam([z1,z2], lr=lr)
|
138 |
+
|
139 |
+
pbar = tqdm(range(n_iters), leave=False)
|
140 |
+
for idi in pbar:
|
141 |
+
optimizer.zero_grad()
|
142 |
+
input_feat = (z1,z2)
|
143 |
+
for idd, forward_rear in enumerate(forward_rears):
|
144 |
+
output_feat = forward_rear(input_feat)
|
145 |
+
if idd < len(devices) - 1:
|
146 |
+
midz1, midz2 = output_feat
|
147 |
+
midz1, midz2 = midz1.to(devices[idd+1]), midz2.to(devices[idd+1])
|
148 |
+
input_feat = (midz1, midz2)
|
149 |
+
else:
|
150 |
+
pred = output_feat
|
151 |
+
|
152 |
+
if ref_lower_res is None:
|
153 |
+
break
|
154 |
+
losses = {}
|
155 |
+
######################### multi-scale #############################
|
156 |
+
# scaled loss with downsampler
|
157 |
+
pred_downscaled = _pyrdown(pred[:,:,:orig_shape[0],:orig_shape[1]])
|
158 |
+
mask_downscaled = _pyrdown_mask(mask[:,:1,:orig_shape[0],:orig_shape[1]], blur_mask=False, round_up=False)
|
159 |
+
mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel)
|
160 |
+
mask_downscaled = mask_downscaled.repeat(1,3,1,1)
|
161 |
+
losses["ms_l1"] = _l1_loss(pred, pred_downscaled, ref_lower_res, mask, mask_downscaled, image, on_pred=True)
|
162 |
+
|
163 |
+
loss = sum(losses.values())
|
164 |
+
pbar.set_description("Refining scale {} using scale {} ...current loss: {:.4f}".format(scale_ind+1, scale_ind, loss.item()))
|
165 |
+
if idi < n_iters - 1:
|
166 |
+
loss.backward()
|
167 |
+
optimizer.step()
|
168 |
+
del pred_downscaled
|
169 |
+
del loss
|
170 |
+
del pred
|
171 |
+
# "pred" is the prediction after Plug-n-Play module
|
172 |
+
inpainted = mask * pred + (1 - mask) * image
|
173 |
+
inpainted = inpainted.detach().cpu()
|
174 |
+
return inpainted
|
175 |
+
|
176 |
+
def _get_image_mask_pyramid(batch : dict, min_side : int, max_scales : int, px_budget : int):
|
177 |
+
"""Build the image mask pyramid
|
178 |
+
|
179 |
+
Parameters
|
180 |
+
----------
|
181 |
+
batch : dict
|
182 |
+
batch containing image, mask, etc
|
183 |
+
min_side : int
|
184 |
+
minimum side length to limit the number of scales of the pyramid
|
185 |
+
max_scales : int
|
186 |
+
maximum number of scales allowed
|
187 |
+
px_budget : int
|
188 |
+
the product H*W cannot exceed this budget, because of resource constraints
|
189 |
+
|
190 |
+
Returns
|
191 |
+
-------
|
192 |
+
tuple
|
193 |
+
image-mask pyramid in the form of list of images and list of masks
|
194 |
+
"""
|
195 |
+
|
196 |
+
assert batch['image'].shape[0] == 1, "refiner works on only batches of size 1!"
|
197 |
+
|
198 |
+
h, w = batch['unpad_to_size']
|
199 |
+
h, w = h[0].item(), w[0].item()
|
200 |
+
|
201 |
+
image = batch['image'][...,:h,:w]
|
202 |
+
mask = batch['mask'][...,:h,:w]
|
203 |
+
if h*w > px_budget:
|
204 |
+
#resize
|
205 |
+
ratio = np.sqrt(px_budget / float(h*w))
|
206 |
+
h_orig, w_orig = h, w
|
207 |
+
h,w = int(h*ratio), int(w*ratio)
|
208 |
+
print(f"Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...")
|
209 |
+
image = resize(image, (h,w),interpolation='bilinear', align_corners=False)
|
210 |
+
mask = resize(mask, (h,w),interpolation='bilinear', align_corners=False)
|
211 |
+
mask[mask>1e-8] = 1
|
212 |
+
breadth = min(h,w)
|
213 |
+
n_scales = min(1 + int(round(max(0,np.log2(breadth / min_side)))), max_scales)
|
214 |
+
ls_images = []
|
215 |
+
ls_masks = []
|
216 |
+
|
217 |
+
ls_images.append(image)
|
218 |
+
ls_masks.append(mask)
|
219 |
+
|
220 |
+
for _ in range(n_scales - 1):
|
221 |
+
image_p = _pyrdown(ls_images[-1])
|
222 |
+
mask_p = _pyrdown_mask(ls_masks[-1])
|
223 |
+
ls_images.append(image_p)
|
224 |
+
ls_masks.append(mask_p)
|
225 |
+
# reverse the lists because we want the lowest resolution image as index 0
|
226 |
+
return ls_images[::-1], ls_masks[::-1]
|
227 |
+
|
228 |
+
def refine_predict(
|
229 |
+
batch : dict, inpainter : nn.Module, gpu_ids : str,
|
230 |
+
modulo : int, n_iters : int, lr : float, min_side : int,
|
231 |
+
max_scales : int, px_budget : int
|
232 |
+
):
|
233 |
+
"""Refines the inpainting of the network
|
234 |
+
|
235 |
+
Parameters
|
236 |
+
----------
|
237 |
+
batch : dict
|
238 |
+
image-mask batch, currently we assume the batchsize to be 1
|
239 |
+
inpainter : nn.Module
|
240 |
+
the inpainting neural network
|
241 |
+
gpu_ids : str
|
242 |
+
the GPU ids of the machine to use. If only single GPU, use: "0,"
|
243 |
+
modulo : int
|
244 |
+
pad the image to ensure dimension % modulo == 0
|
245 |
+
n_iters : int
|
246 |
+
number of iterations of refinement for each scale
|
247 |
+
lr : float
|
248 |
+
learning rate
|
249 |
+
min_side : int
|
250 |
+
all sides of image on all scales should be >= min_side / sqrt(2)
|
251 |
+
max_scales : int
|
252 |
+
max number of downscaling scales for the image-mask pyramid
|
253 |
+
px_budget : int
|
254 |
+
pixels budget. Any image will be resized to satisfy height*width <= px_budget
|
255 |
+
|
256 |
+
Returns
|
257 |
+
-------
|
258 |
+
torch.Tensor
|
259 |
+
inpainted image of size (1,3,H,W)
|
260 |
+
"""
|
261 |
+
|
262 |
+
assert not inpainter.training
|
263 |
+
assert not inpainter.add_noise_kwargs
|
264 |
+
assert inpainter.concat_mask
|
265 |
+
|
266 |
+
gpu_ids = [f'cuda:{gpuid}' for gpuid in gpu_ids.replace(" ","").split(",") if gpuid.isdigit()]
|
267 |
+
n_resnet_blocks = 0
|
268 |
+
first_resblock_ind = 0
|
269 |
+
found_first_resblock = False
|
270 |
+
for idl in range(len(inpainter.generator.model)):
|
271 |
+
if isinstance(inpainter.generator.model[idl], FFCResnetBlock) or isinstance(inpainter.generator.model[idl], ResnetBlock):
|
272 |
+
n_resnet_blocks += 1
|
273 |
+
found_first_resblock = True
|
274 |
+
elif not found_first_resblock:
|
275 |
+
first_resblock_ind += 1
|
276 |
+
resblocks_per_gpu = n_resnet_blocks // len(gpu_ids)
|
277 |
+
|
278 |
+
devices = [torch.device(gpu_id) for gpu_id in gpu_ids]
|
279 |
+
|
280 |
+
# split the model into front, and rear parts
|
281 |
+
forward_front = inpainter.generator.model[0:first_resblock_ind]
|
282 |
+
forward_front.to(devices[0])
|
283 |
+
forward_rears = []
|
284 |
+
for idd in range(len(gpu_ids)):
|
285 |
+
if idd < len(gpu_ids) - 1:
|
286 |
+
forward_rears.append(inpainter.generator.model[first_resblock_ind + resblocks_per_gpu*(idd):first_resblock_ind+resblocks_per_gpu*(idd+1)])
|
287 |
+
else:
|
288 |
+
forward_rears.append(inpainter.generator.model[first_resblock_ind + resblocks_per_gpu*(idd):])
|
289 |
+
forward_rears[idd].to(devices[idd])
|
290 |
+
|
291 |
+
ls_images, ls_masks = _get_image_mask_pyramid(
|
292 |
+
batch,
|
293 |
+
min_side,
|
294 |
+
max_scales,
|
295 |
+
px_budget
|
296 |
+
)
|
297 |
+
image_inpainted = None
|
298 |
+
|
299 |
+
for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)):
|
300 |
+
orig_shape = image.shape[2:]
|
301 |
+
image = pad_tensor_to_modulo(image, modulo)
|
302 |
+
mask = pad_tensor_to_modulo(mask, modulo)
|
303 |
+
mask[mask >= 1e-8] = 1.0
|
304 |
+
mask[mask < 1e-8] = 0.0
|
305 |
+
image, mask = move_to_device(image, devices[0]), move_to_device(mask, devices[0])
|
306 |
+
if image_inpainted is not None:
|
307 |
+
image_inpainted = move_to_device(image_inpainted, devices[-1])
|
308 |
+
image_inpainted = _infer(image, mask, forward_front, forward_rears, image_inpainted, orig_shape, devices, ids, n_iters, lr)
|
309 |
+
image_inpainted = image_inpainted[:,:,:orig_shape[0], :orig_shape[1]]
|
310 |
+
# detach everything to save resources
|
311 |
+
image = image.detach().cpu()
|
312 |
+
mask = mask.detach().cpu()
|
313 |
+
|
314 |
+
return image_inpainted
|
saicinpainting/evaluation/utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
from easydict import EasyDict as edict
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def load_yaml(path):
|
10 |
+
with open(path, 'r') as f:
|
11 |
+
return edict(yaml.safe_load(f))
|
12 |
+
|
13 |
+
|
14 |
+
def move_to_device(obj, device):
|
15 |
+
if isinstance(obj, nn.Module):
|
16 |
+
return obj.to(device)
|
17 |
+
if torch.is_tensor(obj):
|
18 |
+
return obj.to(device)
|
19 |
+
if isinstance(obj, (tuple, list)):
|
20 |
+
return [move_to_device(el, device) for el in obj]
|
21 |
+
if isinstance(obj, dict):
|
22 |
+
return {name: move_to_device(val, device) for name, val in obj.items()}
|
23 |
+
raise ValueError(f'Unexpected type {type(obj)}')
|
24 |
+
|
25 |
+
|
26 |
+
class SmallMode(Enum):
|
27 |
+
DROP = "drop"
|
28 |
+
UPSCALE = "upscale"
|
saicinpainting/evaluation/vis.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from skimage import io
|
3 |
+
from skimage.segmentation import mark_boundaries
|
4 |
+
|
5 |
+
|
6 |
+
def save_item_for_vis(item, out_file):
|
7 |
+
mask = item['mask'] > 0.5
|
8 |
+
if mask.ndim == 3:
|
9 |
+
mask = mask[0]
|
10 |
+
img = mark_boundaries(np.transpose(item['image'], (1, 2, 0)),
|
11 |
+
mask,
|
12 |
+
color=(1., 0., 0.),
|
13 |
+
outline_color=(1., 1., 1.),
|
14 |
+
mode='thick')
|
15 |
+
|
16 |
+
if 'inpainted' in item:
|
17 |
+
inp_img = mark_boundaries(np.transpose(item['inpainted'], (1, 2, 0)),
|
18 |
+
mask,
|
19 |
+
color=(1., 0., 0.),
|
20 |
+
mode='outer')
|
21 |
+
img = np.concatenate((img, inp_img), axis=1)
|
22 |
+
|
23 |
+
img = np.clip(img * 255, 0, 255).astype('uint8')
|
24 |
+
io.imsave(out_file, img)
|
25 |
+
|
26 |
+
|
27 |
+
def save_mask_for_sidebyside(item, out_file):
|
28 |
+
mask = item['mask']# > 0.5
|
29 |
+
if mask.ndim == 3:
|
30 |
+
mask = mask[0]
|
31 |
+
mask = np.clip(mask * 255, 0, 255).astype('uint8')
|
32 |
+
io.imsave(out_file, mask)
|
33 |
+
|
34 |
+
def save_img_for_sidebyside(item, out_file):
|
35 |
+
img = np.transpose(item['image'], (1, 2, 0))
|
36 |
+
img = np.clip(img * 255, 0, 255).astype('uint8')
|
37 |
+
io.imsave(out_file, img)
|
saicinpainting/training/__init__.py
ADDED
File without changes
|
saicinpainting/training/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (152 Bytes). View file
|
|
saicinpainting/training/data/__init__.py
ADDED
File without changes
|
saicinpainting/training/data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (157 Bytes). View file
|
|
saicinpainting/training/data/__pycache__/aug.cpython-39.pyc
ADDED
Binary file (3.16 kB). View file
|
|
saicinpainting/training/data/__pycache__/datasets.cpython-39.pyc
ADDED
Binary file (8.94 kB). View file
|
|
saicinpainting/training/data/__pycache__/masks.cpython-39.pyc
ADDED
Binary file (11.9 kB). View file
|
|