APISR / loss /anime_perceptual_loss.py
HikariDawn's picture
feat: initial push
561c629
raw
history blame
22.1 kB
import os, sys
from collections import OrderedDict
import cv2
import torch.nn as nn
import torch
from torchvision import models
import torchvision.transforms as transforms
'''
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 4,096
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 16,384
BatchNorm2d-12 [-1, 256, 56, 56] 512
Conv2d-13 [-1, 256, 56, 56] 16,384
BatchNorm2d-14 [-1, 256, 56, 56] 512
ReLU-15 [-1, 256, 56, 56] 0
Bottleneck-16 [-1, 256, 56, 56] 0
Conv2d-17 [-1, 64, 56, 56] 16,384
BatchNorm2d-18 [-1, 64, 56, 56] 128
ReLU-19 [-1, 64, 56, 56] 0
Conv2d-20 [-1, 64, 56, 56] 36,864
BatchNorm2d-21 [-1, 64, 56, 56] 128
ReLU-22 [-1, 64, 56, 56] 0
Conv2d-23 [-1, 256, 56, 56] 16,384
BatchNorm2d-24 [-1, 256, 56, 56] 512
ReLU-25 [-1, 256, 56, 56] 0
Bottleneck-26 [-1, 256, 56, 56] 0
Conv2d-27 [-1, 64, 56, 56] 16,384
BatchNorm2d-28 [-1, 64, 56, 56] 128
ReLU-29 [-1, 64, 56, 56] 0
Conv2d-30 [-1, 64, 56, 56] 36,864
BatchNorm2d-31 [-1, 64, 56, 56] 128
ReLU-32 [-1, 64, 56, 56] 0
Conv2d-33 [-1, 256, 56, 56] 16,384
BatchNorm2d-34 [-1, 256, 56, 56] 512
ReLU-35 [-1, 256, 56, 56] 0
Bottleneck-36 [-1, 256, 56, 56] 0
Conv2d-37 [-1, 128, 56, 56] 32,768
BatchNorm2d-38 [-1, 128, 56, 56] 256
ReLU-39 [-1, 128, 56, 56] 0
Conv2d-40 [-1, 128, 28, 28] 147,456
BatchNorm2d-41 [-1, 128, 28, 28] 256
ReLU-42 [-1, 128, 28, 28] 0
Conv2d-43 [-1, 512, 28, 28] 65,536
BatchNorm2d-44 [-1, 512, 28, 28] 1,024
Conv2d-45 [-1, 512, 28, 28] 131,072
BatchNorm2d-46 [-1, 512, 28, 28] 1,024
ReLU-47 [-1, 512, 28, 28] 0
Bottleneck-48 [-1, 512, 28, 28] 0
Conv2d-49 [-1, 128, 28, 28] 65,536
BatchNorm2d-50 [-1, 128, 28, 28] 256
ReLU-51 [-1, 128, 28, 28] 0
Conv2d-52 [-1, 128, 28, 28] 147,456
BatchNorm2d-53 [-1, 128, 28, 28] 256
ReLU-54 [-1, 128, 28, 28] 0
Conv2d-55 [-1, 512, 28, 28] 65,536
BatchNorm2d-56 [-1, 512, 28, 28] 1,024
ReLU-57 [-1, 512, 28, 28] 0
Bottleneck-58 [-1, 512, 28, 28] 0
Conv2d-59 [-1, 128, 28, 28] 65,536
BatchNorm2d-60 [-1, 128, 28, 28] 256
ReLU-61 [-1, 128, 28, 28] 0
Conv2d-62 [-1, 128, 28, 28] 147,456
BatchNorm2d-63 [-1, 128, 28, 28] 256
ReLU-64 [-1, 128, 28, 28] 0
Conv2d-65 [-1, 512, 28, 28] 65,536
BatchNorm2d-66 [-1, 512, 28, 28] 1,024
ReLU-67 [-1, 512, 28, 28] 0
Bottleneck-68 [-1, 512, 28, 28] 0
Conv2d-69 [-1, 128, 28, 28] 65,536
BatchNorm2d-70 [-1, 128, 28, 28] 256
ReLU-71 [-1, 128, 28, 28] 0
Conv2d-72 [-1, 128, 28, 28] 147,456
BatchNorm2d-73 [-1, 128, 28, 28] 256
ReLU-74 [-1, 128, 28, 28] 0
Conv2d-75 [-1, 512, 28, 28] 65,536
BatchNorm2d-76 [-1, 512, 28, 28] 1,024
ReLU-77 [-1, 512, 28, 28] 0
Bottleneck-78 [-1, 512, 28, 28] 0
Conv2d-79 [-1, 256, 28, 28] 131,072
BatchNorm2d-80 [-1, 256, 28, 28] 512
ReLU-81 [-1, 256, 28, 28] 0
Conv2d-82 [-1, 256, 14, 14] 589,824
BatchNorm2d-83 [-1, 256, 14, 14] 512
ReLU-84 [-1, 256, 14, 14] 0
Conv2d-85 [-1, 1024, 14, 14] 262,144
BatchNorm2d-86 [-1, 1024, 14, 14] 2,048
Conv2d-87 [-1, 1024, 14, 14] 524,288
BatchNorm2d-88 [-1, 1024, 14, 14] 2,048
ReLU-89 [-1, 1024, 14, 14] 0
Bottleneck-90 [-1, 1024, 14, 14] 0
Conv2d-91 [-1, 256, 14, 14] 262,144
BatchNorm2d-92 [-1, 256, 14, 14] 512
ReLU-93 [-1, 256, 14, 14] 0
Conv2d-94 [-1, 256, 14, 14] 589,824
BatchNorm2d-95 [-1, 256, 14, 14] 512
ReLU-96 [-1, 256, 14, 14] 0
Conv2d-97 [-1, 1024, 14, 14] 262,144
BatchNorm2d-98 [-1, 1024, 14, 14] 2,048
ReLU-99 [-1, 1024, 14, 14] 0
Bottleneck-100 [-1, 1024, 14, 14] 0
Conv2d-101 [-1, 256, 14, 14] 262,144
BatchNorm2d-102 [-1, 256, 14, 14] 512
ReLU-103 [-1, 256, 14, 14] 0
Conv2d-104 [-1, 256, 14, 14] 589,824
BatchNorm2d-105 [-1, 256, 14, 14] 512
ReLU-106 [-1, 256, 14, 14] 0
Conv2d-107 [-1, 1024, 14, 14] 262,144
BatchNorm2d-108 [-1, 1024, 14, 14] 2,048
ReLU-109 [-1, 1024, 14, 14] 0
Bottleneck-110 [-1, 1024, 14, 14] 0
Conv2d-111 [-1, 256, 14, 14] 262,144
BatchNorm2d-112 [-1, 256, 14, 14] 512
ReLU-113 [-1, 256, 14, 14] 0
Conv2d-114 [-1, 256, 14, 14] 589,824
BatchNorm2d-115 [-1, 256, 14, 14] 512
ReLU-116 [-1, 256, 14, 14] 0
Conv2d-117 [-1, 1024, 14, 14] 262,144
BatchNorm2d-118 [-1, 1024, 14, 14] 2,048
ReLU-119 [-1, 1024, 14, 14] 0
Bottleneck-120 [-1, 1024, 14, 14] 0
Conv2d-121 [-1, 256, 14, 14] 262,144
BatchNorm2d-122 [-1, 256, 14, 14] 512
ReLU-123 [-1, 256, 14, 14] 0
Conv2d-124 [-1, 256, 14, 14] 589,824
BatchNorm2d-125 [-1, 256, 14, 14] 512
ReLU-126 [-1, 256, 14, 14] 0
Conv2d-127 [-1, 1024, 14, 14] 262,144
BatchNorm2d-128 [-1, 1024, 14, 14] 2,048
ReLU-129 [-1, 1024, 14, 14] 0
Bottleneck-130 [-1, 1024, 14, 14] 0
Conv2d-131 [-1, 256, 14, 14] 262,144
BatchNorm2d-132 [-1, 256, 14, 14] 512
ReLU-133 [-1, 256, 14, 14] 0
Conv2d-134 [-1, 256, 14, 14] 589,824
BatchNorm2d-135 [-1, 256, 14, 14] 512
ReLU-136 [-1, 256, 14, 14] 0
Conv2d-137 [-1, 1024, 14, 14] 262,144
BatchNorm2d-138 [-1, 1024, 14, 14] 2,048
ReLU-139 [-1, 1024, 14, 14] 0
Bottleneck-140 [-1, 1024, 14, 14] 0
Conv2d-141 [-1, 512, 14, 14] 524,288
BatchNorm2d-142 [-1, 512, 14, 14] 1,024
ReLU-143 [-1, 512, 14, 14] 0
Conv2d-144 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-145 [-1, 512, 7, 7] 1,024
ReLU-146 [-1, 512, 7, 7] 0
Conv2d-147 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-148 [-1, 2048, 7, 7] 4,096
Conv2d-149 [-1, 2048, 7, 7] 2,097,152
BatchNorm2d-150 [-1, 2048, 7, 7] 4,096
ReLU-151 [-1, 2048, 7, 7] 0
Bottleneck-152 [-1, 2048, 7, 7] 0
Conv2d-153 [-1, 512, 7, 7] 1,048,576
BatchNorm2d-154 [-1, 512, 7, 7] 1,024
ReLU-155 [-1, 512, 7, 7] 0
Conv2d-156 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-157 [-1, 512, 7, 7] 1,024
ReLU-158 [-1, 512, 7, 7] 0
Conv2d-159 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-160 [-1, 2048, 7, 7] 4,096
ReLU-161 [-1, 2048, 7, 7] 0
Bottleneck-162 [-1, 2048, 7, 7] 0
Conv2d-163 [-1, 512, 7, 7] 1,048,576
BatchNorm2d-164 [-1, 512, 7, 7] 1,024
ReLU-165 [-1, 512, 7, 7] 0
Conv2d-166 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-167 [-1, 512, 7, 7] 1,024
ReLU-168 [-1, 512, 7, 7] 0
Conv2d-169 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-170 [-1, 2048, 7, 7] 4,096
ReLU-171 [-1, 2048, 7, 7] 0
Bottleneck-172 [-1, 2048, 7, 7] 0
AdaptiveMaxPool2d-173 [-1, 2048, 1, 1] 0
AdaptiveAvgPool2d-174 [-1, 2048, 1, 1] 0
AdaptiveConcatPool2d-175 [-1, 4096, 1, 1] 0
Flatten-176 [-1, 4096] 0
BatchNorm1d-177 [-1, 4096] 8,192
Dropout-178 [-1, 4096] 0
Linear-179 [-1, 512] 2,097,664
ReLU-180 [-1, 512] 0
BatchNorm1d-181 [-1, 512] 1,024
Dropout-182 [-1, 512] 0
Linear-183 [-1, 6000] 3,078,000
================================================================
Total params: 28,692,912
Trainable params: 28,692,912
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 286.75
Params size (MB): 109.45
Estimated Total Size (MB): 396.78
----------------------------------------------------------------
'''
class AdaptiveConcatPool2d(nn.Module):
"""
Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`.
Source: Fastai. This code was taken from the fastai library at url
https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176
"""
def __init__(self, sz=None):
"Output will be 2*sz or 2 if sz is None"
super().__init__()
self.output_size = sz or 1
self.ap = nn.AdaptiveAvgPool2d(self.output_size)
self.mp = nn.AdaptiveMaxPool2d(self.output_size)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
class Flatten(nn.Module):
"""
Flatten `x` to a single dimension. Adapted from fastai's Flatten() layer,
at https://github.com/fastai/fastai/blob/master/fastai/layers.py#L25
"""
def __init__(self): super().__init__()
def forward(self, x): return x.view(x.size(0), -1)
def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0., actn=None):
"""
Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`.
Adapted from Fastai at https://github.com/fastai/fastai/blob/master/fastai/layers.py#L44
"""
layers = [nn.BatchNorm1d(n_in)] if bn else []
if p != 0: layers.append(nn.Dropout(p))
layers.append(nn.Linear(n_in, n_out))
if actn is not None: layers.append(actn)
return layers
def create_head(top_n_tags, nf, ps=0.5):
nc = top_n_tags
lin_ftrs = [nf, 512, nc]
p1 = 0.25 # dropout for second last layer
p2 = 0.5 # dropout for last layer
actns = [nn.ReLU(inplace=True),] + [None]
pool = AdaptiveConcatPool2d()
layers = [pool, Flatten()]
layers += [
*bn_drop_lin(lin_ftrs[0], lin_ftrs[1], True, p1, nn.ReLU(inplace=True)),
*bn_drop_lin(lin_ftrs[1], lin_ftrs[2], True, p2)
]
return nn.Sequential(*layers)
def _resnet(base_arch, top_n, **kwargs):
cut = -2
s = base_arch(pretrained=False, **kwargs)
body = nn.Sequential(*list(s.children())[:cut])
if base_arch in [models.resnet18, models.resnet34]:
num_features_model = 512
elif base_arch in [models.resnet50, models.resnet101]:
num_features_model = 2048
nf = num_features_model * 2
nc = top_n
# head = create_head(nc, nf)
model = body # nn.Sequential(body, head)
return model
def resnet50(pretrained=True, progress=True, top_n=6000, **kwargs):
r"""
Resnet50 model trained on the full Danbooru2018 dataset's top 6000 tags
Args:
pretrained (bool): kwargs, load pretrained weights into the model.
top_n (int): kwargs, pick to load the model for predicting the top `n` tags,
currently only supports top_n=6000.
"""
model = _resnet(models.resnet50, top_n, **kwargs) # Take Resnet without the head (we don't care about final FC layers)
if pretrained:
if top_n == 6000:
state = torch.hub.load_state_dict_from_url("https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth",
progress=progress)
old_keys = [key for key in state]
for old_key in old_keys:
if old_key[0] == '0':
new_key = old_key[2:]
state[new_key] = state[old_key]
del state[old_key]
elif old_key[0] == '1':
del state[old_key]
model.load_state_dict(state)
else:
raise ValueError("Sorry, the resnet50 model only supports the top-6000 tags \
at the moment")
return model
class resnet50_Extractor(nn.Module):
"""ResNet50 network for feature extraction.
"""
def get_activation(self, name):
def hook(model, input, output):
self.activation[name] = output.detach()
return hook
def __init__(self,
model,
layer_labels,
use_input_norm=True,
range_norm=False,
requires_grad=False
):
super(resnet50_Extractor, self).__init__()
self.model = model
self.use_input_norm = use_input_norm
self.range_norm = range_norm
self.layer_labels = layer_labels
self.activation = {}
# Extract needed features
for layer_label in layer_labels:
elements = layer_label.split('_')
if len(elements) == 1:
# modified_net[layer_label] = getattr(model, elements[0])
getattr(self.model, elements[0]).register_forward_hook(self.get_activation(layer_label))
else:
body_layer = self.model
for element in elements[:-1]:
# Iterate until the last element
assert(isinstance(int(element), int))
body_layer = body_layer[int(element)]
getattr(body_layer, elements[-1]).register_forward_hook(self.get_activation(layer_label))
# Set as evaluation
if not requires_grad:
self.model.eval()
for param in self.parameters():
param.requires_grad = False
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [0, 1]
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.range_norm:
x = (x + 1) / 2
if self.use_input_norm:
x = (x - self.mean) / self.std
# Execute model first
output = self.model(x) # Zomby input
# Extract the layers we need
store = {}
for layer_label in self.layer_labels:
store[layer_label] = self.activation[layer_label]
return store
class Anime_PerceptualLoss(nn.Module):
"""Anime Perceptual loss
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'conv5_4': 1.}, which means the conv5_4
feature layer (before relu5_4) will be extracted with weight
1.0 in calculating losses.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
def __init__(self,
layer_weights,
perceptual_weight=1.0,
criterion='l1'):
super(Anime_PerceptualLoss, self).__init__()
model = resnet50()
self.perceptual_weight = perceptual_weight
self.layer_weights = layer_weights
self.layer_labels = layer_weights.keys()
self.resnet50 = resnet50_Extractor(model, self.layer_labels).cuda()
if criterion == 'l1':
self.criterion = torch.nn.L1Loss()
else:
raise NotImplementedError("We don't support such criterion loss in perceptual loss")
def forward(self, gen, gt):
"""Forward function.
Args:
gen (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# extract vgg features
gen_features = self.resnet50(gen)
gt_features = self.resnet50(gt.detach())
temp_store = []
# calculate perceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for idx, k in enumerate(gen_features.keys()):
raw_comparison = self.criterion(gen_features[k], gt_features[k])
percep_loss += raw_comparison * self.layer_weights[k]
# print("layer" + str(idx) + " has loss " + str(raw_comparison.cpu().numpy()))
# temp_store.append(float(raw_comparison.cpu().numpy()))
percep_loss *= self.perceptual_weight
else:
percep_loss = None
# 第一个是为了Debug purpose
if len(temp_store) != 0:
return temp_store, percep_loss
else:
return percep_loss
if __name__ == "__main__":
import torchvision.transforms as transforms
import cv2
import collections
loss = Anime_PerceptualLoss({"0": 0.5, "4_2_conv3": 20, "5_3_conv3": 30, "6_5_conv3": 1, "7_2_conv3": 1}).cuda()
store = collections.defaultdict(list)
for img_name in sorted(os.listdir('datasets/train_gen/')):
gen = transforms.ToTensor()(cv2.imread('datasets/train_gen/'+img_name)).cuda()
gt = transforms.ToTensor()(cv2.imread('datasets/train_hr_anime_usm/'+img_name)).cuda()
temp_store, _ = loss(gen, gt)
for idx in range(len(temp_store)):
store[idx].append(temp_store[idx])
for idx in range(len(store)):
print("Average layer" + str(idx) + " has loss " + str(sum(store[idx]) / len(store[idx])))
# model = loss.vgg
# pytorch_total_params = sum(p.numel() for p in model.parameters())
# print(f"Perceptual VGG has param {pytorch_total_params//1000000} M params")