diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c1113c05878b29d468bb240c6b1b64a022fda8bf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/arch_2.pdf filter=lfs diff=lfs merge=lfs -text +assets/arch.png filter=lfs diff=lfs merge=lfs -text +assets/comparison_3.pdf filter=lfs diff=lfs merge=lfs -text +assets/new_ablation.pdf filter=lfs diff=lfs merge=lfs -text +assets/show_3.png filter=lfs diff=lfs merge=lfs -text +imgs/horse2zebra.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3f54bce11e6c18652b298f1112e71030541e3531 --- /dev/null +++ b/.gitignore @@ -0,0 +1,120 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +checkpoints/ +.DS_Store +._.DS_Store +.vscode +predict/ +results/ +model/ +.pth +.png +.jpg +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json diff --git a/License b/License new file mode 100644 index 0000000000000000000000000000000000000000..2920867529a5e5ba455d2954101823b94222d771 --- /dev/null +++ b/License @@ -0,0 +1,58 @@ +Copyright (c) 2019, Yifan Jiang and Zhangyang Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +--------------------------- LICENSE FOR EnlightenGAN -------------------------------- +BSD License + +For EnlightenGAN software +Copyright (c) 2019, Yifan Jiang and Zhangyang Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +----------------------------- LICENSE FOR DCGAN -------------------------------- +BSD License + +For dcgan.torch software + +Copyright (c) 2015, Facebook, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index 3c414fcd8fb2aa2ffce16bee38ac69e30c371d34..e73a748b6811dc87da7193ee940a6800ea3ac6a0 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,67 @@ ---- -license: bsd ---- +# EnlightenGAN: Deep Light Enhancement without Paired Supervision +[Yifan Jiang](https://yifanjiang19.github.io/), Xinyu Gong, Ding Liu, Yu Cheng, Chen Fang, Xiaohui Shen, Jianchao Yang, Pan Zhou, Zhangyang Wang + +[[Paper]](https://arxiv.org/abs/1906.06972) [[Supplementary Materials]](https://yifanjiang.net/files/EnlightenGAN_Supplementary.pdf) + + +### Representitive Results +![representive_results](/assets/show_3.png) + +### Overal Architecture +![architecture](/assets/arch.png) + +## Environment Preparing +``` +python3.5 +``` +You should prepare at least 3 1080ti gpus or change the batch size. + + +```pip install -r requirement.txt```
+```mkdir model```
+Download VGG pretrained model from [[Google Drive 1]](https://drive.google.com/file/d/1IfCeihmPqGWJ0KHmH-mTMi_pn3z3Zo-P/view?usp=sharing), and then put it into the directory `model`. + +### Training process +Before starting training process, you should launch the `visdom.server` for visualizing. + +```nohup python -m visdom.server -port=8097``` + +then run the following command + +```python scripts/script.py --train``` + +### Testing process + +Download [pretrained model](https://drive.google.com/file/d/1AkV-n2MdyfuZTFvcon8Z4leyVb0i7x63/view?usp=sharing) and put it into `./checkpoints/enlightening` + +Create directories `../test_dataset/testA` and `../test_dataset/testB`. Put your test images on `../test_dataset/testA` (And you should keep whatever one image in `../test_dataset/testB` to make sure program can start.) + +Run + +```python scripts/script.py --predict ``` + +### Dataset preparing + +Training data [[Google Drive]](https://drive.google.com/drive/folders/1fwqz8-RnTfxgIIkebFG2Ej3jQFsYECh0?usp=sharing) (unpaired images collected from multiple datasets) + +Testing data [[Google Drive]](https://drive.google.com/open?id=1PrvL8jShZ7zj2IC3fVdDxBY1oJR72iDf) (including LIME, MEF, NPE, VV, DICP) + +And [[BaiduYun]](https://github.com/TAMU-VITA/EnlightenGAN/issues/28) is available now thanks to @YHLelaine! + +### Faster Inference +https://github.com/arsenyinfo/EnlightenGAN-inference from @arsenyinfo + + + +If you find this work useful for you, please cite +``` +@article{jiang2021enlightengan, + title={Enlightengan: Deep light enhancement without paired supervision}, + author={Jiang, Yifan and Gong, Xinyu and Liu, Ding and Cheng, Yu and Fang, Chen and Shen, Xiaohui and Yang, Jianchao and Zhou, Pan and Wang, Zhangyang}, + journal={IEEE Transactions on Image Processing}, + volume={30}, + pages={2340--2349}, + year={2021}, + publisher={IEEE} +} +``` diff --git a/assets/.DS_Store b/assets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/assets/.DS_Store differ diff --git a/assets/arch.png b/assets/arch.png new file mode 100644 index 0000000000000000000000000000000000000000..7f800d31fc3927ee170a4571d0f447a2f4a1d2e8 --- /dev/null +++ b/assets/arch.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab49b01eedf35c7325f9cfd98825cfbc96f209b769f4ed369d6835429906132b +size 1178967 diff --git a/assets/arch_2.pdf b/assets/arch_2.pdf new file mode 100644 index 0000000000000000000000000000000000000000..f9c8bc950197d8f2519ed87e39195099d43aa02f --- /dev/null +++ b/assets/arch_2.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df02a7f2b894d6230a1f120aa7c112962abe39232c648a441879d9dc8cc71756 +size 1738396 diff --git a/assets/comparison_3.pdf b/assets/comparison_3.pdf new file mode 100644 index 0000000000000000000000000000000000000000..bb62ffc1cd9db000f46e42387ee55703cb24484c --- /dev/null +++ b/assets/comparison_3.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:75a9820f1a978d9f0b6230dbd163efad5f8ca4100afe06bbed90cbe780a341d5 +size 1753489 diff --git a/assets/new_ablation.pdf b/assets/new_ablation.pdf new file mode 100644 index 0000000000000000000000000000000000000000..4ae1f990c065d35191f55abdb7407c0f7531cf10 --- /dev/null +++ b/assets/new_ablation.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38b13a6ad0682986d9535ad2d908a5124f65d321240d0cca881f84f6ef033892 +size 1407150 diff --git a/assets/show_3.png b/assets/show_3.png new file mode 100644 index 0000000000000000000000000000000000000000..b93ab45cddabac4c3c969b74d04f6026c124d21f --- /dev/null +++ b/assets/show_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e87e8ff557f17604ad8f51bcadc52bea21e57f72b880a45fe3ea3e0c42703c76 +size 3738934 diff --git a/assets/table.pdf b/assets/table.pdf new file mode 100644 index 0000000000000000000000000000000000000000..c403bb27188bc8ed86d6393c1cae722f7e7693d1 Binary files /dev/null and b/assets/table.pdf differ diff --git a/configs/unit_gta2city_folder.yaml b/configs/unit_gta2city_folder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..63c5b642c17ff63c1b0c3a3b06c5ec4e6db29697 --- /dev/null +++ b/configs/unit_gta2city_folder.yaml @@ -0,0 +1,54 @@ +# Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). + +# logger options +image_save_iter: 1000 # How often do you want to save output images during training +image_display_iter: 10 # How often do you want to display output images during training +display_size: 8 # How many images do you want to display each time +snapshot_save_iter: 10000 # How often do you want to save trained models +log_iter: 1 # How often do you want to log the training stats + +# optimization options +max_iter: 1000000 # maximum number of training iterations +batch_size: 1 # batch size +weight_decay: 0.0001 # weight decay +beta1: 0.5 # Adam parameter +beta2: 0.999 # Adam parameter +init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] +lr: 0.0001 # initial learning rate +lr_policy: step # learning rate scheduler +step_size: 100000 # how often to decay learning rate +gamma: 0.5 # how much to decay learning rate +gan_w: 1 # weight of adversarial loss +recon_x_w: 10 # weight of image reconstruction loss +recon_h_w: 0 # weight of hidden reconstruction loss +recon_kl_w: 0.01 # weight of KL loss for reconstruction +recon_x_cyc_w: 10 # weight of cycle consistency loss +recon_kl_cyc_w: 0.01 # weight of KL loss for cycle consistency +vgg_w: 0 # weight of domain-invariant perceptual loss + +# model options +gen: + dim: 64 # number of filters in the bottommost layer + activ: relu # activation function [relu/lrelu/prelu/selu/tanh] + n_downsample: 2 # number of downsampling layers in content encoder + n_res: 4 # number of residual blocks in content encoder/decoder + pad_type: reflect # padding type [zero/reflect] +dis: + dim: 64 # number of filters in the bottommost layer + norm: none # normalization layer [none/bn/in/ln] + activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] + n_layer: 4 # number of layers in D + gan_type: lsgan # GAN loss [lsgan/nsgan] + num_scales: 3 # number of scales + pad_type: reflect # padding type [zero/reflect] + +# data options +input_dim_a: 3 # number of image channels [1/3] +input_dim_b: 3 # number of image channels [1/3] +num_workers: 8 # number of data loading threads +new_size: 256 # first resize the shortest image side to this size +crop_image_height: 256 # random crop image of this height +crop_image_width: 256 # random crop image of this width + +data_root: ./datasets/lol/ # dataset folder location \ No newline at end of file diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0f45c4030c7a3284e8e8d3e94d6e257e7184757d --- /dev/null +++ b/data/aligned_dataset.py @@ -0,0 +1,56 @@ +import os.path +import random +import torchvision.transforms as transforms +import torch +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image + + +class AlignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_AB = os.path.join(opt.dataroot, opt.phase) + + self.AB_paths = sorted(make_dataset(self.dir_AB)) + + assert(opt.resize_or_crop == 'resize_and_crop') + + transform_list = [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + + self.transform = transforms.Compose(transform_list) + + def __getitem__(self, index): + AB_path = self.AB_paths[index] + AB = Image.open(AB_path).convert('RGB') + AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC) + AB = self.transform(AB) + + w_total = AB.size(2) + w = int(w_total / 2) + h = AB.size(1) + w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) + h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) + + A = AB[:, h_offset:h_offset + self.opt.fineSize, + w_offset:w_offset + self.opt.fineSize] + B = AB[:, h_offset:h_offset + self.opt.fineSize, + w + w_offset:w + w_offset + self.opt.fineSize] + + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A.size(2) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A = A.index_select(2, idx) + B = B.index_select(2, idx) + + return {'A': A, 'B': B, + 'A_paths': AB_path, 'B_paths': AB_path} + + def __len__(self): + return len(self.AB_paths) + + def name(self): + return 'AlignedDataset' diff --git a/data/base_data_loader.py b/data/base_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..0e1deb55e5eef353f379ce63f1223e26121248c8 --- /dev/null +++ b/data/base_data_loader.py @@ -0,0 +1,14 @@ + +class BaseDataLoader(): + def __init__(self): + pass + + def initialize(self, opt): + self.opt = opt + pass + + def load_data(): + return None + + + diff --git a/data/base_dataset.py b/data/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7acac20a00f22ae13b987420e92e739bb87dc7 --- /dev/null +++ b/data/base_dataset.py @@ -0,0 +1,50 @@ +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +import random + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + def initialize(self, opt): + pass + +def get_transform(opt): + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + zoom = 1 + 0.1*radom.randint(0,4) + osize = [int(400*zoom), int(600*zoom)] + transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'crop': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'scale_width': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.fineSize))) + elif opt.resize_or_crop == 'scale_width_and_crop': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.loadSize))) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + # elif opt.resize_or_crop == 'no': + # osize = [384, 512] + # transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def __scale_width(img, target_width): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), Image.BICUBIC) diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..2bbfd4c4fa9220651052db637aec681f79f7ecb3 --- /dev/null +++ b/data/custom_dataset_data_loader.py @@ -0,0 +1,50 @@ +import torch.utils.data +from data.base_data_loader import BaseDataLoader + + +def CreateDataset(opt): + dataset = None + if opt.dataset_mode == 'aligned': + from data.aligned_dataset import AlignedDataset + dataset = AlignedDataset() + elif opt.dataset_mode == 'unaligned': + from data.unaligned_dataset import UnalignedDataset + dataset = UnalignedDataset() + elif opt.dataset_mode == 'unaligned_random_crop': + from data.unaligned_random_crop import UnalignedDataset + dataset = UnalignedDataset() + elif opt.dataset_mode == 'pair': + from data.pair_dataset import PairDataset + dataset = PairDataset() + elif opt.dataset_mode == 'syn': + from data.syn_dataset import PairDataset + dataset = PairDataset() + elif opt.dataset_mode == 'single': + from data.single_dataset import SingleDataset + dataset = SingleDataset() + else: + raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) + + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + + +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self.dataloader + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/data/data_loader.py b/data/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..2a4433a29eb1e0a2be75c477f292c9e784ce5b6b --- /dev/null +++ b/data/data_loader.py @@ -0,0 +1,7 @@ + +def CreateDataLoader(opt): + from data.custom_dataset_data_loader import CustomDatasetDataLoader + data_loader = CustomDatasetDataLoader() + print(data_loader.name()) + data_loader.initialize(opt) + return data_loader diff --git a/data/image_folder.py b/data/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..88864fa46c2c8ad03b59aa39a8b92147c7414e5f --- /dev/null +++ b/data/image_folder.py @@ -0,0 +1,83 @@ +############################################################################### +# Code from +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py +# Modified the original code so that it also loads images from the current +# directory as well as the subdirectories +############################################################################### + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + return images + +def store_dataset(dir): + images = [] + all_path = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + img = Image.open(path).convert('RGB') + images.append(img) + all_path.append(path) + + return images, all_path + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/data/pair_dataset.py b/data/pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..39192e13df4a2620e8ccc01564741418321e45a2 --- /dev/null +++ b/data/pair_dataset.py @@ -0,0 +1,95 @@ +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image +import PIL +import random +import torch +from pdb import set_trace as st + + +class PairDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') + + self.A_paths = make_dataset(self.dir_A) + self.B_paths = make_dataset(self.dir_B) + + self.A_paths = sorted(self.A_paths) + self.B_paths = sorted(self.B_paths) + self.A_size = len(self.A_paths) + self.B_size = len(self.B_paths) + + transform_list = [] + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + # transform_list = [transforms.ToTensor()] + + self.transform = transforms.Compose(transform_list) + # self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.A_paths[index % self.A_size] + B_path = self.B_paths[index % self.B_size] + + A_img = Image.open(A_path).convert('RGB') + B_img = Image.open(A_path.replace("low", "normal").replace("A", "B")).convert('RGB') + + + A_img = self.transform(A_img) + B_img = self.transform(B_img) + + w = A_img.size(2) + h = A_img.size(1) + w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) + h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) + + A_img = A_img[:, h_offset:h_offset + self.opt.fineSize, + w_offset:w_offset + self.opt.fineSize] + B_img = B_img[:, h_offset:h_offset + self.opt.fineSize, + w_offset:w_offset + self.opt.fineSize] + + + if self.opt.resize_or_crop == 'no': + r,g,b = A_img[0]+1, A_img[1]+1, A_img[2]+1 + A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2. + A_gray = torch.unsqueeze(A_gray, 0) + input_img = A_img + # A_gray = (1./A_gray)/255. + else: + + + # A_gray = (1./A_gray)/255. + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A_img.size(2) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A_img = A_img.index_select(2, idx) + B_img = B_img.index_select(2, idx) + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A_img.size(1) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A_img = A_img.index_select(1, idx) + B_img = B_img.index_select(1, idx) + if (not self.opt.no_flip) and random.random() < 0.5: + times = random.randint(self.opt.low_times,self.opt.high_times)/100. + input_img = (A_img+1)/2./times + input_img = input_img*2-1 + else: + input_img = A_img + r,g,b = input_img[0]+1, input_img[1]+1, input_img[2]+1 + A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2. + A_gray = torch.unsqueeze(A_gray, 0) + return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img':input_img, + 'A_paths': A_path, 'B_paths': B_path} + + def __len__(self): + return self.A_size + + def name(self): + return 'PairDataset' diff --git a/data/single_dataset.py b/data/single_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c2f1c929c480fdb3c3a4079fb93a70eb45e940 --- /dev/null +++ b/data/single_dataset.py @@ -0,0 +1,36 @@ +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image + + +class SingleDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot) + + self.A_paths = make_dataset(self.dir_A) + + self.A_paths = sorted(self.A_paths) + + self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.A_paths[index] + + A_img = Image.open(A_path).convert('RGB') + A_size = A_img.size + A_size = A_size = (A_size[0]//16*16, A_size[1]//16*16) + A_img = A_img.resize(A_size, Image.BICUBIC) + + A_img = self.transform(A_img) + + return {'A': A_img, 'A_paths': A_path} + + def __len__(self): + return len(self.A_paths) + + def name(self): + return 'SingleImageDataset' diff --git a/data/syn_dataset.py b/data/syn_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9e5c7040a4b42ea8cb495605fe2c38b32bf7f4 --- /dev/null +++ b/data/syn_dataset.py @@ -0,0 +1,91 @@ +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image +import PIL +import random +import torch +from pdb import set_trace as st + + +class PairDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') + + self.A_paths = make_dataset(self.dir_A) + self.B_paths = make_dataset(self.dir_B) + + self.A_paths = sorted(self.A_paths) + self.B_paths = sorted(self.B_paths) + self.A_size = len(self.A_paths) + self.B_size = len(self.B_paths) + + transform_list = [] + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + # transform_list = [transforms.ToTensor()] + + self.transform = transforms.Compose(transform_list) + # self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.A_paths[index % self.A_size] + B_path = self.B_paths[index % self.B_size] + + B_img = Image.open(B_path).convert('RGB') + # B_img = Image.open(A_path.replace("low", "normal").replace("A", "B")).convert('RGB') + + + # A_img = self.transform(A_img) + B_img = self.transform(B_img) + + w = B_img.size(2) + h = B_img.size(1) + w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) + h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) + + B_img = B_img[:, h_offset:h_offset + self.opt.fineSize, + w_offset:w_offset + self.opt.fineSize] + + + if self.opt.resize_or_crop == 'no': + pass + # r,g,b = A_img[0]+1, A_img[1]+1, A_img[2]+1 + # A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2. + # A_gray = torch.unsqueeze(A_gray, 0) + # input_img = A_img + # A_gray = (1./A_gray)/255. + else: + + + # A_gray = (1./A_gray)/255. + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(B_img.size(2) - 1, -1, -1)] + idx = torch.LongTensor(idx) + B_img = B_img.index_select(2, idx) + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(B_img.size(1) - 1, -1, -1)] + idx = torch.LongTensor(idx) + B_img = B_img.index_select(1, idx) + + times = random.randint(self.opt.low_times,self.opt.high_times)/100. + input_img = (B_img+1)/2./times + input_img = input_img*2-1 + A_img = input_img + r,g,b = input_img[0]+1, input_img[1]+1, input_img[2]+1 + A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2. + A_gray = torch.unsqueeze(A_gray, 0) + return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img':input_img, + 'A_paths': A_path, 'B_paths': B_path} + + def __len__(self): + return self.A_size + + def name(self): + return 'PairDataset' diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..826fd594ee8793f65b8a905c6a9f535a47673d9f --- /dev/null +++ b/data/unaligned_dataset.py @@ -0,0 +1,141 @@ +import torch +from torch import nn +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset, store_dataset +import random +from PIL import Image +import PIL +from pdb import set_trace as st + +def pad_tensor(input): + + height_org, width_org = input.shape[2], input.shape[3] + divide = 16 + + if width_org % divide != 0 or height_org % divide != 0: + + width_res = width_org % divide + height_res = height_org % divide + if width_res != 0: + width_div = divide - width_res + pad_left = int(width_div / 2) + pad_right = int(width_div - pad_left) + else: + pad_left = 0 + pad_right = 0 + + if height_res != 0: + height_div = divide - height_res + pad_top = int(height_div / 2) + pad_bottom = int(height_div - pad_top) + else: + pad_top = 0 + pad_bottom = 0 + + padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom)) + input = padding(input).data + else: + pad_left = 0 + pad_right = 0 + pad_top = 0 + pad_bottom = 0 + + height, width = input.shape[2], input.shape[3] + assert width % divide == 0, 'width cant divided by stride' + assert height % divide == 0, 'height cant divided by stride' + + return input, pad_left, pad_right, pad_top, pad_bottom + +def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom): + height, width = input.shape[2], input.shape[3] + return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right] + + +class UnalignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') + + # self.A_paths = make_dataset(self.dir_A) + # self.B_paths = make_dataset(self.dir_B) + self.A_imgs, self.A_paths = store_dataset(self.dir_A) + self.B_imgs, self.B_paths = store_dataset(self.dir_B) + + # self.A_paths = sorted(self.A_paths) + # self.B_paths = sorted(self.B_paths) + self.A_size = len(self.A_paths) + self.B_size = len(self.B_paths) + + self.transform = get_transform(opt) + + def __getitem__(self, index): + # A_path = self.A_paths[index % self.A_size] + # B_path = self.B_paths[index % self.B_size] + + # A_img = Image.open(A_path).convert('RGB') + # B_img = Image.open(B_path).convert('RGB') + A_img = self.A_imgs[index % self.A_size] + B_img = self.B_imgs[index % self.B_size] + A_path = self.A_paths[index % self.A_size] + B_path = self.B_paths[index % self.B_size] + # A_size = A_img.size + # B_size = B_img.size + # A_size = A_size = (A_size[0]//16*16, A_size[1]//16*16) + # B_size = B_size = (B_size[0]//16*16, B_size[1]//16*16) + # A_img = A_img.resize(A_size, Image.BICUBIC) + # B_img = B_img.resize(B_size, Image.BICUBIC) + # A_gray = A_img.convert('LA') + # A_gray = 255.0-A_gray + + A_img = self.transform(A_img) + B_img = self.transform(B_img) + + + if self.opt.resize_or_crop == 'no': + r,g,b = A_img[0]+1, A_img[1]+1, A_img[2]+1 + A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2. + A_gray = torch.unsqueeze(A_gray, 0) + input_img = A_img + # A_gray = (1./A_gray)/255. + else: + w = A_img.size(2) + h = A_img.size(1) + + # A_gray = (1./A_gray)/255. + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A_img.size(2) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A_img = A_img.index_select(2, idx) + B_img = B_img.index_select(2, idx) + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A_img.size(1) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A_img = A_img.index_select(1, idx) + B_img = B_img.index_select(1, idx) + if self.opt.vary == 1 and (not self.opt.no_flip) and random.random() < 0.5: + times = random.randint(self.opt.low_times,self.opt.high_times)/100. + input_img = (A_img+1)/2./times + input_img = input_img*2-1 + else: + input_img = A_img + if self.opt.lighten: + B_img = (B_img + 1)/2. + B_img = (B_img - torch.min(B_img))/(torch.max(B_img) - torch.min(B_img)) + B_img = B_img*2. -1 + r,g,b = input_img[0]+1, input_img[1]+1, input_img[2]+1 + A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2. + A_gray = torch.unsqueeze(A_gray, 0) + return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img': input_img, + 'A_paths': A_path, 'B_paths': B_path} + + def __len__(self): + return max(self.A_size, self.B_size) + + def name(self): + return 'UnalignedDataset' + + diff --git a/data/unaligned_random_crop.py b/data/unaligned_random_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..9df3ef0ed672f2be94c9d05d1b48a6ed708ed9a0 --- /dev/null +++ b/data/unaligned_random_crop.py @@ -0,0 +1,85 @@ +import torch +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +import random +from PIL import Image +import PIL +from pdb import set_trace as st + + +class UnalignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') + + self.A_paths = make_dataset(self.dir_A) + self.B_paths = make_dataset(self.dir_B) + + self.A_paths = sorted(self.A_paths) + self.B_paths = sorted(self.B_paths) + self.A_size = len(self.A_paths) + self.B_size = len(self.B_paths) + + transform_list = [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + + self.transform = transforms.Compose(transform_list) + # self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.A_paths[index % self.A_size] + B_path = self.B_paths[index % self.B_size] + + A_img = Image.open(A_path).convert('RGB') + B_img = Image.open(B_path).convert('RGB') + A_size = A_img.size + B_size = B_img.size + A_size = A_size = (A_size[0]//16*16, A_size[1]//16*16) + B_size = B_size = (B_size[0]//16*16, B_size[1]//16*16) + A_img = A_img.resize(A_size, Image.BICUBIC) + B_img = B_img.resize(B_size, Image.BICUBIC) + + + A_img = self.transform(A_img) + B_img = self.transform(B_img) + + if self.opt.resize_or_crop == 'no': + pass + else: + w = A_img.size(2) + h = A_img.size(1) + size = [8,16,22] + from random import randint + size_index = randint(0,2) + Cropsize = size[size_index]*16 + + w_offset = random.randint(0, max(0, w - Cropsize - 1)) + h_offset = random.randint(0, max(0, h - Cropsize - 1)) + + A_img = A_img[:, h_offset:h_offset + Cropsize, + w_offset:w_offset + Cropsize] + + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A_img.size(2) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A_img = A_img.index_select(2, idx) + B_img = B_img.index_select(2, idx) + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A_img.size(1) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A_img = A_img.index_select(1, idx) + B_img = B_img.index_select(1, idx) + + return {'A': A_img, 'B': B_img, + 'A_paths': A_path, 'B_paths': B_path} + + def __len__(self): + return max(self.A_size, self.B_size) + + def name(self): + return 'UnalignedDataset' diff --git a/datasets/.DS_Store b/datasets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..75255b7007599c68d652328edb83418389fb0e52 Binary files /dev/null and b/datasets/.DS_Store differ diff --git a/datasets/bibtex/cityscapes.tex b/datasets/bibtex/cityscapes.tex new file mode 100644 index 0000000000000000000000000000000000000000..a87bdbf54fe9a5453fc8cf929299ef06d2f47691 --- /dev/null +++ b/datasets/bibtex/cityscapes.tex @@ -0,0 +1,6 @@ +@inproceedings{Cordts2016Cityscapes, +title={The Cityscapes Dataset for Semantic Urban Scene Understanding}, +author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt}, +booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, +year={2016} +} diff --git a/datasets/bibtex/facades.tex b/datasets/bibtex/facades.tex new file mode 100644 index 0000000000000000000000000000000000000000..08b773e1188a9cfe8ce55b34616cead59c4d9243 --- /dev/null +++ b/datasets/bibtex/facades.tex @@ -0,0 +1,7 @@ +@INPROCEEDINGS{Tylecek13, + author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra}, + title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure}, + booktitle = {Proc. GCPR}, + year = {2013}, + address = {Saarbrucken, Germany}, +} diff --git a/datasets/bibtex/handbags.tex b/datasets/bibtex/handbags.tex new file mode 100644 index 0000000000000000000000000000000000000000..b79710c7b5344b181a534e04696dd2e75c744ecf --- /dev/null +++ b/datasets/bibtex/handbags.tex @@ -0,0 +1,13 @@ +@inproceedings{zhu2016generative, + title={Generative Visual Manipulation on the Natural Image Manifold}, + author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.}, + booktitle={Proceedings of European Conference on Computer Vision (ECCV)}, + year={2016} +} + +@InProceedings{xie15hed, + author = {"Xie, Saining and Tu, Zhuowen"}, + Title = {Holistically-Nested Edge Detection}, + Booktitle = "Proceedings of IEEE International Conference on Computer Vision", + Year = {2015}, +} diff --git a/datasets/bibtex/shoes.tex b/datasets/bibtex/shoes.tex new file mode 100644 index 0000000000000000000000000000000000000000..e67e158b945e456b9613f0effb06784cd6682c20 --- /dev/null +++ b/datasets/bibtex/shoes.tex @@ -0,0 +1,14 @@ +@InProceedings{fine-grained, + author = {A. Yu and K. Grauman}, + title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning}, + booktitle = {Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2014} +} + +@InProceedings{xie15hed, + author = {"Xie, Saining and Tu, Zhuowen"}, + Title = {Holistically-Nested Edge Detection}, + Booktitle = "Proceedings of IEEE International Conference on Computer Vision", + Year = {2015}, +} diff --git a/datasets/combine_A_and_B.py b/datasets/combine_A_and_B.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1e2a2a8f5e0a4cb9b729df777fd853e03fad22 --- /dev/null +++ b/datasets/combine_A_and_B.py @@ -0,0 +1,49 @@ +from pdb import set_trace as st +import os +import numpy as np +import cv2 +import argparse + +parser = argparse.ArgumentParser('create image pairs') +parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges') +parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg') +parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB') +parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) +parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true') +args = parser.parse_args() + +for arg in vars(args): + print('[%s] = ' % arg, getattr(args, arg)) + +splits = os.listdir(args.fold_A) + +for sp in splits: + img_fold_A = os.path.join(args.fold_A, sp) + img_fold_B = os.path.join(args.fold_B, sp) + img_list = os.listdir(img_fold_A) + if args.use_AB: + img_list = [img_path for img_path in img_list if '_A.' in img_path] + + num_imgs = min(args.num_imgs, len(img_list)) + print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) + img_fold_AB = os.path.join(args.fold_AB, sp) + if not os.path.isdir(img_fold_AB): + os.makedirs(img_fold_AB) + print('split = %s, number of images = %d' % (sp, num_imgs)) + for n in range(num_imgs): + name_A = img_list[n] + path_A = os.path.join(img_fold_A, name_A) + if args.use_AB: + name_B = name_A.replace('_A.', '_B.') + else: + name_B = name_A + path_B = os.path.join(img_fold_B, name_B) + if os.path.isfile(path_A) and os.path.isfile(path_B): + name_AB = name_A + if args.use_AB: + name_AB = name_AB.replace('_A.', '.') # remove _A + path_AB = os.path.join(img_fold_AB, name_AB) + im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR) + im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR) + im_AB = np.concatenate([im_A, im_B], 1) + cv2.imwrite(path_AB, im_AB) diff --git a/datasets/download_cyclegan_dataset.sh b/datasets/download_cyclegan_dataset.sh new file mode 100644 index 0000000000000000000000000000000000000000..1f0b163185554efe599623379776924a4058c650 --- /dev/null +++ b/datasets/download_cyclegan_dataset.sh @@ -0,0 +1,14 @@ +FILE=$1 + +if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then + echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" + exit 1 +fi + +URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip +ZIP_FILE=./datasets/$FILE.zip +TARGET_DIR=./datasets/$FILE/ +wget -N $URL -O $ZIP_FILE +mkdir $TARGET_DIR +unzip $ZIP_FILE -d ./datasets/ +rm $ZIP_FILE diff --git a/datasets/download_pix2pix_dataset.sh b/datasets/download_pix2pix_dataset.sh new file mode 100644 index 0000000000000000000000000000000000000000..2d28e4f38ebb2ba33631329a64b4b1b879f51853 --- /dev/null +++ b/datasets/download_pix2pix_dataset.sh @@ -0,0 +1,8 @@ +FILE=$1 +URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz +TAR_FILE=./datasets/$FILE.tar.gz +TARGET_DIR=./datasets/$FILE/ +wget -N $URL -O $TAR_FILE +mkdir $TARGET_DIR +tar -zxvf $TAR_FILE -C ./datasets/ +rm $TAR_FILE \ No newline at end of file diff --git a/imgs/edges2cats.jpg b/imgs/edges2cats.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c9586bcfb318a5a978a3d79a7283466f9bc48c44 Binary files /dev/null and b/imgs/edges2cats.jpg differ diff --git a/imgs/horse2zebra.gif b/imgs/horse2zebra.gif new file mode 100644 index 0000000000000000000000000000000000000000..4ded4d1ec2f5438765418c8a32d5e6f401b7d5cf --- /dev/null +++ b/imgs/horse2zebra.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16a76adedd309c46ba6ed63f89b14130c4a671fd6febc26fb0372a1ccf16c7aa +size 7686299 diff --git a/lib/nn/__init__.py b/lib/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..98a96370ef04570f516052bb73f568d0ebc346c3 --- /dev/null +++ b/lib/nn/__init__.py @@ -0,0 +1,2 @@ +from .modules import * +from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to diff --git a/lib/nn/modules/__init__.py b/lib/nn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8709d92c610b36e0bcbd7da20c1eb41dc8cfcf --- /dev/null +++ b/lib/nn/modules/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/lib/nn/modules/batchnorm.py b/lib/nn/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..18318965335b37cc671004a6aceda3229dc7b477 --- /dev/null +++ b/lib/nn/modules/batchnorm.py @@ -0,0 +1,329 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + # customed batch norm statistics + self._moving_average_fraction = 1. - momentum + self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) + self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) + self.register_buffer('_running_iter', torch.ones(1)) + self._tmp_running_mean = self.running_mean.clone() * self._running_iter + self._tmp_running_var = self.running_var.clone() * self._running_iter + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): + """return *dest* by `dest := dest*alpha + delta*beta + bias`""" + return dest * alpha + delta * beta + bias + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) + self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) + self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) + + self.running_mean = self._tmp_running_mean / self._running_iter + self.running_var = self._tmp_running_var / self._running_iter + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/lib/nn/modules/comm.py b/lib/nn/modules/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..b64bf6ba3b3e7abbab375c6dd4a87d8239e62138 --- /dev/null +++ b/lib/nn/modules/comm.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/lib/nn/modules/replicate.py b/lib/nn/modules/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/lib/nn/modules/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/lib/nn/modules/tests/test_numeric_batchnorm.py b/lib/nn/modules/tests/test_numeric_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd45a930d3dc84912e58659ee575be08e9038f0 --- /dev/null +++ b/lib/nn/modules/tests/test_numeric_batchnorm.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# File : test_numeric_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +class NumericTestCase(TorchTestCase): + def testNumericBatchNorm(self): + a = torch.rand(16, 10) + bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) + bn.train() + + a_var1 = Variable(a, requires_grad=True) + b_var1 = bn(a_var1) + loss1 = b_var1.sum() + loss1.backward() + + a_var2 = Variable(a, requires_grad=True) + a_mean2 = a_var2.mean(dim=0, keepdim=True) + a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) + # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) + b_var2 = (a_var2 - a_mean2) / a_std2 + loss2 = b_var2.sum() + loss2.backward() + + self.assertTensorClose(bn.running_mean, a.mean(dim=0)) + self.assertTensorClose(bn.running_var, handy_var(a)) + self.assertTensorClose(a_var1.data, a_var2.data) + self.assertTensorClose(b_var1.data, b_var2.data) + self.assertTensorClose(a_var1.grad, a_var2.grad) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/nn/modules/tests/test_sync_batchnorm.py b/lib/nn/modules/tests/test_sync_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..45bb3c8cfd36d8f668e6fde756b17587eab72082 --- /dev/null +++ b/lib/nn/modules/tests/test_sync_batchnorm.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# File : test_sync_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +def _find_bn(module): + for m in module.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): + return m + + +class SyncTestCase(TorchTestCase): + def _syncParameters(self, bn1, bn2): + bn1.reset_parameters() + bn2.reset_parameters() + if bn1.affine and bn2.affine: + bn2.weight.data.copy_(bn1.weight.data) + bn2.bias.data.copy_(bn1.bias.data) + + def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): + """Check the forward and backward for the customized batch normalization.""" + bn1.train(mode=is_train) + bn2.train(mode=is_train) + + if cuda: + input = input.cuda() + + self._syncParameters(_find_bn(bn1), _find_bn(bn2)) + + input1 = Variable(input, requires_grad=True) + output1 = bn1(input1) + output1.sum().backward() + input2 = Variable(input, requires_grad=True) + output2 = bn2(input2) + output2.sum().backward() + + self.assertTensorClose(input1.data, input2.data) + self.assertTensorClose(output1.data, output2.data) + self.assertTensorClose(input1.grad, input2.grad) + self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) + self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) + + def testSyncBatchNormNormalTrain(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) + + def testSyncBatchNormNormalEval(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) + + def testSyncBatchNormSyncTrain(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) + + def testSyncBatchNormSyncEval(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) + + def testSyncBatchNorm2DSyncTrain(self): + bn = nn.BatchNorm2d(10) + sync_bn = SynchronizedBatchNorm2d(10) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/nn/modules/unittest.py b/lib/nn/modules/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..0675c022e4ba85d38d1f813490f6740150909524 --- /dev/null +++ b/lib/nn/modules/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/lib/nn/parallel/__init__.py b/lib/nn/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b52f49cc0755562218a460483cbf02514ddd773 --- /dev/null +++ b/lib/nn/parallel/__init__.py @@ -0,0 +1 @@ +from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to diff --git a/lib/nn/parallel/data_parallel.py b/lib/nn/parallel/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..376fc038919aa2a5bd696141e7bb6025d4981306 --- /dev/null +++ b/lib/nn/parallel/data_parallel.py @@ -0,0 +1,112 @@ +# -*- coding: utf8 -*- + +import torch.cuda as cuda +import torch.nn as nn +import torch +import collections +from torch.nn.parallel._functions import Gather + + +__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] + + +def async_copy_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + v = obj.cuda(dev, non_blocking=True) + if main_stream is not None: + v.data.record_stream(main_stream) + return v + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + else: + return obj + + +def dict_gather(outputs, target_device, dim=0): + """ + Gathers variables from different GPUs on a specified device + (-1 means the CPU), with dictionary support. + """ + def gather_map(outputs): + out = outputs[0] + if torch.is_tensor(out): + # MJY(20180330) HACK:: force nr_dims > 0 + if out.dim() == 0: + outputs = [o.unsqueeze(0) for o in outputs] + return Gather.apply(target_device, dim, *outputs) + elif out is None: + return None + elif isinstance(out, collections.Mapping): + return {k: gather_map([o[k] for o in outputs]) for k in out} + elif isinstance(out, collections.Sequence): + return type(out)(map(gather_map, zip(*outputs))) + return gather_map(outputs) + + +class DictGatherDataParallel(nn.DataParallel): + def gather(self, outputs, output_device): + return dict_gather(outputs, output_device, dim=self.dim) + + +class UserScatteredDataParallel(DictGatherDataParallel): + def scatter(self, inputs, kwargs, device_ids): + assert len(inputs) == 1 + inputs = inputs[0] + inputs = _async_copy_stream(inputs, device_ids) + inputs = [[i] for i in inputs] + assert len(kwargs) == 0 + kwargs = [{} for _ in range(len(inputs))] + + return inputs, kwargs + + +def user_scattered_collate(batch): + return batch + + +def _async_copy(inputs, device_ids): + nr_devs = len(device_ids) + assert type(inputs) in (tuple, list) + assert len(inputs) == nr_devs + + outputs = [] + for i, dev in zip(inputs, device_ids): + with cuda.device(dev): + outputs.append(async_copy_to(i, dev)) + + return tuple(outputs) + + +def _async_copy_stream(inputs, device_ids): + nr_devs = len(device_ids) + assert type(inputs) in (tuple, list) + assert len(inputs) == nr_devs + + outputs = [] + streams = [_get_stream(d) for d in device_ids] + for i, dev, stream in zip(inputs, device_ids, streams): + with cuda.device(dev): + main_stream = cuda.current_stream() + with cuda.stream(stream): + outputs.append(async_copy_to(i, dev, main_stream=main_stream)) + main_stream.wait_stream(stream) + + return outputs + + +"""Adapted from: torch/nn/parallel/_functions.py""" +# background streams used for copying +_streams = None + + +def _get_stream(device): + """Gets a background stream for copying between CPU and GPU""" + global _streams + if device == -1: + return None + if _streams is None: + _streams = [None] * cuda.device_count() + if _streams[device] is None: _streams[device] = cuda.Stream(device) + return _streams[device] diff --git a/lib/utils/__init__.py b/lib/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abe3cbe49477fe37d4fc16249de8a10f4fb4a013 --- /dev/null +++ b/lib/utils/__init__.py @@ -0,0 +1 @@ +from .th import * diff --git a/lib/utils/data/__init__.py b/lib/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b008fb13c5e8a84b1b785056e8c4f5226dc976 --- /dev/null +++ b/lib/utils/data/__init__.py @@ -0,0 +1,3 @@ + +from .dataset import Dataset, TensorDataset, ConcatDataset +from .dataloader import DataLoader diff --git a/lib/utils/data/dataloader.py b/lib/utils/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..220f98dffd73b8eb3fbee04766cddb8e83f4689e --- /dev/null +++ b/lib/utils/data/dataloader.py @@ -0,0 +1,422 @@ +import torch +import torch.multiprocessing as multiprocessing +from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ + _remove_worker_pids, _error_if_any_worker_fails +from .sampler import SequentialSampler, RandomSampler, BatchSampler +import signal +import functools +import collections +import re +import sys +import threading +import traceback +from torch._six import string_classes, int_classes +import numpy as np + +if sys.version_info[0] == 2: + import Queue as queue +else: + import queue + + +class ExceptionWrapper(object): + r"Wraps an exception plus traceback to communicate across threads" + + def __init__(self, exc_info): + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + + +_use_shared_memory = False +"""Whether to use shared memory in default_collate""" + + +def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): + global _use_shared_memory + _use_shared_memory = True + + # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal happened again already. + # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 + _set_worker_signal_handlers() + + torch.set_num_threads(1) + torch.manual_seed(seed) + np.random.seed(seed) + + if init_fn is not None: + init_fn(worker_id) + + while True: + r = index_queue.get() + if r is None: + break + idx, batch_indices = r + try: + samples = collate_fn([dataset[i] for i in batch_indices]) + except Exception: + data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + data_queue.put((idx, samples)) + + +def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): + if pin_memory: + torch.cuda.set_device(device_id) + + while True: + try: + r = in_queue.get() + except Exception: + if done_event.is_set(): + return + raise + if r is None: + break + if isinstance(r[1], ExceptionWrapper): + out_queue.put(r) + continue + idx, batch = r + try: + if pin_memory: + batch = pin_memory_batch(batch) + except Exception: + out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + out_queue.put((idx, batch)) + +numpy_type_map = { + 'float64': torch.DoubleTensor, + 'float32': torch.FloatTensor, + 'float16': torch.HalfTensor, + 'int64': torch.LongTensor, + 'int32': torch.IntTensor, + 'int16': torch.ShortTensor, + 'int8': torch.CharTensor, + 'uint8': torch.ByteTensor, +} + + +def default_collate(batch): + "Puts each data field into a tensor with outer dimension batch size" + + error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" + elem_type = type(batch[0]) + if torch.is_tensor(batch[0]): + out = None + if _use_shared_memory: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = batch[0].storage()._new_shared(numel) + out = batch[0].new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + elem = batch[0] + if elem_type.__name__ == 'ndarray': + # array of string classes and object + if re.search('[SaUO]', elem.dtype.str) is not None: + raise TypeError(error_msg.format(elem.dtype)) + + return torch.stack([torch.from_numpy(b) for b in batch], 0) + if elem.shape == (): # scalars + py_type = float if elem.dtype.name.startswith('float') else int + return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) + elif isinstance(batch[0], int_classes): + return torch.LongTensor(batch) + elif isinstance(batch[0], float): + return torch.DoubleTensor(batch) + elif isinstance(batch[0], string_classes): + return batch + elif isinstance(batch[0], collections.Mapping): + return {key: default_collate([d[key] for d in batch]) for key in batch[0]} + elif isinstance(batch[0], collections.Sequence): + transposed = zip(*batch) + return [default_collate(samples) for samples in transposed] + + raise TypeError((error_msg.format(type(batch[0])))) + + +def pin_memory_batch(batch): + if torch.is_tensor(batch): + return batch.pin_memory() + elif isinstance(batch, string_classes): + return batch + elif isinstance(batch, collections.Mapping): + return {k: pin_memory_batch(sample) for k, sample in batch.items()} + elif isinstance(batch, collections.Sequence): + return [pin_memory_batch(sample) for sample in batch] + else: + return batch + + +_SIGCHLD_handler_set = False +"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one +handler needs to be set for all DataLoaders in a process.""" + + +def _set_SIGCHLD_handler(): + # Windows doesn't support SIGCHLD handler + if sys.platform == 'win32': + return + # can't set signal in child threads + if not isinstance(threading.current_thread(), threading._MainThread): + return + global _SIGCHLD_handler_set + if _SIGCHLD_handler_set: + return + previous_handler = signal.getsignal(signal.SIGCHLD) + if not callable(previous_handler): + previous_handler = None + + def handler(signum, frame): + # This following call uses `waitid` with WNOHANG from C side. Therefore, + # Python can still get and update the process status successfully. + _error_if_any_worker_fails() + if previous_handler is not None: + previous_handler(signum, frame) + + signal.signal(signal.SIGCHLD, handler) + _SIGCHLD_handler_set = True + + +class DataLoaderIter(object): + "Iterates once over the DataLoader's dataset, as specified by the sampler" + + def __init__(self, loader): + self.dataset = loader.dataset + self.collate_fn = loader.collate_fn + self.batch_sampler = loader.batch_sampler + self.num_workers = loader.num_workers + self.pin_memory = loader.pin_memory and torch.cuda.is_available() + self.timeout = loader.timeout + self.done_event = threading.Event() + + self.sample_iter = iter(self.batch_sampler) + + if self.num_workers > 0: + self.worker_init_fn = loader.worker_init_fn + self.index_queue = multiprocessing.SimpleQueue() + self.worker_result_queue = multiprocessing.SimpleQueue() + self.batches_outstanding = 0 + self.worker_pids_set = False + self.shutdown = False + self.send_idx = 0 + self.rcvd_idx = 0 + self.reorder_dict = {} + + base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] + self.workers = [ + multiprocessing.Process( + target=_worker_loop, + args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, + base_seed + i, self.worker_init_fn, i)) + for i in range(self.num_workers)] + + if self.pin_memory or self.timeout > 0: + self.data_queue = queue.Queue() + if self.pin_memory: + maybe_device_id = torch.cuda.current_device() + else: + # do not initialize cuda context if not necessary + maybe_device_id = None + self.worker_manager_thread = threading.Thread( + target=_worker_manager_loop, + args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, + maybe_device_id)) + self.worker_manager_thread.daemon = True + self.worker_manager_thread.start() + else: + self.data_queue = self.worker_result_queue + + for w in self.workers: + w.daemon = True # ensure that the worker exits on process exit + w.start() + + _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) + _set_SIGCHLD_handler() + self.worker_pids_set = True + + # prime the prefetch loop + for _ in range(2 * self.num_workers): + self._put_indices() + + def __len__(self): + return len(self.batch_sampler) + + def _get_batch(self): + if self.timeout > 0: + try: + return self.data_queue.get(timeout=self.timeout) + except queue.Empty: + raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) + else: + return self.data_queue.get() + + def __next__(self): + if self.num_workers == 0: # same-process loading + indices = next(self.sample_iter) # may raise StopIteration + batch = self.collate_fn([self.dataset[i] for i in indices]) + if self.pin_memory: + batch = pin_memory_batch(batch) + return batch + + # check if the next sample has already been generated + if self.rcvd_idx in self.reorder_dict: + batch = self.reorder_dict.pop(self.rcvd_idx) + return self._process_next_batch(batch) + + if self.batches_outstanding == 0: + self._shutdown_workers() + raise StopIteration + + while True: + assert (not self.shutdown and self.batches_outstanding > 0) + idx, batch = self._get_batch() + self.batches_outstanding -= 1 + if idx != self.rcvd_idx: + # store out-of-order samples + self.reorder_dict[idx] = batch + continue + return self._process_next_batch(batch) + + next = __next__ # Python 2 compatibility + + def __iter__(self): + return self + + def _put_indices(self): + assert self.batches_outstanding < 2 * self.num_workers + indices = next(self.sample_iter, None) + if indices is None: + return + self.index_queue.put((self.send_idx, indices)) + self.batches_outstanding += 1 + self.send_idx += 1 + + def _process_next_batch(self, batch): + self.rcvd_idx += 1 + self._put_indices() + if isinstance(batch, ExceptionWrapper): + raise batch.exc_type(batch.exc_msg) + return batch + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("DataLoaderIterator cannot be pickled") + + def _shutdown_workers(self): + try: + if not self.shutdown: + self.shutdown = True + self.done_event.set() + # if worker_manager_thread is waiting to put + while not self.data_queue.empty(): + self.data_queue.get() + for _ in self.workers: + self.index_queue.put(None) + # done_event should be sufficient to exit worker_manager_thread, + # but be safe here and put another None + self.worker_result_queue.put(None) + finally: + # removes pids no matter what + if self.worker_pids_set: + _remove_worker_pids(id(self)) + self.worker_pids_set = False + + def __del__(self): + if self.num_workers > 0: + self._shutdown_workers() + + +class DataLoader(object): + """ + Data loader. Combines a dataset and a sampler, and provides + single- or multi-process iterators over the dataset. + + Arguments: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: 1). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: False). + sampler (Sampler, optional): defines the strategy to draw samples from + the dataset. If specified, ``shuffle`` must be False. + batch_sampler (Sampler, optional): like sampler, but returns a batch of + indices at a time. Mutually exclusive with batch_size, shuffle, + sampler, and drop_last. + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means that the data will be loaded in the main process. + (default: 0) + collate_fn (callable, optional): merges a list of samples to form a mini-batch. + pin_memory (bool, optional): If ``True``, the data loader will copy tensors + into CUDA pinned memory before returning them. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: False) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: 0) + worker_init_fn (callable, optional): If not None, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: None) + + .. note:: By default, each worker will have its PyTorch seed set to + ``base_seed + worker_id``, where ``base_seed`` is a long generated + by main process using its RNG. You may use ``torch.initial_seed()`` to access + this value in :attr:`worker_init_fn`, which can be used to set other seeds + (e.g. NumPy) before data loading. + + .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an + unpicklable object, e.g., a lambda function. + """ + + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, + num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None): + self.dataset = dataset + self.batch_size = batch_size + self.num_workers = num_workers + self.collate_fn = collate_fn + self.pin_memory = pin_memory + self.drop_last = drop_last + self.timeout = timeout + self.worker_init_fn = worker_init_fn + + if timeout < 0: + raise ValueError('timeout option should be non-negative') + + if batch_sampler is not None: + if batch_size > 1 or shuffle or sampler is not None or drop_last: + raise ValueError('batch_sampler is mutually exclusive with ' + 'batch_size, shuffle, sampler, and drop_last') + + if sampler is not None and shuffle: + raise ValueError('sampler is mutually exclusive with shuffle') + + if self.num_workers < 0: + raise ValueError('num_workers cannot be negative; ' + 'use num_workers=0 to disable multiprocessing.') + + if batch_sampler is None: + if sampler is None: + if shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.sampler = sampler + self.batch_sampler = batch_sampler + + def __iter__(self): + return DataLoaderIter(self) + + def __len__(self): + return len(self.batch_sampler) diff --git a/lib/utils/data/dataset.py b/lib/utils/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..605aa877f7031a5cd2b98c0f831410aa80fddefa --- /dev/null +++ b/lib/utils/data/dataset.py @@ -0,0 +1,118 @@ +import bisect +import warnings + +from torch._utils import _accumulate +from torch import randperm + + +class Dataset(object): + """An abstract class representing a Dataset. + + All other datasets should subclass it. All subclasses should override + ``__len__``, that provides the size of the dataset, and ``__getitem__``, + supporting integer indexing in range from 0 to len(self) exclusive. + """ + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def __add__(self, other): + return ConcatDataset([self, other]) + + +class TensorDataset(Dataset): + """Dataset wrapping data and target tensors. + + Each sample will be retrieved by indexing both tensors along the first + dimension. + + Arguments: + data_tensor (Tensor): contains sample data. + target_tensor (Tensor): contains sample targets (labels). + """ + + def __init__(self, data_tensor, target_tensor): + assert data_tensor.size(0) == target_tensor.size(0) + self.data_tensor = data_tensor + self.target_tensor = target_tensor + + def __getitem__(self, index): + return self.data_tensor[index], self.target_tensor[index] + + def __len__(self): + return self.data_tensor.size(0) + + +class ConcatDataset(Dataset): + """ + Dataset to concatenate multiple datasets. + Purpose: useful to assemble different existing datasets, possibly + large-scale datasets as the concatenation operation is done in an + on-the-fly manner. + + Arguments: + datasets (iterable): List of datasets to be concatenated + """ + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets): + super(ConcatDataset, self).__init__() + assert len(datasets) > 0, 'datasets should not be an empty iterable' + self.datasets = list(datasets) + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + def cummulative_sizes(self): + warnings.warn("cummulative_sizes attribute is renamed to " + "cumulative_sizes", DeprecationWarning, stacklevel=2) + return self.cumulative_sizes + + +class Subset(Dataset): + def __init__(self, dataset, indices): + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + return self.dataset[self.indices[idx]] + + def __len__(self): + return len(self.indices) + + +def random_split(dataset, lengths): + """ + Randomly split a dataset into non-overlapping new datasets of given lengths + ds + + Arguments: + dataset (Dataset): Dataset to be split + lengths (iterable): lengths of splits to be produced + """ + if sum(lengths) != len(dataset): + raise ValueError("Sum of input lengths does not equal the length of the input dataset!") + + indices = randperm(sum(lengths)) + return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] diff --git a/lib/utils/data/distributed.py b/lib/utils/data/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d890e28fd2b9e044bdd9494de4a43ad2471eed --- /dev/null +++ b/lib/utils/data/distributed.py @@ -0,0 +1,58 @@ +import math +import torch +from .sampler import Sampler +from torch.distributed import get_world_size, get_rank + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None): + if num_replicas is None: + num_replicas = get_world_size() + if rank is None: + rank = get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = list(torch.randperm(len(self.dataset), generator=g)) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/lib/utils/data/sampler.py b/lib/utils/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..62a9a43bd1d4c21fbdcb262db7da8d4fe27b26de --- /dev/null +++ b/lib/utils/data/sampler.py @@ -0,0 +1,131 @@ +import torch + + +class Sampler(object): + """Base class for all Samplers. + + Every Sampler subclass has to provide an __iter__ method, providing a way + to iterate over indices of dataset elements, and a __len__ method that + returns the length of the returned iterators. + """ + + def __init__(self, data_source): + pass + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class SequentialSampler(Sampler): + """Samples elements sequentially, always in the same order. + + Arguments: + data_source (Dataset): dataset to sample from + """ + + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + +class RandomSampler(Sampler): + """Samples elements randomly, without replacement. + + Arguments: + data_source (Dataset): dataset to sample from + """ + + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(torch.randperm(len(self.data_source)).long()) + + def __len__(self): + return len(self.data_source) + + +class SubsetRandomSampler(Sampler): + """Samples elements randomly from a given list of indices, without replacement. + + Arguments: + indices (list): a list of indices + """ + + def __init__(self, indices): + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in torch.randperm(len(self.indices))) + + def __len__(self): + return len(self.indices) + + +class WeightedRandomSampler(Sampler): + """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). + + Arguments: + weights (list) : a list of weights, not necessary summing up to one + num_samples (int): number of samples to draw + replacement (bool): if ``True``, samples are drawn with replacement. + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + """ + + def __init__(self, weights, num_samples, replacement=True): + self.weights = torch.DoubleTensor(weights) + self.num_samples = num_samples + self.replacement = replacement + + def __iter__(self): + return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) + + def __len__(self): + return self.num_samples + + +class BatchSampler(object): + """Wraps another sampler to yield a mini-batch of indices. + + Args: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + + Example: + >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__(self, sampler, batch_size, drop_last): + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size diff --git a/lib/utils/th.py b/lib/utils/th.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6ef9385e3b5c0a439579d3fd7aa73b5dc62758 --- /dev/null +++ b/lib/utils/th.py @@ -0,0 +1,41 @@ +import torch +from torch.autograd import Variable +import numpy as np +import collections + +__all__ = ['as_variable', 'as_numpy', 'mark_volatile'] + +def as_variable(obj): + if isinstance(obj, Variable): + return obj + if isinstance(obj, collections.Sequence): + return [as_variable(v) for v in obj] + elif isinstance(obj, collections.Mapping): + return {k: as_variable(v) for k, v in obj.items()} + else: + return Variable(obj) + +def as_numpy(obj): + if isinstance(obj, collections.Sequence): + return [as_numpy(v) for v in obj] + elif isinstance(obj, collections.Mapping): + return {k: as_numpy(v) for k, v in obj.items()} + elif isinstance(obj, Variable): + return obj.data.cpu().numpy() + elif torch.is_tensor(obj): + return obj.cpu().numpy() + else: + return np.array(obj) + +def mark_volatile(obj): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + obj.no_grad = True + return obj + elif isinstance(obj, collections.Mapping): + return {k: mark_volatile(o) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [mark_volatile(o) for o in obj] + else: + return obj diff --git a/models/Unet_L1.py b/models/Unet_L1.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe38d6e30ba05bf80e6688a26190214cebf86e3 --- /dev/null +++ b/models/Unet_L1.py @@ -0,0 +1,162 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from torch.autograd import Variable +import util.util as util +from collections import OrderedDict +from torch.autograd import Variable +import itertools +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import sys + + +class PairModel(BaseModel): + def name(self): + return 'CycleGANModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + nb = opt.batchSize + size = opt.fineSize + self.opt = opt + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + self.input_B = self.Tensor(nb, opt.output_nc, size, size) + self.input_img = self.Tensor(nb, opt.input_nc, size, size) + self.input_A_gray = self.Tensor(nb, 1, size, size) + + if opt.vgg > 0: + self.vgg_loss = networks.PerceptualLoss() + self.vgg_loss.cuda() + self.vgg = networks.load_vgg16("./model") + self.vgg.eval() + for param in self.vgg.parameters(): + param.requires_grad = False + # load/define networks + # The naming conversion is different from those used in the paper + # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) + + skip = True if opt.skip > 0 else False + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) + + if not self.isTrain or opt.continue_train: + which_epoch = opt.which_epoch + self.load_network(self.netG_A, 'G_A', which_epoch) + + if self.isTrain: + self.old_lr = opt.lr + self.fake_A_pool = ImagePool(opt.pool_size) + self.fake_B_pool = ImagePool(opt.pool_size) + # define loss functions + if opt.use_wgan: + self.criterionGAN = networks.DiscLossWGANGP() + else: + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + if opt.use_mse: + self.criterionCycle = torch.nn.MSELoss() + else: + self.criterionCycle = torch.nn.L1Loss() + self.criterionL1 = torch.nn.L1Loss() + self.criterionIdt = torch.nn.L1Loss() + # initialize optimizers + self.optimizer_G = torch.optim.Adam(self.netG_A.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG_A) + if opt.isTrain: + self.netG_A.train() + else: + self.netG_A.eval() + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + input_img = input['input_img'] + input_A_gray = input['A_gray'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_A_gray.resize_(input_A_gray.size()).copy_(input_A_gray) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.input_img.resize_(input_img.size()).copy_(input_img) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + self.real_A = Variable(self.input_A) + self.real_B = Variable(self.input_B) + self.real_A_gray = Variable(self.input_A_gray) + self.real_img = Variable(self.input_img) + + + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) + + self.real_B = Variable(self.input_B, volatile=True) + + def predict(self): + self.real_A = Variable(self.input_A, volatile=True) + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) + + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + if self.opt.skip == 1: + latent_real_A = util.tensor2im(self.latent_real_A.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("latent_real_A", latent_real_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) + + # get image paths + def get_image_paths(self): + return self.image_paths + + def backward_G(self): + + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) + # = self.latent_real_A + self.opt.skip * self.real_A + self.L1_AB = self.criterionL1(self.fake_B, self.real_B) * self.opt.l1 + self.loss_G = self.L1_AB + self.loss_G.backward() + + + def optimize_parameters(self, epoch): + # forward + self.forward() + # G_A and G_B + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + + + def get_current_errors(self, epoch): + L1 = self.L1_AB.data[0] + loss_G = self.loss_G.data[0] + return OrderedDict([('L1', L1), ('loss_G', loss_G)]) + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + real_B = util.tensor2im(self.real_B.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) + + def save(self, label): + self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) + + def update_learning_rate(self): + + if self.opt.new_lr: + lr = self.old_lr/2 + else: + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/base_model.py b/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8edc7c46a4a3a9e6b11240d01e4c98e7bc5861 --- /dev/null +++ b/models/base_model.py @@ -0,0 +1,56 @@ +import os +import torch + + +class BaseModel(): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if len(gpu_ids) and torch.cuda.is_available(): + network.cuda(device=gpu_ids[0]) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + network.load_state_dict(torch.load(save_path)) + + def update_learning_rate(): + pass diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed74b16d39215148a065c963e7db80a4dce9c52 --- /dev/null +++ b/models/cycle_gan_model.py @@ -0,0 +1,325 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from torch.autograd import Variable +import util.util as util +from collections import OrderedDict +from torch.autograd import Variable +import itertools +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import sys + + +class CycleGANModel(BaseModel): + def name(self): + return 'CycleGANModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + nb = opt.batchSize + size = opt.fineSize + self.opt = opt + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + self.input_B = self.Tensor(nb, opt.output_nc, size, size) + + if opt.vgg > 0: + self.vgg_loss = networks.PerceptualLoss() + self.vgg_loss.cuda() + self.vgg = networks.load_vgg16("./model") + self.vgg.eval() + for param in self.vgg.parameters(): + param.requires_grad = False + # load/define networks + # The naming conversion is different from those used in the paper + # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) + + skip = True if opt.skip > 0 else False + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) + self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD_A = networks.define_D(opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + self.netD_B = networks.define_D(opt.input_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + if not self.isTrain or opt.continue_train: + which_epoch = opt.which_epoch + self.load_network(self.netG_A, 'G_A', which_epoch) + self.load_network(self.netG_B, 'G_B', which_epoch) + if self.isTrain: + self.load_network(self.netD_A, 'D_A', which_epoch) + self.load_network(self.netD_B, 'D_B', which_epoch) + + if self.isTrain: + self.old_lr = opt.lr + self.fake_A_pool = ImagePool(opt.pool_size) + self.fake_B_pool = ImagePool(opt.pool_size) + # define loss functions + if opt.use_wgan: + self.criterionGAN = networks.DiscLossWGANGP() + else: + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionCycle = torch.nn.L1Loss() + self.criterionL1 = torch.nn.L1Loss() + self.criterionIdt = torch.nn.L1Loss() + # initialize optimizers + self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG_A) + networks.print_network(self.netG_B) + if self.isTrain: + networks.print_network(self.netD_A) + networks.print_network(self.netD_B) + if opt.isTrain: + self.netG_A.train() + self.netG_B.train() + else: + self.netG_A.eval() + self.netG_B.eval() + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + self.real_A = Variable(self.input_A) + self.real_B = Variable(self.input_B) + + + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) + else: + self.fake_B = self.netG_A.forward(self.real_A) + self.rec_A = self.netG_B.forward(self.fake_B) + + self.real_B = Variable(self.input_B, volatile=True) + self.fake_A = self.netG_B.forward(self.real_B) + if self.opt.skip == 1: + self.rec_B, self.latent_fake_A = self.netG_A.forward(self.fake_A) + else: + self.rec_B = self.netG_A.forward(self.fake_A) + + def predict(self): + self.real_A = Variable(self.input_A, volatile=True) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) + else: + self.fake_B = self.netG_A.forward(self.real_A) + self.rec_A = self.netG_B.forward(self.fake_B) + + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + rec_A = util.tensor2im(self.rec_A.data) + if self.opt.skip == 1: + latent_real_A = util.tensor2im(self.latent_real_A.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("latent_real_A", latent_real_A), ("rec_A", rec_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("rec_A", rec_A)]) + + # get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D_basic(self, netD, real, fake): + # Real + pred_real = netD.forward(real) + if self.opt.use_wgan: + loss_D_real = pred_real.mean() + else: + loss_D_real = self.criterionGAN(pred_real, True) + # Fake + pred_fake = netD.forward(fake.detach()) + if self.opt.use_wgan: + loss_D_fake = pred_fake.mean() + else: + loss_D_fake = self.criterionGAN(pred_fake, False) + # Combined loss + if self.opt.use_wgan: + loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty(netD, real.data, fake.data) + else: + loss_D = (loss_D_real + loss_D_fake) * 0.5 + # backward + loss_D.backward() + return loss_D + + def backward_D_A(self): + fake_B = self.fake_B_pool.query(self.fake_B) + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + + def backward_D_B(self): + fake_A = self.fake_A_pool.query(self.fake_A) + self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + + def backward_G(self, epoch): + lambda_idt = self.opt.identity + lambda_A = self.opt.lambda_A + lambda_B = self.opt.lambda_B + # Identity loss + if lambda_idt > 0: + # G_A should be identity if real_B is fed. + self.idt_A = self.netG_A.forward(self.real_B) + self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt + # G_B should be identity if real_A is fed. + self.idt_B = self.netG_B.forward(self.real_A) + self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt + else: + self.loss_idt_A = 0 + self.loss_idt_B = 0 + + # GAN loss + # D_A(G_A(A)) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) + else: + self.fake_B = self.netG_A.forward(self.real_A) + # = self.latent_real_A + self.opt.skip * self.real_A + pred_fake = self.netD_A.forward(self.fake_B) + if self.opt.use_wgan: + self.loss_G_A = -pred_fake.mean() + else: + self.loss_G_A = self.criterionGAN(pred_fake, True) + if self.opt.l1 > 0: + self.L1_AB = self.criterionL1(self.fake_B, self.real_B) * self.opt.l1 + else: + self.L1_AB = 0 + # D_B(G_B(B)) + self.fake_A = self.netG_B.forward(self.real_B) + pred_fake = self.netD_B.forward(self.fake_A) + if self.opt.l1 > 0: + self.L1_BA = self.criterionL1(self.fake_A, self.real_A) * self.opt.l1 + else: + self.L1_BA = 0 + if self.opt.use_wgan: + self.loss_G_B = -pred_fake.mean() + else: + self.loss_G_B = self.criterionGAN(pred_fake, True) + # Forward cycle loss + + if lambda_A > 0: + self.rec_A = self.netG_B.forward(self.fake_B) + self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A + else: + self.loss_cycle_A = 0 + # Backward cycle loss + + # = self.latent_fake_A + self.opt.skip * self.fake_A + if lambda_B > 0: + if self.opt.skip == 1: + self.rec_B, self.latent_fake_A = self.netG_A.forward(self.fake_A) + else: + self.rec_B = self.netG_A.forward(self.fake_A) + self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B + else: + self.loss_cycle_B = 0 + self.loss_vgg_a = self.vgg_loss.compute_vgg_loss(self.vgg, self.fake_A, self.real_B) * self.opt.vgg if self.opt.vgg > 0 else 0 + self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(self.vgg, self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0 + if epoch <= 10: + self.loss_vgg_a = 0 + self.loss_vgg_b = 0 + # combined loss + self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_vgg_a + self.loss_vgg_b + # self.loss_G = self.L1_AB + self.L1_BA + self.loss_G.backward() + + def optimize_parameters(self, epoch): + # forward + self.forward() + # G_A and G_B + self.optimizer_G.zero_grad() + self.backward_G(epoch) + self.optimizer_G.step() + # D_A + self.optimizer_D_A.zero_grad() + self.backward_D_A() + self.optimizer_D_A.step() + # D_B + self.optimizer_D_B.zero_grad() + self.backward_D_B() + self.optimizer_D_B.step() + + def get_current_errors(self, epoch): + D_A = self.loss_D_A.data[0] + G_A = self.loss_G_A.data[0] + Cyc_A = self.loss_cycle_A.data[0] + D_B = self.loss_D_B.data[0] + G_B = self.loss_G_B.data[0] + Cyc_B = self.loss_cycle_B.data[0] + if epoch <= 10: + vgg = 0 + else: + vgg = (self.loss_vgg_a.data[0] + self.loss_vgg_b.data[0]) / self.opt.vgg if self.opt.vgg > 0 else 0 + if self.opt.lambda_A > 0.0: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), + ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ("vgg", vgg)]) + else: + return OrderedDict([('D_A', D_A), ('G_A', G_A), + ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg)) + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + if self.opt.skip > 0: + latent_real_A = util.tensor2im(self.latent_real_A.data) + + real_B = util.tensor2im(self.real_B.data) + fake_A = util.tensor2im(self.fake_A.data) + + if self.opt.lambda_A > 0.0: + rec_A = util.tensor2im(self.rec_A.data) + rec_B = util.tensor2im(self.rec_B.data) + if self.opt.skip > 0: + latent_fake_A = util.tensor2im(self.latent_fake_A.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('latent_fake_A', latent_fake_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) + else: + if self.opt.skip > 0: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('real_B', real_B), ('fake_A', fake_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), + ('real_B', real_B), ('fake_A', fake_A)]) + + def save(self, label): + self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) + self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) + self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) + self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D_A.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_D_B.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr diff --git a/models/models.py b/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..8365f4b0a80744cbdf70042db758f04fc0481a13 --- /dev/null +++ b/models/models.py @@ -0,0 +1,38 @@ + +def create_model(opt): + model = None + print(opt.model) + if opt.model == 'cycle_gan': + assert(opt.dataset_mode == 'unaligned') + from .cycle_gan_model import CycleGANModel + model = CycleGANModel() + elif opt.model == 'pix2pix': + assert(opt.dataset_mode == 'pix2pix') + from .pix2pix_model import Pix2PixModel + model = Pix2PixModel() + elif opt.model == 'pair': + # assert(opt.dataset_mode == 'pair') + # from .pair_model import PairModel + from .Unet_L1 import PairModel + model = PairModel() + elif opt.model == 'single': + # assert(opt.dataset_mode == 'unaligned') + from .single_model import SingleModel + model = SingleModel() + elif opt.model == 'temp': + # assert(opt.dataset_mode == 'unaligned') + from .temp_model import TempModel + model = TempModel() + elif opt.model == 'UNIT': + assert(opt.dataset_mode == 'unaligned') + from .unit_model import UNITModel + model = UNITModel() + elif opt.model == 'test': + assert(opt.dataset_mode == 'single') + from .test_model import TestModel + model = TestModel() + else: + raise ValueError("Model [%s] not recognized." % opt.model) + model.initialize(opt) + print("model [%s] was created" % (model.name())) + return model diff --git a/models/multi_model.py b/models/multi_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6b64ababb9d0bec9211235294d788834355e6db8 --- /dev/null +++ b/models/multi_model.py @@ -0,0 +1,359 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from torch.autograd import Variable +import util.util as util +from collections import OrderedDict +from torch.autograd import Variable +import itertools +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import sys + + +class MultiModel(BaseModel): + def name(self): + return 'MultiGANModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + nb = opt.batchSize + size = opt.fineSize + self.opt = opt + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + self.input_B = self.Tensor(nb, opt.output_nc, size, size) + + if opt.vgg > 0: + self.vgg_loss = networks.PerceptualLoss() + self.vgg_loss.cuda() + self.vgg = networks.load_vgg16("./model") + self.vgg.eval() + for param in self.vgg.parameters(): + param.requires_grad = False + # load/define networks + # The naming conversion is different from those used in the paper + # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) + + skip = True if opt.skip > 0 else False + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) + self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD_A = networks.define_D(opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + self.netD_B = networks.define_D(opt.input_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + if not self.isTrain or opt.continue_train: + which_epoch = opt.which_epoch + self.load_network(self.netG_A, 'G_A', which_epoch) + self.load_network(self.netG_B, 'G_B', which_epoch) + if self.isTrain: + self.load_network(self.netD_A, 'D_A', which_epoch) + self.load_network(self.netD_B, 'D_B', which_epoch) + + if self.isTrain: + self.old_lr = opt.lr + self.fake_A_pool = ImagePool(opt.pool_size) + self.fake_B_pool = ImagePool(opt.pool_size) + # define loss functions + if opt.use_wgan: + self.criterionGAN = networks.DiscLossWGANGP() + else: + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + if opt.use_mse: + self.criterionCycle = torch.nn.MSELoss() + else: + self.criterionCycle = torch.nn.L1Loss() + self.criterionL1 = torch.nn.L1Loss() + self.criterionIdt = torch.nn.L1Loss() + # initialize optimizers + self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG_A) + networks.print_network(self.netG_B) + if self.isTrain: + networks.print_network(self.netD_A) + networks.print_network(self.netD_B) + if opt.isTrain: + self.netG_A.train() + self.netG_B.train() + else: + self.netG_A.eval() + self.netG_B.eval() + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + self.real_A = Variable(self.input_A) + self.real_B = Variable(self.input_B) + + + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) + else: + self.fake_B = self.netG_A.forward(self.real_A) + self.rec_A = self.netG_B.forward(self.fake_B) + + self.real_B = Variable(self.input_B, volatile=True) + self.fake_A = self.netG_B.forward(self.real_B) + if self.opt.skip == 1: + self.rec_B, self.latent_fake_A = self.netG_A.forward(self.fake_A) + else: + self.rec_B = self.netG_A.forward(self.fake_A) + + def predict(self): + self.real_A = Variable(self.input_A, volatile=True) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) + else: + self.fake_B = self.netG_A.forward(self.real_A) + self.rec_A = self.netG_B.forward(self.fake_B) + + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + rec_A = util.tensor2im(self.rec_A.data) + if self.opt.skip == 1: + latent_real_A = util.tensor2im(self.latent_real_A.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("latent_real_A", latent_real_A), ("rec_A", rec_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("rec_A", rec_A)]) + + # get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D_basic(self, netD, real, fake): + # Real + pred_real = netD.forward(real) + if self.opt.use_wgan: + loss_D_real = pred_real.mean() + else: + loss_D_real = self.criterionGAN(pred_real, True) + # Fake + pred_fake = netD.forward(fake.detach()) + if self.opt.use_wgan: + loss_D_fake = pred_fake.mean() + else: + loss_D_fake = self.criterionGAN(pred_fake, False) + # Combined loss + if self.opt.use_wgan: + loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty(netD, real.data, fake.data) + else: + loss_D = (loss_D_real + loss_D_fake) * 0.5 + # backward + loss_D.backward() + return loss_D + + def backward_D_A(self): + fake_B = self.fake_B_pool.query(self.fake_B) + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + + def backward_D_B(self): + fake_A = self.fake_A_pool.query(self.fake_A) + self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + + def backward_G(self): + lambda_idt = self.opt.identity + lambda_A = self.opt.lambda_A + lambda_B = self.opt.lambda_B + # Identity loss + if lambda_idt > 0: + # G_A should be identity if real_B is fed. + if self.opt.skip == 1: + self.idt_A, _ = self.netG_A.forward(self.real_B) + else: + self.idt_A = self.netG_A.forward(self.real_B) + self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt + # G_B should be identity if real_A is fed. + self.idt_B = self.netG_B.forward(self.real_A) + self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt + else: + self.loss_idt_A = 0 + self.loss_idt_B = 0 + + # GAN loss + # D_A(G_A(A)) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) + else: + self.fake_B = self.netG_A.forward(self.real_A) + # = self.latent_real_A + self.opt.skip * self.real_A + pred_fake = self.netD_A.forward(self.fake_B) + if self.opt.use_wgan: + self.loss_G_A = -pred_fake.mean() + else: + self.loss_G_A = self.criterionGAN(pred_fake, True) + self.L1_AB = self.criterionL1(self.fake_B, self.real_B) * self.opt.l1 + # D_B(G_B(B)) + self.fake_A = self.netG_B.forward(self.real_B) + pred_fake = self.netD_B.forward(self.fake_A) + self.L1_BA = self.criterionL1(self.fake_A, self.real_A) * self.opt.l1 + if self.opt.use_wgan: + self.loss_G_B = -pred_fake.mean() + else: + self.loss_G_B = self.criterionGAN(pred_fake, True) + # Forward cycle loss + + if lambda_A > 0: + self.rec_A = self.netG_B.forward(self.fake_B) + self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A + else: + self.loss_cycle_A = 0 + # Backward cycle loss + + # = self.latent_fake_A + self.opt.skip * self.fake_A + if lambda_B > 0: + if self.opt.skip == 1: + self.rec_B, self.latent_fake_A = self.netG_A.forward(self.fake_A) + else: + self.rec_B = self.netG_A.forward(self.fake_A) + self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B + else: + self.loss_cycle_B = 0 + self.loss_vgg_a = self.vgg_loss.compute_vgg_loss(self.vgg, self.fake_A, self.real_B) * self.opt.vgg if self.opt.vgg > 0 else 0 + self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(self.vgg, self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0 + # combined loss + self.loss_G = self.loss_G_A + self.loss_G_B + self.L1_AB + self.L1_BA + self.loss_cycle_A + self.loss_cycle_B + \ + self.loss_vgg_a + self.loss_vgg_b + \ + self.loss_idt_A + self.loss_idt_B + # self.loss_G = self.L1_AB + self.L1_BA + self.loss_G.backward() + + + def optimize_parameters(self): + # forward + self.forward() + # G_A and G_B + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + # D_A + self.optimizer_D_A.zero_grad() + self.backward_D_A() + self.optimizer_D_A.step() + # D_B + self.optimizer_D_B.zero_grad() + self.backward_D_B() + self.optimizer_D_B.step() + + + def get_current_errors(self): + D_A = self.loss_D_A.data[0] + G_A = self.loss_G_A.data[0] + L1 = (self.L1_AB + self.L1_BA).data[0] + Cyc_A = self.loss_cycle_A.data[0] + D_B = self.loss_D_B.data[0] + G_B = self.loss_G_B.data[0] + Cyc_B = self.loss_cycle_B.data[0] + vgg = (self.loss_vgg_a.data[0] + self.loss_vgg_b.data[0])/self.opt.vgg if self.opt.vgg > 0 else 0 + if self.opt.identity > 0: + idt = self.loss_idt_A.data[0] + self.loss_idt_B.data[0] + if self.opt.lambda_A > 0.0: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('Cyc_A', Cyc_A), + ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ("vgg", vgg), ("idt", idt)]) + else: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), + ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg), ("idt", idt)) + else: + if self.opt.lambda_A > 0.0: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('Cyc_A', Cyc_A), + ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ("vgg", vgg)]) + else: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), + ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg)) + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + if self.opt.skip > 0: + latent_real_A = util.tensor2im(self.latent_real_A.data) + + real_B = util.tensor2im(self.real_B.data) + fake_A = util.tensor2im(self.fake_A.data) + + if self.opt.identity > 0: + idt_A = util.tensor2im(self.idt_A.data) + idt_B = util.tensor2im(self.idt_B.data) + if self.opt.lambda_A > 0.0: + rec_A = util.tensor2im(self.rec_A.data) + rec_B = util.tensor2im(self.rec_B.data) + if self.opt.skip > 0: + latent_fake_A = util.tensor2im(self.latent_fake_A.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('latent_fake_A', latent_fake_A), + ("idt_A", idt_A), ("idt_B", idt_B)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ("idt_A", idt_A), ("idt_B", idt_B)]) + else: + if self.opt.skip > 0: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('real_B', real_B), ('fake_A', fake_A), ("idt_A", idt_A), ("idt_B", idt_B)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), + ('real_B', real_B), ('fake_A', fake_A), ("idt_A", idt_A), ("idt_B", idt_B)]) + else: + if self.opt.lambda_A > 0.0: + rec_A = util.tensor2im(self.rec_A.data) + rec_B = util.tensor2im(self.rec_B.data) + if self.opt.skip > 0: + latent_fake_A = util.tensor2im(self.latent_fake_A.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('latent_fake_A', latent_fake_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) + else: + if self.opt.skip > 0: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('real_B', real_B), ('fake_A', fake_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), + ('real_B', real_B), ('fake_A', fake_A)]) + + def save(self, label): + self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) + self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) + self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) + self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) + + def update_learning_rate(self): + + if self.opt.new_lr: + lr = self.old_lr/2 + else: + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D_A.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_D_B.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr diff --git a/models/networks.py b/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..85c141db6169c32d5ba94b18f84ed07a5f247a0e --- /dev/null +++ b/models/networks.py @@ -0,0 +1,1181 @@ +import torch +import os +import math +import torch.nn as nn +from torch.nn import init +import functools +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np +# from torch.utils.serialization import load_lua +from lib.nn import SynchronizedBatchNorm2d as SynBN2d +############################################################################### +# Functions +############################################################################### + +def pad_tensor(input): + + height_org, width_org = input.shape[2], input.shape[3] + divide = 16 + + if width_org % divide != 0 or height_org % divide != 0: + + width_res = width_org % divide + height_res = height_org % divide + if width_res != 0: + width_div = divide - width_res + pad_left = int(width_div / 2) + pad_right = int(width_div - pad_left) + else: + pad_left = 0 + pad_right = 0 + + if height_res != 0: + height_div = divide - height_res + pad_top = int(height_div / 2) + pad_bottom = int(height_div - pad_top) + else: + pad_top = 0 + pad_bottom = 0 + + padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom)) + input = padding(input) + else: + pad_left = 0 + pad_right = 0 + pad_top = 0 + pad_bottom = 0 + + height, width = input.data.shape[2], input.data.shape[3] + assert width % divide == 0, 'width cant divided by stride' + assert height % divide == 0, 'height cant divided by stride' + + return input, pad_left, pad_right, pad_top, pad_bottom + +def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom): + height, width = input.shape[2], input.shape[3] + return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right] + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'synBN': + norm_layer = functools.partial(SynBN2d, affine=True) + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm) + return norm_layer + + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], skip=False, opt=None): + netG = None + use_gpu = len(gpu_ids) > 0 + norm_layer = get_norm_layer(norm_type=norm) + + if use_gpu: + assert(torch.cuda.is_available()) + + if which_model_netG == 'resnet_9blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) + elif which_model_netG == 'resnet_6blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) + elif which_model_netG == 'unet_128': + netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) + elif which_model_netG == 'unet_256': + netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt) + elif which_model_netG == 'unet_512': + netG = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt) + elif which_model_netG == 'sid_unet': + netG = Unet(opt, skip) + elif which_model_netG == 'sid_unet_shuffle': + netG = Unet_pixelshuffle(opt, skip) + elif which_model_netG == 'sid_unet_resize': + netG = Unet_resize_conv(opt, skip) + elif which_model_netG == 'DnCNN': + netG = DnCNN(opt, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3) + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) + if len(gpu_ids) >= 0: + netG.cuda(device=gpu_ids[0]) + netG = torch.nn.DataParallel(netG, gpu_ids) + netG.apply(weights_init) + return netG + + +def define_D(input_nc, ndf, which_model_netD, + n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], patch=False): + netD = None + use_gpu = len(gpu_ids) > 0 + norm_layer = get_norm_layer(norm_type=norm) + + if use_gpu: + assert(torch.cuda.is_available()) + if which_model_netD == 'basic': + netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'n_layers': + netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'no_norm': + netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'no_norm_4': + netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'no_patchgan': + netD = FCDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, patch=patch) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % + which_model_netD) + if use_gpu: + netD.cuda(device=gpu_ids[0]) + netD = torch.nn.DataParallel(netD, gpu_ids) + netD.apply(weights_init) + return netD + + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + + +############################################################################## +# Classes +############################################################################## + + +# Defines the GAN loss which uses either LSGAN or the regular GAN. +# When LSGAN is used, it is basically same as MSELoss, +# but it abstracts away the need to create the target label tensor +# that has the same size as the input +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + return self.loss(input, target_tensor) + + + +class DiscLossWGANGP(): + def __init__(self): + self.LAMBDA = 10 + + def name(self): + return 'DiscLossWGAN-GP' + + def initialize(self, opt, tensor): + # DiscLossLS.initialize(self, opt, tensor) + self.LAMBDA = 10 + + # def get_g_loss(self, net, realA, fakeB): + # # First, G(A) should fake the discriminator + # self.D_fake = net.forward(fakeB) + # return -self.D_fake.mean() + + def calc_gradient_penalty(self, netD, real_data, fake_data): + alpha = torch.rand(1, 1) + alpha = alpha.expand(real_data.size()) + alpha = alpha.cuda() + + interpolates = alpha * real_data + ((1 - alpha) * fake_data) + + interpolates = interpolates.cuda() + interpolates = Variable(interpolates, requires_grad=True) + + disc_interpolates = netD.forward(interpolates) + + gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).cuda(), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA + return gradient_penalty + +# Defines the generator that consists of Resnet blocks between a few +# downsampling/upsampling operations. +# Code and idea originally from Justin Johnson's architecture. +# https://github.com/jcjohnson/fast-neural-style/ +class ResnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'): + assert(n_blocks >= 0) + super(ResnetGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.ngf = ngf + self.gpu_ids = gpu_ids + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, + stride=2, padding=1), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2**n_downsampling + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)] + + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim), + nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +# Defines the Unet generator. +# |num_downs|: number of downsamplings in UNet. For example, +# if |num_downs| == 7, image of size 128x128 will become of size 1x1 +# at the bottleneck +class UnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, num_downs, ngf=64, + norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], skip=False, opt=None): + super(UnetGenerator, self).__init__() + self.gpu_ids = gpu_ids + self.opt = opt + # currently support only input_nc == output_nc + assert(input_nc == output_nc) + + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, opt=opt) + for i in range(num_downs - 5): + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout, opt=opt) + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer, opt=opt) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer, opt=opt) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer, opt=opt) + unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer, opt=opt) + + if skip == True: + skipmodule = SkipModule(unet_block, opt) + self.model = skipmodule + else: + self.model = unet_block + + def forward(self, input): + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + +class SkipModule(nn.Module): + def __init__(self, submodule, opt): + super(SkipModule, self).__init__() + self.submodule = submodule + self.opt = opt + + def forward(self, x): + latent = self.submodule(x) + return self.opt.skip*x + latent, latent + + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class UnetSkipConnectionBlock(nn.Module): + def __init__(self, outer_nc, inner_nc, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, opt=None): + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + + downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, + stride=2, padding=1) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if opt.use_norm == 0: + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downrelu, downconv] + up = [uprelu, upconv] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downrelu, downconv] + up = [uprelu, upconv] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + else: + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + return torch.cat([self.model(x), x], 1) + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): + super(NLayerDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + + kw = 4 + padw = int(np.ceil((kw-1)/2)) + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] + + if use_sigmoid: + sequence += [nn.Sigmoid()] + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + # if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + # return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + # else: + return self.model(input) + +class NoNormDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]): + super(NoNormDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + + kw = 4 + padw = int(np.ceil((kw-1)/2)) + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] + + if use_sigmoid: + sequence += [nn.Sigmoid()] + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + # if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + # return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + # else: + return self.model(input) + +class FCDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[], patch=False): + super(FCDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + self.use_sigmoid = use_sigmoid + kw = 4 + padw = int(np.ceil((kw-1)/2)) + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] + if patch: + self.linear = nn.Linear(7*7,1) + else: + self.linear = nn.Linear(13*13,1) + if use_sigmoid: + self.sigmoid = nn.Sigmoid() + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + batchsize = input.size()[0] + output = self.model(input) + output = output.view(batchsize,-1) + # print(output.size()) + output = self.linear(output) + if self.use_sigmoid: + print("sigmoid") + output = self.sigmoid(output) + return output + + +class Unet_resize_conv(nn.Module): + def __init__(self, opt, skip): + super(Unet_resize_conv, self).__init__() + + self.opt = opt + self.skip = skip + p = 1 + # self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p) + if opt.self_attention: + self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p) + # self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p) + self.downsample_1 = nn.MaxPool2d(2) + self.downsample_2 = nn.MaxPool2d(2) + self.downsample_3 = nn.MaxPool2d(2) + self.downsample_4 = nn.MaxPool2d(2) + else: + self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p) + self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn1_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) + self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p) + self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn1_2 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) + self.max_pool1 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) + + self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p) + self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn2_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) + self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p) + self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn2_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) + self.max_pool2 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) + + self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p) + self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn3_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) + self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p) + self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn3_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) + self.max_pool3 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) + + self.conv4_1 = nn.Conv2d(128, 256, 3, padding=p) + self.LReLU4_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn4_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) + self.conv4_2 = nn.Conv2d(256, 256, 3, padding=p) + self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn4_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) + self.max_pool4 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) + + self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p) + self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn5_1 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512) + self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p) + self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn5_2 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512) + + # self.deconv5 = nn.ConvTranspose2d(512, 256, 2, stride=2) + self.deconv5 = nn.Conv2d(512, 256, 3, padding=p) + self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p) + self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn6_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) + self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p) + self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn6_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) + + # self.deconv6 = nn.ConvTranspose2d(256, 128, 2, stride=2) + self.deconv6 = nn.Conv2d(256, 128, 3, padding=p) + self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p) + self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn7_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) + self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p) + self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn7_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) + + # self.deconv7 = nn.ConvTranspose2d(128, 64, 2, stride=2) + self.deconv7 = nn.Conv2d(128, 64, 3, padding=p) + self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p) + self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn8_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) + self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p) + self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn8_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) + + # self.deconv8 = nn.ConvTranspose2d(64, 32, 2, stride=2) + self.deconv8 = nn.Conv2d(64, 32, 3, padding=p) + self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p) + self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn9_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) + self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p) + self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True) + + self.conv10 = nn.Conv2d(32, 3, 1) + if self.opt.tanh: + self.tanh = nn.Tanh() + + def depth_to_space(self, input, block_size): + block_size_sq = block_size*block_size + output = input.permute(0, 2, 3, 1) + (batch_size, d_height, d_width, d_depth) = output.size() + s_depth = int(d_depth / block_size_sq) + s_width = int(d_width * block_size) + s_height = int(d_height * block_size) + t_1 = output.resize(batch_size, d_height, d_width, block_size_sq, s_depth) + spl = t_1.split(block_size, 3) + stack = [t_t.resize(batch_size, d_height, s_width, s_depth) for t_t in spl] + output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).resize(batch_size, s_height, s_width, s_depth) + output = output.permute(0, 3, 1, 2) + return output + + def forward(self, input, gray): + flag = 0 + if input.size()[3] > 2200: + avg = nn.AvgPool2d(2) + input = avg(input) + gray = avg(gray) + flag = 1 + # pass + input, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(input) + gray, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(gray) + if self.opt.self_attention: + gray_2 = self.downsample_1(gray) + gray_3 = self.downsample_2(gray_2) + gray_4 = self.downsample_3(gray_3) + gray_5 = self.downsample_4(gray_4) + if self.opt.use_norm == 1: + if self.opt.self_attention: + x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1)))) + # x = self.bn1_1(self.LReLU1_1(self.conv1_1(input))) + else: + x = self.bn1_1(self.LReLU1_1(self.conv1_1(input))) + conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x))) + x = self.max_pool1(conv1) + + x = self.bn2_1(self.LReLU2_1(self.conv2_1(x))) + conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x))) + x = self.max_pool2(conv2) + + x = self.bn3_1(self.LReLU3_1(self.conv3_1(x))) + conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x))) + x = self.max_pool3(conv3) + + x = self.bn4_1(self.LReLU4_1(self.conv4_1(x))) + conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x))) + x = self.max_pool4(conv4) + + x = self.bn5_1(self.LReLU5_1(self.conv5_1(x))) + x = x*gray_5 if self.opt.self_attention else x + conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x))) + + conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear') + conv4 = conv4*gray_4 if self.opt.self_attention else conv4 + up6 = torch.cat([self.deconv5(conv5), conv4], 1) + x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6))) + conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x))) + + conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear') + conv3 = conv3*gray_3 if self.opt.self_attention else conv3 + up7 = torch.cat([self.deconv6(conv6), conv3], 1) + x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7))) + conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x))) + + conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear') + conv2 = conv2*gray_2 if self.opt.self_attention else conv2 + up8 = torch.cat([self.deconv7(conv7), conv2], 1) + x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8))) + conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x))) + + conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear') + conv1 = conv1*gray if self.opt.self_attention else conv1 + up9 = torch.cat([self.deconv8(conv8), conv1], 1) + x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9))) + conv9 = self.LReLU9_2(self.conv9_2(x)) + + latent = self.conv10(conv9) + + if self.opt.times_residual: + latent = latent*gray + + # output = self.depth_to_space(conv10, 2) + if self.opt.tanh: + latent = self.tanh(latent) + if self.skip: + if self.opt.linear_add: + if self.opt.latent_threshold: + latent = F.relu(latent) + elif self.opt.latent_norm: + latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) + input = (input - torch.min(input))/(torch.max(input) - torch.min(input)) + output = latent + input*self.opt.skip + output = output*2 - 1 + else: + if self.opt.latent_threshold: + latent = F.relu(latent) + elif self.opt.latent_norm: + latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) + output = latent + input*self.opt.skip + else: + output = latent + + if self.opt.linear: + output = output/torch.max(torch.abs(output)) + + + elif self.opt.use_norm == 0: + if self.opt.self_attention: + x = self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1))) + else: + x = self.LReLU1_1(self.conv1_1(input)) + conv1 = self.LReLU1_2(self.conv1_2(x)) + x = self.max_pool1(conv1) + + x = self.LReLU2_1(self.conv2_1(x)) + conv2 = self.LReLU2_2(self.conv2_2(x)) + x = self.max_pool2(conv2) + + x = self.LReLU3_1(self.conv3_1(x)) + conv3 = self.LReLU3_2(self.conv3_2(x)) + x = self.max_pool3(conv3) + + x = self.LReLU4_1(self.conv4_1(x)) + conv4 = self.LReLU4_2(self.conv4_2(x)) + x = self.max_pool4(conv4) + + x = self.LReLU5_1(self.conv5_1(x)) + x = x*gray_5 if self.opt.self_attention else x + conv5 = self.LReLU5_2(self.conv5_2(x)) + + conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear') + conv4 = conv4*gray_4 if self.opt.self_attention else conv4 + up6 = torch.cat([self.deconv5(conv5), conv4], 1) + x = self.LReLU6_1(self.conv6_1(up6)) + conv6 = self.LReLU6_2(self.conv6_2(x)) + + conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear') + conv3 = conv3*gray_3 if self.opt.self_attention else conv3 + up7 = torch.cat([self.deconv6(conv6), conv3], 1) + x = self.LReLU7_1(self.conv7_1(up7)) + conv7 = self.LReLU7_2(self.conv7_2(x)) + + conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear') + conv2 = conv2*gray_2 if self.opt.self_attention else conv2 + up8 = torch.cat([self.deconv7(conv7), conv2], 1) + x = self.LReLU8_1(self.conv8_1(up8)) + conv8 = self.LReLU8_2(self.conv8_2(x)) + + conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear') + conv1 = conv1*gray if self.opt.self_attention else conv1 + up9 = torch.cat([self.deconv8(conv8), conv1], 1) + x = self.LReLU9_1(self.conv9_1(up9)) + conv9 = self.LReLU9_2(self.conv9_2(x)) + + latent = self.conv10(conv9) + + if self.opt.times_residual: + latent = latent*gray + + if self.opt.tanh: + latent = self.tanh(latent) + if self.skip: + if self.opt.linear_add: + if self.opt.latent_threshold: + latent = F.relu(latent) + elif self.opt.latent_norm: + latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) + input = (input - torch.min(input))/(torch.max(input) - torch.min(input)) + output = latent + input*self.opt.skip + output = output*2 - 1 + else: + if self.opt.latent_threshold: + latent = F.relu(latent) + elif self.opt.latent_norm: + latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) + output = latent + input*self.opt.skip + else: + output = latent + + if self.opt.linear: + output = output/torch.max(torch.abs(output)) + + output = pad_tensor_back(output, pad_left, pad_right, pad_top, pad_bottom) + latent = pad_tensor_back(latent, pad_left, pad_right, pad_top, pad_bottom) + gray = pad_tensor_back(gray, pad_left, pad_right, pad_top, pad_bottom) + if flag == 1: + output = F.upsample(output, scale_factor=2, mode='bilinear') + gray = F.upsample(gray, scale_factor=2, mode='bilinear') + if self.skip: + return output, latent + else: + return output + +class DnCNN(nn.Module): + def __init__(self, opt=None, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3): + super(DnCNN, self).__init__() + kernel_size = 3 + padding = 1 + layers = [] + + layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True)) + layers.append(nn.ReLU(inplace=True)) + for _ in range(depth-2): + layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False)) + layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False)) + self.dncnn = nn.Sequential(*layers) + self._initialize_weights() + + def forward(self, x): + y = x + out = self.dncnn(x) + return y+out + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.orthogonal_(m.weight) + print('init weight') + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + +class Vgg16(nn.Module): + def __init__(self): + super(Vgg16, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + def forward(self, X, opt): + h = F.relu(self.conv1_1(X), inplace=True) + h = F.relu(self.conv1_2(h), inplace=True) + # relu1_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv2_1(h), inplace=True) + h = F.relu(self.conv2_2(h), inplace=True) + # relu2_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv3_1(h), inplace=True) + h = F.relu(self.conv3_2(h), inplace=True) + h = F.relu(self.conv3_3(h), inplace=True) + # relu3_3 = h + if opt.vgg_choose != "no_maxpool": + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv4_1(h), inplace=True) + relu4_1 = h + h = F.relu(self.conv4_2(h), inplace=True) + relu4_2 = h + conv4_3 = self.conv4_3(h) + h = F.relu(conv4_3, inplace=True) + relu4_3 = h + + if opt.vgg_choose != "no_maxpool": + if opt.vgg_maxpooling: + h = F.max_pool2d(h, kernel_size=2, stride=2) + + relu5_1 = F.relu(self.conv5_1(h), inplace=True) + relu5_2 = F.relu(self.conv5_2(relu5_1), inplace=True) + conv5_3 = self.conv5_3(relu5_2) + h = F.relu(conv5_3, inplace=True) + relu5_3 = h + if opt.vgg_choose == "conv4_3": + return conv4_3 + elif opt.vgg_choose == "relu4_2": + return relu4_2 + elif opt.vgg_choose == "relu4_1": + return relu4_1 + elif opt.vgg_choose == "relu4_3": + return relu4_3 + elif opt.vgg_choose == "conv5_3": + return conv5_3 + elif opt.vgg_choose == "relu5_1": + return relu5_1 + elif opt.vgg_choose == "relu5_2": + return relu5_2 + elif opt.vgg_choose == "relu5_3" or "maxpool": + return relu5_3 + +def vgg_preprocess(batch, opt): + tensortype = type(batch.data) + (r, g, b) = torch.chunk(batch, 3, dim = 1) + batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR + batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] + if opt.vgg_mean: + mean = tensortype(batch.data.size()) + mean[:, 0, :, :] = 103.939 + mean[:, 1, :, :] = 116.779 + mean[:, 2, :, :] = 123.680 + batch = batch.sub(Variable(mean)) # subtract mean + return batch + +class PerceptualLoss(nn.Module): + def __init__(self, opt): + super(PerceptualLoss, self).__init__() + self.opt = opt + self.instancenorm = nn.InstanceNorm2d(512, affine=False) + + def compute_vgg_loss(self, vgg, img, target): + img_vgg = vgg_preprocess(img, self.opt) + target_vgg = vgg_preprocess(target, self.opt) + img_fea = vgg(img_vgg, self.opt) + target_fea = vgg(target_vgg, self.opt) + if self.opt.no_vgg_instance: + return torch.mean((img_fea - target_fea) ** 2) + else: + return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) + +def load_vgg16(model_dir, gpu_ids): + """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """ + if not os.path.exists(model_dir): + os.mkdir(model_dir) + # if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')): + # if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')): + # os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7')) + # vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7')) + # vgg = Vgg16() + # for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): + # dst.data[:] = src + # torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight')) + vgg = Vgg16() + # vgg.cuda() + vgg.cuda(device=gpu_ids[0]) + vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) + vgg = torch.nn.DataParallel(vgg, gpu_ids) + return vgg + + + +class FCN32s(nn.Module): + def __init__(self, n_class=21): + super(FCN32s, self).__init__() + # conv1 + self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) + self.relu1_1 = nn.ReLU(inplace=True) + self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) + self.relu1_2 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 + + # conv2 + self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) + self.relu2_1 = nn.ReLU(inplace=True) + self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) + self.relu2_2 = nn.ReLU(inplace=True) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 + + # conv3 + self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) + self.relu3_1 = nn.ReLU(inplace=True) + self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) + self.relu3_2 = nn.ReLU(inplace=True) + self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) + self.relu3_3 = nn.ReLU(inplace=True) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 + + # conv4 + self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) + self.relu4_1 = nn.ReLU(inplace=True) + self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) + self.relu4_2 = nn.ReLU(inplace=True) + self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) + self.relu4_3 = nn.ReLU(inplace=True) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 + + # conv5 + self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) + self.relu5_1 = nn.ReLU(inplace=True) + self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) + self.relu5_2 = nn.ReLU(inplace=True) + self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) + self.relu5_3 = nn.ReLU(inplace=True) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 + + # fc6 + self.fc6 = nn.Conv2d(512, 4096, 7) + self.relu6 = nn.ReLU(inplace=True) + self.drop6 = nn.Dropout2d() + + # fc7 + self.fc7 = nn.Conv2d(4096, 4096, 1) + self.relu7 = nn.ReLU(inplace=True) + self.drop7 = nn.Dropout2d() + + self.score_fr = nn.Conv2d(4096, n_class, 1) + self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32, + bias=False) + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + m.weight.data.zero_() + if m.bias is not None: + m.bias.data.zero_() + if isinstance(m, nn.ConvTranspose2d): + assert m.kernel_size[0] == m.kernel_size[1] + initial_weight = get_upsampling_weight( + m.in_channels, m.out_channels, m.kernel_size[0]) + m.weight.data.copy_(initial_weight) + + def forward(self, x): + h = x + h = self.relu1_1(self.conv1_1(h)) + h = self.relu1_2(self.conv1_2(h)) + h = self.pool1(h) + + h = self.relu2_1(self.conv2_1(h)) + h = self.relu2_2(self.conv2_2(h)) + h = self.pool2(h) + + h = self.relu3_1(self.conv3_1(h)) + h = self.relu3_2(self.conv3_2(h)) + h = self.relu3_3(self.conv3_3(h)) + h = self.pool3(h) + + h = self.relu4_1(self.conv4_1(h)) + h = self.relu4_2(self.conv4_2(h)) + h = self.relu4_3(self.conv4_3(h)) + h = self.pool4(h) + + h = self.relu5_1(self.conv5_1(h)) + h = self.relu5_2(self.conv5_2(h)) + h = self.relu5_3(self.conv5_3(h)) + h = self.pool5(h) + + h = self.relu6(self.fc6(h)) + h = self.drop6(h) + + h = self.relu7(self.fc7(h)) + h = self.drop7(h) + + h = self.score_fr(h) + + h = self.upscore(h) + h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous() + return h + +def load_fcn(model_dir): + fcn = FCN32s() + fcn.load_state_dict(torch.load(os.path.join(model_dir, 'fcn32s_from_caffe.pth'))) + fcn.cuda() + return fcn + +class SemanticLoss(nn.Module): + def __init__(self, opt): + super(SemanticLoss, self).__init__() + self.opt = opt + self.instancenorm = nn.InstanceNorm2d(21, affine=False) + + def compute_fcn_loss(self, fcn, img, target): + img_fcn = vgg_preprocess(img, self.opt) + target_fcn = vgg_preprocess(target, self.opt) + img_fea = fcn(img_fcn) + target_fea = fcn(target_fcn) + return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) diff --git a/models/pair_model.py b/models/pair_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f98eded01d3d19a08b06c1c392a5289e9ecfe8 --- /dev/null +++ b/models/pair_model.py @@ -0,0 +1,359 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from torch.autograd import Variable +import util.util as util +from collections import OrderedDict +from torch.autograd import Variable +import itertools +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import sys + + +class PairModel(BaseModel): + def name(self): + return 'CycleGANModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + nb = opt.batchSize + size = opt.fineSize + self.opt = opt + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + self.input_B = self.Tensor(nb, opt.output_nc, size, size) + + if opt.vgg > 0: + self.vgg_loss = networks.PerceptualLoss() + self.vgg_loss.cuda() + self.vgg = networks.load_vgg16("./model") + self.vgg.eval() + for param in self.vgg.parameters(): + param.requires_grad = False + # load/define networks + # The naming conversion is different from those used in the paper + # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) + + skip = True if opt.skip > 0 else False + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) + self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD_A = networks.define_D(opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + self.netD_B = networks.define_D(opt.input_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + if not self.isTrain or opt.continue_train: + which_epoch = opt.which_epoch + self.load_network(self.netG_A, 'G_A', which_epoch) + self.load_network(self.netG_B, 'G_B', which_epoch) + if self.isTrain: + self.load_network(self.netD_A, 'D_A', which_epoch) + self.load_network(self.netD_B, 'D_B', which_epoch) + + if self.isTrain: + self.old_lr = opt.lr + self.fake_A_pool = ImagePool(opt.pool_size) + self.fake_B_pool = ImagePool(opt.pool_size) + # define loss functions + if opt.use_wgan: + self.criterionGAN = networks.DiscLossWGANGP() + else: + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + if opt.use_mse: + self.criterionCycle = torch.nn.MSELoss() + else: + self.criterionCycle = torch.nn.L1Loss() + self.criterionL1 = torch.nn.L1Loss() + self.criterionIdt = torch.nn.L1Loss() + # initialize optimizers + self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG_A) + networks.print_network(self.netG_B) + if self.isTrain: + networks.print_network(self.netD_A) + networks.print_network(self.netD_B) + if opt.isTrain: + self.netG_A.train() + self.netG_B.train() + else: + self.netG_A.eval() + self.netG_B.eval() + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + self.real_A = Variable(self.input_A) + self.real_B = Variable(self.input_B) + + + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) + else: + self.fake_B = self.netG_A.forward(self.real_A) + self.rec_A = self.netG_B.forward(self.fake_B) + + self.real_B = Variable(self.input_B, volatile=True) + self.fake_A = self.netG_B.forward(self.real_B) + if self.opt.skip == 1: + self.rec_B, self.latent_fake_A = self.netG_A.forward(self.fake_A) + else: + self.rec_B = self.netG_A.forward(self.fake_A) + + def predict(self): + self.real_A = Variable(self.input_A, volatile=True) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) + else: + self.fake_B = self.netG_A.forward(self.real_A) + self.rec_A = self.netG_B.forward(self.fake_B) + + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + rec_A = util.tensor2im(self.rec_A.data) + if self.opt.skip == 1: + latent_real_A = util.tensor2im(self.latent_real_A.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("latent_real_A", latent_real_A), ("rec_A", rec_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("rec_A", rec_A)]) + + # get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D_basic(self, netD, real, fake): + # Real + pred_real = netD.forward(real) + if self.opt.use_wgan: + loss_D_real = pred_real.mean() + else: + loss_D_real = self.criterionGAN(pred_real, True) + # Fake + pred_fake = netD.forward(fake.detach()) + if self.opt.use_wgan: + loss_D_fake = pred_fake.mean() + else: + loss_D_fake = self.criterionGAN(pred_fake, False) + # Combined loss + if self.opt.use_wgan: + loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty(netD, real.data, fake.data) + else: + loss_D = (loss_D_real + loss_D_fake) * 0.5 + # backward + loss_D.backward() + return loss_D + + def backward_D_A(self): + fake_B = self.fake_B_pool.query(self.fake_B) + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + + def backward_D_B(self): + fake_A = self.fake_A_pool.query(self.fake_A) + self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + + def backward_G(self): + lambda_idt = self.opt.identity + lambda_A = self.opt.lambda_A + lambda_B = self.opt.lambda_B + # Identity loss + if lambda_idt > 0: + # G_A should be identity if real_B is fed. + if self.opt.skip == 1: + self.idt_A, _ = self.netG_A.forward(self.real_B) + else: + self.idt_A = self.netG_A.forward(self.real_B) + self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt + # G_B should be identity if real_A is fed. + self.idt_B = self.netG_B.forward(self.real_A) + self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt + else: + self.loss_idt_A = 0 + self.loss_idt_B = 0 + + # GAN loss + # D_A(G_A(A)) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A) + else: + self.fake_B = self.netG_A.forward(self.real_A) + # = self.latent_real_A + self.opt.skip * self.real_A + pred_fake = self.netD_A.forward(self.fake_B) + if self.opt.use_wgan: + self.loss_G_A = -pred_fake.mean() + else: + self.loss_G_A = self.criterionGAN(pred_fake, True) + self.L1_AB = self.criterionL1(self.fake_B, self.real_B) * self.opt.l1 + # D_B(G_B(B)) + self.fake_A = self.netG_B.forward(self.real_B) + pred_fake = self.netD_B.forward(self.fake_A) + self.L1_BA = self.criterionL1(self.fake_A, self.real_A) * self.opt.l1 + if self.opt.use_wgan: + self.loss_G_B = -pred_fake.mean() + else: + self.loss_G_B = self.criterionGAN(pred_fake, True) + # Forward cycle loss + + if lambda_A > 0: + self.rec_A = self.netG_B.forward(self.fake_B) + self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A + else: + self.loss_cycle_A = 0 + # Backward cycle loss + + # = self.latent_fake_A + self.opt.skip * self.fake_A + if lambda_B > 0: + if self.opt.skip == 1: + self.rec_B, self.latent_fake_A = self.netG_A.forward(self.fake_A) + else: + self.rec_B = self.netG_A.forward(self.fake_A) + self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B + else: + self.loss_cycle_B = 0 + self.loss_vgg_a = self.vgg_loss.compute_vgg_loss(self.vgg, self.fake_A, self.real_B) * self.opt.vgg if self.opt.vgg > 0 else 0 + self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(self.vgg, self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0 + # combined loss + self.loss_G = self.loss_G_A + self.loss_G_B + self.L1_AB + self.L1_BA + self.loss_cycle_A + self.loss_cycle_B + \ + self.loss_vgg_a + self.loss_vgg_b + \ + self.loss_idt_A + self.loss_idt_B + # self.loss_G = self.L1_AB + self.L1_BA + self.loss_G.backward() + + + def optimize_parameters(self): + # forward + self.forward() + # G_A and G_B + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + # D_A + self.optimizer_D_A.zero_grad() + self.backward_D_A() + self.optimizer_D_A.step() + # D_B + self.optimizer_D_B.zero_grad() + self.backward_D_B() + self.optimizer_D_B.step() + + + def get_current_errors(self): + D_A = self.loss_D_A.data[0] + G_A = self.loss_G_A.data[0] + L1 = (self.L1_AB + self.L1_BA).data[0] + Cyc_A = self.loss_cycle_A.data[0] + D_B = self.loss_D_B.data[0] + G_B = self.loss_G_B.data[0] + Cyc_B = self.loss_cycle_B.data[0] + vgg = (self.loss_vgg_a.data[0] + self.loss_vgg_b.data[0])/self.opt.vgg if self.opt.vgg > 0 else 0 + if self.opt.identity > 0: + idt = self.loss_idt_A.data[0] + self.loss_idt_B.data[0] + if self.opt.lambda_A > 0.0: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('Cyc_A', Cyc_A), + ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ("vgg", vgg), ("idt", idt)]) + else: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), + ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg), ("idt", idt)) + else: + if self.opt.lambda_A > 0.0: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), ('Cyc_A', Cyc_A), + ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ("vgg", vgg)]) + else: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1), + ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg)) + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + if self.opt.skip > 0: + latent_real_A = util.tensor2im(self.latent_real_A.data) + + real_B = util.tensor2im(self.real_B.data) + fake_A = util.tensor2im(self.fake_A.data) + + if self.opt.identity > 0: + idt_A = util.tensor2im(self.idt_A.data) + idt_B = util.tensor2im(self.idt_B.data) + if self.opt.lambda_A > 0.0: + rec_A = util.tensor2im(self.rec_A.data) + rec_B = util.tensor2im(self.rec_B.data) + if self.opt.skip > 0: + latent_fake_A = util.tensor2im(self.latent_fake_A.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('latent_fake_A', latent_fake_A), + ("idt_A", idt_A), ("idt_B", idt_B)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ("idt_A", idt_A), ("idt_B", idt_B)]) + else: + if self.opt.skip > 0: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('real_B', real_B), ('fake_A', fake_A), ("idt_A", idt_A), ("idt_B", idt_B)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), + ('real_B', real_B), ('fake_A', fake_A), ("idt_A", idt_A), ("idt_B", idt_B)]) + else: + if self.opt.lambda_A > 0.0: + rec_A = util.tensor2im(self.rec_A.data) + rec_B = util.tensor2im(self.rec_B.data) + if self.opt.skip > 0: + latent_fake_A = util.tensor2im(self.latent_fake_A.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('latent_fake_A', latent_fake_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) + else: + if self.opt.skip > 0: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('real_B', real_B), ('fake_A', fake_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), + ('real_B', real_B), ('fake_A', fake_A)]) + + def save(self, label): + self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) + self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) + self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) + self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) + + def update_learning_rate(self): + + if self.opt.new_lr: + lr = self.old_lr/2 + else: + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D_A.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_D_B.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a524f2ceda4575373a7475816b3d99e8a33579b9 --- /dev/null +++ b/models/pix2pix_model.py @@ -0,0 +1,146 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from torch.autograd import Variable +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks + + +class Pix2PixModel(BaseModel): + def name(self): + return 'Pix2PixModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + self.isTrain = opt.isTrain + # define tensors + self.input_A = self.Tensor(opt.batchSize, opt.input_nc, + opt.fineSize, opt.fineSize) + self.input_B = self.Tensor(opt.batchSize, opt.output_nc, + opt.fineSize, opt.fineSize) + + # load/define networks + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, + opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + if not self.isTrain or opt.continue_train: + self.load_network(self.netG, 'G', opt.which_epoch) + if self.isTrain: + self.load_network(self.netD, 'D', opt.which_epoch) + + if self.isTrain: + self.fake_AB_pool = ImagePool(opt.pool_size) + self.old_lr = opt.lr + # define loss functions + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionL1 = torch.nn.L1Loss() + + # initialize optimizers + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG) + if self.isTrain: + networks.print_network(self.netD) + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + self.real_A = Variable(self.input_A) + self.fake_B = self.netG.forward(self.real_A) + self.real_B = Variable(self.input_B) + + # no backprop gradients + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + self.fake_B = self.netG.forward(self.real_A) + self.real_B = Variable(self.input_B, volatile=True) + + # get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D(self): + # Fake + # stop backprop to the generator by detaching fake_B + fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) + self.pred_fake = self.netD.forward(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(self.pred_fake, False) + + # Real + real_AB = torch.cat((self.real_A, self.real_B), 1) + self.pred_real = self.netD.forward(real_AB) + self.loss_D_real = self.criterionGAN(self.pred_real, True) + + # Combined loss + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + + self.loss_D.backward() + + def backward_G(self): + # First, G(A) should fake the discriminator + fake_AB = torch.cat((self.real_A, self.fake_B), 1) + pred_fake = self.netD.forward(fake_AB) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + + # Second, G(A) = B + self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A + + self.loss_G = self.loss_G_GAN + self.loss_G_L1 + + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() + + self.optimizer_D.zero_grad() + self.backward_D() + self.optimizer_D.step() + + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + + def get_current_errors(self): + return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), + ('G_L1', self.loss_G_L1.data[0]), + ('D_real', self.loss_D_real.data[0]), + ('D_fake', self.loss_D_fake.data[0]) + ]) + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + real_B = util.tensor2im(self.real_B.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) + + def save(self, label): + self.save_network(self.netG, 'G', label, self.gpu_ids) + self.save_network(self.netD, 'D', label, self.gpu_ids) + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr diff --git a/models/single_model.py b/models/single_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b1042c872f57389a662f5c57041d3d195c1656ae --- /dev/null +++ b/models/single_model.py @@ -0,0 +1,496 @@ +import numpy as np +import torch +from torch import nn +import os +from collections import OrderedDict +from torch.autograd import Variable +import util.util as util +from collections import OrderedDict +from torch.autograd import Variable +import itertools +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +import random +from . import networks +import sys + + +class SingleModel(BaseModel): + def name(self): + return 'SingleGANModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + nb = opt.batchSize + size = opt.fineSize + self.opt = opt + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + self.input_B = self.Tensor(nb, opt.output_nc, size, size) + self.input_img = self.Tensor(nb, opt.input_nc, size, size) + self.input_A_gray = self.Tensor(nb, 1, size, size) + + if opt.vgg > 0: + self.vgg_loss = networks.PerceptualLoss(opt) + if self.opt.IN_vgg: + self.vgg_patch_loss = networks.PerceptualLoss(opt) + self.vgg_patch_loss.cuda() + self.vgg_loss.cuda() + self.vgg = networks.load_vgg16("./model", self.gpu_ids) + self.vgg.eval() + for param in self.vgg.parameters(): + param.requires_grad = False + elif opt.fcn > 0: + self.fcn_loss = networks.SemanticLoss(opt) + self.fcn_loss.cuda() + self.fcn = networks.load_fcn("./model") + self.fcn.eval() + for param in self.fcn.parameters(): + param.requires_grad = False + # load/define networks + # The naming conversion is different from those used in the paper + # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) + + skip = True if opt.skip > 0 else False + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) + # self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, + # opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD_A = networks.define_D(opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids, False) + if self.opt.patchD: + self.netD_P = networks.define_D(opt.input_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_patchD, opt.norm, use_sigmoid, self.gpu_ids, True) + if not self.isTrain or opt.continue_train: + which_epoch = opt.which_epoch + self.load_network(self.netG_A, 'G_A', which_epoch) + # self.load_network(self.netG_B, 'G_B', which_epoch) + if self.isTrain: + self.load_network(self.netD_A, 'D_A', which_epoch) + if self.opt.patchD: + self.load_network(self.netD_P, 'D_P', which_epoch) + + if self.isTrain: + self.old_lr = opt.lr + # self.fake_A_pool = ImagePool(opt.pool_size) + self.fake_B_pool = ImagePool(opt.pool_size) + # define loss functions + if opt.use_wgan: + self.criterionGAN = networks.DiscLossWGANGP() + else: + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + if opt.use_mse: + self.criterionCycle = torch.nn.MSELoss() + else: + self.criterionCycle = torch.nn.L1Loss() + self.criterionL1 = torch.nn.L1Loss() + self.criterionIdt = torch.nn.L1Loss() + # initialize optimizers + self.optimizer_G = torch.optim.Adam(self.netG_A.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + if self.opt.patchD: + self.optimizer_D_P = torch.optim.Adam(self.netD_P.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG_A) + # networks.print_network(self.netG_B) + if self.isTrain: + networks.print_network(self.netD_A) + if self.opt.patchD: + networks.print_network(self.netD_P) + # networks.print_network(self.netD_B) + if opt.isTrain: + self.netG_A.train() + # self.netG_B.train() + else: + self.netG_A.eval() + # self.netG_B.eval() + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + input_img = input['input_img'] + input_A_gray = input['A_gray'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_A_gray.resize_(input_A_gray.size()).copy_(input_A_gray) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.input_img.resize_(input_img.size()).copy_(input_img) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + + + + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + self.real_A_gray = Variable(self.input_A_gray, volatile=True) + if self.opt.noise > 0: + self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.)) + self.real_A = self.real_A + self.noise + if self.opt.input_linear: + self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A)) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) + else: + self.fake_B = self.netG_A.forward(self.real_A, self.real_A_gray) + # self.rec_A = self.netG_B.forward(self.fake_B) + + self.real_B = Variable(self.input_B, volatile=True) + + + def predict(self): + self.real_A = Variable(self.input_A, volatile=True) + self.real_A_gray = Variable(self.input_A_gray, volatile=True) + if self.opt.noise > 0: + self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.)) + self.real_A = self.real_A + self.noise + if self.opt.input_linear: + self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A)) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) + else: + self.fake_B = self.netG_A.forward(self.real_A, self.real_A_gray) + # self.rec_A = self.netG_B.forward(self.fake_B) + + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + A_gray = util.atten2im(self.real_A_gray.data) + # rec_A = util.tensor2im(self.rec_A.data) + # if self.opt.skip == 1: + # latent_real_A = util.tensor2im(self.latent_real_A.data) + # latent_show = util.latent2im(self.latent_real_A.data) + # max_image = util.max2im(self.fake_B.data, self.latent_real_A.data) + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + # ('latent_show', latent_show), ('max_image', max_image), ('A_gray', A_gray)]) + # else: + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) + # return OrderedDict([('fake_B', fake_B)]) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) + + # get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D_basic(self, netD, real, fake, use_ragan): + # Real + pred_real = netD.forward(real) + pred_fake = netD.forward(fake.detach()) + if self.opt.use_wgan: + loss_D_real = pred_real.mean() + loss_D_fake = pred_fake.mean() + loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty(netD, + real.data, fake.data) + elif self.opt.use_ragan and use_ragan: + loss_D = (self.criterionGAN(pred_real - torch.mean(pred_fake), True) + + self.criterionGAN(pred_fake - torch.mean(pred_real), False)) / 2 + else: + loss_D_real = self.criterionGAN(pred_real, True) + loss_D_fake = self.criterionGAN(pred_fake, False) + loss_D = (loss_D_real + loss_D_fake) * 0.5 + # loss_D.backward() + return loss_D + + def backward_D_A(self): + fake_B = self.fake_B_pool.query(self.fake_B) + fake_B = self.fake_B + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, True) + self.loss_D_A.backward() + + def backward_D_P(self): + if self.opt.hybrid_loss: + loss_D_P = self.backward_D_basic(self.netD_P, self.real_patch, self.fake_patch, False) + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + loss_D_P += self.backward_D_basic(self.netD_P, self.real_patch_1[i], self.fake_patch_1[i], False) + self.loss_D_P = loss_D_P/float(self.opt.patchD_3 + 1) + else: + self.loss_D_P = loss_D_P + else: + loss_D_P = self.backward_D_basic(self.netD_P, self.real_patch, self.fake_patch, True) + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + loss_D_P += self.backward_D_basic(self.netD_P, self.real_patch_1[i], self.fake_patch_1[i], True) + self.loss_D_P = loss_D_P/float(self.opt.patchD_3 + 1) + else: + self.loss_D_P = loss_D_P + if self.opt.D_P_times2: + self.loss_D_P = self.loss_D_P*2 + self.loss_D_P.backward() + + # def backward_D_B(self): + # fake_A = self.fake_A_pool.query(self.fake_A) + # self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + def forward(self): + self.real_A = Variable(self.input_A) + self.real_B = Variable(self.input_B) + self.real_A_gray = Variable(self.input_A_gray) + self.real_img = Variable(self.input_img) + if self.opt.noise > 0: + self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.)) + self.real_A = self.real_A + self.noise + if self.opt.input_linear: + self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A)) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_img, self.real_A_gray) + else: + self.fake_B = self.netG_A.forward(self.real_img, self.real_A_gray) + if self.opt.patchD: + w = self.real_A.size(3) + h = self.real_A.size(2) + w_offset = random.randint(0, max(0, w - self.opt.patchSize - 1)) + h_offset = random.randint(0, max(0, h - self.opt.patchSize - 1)) + + self.fake_patch = self.fake_B[:,:, h_offset:h_offset + self.opt.patchSize, + w_offset:w_offset + self.opt.patchSize] + self.real_patch = self.real_B[:,:, h_offset:h_offset + self.opt.patchSize, + w_offset:w_offset + self.opt.patchSize] + self.input_patch = self.real_A[:,:, h_offset:h_offset + self.opt.patchSize, + w_offset:w_offset + self.opt.patchSize] + if self.opt.patchD_3 > 0: + self.fake_patch_1 = [] + self.real_patch_1 = [] + self.input_patch_1 = [] + w = self.real_A.size(3) + h = self.real_A.size(2) + for i in range(self.opt.patchD_3): + w_offset_1 = random.randint(0, max(0, w - self.opt.patchSize - 1)) + h_offset_1 = random.randint(0, max(0, h - self.opt.patchSize - 1)) + self.fake_patch_1.append(self.fake_B[:,:, h_offset_1:h_offset_1 + self.opt.patchSize, + w_offset_1:w_offset_1 + self.opt.patchSize]) + self.real_patch_1.append(self.real_B[:,:, h_offset_1:h_offset_1 + self.opt.patchSize, + w_offset_1:w_offset_1 + self.opt.patchSize]) + self.input_patch_1.append(self.real_A[:,:, h_offset_1:h_offset_1 + self.opt.patchSize, + w_offset_1:w_offset_1 + self.opt.patchSize]) + + # w_offset_2 = random.randint(0, max(0, w - self.opt.patchSize - 1)) + # h_offset_2 = random.randint(0, max(0, h - self.opt.patchSize - 1)) + # self.fake_patch_2 = self.fake_B[:,:, h_offset_2:h_offset_2 + self.opt.patchSize, + # w_offset_2:w_offset_2 + self.opt.patchSize] + # self.real_patch_2 = self.real_B[:,:, h_offset_2:h_offset_2 + self.opt.patchSize, + # w_offset_2:w_offset_2 + self.opt.patchSize] + # self.input_patch_2 = self.real_A[:,:, h_offset_2:h_offset_2 + self.opt.patchSize, + # w_offset_2:w_offset_2 + self.opt.patchSize] + + def backward_G(self, epoch): + pred_fake = self.netD_A.forward(self.fake_B) + if self.opt.use_wgan: + self.loss_G_A = -pred_fake.mean() + elif self.opt.use_ragan: + pred_real = self.netD_A.forward(self.real_B) + + self.loss_G_A = (self.criterionGAN(pred_real - torch.mean(pred_fake), False) + + self.criterionGAN(pred_fake - torch.mean(pred_real), True)) / 2 + + else: + self.loss_G_A = self.criterionGAN(pred_fake, True) + + loss_G_A = 0 + if self.opt.patchD: + pred_fake_patch = self.netD_P.forward(self.fake_patch) + if self.opt.hybrid_loss: + loss_G_A += self.criterionGAN(pred_fake_patch, True) + else: + pred_real_patch = self.netD_P.forward(self.real_patch) + + loss_G_A += (self.criterionGAN(pred_real_patch - torch.mean(pred_fake_patch), False) + + self.criterionGAN(pred_fake_patch - torch.mean(pred_real_patch), True)) / 2 + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + pred_fake_patch_1 = self.netD_P.forward(self.fake_patch_1[i]) + if self.opt.hybrid_loss: + loss_G_A += self.criterionGAN(pred_fake_patch_1, True) + else: + pred_real_patch_1 = self.netD_P.forward(self.real_patch_1[i]) + + loss_G_A += (self.criterionGAN(pred_real_patch_1 - torch.mean(pred_fake_patch_1), False) + + self.criterionGAN(pred_fake_patch_1 - torch.mean(pred_real_patch_1), True)) / 2 + + if not self.opt.D_P_times2: + self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1) + else: + self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1)*2 + else: + if not self.opt.D_P_times2: + self.loss_G_A += loss_G_A + else: + self.loss_G_A += loss_G_A*2 + + if epoch < 0: + vgg_w = 0 + else: + vgg_w = 1 + if self.opt.vgg > 0: + self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(self.vgg, + self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0 + if self.opt.patch_vgg: + if not self.opt.IN_vgg: + loss_vgg_patch = self.vgg_loss.compute_vgg_loss(self.vgg, + self.fake_patch, self.input_patch) * self.opt.vgg + else: + loss_vgg_patch = self.vgg_patch_loss.compute_vgg_loss(self.vgg, + self.fake_patch, self.input_patch) * self.opt.vgg + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + if not self.opt.IN_vgg: + loss_vgg_patch += self.vgg_loss.compute_vgg_loss(self.vgg, + self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg + else: + loss_vgg_patch += self.vgg_patch_loss.compute_vgg_loss(self.vgg, + self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg + self.loss_vgg_b += loss_vgg_patch/float(self.opt.patchD_3 + 1) + else: + self.loss_vgg_b += loss_vgg_patch + self.loss_G = self.loss_G_A + self.loss_vgg_b*vgg_w + elif self.opt.fcn > 0: + self.loss_fcn_b = self.fcn_loss.compute_fcn_loss(self.fcn, + self.fake_B, self.real_A) * self.opt.fcn if self.opt.fcn > 0 else 0 + if self.opt.patchD: + loss_fcn_patch = self.fcn_loss.compute_vgg_loss(self.fcn, + self.fake_patch, self.input_patch) * self.opt.fcn + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + loss_fcn_patch += self.fcn_loss.compute_vgg_loss(self.fcn, + self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.fcn + self.loss_fcn_b += loss_fcn_patch/float(self.opt.patchD_3 + 1) + else: + self.loss_fcn_b += loss_fcn_patch + self.loss_G = self.loss_G_A + self.loss_fcn_b*vgg_w + # self.loss_G = self.L1_AB + self.L1_BA + self.loss_G.backward() + + + # def optimize_parameters(self, epoch): + # # forward + # self.forward() + # # G_A and G_B + # self.optimizer_G.zero_grad() + # self.backward_G(epoch) + # self.optimizer_G.step() + # # D_A + # self.optimizer_D_A.zero_grad() + # self.backward_D_A() + # self.optimizer_D_A.step() + # if self.opt.patchD: + # self.forward() + # self.optimizer_D_P.zero_grad() + # self.backward_D_P() + # self.optimizer_D_P.step() + # D_B + # self.optimizer_D_B.zero_grad() + # self.backward_D_B() + # self.optimizer_D_B.step() + def optimize_parameters(self, epoch): + # forward + self.forward() + # G_A and G_B + self.optimizer_G.zero_grad() + self.backward_G(epoch) + self.optimizer_G.step() + # D_A + self.optimizer_D_A.zero_grad() + self.backward_D_A() + if not self.opt.patchD: + self.optimizer_D_A.step() + else: + # self.forward() + self.optimizer_D_P.zero_grad() + self.backward_D_P() + self.optimizer_D_A.step() + self.optimizer_D_P.step() + + + def get_current_errors(self, epoch): + D_A = self.loss_D_A.data[0] + D_P = self.loss_D_P.data[0] if self.opt.patchD else 0 + G_A = self.loss_G_A.data[0] + if self.opt.vgg > 0: + vgg = self.loss_vgg_b.data[0]/self.opt.vgg if self.opt.vgg > 0 else 0 + return OrderedDict([('D_A', D_A), ('G_A', G_A), ("vgg", vgg), ("D_P", D_P)]) + elif self.opt.fcn > 0: + fcn = self.loss_fcn_b.data[0]/self.opt.fcn if self.opt.fcn > 0 else 0 + return OrderedDict([('D_A', D_A), ('G_A', G_A), ("fcn", fcn), ("D_P", D_P)]) + + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + real_B = util.tensor2im(self.real_B.data) + if self.opt.skip > 0: + latent_real_A = util.tensor2im(self.latent_real_A.data) + latent_show = util.latent2im(self.latent_real_A.data) + if self.opt.patchD: + fake_patch = util.tensor2im(self.fake_patch.data) + real_patch = util.tensor2im(self.real_patch.data) + if self.opt.patch_vgg: + input_patch = util.tensor2im(self.input_patch.data) + if not self.opt.self_attention: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + ('fake_patch', fake_patch), ('input_patch', input_patch)]) + else: + self_attention = util.atten2im(self.real_A_gray.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + ('fake_patch', fake_patch), ('input_patch', input_patch), ('self_attention', self_attention)]) + else: + if not self.opt.self_attention: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + ('fake_patch', fake_patch)]) + else: + self_attention = util.atten2im(self.real_A_gray.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + ('fake_patch', fake_patch), ('self_attention', self_attention)]) + else: + if not self.opt.self_attention: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B)]) + else: + self_attention = util.atten2im(self.real_A_gray.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), + ('latent_real_A', latent_real_A), ('latent_show', latent_show), + ('self_attention', self_attention)]) + else: + if not self.opt.self_attention: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) + else: + self_attention = util.atten2im(self.real_A_gray.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), + ('self_attention', self_attention)]) + + def save(self, label): + self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) + self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) + if self.opt.patchD: + self.save_network(self.netD_P, 'D_P', label, self.gpu_ids) + # self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) + # self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) + + def update_learning_rate(self): + + if self.opt.new_lr: + lr = self.old_lr/2 + else: + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D_A.param_groups: + param_group['lr'] = lr + if self.opt.patchD: + for param_group in self.optimizer_D_P.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr diff --git a/models/temp_model.py b/models/temp_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0f7f87c455c290fffb29e775e0702c458b1b5a --- /dev/null +++ b/models/temp_model.py @@ -0,0 +1,499 @@ +import numpy as np +import torch +from torch import nn +import os +from collections import OrderedDict +from torch.autograd import Variable +import util.util as util +from collections import OrderedDict +from torch.autograd import Variable +import itertools +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +import random +from . import networks +import sys + + +class TempModel(BaseModel): + def name(self): + return 'TempModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + nb = opt.batchSize + size = opt.fineSize + self.opt = opt + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + self.input_B = self.Tensor(nb, opt.output_nc, size, size) + self.input_img = self.Tensor(nb, opt.input_nc, size, size) + self.input_A_gray = self.Tensor(nb, 1, size, size) + + if opt.vgg > 0: + self.vgg_loss = networks.PerceptualLoss(opt) + # if self.opt.IN_vgg: + # self.vgg_patch_loss = networks.PerceptualLoss(opt) + # self.vgg_patch_loss.cuda() + self.vgg_loss.cuda() + self.vgg = networks.load_vgg16("./model", self.gpu_ids) + self.vgg.eval() + for param in self.vgg.parameters(): + param.requires_grad = False + elif opt.fcn > 0: + self.fcn_loss = networks.SemanticLoss(opt) + self.fcn_loss.cuda() + self.fcn = networks.load_fcn("./model") + self.fcn.eval() + for param in self.fcn.parameters(): + param.requires_grad = False + # load/define networks + # The naming conversion is different from those used in the paper + # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) + + skip = True if opt.skip > 0 else False + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) + # self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, + # opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD_A = networks.define_D(opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids, False) + if self.opt.patchD: + self.netD_P = networks.define_D(opt.input_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_patchD, opt.norm, use_sigmoid, self.gpu_ids, True) + if not self.isTrain or opt.continue_train: + which_epoch = opt.which_epoch + self.load_network(self.netG_A, 'G_A', which_epoch) + # self.load_network(self.netG_B, 'G_B', which_epoch) + if self.isTrain: + self.load_network(self.netD_A, 'D_A', which_epoch) + if self.opt.patchD: + self.load_network(self.netD_P, 'D_P', which_epoch) + + if self.isTrain: + self.old_lr = opt.lr + # self.fake_A_pool = ImagePool(opt.pool_size) + self.fake_B_pool = ImagePool(opt.pool_size) + # define loss functions + if opt.use_wgan: + self.criterionGAN = networks.DiscLossWGANGP() + else: + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + if opt.use_mse: + self.criterionCycle = torch.nn.MSELoss() + else: + self.criterionCycle = torch.nn.L1Loss() + self.criterionL1 = torch.nn.L1Loss() + self.criterionIdt = torch.nn.L1Loss() + # initialize optimizers + self.optimizer_G = torch.optim.Adam(self.netG_A.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + # if self.opt.patchD: + # self.optimizer_D_P = torch.optim.Adam(self.netD_P.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG_A) + # networks.print_network(self.netG_B) + if self.isTrain: + networks.print_network(self.netD_A) + # if self.opt.patchD: + # networks.print_network(self.netD_P) + # networks.print_network(self.netD_B) + if opt.isTrain: + self.netG_A.train() + # self.netG_B.train() + else: + self.netG_A.eval() + # self.netG_B.eval() + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + input_img = input['input_img'] + input_A_gray = input['A_gray'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_A_gray.resize_(input_A_gray.size()).copy_(input_A_gray) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.input_img.resize_(input_img.size()).copy_(input_img) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + + + + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + self.real_A_gray = Variable(self.input_A_gray, volatile=True) + if self.opt.noise > 0: + self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.)) + self.real_A = self.real_A + self.noise + if self.opt.input_linear: + self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A)) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) + else: + self.fake_B = self.netG_A.forward(self.real_A, self.real_A_gray) + # self.rec_A = self.netG_B.forward(self.fake_B) + + self.real_B = Variable(self.input_B, volatile=True) + + + def predict(self): + self.real_A = Variable(self.input_A, volatile=True) + self.real_A_gray = Variable(self.input_A_gray, volatile=True) + # if self.opt.noise > 0: + # self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.)) + # self.real_A = self.real_A + self.noise + # if self.opt.input_linear: + # self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A)) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) + else: + self.fake_B = self.netG_A.forward(self.real_A, self.real_A_gray) + # self.rec_A = self.netG_B.forward(self.fake_B) + + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + A_gray = util.atten2im(self.real_A_gray.data) + # rec_A = util.tensor2im(self.rec_A.data) + # if self.opt.skip == 1: + # latent_real_A = util.tensor2im(self.latent_real_A.data) + # latent_show = util.latent2im(self.latent_real_A.data) + # max_image = util.max2im(self.fake_B.data, self.latent_real_A.data) + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + # ('latent_show', latent_show), ('max_image', max_image), ('A_gray', A_gray)]) + # else: + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) + # return OrderedDict([('fake_B', fake_B)]) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) + + # get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D_basic(self, netD, real, fake, use_ragan): + # Real + pred_real = netD.forward(real) + pred_fake = netD.forward(fake.detach()) + if self.opt.use_wgan: + loss_D_real = pred_real.mean() + loss_D_fake = pred_fake.mean() + loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty(netD, + real.data, fake.data) + elif self.opt.use_ragan and use_ragan: + loss_D = (self.criterionGAN(pred_real - torch.mean(pred_fake), True) + + self.criterionGAN(pred_fake - torch.mean(pred_real), False)) / 2 + else: + loss_D_real = self.criterionGAN(pred_real, True) + loss_D_fake = self.criterionGAN(pred_fake, False) + loss_D = (loss_D_real + loss_D_fake) * 0.5 + # loss_D.backward() + return loss_D + + def backward_D_A(self): + fake_B = self.fake_B_pool.query(self.fake_B) + fake_B = self.fake_B + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, True) + self.loss_D_A.backward() + + def backward_D_P(self): + if self.opt.hybrid_loss: + loss_D_P = self.backward_D_basic(self.netD_P, self.real_patch, self.fake_patch, False) + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + loss_D_P += self.backward_D_basic(self.netD_P, self.real_patch_1[i], self.fake_patch_1[i], False) + self.loss_D_P = loss_D_P/float(self.opt.patchD_3 + 1) + else: + self.loss_D_P = loss_D_P + else: + loss_D_P = self.backward_D_basic(self.netD_P, self.real_patch, self.fake_patch, True) + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + loss_D_P += self.backward_D_basic(self.netD_P, self.real_patch_1[i], self.fake_patch_1[i], True) + self.loss_D_P = loss_D_P/float(self.opt.patchD_3 + 1) + else: + self.loss_D_P = loss_D_P + if self.opt.D_P_times2: + self.loss_D_P = self.loss_D_P*2 + self.loss_D_P.backward() + + # def backward_D_B(self): + # fake_A = self.fake_A_pool.query(self.fake_A) + # self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + def forward(self): + self.real_A = Variable(self.input_A) + self.real_B = Variable(self.input_B) + self.real_A_gray = Variable(self.input_A_gray) + self.real_img = Variable(self.input_img) + if self.opt.noise > 0: + self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.)) + self.real_A = self.real_A + self.noise + if self.opt.input_linear: + self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A)) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_img, self.real_A_gray) + else: + self.fake_B = self.netG_A.forward(self.real_img, self.real_A_gray) + if self.opt.patchD: + w = self.real_A.size(3) + h = self.real_A.size(2) + w_offset = random.randint(0, max(0, w - self.opt.patchSize - 1)) + h_offset = random.randint(0, max(0, h - self.opt.patchSize - 1)) + + self.fake_patch = self.fake_B[:,:, h_offset:h_offset + self.opt.patchSize, + w_offset:w_offset + self.opt.patchSize] + self.real_patch = self.real_B[:,:, h_offset:h_offset + self.opt.patchSize, + w_offset:w_offset + self.opt.patchSize] + self.input_patch = self.real_A[:,:, h_offset:h_offset + self.opt.patchSize, + w_offset:w_offset + self.opt.patchSize] + if self.opt.patchD_3 > 0: + self.fake_patch_1 = [] + self.real_patch_1 = [] + self.input_patch_1 = [] + w = self.real_A.size(3) + h = self.real_A.size(2) + for i in range(self.opt.patchD_3): + w_offset_1 = random.randint(0, max(0, w - self.opt.patchSize - 1)) + h_offset_1 = random.randint(0, max(0, h - self.opt.patchSize - 1)) + self.fake_patch_1.append(self.fake_B[:,:, h_offset_1:h_offset_1 + self.opt.patchSize, + w_offset_1:w_offset_1 + self.opt.patchSize]) + self.real_patch_1.append(self.real_B[:,:, h_offset_1:h_offset_1 + self.opt.patchSize, + w_offset_1:w_offset_1 + self.opt.patchSize]) + self.input_patch_1.append(self.real_A[:,:, h_offset_1:h_offset_1 + self.opt.patchSize, + w_offset_1:w_offset_1 + self.opt.patchSize]) + + w_offset_2 = random.randint(0, max(0, w - self.opt.patchSize - 1)) + h_offset_2 = random.randint(0, max(0, h - self.opt.patchSize - 1)) + self.fake_patch_2 = self.fake_B[:,:, h_offset_2:h_offset_2 + self.opt.patchSize, + w_offset_2:w_offset_2 + self.opt.patchSize] + self.real_patch_2 = self.real_B[:,:, h_offset_2:h_offset_2 + self.opt.patchSize, + w_offset_2:w_offset_2 + self.opt.patchSize] + self.input_patch_2 = self.real_A[:,:, h_offset_2:h_offset_2 + self.opt.patchSize, + w_offset_2:w_offset_2 + self.opt.patchSize] + + def backward_G(self, epoch): + pred_fake = self.netD_A.forward(self.fake_B) + if self.opt.use_wgan: + self.loss_G_A = -pred_fake.mean() + elif self.opt.use_ragan: + pred_real = self.netD_A.forward(self.real_B) + + self.loss_G_A = (self.criterionGAN(pred_real - torch.mean(pred_fake), False) + + self.criterionGAN(pred_fake - torch.mean(pred_real), True)) / 2 + + else: + self.loss_G_A = self.criterionGAN(pred_fake, True) + + loss_G_A = 0 + if self.opt.patchD: + pred_fake_patch = self.netD_P.forward(self.fake_patch) + if self.opt.hybrid_loss: + loss_G_A += self.criterionGAN(pred_fake_patch, True) + else: + pred_real_patch = self.netD_P.forward(self.real_patch) + + loss_G_A += (self.criterionGAN(pred_real_patch - torch.mean(pred_fake_patch), False) + + self.criterionGAN(pred_fake_patch - torch.mean(pred_real_patch), True)) / 2 + self.loss_G_A += loss_G_A + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + pred_fake_patch_1 = self.netD_P.forward(self.fake_patch_1[i]) + if self.opt.hybrid_loss: + loss_G_A += self.criterionGAN(pred_fake_patch_1, True) + else: + pred_real_patch_1 = self.netD_P.forward(self.real_patch_1[i]) + + loss_G_A += (self.criterionGAN(pred_real_patch_1 - torch.mean(pred_fake_patch_1), False) + + self.criterionGAN(pred_fake_patch_1 - torch.mean(pred_real_patch_1), True)) / 2 + + if not self.opt.D_P_times2: + self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1) + else: + self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1)*2 + else: + if not self.opt.D_P_times2: + self.loss_G_A += loss_G_A + else: + self.loss_G_A += loss_G_A*2 + + if epoch < 0: + vgg_w = 0 + else: + vgg_w = 1 + if self.opt.vgg > 0: + self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(self.vgg, + self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0 + if self.opt.patch_vgg: + if not self.opt.IN_vgg: + loss_vgg_patch = self.vgg_loss.compute_vgg_loss(self.vgg, + self.fake_patch, self.input_patch) * self.opt.vgg + else: + loss_vgg_patch = self.vgg_patch_loss.compute_vgg_loss(self.vgg, + self.fake_patch, self.input_patch) * self.opt.vgg + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + if not self.opt.IN_vgg: + loss_vgg_patch += self.vgg_loss.compute_vgg_loss(self.vgg, + self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg + else: + loss_vgg_patch += self.vgg_patch_loss.compute_vgg_loss(self.vgg, + self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg + self.loss_vgg_b += loss_vgg_patch/float(self.opt.patchD_3 + 1) + else: + self.loss_vgg_b += loss_vgg_patch + self.loss_G = self.loss_G_A + self.loss_vgg_b*vgg_w + elif self.opt.fcn > 0: + self.loss_fcn_b = self.fcn_loss.compute_fcn_loss(self.fcn, + self.fake_B, self.real_A) * self.opt.fcn if self.opt.fcn > 0 else 0 + if self.opt.patchD: + loss_fcn_patch = self.fcn_loss.compute_vgg_loss(self.fcn, + self.fake_patch, self.input_patch) * self.opt.fcn + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + loss_fcn_patch += self.fcn_loss.compute_vgg_loss(self.fcn, + self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.fcn + self.loss_fcn_b += loss_fcn_patch/float(self.opt.patchD_3 + 1) + else: + self.loss_fcn_b += loss_fcn_patch + self.loss_G = self.loss_G_A + self.loss_fcn_b*vgg_w + # self.loss_G = self.L1_AB + self.L1_BA + self.loss_G.backward() + + + # def optimize_parameters(self, epoch): + # # forward + # self.forward() + # # G_A and G_B + # self.optimizer_G.zero_grad() + # self.backward_G(epoch) + # self.optimizer_G.step() + # # D_A + # self.optimizer_D_A.zero_grad() + # self.backward_D_A() + # self.optimizer_D_A.step() + # if self.opt.patchD: + # self.forward() + # self.optimizer_D_P.zero_grad() + # self.backward_D_P() + # self.optimizer_D_P.step() + # D_B + # self.optimizer_D_B.zero_grad() + # self.backward_D_B() + # self.optimizer_D_B.step() + def optimize_parameters(self, epoch): + # forward + self.forward() + # G_A and G_B + self.optimizer_G.zero_grad() + self.backward_G(epoch) + self.optimizer_G.step() + # D_A + self.optimizer_D_A.zero_grad() + self.backward_D_A() + if not self.opt.patchD: + self.optimizer_D_A.step() + else: + # self.forward() + self.optimizer_D_P.zero_grad() + self.backward_D_P() + self.optimizer_D_A.step() + self.optimizer_D_P.step() + + + def get_current_errors(self, epoch): + D_A = self.loss_D_A.data[0] + D_P = self.loss_D_P.data[0] if self.opt.patchD else 0 + G_A = self.loss_G_A.data[0] + if self.opt.vgg > 0: + vgg = self.loss_vgg_b.data[0]/self.opt.vgg if self.opt.vgg > 0 else 0 + return OrderedDict([('D_A', D_A), ('G_A', G_A), ("vgg", vgg)]) + elif self.opt.fcn > 0: + fcn = self.loss_fcn_b.data[0]/self.opt.fcn if self.opt.fcn > 0 else 0 + return OrderedDict([('D_A', D_A), ('G_A', G_A), ("fcn", fcn)]) + + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + real_B = util.tensor2im(self.real_B.data) + if self.opt.skip > 0: + latent_real_A = util.tensor2im(self.latent_real_A.data) + latent_show = util.latent2im(self.latent_real_A.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B)]) + # if self.opt.patchD: + # fake_patch = util.tensor2im(self.fake_patch.data) + # real_patch = util.tensor2im(self.real_patch.data) + # if self.opt.patch_vgg: + # input_patch = util.tensor2im(self.input_patch.data) + # if not self.opt.self_attention: + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + # ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + # ('fake_patch', fake_patch), ('input_patch', input_patch)]) + # else: + # self_attention = util.atten2im(self.real_A_gray.data) + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + # ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + # ('fake_patch', fake_patch), ('input_patch', input_patch), ('self_attention', self_attention)]) + # else: + # if not self.opt.self_attention: + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + # ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + # ('fake_patch', fake_patch)]) + # else: + # self_attention = util.atten2im(self.real_A_gray.data) + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + # ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + # ('fake_patch', fake_patch), ('self_attention', self_attention)]) + # else: + # if not self.opt.self_attention: + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + # ('latent_show', latent_show), ('real_B', real_B)]) + # else: + # self_attention = util.atten2im(self.real_A_gray.data) + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), + # ('latent_real_A', latent_real_A), ('latent_show', latent_show), + # ('self_attention', self_attention)]) + # else: + # if not self.opt.self_attention: + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) + # else: + # self_attention = util.atten2im(self.real_A_gray.data) + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), + # ('self_attention', self_attention)]) + + def save(self, label): + self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) + self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) + if self.opt.patchD: + self.save_network(self.netD_P, 'D_P', label, self.gpu_ids) + # self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) + # self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) + + def update_learning_rate(self): + + if self.opt.new_lr: + lr = self.old_lr/2 + else: + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D_A.param_groups: + param_group['lr'] = lr + if self.opt.patchD: + for param_group in self.optimizer_D_P.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr diff --git a/models/test_model.py b/models/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..03aef655aacba302e5ca6fc963b09da213f9911d --- /dev/null +++ b/models/test_model.py @@ -0,0 +1,45 @@ +from torch.autograd import Variable +from collections import OrderedDict +import util.util as util +from .base_model import BaseModel +from . import networks + + +class TestModel(BaseModel): + def name(self): + return 'TestModel' + + def initialize(self, opt): + assert(not opt.isTrain) + BaseModel.initialize(self, opt) + self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) + + self.netG = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, + opt.norm, not opt.no_dropout, + self.gpu_ids) + which_epoch = opt.which_epoch + self.load_network(self.netG, 'G', which_epoch) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG) + print('-----------------------------------------------') + + def set_input(self, input): + # we need to use single_dataset mode + input_A = input['A'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.image_paths = input['A_paths'] + + def test(self): + self.real_A = Variable(self.input_A) + self.fake_B = self.netG.forward(self.real_A) + + # get image paths + def get_image_paths(self): + return self.image_paths + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) diff --git a/models/unit_model.py b/models/unit_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7e85091823ba075940f5df6efdfe6bc05a1265 --- /dev/null +++ b/models/unit_model.py @@ -0,0 +1,277 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from torch.autograd import Variable +import itertools +import util.util as util +from util.util import weights_init, get_model_list, vgg_preprocess, load_vgg16, get_scheduler +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +from .unit_network import * +import sys + +def get_config(config): + import yaml + with open(config, 'r') as stream: + return yaml.load(stream) + +class UNITModel(BaseModel): + def name(self): + return 'UNITModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + self.config = get_config(opt.config) + nb = opt.batchSize + size = opt.fineSize + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + self.input_B = self.Tensor(nb, opt.output_nc, size, size) + + # load/define networks + # The naming conversion is different from those used in the paper + # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) + + self.gen_a = VAEGen(self.config['input_dim_a'], self.config['gen']) + self.gen_b = VAEGen(self.config['input_dim_a'], self.config['gen']) + + if self.isTrain: + self.dis_a = MsImageDis(self.config['input_dim_a'], self.config['dis']) # discriminator for domain a + self.dis_b = MsImageDis(self.config['input_dim_b'], self.config['dis']) # discriminator for domain b + if not self.isTrain or opt.continue_train: + which_epoch = opt.which_epoch + self.load_network(self.gen_a, 'G_A', which_epoch) + self.load_network(self.gen_b, 'G_B', which_epoch) + if self.isTrain: + self.load_network(self.dis_a, 'D_A', which_epoch) + self.load_network(self.dis_b, 'D_B', which_epoch) + + if self.isTrain: + self.old_lr = self.config['lr'] + self.fake_A_pool = ImagePool(opt.pool_size) + self.fake_B_pool = ImagePool(opt.pool_size) + # define loss functions + # Setup the optimizers + beta1 = self.config['beta1'] + beta2 = self.config['beta2'] + dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) + gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) + self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], + lr=self.config['lr'], betas=(beta1, beta2), weight_decay=self.config['weight_decay']) + self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], + lr=self.config['lr'], betas=(beta1, beta2), weight_decay=self.config['weight_decay']) + self.dis_scheduler = get_scheduler(self.dis_opt, self.config) + self.gen_scheduler = get_scheduler(self.gen_opt, self.config) + + # Network weight initialization + # self.apply(weights_init(self.config['init'])) + self.dis_a.apply(weights_init('gaussian')) + self.dis_b.apply(weights_init('gaussian')) + + # Load VGG model if needed + if 'vgg_w' in self.config.keys() and self.config['vgg_w'] > 0: + self.vgg = load_vgg16(self.config['vgg_model_path'] + '/models') + self.vgg.eval() + for param in self.vgg.parameters(): + param.requires_grad = False + self.gen_a.cuda() + self.gen_b.cuda() + self.dis_a.cuda() + self.dis_b.cuda() + + print('---------- Networks initialized -------------') + networks.print_network(self.gen_a) + networks.print_network(self.gen_b) + if self.isTrain: + networks.print_network(self.dis_a) + networks.print_network(self.dis_b) + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + self.real_A = Variable(self.input_A.cuda()) + self.real_B = Variable(self.input_B.cuda()) + + # def forward(self): + # self.real_A = Variable(self.input_A) + # self.real_B = Variable(self.input_B) + + def test(self): + self.real_A = Variable(self.input_A.cuda(), volatile=True) + self.real_B = Variable(self.input_B.cuda(), volatile=True) + h_a, n_a = self.gen_a.encode(self.real_A) + h_b, n_b = self.gen_b.encode(self.real_B) + x_a_recon = self.gen_a.decode(h_a + n_a) + x_a*1 + x_b_recon = self.gen_b.decode(h_b + n_b) + x_b*1 + x_ba = self.gen_a.decode(h_b + n_b) + x_b*1 + x_ab = self.gen_b.decode(h_a + n_a) + x_a*1 + h_b_recon, n_b_recon = self.gen_a.encode(x_ba) + h_a_recon, n_a_recon = self.gen_b.encode(x_ab) + x_aba = self.gen_a.decode(h_a_recon + n_a_recon) + x_ab*1 if self.config['recon_x_cyc_w'] > 0 else None + x_bab = self.gen_b.decode(h_b_recon + n_b_recon) + x_ba*1 if self.config['recon_x_cyc_w'] > 0 else None + self.x_a_recon, self.x_ab, self.x_aba = x_a_recon, x_ab, x_aba + self.x_b_recon, self.x_ba, self.x_bab = x_b_recon, x_ba, x_bab + + # get image paths + def get_image_paths(self): + return self.image_paths + + + def optimize_parameters(self): + self.gen_update(self.real_A, self.real_B) + self.dis_update(self.real_A, self.real_B) + + def recon_criterion(self, input, target): + return torch.mean(torch.abs(input - target)) + + def forward(self, x_a, x_b): + self.eval() + x_a.volatile = True + x_b.volatile = True + h_a, _ = self.gen_a.encode(x_a) + h_b, _ = self.gen_b.encode(x_b) + x_ba = self.gen_a.decode(h_b) + x_ab = self.gen_b.decode(h_a) + self.train() + return x_ab, x_ba + + def __compute_kl(self, mu): + # def _compute_kl(self, mu, sd): + # mu_2 = torch.pow(mu, 2) + # sd_2 = torch.pow(sd, 2) + # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) + # return encoding_loss + mu_2 = torch.pow(mu, 2) + encoding_loss = torch.mean(mu_2) + return encoding_loss + + def gen_update(self, x_a, x_b): + self.gen_opt.zero_grad() + # encode + h_a, n_a = self.gen_a.encode(x_a) + h_b, n_b = self.gen_b.encode(x_b) + # decode (within domain) + x_a_recon = self.gen_a.decode(h_a + n_a) + 0*x_a + x_b_recon = self.gen_b.decode(h_b + n_b) + 0*x_b + # decode (cross domain) + x_ba = self.gen_a.decode(h_b + n_b) + 0*x_b + x_ab = self.gen_b.decode(h_a + n_a) + 0*x_a + # encode again + h_b_recon, n_b_recon = self.gen_a.encode(x_ba) + h_a_recon, n_a_recon = self.gen_b.encode(x_ab) + # decode again (if needed) + x_aba = self.gen_a.decode(h_a_recon + n_a_recon) + 0*x_ab if self.config['recon_x_cyc_w'] > 0 else None + x_bab = self.gen_b.decode(h_b_recon + n_b_recon) + 0*x_ba if self.config['recon_x_cyc_w'] > 0 else None + + # reconstruction loss + self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) + self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) + self.loss_gen_recon_kl_a = self.__compute_kl(h_a) + self.loss_gen_recon_kl_b = self.__compute_kl(h_b) + self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) + self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) + self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) + self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) + # GAN loss + self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) + self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) + # domain-invariant perceptual loss + self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if self.config['vgg_w'] > 0 else 0 + self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if self.config['vgg_w'] > 0 else 0 + # total loss + self.loss_gen_total = self.config['gan_w'] * self.loss_gen_adv_a + \ + self.config['gan_w'] * self.loss_gen_adv_b + \ + self.config['recon_x_w'] * self.loss_gen_recon_x_a + \ + self.config['recon_kl_w'] * self.loss_gen_recon_kl_a + \ + self.config['recon_x_w'] * self.loss_gen_recon_x_b + \ + self.config['recon_kl_w'] * self.loss_gen_recon_kl_b + \ + self.config['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ + self.config['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ + self.config['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ + self.config['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ + self.config['vgg_w'] * self.loss_gen_vgg_a + \ + self.config['vgg_w'] * self.loss_gen_vgg_b + self.loss_gen_total.backward() + self.gen_opt.step() + self.x_a_recon, self.x_ab, self.x_aba = x_a_recon, x_ab, x_aba + self.x_b_recon, self.x_ba, self.x_bab = x_b_recon, x_ba, x_bab + + def compute_vgg_loss(self, vgg, img, target): + img_vgg = vgg_preprocess(img) + target_vgg = vgg_preprocess(target) + img_fea = vgg(img_vgg) + target_fea = vgg(target_vgg) + return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) + + def dis_update(self, x_a, x_b): + self.dis_opt.zero_grad() + # encode + h_a, n_a = self.gen_a.encode(x_a) + h_b, n_b = self.gen_b.encode(x_b) + # decode (cross domain) + x_ba = self.gen_a.decode(h_b + n_b) + x_ab = self.gen_b.decode(h_a + n_a) + # D loss + self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) + self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) + self.loss_dis_total = self.config['gan_w'] * self.loss_dis_a + self.config['gan_w'] * self.loss_dis_b + self.loss_dis_total.backward() + self.dis_opt.step() + + def get_current_errors(self): + D_A = self.loss_dis_a.data[0] + G_A = self.loss_gen_adv_a.data[0] + kl_A = self.loss_gen_recon_kl_a.data[0] + Cyc_A = self.loss_gen_cyc_x_a.data[0] + D_B = self.loss_dis_b.data[0] + G_B = self.loss_gen_adv_b.data[0] + kl_B = self.loss_gen_recon_kl_b.data[0] + Cyc_B = self.loss_gen_cyc_x_b.data[0] + if self.config['vgg_w'] > 0: + vgg_A = self.loss_gen_vgg_a + vgg_B = self.loss_gen_vgg_b + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('kl_A', kl_A), ('vgg_A', vgg_A), + ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('kl_B', kl_B), ('vgg_B', vgg_B)]) + else: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('kl_A', kl_A), ('Cyc_A', Cyc_A), + ('D_B', D_B), ('G_B', G_B), ('kl_B', kl_B), ('Cyc_B', Cyc_B)]) + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + recon_A = util.tensor2im(self.x_a_recon.data) + A_B = util.tensor2im(self.x_ab.data) + ABA = util.tensor2im(self.x_aba.data) + real_B = util.tensor2im(self.real_B.data) + recon_B = util.tensor2im(self.x_b_recon.data) + B_A = util.tensor2im(self.x_ba.data) + BAB = util.tensor2im(self.x_b_recon.data) + return OrderedDict([('real_A', real_A), ('A_B', A_B), ('recon_A', recon_A), ('ABA', ABA), + ('real_B', real_B), ('B_A', B_A), ('recon_B', recon_B), ('BAB', BAB)]) + + def save(self, label): + self.save_network(self.gen_a, 'G_A', label, self.gpu_ids) + self.save_network(self.dis_a, 'D_A', label, self.gpu_ids) + self.save_network(self.gen_b, 'G_B', label, self.gpu_ids) + self.save_network(self.dis_b, 'D_B', label, self.gpu_ids) + + def update_learning_rate(self): + lrd = self.config['lr'] / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.gen_a.param_groups: + param_group['lr'] = lr + for param_group in self.gen_b.param_groups: + param_group['lr'] = lr + for param_group in self.dis_a.param_groups: + param_group['lr'] = lr + for param_group in self.dis_b.param_groups: + param_group['lr'] = lr + + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr \ No newline at end of file diff --git a/models/unit_network.py b/models/unit_network.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a23fb2292703eb6513f7dcac8f4974daa64a34 --- /dev/null +++ b/models/unit_network.py @@ -0,0 +1,497 @@ +""" +Copyright (C) 2018 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" +from torch import nn +from torch.autograd import Variable +import torch +import torch.nn.functional as F +try: + from itertools import izip as zip +except ImportError: # will be 3.x series + pass + +################################################################################## +# Discriminator +################################################################################## + +class MsImageDis(nn.Module): + # Multi-scale discriminator architecture + def __init__(self, input_dim, params): + super(MsImageDis, self).__init__() + self.n_layer = params['n_layer'] + self.gan_type = params['gan_type'] + self.dim = params['dim'] + self.norm = params['norm'] + self.activ = params['activ'] + self.num_scales = params['num_scales'] + self.pad_type = params['pad_type'] + self.input_dim = input_dim + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + self.cnns = nn.ModuleList() + for _ in range(self.num_scales): + self.cnns.append(self._make_net()) + + def _make_net(self): + dim = self.dim + cnn_x = [] + cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)] + for i in range(self.n_layer - 1): + cnn_x += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)] + dim *= 2 + cnn_x += [nn.Conv2d(dim, 1, 1, 1, 0)] + cnn_x = nn.Sequential(*cnn_x) + return cnn_x + + def forward(self, x): + outputs = [] + for model in self.cnns: + outputs.append(model(x)) + x = self.downsample(x) + return outputs + + def calc_dis_loss(self, input_fake, input_real): + # calculate the loss to train D + outs0 = self.forward(input_fake) + outs1 = self.forward(input_real) + loss = 0 + + for it, (out0, out1) in enumerate(zip(outs0, outs1)): + if self.gan_type == 'lsgan': + loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2) + elif self.gan_type == 'nsgan': + all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False) + all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False) + loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) + + F.binary_cross_entropy(F.sigmoid(out1), all1)) + else: + assert 0, "Unsupported GAN type: {}".format(self.gan_type) + return loss + + def calc_gen_loss(self, input_fake): + # calculate the loss to train G + outs0 = self.forward(input_fake) + loss = 0 + for it, (out0) in enumerate(outs0): + if self.gan_type == 'lsgan': + loss += torch.mean((out0 - 1)**2) # LSGAN + elif self.gan_type == 'nsgan': + all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False) + loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1)) + else: + assert 0, "Unsupported GAN type: {}".format(self.gan_type) + return loss + +################################################################################## +# Generator +################################################################################## + +class AdaINGen(nn.Module): + # AdaIN auto-encoder architecture + def __init__(self, input_dim, params): + super(AdaINGen, self).__init__() + dim = params['dim'] + style_dim = params['style_dim'] + n_downsample = params['n_downsample'] + n_res = params['n_res'] + activ = params['activ'] + pad_type = params['pad_type'] + mlp_dim = params['mlp_dim'] + + # style encoder + self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type) + + # content encoder + self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type) + self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type) + + # MLP to generate AdaIN parameters + self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ) + + def forward(self, images): + # reconstruct an image + content, style_fake = self.encode(images) + images_recon = self.decode(content, style_fake) + return images_recon + + def encode(self, images): + # encode an image to its content and style codes + style_fake = self.enc_style(images) + content = self.enc_content(images) + return content, style_fake + + def decode(self, content, style): + # decode content and style codes to an image + adain_params = self.mlp(style) + self.assign_adain_params(adain_params, self.dec) + images = self.dec(content) + return images + + def assign_adain_params(self, adain_params, model): + # assign the adain_params to the AdaIN layers in model + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm2d": + mean = adain_params[:, :m.num_features] + std = adain_params[:, m.num_features:2*m.num_features] + m.bias = mean.contiguous().view(-1) + m.weight = std.contiguous().view(-1) + if adain_params.size(1) > 2*m.num_features: + adain_params = adain_params[:, 2*m.num_features:] + + def get_num_adain_params(self, model): + # return the number of AdaIN parameters needed by the model + num_adain_params = 0 + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm2d": + num_adain_params += 2*m.num_features + return num_adain_params + + +class VAEGen(nn.Module): + # VAE architecture + def __init__(self, input_dim, params): + super(VAEGen, self).__init__() + dim = params['dim'] + n_downsample = params['n_downsample'] + n_res = params['n_res'] + activ = params['activ'] + pad_type = params['pad_type'] + + # content encoder + self.enc = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type) + self.dec = Decoder(n_downsample, n_res, self.enc.output_dim, input_dim, res_norm='in', activ=activ, pad_type=pad_type) + + def forward(self, images): + # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. + hiddens = self.encode(images) + if self.training == True: + noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) + images_recon = self.decode(hiddens + noise) + else: + images_recon = self.decode(hiddens) + return images_recon, hiddens + + def encode(self, images): + hiddens = self.enc(images) + noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) + return hiddens, noise + + def decode(self, hiddens): + images = self.dec(hiddens) + return images + + +################################################################################## +# Encoder and Decoders +################################################################################## + +class StyleEncoder(nn.Module): + def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): + super(StyleEncoder, self).__init__() + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + for i in range(2): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + for i in range(n_downsample - 2): + self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling + self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)] + self.model = nn.Sequential(*self.model) + self.output_dim = dim + + def forward(self, x): + return self.model(x) + +class ContentEncoder(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder, self).__init__() + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + # residual blocks + self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] + self.model = nn.Sequential(*self.model) + self.output_dim = dim + + def forward(self, x): + return self.model(x) + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'): + super(Decoder, self).__init__() + + self.model = [] + # AdaIN residual blocks + self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)] + # upsampling blocks + for i in range(n_upsample): + self.model += [nn.Upsample(scale_factor=2), + Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)] + dim //= 2 + # use reflection padding in the last conv layer + self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x) + +################################################################################## +# Sequential Models +################################################################################## +class ResBlocks(nn.Module): + def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlocks, self).__init__() + self.model = [] + for i in range(num_blocks): + self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x) + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): + + super(MLP, self).__init__() + self.model = [] + self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] + for i in range(n_blk - 2): + self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] + self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x.view(x.size(0), -1)) + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'ln': + self.norm = LayerNorm(norm_dim) + elif norm == 'adain': + self.norm = AdaptiveInstanceNorm2d(norm_dim) + elif norm == 'none': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + # initialize convolution + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + +class LinearBlock(nn.Module): + def __init__(self, input_dim, output_dim, norm='none', activation='relu'): + super(LinearBlock, self).__init__() + use_bias = True + # initialize fully connected layer + self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm1d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm1d(norm_dim) + elif norm == 'ln': + self.norm = LayerNorm(norm_dim) + elif norm == 'none': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + def forward(self, x): + out = self.fc(x) + if self.norm: + out = self.norm(out) + if self.activation: + out = self.activation(out) + return out + +################################################################################## +# VGG network definition +################################################################################## +class Vgg16(nn.Module): + def __init__(self): + super(Vgg16, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + def forward(self, X): + h = F.relu(self.conv1_1(X), inplace=True) + h = F.relu(self.conv1_2(h), inplace=True) + # relu1_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv2_1(h), inplace=True) + h = F.relu(self.conv2_2(h), inplace=True) + # relu2_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv3_1(h), inplace=True) + h = F.relu(self.conv3_2(h), inplace=True) + h = F.relu(self.conv3_3(h), inplace=True) + # relu3_3 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv4_1(h), inplace=True) + h = F.relu(self.conv4_2(h), inplace=True) + h = F.relu(self.conv4_3(h), inplace=True) + # relu4_3 = h + + h = F.relu(self.conv5_1(h), inplace=True) + h = F.relu(self.conv5_2(h), inplace=True) + h = F.relu(self.conv5_3(h), inplace=True) + relu5_3 = h + + return relu5_3 + # return [relu1_2, relu2_2, relu3_3, relu4_3] + +################################################################################## +# Normalization layers +################################################################################## +class AdaptiveInstanceNorm2d(nn.Module): + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super(AdaptiveInstanceNorm2d, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + # weight and bias are dynamically assigned + self.weight = None + self.bias = None + # just dummy buffers, not used + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + + def forward(self, x): + assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" + b, c = x.size(0), x.size(1) + running_mean = self.running_mean.repeat(b) + running_var = self.running_var.repeat(b) + + # Apply instance norm + x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) + + out = F.batch_norm( + x_reshaped, running_mean, running_var, self.weight, self.bias, + True, self.momentum, self.eps) + + return out.view(b, c, *x.size()[2:]) + + def __repr__(self): + return self.__class__.__name__ + '(' + str(self.num_features) + ')' + +class LayerNorm(nn.Module): + def __init__(self, num_features, eps=1e-5, affine=True): + super(LayerNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + if self.affine: + self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) + self.beta = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + shape = [-1] + [1] * (x.dim() - 1) + mean = x.view(x.size(0), -1).mean(1).view(*shape) + std = x.view(x.size(0), -1).std(1).view(*shape) + x = (x - mean) / (std + self.eps) + + if self.affine: + shape = [1, -1] + [1] * (x.dim() - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x \ No newline at end of file diff --git a/options/._single_unet_conv_add_bs32_BN_nonormDlayer5_3_final_lsgan_64patchD_P_vgg.py b/options/._single_unet_conv_add_bs32_BN_nonormDlayer5_3_final_lsgan_64patchD_P_vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..eba79442330d3c0b20bedbc935a5f16845b1634e Binary files /dev/null and b/options/._single_unet_conv_add_bs32_BN_nonormDlayer5_3_final_lsgan_64patchD_P_vgg.py differ diff --git a/options/__init__.py b/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/options/base_options.py b/options/base_options.py new file mode 100644 index 0000000000000000000000000000000000000000..783671544a05bc79b665da5f57cdd497eeec8b4f --- /dev/null +++ b/options/base_options.py @@ -0,0 +1,118 @@ +import argparse +import os +from util import util +import torch + +class BaseOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + self.initialized = False + + def initialize(self): + self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') + self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') + self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') + self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') + self.parser.add_argument('--patchSize', type=int, default=64, help='then crop to this size') + self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') + self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') + self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') + self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') + self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD') + self.parser.add_argument('--which_model_netG', type=str, default='unet_256', help='selects model to use for netG') + self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') + self.parser.add_argument('--n_layers_patchD', type=int, default=3, help='only used if which_model_netD==n_layers') + self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') + self.parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') + self.parser.add_argument('--model', type=str, default='cycle_gan', + help='chooses which model to use. cycle_gan, pix2pix, test') + self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') + self.parser.add_argument('--nThreads', default=4, type=int, help='# threads for loading data') + self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') + self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') + self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') + self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') + self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') + self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') + self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') + self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') + self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + self.parser.add_argument('--resize_or_crop', type=str, default='crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') + self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') + self.parser.add_argument('--skip', type=float, default=0.8, help='B = net.forward(A) + skip*A') + self.parser.add_argument('--use_mse', action='store_true', help='MSELoss') + self.parser.add_argument('--l1', type=float, default=10.0, help='L1 loss weight is 10.0') + self.parser.add_argument('--use_norm', type=float, default=1, help='L1 loss weight is 10.0') + self.parser.add_argument('--use_wgan', type=float, default=0, help='use wgan-gp') + self.parser.add_argument('--use_ragan', action='store_true', help='use ragan') + self.parser.add_argument('--vgg', type=float, default=0, help='use perceptrual loss') + self.parser.add_argument('--vgg_mean', action='store_true', help='substract mean in vgg loss') + self.parser.add_argument('--vgg_choose', type=str, default='relu5_3', help='choose layer for vgg') + self.parser.add_argument('--no_vgg_instance', action='store_true', help='vgg instance normalization') + self.parser.add_argument('--vgg_maxpooling', action='store_true', help='normalize attention map') + self.parser.add_argument('--IN_vgg', action='store_true', help='patch vgg individual') + self.parser.add_argument('--fcn', type=float, default=0, help='use semantic loss') + self.parser.add_argument('--use_avgpool', type=float, default=0, help='use perceptrual loss') + self.parser.add_argument('--instance_norm', type=float, default=0, help='use instance normalization') + self.parser.add_argument('--syn_norm', action='store_true', help='use synchronize batch normalization') + self.parser.add_argument('--tanh', action='store_true', help='tanh') + self.parser.add_argument('--linear', action='store_true', help='tanh') + self.parser.add_argument('--new_lr', action='store_true', help='tanh') + self.parser.add_argument('--multiply', action='store_true', help='tanh') + self.parser.add_argument('--noise', type=float, default=0, help='variance of noise') + self.parser.add_argument('--input_linear', action='store_true', help='lieanr scaling input') + self.parser.add_argument('--linear_add', action='store_true', help='lieanr scaling input') + self.parser.add_argument('--latent_threshold', action='store_true', help='lieanr scaling input') + self.parser.add_argument('--latent_norm', action='store_true', help='lieanr scaling input') + self.parser.add_argument('--patchD', action='store_true', help='use patch discriminator') + self.parser.add_argument('--patchD_3', type=int, default=0, help='choose the number of crop for patch discriminator') + self.parser.add_argument('--D_P_times2', action='store_true', help='loss_D_P *= 2') + self.parser.add_argument('--patch_vgg', action='store_true', help='use vgg loss between each patch') + self.parser.add_argument('--hybrid_loss', action='store_true', help='use lsgan and ragan separately') + self.parser.add_argument('--self_attention', action='store_true', help='adding attention on the input of generator') + self.parser.add_argument('--times_residual', action='store_true', help='output = input + residual*attention') + self.parser.add_argument('--low_times', type=int, default=200, help='choose the number of crop for patch discriminator') + self.parser.add_argument('--high_times', type=int, default=400, help='choose the number of crop for patch discriminator') + self.parser.add_argument('--norm_attention', action='store_true', help='normalize attention map') + self.parser.add_argument('--vary', type=int, default=1, help='use light data augmentation') + self.parser.add_argument('--lighten', action='store_true', help='normalize attention map') + self.initialized = True + + def parse(self): + if not self.initialized: + self.initialize() + self.opt = self.parser.parse_args() + self.opt.isTrain = self.isTrain # train or test + + str_ids = self.opt.gpu_ids.split(',') + self.opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + self.opt.gpu_ids.append(id) + + # set gpu ids + if len(self.opt.gpu_ids) > 0: + torch.cuda.set_device(self.opt.gpu_ids[0]) + + args = vars(self.opt) + + print('------------ Options -------------') + for k, v in sorted(args.items()): + print('%s: %s' % (str(k), str(v))) + print('-------------- End ----------------') + + # save to the disk + expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write('------------ Options -------------\n') + for k, v in sorted(args.items()): + opt_file.write('%s: %s\n' % (str(k), str(v))) + opt_file.write('-------------- End ----------------\n') + return self.opt diff --git a/options/single_unet_conv_add_bs32_BN_nonormDlayer5_3_final_lsgan_64patchD_P_vgg.py b/options/single_unet_conv_add_bs32_BN_nonormDlayer5_3_final_lsgan_64patchD_P_vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..7a52be2e74136a08364a11d6f6175fb53413c005 --- /dev/null +++ b/options/single_unet_conv_add_bs32_BN_nonormDlayer5_3_final_lsgan_64patchD_P_vgg.py @@ -0,0 +1,64 @@ +import os +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--port", type=str, default="8097") +parser.add_argument("--train", action='store_true') +parser.add_argument("--test", action='store_true') +parser.add_argument("--predict", action='store_true') +opt = parser.parse_args() + +if opt.train: + os.system("python train.py \ + --dataroot /vita1_ssd1/yifan/final_dataset \ + --no_dropout \ + --name single_unet_conv_add_bs32_BN_nonormDlayer5_3_final_lsgan_64patchD_P_vgg \ + --model single \ + --dataset_mode unaligned \ + --which_model_netG sid_unet_resize \ + --which_model_netD no_norm_4 \ + --n_layers_D 5 \ + --n_layers_patchD 3 \ + --patchD \ + --fineSize 320 \ + --patchSize 64 \ + --skip 1 \ + --batchSize 30 \ + --use_norm 1 \ + --use_wgan 0 \ + --instance_norm 0 \ + --vgg 1 \ + --gpu_ids 0,1,2,3 \ + --display_port=" + opt.port) + +elif opt.test: + for i in range(20): + os.system("python test.py \ + --dataroot /vita1_ssd1/yifan/compete_LOL \ + --name single_unet_conv_add_bs32_BN_nonormDlayer5_3_final_lsgan_64patchD_P_vgg \ + --model single \ + --which_direction AtoB \ + --no_dropout \ + --dataset_mode pair \ + --which_model_netG sid_unet_resize \ + --skip 1 \ + --use_norm 1 \ + --use_wgan 0 \ + --instance_norm 0 \ + --which_epoch " + str(i*5+100)) + +elif opt.predict: + for i in range(20): + os.system("python predict.py \ + --dataroot /vita1_ssd1/yifan/common_dataset \ + --name single_unet_conv_add_bs32_BN_nonormDlayer5_3_final_lsgan_64patchD_P_vgg \ + --model single \ + --which_direction AtoB \ + --no_dropout \ + --dataset_mode unaligned \ + --which_model_netG sid_unet_resize \ + --skip 1 \ + --use_norm 1 \ + --use_wgan 0 \ + --instance_norm 0 --resize_or_crop='no'\ + --which_epoch " + str(200 - i*10)) diff --git a/options/test_options.py b/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..6b79860fd50f5d11b6c25432aff1b051d5388170 --- /dev/null +++ b/options/test_options.py @@ -0,0 +1,13 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') + self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') + self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') + self.isTrain = False diff --git a/options/train_options.py b/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..83084a25345b24104b82dbee6b758ee51ae2afc3 --- /dev/null +++ b/options/train_options.py @@ -0,0 +1,22 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + self.parser.add_argument('--display_freq', type=int, default=30, help='frequency of showing training results on screen') + self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') + self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') + self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') + self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + self.parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') + self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') + self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') + self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + self.parser.add_argument('--config', type=str, default='configs/unit_gta2city_folder.yaml', help='Path to the config file.') + self.isTrain = True diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..877ab2f920a264b821f368bf143b8dfb485ab138 --- /dev/null +++ b/predict.py @@ -0,0 +1,32 @@ +import time +import os +from options.test_options import TestOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +from util.visualizer import Visualizer +from pdb import set_trace as st +from util import html + +opt = TestOptions().parse() +opt.nThreads = 1 # test code only supports nThreads = 1 +opt.batchSize = 1 # test code only supports batchSize = 1 +opt.serial_batches = True # no shuffle +opt.no_flip = True # no flip + +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +model = create_model(opt) +visualizer = Visualizer(opt) +# create website +web_dir = os.path.join("./ablation/", opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) +webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) +# test +print(len(dataset)) +for i, data in enumerate(dataset): + model.set_input(data) + visuals = model.predict() + img_path = model.get_image_paths() + print('process image... %s' % img_path) + visualizer.save_images(webpage, visuals, img_path) + +webpage.save() diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..80db8eb17f8f6cc27d289d2a85388c4f3d058eae --- /dev/null +++ b/requirement.txt @@ -0,0 +1,4 @@ +torch==0.3.1 +torchvision==0.2.0 +visdom +dominate diff --git a/scripts/script.py b/scripts/script.py new file mode 100644 index 0000000000000000000000000000000000000000..4d982a11dd1f08141d28e5592d30eba4170aa48e --- /dev/null +++ b/scripts/script.py @@ -0,0 +1,56 @@ +import os +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--port", type=str, default="8097") +parser.add_argument("--train", action='store_true') +parser.add_argument("--predict", action='store_true') +opt = parser.parse_args() + +if opt.train: + os.system("python train.py \ + --dataroot ../final_dataset \ + --no_dropout \ + --name enlightening \ + --model single \ + --dataset_mode unaligned \ + --which_model_netG sid_unet_resize \ + --which_model_netD no_norm_4 \ + --patchD \ + --patch_vgg \ + --patchD_3 5 \ + --n_layers_D 5 \ + --n_layers_patchD 4 \ + --fineSize 320 \ + --patchSize 32 \ + --skip 1 \ + --batchSize 32 \ + --self_attention \ + --use_norm 1 \ + --use_wgan 0 \ + --use_ragan \ + --hybrid_loss \ + --times_residual \ + --instance_norm 0 \ + --vgg 1 \ + --vgg_choose relu5_1 \ + --gpu_ids 0,1,2 \ + --display_port=" + opt.port) + +elif opt.predict: + for i in range(1): + os.system("python predict.py \ + --dataroot ../test_dataset \ + --name enlightening \ + --model single \ + --which_direction AtoB \ + --no_dropout \ + --dataset_mode unaligned \ + --which_model_netG sid_unet_resize \ + --skip 1 \ + --use_norm 1 \ + --use_wgan 0 \ + --self_attention \ + --times_residual \ + --instance_norm 0 --resize_or_crop='no'\ + --which_epoch " + str(200 - i*5)) \ No newline at end of file diff --git a/seg/PSPNet.py b/seg/PSPNet.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/seg/resnet.py b/seg/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/seg/resnext b/seg/resnext new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6781b8b0e1198f8bcffac2fd8194849c29deb989 --- /dev/null +++ b/train.py @@ -0,0 +1,71 @@ +import time +from options.train_options import TrainOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +from util.visualizer import Visualizer + +def get_config(config): + import yaml + with open(config, 'r') as stream: + return yaml.load(stream) + +opt = TrainOptions().parse() +config = get_config(opt.config) +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +dataset_size = len(data_loader) +print('#training images = %d' % dataset_size) + +model = create_model(opt) +visualizer = Visualizer(opt) + +total_steps = 0 + +for epoch in range(1, opt.niter + opt.niter_decay + 1): + epoch_start_time = time.time() + for i, data in enumerate(dataset): + iter_start_time = time.time() + total_steps += opt.batchSize + epoch_iter = total_steps - dataset_size * (epoch - 1) + model.set_input(data) + model.optimize_parameters(epoch) + + if total_steps % opt.display_freq == 0: + visualizer.display_current_results(model.get_current_visuals(), epoch) + + if total_steps % opt.print_freq == 0: + errors = model.get_current_errors(epoch) + t = (time.time() - iter_start_time) / opt.batchSize + visualizer.print_current_errors(epoch, epoch_iter, errors, t) + if opt.display_id > 0: + visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) + + if total_steps % opt.save_latest_freq == 0: + print('saving the latest model (epoch %d, total_steps %d)' % + (epoch, total_steps)) + model.save('latest') + + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % + (epoch, total_steps)) + model.save('latest') + model.save(epoch) + + print('End of epoch %d / %d \t Time Taken: %d sec' % + (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) + + if opt.new_lr: + if epoch == opt.niter: + model.update_learning_rate() + elif epoch == (opt.niter + 20): + model.update_learning_rate() + elif epoch == (opt.niter + 70): + model.update_learning_rate() + elif epoch == (opt.niter + 90): + model.update_learning_rate() + model.update_learning_rate() + model.update_learning_rate() + model.update_learning_rate() + else: + if epoch > opt.niter: + model.update_learning_rate() diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/util/get_data.py b/util/get_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6325605bc68ec3b4036a4e0f42c28e0b8965867d --- /dev/null +++ b/util/get_data.py @@ -0,0 +1,115 @@ +from __future__ import print_function +import os +import tarfile +import requests +from warnings import warn +from zipfile import ZipFile +from bs4 import BeautifulSoup +from os.path import abspath, isdir, join, basename + + +class GetData(object): + """ + + Download CycleGAN or Pix2Pix Data. + + Args: + technique : str + One of: 'cyclegan' or 'pix2pix'. + verbose : bool + If True, print additional information. + + Examples: + >>> from util.get_data import GetData + >>> gd = GetData(technique='cyclegan') + >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. + + """ + + def __init__(self, technique='cyclegan', verbose=True): + url_dict = { + 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', + 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' + } + self.url = url_dict.get(technique.lower()) + self._verbose = verbose + + def _print(self, text): + if self._verbose: + print(text) + + @staticmethod + def _get_options(r): + soup = BeautifulSoup(r.text, 'lxml') + options = [h.text for h in soup.find_all('a', href=True) + if h.text.endswith(('.zip', 'tar.gz'))] + return options + + def _present_options(self): + r = requests.get(self.url) + options = self._get_options(r) + print('Options:\n') + for i, o in enumerate(options): + print("{0}: {1}".format(i, o)) + choice = input("\nPlease enter the number of the " + "dataset above you wish to download:") + return options[int(choice)] + + def _download_data(self, dataset_url, save_path): + if not isdir(save_path): + os.makedirs(save_path) + + base = basename(dataset_url) + temp_save_path = join(save_path, base) + + with open(temp_save_path, "wb") as f: + r = requests.get(dataset_url) + f.write(r.content) + + if base.endswith('.tar.gz'): + obj = tarfile.open(temp_save_path) + elif base.endswith('.zip'): + obj = ZipFile(temp_save_path, 'r') + else: + raise ValueError("Unknown File Type: {0}.".format(base)) + + self._print("Unpacking Data...") + obj.extractall(save_path) + obj.close() + os.remove(temp_save_path) + + def get(self, save_path, dataset=None): + """ + + Download a dataset. + + Args: + save_path : str + A directory to save the data to. + dataset : str, optional + A specific dataset to download. + Note: this must include the file extension. + If None, options will be presented for you + to choose from. + + Returns: + save_path_full : str + The absolute path to the downloaded data. + + """ + if dataset is None: + selected_dataset = self._present_options() + else: + selected_dataset = dataset + + save_path_full = join(save_path, selected_dataset.split('.')[0]) + + if isdir(save_path_full): + warn("\n'{0}' already exists. Voiding Download.".format( + save_path_full)) + else: + self._print('Downloading Data...') + url = "{0}/{1}".format(self.url, selected_dataset) + self._download_data(url, save_path=save_path) + + return abspath(save_path_full) diff --git a/util/html.py b/util/html.py new file mode 100644 index 0000000000000000000000000000000000000000..c7956f1353fd25aee253e39a6178481b0b330621 --- /dev/null +++ b/util/html.py @@ -0,0 +1,64 @@ +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, reflesh=0): + self.title = title + self.web_dir = web_dir + self.img_dir = os.path.join(self.web_dir, 'images') + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + # print(self.img_dir) + + self.doc = dominate.document(title=title) + if reflesh > 0: + with self.doc.head: + meta(http_equiv="reflesh", content=str(reflesh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=400): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/util/image_pool.py b/util/image_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..152ef5be2b40c249927100ef7fff7306986c8feb --- /dev/null +++ b/util/image_pool.py @@ -0,0 +1,32 @@ +import random +import numpy as np +import torch +from torch.autograd import Variable +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images.data: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size-1) + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = Variable(torch.cat(return_images, 0)) + return return_images diff --git a/util/png.py b/util/png.py new file mode 100644 index 0000000000000000000000000000000000000000..0936cf08d7f3d307aa44d8e3ccb548251067f53d --- /dev/null +++ b/util/png.py @@ -0,0 +1,33 @@ +import struct +import zlib + +def encode(buf, width, height): + """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ + assert (width * height * 3 == len(buf)) + bpp = 3 + + def raw_data(): + # reverse the vertical line order and add null bytes at the start + row_bytes = width * bpp + for row_start in range((height - 1) * width * bpp, -1, -row_bytes): + yield b'\x00' + yield buf[row_start:row_start + row_bytes] + + def chunk(tag, data): + return [ + struct.pack("!I", len(data)), + tag, + data, + struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) + ] + + SIGNATURE = b'\x89PNG\r\n\x1a\n' + COLOR_TYPE_RGB = 2 + COLOR_TYPE_RGBA = 6 + bit_depth = 8 + return b''.join( + [ SIGNATURE ] + + chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + + chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + + chunk(b'IEND', b'') + ) diff --git a/util/util.py b/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..5d499e1f761cb8f6af3725227be07a77e80b2622 --- /dev/null +++ b/util/util.py @@ -0,0 +1,182 @@ +# from __future__ import print_function +import numpy as np +from PIL import Image +import inspect, re +import numpy as np +import torch +import os +import collections +from torch.optim import lr_scheduler +import torch.nn.init as init + + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + image_numpy = np.maximum(image_numpy, 0) + image_numpy = np.minimum(image_numpy, 255) + return image_numpy.astype(imtype) + +def atten2im(image_tensor, imtype=np.uint8): + image_tensor = image_tensor[0] + image_tensor = torch.cat((image_tensor, image_tensor, image_tensor), 0) + image_numpy = image_tensor.cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 + image_numpy = image_numpy/(image_numpy.max()/255.0) + return image_numpy.astype(imtype) + +def latent2im(image_tensor, imtype=np.uint8): + # image_tensor = (image_tensor - torch.min(image_tensor))/(torch.max(image_tensor)-torch.min(image_tensor)) + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 + image_numpy = np.maximum(image_numpy, 0) + image_numpy = np.minimum(image_numpy, 255) + return image_numpy.astype(imtype) + +def max2im(image_1, image_2, imtype=np.uint8): + image_1 = image_1[0].cpu().float().numpy() + image_2 = image_2[0].cpu().float().numpy() + image_1 = (np.transpose(image_1, (1, 2, 0)) + 1) / 2.0 * 255.0 + image_2 = (np.transpose(image_2, (1, 2, 0))) * 255.0 + output = np.maximum(image_1, image_2) + output = np.maximum(output, 0) + output = np.minimum(output, 255) + return output.astype(imtype) + +def variable2im(image_tensor, imtype=np.uint8): + image_numpy = image_tensor[0].data.cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + +def info(object, spacing=10, collapse=1): + """Print methods and doc strings. + Takes module, class, list, dictionary, or string.""" + methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] + processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) + print( "\n".join(["%s %s" % + (method.ljust(spacing), + processFunc(str(getattr(object, method).__doc__))) + for method in methodList]) ) + +def varname(p): + for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: + m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) + if m: + return m.group(1) + +def print_numpy(x, val=True, shp=False): + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +def get_model_list(dirname, key): + if os.path.exists(dirname) is False: + return None + gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if + os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] + if gen_models is None: + return None + gen_models.sort() + last_model_name = gen_models[-1] + return last_model_name + + +def load_vgg16(model_dir): + """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """ + if not os.path.exists(model_dir): + os.mkdir(model_dir) + if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')): + if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')): + os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7')) + vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7')) + vgg = Vgg16() + for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): + dst.data[:] = src + torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight')) + vgg = Vgg16() + vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) + return vgg + + +def vgg_preprocess(batch): + tensortype = type(batch.data) + (r, g, b) = torch.chunk(batch, 3, dim = 1) + batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR + batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] + mean = tensortype(batch.data.size()) + mean[:, 0, :, :] = 103.939 + mean[:, 1, :, :] = 116.779 + mean[:, 2, :, :] = 123.680 + batch = batch.sub(Variable(mean)) # subtract mean + return batch + + +def get_scheduler(optimizer, hyperparameters, iterations=-1): + if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant': + scheduler = None # constant scheduler + elif hyperparameters['lr_policy'] == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'], + gamma=hyperparameters['gamma'], last_epoch=iterations) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy']) + return scheduler + + +def weights_init(init_type='gaussian'): + def init_fun(m): + classname = m.__class__.__name__ + if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): + # print m.__class__.__name__ + if init_type == 'gaussian': + init.normal(m.weight.data, 0.0, 0.02) + elif init_type == 'xavier': + init.xavier_normal(m.weight.data, gain=math.sqrt(2)) + elif init_type == 'kaiming': + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal(m.weight.data, gain=math.sqrt(2)) + elif init_type == 'default': + pass + else: + assert 0, "Unsupported initialization: {}".format(init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant(m.bias.data, 0.0) + + return init_fun \ No newline at end of file diff --git a/util/visualizer.py b/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8c898a0243433134ce0e2495d12d1a8843c35fd0 --- /dev/null +++ b/util/visualizer.py @@ -0,0 +1,160 @@ +import numpy as np +import os +import ntpath +import time +from . import util +from . import html + +class Visualizer(): + def __init__(self, opt): + # self.opt = opt + self.display_id = opt.display_id + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + if self.display_id > 0: + import visdom + self.vis = visdom.Visdom(port = opt.display_port) + self.display_single_pane_ncols = opt.display_single_pane_ncols + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch): + if self.display_id > 0: # show images in the browser + if self.display_single_pane_ncols > 0: + h, w = next(iter(visuals.values())).shape[:2] + table_css = """""" % (w, h) + ncols = self.display_single_pane_ncols + title = self.name + label_html = '' + label_html_row = '' + nrows = int(np.ceil(len(visuals.items()) / ncols)) + images = [] + idx = 0 + for label, image_numpy in visuals.items(): + label_html_row += '%s' % label + images.append(image_numpy.transpose([2, 0, 1])) + idx += 1 + if idx % ncols == 0: + label_html += '%s' % label_html_row + label_html_row = '' + white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 + while idx % ncols != 0: + images.append(white_image) + label_html_row += '' + idx += 1 + if label_html_row != '': + label_html += '%s' % label_html_row + # pane col = image row + self.vis.images(images, nrow=ncols, win=self.display_id + 1, + padding=2, opts=dict(title=title + ' images')) + label_html = '%s
' % label_html + self.vis.text(table_css + label_html, win = self.display_id + 2, + opts=dict(title=title + ' labels')) + else: + idx = 1 + for label, image_numpy in visuals.items(): + #image_numpy = np.flipud(image_numpy) + self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), + win=self.display_id + idx) + idx += 1 + + if self.use_html: # save images to a html file + for label, image_numpy in visuals.items(): + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + # errors: dictionary of error labels and values + def plot_current_errors(self, epoch, counter_ratio, opt, errors): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, t): + message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) + for k, v in errors.items(): + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + + # save image to the disk + def save_images(self, webpage, visuals, image_path): + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + image_name = '%s_%s.png' % (name, label) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=self.win_size) + + + def save_images_demo(self, webpage, visuals, image_path): + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + image_name = '%s.jpg' % (name) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=self.win_size)