|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import functools |
|
|
|
|
|
import kornia |
|
import torch.nn.functional as F |
|
import torchvision.models |
|
|
|
try: |
|
import archs.arch_util as arch_util |
|
from archs.NAFBlock import * |
|
|
|
except: |
|
import arch_util as arch_util |
|
from NAFBlock import * |
|
class VGG19(torch.nn.Module): |
|
|
|
def __init__(self, requires_grad=False): |
|
super().__init__() |
|
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features |
|
self.slice1 = torch.nn.Sequential() |
|
self.slice2 = torch.nn.Sequential() |
|
self.slice3 = torch.nn.Sequential() |
|
self.slice4 = torch.nn.Sequential() |
|
self.slice5 = torch.nn.Sequential() |
|
for x in range(2): |
|
self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(2, 7): |
|
self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(7, 12): |
|
self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(12, 21): |
|
self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(21, 30): |
|
self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
|
if not requires_grad: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, X): |
|
h_relu1 = self.slice1(X) |
|
h_relu2 = self.slice2(h_relu1) |
|
h_relu3 = self.slice3(h_relu2) |
|
h_relu4 = self.slice4(h_relu3) |
|
h_relu5 = self.slice5(h_relu4) |
|
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] |
|
return out |
|
|
|
class VGGLoss(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super(VGGLoss, self).__init__() |
|
self.vgg = VGG19().cuda() |
|
|
|
self.criterion = nn.L1Loss(reduction='sum') |
|
self.criterion2 = nn.L1Loss() |
|
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] |
|
|
|
def forward(self, x, y): |
|
|
|
x_vgg, y_vgg = self.vgg(x), self.vgg(y) |
|
|
|
loss = 0 |
|
for i in range(len(x_vgg)): |
|
|
|
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) |
|
|
|
|
|
return loss |
|
|
|
|
|
class FourNet(nn.Module): |
|
def __init__(self, nf=64): |
|
super(FourNet, self).__init__() |
|
|
|
|
|
self.AmpNet = nn.Sequential( |
|
AmplitudeNet_skip(8), |
|
nn.Sigmoid() |
|
) |
|
|
|
self.nf = nf |
|
ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) |
|
|
|
self.conv_first_1 = nn.Conv2d(3 * 2, nf, 3, 1, 1, bias=True) |
|
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) |
|
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) |
|
|
|
self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, 1) |
|
self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, 1) |
|
|
|
self.upconv1 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True) |
|
self.upconv2 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True) |
|
self.pixel_shuffle = nn.PixelShuffle(2) |
|
self.HRconv = nn.Conv2d(nf*2, nf, 3, 1, 1, bias=True) |
|
self.conv_last = nn.Conv2d(nf, 3, 3, 1, 1, bias=True) |
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) |
|
self.transformer = SFNet(nf, n = 4) |
|
self.recon_trunk_light = arch_util.make_layer(ResidualBlock_noBN_f, 6) |
|
|
|
def get_mask(self,dark): |
|
|
|
light = kornia.filters.gaussian_blur2d(dark, (5, 5), (1.5, 1.5)) |
|
dark = dark[:, 0:1, :, :] * 0.299 + dark[:, 1:2, :, :] * 0.587 + dark[:, 2:3, :, :] * 0.114 |
|
light = light[:, 0:1, :, :] * 0.299 + light[:, 1:2, :, :] * 0.587 + light[:, 2:3, :, :] * 0.114 |
|
noise = torch.abs(dark - light) |
|
|
|
mask = torch.div(light, noise + 0.0001) |
|
|
|
batch_size = mask.shape[0] |
|
height = mask.shape[2] |
|
width = mask.shape[3] |
|
mask_max = torch.max(mask.view(batch_size, -1), dim=1)[0] |
|
mask_max = mask_max.view(batch_size, 1, 1, 1) |
|
mask_max = mask_max.repeat(1, 1, height, width) |
|
mask = mask * 1.0 / (mask_max + 0.0001) |
|
|
|
mask = torch.clamp(mask, min=0, max=1.0) |
|
return mask.float() |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
_, _, H, W = x.shape |
|
image_fft = torch.fft.fft2(x, norm='backward') |
|
mag_image = torch.abs(image_fft) |
|
pha_image = torch.angle(image_fft) |
|
curve_amps = self.AmpNet(x) |
|
mag_image = mag_image / (curve_amps + 0.00000001) |
|
real_image_enhanced = mag_image * torch.cos(pha_image) |
|
imag_image_enhanced = mag_image * torch.sin(pha_image) |
|
img_amp_enhanced = torch.fft.ifft2(torch.complex(real_image_enhanced, imag_image_enhanced), s=(H, W), |
|
norm='backward').real |
|
|
|
x_center = img_amp_enhanced |
|
|
|
rate = 2 ** 3 |
|
pad_h = (rate - H % rate) % rate |
|
pad_w = (rate - W % rate) % rate |
|
if pad_h != 0 or pad_w != 0: |
|
x_center = F.pad(x_center, (0, pad_w, 0, pad_h), "reflect") |
|
x = F.pad(x, (0, pad_w, 0, pad_h), "reflect") |
|
|
|
|
|
|
|
L1_fea_1 = self.lrelu(self.conv_first_1(torch.cat((x_center,x),dim=1))) |
|
L1_fea_2 = self.lrelu(self.conv_first_2(L1_fea_1)) |
|
L1_fea_3 = self.lrelu(self.conv_first_3(L1_fea_2)) |
|
|
|
fea = self.feature_extraction(L1_fea_3) |
|
fea_light = self.recon_trunk_light(fea) |
|
|
|
h_feature = fea.shape[2] |
|
w_feature = fea.shape[3] |
|
mask_image = self.get_mask(x_center) |
|
mask = F.interpolate(mask_image, size=[h_feature, w_feature], mode='nearest') |
|
|
|
fea_unfold = self.transformer(fea) |
|
|
|
channel = fea.shape[1] |
|
mask = mask.repeat(1, channel, 1, 1) |
|
fea = fea_unfold * (1 - mask) + fea_light * mask |
|
|
|
out_noise = self.recon_trunk(fea) |
|
out_noise = torch.cat([out_noise, L1_fea_3], dim=1) |
|
out_noise = self.lrelu(self.pixel_shuffle(self.upconv1(out_noise))) |
|
out_noise = torch.cat([out_noise, L1_fea_2], dim=1) |
|
out_noise = self.lrelu(self.pixel_shuffle(self.upconv2(out_noise))) |
|
out_noise = torch.cat([out_noise, L1_fea_1], dim=1) |
|
out_noise = self.lrelu(self.HRconv(out_noise)) |
|
out_noise = self.conv_last(out_noise) |
|
out_noise = out_noise + x |
|
out_noise = out_noise[:, :, :H, :W] |
|
|
|
|
|
return out_noise, mag_image, x_center, mask_image |