import argparse |
import os |
import kornia |
import torch |
import torch.nn.functional as F |
import tqdm |
from torch import nn |
from torch.utils.data import DataLoader |
import models |
from datasets import LowLightDatasetTest |
from tools import saver, mutils |
def get_args(): |
parser = argparse.ArgumentParser('Breaking Downing the Darkness') |
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used') |
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader') |
parser.add_argument('--batch_size', type=int, default=4, help='The number of images per batch among all devices') |
parser.add_argument('-m1', '--model1', type=str, default='IANet', help='Model1 Name') |
parser.add_argument('-m2', '--model2', type=str, default='NSNet', help='Model2 Name') |
parser.add_argument('-m3', '--model3', type=str, default='FuseNet', help='Model3 Name') |
parser.add_argument('-m4', '--model4', type=str, default=None, help='Model4 Name') |
parser.add_argument('-m1w', '--model1_weight', type=str, default=None, help='Model weight of IAN') |
parser.add_argument('-m2w', '--model2_weight', type=str, default=None, help='Model weight of ANSN') |
parser.add_argument('-m3w', '--model3_weight', type=str, default=None, help='Model weight of CAN') |
parser.add_argument('-m4w', '--model4_weight', type=str, default=None, help='Model weight of NFM') |
parser.add_argument('--mef', action='store_true') |
parser.add_argument('--save_extra', action='store_true', help='save intermediate outputs or not') |
parser.add_argument('--comment', type=str, default='default', |
help='Project comment') |
parser.add_argument('--alpha', '-a', type=float, default=0.10) |
parser.add_argument('--lr', type=float, default=0.01) |
parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, ' |
'suggest using \'admaw\' until the' |
' very final stage then switch to \'sgd\'') |
parser.add_argument('--data_path', type=str, default='./data/test', |
help='the root folder of dataset') |
parser.add_argument('--log_path', type=str, default='logs/') |
parser.add_argument('--saved_path', type=str, default='logs/') |
args = parser.parse_args() |
return args |
class ModelBreadNet(nn.Module): |
def __init__(self, model1, model2, model3, model4): |
super().__init__() |
self.eps = 1e-6 |
self.model_ianet = model1(in_channels=1, out_channels=1) |
self.model_nsnet = model2(in_channels=2, out_channels=1) |
self.model_canet = model3(in_channels=4, out_channels=2) if opt.mef else model3(in_channels=6, out_channels=2) |
self.model_fdnet = model4(in_channels=3, out_channels=1) if opt.model4 else None |
self.load_weight(self.model_ianet, opt.model1_weight) |
self.load_weight(self.model_nsnet, opt.model2_weight) |
self.load_weight(self.model_canet, opt.model3_weight) |
self.load_weight(self.model_fdnet, opt.model4_weight) |
def load_weight(self, model, weight_pth): |
if model is not None: |
state_dict = torch.load(weight_pth) |
ret = model.load_state_dict(state_dict, strict=True) |
print(ret) |
def noise_syn_exp(self, illumi, strength): |
return torch.exp(-illumi) * strength |
def forward(self, image): |
texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1) |
texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True) |
texture_illumi = self.model_ianet(texture_in_down) |
texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True) |
texture_illumi = torch.clamp(texture_illumi, 0., 1.) |
texture_ia = texture_in / torch.clamp_min(texture_illumi, self.eps) |
texture_ia = torch.clamp(texture_ia, 0., 1.) |
texture_nss = [] |
for strength in [0., 0.05, 0.1]: |
attention = self.noise_syn_exp(texture_illumi, strength=strength) |
texture_res = self.model_nsnet(torch.cat([texture_ia, attention], dim=1)) |
texture_ns = texture_ia + texture_res |
texture_nss.append(texture_ns) |
texture_nss = torch.cat(texture_nss, dim=1).detach() |
texture_fd = self.model_fdnet(texture_nss) |
texture_fd = texture_illumi * texture_in + (1 - texture_illumi) * texture_fd |
texture_fd = torch.clamp(texture_fd, 0, 1) |
if not opt.mef: |
image_ia_ycbcr = kornia.color.rgb_to_ycbcr(torch.clamp(image / (texture_illumi + self.eps), 0, 1)) |
_, cb_ia, cr_ia = torch.split(image_ia_ycbcr, 1, dim=1) |
colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_fd, cb_ia, cr_ia], dim=1)) |
else: |
colors = self.model_canet( |
torch.cat([texture_in, cb_in, cr_in, texture_fd], dim=1)) |
cb_out, cr_out = torch.split(colors, 1, dim=1) |
cb_out = torch.clamp(cb_out, 0, 1) |
cr_out = torch.clamp(cr_out, 0, 1) |
image_out = kornia.color.ycbcr_to_rgb( |
torch.cat([texture_fd, cb_out, cr_out], dim=1)) |
img_fusion = texture_illumi * image + (1 - texture_illumi) * image_out |
_, cb_fuse, cr_fuse = torch.split(kornia.color.rgb_to_ycbcr(img_fusion), 1, dim=1) |
image_out = kornia.color.ycbcr_to_rgb( |
torch.cat([texture_fd, cb_fuse, cr_fuse], dim=1)) |
image_out = torch.clamp(image_out, 0, 1) |
return texture_ia, texture_nss, texture_fd, image_out, texture_illumi, texture_res |
def test(opt): |
if torch.cuda.is_available(): |
torch.cuda.manual_seed(42) |
else: |
torch.manual_seed(42) |
timestamp = mutils.get_formatted_time() |
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}' |
os.makedirs(opt.saved_path, exist_ok=True) |
test_params = {'batch_size': 1, |
'shuffle': False, |
'drop_last': False, |
'num_workers': opt.num_workers} |
test_set = LowLightDatasetTest(opt.data_path) |
test_generator = DataLoader(test_set, **test_params) |
test_generator = tqdm.tqdm(test_generator) |
model1 = getattr(models, opt.model1) |
model2 = getattr(models, opt.model2) |
model3 = getattr(models, opt.model3) |
model4 = getattr(models, opt.model4) |
model = ModelBreadNet(model1, model2, model3, model4) |
print(model) |
if opt.num_gpus > 0: |
model = model.cuda() |
if opt.num_gpus > 1: |
model = nn.DataParallel(model) |
model.eval() |
for iter, (data, subset, name) in enumerate(test_generator): |
saver.base_url = os.path.join(opt.saved_path, 'results', subset[0]) |
with torch.no_grad(): |
if opt.num_gpus == 1: |
data = data.cuda() |
texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(data), 1, dim=1) |
texture_ia, texture_nss, texture_fd, image_out, texture_illumi, texture_res = model(data) |
if opt.save_extra: |
saver.save_image(data, name=os.path.splitext(name[0])[0] + '_im_in') |
saver.save_image(texture_in, name=os.path.splitext(name[0])[0] + '_y_in') |
saver.save_image(texture_ia, name=os.path.splitext(name[0])[0] + '_ia') |
for i in range(texture_nss.shape[1]): |
saver.save_image(texture_nss[:, i, ...], name=os.path.splitext(name[0])[0] + f'_ns_{i}') |
saver.save_image(texture_fd, name=os.path.splitext(name[0])[0] + '_fd') |
saver.save_image(texture_illumi, name=os.path.splitext(name[0])[0] + '_illumi') |
saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res') |
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_out') |
else: |
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_Bread') |
def save_checkpoint(model, name): |
if isinstance(model, nn.DataParallel): |
torch.save(model.module3.model_nsnet.state_dict(), os.path.join(opt.saved_path, name)) |
else: |
torch.save(model.model_nsnet.state_dict(), os.path.join(opt.saved_path, name)) |
if __name__ == '__main__': |
opt = get_args() |
test(opt) |