Spaces:
Sleeping
Sleeping
update repo
Browse files- .DS_Store +0 -0
- .gitattributes +2 -0
- dnnlib/__pycache__/util.cpython-38.pyc +0 -0
- id_loss.py +1 -1
- metrics/__init__.py +0 -9
- metrics/frechet_inception_distance.py +0 -41
- metrics/inception_score.py +0 -38
- metrics/kernel_inception_distance.py +0 -46
- metrics/metric_main.py +0 -152
- metrics/metric_utils.py +0 -275
- metrics/perceptual_path_length.py +0 -131
- metrics/precision_recall.py +0 -62
- pretrained/.DS_Store +0 -0
- pretrained/ffhq.pkl +3 -0
- pretrained/metfaces.pkl +3 -0
- model_ir_se50.pth → pretrained/model_ir_se50.pth +0 -0
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
.gitattributes
CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
*.pth* filter=lfs diff=lfs merge=lfs -text
|
36 |
filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
*.pth* filter=lfs diff=lfs merge=lfs -text
|
36 |
filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.pkl* filter=lfs diff=lfs merge=lfs -text
|
38 |
+
filter=lfs diff=lfs merge=lfs -text
|
dnnlib/__pycache__/util.cpython-38.pyc
CHANGED
Binary files a/dnnlib/__pycache__/util.cpython-38.pyc and b/dnnlib/__pycache__/util.cpython-38.pyc differ
|
|
id_loss.py
CHANGED
@@ -15,7 +15,7 @@ class IDLoss(nn.Module):
|
|
15 |
super(IDLoss, self).__init__()
|
16 |
print('Loading ResNet ArcFace')
|
17 |
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
18 |
-
self.facenet.load_state_dict(torch.load("model_ir_se50.pth", map_location=device))
|
19 |
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
20 |
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
21 |
self.facenet.eval()
|
|
|
15 |
super(IDLoss, self).__init__()
|
16 |
print('Loading ResNet ArcFace')
|
17 |
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
18 |
+
self.facenet.load_state_dict(torch.load("./pretrained/model_ir_se50.pth", map_location=device))
|
19 |
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
20 |
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
21 |
self.facenet.eval()
|
metrics/__init__.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
# empty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/frechet_inception_distance.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Frechet Inception Distance (FID) from the paper
|
10 |
-
"GANs trained by a two time-scale update rule converge to a local Nash
|
11 |
-
equilibrium". Matches the original implementation by Heusel et al. at
|
12 |
-
https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
|
13 |
-
|
14 |
-
import numpy as np
|
15 |
-
import scipy.linalg
|
16 |
-
from . import metric_utils
|
17 |
-
|
18 |
-
#----------------------------------------------------------------------------
|
19 |
-
|
20 |
-
def compute_fid(opts, max_real, num_gen):
|
21 |
-
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
22 |
-
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
23 |
-
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
24 |
-
|
25 |
-
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
|
26 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
27 |
-
rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
|
28 |
-
|
29 |
-
mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
|
30 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
31 |
-
rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
|
32 |
-
|
33 |
-
if opts.rank != 0:
|
34 |
-
return float('nan')
|
35 |
-
|
36 |
-
m = np.square(mu_gen - mu_real).sum()
|
37 |
-
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
38 |
-
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
39 |
-
return float(fid)
|
40 |
-
|
41 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/inception_score.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Inception Score (IS) from the paper "Improved techniques for training
|
10 |
-
GANs". Matches the original implementation by Salimans et al. at
|
11 |
-
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
|
12 |
-
|
13 |
-
import numpy as np
|
14 |
-
from . import metric_utils
|
15 |
-
|
16 |
-
#----------------------------------------------------------------------------
|
17 |
-
|
18 |
-
def compute_is(opts, num_gen, num_splits):
|
19 |
-
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
20 |
-
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
21 |
-
detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
|
22 |
-
|
23 |
-
gen_probs = metric_utils.compute_feature_stats_for_generator(
|
24 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
25 |
-
capture_all=True, max_items=num_gen).get_all()
|
26 |
-
|
27 |
-
if opts.rank != 0:
|
28 |
-
return float('nan'), float('nan')
|
29 |
-
|
30 |
-
scores = []
|
31 |
-
for i in range(num_splits):
|
32 |
-
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
|
33 |
-
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
|
34 |
-
kl = np.mean(np.sum(kl, axis=1))
|
35 |
-
scores.append(np.exp(kl))
|
36 |
-
return float(np.mean(scores)), float(np.std(scores))
|
37 |
-
|
38 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/kernel_inception_distance.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
|
10 |
-
GANs". Matches the original implementation by Binkowski et al. at
|
11 |
-
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
|
12 |
-
|
13 |
-
import numpy as np
|
14 |
-
from . import metric_utils
|
15 |
-
|
16 |
-
#----------------------------------------------------------------------------
|
17 |
-
|
18 |
-
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
|
19 |
-
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
20 |
-
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
21 |
-
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
22 |
-
|
23 |
-
real_features = metric_utils.compute_feature_stats_for_dataset(
|
24 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
25 |
-
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
|
26 |
-
|
27 |
-
gen_features = metric_utils.compute_feature_stats_for_generator(
|
28 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
29 |
-
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
|
30 |
-
|
31 |
-
if opts.rank != 0:
|
32 |
-
return float('nan')
|
33 |
-
|
34 |
-
n = real_features.shape[1]
|
35 |
-
m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
|
36 |
-
t = 0
|
37 |
-
for _subset_idx in range(num_subsets):
|
38 |
-
x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
|
39 |
-
y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
|
40 |
-
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
|
41 |
-
b = (x @ y.T / n + 1) ** 3
|
42 |
-
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
|
43 |
-
kid = t / num_subsets / m
|
44 |
-
return float(kid)
|
45 |
-
|
46 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/metric_main.py
DELETED
@@ -1,152 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
import os
|
10 |
-
import time
|
11 |
-
import json
|
12 |
-
import torch
|
13 |
-
import dnnlib
|
14 |
-
|
15 |
-
from . import metric_utils
|
16 |
-
from . import frechet_inception_distance
|
17 |
-
from . import kernel_inception_distance
|
18 |
-
from . import precision_recall
|
19 |
-
from . import perceptual_path_length
|
20 |
-
from . import inception_score
|
21 |
-
|
22 |
-
#----------------------------------------------------------------------------
|
23 |
-
|
24 |
-
_metric_dict = dict() # name => fn
|
25 |
-
|
26 |
-
def register_metric(fn):
|
27 |
-
assert callable(fn)
|
28 |
-
_metric_dict[fn.__name__] = fn
|
29 |
-
return fn
|
30 |
-
|
31 |
-
def is_valid_metric(metric):
|
32 |
-
return metric in _metric_dict
|
33 |
-
|
34 |
-
def list_valid_metrics():
|
35 |
-
return list(_metric_dict.keys())
|
36 |
-
|
37 |
-
#----------------------------------------------------------------------------
|
38 |
-
|
39 |
-
def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
|
40 |
-
assert is_valid_metric(metric)
|
41 |
-
opts = metric_utils.MetricOptions(**kwargs)
|
42 |
-
|
43 |
-
# Calculate.
|
44 |
-
start_time = time.time()
|
45 |
-
results = _metric_dict[metric](opts)
|
46 |
-
total_time = time.time() - start_time
|
47 |
-
|
48 |
-
# Broadcast results.
|
49 |
-
for key, value in list(results.items()):
|
50 |
-
if opts.num_gpus > 1:
|
51 |
-
value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
|
52 |
-
torch.distributed.broadcast(tensor=value, src=0)
|
53 |
-
value = float(value.cpu())
|
54 |
-
results[key] = value
|
55 |
-
|
56 |
-
# Decorate with metadata.
|
57 |
-
return dnnlib.EasyDict(
|
58 |
-
results = dnnlib.EasyDict(results),
|
59 |
-
metric = metric,
|
60 |
-
total_time = total_time,
|
61 |
-
total_time_str = dnnlib.util.format_time(total_time),
|
62 |
-
num_gpus = opts.num_gpus,
|
63 |
-
)
|
64 |
-
|
65 |
-
#----------------------------------------------------------------------------
|
66 |
-
|
67 |
-
def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
|
68 |
-
metric = result_dict['metric']
|
69 |
-
assert is_valid_metric(metric)
|
70 |
-
if run_dir is not None and snapshot_pkl is not None:
|
71 |
-
snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
|
72 |
-
|
73 |
-
jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
|
74 |
-
print(jsonl_line)
|
75 |
-
if run_dir is not None and os.path.isdir(run_dir):
|
76 |
-
with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
|
77 |
-
f.write(jsonl_line + '\n')
|
78 |
-
|
79 |
-
#----------------------------------------------------------------------------
|
80 |
-
# Primary metrics.
|
81 |
-
|
82 |
-
@register_metric
|
83 |
-
def fid50k_full(opts):
|
84 |
-
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
85 |
-
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
|
86 |
-
return dict(fid50k_full=fid)
|
87 |
-
|
88 |
-
@register_metric
|
89 |
-
def kid50k_full(opts):
|
90 |
-
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
91 |
-
kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
92 |
-
return dict(kid50k_full=kid)
|
93 |
-
|
94 |
-
@register_metric
|
95 |
-
def pr50k3_full(opts):
|
96 |
-
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
97 |
-
precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
98 |
-
return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
|
99 |
-
|
100 |
-
@register_metric
|
101 |
-
def ppl2_wend(opts):
|
102 |
-
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
|
103 |
-
return dict(ppl2_wend=ppl)
|
104 |
-
|
105 |
-
@register_metric
|
106 |
-
def is50k(opts):
|
107 |
-
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
108 |
-
mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
|
109 |
-
return dict(is50k_mean=mean, is50k_std=std)
|
110 |
-
|
111 |
-
#----------------------------------------------------------------------------
|
112 |
-
# Legacy metrics.
|
113 |
-
|
114 |
-
@register_metric
|
115 |
-
def fid50k(opts):
|
116 |
-
opts.dataset_kwargs.update(max_size=None)
|
117 |
-
fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
|
118 |
-
return dict(fid50k=fid)
|
119 |
-
|
120 |
-
@register_metric
|
121 |
-
def kid50k(opts):
|
122 |
-
opts.dataset_kwargs.update(max_size=None)
|
123 |
-
kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
124 |
-
return dict(kid50k=kid)
|
125 |
-
|
126 |
-
@register_metric
|
127 |
-
def pr50k3(opts):
|
128 |
-
opts.dataset_kwargs.update(max_size=None)
|
129 |
-
precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
130 |
-
return dict(pr50k3_precision=precision, pr50k3_recall=recall)
|
131 |
-
|
132 |
-
@register_metric
|
133 |
-
def ppl_zfull(opts):
|
134 |
-
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
|
135 |
-
return dict(ppl_zfull=ppl)
|
136 |
-
|
137 |
-
@register_metric
|
138 |
-
def ppl_wfull(opts):
|
139 |
-
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
|
140 |
-
return dict(ppl_wfull=ppl)
|
141 |
-
|
142 |
-
@register_metric
|
143 |
-
def ppl_zend(opts):
|
144 |
-
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
|
145 |
-
return dict(ppl_zend=ppl)
|
146 |
-
|
147 |
-
@register_metric
|
148 |
-
def ppl_wend(opts):
|
149 |
-
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
|
150 |
-
return dict(ppl_wend=ppl)
|
151 |
-
|
152 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/metric_utils.py
DELETED
@@ -1,275 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
import os
|
10 |
-
import time
|
11 |
-
import hashlib
|
12 |
-
import pickle
|
13 |
-
import copy
|
14 |
-
import uuid
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
import dnnlib
|
18 |
-
|
19 |
-
#----------------------------------------------------------------------------
|
20 |
-
|
21 |
-
class MetricOptions:
|
22 |
-
def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
|
23 |
-
assert 0 <= rank < num_gpus
|
24 |
-
self.G = G
|
25 |
-
self.G_kwargs = dnnlib.EasyDict(G_kwargs)
|
26 |
-
self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
|
27 |
-
self.num_gpus = num_gpus
|
28 |
-
self.rank = rank
|
29 |
-
self.device = device if device is not None else torch.device('cuda', rank)
|
30 |
-
self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
|
31 |
-
self.cache = cache
|
32 |
-
|
33 |
-
#----------------------------------------------------------------------------
|
34 |
-
|
35 |
-
_feature_detector_cache = dict()
|
36 |
-
|
37 |
-
def get_feature_detector_name(url):
|
38 |
-
return os.path.splitext(url.split('/')[-1])[0]
|
39 |
-
|
40 |
-
def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
|
41 |
-
assert 0 <= rank < num_gpus
|
42 |
-
key = (url, device)
|
43 |
-
if key not in _feature_detector_cache:
|
44 |
-
is_leader = (rank == 0)
|
45 |
-
if not is_leader and num_gpus > 1:
|
46 |
-
torch.distributed.barrier() # leader goes first
|
47 |
-
with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
|
48 |
-
_feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
|
49 |
-
if is_leader and num_gpus > 1:
|
50 |
-
torch.distributed.barrier() # others follow
|
51 |
-
return _feature_detector_cache[key]
|
52 |
-
|
53 |
-
#----------------------------------------------------------------------------
|
54 |
-
|
55 |
-
class FeatureStats:
|
56 |
-
def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
|
57 |
-
self.capture_all = capture_all
|
58 |
-
self.capture_mean_cov = capture_mean_cov
|
59 |
-
self.max_items = max_items
|
60 |
-
self.num_items = 0
|
61 |
-
self.num_features = None
|
62 |
-
self.all_features = None
|
63 |
-
self.raw_mean = None
|
64 |
-
self.raw_cov = None
|
65 |
-
|
66 |
-
def set_num_features(self, num_features):
|
67 |
-
if self.num_features is not None:
|
68 |
-
assert num_features == self.num_features
|
69 |
-
else:
|
70 |
-
self.num_features = num_features
|
71 |
-
self.all_features = []
|
72 |
-
self.raw_mean = np.zeros([num_features], dtype=np.float64)
|
73 |
-
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
|
74 |
-
|
75 |
-
def is_full(self):
|
76 |
-
return (self.max_items is not None) and (self.num_items >= self.max_items)
|
77 |
-
|
78 |
-
def append(self, x):
|
79 |
-
x = np.asarray(x, dtype=np.float32)
|
80 |
-
assert x.ndim == 2
|
81 |
-
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
|
82 |
-
if self.num_items >= self.max_items:
|
83 |
-
return
|
84 |
-
x = x[:self.max_items - self.num_items]
|
85 |
-
|
86 |
-
self.set_num_features(x.shape[1])
|
87 |
-
self.num_items += x.shape[0]
|
88 |
-
if self.capture_all:
|
89 |
-
self.all_features.append(x)
|
90 |
-
if self.capture_mean_cov:
|
91 |
-
x64 = x.astype(np.float64)
|
92 |
-
self.raw_mean += x64.sum(axis=0)
|
93 |
-
self.raw_cov += x64.T @ x64
|
94 |
-
|
95 |
-
def append_torch(self, x, num_gpus=1, rank=0):
|
96 |
-
assert isinstance(x, torch.Tensor) and x.ndim == 2
|
97 |
-
assert 0 <= rank < num_gpus
|
98 |
-
if num_gpus > 1:
|
99 |
-
ys = []
|
100 |
-
for src in range(num_gpus):
|
101 |
-
y = x.clone()
|
102 |
-
torch.distributed.broadcast(y, src=src)
|
103 |
-
ys.append(y)
|
104 |
-
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
|
105 |
-
self.append(x.cpu().numpy())
|
106 |
-
|
107 |
-
def get_all(self):
|
108 |
-
assert self.capture_all
|
109 |
-
return np.concatenate(self.all_features, axis=0)
|
110 |
-
|
111 |
-
def get_all_torch(self):
|
112 |
-
return torch.from_numpy(self.get_all())
|
113 |
-
|
114 |
-
def get_mean_cov(self):
|
115 |
-
assert self.capture_mean_cov
|
116 |
-
mean = self.raw_mean / self.num_items
|
117 |
-
cov = self.raw_cov / self.num_items
|
118 |
-
cov = cov - np.outer(mean, mean)
|
119 |
-
return mean, cov
|
120 |
-
|
121 |
-
def save(self, pkl_file):
|
122 |
-
with open(pkl_file, 'wb') as f:
|
123 |
-
pickle.dump(self.__dict__, f)
|
124 |
-
|
125 |
-
@staticmethod
|
126 |
-
def load(pkl_file):
|
127 |
-
with open(pkl_file, 'rb') as f:
|
128 |
-
s = dnnlib.EasyDict(pickle.load(f))
|
129 |
-
obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
|
130 |
-
obj.__dict__.update(s)
|
131 |
-
return obj
|
132 |
-
|
133 |
-
#----------------------------------------------------------------------------
|
134 |
-
|
135 |
-
class ProgressMonitor:
|
136 |
-
def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
|
137 |
-
self.tag = tag
|
138 |
-
self.num_items = num_items
|
139 |
-
self.verbose = verbose
|
140 |
-
self.flush_interval = flush_interval
|
141 |
-
self.progress_fn = progress_fn
|
142 |
-
self.pfn_lo = pfn_lo
|
143 |
-
self.pfn_hi = pfn_hi
|
144 |
-
self.pfn_total = pfn_total
|
145 |
-
self.start_time = time.time()
|
146 |
-
self.batch_time = self.start_time
|
147 |
-
self.batch_items = 0
|
148 |
-
if self.progress_fn is not None:
|
149 |
-
self.progress_fn(self.pfn_lo, self.pfn_total)
|
150 |
-
|
151 |
-
def update(self, cur_items):
|
152 |
-
assert (self.num_items is None) or (cur_items <= self.num_items)
|
153 |
-
if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
|
154 |
-
return
|
155 |
-
cur_time = time.time()
|
156 |
-
total_time = cur_time - self.start_time
|
157 |
-
time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
|
158 |
-
if (self.verbose) and (self.tag is not None):
|
159 |
-
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
|
160 |
-
self.batch_time = cur_time
|
161 |
-
self.batch_items = cur_items
|
162 |
-
|
163 |
-
if (self.progress_fn is not None) and (self.num_items is not None):
|
164 |
-
self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
|
165 |
-
|
166 |
-
def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
|
167 |
-
return ProgressMonitor(
|
168 |
-
tag = tag,
|
169 |
-
num_items = num_items,
|
170 |
-
flush_interval = flush_interval,
|
171 |
-
verbose = self.verbose,
|
172 |
-
progress_fn = self.progress_fn,
|
173 |
-
pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
|
174 |
-
pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
|
175 |
-
pfn_total = self.pfn_total,
|
176 |
-
)
|
177 |
-
|
178 |
-
#----------------------------------------------------------------------------
|
179 |
-
|
180 |
-
def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
|
181 |
-
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
182 |
-
if data_loader_kwargs is None:
|
183 |
-
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
|
184 |
-
|
185 |
-
# Try to lookup from cache.
|
186 |
-
cache_file = None
|
187 |
-
if opts.cache:
|
188 |
-
# Choose cache file name.
|
189 |
-
args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
|
190 |
-
md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
|
191 |
-
cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
|
192 |
-
cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
|
193 |
-
|
194 |
-
# Check if the file exists (all processes must agree).
|
195 |
-
flag = os.path.isfile(cache_file) if opts.rank == 0 else False
|
196 |
-
if opts.num_gpus > 1:
|
197 |
-
flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
|
198 |
-
torch.distributed.broadcast(tensor=flag, src=0)
|
199 |
-
flag = (float(flag.cpu()) != 0)
|
200 |
-
|
201 |
-
# Load.
|
202 |
-
if flag:
|
203 |
-
return FeatureStats.load(cache_file)
|
204 |
-
|
205 |
-
# Initialize.
|
206 |
-
num_items = len(dataset)
|
207 |
-
if max_items is not None:
|
208 |
-
num_items = min(num_items, max_items)
|
209 |
-
stats = FeatureStats(max_items=num_items, **stats_kwargs)
|
210 |
-
progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
211 |
-
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
212 |
-
|
213 |
-
# Main loop.
|
214 |
-
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
|
215 |
-
for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
|
216 |
-
if images.shape[1] == 1:
|
217 |
-
images = images.repeat([1, 3, 1, 1])
|
218 |
-
features = detector(images.to(opts.device), **detector_kwargs)
|
219 |
-
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
220 |
-
progress.update(stats.num_items)
|
221 |
-
|
222 |
-
# Save to cache.
|
223 |
-
if cache_file is not None and opts.rank == 0:
|
224 |
-
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
225 |
-
temp_file = cache_file + '.' + uuid.uuid4().hex
|
226 |
-
stats.save(temp_file)
|
227 |
-
os.replace(temp_file, cache_file) # atomic
|
228 |
-
return stats
|
229 |
-
|
230 |
-
#----------------------------------------------------------------------------
|
231 |
-
|
232 |
-
def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
|
233 |
-
if batch_gen is None:
|
234 |
-
batch_gen = min(batch_size, 4)
|
235 |
-
assert batch_size % batch_gen == 0
|
236 |
-
|
237 |
-
# Setup generator and load labels.
|
238 |
-
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
239 |
-
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
240 |
-
|
241 |
-
# Image generation func.
|
242 |
-
def run_generator(z, c):
|
243 |
-
img = G(z=z, c=c, **opts.G_kwargs)
|
244 |
-
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
245 |
-
return img
|
246 |
-
|
247 |
-
# JIT.
|
248 |
-
if jit:
|
249 |
-
z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
|
250 |
-
c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
|
251 |
-
run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
|
252 |
-
|
253 |
-
# Initialize.
|
254 |
-
stats = FeatureStats(**stats_kwargs)
|
255 |
-
assert stats.max_items is not None
|
256 |
-
progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
257 |
-
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
258 |
-
|
259 |
-
# Main loop.
|
260 |
-
while not stats.is_full():
|
261 |
-
images = []
|
262 |
-
for _i in range(batch_size // batch_gen):
|
263 |
-
z = torch.randn([batch_gen, G.z_dim], device=opts.device)
|
264 |
-
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
|
265 |
-
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
266 |
-
images.append(run_generator(z, c))
|
267 |
-
images = torch.cat(images)
|
268 |
-
if images.shape[1] == 1:
|
269 |
-
images = images.repeat([1, 3, 1, 1])
|
270 |
-
features = detector(images, **detector_kwargs)
|
271 |
-
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
272 |
-
progress.update(stats.num_items)
|
273 |
-
return stats
|
274 |
-
|
275 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/perceptual_path_length.py
DELETED
@@ -1,131 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
|
10 |
-
Architecture for Generative Adversarial Networks". Matches the original
|
11 |
-
implementation by Karras et al. at
|
12 |
-
https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
|
13 |
-
|
14 |
-
import copy
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
import dnnlib
|
18 |
-
from . import metric_utils
|
19 |
-
|
20 |
-
#----------------------------------------------------------------------------
|
21 |
-
|
22 |
-
# Spherical interpolation of a batch of vectors.
|
23 |
-
def slerp(a, b, t):
|
24 |
-
a = a / a.norm(dim=-1, keepdim=True)
|
25 |
-
b = b / b.norm(dim=-1, keepdim=True)
|
26 |
-
d = (a * b).sum(dim=-1, keepdim=True)
|
27 |
-
p = t * torch.acos(d)
|
28 |
-
c = b - d * a
|
29 |
-
c = c / c.norm(dim=-1, keepdim=True)
|
30 |
-
d = a * torch.cos(p) + c * torch.sin(p)
|
31 |
-
d = d / d.norm(dim=-1, keepdim=True)
|
32 |
-
return d
|
33 |
-
|
34 |
-
#----------------------------------------------------------------------------
|
35 |
-
|
36 |
-
class PPLSampler(torch.nn.Module):
|
37 |
-
def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
|
38 |
-
assert space in ['z', 'w']
|
39 |
-
assert sampling in ['full', 'end']
|
40 |
-
super().__init__()
|
41 |
-
self.G = copy.deepcopy(G)
|
42 |
-
self.G_kwargs = G_kwargs
|
43 |
-
self.epsilon = epsilon
|
44 |
-
self.space = space
|
45 |
-
self.sampling = sampling
|
46 |
-
self.crop = crop
|
47 |
-
self.vgg16 = copy.deepcopy(vgg16)
|
48 |
-
|
49 |
-
def forward(self, c):
|
50 |
-
# Generate random latents and interpolation t-values.
|
51 |
-
t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
|
52 |
-
z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
|
53 |
-
|
54 |
-
# Interpolate in W or Z.
|
55 |
-
if self.space == 'w':
|
56 |
-
w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
|
57 |
-
wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
|
58 |
-
wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
|
59 |
-
else: # space == 'z'
|
60 |
-
zt0 = slerp(z0, z1, t.unsqueeze(1))
|
61 |
-
zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
|
62 |
-
wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
|
63 |
-
|
64 |
-
# Randomize noise buffers.
|
65 |
-
for name, buf in self.G.named_buffers():
|
66 |
-
if name.endswith('.noise_const'):
|
67 |
-
buf.copy_(torch.randn_like(buf))
|
68 |
-
|
69 |
-
# Generate images.
|
70 |
-
img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
|
71 |
-
|
72 |
-
# Center crop.
|
73 |
-
if self.crop:
|
74 |
-
assert img.shape[2] == img.shape[3]
|
75 |
-
c = img.shape[2] // 8
|
76 |
-
img = img[:, :, c*3 : c*7, c*2 : c*6]
|
77 |
-
|
78 |
-
# Downsample to 256x256.
|
79 |
-
factor = self.G.img_resolution // 256
|
80 |
-
if factor > 1:
|
81 |
-
img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
|
82 |
-
|
83 |
-
# Scale dynamic range from [-1,1] to [0,255].
|
84 |
-
img = (img + 1) * (255 / 2)
|
85 |
-
if self.G.img_channels == 1:
|
86 |
-
img = img.repeat([1, 3, 1, 1])
|
87 |
-
|
88 |
-
# Evaluate differential LPIPS.
|
89 |
-
lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
|
90 |
-
dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
|
91 |
-
return dist
|
92 |
-
|
93 |
-
#----------------------------------------------------------------------------
|
94 |
-
|
95 |
-
def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
|
96 |
-
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
97 |
-
vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
98 |
-
vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
|
99 |
-
|
100 |
-
# Setup sampler.
|
101 |
-
sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
|
102 |
-
sampler.eval().requires_grad_(False).to(opts.device)
|
103 |
-
if jit:
|
104 |
-
c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
|
105 |
-
sampler = torch.jit.trace(sampler, [c], check_trace=False)
|
106 |
-
|
107 |
-
# Sampling loop.
|
108 |
-
dist = []
|
109 |
-
progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
|
110 |
-
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
|
111 |
-
progress.update(batch_start)
|
112 |
-
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
|
113 |
-
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
114 |
-
x = sampler(c)
|
115 |
-
for src in range(opts.num_gpus):
|
116 |
-
y = x.clone()
|
117 |
-
if opts.num_gpus > 1:
|
118 |
-
torch.distributed.broadcast(y, src=src)
|
119 |
-
dist.append(y)
|
120 |
-
progress.update(num_samples)
|
121 |
-
|
122 |
-
# Compute PPL.
|
123 |
-
if opts.rank != 0:
|
124 |
-
return float('nan')
|
125 |
-
dist = torch.cat(dist)[:num_samples].cpu().numpy()
|
126 |
-
lo = np.percentile(dist, 1, interpolation='lower')
|
127 |
-
hi = np.percentile(dist, 99, interpolation='higher')
|
128 |
-
ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
|
129 |
-
return float(ppl)
|
130 |
-
|
131 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/precision_recall.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Precision/Recall (PR) from the paper "Improved Precision and Recall
|
10 |
-
Metric for Assessing Generative Models". Matches the original implementation
|
11 |
-
by Kynkaanniemi et al. at
|
12 |
-
https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
|
13 |
-
|
14 |
-
import torch
|
15 |
-
from . import metric_utils
|
16 |
-
|
17 |
-
#----------------------------------------------------------------------------
|
18 |
-
|
19 |
-
def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
|
20 |
-
assert 0 <= rank < num_gpus
|
21 |
-
num_cols = col_features.shape[0]
|
22 |
-
num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
|
23 |
-
col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
|
24 |
-
dist_batches = []
|
25 |
-
for col_batch in col_batches[rank :: num_gpus]:
|
26 |
-
dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
|
27 |
-
for src in range(num_gpus):
|
28 |
-
dist_broadcast = dist_batch.clone()
|
29 |
-
if num_gpus > 1:
|
30 |
-
torch.distributed.broadcast(dist_broadcast, src=src)
|
31 |
-
dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
|
32 |
-
return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
|
33 |
-
|
34 |
-
#----------------------------------------------------------------------------
|
35 |
-
|
36 |
-
def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
|
37 |
-
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
38 |
-
detector_kwargs = dict(return_features=True)
|
39 |
-
|
40 |
-
real_features = metric_utils.compute_feature_stats_for_dataset(
|
41 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
42 |
-
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
|
43 |
-
|
44 |
-
gen_features = metric_utils.compute_feature_stats_for_generator(
|
45 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
46 |
-
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
|
47 |
-
|
48 |
-
results = dict()
|
49 |
-
for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
|
50 |
-
kth = []
|
51 |
-
for manifold_batch in manifold.split(row_batch_size):
|
52 |
-
dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
|
53 |
-
kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
|
54 |
-
kth = torch.cat(kth) if opts.rank == 0 else None
|
55 |
-
pred = []
|
56 |
-
for probes_batch in probes.split(row_batch_size):
|
57 |
-
dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
|
58 |
-
pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
|
59 |
-
results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
|
60 |
-
return results['precision'], results['recall']
|
61 |
-
|
62 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretrained/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
pretrained/ffhq.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a205a346e86a9ddaae702e118097d014b7b8bd719491396a162cca438f2f524c
|
3 |
+
size 381624121
|
pretrained/metfaces.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:880a460d011a3696c088f58f5844b44271b17903963f2671f96f72dfbce5f76f
|
3 |
+
size 381624133
|
model_ir_se50.pth → pretrained/model_ir_se50.pth
RENAMED
File without changes
|