|
import argparse |
|
import os |
|
|
|
import kornia |
|
import torch |
|
import torch.nn.functional as F |
|
import tqdm |
|
from torch import nn |
|
from torch.utils.data import DataLoader |
|
|
|
import models |
|
from datasets import LowLightDataset |
|
from tools import saver, mutils |
|
from models import PSNR, SSIM |
|
import numpy as np |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser('Breaking Downing the Darkness') |
|
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used') |
|
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader') |
|
parser.add_argument('--batch_size', type=int, default=4, help='The number of images per batch among all devices') |
|
parser.add_argument('-m1', '--model1', type=str, default='IANet', help='Model1 Name') |
|
parser.add_argument('-m2', '--model2', type=str, default='NSNet', help='Model2 Name') |
|
parser.add_argument('-m3', '--model3', type=str, default='FuseNet', help='Model3 Name') |
|
parser.add_argument('-m4', '--model4', type=str, default=None, help='Model4 Name') |
|
|
|
parser.add_argument('-m1w', '--model1_weight', type=str, default=None, help='Model weight of IAN') |
|
parser.add_argument('-m2w', '--model2_weight', type=str, default=None, help='Model weight of ANSN') |
|
parser.add_argument('-m3w', '--model3_weight', type=str, default=None, help='Model weight of CAN') |
|
parser.add_argument('-m4w', '--model4_weight', type=str, default=None, help='Model weight of NFM') |
|
|
|
parser.add_argument('--mef', action='store_true', help='using color adation based MEF data or not') |
|
parser.add_argument('--gc', action='store_true', help='using gamma correction or not') |
|
parser.add_argument('--save_extra', action='store_true', help='save intermediate outputs or not') |
|
|
|
parser.add_argument('--comment', type=str, default='default', |
|
help='Project comment') |
|
|
|
parser.add_argument('--alpha', '-a', type=float, default=0.10) |
|
parser.add_argument('--lr', type=float, default=0.01) |
|
parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, ' |
|
'suggest using \'admaw\' until the' |
|
' very final stage then switch to \'sgd\'') |
|
parser.add_argument('--data_path', type=str, default='./data/LOL/eval', |
|
help='the root folder of dataset') |
|
parser.add_argument('--log_path', type=str, default='logs/') |
|
parser.add_argument('--saved_path', type=str, default='logs/') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
class ModelBreadNet(nn.Module): |
|
def __init__(self, model1, model2, model3, model4): |
|
super().__init__() |
|
self.eps = 1e-6 |
|
self.model_ianet = model1(in_channels=1, out_channels=1) |
|
self.model_nsnet = model2(in_channels=2, out_channels=1) |
|
self.model_canet = model3(in_channels=4, out_channels=2) if opt.mef else model3(in_channels=6, out_channels=2) |
|
self.model_fdnet = model4(in_channels=3, out_channels=1) if opt.model4 else None |
|
self.load_weight(self.model_ianet, opt.model1_weight) |
|
self.load_weight(self.model_nsnet, opt.model2_weight) |
|
self.load_weight(self.model_canet, opt.model3_weight) |
|
self.load_weight(self.model_fdnet, opt.model4_weight) |
|
|
|
def load_weight(self, model, weight_pth): |
|
if model is not None: |
|
state_dict = torch.load(weight_pth) |
|
ret = model.load_state_dict(state_dict, strict=True) |
|
print(ret) |
|
|
|
def noise_syn_exp(self, illumi, strength): |
|
return torch.exp(-illumi) * strength |
|
|
|
def forward(self, image, image_gt): |
|
|
|
texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1) |
|
texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1) |
|
|
|
|
|
texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True) |
|
texture_illumi = self.model_ianet(texture_in_down) |
|
texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True) |
|
|
|
|
|
texture_illumi = torch.clamp(texture_illumi, 0., 1.) |
|
texture_ia = texture_in / torch.clamp_min(texture_illumi, self.eps) |
|
texture_ia = torch.clamp(texture_ia, 0., 1.) |
|
|
|
|
|
texture_nss = [] |
|
for strength in [0., 0.05, 0.1]: |
|
attention = self.noise_syn_exp(texture_illumi, strength=strength) |
|
texture_res = self.model_nsnet(torch.cat([texture_ia, attention], dim=1)) |
|
texture_ns = texture_ia + texture_res |
|
texture_nss.append(texture_ns) |
|
texture_nss = torch.cat(texture_nss, dim=1).detach() |
|
texture_fd = self.model_fdnet(texture_nss) |
|
|
|
|
|
|
|
if opt.gc: |
|
max_psnr = 0 |
|
best = None |
|
for ga in np.arange(0.1, 2.0, 0.01): |
|
tx_en = texture_fd ** ga |
|
psnr = PSNR(tx_en, texture_gt) |
|
if psnr > max_psnr: |
|
max_psnr = psnr |
|
best = tx_en |
|
|
|
texture_fd = torch.clamp(best, 0, 1) |
|
|
|
|
|
if not opt.mef: |
|
image_ia_ycbcr = kornia.color.rgb_to_ycbcr(torch.clamp(image / (texture_illumi + self.eps), 0, 1)) |
|
_, cb_ia, cr_ia = torch.split(image_ia_ycbcr, 1, dim=1) |
|
colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_fd, cb_ia, cr_ia], dim=1)) |
|
else: |
|
colors = self.model_canet( |
|
torch.cat([texture_in, cb_in, cr_in, texture_fd], dim=1)) |
|
|
|
cb_out, cr_out = torch.split(colors, 1, dim=1) |
|
cb_out = torch.clamp(cb_out, 0, 1) |
|
cr_out = torch.clamp(cr_out, 0, 1) |
|
|
|
|
|
image_out = kornia.color.ycbcr_to_rgb( |
|
torch.cat([texture_fd, cb_out, cr_out], dim=1)) |
|
image_out = torch.clamp(image_out, 0, 1) |
|
|
|
|
|
psnr = PSNR(image_out, image_gt) |
|
ssim = SSIM(image_out, image_gt).item() |
|
|
|
return texture_ia, texture_nss, texture_fd, image_out, texture_illumi, texture_res, psnr, ssim |
|
|
|
|
|
def evaluation(opt): |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(42) |
|
else: |
|
torch.manual_seed(42) |
|
|
|
timestamp = mutils.get_formatted_time() |
|
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}' |
|
os.makedirs(opt.saved_path, exist_ok=True) |
|
|
|
val_params = {'batch_size': 1, |
|
'shuffle': False, |
|
'drop_last': False, |
|
'num_workers': opt.num_workers} |
|
|
|
val_set = LowLightDataset(opt.data_path) |
|
|
|
val_generator = DataLoader(val_set, **val_params) |
|
val_generator = tqdm.tqdm(val_generator) |
|
|
|
model1 = getattr(models, opt.model1) |
|
model2 = getattr(models, opt.model2) |
|
model3 = getattr(models, opt.model3) |
|
model4 = getattr(models, opt.model4) if opt.model4 else None |
|
|
|
model = ModelBreadNet(model1, model2, model3, model4) |
|
print(model) |
|
|
|
if opt.num_gpus > 0: |
|
model = model.cuda() |
|
if opt.num_gpus > 1: |
|
model = nn.DataParallel(model) |
|
|
|
model.eval() |
|
psnrs, ssims, fns = [], [], [] |
|
for iter, (data, target, name) in enumerate(val_generator): |
|
saver.base_url = os.path.join(opt.saved_path, 'results') |
|
with torch.no_grad(): |
|
if opt.num_gpus == 1: |
|
data = data.cuda() |
|
target = target.cuda() |
|
texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(data), 1, dim=1) |
|
texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(target), 1, dim=1) |
|
texture_ia, texture_nss, texture_fd, image_out, \ |
|
texture_illumi, texture_res, psnr, ssim = model(data, target) |
|
if opt.save_extra: |
|
saver.save_image(data, name=os.path.splitext(name[0])[0] + '_im_in') |
|
saver.save_image(target, name=os.path.splitext(name[0])[0] + '_im_gt') |
|
|
|
saver.save_image(texture_in, name=os.path.splitext(name[0])[0] + '_y_in') |
|
saver.save_image(texture_gt, name=os.path.splitext(name[0])[0] + '_y_gt') |
|
|
|
saver.save_image(texture_ia, name=os.path.splitext(name[0])[0] + '_ia') |
|
for i in range(texture_nss.shape[1]): |
|
saver.save_image(texture_nss[:, i, ...], name=os.path.splitext(name[0])[0] + f'_ns_{i}') |
|
saver.save_image(texture_fd, name=os.path.splitext(name[0])[0] + '_fd') |
|
|
|
saver.save_image(texture_illumi, name=os.path.splitext(name[0])[0] + '_illumi') |
|
saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res') |
|
|
|
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_out') |
|
else: |
|
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_Bread') |
|
|
|
psnrs.append(psnr) |
|
ssims.append(ssim) |
|
fns.append(name[0]) |
|
|
|
results = list(zip(psnrs, ssims, fns)) |
|
results.sort(key=lambda item: item[0]) |
|
for r in results: |
|
print(*r) |
|
psnr = np.mean(np.array(psnrs)) |
|
ssim = np.mean(np.array(ssims)) |
|
print('psnr: ', psnr, ', ssim: ', ssim) |
|
|
|
|
|
if __name__ == '__main__': |
|
opt = get_args() |
|
evaluation(opt) |
|
|