diff --git a/.gitignore b/.gitignore index 740bd3e54af0b685cbd2e9c0acd29c9e0311a689..9b23c36b27c0163cc56f5f5cb648b4d4b2a36ee6 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ -model/__pycache__/ \ No newline at end of file +model/__pycache__/ +models/llflow/LOLdataset.zip +models/llflow/dataset_samples +models/results \ No newline at end of file diff --git a/README.md b/README.md index 50b16593811c0de4324832fd2a6ffc65f3b52310..c2f06ba6650dd49035b2fcaa8435ab1b2eb47e3d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ --- -title: Screen Image Demoireing +title: Image Enhancement emoji: 🔥 sdk: gradio sdk_version: 3.10.1 diff --git a/app.py b/app.py index 7f9a11dcdc2d9702b2c8edfdbfe8a2b4f365d155..a66b279f9abe9001d10f66f7b3c3643c2b9f7c8e 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,5 @@ import gradio as gr -from model.nets import my_model +from models.demoire.nets import my_model import torch import cv2 import torch.utils.data as data @@ -15,7 +15,9 @@ import torch.nn.functional as F from rich.panel import Panel from rich.columns import Columns from rich.console import Console -from models.gfpgan import gfpgan_predict +# from models.gfpgan import gfpgan_predict +from models.llflow.inference import main + os.environ["CUDA_VISIBLE_DEVICES"] = "1" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -89,13 +91,6 @@ def predict_gfpgan(img): with Console().status("[red] using [green] GFP-GAN v1.4", spinner="aesthetic"): # if image already exists with this name then delete it - if Path("input_image_gfpgan.jpg").exists(): - os.remove("input_image_gfpgan.jpg") - # save incoming PIL image to disk - img.save("input_image_gfpgan.jpg") - - out = gfpgan_predict(img) - Console().print(out) return img diff --git a/image_enhancement.egg-info/PKG-INFO b/image_enhancement.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..b08441bd895b8fb44d8f310b0816729c6217a536 --- /dev/null +++ b/image_enhancement.egg-info/PKG-INFO @@ -0,0 +1,9 @@ +Metadata-Version: 2.1 +Name: image-enhancement +Version: 1.0 +Summary: UNKNOWN +License: UNKNOWN +Platform: UNKNOWN + +UNKNOWN + diff --git a/image_enhancement.egg-info/SOURCES.txt b/image_enhancement.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..6de967bc8ebff33463819de6ecfad254fa7932d1 --- /dev/null +++ b/image_enhancement.egg-info/SOURCES.txt @@ -0,0 +1,38 @@ +README.md +setup.py +image_enhancement.egg-info/PKG-INFO +image_enhancement.egg-info/SOURCES.txt +image_enhancement.egg-info/dependency_links.txt +image_enhancement.egg-info/top_level.txt +model/__init__.py +model/nets.py +models/__init__.py +models/gfpgan.py +models/llflow/Measure.py +models/llflow/__init__.py +models/llflow/imresize.py +models/llflow/inference.py +models/llflow/option_.py +models/llflow/util.py +models/llflow/models/LLFlow_model.py +models/llflow/models/__init__.py +models/llflow/models/base_model.py +models/llflow/models/lr_scheduler.py +models/llflow/models/networks.py +models/llflow/models/modules/ConditionEncoder.py +models/llflow/models/modules/FlowActNorms.py +models/llflow/models/modules/FlowAffineCouplingsAblation.py +models/llflow/models/modules/FlowStep.py +models/llflow/models/modules/FlowUpsamplerNet.py +models/llflow/models/modules/LLFlow_arch.py +models/llflow/models/modules/Permutations.py +models/llflow/models/modules/RRDBNet_arch.py +models/llflow/models/modules/Split.py +models/llflow/models/modules/__init__.py +models/llflow/models/modules/base_layers.py +models/llflow/models/modules/color_encoder.py +models/llflow/models/modules/flow.py +models/llflow/models/modules/glow_arch.py +models/llflow/models/modules/loss.py +models/llflow/models/modules/module_util.py +models/llflow/models/modules/thops.py \ No newline at end of file diff --git a/image_enhancement.egg-info/dependency_links.txt b/image_enhancement.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/image_enhancement.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/image_enhancement.egg-info/top_level.txt b/image_enhancement.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..88fff902df174f18f5a9b1f8f15a3e9cc2b86683 --- /dev/null +++ b/image_enhancement.egg-info/top_level.txt @@ -0,0 +1,2 @@ +model +models diff --git a/models/README.md b/models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..13faaad117d5edd0e4ab7e9468ba121ef2af254f --- /dev/null +++ b/models/README.md @@ -0,0 +1,16 @@ +# References 📚 + + + +[1] Y. Wang, “[AAAI 2022 Oral] Low-Light Image Enhancement with Normalizing Flow.” Nov. 23, 2022. Accessed: Nov. 24, 2022. [Online]. Available: https://github.com/wyf0912/LLFlow +[2] L. Wang and K.-J. Yoon, “Deep Learning for HDR Imaging: State-of-the-Art and Future Trends.” arXiv, Nov. 07, 2021. Accessed: Nov. 24, 2022. [Online]. Available: http://arxiv.org/abs/2110.10394 +[3] Y. WANG, “Neural Color Operators for Sequential Image Retouching (ECCV2022).” Nov. 10, 2022. Accessed: Nov. 24, 2022. [Online]. Available: https://github.com/amberwangyili/neurop +[4] jwhe, “Conditional Sequential Modulation for Efficient Global Image Retouching Paper Link.” Nov. 23, 2022. Accessed: Nov. 24, 2022. [Online]. Available: https://github.com/hejingwenhejingwen/CSRNet +[5] Why, “Local Color Distributions Prior for Image Enhancement [ECCV2022].” Nov. 21, 2022. Accessed: Nov. 23, 2022. [Online]. Available: https://github.com/onpix/LCDPNet +[6] “Towards Efficient and Scale-Robust Ultra-High-Definition Image Demoiréing.” CVMI Lab, Nov. 21, 2022. Accessed: Nov. 21, 2022. [Online]. Available: https://github.com/CVMI-Lab/UHDM +[7] Z. Wang, “Uformer: A General U-Shaped Transformer for Image Restoration (CVPR 2022).” Nov. 20, 2022. Accessed: Nov. 21, 2022. [Online]. Available: https://github.com/ZhendongWang6/Uformer +[8] B. Zheng, “Learnbale_Bandpass_Filter.” Nov. 21, 2022. Accessed: Nov. 21, 2022. [Online]. Available: https://github.com/zhenngbolun/Learnbale_Bandpass_Filter +[9] K. Team, “Keras documentation: Enhanced Deep Residual Networks for single-image super-resolution.” https://keras.io/examples/vision/edsr/ (accessed Nov. 21, 2022). +[10] B. Lim, S. Son, H. Kim, S. Nah, and K. M. Lee, “Enhanced Deep Residual Networks for Single Image Super-Resolution.” arXiv, Jul. 10, 2017. doi: 10.48550/arXiv.1707.02921. +[11] C. Dong, C. C. Loy, K. He, and X. Tang, “Image Super-Resolution Using Deep Convolutional Networks.” arXiv, Jul. 31, 2015. doi: 10.48550/arXiv.1501.00092. +[12] Z. Anvari and V. Athitsos, “A Survey on Deep learning based Document Image Enhancement.” arXiv, Jan. 03, 2022. doi: 10.48550/arXiv.2112.02719. diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/__pycache__/__init__.cpython-310.pyc b/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c11f813df0801d0887192c2e950974b0bf5a1d3 Binary files /dev/null and b/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/demoire/__init__.py b/models/demoire/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/demoire/__pycache__/__init__.cpython-310.pyc b/models/demoire/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d8c855084f738c98cd392c8da6d9e551aed7fc8 Binary files /dev/null and b/models/demoire/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/demoire/__pycache__/nets.cpython-310.pyc b/models/demoire/__pycache__/nets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c8dc7535fd8501a7508aab63d19109be833bf48 Binary files /dev/null and b/models/demoire/__pycache__/nets.cpython-310.pyc differ diff --git a/model/nets.py b/models/demoire/nets.py similarity index 100% rename from model/nets.py rename to models/demoire/nets.py diff --git a/models/llflow/LOL_smallNet.pth b/models/llflow/LOL_smallNet.pth new file mode 100644 index 0000000000000000000000000000000000000000..80b859f5e4680f6732109960b20c1ac4bba0c1f5 --- /dev/null +++ b/models/llflow/LOL_smallNet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bf4c9192b401bf7155b2aa0781d9d8eed2e0bcc148286a9e2b224e12777bb38 +size 21874185 diff --git a/models/llflow/LOL_smallNet.yml b/models/llflow/LOL_smallNet.yml new file mode 100644 index 0000000000000000000000000000000000000000..3eb38394da253d8468e0a9ad42fe1cbb77d6b7f3 --- /dev/null +++ b/models/llflow/LOL_smallNet.yml @@ -0,0 +1,125 @@ +#### general settings +name: train_rebuttal_smallNet_ch32_blocks1 +use_tb_logger: true +model: LLFlow +distortion: sr +scale: 1 +gpu_ids: [0] +dataset: LoL +optimize_all_z: false +cond_encoder: ConEncoder1 +train_gt_ratio: 0.2 +avg_color_map: false + +concat_histeq: true +histeq_as_input: false +concat_color_map: false +gray_map: false # concat 1-input.mean(dim=1) to the input + +align_condition_feature: false +align_weight: 0.001 +align_maxpool: true + +to_yuv: false + +encode_color_map: false +le_curve: false +# sigmoid_output: true + +#### datasets +datasets: + train: + root: D:\LOLdataset + quant: 32 + use_shuffle: true + n_workers: 1 # per GPU + batch_size: 16 + use_flip: true + color: RGB + use_crop: true + GT_size: 160 # 192 + noise_prob: 0 + noise_level: 5 + log_low: true + gamma_aug: false + + val: + root: D:\LOLdataset + n_workers: 1 + quant: 32 + n_max: 20 + batch_size: 1 # must be 1 + log_low: true + +#### Test Settings +# dataroot_GT: D:\LOLdataset\eval15\high +# dataroot_LR: D:\LOLdataset\eval15\low +dataroot_unpaired: models/llflow/dataset_samples/our485/low +# dataroot_unpaired: /home/data/Dataset/LOL_test/Fusion +dataroot_GT: D:\Dataset\LOL-v2\LOL-v2\IntegratedTest\Test\high +dataroot_LR: D:\Dataset\LOL-v2\LOL-v2\IntegratedTest\Test\low +model_path: models/llflow/LOL_smallNet.pth +heat: 0 # This is the standard deviation of the latent vectors + +#### network structures +network_G: + which_model_G: LLFlow + in_nc: 3 + out_nc: 3 + nf: 32 + nb: 4 # 12 for our low light encoder, 23 for LLFlow + train_RRDB: false + train_RRDB_delay: 0.5 + + flow: + K: 4 # 24.49 psnr用的12 # 16 + L: 3 # 4 + noInitialInj: true + coupling: CondAffineSeparatedAndCond + additionalFlowNoAffine: 2 + conditionInFeaDim: 64 + split: + enable: false + fea_up0: true + stackRRDB: + blocks: [1] + concat: true + +#### path +path: + # pretrain_model_G: ../pretrained_models/RRDB_DF2K_8X.pth + strict_load: true + resume_state: auto + +#### training settings: learning rate scheme, loss +train: + manual_seed: 10 + lr_G: !!float 5e-4 # normalizing flow 5e-4; l1 loss train 5e-5 + weight_decay_G: 0 # 1e-5 # 5e-5 # 1e-5 + beta1: 0.9 + beta2: 0.99 + lr_scheme: MultiStepLR + warmup_iter: -1 # no warm up + lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ] # [0.2, 0.35, 0.5, 0.65, 0.8, 0.95] # [ 0.5, 0.75, 0.9, 0.95 ] + lr_gamma: 0.5 + + weight_l1: 0 + # flow_warm_up_iter: -1 + weight_fl: 1 + + niter: 45000 #200000 + val_freq: 200 # 200 + +#### validation settings +val: + # heats: [ 0.0, 0.5, 0.75, 1.0 ] + n_sample: 4 + +test: + heats: [ 0.0, 0.7, 0.8, 0.9 ] + +#### logger +logger: + # Debug print_freq: 100 + print_freq: 100 + save_checkpoint_freq: !!float 1e3 diff --git a/models/llflow/Measure.py b/models/llflow/Measure.py new file mode 100644 index 0000000000000000000000000000000000000000..dd750b16c7908d48d23ee915aa1d3170b08237e5 --- /dev/null +++ b/models/llflow/Measure.py @@ -0,0 +1,127 @@ +import glob +import os +import time +from collections import OrderedDict + +import numpy as np +import torch +import cv2 +import argparse + +from natsort import natsort +from skimage.metrics import structural_similarity as ssim +from skimage.metrics import peak_signal_noise_ratio as psnr +import lpips + + +class Measure(): + def __init__(self, net='alex', use_gpu=False): + self.device = 'cuda' if use_gpu else 'cpu' + self.model = lpips.LPIPS(net=net) + self.model.to(self.device) + + def measure(self, imgA, imgB): + return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]] + + def lpips(self, imgA, imgB, model=None): + tA = t(imgA).to(self.device) + tB = t(imgB).to(self.device) + dist01 = self.model.forward(tA, tB).item() + return dist01 + + def ssim(self, imgA, imgB, gray_scale=True): + if gray_scale: + score, diff = ssim(cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor( + imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True) + # multichannel: If True, treat the last dimension of the array as channels. Similarity calculations are done independently for each channel then averaged. + else: + score, diff = ssim(imgA, imgB, full=True, multichannel=True) + return score + + def psnr(self, imgA, imgB): + psnr_val = psnr(imgA, imgB) + return psnr_val + + +def t(img): + def to_4d(img): + assert len(img.shape) == 3 + assert img.dtype == np.uint8 + img_new = np.expand_dims(img, axis=0) + assert len(img_new.shape) == 4 + return img_new + + def to_CHW(img): + return np.transpose(img, [2, 0, 1]) + + def to_tensor(img): + return torch.Tensor(img) + + return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1 + + +def fiFindByWildcard(wildcard): + return natsort.natsorted(glob.glob(wildcard, recursive=True)) + + +def imread(path): + return cv2.imread(path)[:, :, [2, 1, 0]] + + +def format_result(psnr, ssim, lpips): + return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}' + + +def measure_dirs(dirA, dirB, use_gpu, verbose=False): + if verbose: + def vprint(x): return print(x) + else: + def vprint(x): return None + + t_init = time.time() + + paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}')) + paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}')) + + vprint("Comparing: ") + vprint(dirA) + vprint(dirB) + + measure = Measure(use_gpu=use_gpu) + + results = [] + for pathA, pathB in zip(paths_A, paths_B): + result = OrderedDict() + + t = time.time() + result['psnr'], result['ssim'], result['lpips'] = measure.measure( + imread(pathA), imread(pathB)) + d = time.time() - t + vprint( + f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}") + + results.append(result) + + psnr = np.mean([result['psnr'] for result in results]) + ssim = np.mean([result['ssim'] for result in results]) + lpips = np.mean([result['lpips'] for result in results]) + + vprint( + f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-dirA', default='', type=str) + parser.add_argument('-dirB', default='', type=str) + parser.add_argument('-type', default='png') + parser.add_argument('--use_gpu', action='store_true', default=False) + args = parser.parse_args() + + dirA = args.dirA + dirB = args.dirB + type = args.type + use_gpu = args.use_gpu + + if len(dirA) > 0 and len(dirB) > 0: + measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True) diff --git a/models/llflow/__init__.py b/models/llflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c576a122fb8209d29ce6d006467459df588b2d2 --- /dev/null +++ b/models/llflow/__init__.py @@ -0,0 +1,4 @@ +# from util import get_resume_paths, opt_get +from .Measure import Measure, psnr +from .imresize import imresize +from models import * \ No newline at end of file diff --git a/models/llflow/__pycache__/Measure.cpython-310.pyc b/models/llflow/__pycache__/Measure.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05ff66ba84987edf1397f077147e611b6cfe5045 Binary files /dev/null and b/models/llflow/__pycache__/Measure.cpython-310.pyc differ diff --git a/models/llflow/__pycache__/__init__.cpython-310.pyc b/models/llflow/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aadc328c02daf2bb9b72e7700b678ae6b8d33008 Binary files /dev/null and b/models/llflow/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/llflow/__pycache__/imresize.cpython-310.pyc b/models/llflow/__pycache__/imresize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93cfca04771231e24c43377ccc26473aa9c864fb Binary files /dev/null and b/models/llflow/__pycache__/imresize.cpython-310.pyc differ diff --git a/models/llflow/__pycache__/inference.cpython-310.pyc b/models/llflow/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..416301323f272237ff80cf7ebf949ecad3dd6d00 Binary files /dev/null and b/models/llflow/__pycache__/inference.cpython-310.pyc differ diff --git a/models/llflow/__pycache__/option_.cpython-310.pyc b/models/llflow/__pycache__/option_.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db2df51ba88f783da564aeaaec2464df05dd4b94 Binary files /dev/null and b/models/llflow/__pycache__/option_.cpython-310.pyc differ diff --git a/models/llflow/__pycache__/util.cpython-310.pyc b/models/llflow/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d01b6acba178f16880ede82d2abdee1ce4f3421 Binary files /dev/null and b/models/llflow/__pycache__/util.cpython-310.pyc differ diff --git a/models/llflow/imresize.py b/models/llflow/imresize.py new file mode 100644 index 0000000000000000000000000000000000000000..734d48994d927b1377b39b9c9de8ed36f8f5e6e0 --- /dev/null +++ b/models/llflow/imresize.py @@ -0,0 +1,180 @@ +# https://github.com/fatheral/matlab_imresize +# +# MIT License +# +# Copyright (c) 2020 Alex +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from __future__ import print_function +import numpy as np +from math import ceil, floor + + +def deriveSizeFromScale(img_shape, scale): + output_shape = [] + for k in range(2): + output_shape.append(int(ceil(scale[k] * img_shape[k]))) + return output_shape + + +def deriveScaleFromSize(img_shape_in, img_shape_out): + scale = [] + for k in range(2): + scale.append(1.0 * img_shape_out[k] / img_shape_in[k]) + return scale + + +def triangle(x): + x = np.array(x).astype(np.float64) + lessthanzero = np.logical_and((x >= -1), x < 0) + greaterthanzero = np.logical_and((x <= 1), x >= 0) + f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero) + return f + + +def cubic(x): + x = np.array(x).astype(np.float64) + absx = np.absolute(x) + absx2 = np.multiply(absx, absx) + absx3 = np.multiply(absx2, absx) + f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2, + (1 < absx) & (absx <= 2)) + return f + + +def contributions(in_length, out_length, scale, kernel, k_width): + if scale < 1: + h = lambda x: scale * kernel(scale * x) + kernel_width = 1.0 * k_width / scale + else: + h = kernel + kernel_width = k_width + x = np.arange(1, out_length + 1).astype(np.float64) + u = x / scale + 0.5 * (1 - 1 / scale) + left = np.floor(u - kernel_width / 2) + P = int(ceil(kernel_width)) + 2 + ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0 + indices = ind.astype(np.int32) + weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0 + weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1)) + aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32) + indices = aux[np.mod(indices, aux.size)] + ind2store = np.nonzero(np.any(weights, axis=0)) + weights = weights[:, ind2store] + indices = indices[:, ind2store] + return weights, indices + + +def imresizemex(inimg, weights, indices, dim): + in_shape = inimg.shape + w_shape = weights.shape + out_shape = list(in_shape) + out_shape[dim] = w_shape[0] + outimg = np.zeros(out_shape) + if dim == 0: + for i_img in range(in_shape[1]): + for i_w in range(w_shape[0]): + w = weights[i_w, :] + ind = indices[i_w, :] + im_slice = inimg[ind, i_img].astype(np.float64) + outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) + elif dim == 1: + for i_img in range(in_shape[0]): + for i_w in range(w_shape[0]): + w = weights[i_w, :] + ind = indices[i_w, :] + im_slice = inimg[i_img, ind].astype(np.float64) + outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) + if inimg.dtype == np.uint8: + outimg = np.clip(outimg, 0, 255) + return np.around(outimg).astype(np.uint8) + else: + return outimg + + +def imresizevec(inimg, weights, indices, dim): + wshape = weights.shape + if dim == 0: + weights = weights.reshape((wshape[0], wshape[2], 1, 1)) + outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1) + elif dim == 1: + weights = weights.reshape((1, wshape[0], wshape[2], 1)) + outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2) + if inimg.dtype == np.uint8: + outimg = np.clip(outimg, 0, 255) + return np.around(outimg).astype(np.uint8) + else: + return outimg + + +def resizeAlongDim(A, dim, weights, indices, mode="vec"): + if mode == "org": + out = imresizemex(A, weights, indices, dim) + else: + out = imresizevec(A, weights, indices, dim) + return out + + +def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"): + if method is 'bicubic': + kernel = cubic + elif method is 'bilinear': + kernel = triangle + else: + print('Error: Unidentified method supplied') + + kernel_width = 4.0 + # Fill scale and output_size + if scalar_scale is not None: + scalar_scale = float(scalar_scale) + scale = [scalar_scale, scalar_scale] + output_size = deriveSizeFromScale(I.shape, scale) + elif output_shape is not None: + scale = deriveScaleFromSize(I.shape, output_shape) + output_size = list(output_shape) + else: + print('Error: scalar_scale OR output_shape should be defined!') + return + scale_np = np.array(scale) + order = np.argsort(scale_np) + weights = [] + indices = [] + for k in range(2): + w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width) + weights.append(w) + indices.append(ind) + B = np.copy(I) + flag2D = False + if B.ndim == 2: + B = np.expand_dims(B, axis=2) + flag2D = True + for k in range(2): + dim = order[k] + B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode) + if flag2D: + B = np.squeeze(B, axis=2) + return B + + +def convertDouble2Byte(I): + B = np.clip(I, 0.0, 1.0) + B = 255 * B + return np.around(B).astype(np.uint8) diff --git a/models/llflow/inference.py b/models/llflow/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1a224166e868683ea0d3fc73af24f042415940 --- /dev/null +++ b/models/llflow/inference.py @@ -0,0 +1,157 @@ +import glob +import sys +from collections import OrderedDict +import tqdm +from natsort import natsort +import argparse +import models.llflow.option_ as option +from models.llflow import Measure, psnr +from models.llflow import imresize +from models import create_model +import torch +from util import opt_get +import numpy as np +import pandas as pd +import os +import cv2 +from rich.console import Console + +def fiFindByWildcard(wildcard): + return natsort.natsorted(glob.glob(wildcard, recursive=True)) + + +def load_model(conf_path): + opt = option.parse(conf_path, is_train=False) + opt['gpu_ids'] = None + opt = option.dict_to_nonedict(opt) + model = create_model(opt) + + model_path = opt_get(opt, ['model_path'], None) + model.load_network(load_path=model_path, network=model.netG) + return model, opt + + +def predict(model, lr): + model.feed_data({"LQ": t(lr)}, need_GT=False) + model.test() + visuals = model.get_current_visuals(need_GT=False) + return visuals.get('rlt', visuals.get('NORMAL')) + + +def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255 + + +def rgb(t): return ( + np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype( + np.uint8) + + +def imread(path): + return cv2.imread(path)[:, :, [2, 1, 0]] + + +def imwrite(path, img): + os.makedirs(os.path.dirname(path), exist_ok=True) + cv2.imwrite(path, img[:, :, [2, 1, 0]]) + + +def imCropCenter(img, size): + h, w, c = img.shape + + h_start = max(h // 2 - size // 2, 0) + h_end = min(h_start + size, h) + + w_start = max(w // 2 - size // 2, 0) + w_end = min(w_start + size, w) + + return img[h_start:h_end, w_start:w_end] + + +def impad(img, top=0, bottom=0, left=0, right=0, color=255): + return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect') + + +def hiseq_color_cv2_img(img): + (b, g, r) = cv2.split(img) + bH = cv2.equalizeHist(b) + gH = cv2.equalizeHist(g) + rH = cv2.equalizeHist(r) + result = cv2.merge((bH, gH, rH)) + return result + + +def auto_padding(img, times=16): + # img: numpy image with shape H*W*C + + h, w, _ = img.shape + h1, w1 = (times - h % times) // 2, (times - w % times) // 2 + h2, w2 = (times - h % times) - h1, (times - w % times) - w1 + img = cv2.copyMakeBorder(img, h1, h2, w1, w2, cv2.BORDER_REFLECT) + return img, [h1, h2, w1, w2] + + +def main(path:str): + parser = argparse.ArgumentParser() + # parser.add_argument("--opt", default="./confs/LOL_smallNet.yml") + parser.add_argument("--opt", default="./models/llflow/LOL_smallNet.yml") + parser.add_argument("-n", "--name", default="unpaired") + + # Namespace(opt="./models/llflow/LOL_smallNet.yml", name="unpaired") + # args = parser.parse_args() + args = parser.parse_args() + + Console().log(f"🛠️\tLoading model from {args.opt}") + + conf_path = args.opt + conf = conf_path.split('/')[-1].replace('.yml', '') + model, opt = load_model(conf_path) + model.netG = model.netG.cuda() + + lr_dir = opt['dataroot_unpaired'] + # lr_paths = fiFindByWildcard(os.path.join(lr_dir, '*.*')) + lr_paths = path + + this_dir = os.path.dirname(os.path.realpath(__file__)) + test_dir = os.path.join(this_dir, '..', 'results', conf, args.name) + print(f"Out dir: {test_dir}") + + # for lr_path, idx_test in tqdm.tqdm(zip(lr_paths, range(len(lr_paths))), colour='green'): + lr_path = lr_paths + lr = imread(lr_path) + raw_shape = lr.shape + lr, padding_params = auto_padding(lr) + his = hiseq_color_cv2_img(lr) + if opt.get("histeq_as_input", False): + lr = his + + lr_t = t(lr) + if opt["datasets"]["train"].get("log_low", False): + lr_t = torch.log(torch.clamp(lr_t + 1e-3, min=1e-3)) + if opt.get("concat_histeq", False): + his = t(his) + lr_t = torch.cat([lr_t, his], dim=1) + heat = opt['heat'] + with torch.cuda.amp.autocast(): + sr_t = model.get_sr(lq=lr_t.cuda(), heat=None) + + sr = rgb(torch.clamp(sr_t, 0, 1)[:, :, padding_params[0]:sr_t.shape[2] - padding_params[1], + padding_params[2]:sr_t.shape[3] - padding_params[3]]) + assert raw_shape == sr.shape + path_out_sr = os.path.join(test_dir, os.path.basename(lr_path)) + # imwrite(path_out_sr, sr) + # cv2.imwrite(path_out_sr, sr[:, :, [2, 1, 0]]) + + return sr[:, :, [2, 1, 0]] + + +def format_measurements(meas): + s_out = [] + for k, v in meas.items(): + v = f"{v:0.2f}" if isinstance(v, float) else v + s_out.append(f"{k}: {v}") + str_out = ", ".join(s_out) + return str_out + + +if __name__ == "__main__": + main() diff --git a/models/llflow/models/LLFlow_model.py b/models/llflow/models/LLFlow_model.py new file mode 100644 index 0000000000000000000000000000000000000000..41fea0267980459c2a2e269940b9905541eaac52 --- /dev/null +++ b/models/llflow/models/LLFlow_model.py @@ -0,0 +1,400 @@ +import logging +from collections import OrderedDict +# from models.llflow.util import get_resume_paths, opt_get +# from models.llflow import get_resume_paths, opt_get +import glob +import os +import natsort +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel +from torch.cuda.amp import GradScaler, autocast + +logger = logging.getLogger('base') + + + + +def get_resume_paths(opt): + resume_state_path = None + resume_model_path = None + ts = opt_get(opt, ['path', 'training_state']) + if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None: + wildcard = os.path.join(ts, "*") + paths = natsort.natsorted(glob.glob(wildcard)) + if len(paths) > 0: + resume_state_path = paths[-1] + resume_model_path = resume_state_path.replace( + 'training_state', 'models').replace('.state', '_G.pth') + else: + resume_state_path = opt.get('path', {}).get('resume_state') + return resume_state_path, resume_model_path + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret + + + + + + + + + + + + +class LLFlowModel(BaseModel): + def __init__(self, opt, step): + super(LLFlowModel, self).__init__(opt) + self.opt = opt + + self.already_print_params_num = False + + self.heats = opt['val']['heats'] + self.n_sample = opt['val']['n_sample'] + self.hr_size = opt['datasets']['train']['GT_size'] # opt_get(opt, ['datasets', 'train', 'center_crop_hr_size']) + # self.hr_size = 160 if self.hr_size is None else self.hr_size + self.lr_size = self.hr_size // opt['scale'] + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt['train'] + + # define network and load pretrained models + self.netG = networks.define_Flow(opt, step).to(self.device) + # + weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0 + if weight_l1 and 1: + missing_keys, unexpected_keys = self.netG.load_state_dict(torch.load( + '/home/yufei/project/LowLightFlow/experiments/to_pretrain_netG/models/1000_G.pth'), + strict=False) + print('missing %d keys, unexpected %d keys' % (len(missing_keys), len(unexpected_keys))) + # if self.device.type != 'cpu': + if opt['gpu_ids'] is not None and len(opt['gpu_ids']) > 0: + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + elif len(opt['gpu_ids']) > 1: + self.netG = DataParallel(self.netG, opt['gpu_ids']) + else: + self.netG.cuda() + # print network + # self.print_network() + + if opt_get(opt, ['path', 'resume_state'], 1) is not None: + self.load() + else: + print("WARNING: skipping initial loading, due to resume_state None") + + if self.is_train: + self.netG.train() + + self.init_optimizer_and_scheduler(train_opt) + self.log_dict = OrderedDict() + + def to(self, device): + self.device = device + self.netG.to(device) + + def init_optimizer_and_scheduler(self, train_opt): + # optimizers + self.optimizers = [] + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + if isinstance(wd_G, str): wd_G = eval(wd_G) + optim_params_RRDB = [] + optim_params_other = [] + for k, v in self.netG.named_parameters(): # can optimize for a part of the model + # print(k, v.requires_grad) + if v.requires_grad: + if '.RRDB.' in k: + optim_params_RRDB.append(v) + # print('opt', k) + else: + optim_params_other.append(v) + # if self.rank <= 0: + # logger.warning('Params [{:s}] will not optimize.'.format(k)) + + print('rrdb params', len(optim_params_RRDB)) + + self.optimizer_G = torch.optim.Adam( + [ + {"params": optim_params_other, "lr": train_opt['lr_G'], 'beta1': train_opt['beta1'], + 'beta2': train_opt['beta2'], 'weight_decay': wd_G}, + {"params": optim_params_RRDB, "lr": train_opt.get('lr_RRDB', train_opt['lr_G']), + 'beta1': train_opt['beta1'], + 'beta2': train_opt['beta2'], 'weight_decay': 1e-5} + ] + ) + + self.scaler = GradScaler() + + self.optimizers.append(self.optimizer_G) + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'], + lr_steps_invese=train_opt.get('lr_steps_inverse', []))) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + raise NotImplementedError('MultiStepLR learning rate scheme is enough.') + + def add_optimizer_and_scheduler_RRDB(self, train_opt): + # optimizers + assert len(self.optimizers) == 1, self.optimizers + assert len(self.optimizer_G.param_groups[1]['params']) == 0, self.optimizer_G.param_groups[1] + for k, v in self.netG.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + if '.RRDB.' in k: + self.optimizer_G.param_groups[1]['params'].append(v) + assert len(self.optimizer_G.param_groups[1]['params']) > 0 + + def feed_data(self, data, need_GT=True): + self.var_L = data['LQ'].to(self.device) # LQ + if need_GT: + self.real_H = data['GT'].to(self.device) # GT + + def get_module(self, model): + if isinstance(model, nn.DataParallel): + return model.module + else: + return model + + def optimize_color_encoder(self, step): + self.netG.train() + self.log_dict = OrderedDict() + self.optimizer_G.zero_grad() + color_lr, color_gt = self.fake_H[(0, 0)], logdet = self.netG(lr=self.var_L, gt=self.real_H, + get_color_map=True) + losses = {} + total_loss = (color_gt - color_lr).abs().mean() + # try: + total_loss.backward() + self.optimizer_G.step() + mean = total_loss.item() + return mean + + def optimize_parameters(self, step): + train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) + if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \ + and not self.get_module(self.netG).RRDB_training: + if self.get_module(self.netG).set_rrdb_training(True): + self.add_optimizer_and_scheduler_RRDB(self.opt['train']) + + # self.print_rrdb_state() + + self.netG.train() + self.log_dict = OrderedDict() + self.optimizer_G.zero_grad() + # with autocast(): + losses = {} + weight_fl = opt_get(self.opt, ['train', 'weight_fl']) + weight_fl = 1 if weight_fl is None else weight_fl + weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0 + flow_warm_up_iter = opt_get(self.opt, ['train', 'flow_warm_up_iter']) + # print(step, flow_warm_up_iter) + if flow_warm_up_iter is not None: + if step > flow_warm_up_iter: + weight_fl = 0 + else: + weight_l1 = 0 + # print(weight_fl, weight_l1) + if weight_fl > 0: + if self.opt['optimize_all_z']: + if self.opt['gpu_ids'] is not None and len(self.opt['gpu_ids']) > 0: + epses = [[] for _ in range(len(self.opt['gpu_ids']))] + else: + epses = [] + else: + epses = None + z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False, epses=epses, + align_condition_feature=opt_get(self.opt, + ['align_condition_feature']) or False) + nll_loss = torch.mean(nll) + losses['nll_loss'] = nll_loss * weight_fl + + if weight_l1 > 0: + z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) + sr, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True) + sr = sr.clamp(0, 1) + not_nan_mask = ~torch.isnan(sr) + sr[torch.isnan(sr)] = 0 + l1_loss = ((sr - self.real_H) * not_nan_mask).abs().mean() + losses['l1_loss'] = l1_loss * weight_l1 + if flow_warm_up_iter is not None: + print(l1_loss, not_nan_mask.float().mean()) + total_loss = sum(losses.values()) + # try: + self.scaler.scale(total_loss).backward() + if not self.already_print_params_num: + logger.info("Parameters of full network %.4f and encoder %.4f"%(sum([m.numel() for m in self.netG.parameters() if m.grad is not None])/1e6, sum([m.numel() for m in self.netG.RRDB.parameters() if m.grad is not None])/1e6)) + self.already_print_params_num = True + self.scaler.step(self.optimizer_G) + self.scaler.update() + # except Exception as e: + # print(e) + # print(total_loss) + + mean = total_loss.item() + return mean + + def print_rrdb_state(self): + for name, param in self.get_module(self.netG).named_parameters(): + if "RRDB.conv_first.weight" in name: + print(name, param.requires_grad, param.data.abs().sum()) + print('params', [len(p['params']) for p in self.optimizer_G.param_groups]) + + def get_color_map(self): + self.netG.eval() + z = self.get_z(0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) + with torch.no_grad(): + color_lr, color_gt = self.fake_H[(0, 0)], logdet = self.netG(lr=self.var_L, gt=self.real_H, + get_color_map=True) + self.netG.train() + return color_lr, color_gt + + def test(self): + self.netG.eval() + self.fake_H = {} + if self.heats is not None: + for heat in self.heats: + for i in range(self.n_sample): + z = self.get_z(heat, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) + with torch.no_grad(): + self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L, z=z, eps_std=heat, reverse=True) + else: + z = self.get_z(0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) + with torch.no_grad(): + # torch.cuda.reset_peak_memory_stats() + self.fake_H[(0, 0)], logdet = self.netG(lr=self.var_L, z=z.to(self.var_L.device), eps_std=0, reverse=True) + # from thop import clever_format, profile + # print(clever_format(profile(self.netG, (None,self.var_L, z.to(self.var_L.device), 0 ,True))),"%.4") + # print(torch.cuda.max_memory_allocated()/1024/1024/1024) + # import time + # t = time.time() + # for i in range(15): + # with torch.no_grad(): + # self.fake_H[(0, 0)], logdet = self.netG(lr=self.var_L, z=z.to(self.var_L.device), eps_std=0, reverse=True) + # print((time.time()-t)/15) + # with torch.no_grad(): + # _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False) + self.netG.train() + return None + # return nll.mean().item() + + def get_encode_nll(self, lq, gt): + self.netG.eval() + with torch.no_grad(): + _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False) + self.netG.train() + return nll.mean().item() + + def get_sr(self, lq, heat=None, seed=None, z=None, epses=None): + return self.get_sr_with_z(lq, heat, seed, z, epses)[0] + + def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True): + self.netG.eval() + with torch.no_grad(): + z, _, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise) + self.netG.train() + return z + + def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True): + self.netG.eval() + with torch.no_grad(): + z, nll, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise) + self.netG.train() + return z, nll + + def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None): + self.netG.eval() + if heat is None: + heat = 0 + z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z + + with torch.no_grad(): + sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses) + self.netG.train() + return sr, z + + def get_z(self, heat, seed=None, batch_size=1, lr_shape=None): + if seed: torch.manual_seed(seed) + if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']): + C = self.get_module(self.netG).flowUpsamplerNet.C + H = int(self.opt['scale'] * lr_shape[2] // self.get_module(self.netG).flowUpsamplerNet.scaleH) + W = int(self.opt['scale'] * lr_shape[3] // self.get_module(self.netG).flowUpsamplerNet.scaleW) + z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros( + (batch_size, C, H, W)) + else: + L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3 + fac = 2 ** L + H = int(self.opt['scale'] * lr_shape[2] // self.get_module(self.netG).flowUpsamplerNet.scaleH) + W = int(self.opt['scale'] * lr_shape[3] // self.get_module(self.netG).flowUpsamplerNet.scaleW) + size = (batch_size, 3 * fac * fac, H, W) + z = torch.normal(mean=0, std=heat, size=size) if heat > 0 else torch.zeros(size) + return z + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self, need_GT=True): + out_dict = OrderedDict() + out_dict['LQ'] = self.var_L.detach()[0].float().cpu() + if self.heats is not None: + for heat in self.heats: + for i in range(self.n_sample): + out_dict[('NORMAL', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu() + else: + out_dict['NORMAL'] = self.fake_H[(0, 0)].detach()[0].float().cpu() + if need_GT: + out_dict['GT'] = self.real_H.detach()[0].float().cpu() + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + def load(self): + _, get_resume_model_path = get_resume_paths(self.opt) + if get_resume_model_path is not None: + self.load_network(get_resume_model_path, self.netG, strict=True, submodule=None) + return + + load_path_G = self.opt['path']['pretrain_model_G'] + load_submodule = self.opt['path']['load_submodule'] if 'load_submodule' in self.opt['path'].keys() else 'RRDB' + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path'].get('strict_load', True), + submodule=load_submodule) + + def save(self, iter_label): + self.save_network(self.netG, 'G', iter_label) diff --git a/models/llflow/models/__init__.py b/models/llflow/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..98032587c0ab68b2ee1300b1f7ca4e2b08f4398a --- /dev/null +++ b/models/llflow/models/__init__.py @@ -0,0 +1,52 @@ +import importlib +import logging +import os + +try: + import local_config +except: + local_config = None + + +logger = logging.getLogger('base') + + +def find_model_using_name(model_name): + # Given the option --model [modelname], + # the file "models/modelname_model.py" + # will be imported. + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + + # In the file, the class called ModelNameModel() will + # be instantiated. It has to be a subclass of torch.nn.Module, + # and it is case-insensitive. + model = None + target_model_name = model_name.replace('_', '') + 'Model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower(): + model = cls + + if model is None: + print( + "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % ( + model_filename, target_model_name)) + exit(0) + + return model + + +def create_model(opt, step=0, **opt_kwargs): + if local_config is not None: + opt['path']['pretrain_model_G'] = os.path.join(local_config.checkpoint_path, os.path.basename(opt['path']['results_root'] + '.pth')) + + for k, v in opt_kwargs.items(): + opt[k] = v + + model = opt['model'] + + M = find_model_using_name(model) + + m = M(opt, step) + logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) + return m diff --git a/models/llflow/models/__pycache__/LLFlow_model.cpython-310.pyc b/models/llflow/models/__pycache__/LLFlow_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..197f310cc617b6a5a6bf60359ff40b99e8657846 Binary files /dev/null and b/models/llflow/models/__pycache__/LLFlow_model.cpython-310.pyc differ diff --git a/models/llflow/models/__pycache__/__init__.cpython-310.pyc b/models/llflow/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f55ff899c9de6eda0288a8e1bd5d586b5b4c7bc Binary files /dev/null and b/models/llflow/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/llflow/models/__pycache__/base_model.cpython-310.pyc b/models/llflow/models/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..051bfb0aaaa740096e9c505f043067cd5816f5f2 Binary files /dev/null and b/models/llflow/models/__pycache__/base_model.cpython-310.pyc differ diff --git a/models/llflow/models/__pycache__/lr_scheduler.cpython-310.pyc b/models/llflow/models/__pycache__/lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4573e4995d6964683b4b1f8b149a03ac21cbf8b0 Binary files /dev/null and b/models/llflow/models/__pycache__/lr_scheduler.cpython-310.pyc differ diff --git a/models/llflow/models/__pycache__/networks.cpython-310.pyc b/models/llflow/models/__pycache__/networks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98d12f000fbe08a4a77056956952c853011ce976 Binary files /dev/null and b/models/llflow/models/__pycache__/networks.cpython-310.pyc differ diff --git a/models/llflow/models/base_model.py b/models/llflow/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..255b8c62574fdc3ccff7eb5b94d3d1b6b1630754 --- /dev/null +++ b/models/llflow/models/base_model.py @@ -0,0 +1,145 @@ + + + +import os +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel +import natsort +import glob + + +class BaseModel(): + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt.get('gpu_ids', None) is not None else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + self.scaler = None + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def get_current_losses(self): + pass + + def print_network(self): + pass + + def save(self, label): + pass + + def load(self): + pass + + def _set_lr(self, lr_groups_l): + ''' set learning rate for warmup, + lr_groups_l: list for lr_groups. each for a optimizer''' + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + # get the initial lr, which is set by the scheduler + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, cur_iter, warmup_iter=-1): + for scheduler in self.schedulers: + scheduler.step() + #### set up warm up learning rate + if cur_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + # return self.schedulers[0].get_lr()[0] + return self.optimizers[0].param_groups[0]['lr'] + + def get_network_description(self, network): + '''Get the string and total parameters of the network''' + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + s = str(network) + n = sum(map(lambda x: x.numel(), network.parameters())) + return s, n + + def save_network(self, network, network_label, iter_label): + paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['models'], "*_{}.pth".format(network_label))), + reverse=True) + paths = [p for p in paths if + "latest_" not in p and not any([str(i * 10000) in p.split("/")[-1].split("_") for i in range(101)])] + if len(paths) > 2: + for path in paths[2:]: + os.remove(path) + save_filename = '{}_{}.pth'.format(iter_label, network_label) + save_path = os.path.join(self.opt['path']['models'], save_filename) + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + state_dict = network.state_dict() + for key, param in state_dict.items(): + state_dict[key] = param.cpu() + torch.save(state_dict, save_path) + + def load_network(self, load_path, network, strict=True, submodule=None): + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + if not (submodule is None or submodule.lower() == 'none'.lower()): + network = network.__getattr__(submodule) + load_net = torch.load(load_path) + load_net_clean = OrderedDict() # remove unnecessary 'module.' + for k, v in load_net.items(): + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + network.load_state_dict(load_net_clean, strict=strict) + + def save_training_state(self, epoch, iter_step): + '''Saves training state during training, which will be used for resuming''' + state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': [], 'scaler': None} + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + state['scaler'] = self.scaler.state_dict() + save_filename = '{}.state'.format(iter_step) + save_path = os.path.join(self.opt['path']['training_state'], save_filename) + + paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['training_state'], "*.state")), + reverse=True) + paths = [p for p in paths if "latest_" not in p] + if len(paths) > 2: + for path in paths[2:]: + os.remove(path) + + torch.save(state, save_path) + + def resume_training(self, resume_state): + '''Resume the optimizers and schedulers for training''' + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + resume_scaler = resume_state['scaler'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + self.scaler.load_state_dict(resume_scaler) diff --git a/models/llflow/models/lr_scheduler.py b/models/llflow/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ba87c2a9b0d19836bf608c0a50c161200fc2230a --- /dev/null +++ b/models/llflow/models/lr_scheduler.py @@ -0,0 +1,147 @@ +import math +from collections import Counter +from collections import defaultdict +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepLR_Restart(_LRScheduler): + def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, + clear_state=False, last_epoch=-1, lr_steps_invese=None): + assert lr_steps_invese is not None, "Use empty list" + self.milestones = Counter(milestones) + self.lr_steps_inverse = Counter(lr_steps_invese) + self.gamma = gamma + self.clear_state = clear_state + self.restarts = restarts if restarts else [0] + self.restart_weights = weights if weights else [1] + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + if self.clear_state: + self.optimizer.state = defaultdict(dict) + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones and self.last_epoch not in self.lr_steps_inverse: + return [group['lr'] for group in self.optimizer.param_groups] + return [ + group['lr'] * (self.gamma ** self.milestones[self.last_epoch]) * + (self.gamma ** (-self.lr_steps_inverse[self.last_epoch])) + for group in self.optimizer.param_groups + ] + + +class CosineAnnealingLR_Restart(_LRScheduler): + def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): + self.T_period = T_period + self.T_max = self.T_period[0] # current T period + self.eta_min = eta_min + self.restarts = restarts if restarts else [0] + self.restart_weights = weights if weights else [1] + self.last_restart = 0 + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lrs + elif self.last_epoch in self.restarts: + self.last_restart = self.last_epoch + self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / + (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * + (group['lr'] - self.eta_min) + self.eta_min + for group in self.optimizer.param_groups] + + +if __name__ == "__main__": + optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, + betas=(0.9, 0.99)) + ############################## + # MultiStepLR_Restart + ############################## + ## Original + lr_steps = [200000, 400000, 600000, 800000] + restarts = None + restart_weights = None + + ## two + lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] + restarts = [500000] + restart_weights = [1] + + ## four + lr_steps = [ + 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, + 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 + ] + restarts = [250000, 500000, 750000] + restart_weights = [1, 1, 1] + + scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, + clear_state=False) + + ############################## + # Cosine Annealing Restart + ############################## + ## two + T_period = [500000, 500000] + restarts = [500000] + restart_weights = [1] + + ## four + T_period = [250000, 250000, 250000, 250000] + restarts = [250000, 500000, 750000] + restart_weights = [1, 1, 1] + + scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, + weights=restart_weights) + + ############################## + # Draw figure + ############################## + N_iter = 1000000 + lr_l = list(range(N_iter)) + for i in range(N_iter): + scheduler.step() + current_lr = optimizer.param_groups[0]['lr'] + lr_l[i] = current_lr + + import matplotlib as mpl + from matplotlib import pyplot as plt + import matplotlib.ticker as mtick + + mpl.style.use('default') + import seaborn + + seaborn.set(style='whitegrid') + seaborn.set_context('paper') + + plt.figure(1) + plt.subplot(111) + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) + plt.title('Title', fontsize=16, color='k') + plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') + legend = plt.legend(loc='upper right', shadow=False) + ax = plt.gca() + labels = ax.get_xticks().tolist() + for k, v in enumerate(labels): + labels[k] = str(int(v / 1000)) + 'K' + ax.set_xticklabels(labels) + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) + + ax.set_ylabel('Learning rate') + ax.set_xlabel('Iteration') + fig = plt.gcf() + plt.show() diff --git a/models/llflow/models/modules/ConditionEncoder.py b/models/llflow/models/modules/ConditionEncoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba68b3b8237d72fb7a1f8ff946e28dfdd9ac582 --- /dev/null +++ b/models/llflow/models/modules/ConditionEncoder.py @@ -0,0 +1,287 @@ + + +from torchvision.utils import save_image +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import models.modules.module_util as mutil +# from utils.util import opt_get +from models.modules.flow import Conv2dZeros + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret + + + + + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + # gamma = torch.sigmoid(self.conv5(torch.cat((x, x1, x2, x3, x4), 1))) + # x = torch.sigmoid(x) + # return x + gamma * x * (1 - x) + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class ConEncoder1(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None): + self.opt = opt + self.gray_map_bool = False + self.concat_color_map = False + if opt['concat_histeq']: + in_nc = in_nc + 3 + if opt['concat_color_map']: + in_nc = in_nc + 3 + self.concat_color_map = True + if opt['gray_map']: + in_nc = in_nc + 1 + self.gray_map_bool = True + in_nc = in_nc + 6 + super(ConEncoder1, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + self.scale = scale + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv_second = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### downsampling + self.downconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.downconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.downconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + # self.downconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + self.awb_para = nn.Linear(nf, 3) + self.fine_tune_color_map = nn.Sequential(nn.Conv2d(nf, 3, 1, 1),nn.Sigmoid()) + + def forward(self, x, get_steps=False): + if self.gray_map_bool: + x = torch.cat([x, 1 - x.mean(dim=1, keepdim=True)], dim=1) + if self.concat_color_map: + x = torch.cat([x, x / (x.sum(dim=1, keepdim=True) + 1e-4)], dim=1) + + raw_low_input = x[:, 0:3].exp() + # fea_for_awb = F.adaptive_avg_pool2d(fea_down8, 1).view(-1, 64) + awb_weight = 1 # (1 + self.awb_para(fea_for_awb).unsqueeze(2).unsqueeze(3)) + low_after_awb = raw_low_input * awb_weight + # import pdb + # pdb.set_trace() + color_map = low_after_awb / (low_after_awb.sum(dim=1, keepdims=True) + 1e-4) + dx, dy = self.gradient(color_map) + noise_map = torch.max(torch.stack([dx.abs(), dy.abs()], dim=0), dim=0)[0] + # color_map = self.fine_tune_color_map(torch.cat([color_map, noise_map], dim=1)) + + fea = self.conv_first(torch.cat([x, color_map, noise_map], dim=1)) + fea = self.lrelu(fea) + fea = self.conv_second(fea) + fea_head = F.max_pool2d(fea, 2) + + block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] + block_results = {} + fea = fea_head + for idx, m in enumerate(self.RRDB_trunk.children()): + fea = m(fea) + for b in block_idxs: + if b == idx: + block_results["block_{}".format(idx)] = fea + trunk = self.trunk_conv(fea) + # fea = F.max_pool2d(fea, 2) + fea_down2 = fea_head + trunk + + fea_down4 = self.downconv1(F.interpolate(fea_down2, scale_factor=1 / 2, mode='bilinear', align_corners=False, + recompute_scale_factor=True)) + fea = self.lrelu(fea_down4) + + fea_down8 = self.downconv2( + F.interpolate(fea, scale_factor=1 / 2, mode='bilinear', align_corners=False, recompute_scale_factor=True)) + # fea = self.lrelu(fea_down8) + + # fea_down16 = self.downconv3( + # F.interpolate(fea, scale_factor=1 / 2, mode='bilinear', align_corners=False, recompute_scale_factor=True)) + # fea = self.lrelu(fea_down16) + + results = {'fea_up0': fea_down8, + 'fea_up1': fea_down4, + 'fea_up2': fea_down2, + 'fea_up4': fea_head, + 'last_lr_fea': fea_down4, + 'color_map': self.fine_tune_color_map(F.interpolate(fea_down2, scale_factor=2)) + } + + # 'color_map': color_map} # raw + + if get_steps: + for k, v in block_results.items(): + results[k] = v + return results + else: + return None + + def gradient(self, x): + def sub_gradient(x): + left_shift_x, right_shift_x, grad = torch.zeros_like( + x), torch.zeros_like(x), torch.zeros_like(x) + left_shift_x[:, :, 0:-1] = x[:, :, 1:] + right_shift_x[:, :, 1:] = x[:, :, 0:-1] + grad = 0.5 * (left_shift_x - right_shift_x) + return grad + + return sub_gradient(x), sub_gradient(torch.transpose(x, 2, 3)).transpose(2, 3) + + +class NoEncoder(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None): + self.opt = opt + self.gray_map_bool = False + self.concat_color_map = False + if opt['concat_histeq']: + in_nc = in_nc + 3 + if opt['concat_color_map']: + in_nc = in_nc + 3 + self.concat_color_map = True + if opt['gray_map']: + in_nc = in_nc + 1 + self.gray_map_bool = True + in_nc = in_nc + 6 + super(NoEncoder, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + self.scale = scale + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv_second = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### downsampling + self.downconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.downconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.downconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + # self.downconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + self.awb_para = nn.Linear(nf, 3) + self.fine_tune_color_map = nn.Sequential(nn.Conv2d(nf, 3, 1, 1),nn.Sigmoid()) + + def forward(self, x, get_steps=False): + if self.gray_map_bool: + x = torch.cat([x, 1 - x.mean(dim=1, keepdim=True)], dim=1) + if self.concat_color_map: + x = torch.cat([x, x / (x.sum(dim=1, keepdim=True) + 1e-4)], dim=1) + + raw_low_input = x[:, 0:3].exp() + # fea_for_awb = F.adaptive_avg_pool2d(fea_down8, 1).view(-1, 64) + awb_weight = 1 # (1 + self.awb_para(fea_for_awb).unsqueeze(2).unsqueeze(3)) + low_after_awb = raw_low_input * awb_weight + # import pdb + # pdb.set_trace() + color_map = low_after_awb / (low_after_awb.sum(dim=1, keepdims=True) + 1e-4) + dx, dy = self.gradient(color_map) + noise_map = torch.max(torch.stack([dx.abs(), dy.abs()], dim=0), dim=0)[0] + # color_map = self.fine_tune_color_map(torch.cat([color_map, noise_map], dim=1)) + + fea = self.conv_first(torch.cat([x, color_map, noise_map], dim=1)) + fea = self.lrelu(fea) + fea = self.conv_second(fea) + fea_head = F.max_pool2d(fea, 2) + + block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] + block_results = {} + fea = fea_head + for idx, m in enumerate(self.RRDB_trunk.children()): + fea = m(fea) + for b in block_idxs: + if b == idx: + block_results["block_{}".format(idx)] = fea + trunk = self.trunk_conv(fea) + # fea = F.max_pool2d(fea, 2) + fea_down2 = fea_head + trunk + + fea_down4 = self.downconv1(F.interpolate(fea_down2, scale_factor=1 / 2, mode='bilinear', align_corners=False, + recompute_scale_factor=True)) + fea = self.lrelu(fea_down4) + + fea_down8 = self.downconv2( + F.interpolate(fea, scale_factor=1 / 2, mode='bilinear', align_corners=False, recompute_scale_factor=True)) + # fea = self.lrelu(fea_down8) + + # fea_down16 = self.downconv3( + # F.interpolate(fea, scale_factor=1 / 2, mode='bilinear', align_corners=False, recompute_scale_factor=True)) + # fea = self.lrelu(fea_down16) + + results = {'fea_up0': fea_down8*0, + 'fea_up1': fea_down4*0, + 'fea_up2': fea_down2*0, + 'fea_up4': fea_head*0, + 'last_lr_fea': fea_down4*0, + 'color_map': self.fine_tune_color_map(F.interpolate(fea_down2, scale_factor=2))*0 + } + + # 'color_map': color_map} # raw + + if get_steps: + for k, v in block_results.items(): + results[k] = v + return results + else: + return None + + def gradient(self, x): + def sub_gradient(x): + left_shift_x, right_shift_x, grad = torch.zeros_like( + x), torch.zeros_like(x), torch.zeros_like(x) + left_shift_x[:, :, 0:-1] = x[:, :, 1:] + right_shift_x[:, :, 1:] = x[:, :, 0:-1] + grad = 0.5 * (left_shift_x - right_shift_x) + return grad + + return sub_gradient(x), sub_gradient(torch.transpose(x, 2, 3)).transpose(2, 3) diff --git a/models/llflow/models/modules/FlowActNorms.py b/models/llflow/models/modules/FlowActNorms.py new file mode 100644 index 0000000000000000000000000000000000000000..2e45d0a39f14745b70ccc74ffefe6489d8e6059d --- /dev/null +++ b/models/llflow/models/modules/FlowActNorms.py @@ -0,0 +1,128 @@ + + + +import torch +from torch import nn as nn + +from models.modules import thops + + +class _ActNorm(nn.Module): + """ + Activation Normalization + Initialize the bias and scale with a given minibatch, + so that the output per-channel have zero mean and unit variance for that. + + After initialization, `bias` and `logs` will be trained as parameters. + """ + + def __init__(self, num_features, scale=1.): + super().__init__() + # register mean and scale + size = [1, num_features, 1, 1] + self.register_parameter("bias", nn.Parameter(torch.zeros(*size))) + self.register_parameter("logs", nn.Parameter(torch.zeros(*size))) + self.num_features = num_features + self.scale = float(scale) + self.inited = False + + def _check_input_dim(self, input): + return NotImplemented + + def initialize_parameters(self, input): + self._check_input_dim(input) + if not self.training: + return + if (self.bias != 0).any(): + self.inited = True + return + assert input.device == self.bias.device, (input.device, self.bias.device) + with torch.no_grad(): + bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0 + vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) + logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) + self.bias.data.copy_(bias.data) + self.logs.data.copy_(logs.data) + self.inited = True + + def _center(self, input, reverse=False, offset=None): + bias = self.bias + + if offset is not None: + bias = bias + offset + + if not reverse: + return input + bias + else: + return input - bias + + def _scale(self, input, logdet=None, reverse=False, offset=None): + logs = self.logs + + if offset is not None: + logs = logs + offset + + if not reverse: + input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1 + # input = input * torch.exp(logs+logs_offset) + else: + input = input * torch.exp(-logs) + if logdet is not None: + """ + logs is log_std of `mean of channels` + so we need to multiply pixels + """ + dlogdet = thops.sum(logs) * thops.pixels(input) + if reverse: + dlogdet *= -1 + logdet = logdet + dlogdet + return input, logdet + + def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None): + if not self.inited: + self.initialize_parameters(input) + self._check_input_dim(input) + + if offset_mask is not None: + logs_offset *= offset_mask + bias_offset *= offset_mask + # no need to permute dims as old version + if not reverse: + # center and scale + + # self.input = input + input = self._center(input, reverse, bias_offset) + input, logdet = self._scale(input, logdet, reverse, logs_offset) + else: + # scale and center + input, logdet = self._scale(input, logdet, reverse, logs_offset) + input = self._center(input, reverse, bias_offset) + return input, logdet + + +class ActNorm2d(_ActNorm): + def __init__(self, num_features, scale=1.): + super().__init__(num_features, scale) + + def _check_input_dim(self, input): + assert len(input.size()) == 4 + assert input.size(1) == self.num_features, ( + "[ActNorm]: input should be in shape as `BCHW`," + " channels should be {} rather than {}".format( + self.num_features, input.size())) + + +class MaskedActNorm2d(ActNorm2d): + def __init__(self, num_features, scale=1.): + super().__init__(num_features, scale) + + def forward(self, input, mask, logdet=None, reverse=False): + + assert mask.dtype == torch.bool + output, logdet_out = super().forward(input, logdet, reverse) + + input[mask] = output[mask] + logdet[mask] = logdet_out[mask] + + return input, logdet + diff --git a/models/llflow/models/modules/FlowAffineCouplingsAblation.py b/models/llflow/models/modules/FlowAffineCouplingsAblation.py new file mode 100644 index 0000000000000000000000000000000000000000..19f1022bb886aac67d71b9af69bea71c1ce4b106 --- /dev/null +++ b/models/llflow/models/modules/FlowAffineCouplingsAblation.py @@ -0,0 +1,169 @@ + +import torch +from torch import nn as nn + +from models.modules import thops +from models.modules.flow import Conv2d, Conv2dZeros +# from utils.util import opt_get + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret + + + +class CondAffineSeparatedAndCond(nn.Module): + def __init__(self, in_channels, opt): + super().__init__() + self.need_features = True + self.in_channels = in_channels + self.in_channels_rrdb = opt_get(opt, ['network_G', 'flow', 'conditionInFeaDim'], 320) + self.kernel_hidden = 1 + self.affine_eps = 0.0001 + self.n_hidden_layers = 1 + hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels']) + self.hidden_channels = 64 if hidden_channels is None else hidden_channels + + self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001) + + self.channels_for_nn = self.in_channels // 2 + self.channels_for_co = self.in_channels - self.channels_for_nn + + if self.channels_for_nn is None: + self.channels_for_nn = self.in_channels // 2 + + self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb, + out_channels=self.channels_for_co * 2, + hidden_channels=self.hidden_channels, + kernel_hidden=self.kernel_hidden, + n_hidden_layers=self.n_hidden_layers) + + self.fFeatures = self.F(in_channels=self.in_channels_rrdb, + out_channels=self.in_channels * 2, + hidden_channels=self.hidden_channels, + kernel_hidden=self.kernel_hidden, + n_hidden_layers=self.n_hidden_layers) + self.opt = opt + self.le_curve = opt['le_curve'] if opt['le_curve'] is not None else False + if self.le_curve: + self.fCurve = self.F(in_channels=self.in_channels_rrdb, + out_channels=self.in_channels, + hidden_channels=self.hidden_channels, + kernel_hidden=self.kernel_hidden, + n_hidden_layers=self.n_hidden_layers) + + def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None): + if not reverse: + z = input + assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels) + + # Feature Conditional + scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures) + z = z + shiftFt + z = z * scaleFt + logdet = logdet + self.get_logdet(scaleFt) + + # Curve conditional + if self.le_curve: + # logdet = logdet + thops.sum(torch.log(torch.sigmoid(z) * (1 - torch.sigmoid(z))), dim=[1, 2, 3]) + # z = torch.sigmoid(z) + # alpha = self.fCurve(ft) + # alpha = (torch.tanh(alpha + 2.) + self.affine_eps) + # logdet = logdet + thops.sum(torch.log((1 + alpha - 2 * z * alpha).abs()), dim=[1, 2, 3]) + # z = z + alpha * z * (1 - z) + + alpha = self.fCurve(ft) + # alpha = (torch.sigmoid(alpha + 2.) + self.affine_eps) + alpha = torch.relu(alpha) + self.affine_eps + logdet = logdet + thops.sum(torch.log(alpha * torch.pow(z.abs(), alpha - 1)) + self.affine_eps) + z = torch.pow(z.abs(), alpha) * z.sign() + + # Self Conditional + z1, z2 = self.split(z) + scale, shift = self.feature_extract_aff(z1, ft, self.fAffine) + self.asserts(scale, shift, z1, z2) + z2 = z2 + shift + z2 = z2 * scale + + logdet = logdet + self.get_logdet(scale) + z = thops.cat_feature(z1, z2) + output = z + else: + z = input + + # Self Conditional + z1, z2 = self.split(z) + scale, shift = self.feature_extract_aff(z1, ft, self.fAffine) + self.asserts(scale, shift, z1, z2) + z2 = z2 / scale + z2 = z2 - shift + z = thops.cat_feature(z1, z2) + logdet = logdet - self.get_logdet(scale) + + # Curve conditional + if self.le_curve: + # alpha = self.fCurve(ft) + # alpha = (torch.sigmoid(alpha + 2.) + self.affine_eps) + # z = (1 + alpha) / alpha - ( + # alpha + torch.pow(2 * alpha - 4 * alpha * z + torch.pow(alpha, 2) + 1, 0.5) + 1) / ( + # 2 * alpha) + # z = torch.log((z / (1 - z)).clamp(1 / 1000, 1000)) + + alpha = self.fCurve(ft) + alpha = torch.relu(alpha) + self.affine_eps + # alpha = (torch.sigmoid(alpha + 2.) + self.affine_eps) + z = torch.pow(z.abs(), 1 / alpha) * z.sign() + + # Feature Conditional + scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures) + z = z / scaleFt + z = z - shiftFt + logdet = logdet - self.get_logdet(scaleFt) + + output = z + return output, logdet + + def asserts(self, scale, shift, z1, z2): + assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn) + assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co) + assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1]) + assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1]) + + def get_logdet(self, scale): + return thops.sum(torch.log(scale), dim=[1, 2, 3]) + + def feature_extract(self, z, f): + h = f(z) + shift, scale = thops.split_feature(h, "cross") + scale = (torch.sigmoid(scale + 2.) + self.affine_eps) + return scale, shift + + def feature_extract_aff(self, z1, ft, f): + z = torch.cat([z1, ft], dim=1) + h = f(z) + shift, scale = thops.split_feature(h, "cross") + scale = (torch.sigmoid(scale + 2.) + self.affine_eps) + return scale, shift + + def split(self, z): + z1 = z[:, :self.channels_for_nn] + z2 = z[:, self.channels_for_nn:] + assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1]) + return z1, z2 + + def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1): + layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)] + + for _ in range(n_hidden_layers): + layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden])) + layers.append(nn.ReLU(inplace=False)) + layers.append(Conv2dZeros(hidden_channels, out_channels)) + + return nn.Sequential(*layers) diff --git a/models/llflow/models/modules/FlowStep.py b/models/llflow/models/modules/FlowStep.py new file mode 100644 index 0000000000000000000000000000000000000000..89472079f645a78ea165cd9babee546b77b535dd --- /dev/null +++ b/models/llflow/models/modules/FlowStep.py @@ -0,0 +1,136 @@ + + + +import torch +from torch import nn as nn + +import models.modules +import models.modules.Permutations +from models.modules import flow, thops, FlowAffineCouplingsAblation +# from utils.util import opt_get + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret + + + +def getConditional(rrdbResults, position): + img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position] + return img_ft + + +class FlowStep(nn.Module): + FlowPermutation = { + "reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet), + "shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet), + "invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), + "squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), + "resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), + "resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), + "InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), + "InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), + "InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), + "InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), + } + + def __init__(self, in_channels, hidden_channels, + actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive", + LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None, + position=None): + # check configures + assert flow_permutation in FlowStep.FlowPermutation, \ + "float_permutation should be in `{}`".format( + FlowStep.FlowPermutation.keys()) + super().__init__() + self.flow_permutation = flow_permutation + self.flow_coupling = flow_coupling + self.image_injector = image_injector + + self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d' + self.position = normOpt['position'] if normOpt else None + + self.in_shape = in_shape + self.position = position + self.acOpt = acOpt + + # 1. actnorm + self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) + + # 2. permute + if flow_permutation == "invconv": + self.invconv = models.modules.Permutations.InvertibleConv1x1( + in_channels, LU_decomposed=LU_decomposed) + + # 3. coupling + if flow_coupling == "CondAffineSeparatedAndCond": + self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, + opt=opt) + elif flow_coupling == "noCoupling": + pass + else: + raise RuntimeError("coupling not Found:", flow_coupling) + + def forward(self, input, logdet=None, reverse=False, rrdbResults=None): + if not reverse: + return self.normal_flow(input, logdet, rrdbResults) + else: + return self.reverse_flow(input, logdet, rrdbResults) + + def normal_flow(self, z, logdet, rrdbResults=None): + if self.flow_coupling == "bentIdentityPreAct": + z, logdet = self.bentIdentPar(z, logdet, reverse=False) + + # 1. actnorm + if self.norm_type == "ConditionalActNormImageInjector": + img_ft = getConditional(rrdbResults, self.position) + z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False) + elif self.norm_type == "noNorm": + pass + else: + z, logdet = self.actnorm(z, logdet=logdet, reverse=False) + + # 2. permute + z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( + self, z, logdet, False) + + need_features = self.affine_need_features() + + # 3. coupling + if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]: + img_ft = getConditional(rrdbResults, self.position) + z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft) + return z, logdet + + def reverse_flow(self, z, logdet, rrdbResults=None): + + need_features = self.affine_need_features() + + # 1.coupling + if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]: + img_ft = getConditional(rrdbResults, self.position) + z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft) + + # 2. permute + z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( + self, z, logdet, True) + + # 3. actnorm + z, logdet = self.actnorm(z, logdet=logdet, reverse=True) + + return z, logdet + + def affine_need_features(self): + need_features = False + try: + need_features = self.affine.need_features + except: + pass + return need_features diff --git a/models/llflow/models/modules/FlowUpsamplerNet.py b/models/llflow/models/modules/FlowUpsamplerNet.py new file mode 100644 index 0000000000000000000000000000000000000000..40d02adf49305c4335048a7bc12a66827bb0e7e8 --- /dev/null +++ b/models/llflow/models/modules/FlowUpsamplerNet.py @@ -0,0 +1,328 @@ + + + +import numpy as np +import torch +from torch import nn as nn + +import models.modules.Split +from models.modules import flow, thops +from models.modules.Split import Split2d +from models.modules.glow_arch import f_conv2d_bias +from models.modules.FlowStep import FlowStep +# from utils.util import opt_get + + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret + + + + + +class FlowUpsamplerNet(nn.Module): + def __init__(self, image_shape, hidden_channels, K, L=None, + actnorm_scale=1.0, + flow_permutation=None, + flow_coupling="affine", + LU_decomposed=False, opt=None): + + super().__init__() + self.hr_size = opt['datasets']['train']['GT_size'] + self.layers = nn.ModuleList() + self.output_shapes = [] + self.sigmoid_output = opt['sigmoid_output'] if opt['sigmoid_output'] is not None else False + self.L = opt_get(opt, ['network_G', 'flow', 'L']) + self.K = opt_get(opt, ['network_G', 'flow', 'K']) + if isinstance(self.K, int): + self.K = [K for K in [K, ] * (self.L + 1)] + + self.opt = opt + H, W, self.C = image_shape + self.check_image_shape() + + if opt['scale'] == 16: + self.levelToName = { + 0: 'fea_up16', + 1: 'fea_up8', + 2: 'fea_up4', + 3: 'fea_up2', + 4: 'fea_up1', + } + + if opt['scale'] == 8: + self.levelToName = { + 0: 'fea_up8', + 1: 'fea_up4', + 2: 'fea_up2', + 3: 'fea_up1', + 4: 'fea_up0' + } + + elif opt['scale'] == 4: + self.levelToName = { + 0: 'fea_up4', + 1: 'fea_up2', + 2: 'fea_up1', + 3: 'fea_up0', + 4: 'fea_up-1' + } + elif opt['scale'] == 1: + self.levelToName = { + # 0: 'fea_up4', + 1: 'fea_up2', + 2: 'fea_up1', + 3: 'fea_up0', + # 4: 'fea_up-1' + } + + affineInCh = self.get_affineInCh(opt_get) + flow_permutation = self.get_flow_permutation(flow_permutation, opt) + + normOpt = opt_get(opt, ['network_G', 'flow', 'norm']) + + conditional_channels = {} + n_rrdb = self.get_n_rrdb_channels(opt, opt_get) + n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels']) + conditional_channels[0] = n_rrdb + for level in range(1, self.L + 1): + # Level 1 gets conditionals from 2, 3, 4 => L - level + # Level 2 gets conditionals from 3, 4 + # Level 3 gets conditionals from 4 + # Level 4 gets conditionals from None + n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels + conditional_channels[level] = n_rrdb + n_bypass + + # Upsampler + for level in range(1, self.L + 1): + # 1. Squeeze + H, W = self.arch_squeeze(H, W) + + # 2. K FlowStep + self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt) + self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, + flow_permutation, + hidden_channels, normOpt, opt, opt_get, + n_conditinal_channels=conditional_channels[level]) + # Split + self.arch_split(H, W, level, self.L, opt, opt_get) + + if opt_get(opt, ['network_G', 'flow', 'split', 'enable']): + self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2) + else: + self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64) + + self.H = H + self.W = W + self.scaleH = opt['datasets']['train']['GT_size'] / H + self.scaleW = opt['datasets']['train']['GT_size'] / W + + def get_n_rrdb_channels(self, opt, opt_get): + blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) + n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64 + return n_rrdb + + def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation, + hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None): + condAff = self.get_condAffSetting(opt, opt_get) + if condAff is not None: + condAff['in_channels_rrdb'] = n_conditinal_channels + + for k in range(K): + position_name = get_position_name(H, self.opt['scale'], opt) + if normOpt: normOpt['position'] = position_name + + self.layers.append( + FlowStep(in_channels=self.C, + hidden_channels=hidden_channels, + actnorm_scale=actnorm_scale, + flow_permutation=flow_permutation, + flow_coupling=flow_coupling, + acOpt=condAff, + position=position_name, + LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt)) + self.output_shapes.append( + [-1, self.C, H, W]) + + def get_condAffSetting(self, opt, opt_get): + condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None + condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff + return condAff + + def arch_split(self, H, W, L, levels, opt, opt_get): + correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False) + correction = 0 if correct_splits else 1 + if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction: + logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0 + consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5 + position_name = get_position_name(H, self.opt['scale'], opt) + position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None + cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels']) + cond_channels = 0 if cond_channels is None else cond_channels + + t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d') + + if t == 'Split2d': + split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, + cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt) + self.layers.append(split) + self.output_shapes.append([-1, split.num_channels_pass, H, W]) + self.C = split.num_channels_pass + + def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt): + if 'additionalFlowNoAffine' in opt['network_G']['flow']: + n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine']) + for _ in range(n_additionalFlowNoAffine): + self.layers.append( + FlowStep(in_channels=self.C, + hidden_channels=hidden_channels, + actnorm_scale=actnorm_scale, + flow_permutation='invconv', + flow_coupling='noCoupling', + LU_decomposed=LU_decomposed, opt=opt)) + self.output_shapes.append( + [-1, self.C, H, W]) + + def arch_squeeze(self, H, W): + self.C, H, W = self.C * 4, H // 2, W // 2 + self.layers.append(flow.SqueezeLayer(factor=2)) + self.output_shapes.append([-1, self.C, H, W]) + return H, W + + def get_flow_permutation(self, flow_permutation, opt): + flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv') + return flow_permutation + + def get_affineInCh(self, opt_get): + affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] + affineInCh = (len(affineInCh) + 1) * 64 + return affineInCh + + def check_image_shape(self): + assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)" + "self.C == 1 or self.C == 3") + + def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None, + y_onehot=None): + + if reverse: + epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses + + sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot) + if self.sigmoid_output: + sr = torch.sigmoid(sr) + return sr, logdet + else: + assert gt is not None + # assert rrdbResults is not None + if self.sigmoid_output: + gt = torch.log((gt / (1 - gt)).clamp(1 / 1000, 1000)) + z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot) + + return z, logdet + + def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None): + fl_fea = gt + reverse = False + level_conditionals = {} + bypasses = {} + + L = opt_get(self.opt, ['network_G', 'flow', 'L']) + + for level in range(1, L + 1): + bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', + align_corners=False) + + for layer, shape in zip(self.layers, self.output_shapes): + size = shape[2] + level = int(np.log(self.hr_size / size) / np.log(2)) + if level > 0 and level not in level_conditionals.keys(): + if rrdbResults is None: + level_conditionals[level] = None + else: + level_conditionals[level] = rrdbResults[self.levelToName[level]] + + if isinstance(layer, FlowStep): + fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level]) + elif isinstance(layer, Split2d): + fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level], + y_onehot=y_onehot) + else: + fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse) + + z = fl_fea + + if not isinstance(epses, list): + return z, logdet + + epses.append(z) + return epses, logdet + + def forward_preFlow(self, fl_fea, logdet, reverse): + if hasattr(self, 'preFlow'): + for l in self.preFlow: + fl_fea, logdet = l(fl_fea, logdet, reverse=reverse) + return fl_fea, logdet + + def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None): + ft = None if layer.position is None else rrdbResults[layer.position] + fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot) + + if isinstance(epses, list): + epses.append(eps) + return fl_fea, logdet + + def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None): + z = epses.pop() if isinstance(epses, list) else z + + fl_fea = z + # debug.imwrite("fl_fea", fl_fea) + bypasses = {} + level_conditionals = {} + if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True: + for level in range(self.L + 1): + if level not in self.levelToName.keys(): + level_conditionals[level] = None + else: + level_conditionals[level] = rrdbResults[self.levelToName[level]] if rrdbResults else None + + for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): + size = shape[2] + level = int(np.log(self.hr_size / size) / np.log(2)) + # size = fl_fea.shape[2] + # level = int(np.log(160 / size) / np.log(2)) + + if isinstance(layer, Split2d): + fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer, + rrdbResults[self.levelToName[level]], logdet=logdet, + y_onehot=y_onehot) + elif isinstance(layer, FlowStep): + fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level]) + else: + fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True) + + sr = fl_fea + + assert sr.shape[1] == 3 + return sr, logdet + + def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None): + ft = None if layer.position is None else rrdbResults[layer.position] + fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, + eps=epses.pop() if isinstance(epses, list) else None, + eps_std=eps_std, ft=ft, y_onehot=y_onehot) + return fl_fea, logdet + + +def get_position_name(H, scale, opt): + downscale_factor = opt['datasets']['train']['GT_size'] // H + position_name = 'fea_up{}'.format(scale / downscale_factor) + return position_name diff --git a/models/llflow/models/modules/LLFlow_arch.py b/models/llflow/models/modules/LLFlow_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..c36642bfcf2b179999d447fabe2523f23f5f30eb --- /dev/null +++ b/models/llflow/models/modules/LLFlow_arch.py @@ -0,0 +1,248 @@ + + + +import math +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from models.modules.RRDBNet_arch import RRDBNet +from models.modules.ConditionEncoder import ConEncoder1, NoEncoder +from models.modules.FlowUpsamplerNet import FlowUpsamplerNet +import models.modules.thops as thops +import models.modules.flow as flow +from models.modules.color_encoder import ColorEncoder +# from utils.util import opt_get +from models.modules.flow import unsqueeze2d, squeeze2d +from torch.cuda.amp import autocast + + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret + + + + + +class LLFlow(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None): + super(LLFlow, self).__init__() + self.crop_size = opt['datasets']['train']['GT_size'] + self.opt = opt + self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ + None else opt_get(opt, ['datasets', 'train', 'quant']) + if opt['cond_encoder'] == 'ConEncoder1': + self.RRDB = ConEncoder1(in_nc, out_nc, nf, nb, gc, scale, opt) + elif opt['cond_encoder'] == 'NoEncoder': + self.RRDB = None # NoEncoder(in_nc, out_nc, nf, nb, gc, scale, opt) + elif opt['cond_encoder'] == 'RRDBNet': + # if self.opt['encode_color_map']: print('Warning: ''encode_color_map'' is not implemented in RRDBNet') + self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) + else: + print('WARNING: Cannot find the conditional encoder %s, select RRDBNet by default.' % opt['cond_encoder']) + # if self.opt['encode_color_map']: print('Warning: ''encode_color_map'' is not implemented in RRDBNet') + opt['cond_encoder'] = 'RRDBNet' + self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) + + if self.opt['encode_color_map']: + self.color_map_encoder = ColorEncoder(nf=nf, opt=opt) + + hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels']) + hidden_channels = hidden_channels or 64 + self.RRDB_training = True # Default is true + + train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) + set_RRDB_to_train = False + if set_RRDB_to_train and self.RRDB: + self.set_rrdb_training(True) + + self.flowUpsamplerNet = \ + FlowUpsamplerNet((self.crop_size, self.crop_size, 3), hidden_channels, K, + flow_coupling=opt['network_G']['flow']['coupling'], opt=opt) + self.i = 0 + if self.opt['to_yuv']: + self.A_rgb2yuv = torch.nn.Parameter(torch.tensor([[0.299, -0.14714119, 0.61497538], + [0.587, -0.28886916, -0.51496512], + [0.114, 0.43601035, -0.10001026]]), requires_grad=False) + self.A_yuv2rgb = torch.nn.Parameter(torch.tensor([[1., 1., 1.], + [0., -0.39465, 2.03211], + [1.13983, -0.58060, 0]]), requires_grad=False) + if self.opt['align_maxpool']: + self.max_pool = torch.nn.MaxPool2d(3) + + def set_rrdb_training(self, trainable): + if self.RRDB_training != trainable: + for p in self.RRDB.parameters(): + p.requires_grad = trainable + self.RRDB_training = trainable + return True + return False + + def rgb2yuv(self, rgb): + rgb_ = rgb.transpose(1, 3) # input is 3*n*n default + yuv = torch.tensordot(rgb_, self.A_rgb2yuv, 1).transpose(1, 3) + return yuv + + def yuv2rgb(self, yuv): + yuv_ = yuv.transpose(1, 3) # input is 3*n*n default + rgb = torch.tensordot(yuv_, self.A_yuv2rgb, 1).transpose(1, 3) + return rgb + + @autocast() + def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False, + lr_enc=None, + add_gt_noise=False, step=None, y_label=None, align_condition_feature=False, get_color_map=False): + if get_color_map: + color_lr = self.color_map_encoder(lr) + color_gt = nn.functional.avg_pool2d(gt, 11, 1, 5) + color_gt = color_gt / torch.sum(color_gt, 1, keepdim=True) + return color_lr, color_gt + if not reverse: + if epses is not None and gt.device.index is not None: + epses = epses[gt.device.index] + return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, + y_onehot=y_label, align_condition_feature=align_condition_feature) + else: + # assert lr.shape[0] == 1 + assert lr.shape[1] == 3 or lr.shape[1] == 6 + # assert lr.shape[2] == 20 + # assert lr.shape[3] == 20 + # assert z.shape[0] == 1 + # assert z.shape[1] == 3 * 8 * 8 + # assert z.shape[2] == 20 + # assert z.shape[3] == 20 + if reverse_with_grad: + return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, + add_gt_noise=add_gt_noise) + else: + with torch.no_grad(): + return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, + add_gt_noise=add_gt_noise) + + def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None, + align_condition_feature=False): + if self.opt['to_yuv']: + gt = self.rgb2yuv(gt) + if lr_enc is None and self.RRDB: + lr_enc = self.rrdbPreprocessing(lr) + + logdet = torch.zeros_like(gt[:, 0, 0, 0]) + pixels = thops.pixels(gt) + + z = gt + + if add_gt_noise: + # Setup + noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True) + if noiseQuant: + z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant) + logdet = logdet + float(-np.log(self.quant) * pixels) + + # Encode + epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses, + y_onehot=y_onehot) + + objective = logdet.clone() + + # if isinstance(epses, (list, tuple)): + # z = epses[-1] + # else: + # z = epses + z = epses + if 'avg_color_map' in self.opt.keys() and self.opt['avg_color_map']: + if 'avg_pool_color_map' in self.opt.keys() and self.opt['avg_pool_color_map']: + mean = squeeze2d(F.avg_pool2d(lr_enc['color_map'], 7, 1, 3), 8) if random.random() > self.opt[ + 'train_gt_ratio'] else squeeze2d(F.avg_pool2d( + gt / (gt.sum(dim=1, keepdims=True) + 1e-4), 7, 1, 3), 8) + else: + if self.RRDB is not None: + mean = squeeze2d(lr_enc['color_map'], 8) if random.random() > self.opt['train_gt_ratio'] else squeeze2d( + gt/(gt.sum(dim=1, keepdims=True) + 1e-4), 8) + else: + mean = squeeze2d(lr[:,:3],8) + objective = objective + flow.GaussianDiag.logp(mean, torch.tensor(0.).to(z.device), z) + + nll = (-objective) / float(np.log(2.) * pixels) + if self.opt['encode_color_map']: + color_map = self.color_map_encoder(lr) + color_gt = nn.functional.avg_pool2d(gt, 11, 1, 5) + color_gt = color_gt / torch.sum(color_gt, 1, keepdim=True) + color_loss = (color_gt - color_map).abs().mean() + nll = nll + color_loss + if align_condition_feature: + with torch.no_grad(): + gt_enc = self.rrdbPreprocessing(gt) + for k, v in gt_enc.items(): + if k in ['fea_up-1']: # ['fea_up2','fea_up1','fea_up0','fea_up-1']: + if self.opt['align_maxpool']: + nll = nll + (self.max_pool(gt_enc[k]) - self.max_pool(lr_enc[k])).abs().mean() * ( + self.opt['align_weight'] if self.opt['align_weight'] is not None else 1) + else: + nll = nll + (gt_enc[k] - lr_enc[k]).abs().mean() * ( + self.opt['align_weight'] if self.opt['align_weight'] is not None else 1) + if isinstance(epses, list): + return epses, nll, logdet + return z, nll, logdet + + def rrdbPreprocessing(self, lr): + rrdbResults = self.RRDB(lr, get_steps=True) + block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] + if len(block_idxs) > 0: + low_level_features = [rrdbResults["block_{}".format(idx)] for idx in block_idxs] + # low_level_features.append(rrdbResults['color_map']) + concat = torch.cat(low_level_features, dim=1) + + if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False: + keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4'] + if 'fea_up0' in rrdbResults.keys(): + keys.append('fea_up0') + if 'fea_up-1' in rrdbResults.keys(): + keys.append('fea_up-1') + for k in keys: + h = rrdbResults[k].shape[2] + w = rrdbResults[k].shape[3] + rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1) + return rrdbResults + + def get_score(self, disc_loss_sigma, z): + score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \ + z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma) + return -score_real + + def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True): + + logdet = torch.zeros_like(lr[:, 0, 0, 0]) + pixels = thops.pixels(lr) * self.opt['scale'] ** 2 + + if add_gt_noise: + logdet = logdet - float(-np.log(self.quant) * pixels) + + if lr_enc is None and self.RRDB: + lr_enc = self.rrdbPreprocessing(lr) + if self.opt['cond_encoder'] == "NoEncoder": + z = squeeze2d(lr[:,:3],8) + else: + if 'avg_color_map' in self.opt.keys() and self.opt['avg_color_map']: + z = squeeze2d(F.avg_pool2d(lr_enc['color_map'], 7, 1, 3), 8) + else: + z = squeeze2d(lr_enc['color_map'], 8) + x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses, + logdet=logdet) + if self.opt['encode_color_map']: + color_map = self.color_map_encoder(lr) + color_out = nn.functional.avg_pool2d(x, 11, 1, 5) + color_out = color_out / torch.sum(color_out, 1, keepdim=True) + x = x * (color_map / color_out) + if self.opt['to_yuv']: + x = self.yuv2rgb(x) + return x, logdet diff --git a/models/llflow/models/modules/Permutations.py b/models/llflow/models/modules/Permutations.py new file mode 100644 index 0000000000000000000000000000000000000000..259eb3cf92821628c1ebc31641ce694ac00682d6 --- /dev/null +++ b/models/llflow/models/modules/Permutations.py @@ -0,0 +1,59 @@ + + + +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from models.modules import thops + + +class InvertibleConv1x1(nn.Module): + def __init__(self, num_channels, LU_decomposed=False): + super().__init__() + w_shape = [num_channels, num_channels] + w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32) + self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) + self.w_shape = w_shape + self.LU = LU_decomposed + + def get_weight(self, input, reverse): + w_shape = self.w_shape + pixels = thops.pixels(input) + dlogdet = torch.tensor(float('inf')) + while torch.isinf(dlogdet): + try: + dlogdet = torch.slogdet(self.weight)[1] * pixels + except Exception as e: + print(e) + dlogdet = \ + torch.slogdet( + self.weight + (self.weight.mean() * torch.randn(*self.w_shape).to(input.device) * 0.001))[ + 1] * pixels + if not reverse: + weight = self.weight.view(w_shape[0], w_shape[1], 1, 1) + else: + try: + weight = torch.inverse(self.weight.double()).float() \ + .view(w_shape[0], w_shape[1], 1, 1) + except: + weight = torch.inverse(self.weight.double()+ (self.weight.mean() * torch.randn(*self.w_shape).to(input.device) * 0.001).float() \ + .view(w_shape[0], w_shape[1], 1, 1)) + return weight, dlogdet + + def forward(self, input, logdet=None, reverse=False): + """ + log-det = log|abs(|W|)| * pixels + """ + weight, dlogdet = self.get_weight(input, reverse) + if not reverse: + z = F.conv2d(input, weight) + if logdet is not None: + logdet = logdet + dlogdet + return z, logdet + else: + z = F.conv2d(input, weight) + if logdet is not None: + logdet = logdet - dlogdet + return z, logdet diff --git a/models/llflow/models/modules/RRDBNet_arch.py b/models/llflow/models/modules/RRDBNet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..bcebf55ebadc78b124eb787e08a8e58360c9a07f --- /dev/null +++ b/models/llflow/models/modules/RRDBNet_arch.py @@ -0,0 +1,147 @@ + + + +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import models.modules.module_util as mutil +# from utils.util import opt_get + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret + + + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None): + self.opt = opt + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + self.scale = scale + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 2, 1, bias=True) + self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + if self.scale >= 8: + self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + if self.scale >= 16: + self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + if self.scale >= 32: + self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x, get_steps=False): + fea = self.conv_first(x) + + block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] + block_results = {} + + for idx, m in enumerate(self.RRDB_trunk.children()): + fea = m(fea) + for b in block_idxs: + if b == idx: + block_results["block_{}".format(idx)] = fea + trunk = self.trunk_conv(fea) + fea = F.max_pool2d(fea, 2) + last_lr_fea = fea + trunk + + fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest')) + fea = self.lrelu(fea_up2) + + fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')) + fea = self.lrelu(fea_up4) + + fea_up8 = None + fea_up16 = None + fea_up32 = None + + if self.scale >= 8: + fea_up8 = self.upconv3(fea) + fea = self.lrelu(fea_up8) + if self.scale >= 16: + fea_up16 = self.upconv4(fea) + fea = self.lrelu(fea_up16) + if self.scale >= 32: + fea_up32 = self.upconv5(fea) + fea = self.lrelu(fea_up32) + + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + results = {'last_lr_fea': last_lr_fea, + 'fea_up1': last_lr_fea, + 'fea_up2': fea_up2, + 'fea_up4': fea_up4, # raw + 'fea_up8': fea_up8, + 'fea_up16': fea_up16, + 'fea_up32': fea_up32, + 'out': out} # raw + + fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False + if fea_up0_en: + results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) + fea_upn1_en = True # opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False + if fea_upn1_en: + results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) + + if get_steps: + for k, v in block_results.items(): + results[k] = v + return results + else: + return out diff --git a/models/llflow/models/modules/Split.py b/models/llflow/models/modules/Split.py new file mode 100644 index 0000000000000000000000000000000000000000..82986c35681ed5f933231c1157619bc476244f42 --- /dev/null +++ b/models/llflow/models/modules/Split.py @@ -0,0 +1,88 @@ + + + +import torch +from torch import nn as nn + +from models.modules import thops +from models.modules.FlowStep import FlowStep +from models.modules.flow import Conv2dZeros, GaussianDiag +# from utils.util import opt_get + + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret + + + + + +class Split2d(nn.Module): + def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None): + super().__init__() + + self.num_channels_consume = int(round(num_channels * consume_ratio)) + self.num_channels_pass = num_channels - self.num_channels_consume + + self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels, + out_channels=self.num_channels_consume * 2) + self.logs_eps = logs_eps + self.position = position + self.opt = opt + + def split2d_prior(self, z, ft): + if ft is not None: + z = torch.cat([z, ft], dim=1) + h = self.conv(z) + return thops.split_feature(h, "cross") + + def exp_eps(self, logs): + return torch.exp(logs) + self.logs_eps + + def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None): + if not reverse: + # self.input = input + z1, z2 = self.split_ratio(input) + mean, logs = self.split2d_prior(z1, ft) + + eps = (z2 - mean) / self.exp_eps(logs) + + logdet = logdet + self.get_logdet(logs, mean, z2) + + # print(logs.shape, mean.shape, z2.shape) + # self.eps = eps + # print('split, enc eps:', eps) + return z1, logdet, eps + else: + z1 = input + mean, logs = self.split2d_prior(z1, ft) + + if eps is None: + #print("WARNING: eps is None, generating eps untested functionality!") + eps = GaussianDiag.sample_eps(mean.shape, eps_std) + + eps = eps.to(mean.device) + z2 = mean + self.exp_eps(logs) * eps + + z = thops.cat_feature(z1, z2) + logdet = logdet - self.get_logdet(logs, mean, z2) + + return z, logdet + # return z, logdet, eps + + def get_logdet(self, logs, mean, z2): + logdet_diff = GaussianDiag.logp(mean, logs, z2) + # print("Split2D: logdet diff", logdet_diff.item()) + return logdet_diff + + def split_ratio(self, input): + z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...] + return z1, z2 \ No newline at end of file diff --git a/models/llflow/models/modules/__init__.py b/models/llflow/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/llflow/models/modules/__pycache__/ConditionEncoder.cpython-310.pyc b/models/llflow/models/modules/__pycache__/ConditionEncoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b4c256cf296d531377fa440bf7c3ba13ebbd2ac Binary files /dev/null and b/models/llflow/models/modules/__pycache__/ConditionEncoder.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/FlowActNorms.cpython-310.pyc b/models/llflow/models/modules/__pycache__/FlowActNorms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9bfc74ab52fad1b0aeb87ea606b9b6550c4ed5a Binary files /dev/null and b/models/llflow/models/modules/__pycache__/FlowActNorms.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/FlowAffineCouplingsAblation.cpython-310.pyc b/models/llflow/models/modules/__pycache__/FlowAffineCouplingsAblation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7ece24a059a4aac1e762478b3e6d9e78d332297 Binary files /dev/null and b/models/llflow/models/modules/__pycache__/FlowAffineCouplingsAblation.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/FlowStep.cpython-310.pyc b/models/llflow/models/modules/__pycache__/FlowStep.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6964f925d0bbd213924d988160693bc9e0de8698 Binary files /dev/null and b/models/llflow/models/modules/__pycache__/FlowStep.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/FlowUpsamplerNet.cpython-310.pyc b/models/llflow/models/modules/__pycache__/FlowUpsamplerNet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b27bb17edcc368bfa38d5e3759c048fe64a38d5 Binary files /dev/null and b/models/llflow/models/modules/__pycache__/FlowUpsamplerNet.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/LLFlow_arch.cpython-310.pyc b/models/llflow/models/modules/__pycache__/LLFlow_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55c7f0ab2f52fd16d5dd4bd7ecfdfe63890d8274 Binary files /dev/null and b/models/llflow/models/modules/__pycache__/LLFlow_arch.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/Permutations.cpython-310.pyc b/models/llflow/models/modules/__pycache__/Permutations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df901ea81f993c00a8ef1811af4d2822f2f760fe Binary files /dev/null and b/models/llflow/models/modules/__pycache__/Permutations.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/RRDBNet_arch.cpython-310.pyc b/models/llflow/models/modules/__pycache__/RRDBNet_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7ec3f62baf2c49ba46d720bc171ff3603448327 Binary files /dev/null and b/models/llflow/models/modules/__pycache__/RRDBNet_arch.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/Split.cpython-310.pyc b/models/llflow/models/modules/__pycache__/Split.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d45c62c9442d22b8ae2cb8d216b3ba85e39e1eb5 Binary files /dev/null and b/models/llflow/models/modules/__pycache__/Split.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/__init__.cpython-310.pyc b/models/llflow/models/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8e03dca61d6a51c1411660ae3724b3a6a8da622 Binary files /dev/null and b/models/llflow/models/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/base_layers.cpython-310.pyc b/models/llflow/models/modules/__pycache__/base_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de1d26b29937f212c958693c69fd27b189945b2a Binary files /dev/null and b/models/llflow/models/modules/__pycache__/base_layers.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/color_encoder.cpython-310.pyc b/models/llflow/models/modules/__pycache__/color_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..829b4810beb2046f13fcc55a1af8907daaa26ef5 Binary files /dev/null and b/models/llflow/models/modules/__pycache__/color_encoder.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/flow.cpython-310.pyc b/models/llflow/models/modules/__pycache__/flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db9ba6632c2d6a112deaffabfc8440c88ffbc206 Binary files /dev/null and b/models/llflow/models/modules/__pycache__/flow.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/glow_arch.cpython-310.pyc b/models/llflow/models/modules/__pycache__/glow_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c25d4fe29259d38f1130e9d2aa1348eaf38d54a Binary files /dev/null and b/models/llflow/models/modules/__pycache__/glow_arch.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/module_util.cpython-310.pyc b/models/llflow/models/modules/__pycache__/module_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32df284fc611ceecbbf25fb6dc92e783725fb4aa Binary files /dev/null and b/models/llflow/models/modules/__pycache__/module_util.cpython-310.pyc differ diff --git a/models/llflow/models/modules/__pycache__/thops.cpython-310.pyc b/models/llflow/models/modules/__pycache__/thops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0db9a401ed1c20597053b055bbec844156c4c22c Binary files /dev/null and b/models/llflow/models/modules/__pycache__/thops.cpython-310.pyc differ diff --git a/models/llflow/models/modules/base_layers.py b/models/llflow/models/modules/base_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ad7bd6ef8e5e620245f4a57d0e84e10bcd6e97 --- /dev/null +++ b/models/llflow/models/modules/base_layers.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MSIA(nn.Module): + def __init__(self, filters, activation='lrelu'): + super().__init__() + # Down 1 + self.conv_bn_relu_1 = Conv_BN_Relu(filters, activation) + # Down 2 + self.down_2 = MaxPooling2D(2, 2) + self.conv_bn_relu_2 = Conv_BN_Relu(filters, activation) + self.deconv_2 = ConvTranspose2D(filters, filters) + # Down 4 + self.down_4 = MaxPooling2D(2, 2) + self.conv_bn_relu_4 = Conv_BN_Relu(filters, activation, kernel=1) + self.deconv_4_1 = ConvTranspose2D(filters, filters) + self.deconv_4_2 = ConvTranspose2D(filters, filters) + # output + self.out = Conv2D(filters*4, filters) + + def forward(self, R, I_att): + R_att = R * I_att + # Down 1 + msia_1 = self.conv_bn_relu_1(R_att) + # Down 2 + down_2 = self.down_2(R_att) + conv_bn_relu_2 = self.conv_bn_relu_2(down_2) + msia_2 = self.deconv_2(conv_bn_relu_2) + # Down 4 + down_4 = self.down_4(down_2) + conv_bn_relu_4 = self.conv_bn_relu_4(down_4) + deconv_4 = self.deconv_4_1(conv_bn_relu_4) + msia_4 = self.deconv_4_2(deconv_4) + # concat + concat = torch.cat([R, msia_1, msia_2, msia_4], dim=1) + out = self.out(concat) + return out + + +class Conv_BN_Relu(nn.Module): + def __init__(self, channels, activation='lrelu', kernel=3): + super().__init__() + self.ActivationLayer = nn.LeakyReLU(inplace=True) + if activation == 'relu': + self.ActivationLayer = nn.ReLU(inplace=True) + self.conv_bn_relu = nn.Sequential( + nn.Conv2d(channels, channels, kernel_size=kernel, padding=kernel//2), + nn.BatchNorm2d(channels, momentum=0.99), # 原论文用的tf.layer的默认参数 + self.ActivationLayer, + ) + + def forward(self, x): + return self.conv_bn_relu(x) + + +class DoubleConv(nn.Module): + def __init__(self, in_channels, out_channels, activation='lrelu'): + super().__init__() + self.doubleconv = nn.Sequential( + Conv2D(in_channels, out_channels, activation), + Conv2D(out_channels,out_channels, activation) + ) + + def forward(self, x): + return self.doubleconv(x) + +class ResConv(nn.Module): + def __init__(self, in_channels, out_channels, activation='lrelu'): + super().__init__() + self.relu = nn.LeakyReLU(0.2, inplace=True) + if activation == 'relu': + self.relu = nn.ReLU(inplace=True) + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.8) + self.cbam = CBAM(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.8) + + def forward(self, x): + conv1 = self.conv1(x) + bn1 = self.bn1(conv1) + x1 = self.relu(bn1) + cbam = self.cbam(x1) + conv2 = self.conv2(cbam) + bn2 = self.bn1(conv2) + out = bn2 + x + return out + +class Conv2D(nn.Module): + def __init__(self, in_channels, out_channels, activation='lrelu', stride=1): + super().__init__() + self.ActivationLayer = nn.LeakyReLU(inplace=True) + if activation == 'relu': + self.ActivationLayer = nn.ReLU(inplace=True) + self.conv_relu = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1), + self.ActivationLayer, + ) + + def forward(self, x): + return self.conv_relu(x) + + +class ConvTranspose2D(nn.Module): + def __init__(self, in_channels, out_channels, activation='lrelu'): + super().__init__() + self.ActivationLayer = nn.LeakyReLU(inplace=True) + if activation == 'relu': + self.ActivationLayer = nn.ReLU(inplace=True) + self.deconv_relu = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0), + self.ActivationLayer, + ) + + def forward(self, x): + return self.deconv_relu(x) + + +class MaxPooling2D(nn.Module): + def __init__(self, kernel_size=2, stride=2): + super().__init__() + self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride) + + def forward(self, x): + return self.maxpool(x) + + +class AvgPooling2D(nn.Module): + def __init__(self, kernel_size=2, stride=2): + super().__init__() + self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) + + def forward(self, x): + return self.avgpool(x) + + +class ChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=16): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.sharedMLP = nn.Sequential( + nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(), + nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avgout = self.sharedMLP(self.avg_pool(x)) + maxout = self.sharedMLP(self.max_pool(x)) + return self.sigmoid(avgout + maxout) + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=3): + super().__init__() + self.conv = nn.Conv2d(2,1,kernel_size, padding=1, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avgout = torch.mean(x, dim=1, keepdim=True) + maxout, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avgout, maxout], dim=1) + x = self.conv(x) + return self.sigmoid(x) + + +class CBAM(nn.Module): + def __init__(self, planes): + super().__init__() + self.ca = ChannelAttention(planes) + self.sa = SpatialAttention() + def forward(self, x): + x = self.ca(x) * x + out = self.sa(x) * x + return x + + +class Concat(nn.Module): + def forward(self, x, y): + _, _, xh, xw = x.size() + _, _, yh, yw = y.size() + diffY = xh - yh + diffX = xw - yw + y = F.pad(y, (diffX // 2, diffX - diffX//2, + diffY // 2, diffY - diffY//2)) + return torch.cat((x, y), dim=1) \ No newline at end of file diff --git a/models/llflow/models/modules/color_encoder.py b/models/llflow/models/modules/color_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e9007526edea11f3e0e6fea023b0fd9f4207c52c --- /dev/null +++ b/models/llflow/models/modules/color_encoder.py @@ -0,0 +1,121 @@ + + + +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import models.modules.module_util as mutil +# from utils.util import opt_get +from models.modules.base_layers import * + + + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret + + + + + + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class ColorEncoder(nn.Module): + def __init__(self, nf, opt=None): + self.opt = opt + super(ColorEncoder, self).__init__() + self.conv_input = Conv2D(3, nf) + # top path build Reflectance map + self.maxpool_r1 = MaxPooling2D() + self.conv_r1 = Conv2D(nf, nf * 2) + self.maxpool_r2 = MaxPooling2D() + self.conv_r2 = Conv2D(nf * 2, nf * 4) + self.deconv_r1 = ConvTranspose2D(nf * 4, nf * 2) + self.concat_r1 = Concat() + self.conv_r3 = Conv2D(nf * 4, nf * 2) + self.deconv_r2 = ConvTranspose2D(nf * 2, nf) + self.concat_r2 = Concat() + self.conv_r4 = Conv2D(nf * 2, nf) + self.conv_r5 = nn.Conv2d(nf, 3, kernel_size=3, padding=1) + # self.R_out = nn.Sigmoid() + self.R_out = nn.Sigmoid()# (negative_slope=0.2, inplace=True) + # bottom path build Illumination map + # self.conv_i1 = Conv2D(nf, nf) + # self.concat_i1 = Concat() + # self.conv_i2 = nn.Conv2d(nf * 2, 1, kernel_size=3, padding=1) + # self.I_out = nn.Sigmoid() + + def forward(self, x, get_steps=False): + assert not get_steps + + # x = torch.cat([x, color_x], dim=1) + conv_input = self.conv_input(x) + # build Reflectance map + maxpool_r1 = self.maxpool_r1(conv_input) + conv_r1 = self.conv_r1(maxpool_r1) + maxpool_r2 = self.maxpool_r2(conv_r1) + conv_r2 = self.conv_r2(maxpool_r2) + deconv_r1 = self.deconv_r1(conv_r2) + concat_r1 = self.concat_r1(conv_r1, deconv_r1) + conv_r3 = self.conv_r3(concat_r1) + deconv_r2 = self.deconv_r2(conv_r3) + concat_r2 = self.concat_r2(conv_input, deconv_r2) + conv_r4 = self.conv_r4(concat_r2) + conv_r5 = self.conv_r5(conv_r4) + R_out = self.R_out(conv_r5) + color_x = nn.functional.avg_pool2d(R_out, self.opt['avg_kernel_size'], 1, self.opt['avg_kernel_size']//2) + # color_x = color_x / torch.sum(color_x, 1, keepdim=True) + # build Illumination map + # conv_i1 = self.conv_i1(conv_input) + # concat_i1 = self.concat_i1(conv_r4, conv_i1) + # conv_i2 = self.conv_i2(concat_i1) + # I_out = self.I_out(conv_i2) + + return color_x diff --git a/models/llflow/models/modules/flow.py b/models/llflow/models/modules/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..701facb962ce9deeac75a1be7635d7a89f5d1cde --- /dev/null +++ b/models/llflow/models/modules/flow.py @@ -0,0 +1,159 @@ + + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from models.modules.FlowActNorms import ActNorm2d +from . import thops + + +class Conv2d(nn.Conv2d): + pad_dict = { + "same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)], + "valid": lambda kernel, stride: [0 for _ in kernel] + } + + @staticmethod + def get_padding(padding, kernel_size, stride): + # make paddding + if isinstance(padding, str): + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] + if isinstance(stride, int): + stride = [stride, stride] + padding = padding.lower() + try: + padding = Conv2d.pad_dict[padding](kernel_size, stride) + except KeyError: + raise ValueError("{} is not supported".format(padding)) + return padding + + def __init__(self, in_channels, out_channels, + kernel_size=[3, 3], stride=[1, 1], + padding="same", do_actnorm=True, weight_std=0.05): + padding = Conv2d.get_padding(padding, kernel_size, stride) + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, bias=(not do_actnorm)) + # init weight with std + self.weight.data.normal_(mean=0.0, std=weight_std) + if not do_actnorm: + self.bias.data.zero_() + else: + self.actnorm = ActNorm2d(out_channels) + self.do_actnorm = do_actnorm + + def forward(self, input): + x = super().forward(input) + if self.do_actnorm: + x, _ = self.actnorm(x) + return x + + +class Conv2dZeros(nn.Conv2d): + def __init__(self, in_channels, out_channels, + kernel_size=[3, 3], stride=[1, 1], + padding="same", logscale_factor=3): + padding = Conv2d.get_padding(padding, kernel_size, stride) + super().__init__(in_channels, out_channels, kernel_size, stride, padding) + # logscale_factor + self.logscale_factor = logscale_factor + self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1))) + # init + self.weight.data.zero_() + self.bias.data.zero_() + + def forward(self, input): + output = super().forward(input) + return output * torch.exp(self.logs * self.logscale_factor) + + +class GaussianDiag: + Log2PI = float(np.log(2 * np.pi)) + + @staticmethod + def likelihood(mean, logs, x): + """ + lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } + k = 1 (Independent) + Var = logs ** 2 + """ + if mean is None and logs is None: + return -0.5 * (x ** 2 + GaussianDiag.Log2PI) + else: + return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI) + + @staticmethod + def logp(mean, logs, x): + likelihood = 0 + if isinstance(x, (list, tuple)): + for x_ in x: likelihood += thops.sum(GaussianDiag.likelihood(mean, logs, x_), dim=[1, 2, 3]) + else: + likelihood = thops.sum(GaussianDiag.likelihood(mean, logs, x), dim=[1, 2, 3]) + return likelihood + # likelihood = GaussianDiag.likelihood(mean, logs, x) + # return thops.sum(likelihood, dim=[1, 2, 3]) + + @staticmethod + def sample(mean, logs, eps_std=None): + eps_std = eps_std or 1 + eps = torch.normal(mean=torch.zeros_like(mean), + std=torch.ones_like(logs) * eps_std) + return mean + torch.exp(logs) * eps + + @staticmethod + def sample_eps(shape, eps_std, seed=None): + if seed is not None: + torch.manual_seed(seed) + eps = torch.normal(mean=torch.zeros(shape), + std=torch.ones(shape) * eps_std) + return eps + + +def squeeze2d(input, factor=2): + assert factor >= 1 and isinstance(factor, int) + if factor == 1: + return input + size = input.size() + B = size[0] + C = size[1] + H = size[2] + W = size[3] + assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor)) + x = input.view(B, C, H // factor, factor, W // factor, factor) + x = x.permute(0, 1, 3, 5, 2, 4).contiguous() + x = x.view(B, C * factor * factor, H // factor, W // factor) + return x + + +def unsqueeze2d(input, factor=2): + assert factor >= 1 and isinstance(factor, int) + factor2 = factor ** 2 + if factor == 1: + return input + size = input.size() + B = size[0] + C = size[1] + H = size[2] + W = size[3] + assert C % (factor2) == 0, "{}".format(C) + x = input.view(B, C // factor2, factor, factor, H, W) + x = x.permute(0, 1, 4, 2, 5, 3).contiguous() + x = x.view(B, C // (factor2), H * factor, W * factor) + return x + + +class SqueezeLayer(nn.Module): + def __init__(self, factor): + super().__init__() + self.factor = factor + + def forward(self, input, logdet=None, reverse=False): + if not reverse: + output = squeeze2d(input, self.factor) # Squeeze in forward + return output, logdet + else: + output = unsqueeze2d(input, self.factor) + return output, logdet diff --git a/models/llflow/models/modules/glow_arch.py b/models/llflow/models/modules/glow_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..1f563e85cbb9f87a70e0a361fa143e49b4fc6338 --- /dev/null +++ b/models/llflow/models/modules/glow_arch.py @@ -0,0 +1,15 @@ + + + +import torch.nn as nn + + +def f_conv2d_bias(in_channels, out_channels): + def padding_same(kernel, stride): + return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)] + + padding = padding_same([3, 3], [1, 1]) + assert padding == [1, 1], padding + return nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1, + bias=True)) diff --git a/models/llflow/models/modules/loss.py b/models/llflow/models/modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ddf7b6de52620d363e4c35f5b3646e0daac94866 --- /dev/null +++ b/models/llflow/models/modules/loss.py @@ -0,0 +1,77 @@ + + + +import torch +import torch.nn as nn + + +class CharbonnierLoss(nn.Module): + """Charbonnier Loss (L1)""" + + def __init__(self, eps=1e-6): + super(CharbonnierLoss, self).__init__() + self.eps = eps + + def forward(self, x, y): + diff = x - y + loss = torch.sum(torch.sqrt(diff * diff + self.eps)) + return loss + + +# Define GAN loss: [vanilla | lsgan | wgan-gp] +class GANLoss(nn.Module): + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type.lower() + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'gan' or self.gan_type == 'ragan': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan-gp': + + def wgan_loss(input, target): + # target is boolean + return -1 * input.mean() if target else input.mean() + + self.loss = wgan_loss + else: + raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) + + def get_target_label(self, input, target_is_real): + if self.gan_type == 'wgan-gp': + return target_is_real + if target_is_real: + return torch.empty_like(input).fill_(self.real_label_val) + else: + return torch.empty_like(input).fill_(self.fake_label_val) + + def forward(self, input, target_is_real): + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + return loss + + +class GradientPenaltyLoss(nn.Module): + def __init__(self, device=torch.device('cpu')): + super(GradientPenaltyLoss, self).__init__() + self.register_buffer('grad_outputs', torch.Tensor()) + self.grad_outputs = self.grad_outputs.to(device) + + def get_grad_outputs(self, input): + if self.grad_outputs.size() != input.size(): + self.grad_outputs.resize_(input.size()).fill_(1.0) + return self.grad_outputs + + def forward(self, interp, interp_crit): + grad_outputs = self.get_grad_outputs(interp_crit) + grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, + grad_outputs=grad_outputs, create_graph=True, + retain_graph=True, only_inputs=True)[0] + grad_interp = grad_interp.view(grad_interp.size(0), -1) + grad_interp_norm = grad_interp.norm(2, dim=1) + + loss = ((grad_interp_norm - 1)**2).mean() + return loss diff --git a/models/llflow/models/modules/module_util.py b/models/llflow/models/modules/module_util.py new file mode 100644 index 0000000000000000000000000000000000000000..471cfa68254915a582f6c0efe4b158aa0d11b9d5 --- /dev/null +++ b/models/llflow/models/modules/module_util.py @@ -0,0 +1,82 @@ + + + +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualBlock_noBN(nn.Module): + '''Residual block w/o BN + ---Conv-ReLU-Conv-+- + |________________| + ''' + + def __init__(self, nf=64): + super(ResidualBlock_noBN, self).__init__() + self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + # initialization + initialize_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = F.relu(self.conv1(x), inplace=True) + out = self.conv2(out) + return identity + out + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): + """Warp an image or feature map with optical flow + Args: + x (Tensor): size (N, C, H, W) + flow (Tensor): size (N, H, W, 2), normal value + interp_mode (str): 'nearest' or 'bilinear' + padding_mode (str): 'zeros' or 'border' or 'reflection' + + Returns: + Tensor: warped image or feature map + """ + assert x.size()[-2:] == flow.size()[1:3] + B, C, H, W = x.size() + # mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + grid = grid.type_as(x) + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) + return output diff --git a/models/llflow/models/modules/thops.py b/models/llflow/models/modules/thops.py new file mode 100644 index 0000000000000000000000000000000000000000..43414698a5689209c6b6ec0b653224635b4e6f21 --- /dev/null +++ b/models/llflow/models/modules/thops.py @@ -0,0 +1,55 @@ + + + +import torch + + +def sum(tensor, dim=None, keepdim=False): + if dim is None: + # sum up all dim + return torch.sum(tensor) + else: + if isinstance(dim, int): + dim = [dim] + dim = sorted(dim) + for d in dim: + tensor = tensor.sum(dim=d, keepdim=True) + if not keepdim: + for i, d in enumerate(dim): + tensor.squeeze_(d-i) + return tensor + + +def mean(tensor, dim=None, keepdim=False): + if dim is None: + # mean all dim + return torch.mean(tensor) + else: + if isinstance(dim, int): + dim = [dim] + dim = sorted(dim) + for d in dim: + tensor = tensor.mean(dim=d, keepdim=True) + if not keepdim: + for i, d in enumerate(dim): + tensor.squeeze_(d-i) + return tensor + + +def split_feature(tensor, type="split"): + """ + type = ["split", "cross"] + """ + C = tensor.size(1) + if type == "split": + return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] + elif type == "cross": + return tensor[:, 0::2, ...], tensor[:, 1::2, ...] + + +def cat_feature(tensor_a, tensor_b): + return torch.cat((tensor_a, tensor_b), dim=1) + + +def pixels(tensor): + return int(tensor.size(2) * tensor.size(3)) \ No newline at end of file diff --git a/models/llflow/models/networks.py b/models/llflow/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..c6798f4e035124512fc207760c370f2f1f0e2b1e --- /dev/null +++ b/models/llflow/models/networks.py @@ -0,0 +1,37 @@ +import importlib + +import torch +import logging +import models.modules.RRDBNet_arch as RRDBNet_arch + +logger = logging.getLogger('base') + + +def find_model_using_name(model_name): + model_filename = "models.modules." + model_name + "_arch" + modellib = importlib.import_module(model_filename) + + model = None + target_model_name = model_name.replace('_Net', '') + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower(): + model = cls + + if model is None: + print( + "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % ( + model_filename, target_model_name)) + exit(0) + + return model + +def define_Flow(opt, step): + opt_net = opt['network_G'] + which_model = opt_net['which_model_G'] + + Arch = find_model_using_name(which_model) + netG = Arch(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step) + + return netG + diff --git a/models/llflow/option_.py b/models/llflow/option_.py new file mode 100644 index 0000000000000000000000000000000000000000..b09794292f17747de3882a7bd5bb4d200154b99e --- /dev/null +++ b/models/llflow/option_.py @@ -0,0 +1,141 @@ +import os +import os.path as osp +import logging +import yaml +from models.llflow.util import OrderedYaml + +Loader, Dumper = OrderedYaml() + + +def parse(opt_path, is_train=True): + with open(opt_path, mode='r') as f: + opt = yaml.load(f, Loader=Loader) + # export CUDA_VISIBLE_DEVICES + gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', [])) + # os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + # print('export CUDA_VISIBLE_DEVICES=' + gpu_list) + opt['is_train'] = is_train + if opt['distortion'] == 'sr': + scale = opt['scale'] + + # datasets + for phase, dataset in opt['datasets'].items(): + phase = phase.split('_')[0] + dataset['phase'] = phase + if opt['distortion'] == 'sr': + dataset['scale'] = scale + is_lmdb = False + if dataset.get('dataroot_GT', None) is not None: + dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) + if dataset['dataroot_GT'].endswith('lmdb'): + is_lmdb = True + if dataset.get('dataroot_LQ', None) is not None: + dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) + if dataset['dataroot_LQ'].endswith('lmdb'): + is_lmdb = True + dataset['data_type'] = 'lmdb' if is_lmdb else 'img' + # if dataset['mode'].endswith('mc'): # for memcached + # dataset['data_type'] = 'mc' + # dataset['mode'] = dataset['mode'].replace('_mc', '') + + # path + for key, path in opt['path'].items(): + if path and key in opt['path'] and key != 'strict_load': + opt['path'][key] = osp.expanduser(path) + opt['path']['root'] = osp.abspath( + osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) + if is_train: + experiments_root = osp.join( + opt['path']['root'], 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_state'] = osp.join( + experiments_root, 'training_state') + opt['path']['log'] = experiments_root + opt['path']['val_images'] = osp.join(experiments_root, 'val_images') + + # change some options for debug mode + if 'debug' in opt['name']: + opt['train']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + if not opt['path'].get('results_root', None): + results_root = osp.join( + opt['path']['root'], 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = opt['path']['results_root'] + + # network + if opt['distortion'] == 'sr': + opt['network_G']['scale'] = scale + + # relative learning rate + if 'train' in opt: + niter = opt['train']['niter'] + if 'T_period_rel' in opt['train']: + opt['train']['T_period'] = [int(x * niter) + for x in opt['train']['T_period_rel']] + if 'restarts_rel' in opt['train']: + opt['train']['restarts'] = [int(x * niter) + for x in opt['train']['restarts_rel']] + if 'lr_steps_rel' in opt['train']: + opt['train']['lr_steps'] = [int(x * niter) + for x in opt['train']['lr_steps_rel']] + if 'lr_steps_inverse_rel' in opt['train']: + opt['train']['lr_steps_inverse'] = [ + int(x * niter) for x in opt['train']['lr_steps_inverse_rel']] + print(opt['train']) + + return opt + + +def dict2str(opt, indent_l=1): + '''dict to string for logger''' + msg = '' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_l * 2) + k + ':[\n' + msg += dict2str(v, indent_l + 1) + msg += ' ' * (indent_l * 2) + ']\n' + else: + msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' + return msg + + +class NoneDict(dict): + def __missing__(self, key): + return None + + +# convert to NoneDict, which return None for missing key. +def dict_to_nonedict(opt): + if isinstance(opt, dict): + new_opt = dict() + for key, sub_opt in opt.items(): + new_opt[key] = dict_to_nonedict(sub_opt) + return NoneDict(**new_opt) + elif isinstance(opt, list): + return [dict_to_nonedict(sub_opt) for sub_opt in opt] + else: + return opt + + +def check_resume(opt, resume_iter): + '''Check resume states and pretrain_model paths''' + logger = logging.getLogger('base') + if opt['path']['resume_state']: + if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( + 'pretrain_model_D', None) is not None: + logger.warning( + 'pretrain_model path will be ignored when resuming training.') + + opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], + '{}_G.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_G] to ' + + opt['path']['pretrain_model_G']) + if 'gan' in opt['model']: + opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], + '{}_D.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_D] to ' + + opt['path']['pretrain_model_D']) diff --git a/models/llflow/util.py b/models/llflow/util.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6ea372e7f5ed83ad82afb246542327f7c78709 --- /dev/null +++ b/models/llflow/util.py @@ -0,0 +1,262 @@ +import yaml +import glob +import os +import sys +import time +import math +from datetime import datetime +import random +import logging +from collections import OrderedDict + +import natsort +import numpy as np +import cv2 +import torch +from torchvision.utils import make_grid +from shutil import get_terminal_size +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / + float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm( + _1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand( + channel, 1, window_size, window_size).contiguous()) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, + padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, + padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, + padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ + ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + + +def ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1.mean(dim=0, keepdims=True), img2.mean(dim=0, keepdims=True), window, window_size, channel, size_average) + + +try: + from yaml import CLoader as Loader, CDumper as Dumper +except ImportError: + from yaml import Loader, Dumper + + +def OrderedYaml(): + '''yaml orderedDict support''' + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +#################### +# miscellaneous +#################### + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + logger = logging.getLogger('base') + logger.info( + 'Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): + '''set up logger''' + lg = logging.getLogger(logger_name) + formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', + datefmt='%y-%m-%d %H:%M:%S') + lg.setLevel(level) + if tofile: + log_file = os.path.join( + root, phase + '_{}.log'.format(get_timestamp())) + fh = logging.FileHandler(log_file, mode='w') + fh.setFormatter(formatter) + lg.addHandler(fh) + if screen: + sh = logging.StreamHandler() + sh.setFormatter(formatter) + lg.addHandler(sh) + + +#################### +# image convert +#################### + + +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + if hasattr(tensor, 'detach'): + tensor = tensor.detach() + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp + tensor = (tensor - min_max[0]) / \ + (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int( + math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = np.clip((img_np * 255.0).round(), 0, 255) + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +def save_img(img, img_path, mode='RGB'): + cv2.imwrite(img_path, img) + + +#################### +# metric +#################### + + +def calculate_psnr(img1, img2): + # img1 and img2 have range [0, 255] + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2) ** 2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +def get_resume_paths(opt): + resume_state_path = None + resume_model_path = None + ts = opt_get(opt, ['path', 'training_state']) + if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None: + wildcard = os.path.join(ts, "*") + paths = natsort.natsorted(glob.glob(wildcard)) + if len(paths) > 0: + resume_state_path = paths[-1] + resume_model_path = resume_state_path.replace( + 'training_state', 'models').replace('.state', '_G.pth') + else: + resume_state_path = opt.get('path', {}).get('resume_state') + return resume_state_path, resume_model_path + + +def opt_get(opt, keys, default=None): + if opt is None: + return default + ret = opt + for k in keys: + ret = ret.get(k, None) + if ret is None: + return default + return ret diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a447971943fc5982096362a167b461a8f264a5f8 --- /dev/null +++ b/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup, find_packages + +setup(name='image_enhancement', version='1.0', packages=find_packages()) \ No newline at end of file