Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
from .nn import mean_flat | |
# input image range [-1,1] | |
class VGG(nn.Module): | |
def __init__(self, conv_index='22', rgb_range=1): | |
super(VGG, self).__init__() | |
vgg_features = models.vgg19(pretrained=True).features | |
modules = [m for m in vgg_features] | |
if conv_index.find('22') >= 0: | |
self.vgg = nn.Sequential(*modules[:8]) | |
elif conv_index.find('54') >= 0: | |
self.vgg = nn.Sequential(*modules[:35]) | |
vgg_mean = (0.485, 0.456, 0.406) | |
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) | |
self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std) | |
for p in self.parameters(): | |
p.requires_grad = False | |
def forward(self, sr, hr): | |
def _forward(x): | |
x = self.sub_mean(x) | |
x = self.vgg(x) | |
return x | |
sr = (sr + 1.)/2. | |
hr = (hr + 1.)/2. | |
vgg_sr = _forward(sr) | |
with torch.no_grad(): | |
vgg_hr = _forward(hr.detach()) | |
loss = mean_flat((vgg_sr - vgg_hr) ** 2) | |
return loss | |
class MeanShift(nn.Conv2d): | |
def __init__( | |
self, rgb_range, | |
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): | |
super(MeanShift, self).__init__(3, 3, kernel_size=1) | |
std = torch.Tensor(rgb_std) | |
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) | |
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std | |
for p in self.parameters(): | |
p.requires_grad = False |