HenryGong commited on
Commit
aba0e05
1 Parent(s): 84beb6b

Upload 84 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +120 -0
  3. License +58 -0
  4. README.md +67 -3
  5. assets/.DS_Store +0 -0
  6. assets/arch.png +3 -0
  7. assets/arch_2.pdf +3 -0
  8. assets/comparison_3.pdf +3 -0
  9. assets/new_ablation.pdf +3 -0
  10. assets/show_3.png +3 -0
  11. assets/table.pdf +0 -0
  12. configs/unit_gta2city_folder.yaml +54 -0
  13. data/__init__.py +0 -0
  14. data/aligned_dataset.py +56 -0
  15. data/base_data_loader.py +14 -0
  16. data/base_dataset.py +50 -0
  17. data/custom_dataset_data_loader.py +50 -0
  18. data/data_loader.py +7 -0
  19. data/image_folder.py +83 -0
  20. data/pair_dataset.py +95 -0
  21. data/single_dataset.py +36 -0
  22. data/syn_dataset.py +91 -0
  23. data/unaligned_dataset.py +141 -0
  24. data/unaligned_random_crop.py +85 -0
  25. datasets/.DS_Store +0 -0
  26. datasets/bibtex/cityscapes.tex +6 -0
  27. datasets/bibtex/facades.tex +7 -0
  28. datasets/bibtex/handbags.tex +13 -0
  29. datasets/bibtex/shoes.tex +14 -0
  30. datasets/combine_A_and_B.py +49 -0
  31. datasets/download_cyclegan_dataset.sh +14 -0
  32. datasets/download_pix2pix_dataset.sh +8 -0
  33. imgs/edges2cats.jpg +0 -0
  34. imgs/horse2zebra.gif +3 -0
  35. lib/nn/__init__.py +2 -0
  36. lib/nn/modules/__init__.py +12 -0
  37. lib/nn/modules/batchnorm.py +329 -0
  38. lib/nn/modules/comm.py +131 -0
  39. lib/nn/modules/replicate.py +94 -0
  40. lib/nn/modules/tests/test_numeric_batchnorm.py +56 -0
  41. lib/nn/modules/tests/test_sync_batchnorm.py +111 -0
  42. lib/nn/modules/unittest.py +29 -0
  43. lib/nn/parallel/__init__.py +1 -0
  44. lib/nn/parallel/data_parallel.py +112 -0
  45. lib/utils/__init__.py +1 -0
  46. lib/utils/data/__init__.py +3 -0
  47. lib/utils/data/dataloader.py +422 -0
  48. lib/utils/data/dataset.py +118 -0
  49. lib/utils/data/distributed.py +58 -0
  50. lib/utils/data/sampler.py +131 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/arch_2.pdf filter=lfs diff=lfs merge=lfs -text
37
+ assets/arch.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/comparison_3.pdf filter=lfs diff=lfs merge=lfs -text
39
+ assets/new_ablation.pdf filter=lfs diff=lfs merge=lfs -text
40
+ assets/show_3.png filter=lfs diff=lfs merge=lfs -text
41
+ imgs/horse2zebra.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ checkpoints/
11
+ .DS_Store
12
+ ._.DS_Store
13
+ .vscode
14
+ predict/
15
+ results/
16
+ model/
17
+ .pth
18
+ .png
19
+ .jpg
20
+ .Python
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ .python-version
91
+
92
+ # celery beat schedule file
93
+ celerybeat-schedule
94
+
95
+ # SageMath parsed files
96
+ *.sage.py
97
+
98
+ # Environments
99
+ .env
100
+ .venv
101
+ env/
102
+ venv/
103
+ ENV/
104
+ env.bak/
105
+ venv.bak/
106
+
107
+ # Spyder project settings
108
+ .spyderproject
109
+ .spyproject
110
+
111
+ # Rope project settings
112
+ .ropeproject
113
+
114
+ # mkdocs documentation
115
+ /site
116
+
117
+ # mypy
118
+ .mypy_cache/
119
+ .dmypy.json
120
+ dmypy.json
License ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2019, Yifan Jiang and Zhangyang Wang
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
+
25
+
26
+ --------------------------- LICENSE FOR EnlightenGAN --------------------------------
27
+ BSD License
28
+
29
+ For EnlightenGAN software
30
+ Copyright (c) 2019, Yifan Jiang and Zhangyang Wang
31
+ All rights reserved.
32
+
33
+ Redistribution and use in source and binary forms, with or without
34
+ modification, are permitted provided that the following conditions are met:
35
+
36
+ * Redistributions of source code must retain the above copyright notice, this
37
+ list of conditions and the following disclaimer.
38
+
39
+ * Redistributions in binary form must reproduce the above copyright notice,
40
+ this list of conditions and the following disclaimer in the documentation
41
+ and/or other materials provided with the distribution.
42
+
43
+ ----------------------------- LICENSE FOR DCGAN --------------------------------
44
+ BSD License
45
+
46
+ For dcgan.torch software
47
+
48
+ Copyright (c) 2015, Facebook, Inc. All rights reserved.
49
+
50
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51
+
52
+ Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53
+
54
+ Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55
+
56
+ Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57
+
58
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,3 +1,67 @@
1
- ---
2
- license: bsd
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EnlightenGAN: Deep Light Enhancement without Paired Supervision
2
+ [Yifan Jiang](https://yifanjiang19.github.io/), Xinyu Gong, Ding Liu, Yu Cheng, Chen Fang, Xiaohui Shen, Jianchao Yang, Pan Zhou, Zhangyang Wang
3
+
4
+ [[Paper]](https://arxiv.org/abs/1906.06972) [[Supplementary Materials]](https://yifanjiang.net/files/EnlightenGAN_Supplementary.pdf)
5
+
6
+
7
+ ### Representitive Results
8
+ ![representive_results](/assets/show_3.png)
9
+
10
+ ### Overal Architecture
11
+ ![architecture](/assets/arch.png)
12
+
13
+ ## Environment Preparing
14
+ ```
15
+ python3.5
16
+ ```
17
+ You should prepare at least 3 1080ti gpus or change the batch size.
18
+
19
+
20
+ ```pip install -r requirement.txt``` </br>
21
+ ```mkdir model``` </br>
22
+ Download VGG pretrained model from [[Google Drive 1]](https://drive.google.com/file/d/1IfCeihmPqGWJ0KHmH-mTMi_pn3z3Zo-P/view?usp=sharing), and then put it into the directory `model`.
23
+
24
+ ### Training process
25
+ Before starting training process, you should launch the `visdom.server` for visualizing.
26
+
27
+ ```nohup python -m visdom.server -port=8097```
28
+
29
+ then run the following command
30
+
31
+ ```python scripts/script.py --train```
32
+
33
+ ### Testing process
34
+
35
+ Download [pretrained model](https://drive.google.com/file/d/1AkV-n2MdyfuZTFvcon8Z4leyVb0i7x63/view?usp=sharing) and put it into `./checkpoints/enlightening`
36
+
37
+ Create directories `../test_dataset/testA` and `../test_dataset/testB`. Put your test images on `../test_dataset/testA` (And you should keep whatever one image in `../test_dataset/testB` to make sure program can start.)
38
+
39
+ Run
40
+
41
+ ```python scripts/script.py --predict ```
42
+
43
+ ### Dataset preparing
44
+
45
+ Training data [[Google Drive]](https://drive.google.com/drive/folders/1fwqz8-RnTfxgIIkebFG2Ej3jQFsYECh0?usp=sharing) (unpaired images collected from multiple datasets)
46
+
47
+ Testing data [[Google Drive]](https://drive.google.com/open?id=1PrvL8jShZ7zj2IC3fVdDxBY1oJR72iDf) (including LIME, MEF, NPE, VV, DICP)
48
+
49
+ And [[BaiduYun]](https://github.com/TAMU-VITA/EnlightenGAN/issues/28) is available now thanks to @YHLelaine!
50
+
51
+ ### Faster Inference
52
+ https://github.com/arsenyinfo/EnlightenGAN-inference from @arsenyinfo
53
+
54
+
55
+
56
+ If you find this work useful for you, please cite
57
+ ```
58
+ @article{jiang2021enlightengan,
59
+ title={Enlightengan: Deep light enhancement without paired supervision},
60
+ author={Jiang, Yifan and Gong, Xinyu and Liu, Ding and Cheng, Yu and Fang, Chen and Shen, Xiaohui and Yang, Jianchao and Zhou, Pan and Wang, Zhangyang},
61
+ journal={IEEE Transactions on Image Processing},
62
+ volume={30},
63
+ pages={2340--2349},
64
+ year={2021},
65
+ publisher={IEEE}
66
+ }
67
+ ```
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/arch.png ADDED

Git LFS Details

  • SHA256: ab49b01eedf35c7325f9cfd98825cfbc96f209b769f4ed369d6835429906132b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
assets/arch_2.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df02a7f2b894d6230a1f120aa7c112962abe39232c648a441879d9dc8cc71756
3
+ size 1738396
assets/comparison_3.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75a9820f1a978d9f0b6230dbd163efad5f8ca4100afe06bbed90cbe780a341d5
3
+ size 1753489
assets/new_ablation.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38b13a6ad0682986d9535ad2d908a5124f65d321240d0cca881f84f6ef033892
3
+ size 1407150
assets/show_3.png ADDED

Git LFS Details

  • SHA256: e87e8ff557f17604ad8f51bcadc52bea21e57f72b880a45fe3ea3e0c42703c76
  • Pointer size: 132 Bytes
  • Size of remote file: 3.74 MB
assets/table.pdf ADDED
Binary file (96 kB). View file
 
configs/unit_gta2city_folder.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2
+ # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3
+
4
+ # logger options
5
+ image_save_iter: 1000 # How often do you want to save output images during training
6
+ image_display_iter: 10 # How often do you want to display output images during training
7
+ display_size: 8 # How many images do you want to display each time
8
+ snapshot_save_iter: 10000 # How often do you want to save trained models
9
+ log_iter: 1 # How often do you want to log the training stats
10
+
11
+ # optimization options
12
+ max_iter: 1000000 # maximum number of training iterations
13
+ batch_size: 1 # batch size
14
+ weight_decay: 0.0001 # weight decay
15
+ beta1: 0.5 # Adam parameter
16
+ beta2: 0.999 # Adam parameter
17
+ init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
18
+ lr: 0.0001 # initial learning rate
19
+ lr_policy: step # learning rate scheduler
20
+ step_size: 100000 # how often to decay learning rate
21
+ gamma: 0.5 # how much to decay learning rate
22
+ gan_w: 1 # weight of adversarial loss
23
+ recon_x_w: 10 # weight of image reconstruction loss
24
+ recon_h_w: 0 # weight of hidden reconstruction loss
25
+ recon_kl_w: 0.01 # weight of KL loss for reconstruction
26
+ recon_x_cyc_w: 10 # weight of cycle consistency loss
27
+ recon_kl_cyc_w: 0.01 # weight of KL loss for cycle consistency
28
+ vgg_w: 0 # weight of domain-invariant perceptual loss
29
+
30
+ # model options
31
+ gen:
32
+ dim: 64 # number of filters in the bottommost layer
33
+ activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
34
+ n_downsample: 2 # number of downsampling layers in content encoder
35
+ n_res: 4 # number of residual blocks in content encoder/decoder
36
+ pad_type: reflect # padding type [zero/reflect]
37
+ dis:
38
+ dim: 64 # number of filters in the bottommost layer
39
+ norm: none # normalization layer [none/bn/in/ln]
40
+ activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
41
+ n_layer: 4 # number of layers in D
42
+ gan_type: lsgan # GAN loss [lsgan/nsgan]
43
+ num_scales: 3 # number of scales
44
+ pad_type: reflect # padding type [zero/reflect]
45
+
46
+ # data options
47
+ input_dim_a: 3 # number of image channels [1/3]
48
+ input_dim_b: 3 # number of image channels [1/3]
49
+ num_workers: 8 # number of data loading threads
50
+ new_size: 256 # first resize the shortest image side to this size
51
+ crop_image_height: 256 # random crop image of this height
52
+ crop_image_width: 256 # random crop image of this width
53
+
54
+ data_root: ./datasets/lol/ # dataset folder location
data/__init__.py ADDED
File without changes
data/aligned_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import random
3
+ import torchvision.transforms as transforms
4
+ import torch
5
+ from data.base_dataset import BaseDataset
6
+ from data.image_folder import make_dataset
7
+ from PIL import Image
8
+
9
+
10
+ class AlignedDataset(BaseDataset):
11
+ def initialize(self, opt):
12
+ self.opt = opt
13
+ self.root = opt.dataroot
14
+ self.dir_AB = os.path.join(opt.dataroot, opt.phase)
15
+
16
+ self.AB_paths = sorted(make_dataset(self.dir_AB))
17
+
18
+ assert(opt.resize_or_crop == 'resize_and_crop')
19
+
20
+ transform_list = [transforms.ToTensor(),
21
+ transforms.Normalize((0.5, 0.5, 0.5),
22
+ (0.5, 0.5, 0.5))]
23
+
24
+ self.transform = transforms.Compose(transform_list)
25
+
26
+ def __getitem__(self, index):
27
+ AB_path = self.AB_paths[index]
28
+ AB = Image.open(AB_path).convert('RGB')
29
+ AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC)
30
+ AB = self.transform(AB)
31
+
32
+ w_total = AB.size(2)
33
+ w = int(w_total / 2)
34
+ h = AB.size(1)
35
+ w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
36
+ h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
37
+
38
+ A = AB[:, h_offset:h_offset + self.opt.fineSize,
39
+ w_offset:w_offset + self.opt.fineSize]
40
+ B = AB[:, h_offset:h_offset + self.opt.fineSize,
41
+ w + w_offset:w + w_offset + self.opt.fineSize]
42
+
43
+ if (not self.opt.no_flip) and random.random() < 0.5:
44
+ idx = [i for i in range(A.size(2) - 1, -1, -1)]
45
+ idx = torch.LongTensor(idx)
46
+ A = A.index_select(2, idx)
47
+ B = B.index_select(2, idx)
48
+
49
+ return {'A': A, 'B': B,
50
+ 'A_paths': AB_path, 'B_paths': AB_path}
51
+
52
+ def __len__(self):
53
+ return len(self.AB_paths)
54
+
55
+ def name(self):
56
+ return 'AlignedDataset'
data/base_data_loader.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class BaseDataLoader():
3
+ def __init__(self):
4
+ pass
5
+
6
+ def initialize(self, opt):
7
+ self.opt = opt
8
+ pass
9
+
10
+ def load_data():
11
+ return None
12
+
13
+
14
+
data/base_dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+ import random
5
+
6
+ class BaseDataset(data.Dataset):
7
+ def __init__(self):
8
+ super(BaseDataset, self).__init__()
9
+
10
+ def name(self):
11
+ return 'BaseDataset'
12
+
13
+ def initialize(self, opt):
14
+ pass
15
+
16
+ def get_transform(opt):
17
+ transform_list = []
18
+ if opt.resize_or_crop == 'resize_and_crop':
19
+ zoom = 1 + 0.1*radom.randint(0,4)
20
+ osize = [int(400*zoom), int(600*zoom)]
21
+ transform_list.append(transforms.Scale(osize, Image.BICUBIC))
22
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
23
+ elif opt.resize_or_crop == 'crop':
24
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
25
+ elif opt.resize_or_crop == 'scale_width':
26
+ transform_list.append(transforms.Lambda(
27
+ lambda img: __scale_width(img, opt.fineSize)))
28
+ elif opt.resize_or_crop == 'scale_width_and_crop':
29
+ transform_list.append(transforms.Lambda(
30
+ lambda img: __scale_width(img, opt.loadSize)))
31
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
32
+ # elif opt.resize_or_crop == 'no':
33
+ # osize = [384, 512]
34
+ # transform_list.append(transforms.Scale(osize, Image.BICUBIC))
35
+
36
+ if opt.isTrain and not opt.no_flip:
37
+ transform_list.append(transforms.RandomHorizontalFlip())
38
+
39
+ transform_list += [transforms.ToTensor(),
40
+ transforms.Normalize((0.5, 0.5, 0.5),
41
+ (0.5, 0.5, 0.5))]
42
+ return transforms.Compose(transform_list)
43
+
44
+ def __scale_width(img, target_width):
45
+ ow, oh = img.size
46
+ if (ow == target_width):
47
+ return img
48
+ w = target_width
49
+ h = int(target_width * oh / ow)
50
+ return img.resize((w, h), Image.BICUBIC)
data/custom_dataset_data_loader.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data
2
+ from data.base_data_loader import BaseDataLoader
3
+
4
+
5
+ def CreateDataset(opt):
6
+ dataset = None
7
+ if opt.dataset_mode == 'aligned':
8
+ from data.aligned_dataset import AlignedDataset
9
+ dataset = AlignedDataset()
10
+ elif opt.dataset_mode == 'unaligned':
11
+ from data.unaligned_dataset import UnalignedDataset
12
+ dataset = UnalignedDataset()
13
+ elif opt.dataset_mode == 'unaligned_random_crop':
14
+ from data.unaligned_random_crop import UnalignedDataset
15
+ dataset = UnalignedDataset()
16
+ elif opt.dataset_mode == 'pair':
17
+ from data.pair_dataset import PairDataset
18
+ dataset = PairDataset()
19
+ elif opt.dataset_mode == 'syn':
20
+ from data.syn_dataset import PairDataset
21
+ dataset = PairDataset()
22
+ elif opt.dataset_mode == 'single':
23
+ from data.single_dataset import SingleDataset
24
+ dataset = SingleDataset()
25
+ else:
26
+ raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
27
+
28
+ print("dataset [%s] was created" % (dataset.name()))
29
+ dataset.initialize(opt)
30
+ return dataset
31
+
32
+
33
+ class CustomDatasetDataLoader(BaseDataLoader):
34
+ def name(self):
35
+ return 'CustomDatasetDataLoader'
36
+
37
+ def initialize(self, opt):
38
+ BaseDataLoader.initialize(self, opt)
39
+ self.dataset = CreateDataset(opt)
40
+ self.dataloader = torch.utils.data.DataLoader(
41
+ self.dataset,
42
+ batch_size=opt.batchSize,
43
+ shuffle=not opt.serial_batches,
44
+ num_workers=int(opt.nThreads))
45
+
46
+ def load_data(self):
47
+ return self.dataloader
48
+
49
+ def __len__(self):
50
+ return min(len(self.dataset), self.opt.max_dataset_size)
data/data_loader.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ def CreateDataLoader(opt):
3
+ from data.custom_dataset_data_loader import CustomDatasetDataLoader
4
+ data_loader = CustomDatasetDataLoader()
5
+ print(data_loader.name())
6
+ data_loader.initialize(opt)
7
+ return data_loader
data/image_folder.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###############################################################################
2
+ # Code from
3
+ # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4
+ # Modified the original code so that it also loads images from the current
5
+ # directory as well as the subdirectories
6
+ ###############################################################################
7
+
8
+ import torch.utils.data as data
9
+
10
+ from PIL import Image
11
+ import os
12
+ import os.path
13
+
14
+ IMG_EXTENSIONS = [
15
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
16
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir):
25
+ images = []
26
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+
34
+ return images
35
+
36
+ def store_dataset(dir):
37
+ images = []
38
+ all_path = []
39
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
40
+
41
+ for root, _, fnames in sorted(os.walk(dir)):
42
+ for fname in fnames:
43
+ if is_image_file(fname):
44
+ path = os.path.join(root, fname)
45
+ img = Image.open(path).convert('RGB')
46
+ images.append(img)
47
+ all_path.append(path)
48
+
49
+ return images, all_path
50
+
51
+
52
+ def default_loader(path):
53
+ return Image.open(path).convert('RGB')
54
+
55
+
56
+ class ImageFolder(data.Dataset):
57
+
58
+ def __init__(self, root, transform=None, return_paths=False,
59
+ loader=default_loader):
60
+ imgs = make_dataset(root)
61
+ if len(imgs) == 0:
62
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
63
+ "Supported image extensions are: " +
64
+ ",".join(IMG_EXTENSIONS)))
65
+
66
+ self.root = root
67
+ self.imgs = imgs
68
+ self.transform = transform
69
+ self.return_paths = return_paths
70
+ self.loader = loader
71
+
72
+ def __getitem__(self, index):
73
+ path = self.imgs[index]
74
+ img = self.loader(path)
75
+ if self.transform is not None:
76
+ img = self.transform(img)
77
+ if self.return_paths:
78
+ return img, path
79
+ else:
80
+ return img
81
+
82
+ def __len__(self):
83
+ return len(self.imgs)
data/pair_dataset.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import torchvision.transforms as transforms
3
+ from data.base_dataset import BaseDataset, get_transform
4
+ from data.image_folder import make_dataset
5
+ from PIL import Image
6
+ import PIL
7
+ import random
8
+ import torch
9
+ from pdb import set_trace as st
10
+
11
+
12
+ class PairDataset(BaseDataset):
13
+ def initialize(self, opt):
14
+ self.opt = opt
15
+ self.root = opt.dataroot
16
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
17
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
18
+
19
+ self.A_paths = make_dataset(self.dir_A)
20
+ self.B_paths = make_dataset(self.dir_B)
21
+
22
+ self.A_paths = sorted(self.A_paths)
23
+ self.B_paths = sorted(self.B_paths)
24
+ self.A_size = len(self.A_paths)
25
+ self.B_size = len(self.B_paths)
26
+
27
+ transform_list = []
28
+
29
+ transform_list += [transforms.ToTensor(),
30
+ transforms.Normalize((0.5, 0.5, 0.5),
31
+ (0.5, 0.5, 0.5))]
32
+ # transform_list = [transforms.ToTensor()]
33
+
34
+ self.transform = transforms.Compose(transform_list)
35
+ # self.transform = get_transform(opt)
36
+
37
+ def __getitem__(self, index):
38
+ A_path = self.A_paths[index % self.A_size]
39
+ B_path = self.B_paths[index % self.B_size]
40
+
41
+ A_img = Image.open(A_path).convert('RGB')
42
+ B_img = Image.open(A_path.replace("low", "normal").replace("A", "B")).convert('RGB')
43
+
44
+
45
+ A_img = self.transform(A_img)
46
+ B_img = self.transform(B_img)
47
+
48
+ w = A_img.size(2)
49
+ h = A_img.size(1)
50
+ w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
51
+ h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
52
+
53
+ A_img = A_img[:, h_offset:h_offset + self.opt.fineSize,
54
+ w_offset:w_offset + self.opt.fineSize]
55
+ B_img = B_img[:, h_offset:h_offset + self.opt.fineSize,
56
+ w_offset:w_offset + self.opt.fineSize]
57
+
58
+
59
+ if self.opt.resize_or_crop == 'no':
60
+ r,g,b = A_img[0]+1, A_img[1]+1, A_img[2]+1
61
+ A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
62
+ A_gray = torch.unsqueeze(A_gray, 0)
63
+ input_img = A_img
64
+ # A_gray = (1./A_gray)/255.
65
+ else:
66
+
67
+
68
+ # A_gray = (1./A_gray)/255.
69
+ if (not self.opt.no_flip) and random.random() < 0.5:
70
+ idx = [i for i in range(A_img.size(2) - 1, -1, -1)]
71
+ idx = torch.LongTensor(idx)
72
+ A_img = A_img.index_select(2, idx)
73
+ B_img = B_img.index_select(2, idx)
74
+ if (not self.opt.no_flip) and random.random() < 0.5:
75
+ idx = [i for i in range(A_img.size(1) - 1, -1, -1)]
76
+ idx = torch.LongTensor(idx)
77
+ A_img = A_img.index_select(1, idx)
78
+ B_img = B_img.index_select(1, idx)
79
+ if (not self.opt.no_flip) and random.random() < 0.5:
80
+ times = random.randint(self.opt.low_times,self.opt.high_times)/100.
81
+ input_img = (A_img+1)/2./times
82
+ input_img = input_img*2-1
83
+ else:
84
+ input_img = A_img
85
+ r,g,b = input_img[0]+1, input_img[1]+1, input_img[2]+1
86
+ A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
87
+ A_gray = torch.unsqueeze(A_gray, 0)
88
+ return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img':input_img,
89
+ 'A_paths': A_path, 'B_paths': B_path}
90
+
91
+ def __len__(self):
92
+ return self.A_size
93
+
94
+ def name(self):
95
+ return 'PairDataset'
data/single_dataset.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import torchvision.transforms as transforms
3
+ from data.base_dataset import BaseDataset, get_transform
4
+ from data.image_folder import make_dataset
5
+ from PIL import Image
6
+
7
+
8
+ class SingleDataset(BaseDataset):
9
+ def initialize(self, opt):
10
+ self.opt = opt
11
+ self.root = opt.dataroot
12
+ self.dir_A = os.path.join(opt.dataroot)
13
+
14
+ self.A_paths = make_dataset(self.dir_A)
15
+
16
+ self.A_paths = sorted(self.A_paths)
17
+
18
+ self.transform = get_transform(opt)
19
+
20
+ def __getitem__(self, index):
21
+ A_path = self.A_paths[index]
22
+
23
+ A_img = Image.open(A_path).convert('RGB')
24
+ A_size = A_img.size
25
+ A_size = A_size = (A_size[0]//16*16, A_size[1]//16*16)
26
+ A_img = A_img.resize(A_size, Image.BICUBIC)
27
+
28
+ A_img = self.transform(A_img)
29
+
30
+ return {'A': A_img, 'A_paths': A_path}
31
+
32
+ def __len__(self):
33
+ return len(self.A_paths)
34
+
35
+ def name(self):
36
+ return 'SingleImageDataset'
data/syn_dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import torchvision.transforms as transforms
3
+ from data.base_dataset import BaseDataset, get_transform
4
+ from data.image_folder import make_dataset
5
+ from PIL import Image
6
+ import PIL
7
+ import random
8
+ import torch
9
+ from pdb import set_trace as st
10
+
11
+
12
+ class PairDataset(BaseDataset):
13
+ def initialize(self, opt):
14
+ self.opt = opt
15
+ self.root = opt.dataroot
16
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
17
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
18
+
19
+ self.A_paths = make_dataset(self.dir_A)
20
+ self.B_paths = make_dataset(self.dir_B)
21
+
22
+ self.A_paths = sorted(self.A_paths)
23
+ self.B_paths = sorted(self.B_paths)
24
+ self.A_size = len(self.A_paths)
25
+ self.B_size = len(self.B_paths)
26
+
27
+ transform_list = []
28
+
29
+ transform_list += [transforms.ToTensor(),
30
+ transforms.Normalize((0.5, 0.5, 0.5),
31
+ (0.5, 0.5, 0.5))]
32
+ # transform_list = [transforms.ToTensor()]
33
+
34
+ self.transform = transforms.Compose(transform_list)
35
+ # self.transform = get_transform(opt)
36
+
37
+ def __getitem__(self, index):
38
+ A_path = self.A_paths[index % self.A_size]
39
+ B_path = self.B_paths[index % self.B_size]
40
+
41
+ B_img = Image.open(B_path).convert('RGB')
42
+ # B_img = Image.open(A_path.replace("low", "normal").replace("A", "B")).convert('RGB')
43
+
44
+
45
+ # A_img = self.transform(A_img)
46
+ B_img = self.transform(B_img)
47
+
48
+ w = B_img.size(2)
49
+ h = B_img.size(1)
50
+ w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
51
+ h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
52
+
53
+ B_img = B_img[:, h_offset:h_offset + self.opt.fineSize,
54
+ w_offset:w_offset + self.opt.fineSize]
55
+
56
+
57
+ if self.opt.resize_or_crop == 'no':
58
+ pass
59
+ # r,g,b = A_img[0]+1, A_img[1]+1, A_img[2]+1
60
+ # A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
61
+ # A_gray = torch.unsqueeze(A_gray, 0)
62
+ # input_img = A_img
63
+ # A_gray = (1./A_gray)/255.
64
+ else:
65
+
66
+
67
+ # A_gray = (1./A_gray)/255.
68
+ if (not self.opt.no_flip) and random.random() < 0.5:
69
+ idx = [i for i in range(B_img.size(2) - 1, -1, -1)]
70
+ idx = torch.LongTensor(idx)
71
+ B_img = B_img.index_select(2, idx)
72
+ if (not self.opt.no_flip) and random.random() < 0.5:
73
+ idx = [i for i in range(B_img.size(1) - 1, -1, -1)]
74
+ idx = torch.LongTensor(idx)
75
+ B_img = B_img.index_select(1, idx)
76
+
77
+ times = random.randint(self.opt.low_times,self.opt.high_times)/100.
78
+ input_img = (B_img+1)/2./times
79
+ input_img = input_img*2-1
80
+ A_img = input_img
81
+ r,g,b = input_img[0]+1, input_img[1]+1, input_img[2]+1
82
+ A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
83
+ A_gray = torch.unsqueeze(A_gray, 0)
84
+ return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img':input_img,
85
+ 'A_paths': A_path, 'B_paths': B_path}
86
+
87
+ def __len__(self):
88
+ return self.A_size
89
+
90
+ def name(self):
91
+ return 'PairDataset'
data/unaligned_dataset.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import os.path
4
+ import torchvision.transforms as transforms
5
+ from data.base_dataset import BaseDataset, get_transform
6
+ from data.image_folder import make_dataset, store_dataset
7
+ import random
8
+ from PIL import Image
9
+ import PIL
10
+ from pdb import set_trace as st
11
+
12
+ def pad_tensor(input):
13
+
14
+ height_org, width_org = input.shape[2], input.shape[3]
15
+ divide = 16
16
+
17
+ if width_org % divide != 0 or height_org % divide != 0:
18
+
19
+ width_res = width_org % divide
20
+ height_res = height_org % divide
21
+ if width_res != 0:
22
+ width_div = divide - width_res
23
+ pad_left = int(width_div / 2)
24
+ pad_right = int(width_div - pad_left)
25
+ else:
26
+ pad_left = 0
27
+ pad_right = 0
28
+
29
+ if height_res != 0:
30
+ height_div = divide - height_res
31
+ pad_top = int(height_div / 2)
32
+ pad_bottom = int(height_div - pad_top)
33
+ else:
34
+ pad_top = 0
35
+ pad_bottom = 0
36
+
37
+ padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom))
38
+ input = padding(input).data
39
+ else:
40
+ pad_left = 0
41
+ pad_right = 0
42
+ pad_top = 0
43
+ pad_bottom = 0
44
+
45
+ height, width = input.shape[2], input.shape[3]
46
+ assert width % divide == 0, 'width cant divided by stride'
47
+ assert height % divide == 0, 'height cant divided by stride'
48
+
49
+ return input, pad_left, pad_right, pad_top, pad_bottom
50
+
51
+ def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom):
52
+ height, width = input.shape[2], input.shape[3]
53
+ return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right]
54
+
55
+
56
+ class UnalignedDataset(BaseDataset):
57
+ def initialize(self, opt):
58
+ self.opt = opt
59
+ self.root = opt.dataroot
60
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
61
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
62
+
63
+ # self.A_paths = make_dataset(self.dir_A)
64
+ # self.B_paths = make_dataset(self.dir_B)
65
+ self.A_imgs, self.A_paths = store_dataset(self.dir_A)
66
+ self.B_imgs, self.B_paths = store_dataset(self.dir_B)
67
+
68
+ # self.A_paths = sorted(self.A_paths)
69
+ # self.B_paths = sorted(self.B_paths)
70
+ self.A_size = len(self.A_paths)
71
+ self.B_size = len(self.B_paths)
72
+
73
+ self.transform = get_transform(opt)
74
+
75
+ def __getitem__(self, index):
76
+ # A_path = self.A_paths[index % self.A_size]
77
+ # B_path = self.B_paths[index % self.B_size]
78
+
79
+ # A_img = Image.open(A_path).convert('RGB')
80
+ # B_img = Image.open(B_path).convert('RGB')
81
+ A_img = self.A_imgs[index % self.A_size]
82
+ B_img = self.B_imgs[index % self.B_size]
83
+ A_path = self.A_paths[index % self.A_size]
84
+ B_path = self.B_paths[index % self.B_size]
85
+ # A_size = A_img.size
86
+ # B_size = B_img.size
87
+ # A_size = A_size = (A_size[0]//16*16, A_size[1]//16*16)
88
+ # B_size = B_size = (B_size[0]//16*16, B_size[1]//16*16)
89
+ # A_img = A_img.resize(A_size, Image.BICUBIC)
90
+ # B_img = B_img.resize(B_size, Image.BICUBIC)
91
+ # A_gray = A_img.convert('LA')
92
+ # A_gray = 255.0-A_gray
93
+
94
+ A_img = self.transform(A_img)
95
+ B_img = self.transform(B_img)
96
+
97
+
98
+ if self.opt.resize_or_crop == 'no':
99
+ r,g,b = A_img[0]+1, A_img[1]+1, A_img[2]+1
100
+ A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
101
+ A_gray = torch.unsqueeze(A_gray, 0)
102
+ input_img = A_img
103
+ # A_gray = (1./A_gray)/255.
104
+ else:
105
+ w = A_img.size(2)
106
+ h = A_img.size(1)
107
+
108
+ # A_gray = (1./A_gray)/255.
109
+ if (not self.opt.no_flip) and random.random() < 0.5:
110
+ idx = [i for i in range(A_img.size(2) - 1, -1, -1)]
111
+ idx = torch.LongTensor(idx)
112
+ A_img = A_img.index_select(2, idx)
113
+ B_img = B_img.index_select(2, idx)
114
+ if (not self.opt.no_flip) and random.random() < 0.5:
115
+ idx = [i for i in range(A_img.size(1) - 1, -1, -1)]
116
+ idx = torch.LongTensor(idx)
117
+ A_img = A_img.index_select(1, idx)
118
+ B_img = B_img.index_select(1, idx)
119
+ if self.opt.vary == 1 and (not self.opt.no_flip) and random.random() < 0.5:
120
+ times = random.randint(self.opt.low_times,self.opt.high_times)/100.
121
+ input_img = (A_img+1)/2./times
122
+ input_img = input_img*2-1
123
+ else:
124
+ input_img = A_img
125
+ if self.opt.lighten:
126
+ B_img = (B_img + 1)/2.
127
+ B_img = (B_img - torch.min(B_img))/(torch.max(B_img) - torch.min(B_img))
128
+ B_img = B_img*2. -1
129
+ r,g,b = input_img[0]+1, input_img[1]+1, input_img[2]+1
130
+ A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
131
+ A_gray = torch.unsqueeze(A_gray, 0)
132
+ return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img': input_img,
133
+ 'A_paths': A_path, 'B_paths': B_path}
134
+
135
+ def __len__(self):
136
+ return max(self.A_size, self.B_size)
137
+
138
+ def name(self):
139
+ return 'UnalignedDataset'
140
+
141
+
data/unaligned_random_crop.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os.path
3
+ import torchvision.transforms as transforms
4
+ from data.base_dataset import BaseDataset, get_transform
5
+ from data.image_folder import make_dataset
6
+ import random
7
+ from PIL import Image
8
+ import PIL
9
+ from pdb import set_trace as st
10
+
11
+
12
+ class UnalignedDataset(BaseDataset):
13
+ def initialize(self, opt):
14
+ self.opt = opt
15
+ self.root = opt.dataroot
16
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
17
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
18
+
19
+ self.A_paths = make_dataset(self.dir_A)
20
+ self.B_paths = make_dataset(self.dir_B)
21
+
22
+ self.A_paths = sorted(self.A_paths)
23
+ self.B_paths = sorted(self.B_paths)
24
+ self.A_size = len(self.A_paths)
25
+ self.B_size = len(self.B_paths)
26
+
27
+ transform_list = [transforms.ToTensor(),
28
+ transforms.Normalize((0.5, 0.5, 0.5),
29
+ (0.5, 0.5, 0.5))]
30
+
31
+ self.transform = transforms.Compose(transform_list)
32
+ # self.transform = get_transform(opt)
33
+
34
+ def __getitem__(self, index):
35
+ A_path = self.A_paths[index % self.A_size]
36
+ B_path = self.B_paths[index % self.B_size]
37
+
38
+ A_img = Image.open(A_path).convert('RGB')
39
+ B_img = Image.open(B_path).convert('RGB')
40
+ A_size = A_img.size
41
+ B_size = B_img.size
42
+ A_size = A_size = (A_size[0]//16*16, A_size[1]//16*16)
43
+ B_size = B_size = (B_size[0]//16*16, B_size[1]//16*16)
44
+ A_img = A_img.resize(A_size, Image.BICUBIC)
45
+ B_img = B_img.resize(B_size, Image.BICUBIC)
46
+
47
+
48
+ A_img = self.transform(A_img)
49
+ B_img = self.transform(B_img)
50
+
51
+ if self.opt.resize_or_crop == 'no':
52
+ pass
53
+ else:
54
+ w = A_img.size(2)
55
+ h = A_img.size(1)
56
+ size = [8,16,22]
57
+ from random import randint
58
+ size_index = randint(0,2)
59
+ Cropsize = size[size_index]*16
60
+
61
+ w_offset = random.randint(0, max(0, w - Cropsize - 1))
62
+ h_offset = random.randint(0, max(0, h - Cropsize - 1))
63
+
64
+ A_img = A_img[:, h_offset:h_offset + Cropsize,
65
+ w_offset:w_offset + Cropsize]
66
+
67
+ if (not self.opt.no_flip) and random.random() < 0.5:
68
+ idx = [i for i in range(A_img.size(2) - 1, -1, -1)]
69
+ idx = torch.LongTensor(idx)
70
+ A_img = A_img.index_select(2, idx)
71
+ B_img = B_img.index_select(2, idx)
72
+ if (not self.opt.no_flip) and random.random() < 0.5:
73
+ idx = [i for i in range(A_img.size(1) - 1, -1, -1)]
74
+ idx = torch.LongTensor(idx)
75
+ A_img = A_img.index_select(1, idx)
76
+ B_img = B_img.index_select(1, idx)
77
+
78
+ return {'A': A_img, 'B': B_img,
79
+ 'A_paths': A_path, 'B_paths': B_path}
80
+
81
+ def __len__(self):
82
+ return max(self.A_size, self.B_size)
83
+
84
+ def name(self):
85
+ return 'UnalignedDataset'
datasets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
datasets/bibtex/cityscapes.tex ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ @inproceedings{Cordts2016Cityscapes,
2
+ title={The Cityscapes Dataset for Semantic Urban Scene Understanding},
3
+ author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
4
+ booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
5
+ year={2016}
6
+ }
datasets/bibtex/facades.tex ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ @INPROCEEDINGS{Tylecek13,
2
+ author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra},
3
+ title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure},
4
+ booktitle = {Proc. GCPR},
5
+ year = {2013},
6
+ address = {Saarbrucken, Germany},
7
+ }
datasets/bibtex/handbags.tex ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @inproceedings{zhu2016generative,
2
+ title={Generative Visual Manipulation on the Natural Image Manifold},
3
+ author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.},
4
+ booktitle={Proceedings of European Conference on Computer Vision (ECCV)},
5
+ year={2016}
6
+ }
7
+
8
+ @InProceedings{xie15hed,
9
+ author = {"Xie, Saining and Tu, Zhuowen"},
10
+ Title = {Holistically-Nested Edge Detection},
11
+ Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
12
+ Year = {2015},
13
+ }
datasets/bibtex/shoes.tex ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @InProceedings{fine-grained,
2
+ author = {A. Yu and K. Grauman},
3
+ title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning},
4
+ booktitle = {Computer Vision and Pattern Recognition (CVPR)},
5
+ month = {June},
6
+ year = {2014}
7
+ }
8
+
9
+ @InProceedings{xie15hed,
10
+ author = {"Xie, Saining and Tu, Zhuowen"},
11
+ Title = {Holistically-Nested Edge Detection},
12
+ Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
13
+ Year = {2015},
14
+ }
datasets/combine_A_and_B.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pdb import set_trace as st
2
+ import os
3
+ import numpy as np
4
+ import cv2
5
+ import argparse
6
+
7
+ parser = argparse.ArgumentParser('create image pairs')
8
+ parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
9
+ parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
10
+ parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
11
+ parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
12
+ parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true')
13
+ args = parser.parse_args()
14
+
15
+ for arg in vars(args):
16
+ print('[%s] = ' % arg, getattr(args, arg))
17
+
18
+ splits = os.listdir(args.fold_A)
19
+
20
+ for sp in splits:
21
+ img_fold_A = os.path.join(args.fold_A, sp)
22
+ img_fold_B = os.path.join(args.fold_B, sp)
23
+ img_list = os.listdir(img_fold_A)
24
+ if args.use_AB:
25
+ img_list = [img_path for img_path in img_list if '_A.' in img_path]
26
+
27
+ num_imgs = min(args.num_imgs, len(img_list))
28
+ print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
29
+ img_fold_AB = os.path.join(args.fold_AB, sp)
30
+ if not os.path.isdir(img_fold_AB):
31
+ os.makedirs(img_fold_AB)
32
+ print('split = %s, number of images = %d' % (sp, num_imgs))
33
+ for n in range(num_imgs):
34
+ name_A = img_list[n]
35
+ path_A = os.path.join(img_fold_A, name_A)
36
+ if args.use_AB:
37
+ name_B = name_A.replace('_A.', '_B.')
38
+ else:
39
+ name_B = name_A
40
+ path_B = os.path.join(img_fold_B, name_B)
41
+ if os.path.isfile(path_A) and os.path.isfile(path_B):
42
+ name_AB = name_A
43
+ if args.use_AB:
44
+ name_AB = name_AB.replace('_A.', '.') # remove _A
45
+ path_AB = os.path.join(img_fold_AB, name_AB)
46
+ im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR)
47
+ im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR)
48
+ im_AB = np.concatenate([im_A, im_B], 1)
49
+ cv2.imwrite(path_AB, im_AB)
datasets/download_cyclegan_dataset.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FILE=$1
2
+
3
+ if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
4
+ echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
5
+ exit 1
6
+ fi
7
+
8
+ URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
9
+ ZIP_FILE=./datasets/$FILE.zip
10
+ TARGET_DIR=./datasets/$FILE/
11
+ wget -N $URL -O $ZIP_FILE
12
+ mkdir $TARGET_DIR
13
+ unzip $ZIP_FILE -d ./datasets/
14
+ rm $ZIP_FILE
datasets/download_pix2pix_dataset.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ FILE=$1
2
+ URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
3
+ TAR_FILE=./datasets/$FILE.tar.gz
4
+ TARGET_DIR=./datasets/$FILE/
5
+ wget -N $URL -O $TAR_FILE
6
+ mkdir $TARGET_DIR
7
+ tar -zxvf $TAR_FILE -C ./datasets/
8
+ rm $TAR_FILE
imgs/edges2cats.jpg ADDED
imgs/horse2zebra.gif ADDED

Git LFS Details

  • SHA256: 16a76adedd309c46ba6ed63f89b14130c4a671fd6febc26fb0372a1ccf16c7aa
  • Pointer size: 132 Bytes
  • Size of remote file: 7.69 MB
lib/nn/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modules import *
2
+ from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
lib/nn/modules/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
lib/nn/modules/batchnorm.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ # customed batch norm statistics
49
+ self._moving_average_fraction = 1. - momentum
50
+ self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
51
+ self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
52
+ self.register_buffer('_running_iter', torch.ones(1))
53
+ self._tmp_running_mean = self.running_mean.clone() * self._running_iter
54
+ self._tmp_running_var = self.running_var.clone() * self._running_iter
55
+
56
+ def forward(self, input):
57
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
58
+ if not (self._is_parallel and self.training):
59
+ return F.batch_norm(
60
+ input, self.running_mean, self.running_var, self.weight, self.bias,
61
+ self.training, self.momentum, self.eps)
62
+
63
+ # Resize the input to (B, C, -1).
64
+ input_shape = input.size()
65
+ input = input.view(input.size(0), self.num_features, -1)
66
+
67
+ # Compute the sum and square-sum.
68
+ sum_size = input.size(0) * input.size(2)
69
+ input_sum = _sum_ft(input)
70
+ input_ssum = _sum_ft(input ** 2)
71
+
72
+ # Reduce-and-broadcast the statistics.
73
+ if self._parallel_id == 0:
74
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
75
+ else:
76
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
77
+
78
+ # Compute the output.
79
+ if self.affine:
80
+ # MJY:: Fuse the multiplication for speed.
81
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
82
+ else:
83
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
84
+
85
+ # Reshape it.
86
+ return output.view(input_shape)
87
+
88
+ def __data_parallel_replicate__(self, ctx, copy_id):
89
+ self._is_parallel = True
90
+ self._parallel_id = copy_id
91
+
92
+ # parallel_id == 0 means master device.
93
+ if self._parallel_id == 0:
94
+ ctx.sync_master = self._sync_master
95
+ else:
96
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
97
+
98
+ def _data_parallel_master(self, intermediates):
99
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
100
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
101
+
102
+ to_reduce = [i[1][:2] for i in intermediates]
103
+ to_reduce = [j for i in to_reduce for j in i] # flatten
104
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
105
+
106
+ sum_size = sum([i[1].sum_size for i in intermediates])
107
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
108
+
109
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
110
+
111
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
112
+
113
+ outputs = []
114
+ for i, rec in enumerate(intermediates):
115
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
116
+
117
+ return outputs
118
+
119
+ def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
120
+ """return *dest* by `dest := dest*alpha + delta*beta + bias`"""
121
+ return dest * alpha + delta * beta + bias
122
+
123
+ def _compute_mean_std(self, sum_, ssum, size):
124
+ """Compute the mean and standard-deviation with sum and square-sum. This method
125
+ also maintains the moving average on the master device."""
126
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
127
+ mean = sum_ / size
128
+ sumvar = ssum - sum_ * mean
129
+ unbias_var = sumvar / (size - 1)
130
+ bias_var = sumvar / size
131
+
132
+ self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
133
+ self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
134
+ self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
135
+
136
+ self.running_mean = self._tmp_running_mean / self._running_iter
137
+ self.running_var = self._tmp_running_var / self._running_iter
138
+
139
+ return mean, bias_var.clamp(self.eps) ** -0.5
140
+
141
+
142
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
143
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
144
+ mini-batch.
145
+
146
+ .. math::
147
+
148
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
149
+
150
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
151
+ standard-deviation are reduced across all devices during training.
152
+
153
+ For example, when one uses `nn.DataParallel` to wrap the network during
154
+ training, PyTorch's implementation normalize the tensor on each device using
155
+ the statistics only on that device, which accelerated the computation and
156
+ is also easy to implement, but the statistics might be inaccurate.
157
+ Instead, in this synchronized version, the statistics will be computed
158
+ over all training samples distributed on multiple devices.
159
+
160
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
161
+ as the built-in PyTorch implementation.
162
+
163
+ The mean and standard-deviation are calculated per-dimension over
164
+ the mini-batches and gamma and beta are learnable parameter vectors
165
+ of size C (where C is the input size).
166
+
167
+ During training, this layer keeps a running estimate of its computed mean
168
+ and variance. The running sum is kept with a default momentum of 0.1.
169
+
170
+ During evaluation, this running mean/variance is used for normalization.
171
+
172
+ Because the BatchNorm is done over the `C` dimension, computing statistics
173
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
174
+
175
+ Args:
176
+ num_features: num_features from an expected input of size
177
+ `batch_size x num_features [x width]`
178
+ eps: a value added to the denominator for numerical stability.
179
+ Default: 1e-5
180
+ momentum: the value used for the running_mean and running_var
181
+ computation. Default: 0.1
182
+ affine: a boolean value that when set to ``True``, gives the layer learnable
183
+ affine parameters. Default: ``True``
184
+
185
+ Shape:
186
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
187
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
188
+
189
+ Examples:
190
+ >>> # With Learnable Parameters
191
+ >>> m = SynchronizedBatchNorm1d(100)
192
+ >>> # Without Learnable Parameters
193
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
194
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
195
+ >>> output = m(input)
196
+ """
197
+
198
+ def _check_input_dim(self, input):
199
+ if input.dim() != 2 and input.dim() != 3:
200
+ raise ValueError('expected 2D or 3D input (got {}D input)'
201
+ .format(input.dim()))
202
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
203
+
204
+
205
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
206
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
207
+ of 3d inputs
208
+
209
+ .. math::
210
+
211
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
212
+
213
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
214
+ standard-deviation are reduced across all devices during training.
215
+
216
+ For example, when one uses `nn.DataParallel` to wrap the network during
217
+ training, PyTorch's implementation normalize the tensor on each device using
218
+ the statistics only on that device, which accelerated the computation and
219
+ is also easy to implement, but the statistics might be inaccurate.
220
+ Instead, in this synchronized version, the statistics will be computed
221
+ over all training samples distributed on multiple devices.
222
+
223
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
224
+ as the built-in PyTorch implementation.
225
+
226
+ The mean and standard-deviation are calculated per-dimension over
227
+ the mini-batches and gamma and beta are learnable parameter vectors
228
+ of size C (where C is the input size).
229
+
230
+ During training, this layer keeps a running estimate of its computed mean
231
+ and variance. The running sum is kept with a default momentum of 0.1.
232
+
233
+ During evaluation, this running mean/variance is used for normalization.
234
+
235
+ Because the BatchNorm is done over the `C` dimension, computing statistics
236
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
237
+
238
+ Args:
239
+ num_features: num_features from an expected input of
240
+ size batch_size x num_features x height x width
241
+ eps: a value added to the denominator for numerical stability.
242
+ Default: 1e-5
243
+ momentum: the value used for the running_mean and running_var
244
+ computation. Default: 0.1
245
+ affine: a boolean value that when set to ``True``, gives the layer learnable
246
+ affine parameters. Default: ``True``
247
+
248
+ Shape:
249
+ - Input: :math:`(N, C, H, W)`
250
+ - Output: :math:`(N, C, H, W)` (same shape as input)
251
+
252
+ Examples:
253
+ >>> # With Learnable Parameters
254
+ >>> m = SynchronizedBatchNorm2d(100)
255
+ >>> # Without Learnable Parameters
256
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
257
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
258
+ >>> output = m(input)
259
+ """
260
+
261
+ def _check_input_dim(self, input):
262
+ if input.dim() != 4:
263
+ raise ValueError('expected 4D input (got {}D input)'
264
+ .format(input.dim()))
265
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
266
+
267
+
268
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
269
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
270
+ of 4d inputs
271
+
272
+ .. math::
273
+
274
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
275
+
276
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
277
+ standard-deviation are reduced across all devices during training.
278
+
279
+ For example, when one uses `nn.DataParallel` to wrap the network during
280
+ training, PyTorch's implementation normalize the tensor on each device using
281
+ the statistics only on that device, which accelerated the computation and
282
+ is also easy to implement, but the statistics might be inaccurate.
283
+ Instead, in this synchronized version, the statistics will be computed
284
+ over all training samples distributed on multiple devices.
285
+
286
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
287
+ as the built-in PyTorch implementation.
288
+
289
+ The mean and standard-deviation are calculated per-dimension over
290
+ the mini-batches and gamma and beta are learnable parameter vectors
291
+ of size C (where C is the input size).
292
+
293
+ During training, this layer keeps a running estimate of its computed mean
294
+ and variance. The running sum is kept with a default momentum of 0.1.
295
+
296
+ During evaluation, this running mean/variance is used for normalization.
297
+
298
+ Because the BatchNorm is done over the `C` dimension, computing statistics
299
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
300
+ or Spatio-temporal BatchNorm
301
+
302
+ Args:
303
+ num_features: num_features from an expected input of
304
+ size batch_size x num_features x depth x height x width
305
+ eps: a value added to the denominator for numerical stability.
306
+ Default: 1e-5
307
+ momentum: the value used for the running_mean and running_var
308
+ computation. Default: 0.1
309
+ affine: a boolean value that when set to ``True``, gives the layer learnable
310
+ affine parameters. Default: ``True``
311
+
312
+ Shape:
313
+ - Input: :math:`(N, C, D, H, W)`
314
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
315
+
316
+ Examples:
317
+ >>> # With Learnable Parameters
318
+ >>> m = SynchronizedBatchNorm3d(100)
319
+ >>> # Without Learnable Parameters
320
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
321
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
322
+ >>> output = m(input)
323
+ """
324
+
325
+ def _check_input_dim(self, input):
326
+ if input.dim() != 5:
327
+ raise ValueError('expected 5D input (got {}D input)'
328
+ .format(input.dim()))
329
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
lib/nn/modules/comm.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def register_slave(self, identifier):
79
+ """
80
+ Register an slave device.
81
+
82
+ Args:
83
+ identifier: an identifier, usually is the device id.
84
+
85
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
86
+
87
+ """
88
+ if self._activated:
89
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
90
+ self._activated = False
91
+ self._registry.clear()
92
+ future = FutureResult()
93
+ self._registry[identifier] = _MasterRegistry(future)
94
+ return SlavePipe(identifier, self._queue, future)
95
+
96
+ def run_master(self, master_msg):
97
+ """
98
+ Main entry for the master device in each forward pass.
99
+ The messages were first collected from each devices (including the master device), and then
100
+ an callback will be invoked to compute the message to be sent back to each devices
101
+ (including the master device).
102
+
103
+ Args:
104
+ master_msg: the message that the master want to send to itself. This will be placed as the first
105
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106
+
107
+ Returns: the message to be sent back to the master device.
108
+
109
+ """
110
+ self._activated = True
111
+
112
+ intermediates = [(0, master_msg)]
113
+ for i in range(self.nr_slaves):
114
+ intermediates.append(self._queue.get())
115
+
116
+ results = self._master_callback(intermediates)
117
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
118
+
119
+ for i, res in results:
120
+ if i == 0:
121
+ continue
122
+ self._registry[i].result.put(res)
123
+
124
+ for i in range(self.nr_slaves):
125
+ assert self._queue.get() is True
126
+
127
+ return results[0][1]
128
+
129
+ @property
130
+ def nr_slaves(self):
131
+ return len(self._registry)
lib/nn/modules/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
lib/nn/modules/tests/test_numeric_batchnorm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_numeric_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.autograd import Variable
14
+
15
+ from sync_batchnorm.unittest import TorchTestCase
16
+
17
+
18
+ def handy_var(a, unbias=True):
19
+ n = a.size(0)
20
+ asum = a.sum(dim=0)
21
+ as_sum = (a ** 2).sum(dim=0) # a square sum
22
+ sumvar = as_sum - asum * asum / n
23
+ if unbias:
24
+ return sumvar / (n - 1)
25
+ else:
26
+ return sumvar / n
27
+
28
+
29
+ class NumericTestCase(TorchTestCase):
30
+ def testNumericBatchNorm(self):
31
+ a = torch.rand(16, 10)
32
+ bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
33
+ bn.train()
34
+
35
+ a_var1 = Variable(a, requires_grad=True)
36
+ b_var1 = bn(a_var1)
37
+ loss1 = b_var1.sum()
38
+ loss1.backward()
39
+
40
+ a_var2 = Variable(a, requires_grad=True)
41
+ a_mean2 = a_var2.mean(dim=0, keepdim=True)
42
+ a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43
+ # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44
+ b_var2 = (a_var2 - a_mean2) / a_std2
45
+ loss2 = b_var2.sum()
46
+ loss2.backward()
47
+
48
+ self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49
+ self.assertTensorClose(bn.running_var, handy_var(a))
50
+ self.assertTensorClose(a_var1.data, a_var2.data)
51
+ self.assertTensorClose(b_var1.data, b_var2.data)
52
+ self.assertTensorClose(a_var1.grad, a_var2.grad)
53
+
54
+
55
+ if __name__ == '__main__':
56
+ unittest.main()
lib/nn/modules/tests/test_sync_batchnorm.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_sync_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.autograd import Variable
14
+
15
+ from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
16
+ from sync_batchnorm.unittest import TorchTestCase
17
+
18
+
19
+ def handy_var(a, unbias=True):
20
+ n = a.size(0)
21
+ asum = a.sum(dim=0)
22
+ as_sum = (a ** 2).sum(dim=0) # a square sum
23
+ sumvar = as_sum - asum * asum / n
24
+ if unbias:
25
+ return sumvar / (n - 1)
26
+ else:
27
+ return sumvar / n
28
+
29
+
30
+ def _find_bn(module):
31
+ for m in module.modules():
32
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33
+ return m
34
+
35
+
36
+ class SyncTestCase(TorchTestCase):
37
+ def _syncParameters(self, bn1, bn2):
38
+ bn1.reset_parameters()
39
+ bn2.reset_parameters()
40
+ if bn1.affine and bn2.affine:
41
+ bn2.weight.data.copy_(bn1.weight.data)
42
+ bn2.bias.data.copy_(bn1.bias.data)
43
+
44
+ def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45
+ """Check the forward and backward for the customized batch normalization."""
46
+ bn1.train(mode=is_train)
47
+ bn2.train(mode=is_train)
48
+
49
+ if cuda:
50
+ input = input.cuda()
51
+
52
+ self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53
+
54
+ input1 = Variable(input, requires_grad=True)
55
+ output1 = bn1(input1)
56
+ output1.sum().backward()
57
+ input2 = Variable(input, requires_grad=True)
58
+ output2 = bn2(input2)
59
+ output2.sum().backward()
60
+
61
+ self.assertTensorClose(input1.data, input2.data)
62
+ self.assertTensorClose(output1.data, output2.data)
63
+ self.assertTensorClose(input1.grad, input2.grad)
64
+ self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65
+ self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66
+
67
+ def testSyncBatchNormNormalTrain(self):
68
+ bn = nn.BatchNorm1d(10)
69
+ sync_bn = SynchronizedBatchNorm1d(10)
70
+
71
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72
+
73
+ def testSyncBatchNormNormalEval(self):
74
+ bn = nn.BatchNorm1d(10)
75
+ sync_bn = SynchronizedBatchNorm1d(10)
76
+
77
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78
+
79
+ def testSyncBatchNormSyncTrain(self):
80
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83
+
84
+ bn.cuda()
85
+ sync_bn.cuda()
86
+
87
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88
+
89
+ def testSyncBatchNormSyncEval(self):
90
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93
+
94
+ bn.cuda()
95
+ sync_bn.cuda()
96
+
97
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98
+
99
+ def testSyncBatchNorm2DSyncTrain(self):
100
+ bn = nn.BatchNorm2d(10)
101
+ sync_bn = SynchronizedBatchNorm2d(10)
102
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103
+
104
+ bn.cuda()
105
+ sync_bn.cuda()
106
+
107
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108
+
109
+
110
+ if __name__ == '__main__':
111
+ unittest.main()
lib/nn/modules/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+
13
+ import numpy as np
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def as_numpy(v):
18
+ if isinstance(v, Variable):
19
+ v = v.data
20
+ return v.cpu().numpy()
21
+
22
+
23
+ class TorchTestCase(unittest.TestCase):
24
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25
+ npa, npb = as_numpy(a), as_numpy(b)
26
+ self.assertTrue(
27
+ np.allclose(npa, npb, atol=atol),
28
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29
+ )
lib/nn/parallel/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
lib/nn/parallel/data_parallel.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf8 -*-
2
+
3
+ import torch.cuda as cuda
4
+ import torch.nn as nn
5
+ import torch
6
+ import collections
7
+ from torch.nn.parallel._functions import Gather
8
+
9
+
10
+ __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
11
+
12
+
13
+ def async_copy_to(obj, dev, main_stream=None):
14
+ if torch.is_tensor(obj):
15
+ v = obj.cuda(dev, non_blocking=True)
16
+ if main_stream is not None:
17
+ v.data.record_stream(main_stream)
18
+ return v
19
+ elif isinstance(obj, collections.Mapping):
20
+ return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
21
+ elif isinstance(obj, collections.Sequence):
22
+ return [async_copy_to(o, dev, main_stream) for o in obj]
23
+ else:
24
+ return obj
25
+
26
+
27
+ def dict_gather(outputs, target_device, dim=0):
28
+ """
29
+ Gathers variables from different GPUs on a specified device
30
+ (-1 means the CPU), with dictionary support.
31
+ """
32
+ def gather_map(outputs):
33
+ out = outputs[0]
34
+ if torch.is_tensor(out):
35
+ # MJY(20180330) HACK:: force nr_dims > 0
36
+ if out.dim() == 0:
37
+ outputs = [o.unsqueeze(0) for o in outputs]
38
+ return Gather.apply(target_device, dim, *outputs)
39
+ elif out is None:
40
+ return None
41
+ elif isinstance(out, collections.Mapping):
42
+ return {k: gather_map([o[k] for o in outputs]) for k in out}
43
+ elif isinstance(out, collections.Sequence):
44
+ return type(out)(map(gather_map, zip(*outputs)))
45
+ return gather_map(outputs)
46
+
47
+
48
+ class DictGatherDataParallel(nn.DataParallel):
49
+ def gather(self, outputs, output_device):
50
+ return dict_gather(outputs, output_device, dim=self.dim)
51
+
52
+
53
+ class UserScatteredDataParallel(DictGatherDataParallel):
54
+ def scatter(self, inputs, kwargs, device_ids):
55
+ assert len(inputs) == 1
56
+ inputs = inputs[0]
57
+ inputs = _async_copy_stream(inputs, device_ids)
58
+ inputs = [[i] for i in inputs]
59
+ assert len(kwargs) == 0
60
+ kwargs = [{} for _ in range(len(inputs))]
61
+
62
+ return inputs, kwargs
63
+
64
+
65
+ def user_scattered_collate(batch):
66
+ return batch
67
+
68
+
69
+ def _async_copy(inputs, device_ids):
70
+ nr_devs = len(device_ids)
71
+ assert type(inputs) in (tuple, list)
72
+ assert len(inputs) == nr_devs
73
+
74
+ outputs = []
75
+ for i, dev in zip(inputs, device_ids):
76
+ with cuda.device(dev):
77
+ outputs.append(async_copy_to(i, dev))
78
+
79
+ return tuple(outputs)
80
+
81
+
82
+ def _async_copy_stream(inputs, device_ids):
83
+ nr_devs = len(device_ids)
84
+ assert type(inputs) in (tuple, list)
85
+ assert len(inputs) == nr_devs
86
+
87
+ outputs = []
88
+ streams = [_get_stream(d) for d in device_ids]
89
+ for i, dev, stream in zip(inputs, device_ids, streams):
90
+ with cuda.device(dev):
91
+ main_stream = cuda.current_stream()
92
+ with cuda.stream(stream):
93
+ outputs.append(async_copy_to(i, dev, main_stream=main_stream))
94
+ main_stream.wait_stream(stream)
95
+
96
+ return outputs
97
+
98
+
99
+ """Adapted from: torch/nn/parallel/_functions.py"""
100
+ # background streams used for copying
101
+ _streams = None
102
+
103
+
104
+ def _get_stream(device):
105
+ """Gets a background stream for copying between CPU and GPU"""
106
+ global _streams
107
+ if device == -1:
108
+ return None
109
+ if _streams is None:
110
+ _streams = [None] * cuda.device_count()
111
+ if _streams[device] is None: _streams[device] = cuda.Stream(device)
112
+ return _streams[device]
lib/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .th import *
lib/utils/data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ from .dataset import Dataset, TensorDataset, ConcatDataset
3
+ from .dataloader import DataLoader
lib/utils/data/dataloader.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.multiprocessing as multiprocessing
3
+ from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
4
+ _remove_worker_pids, _error_if_any_worker_fails
5
+ from .sampler import SequentialSampler, RandomSampler, BatchSampler
6
+ import signal
7
+ import functools
8
+ import collections
9
+ import re
10
+ import sys
11
+ import threading
12
+ import traceback
13
+ from torch._six import string_classes, int_classes
14
+ import numpy as np
15
+
16
+ if sys.version_info[0] == 2:
17
+ import Queue as queue
18
+ else:
19
+ import queue
20
+
21
+
22
+ class ExceptionWrapper(object):
23
+ r"Wraps an exception plus traceback to communicate across threads"
24
+
25
+ def __init__(self, exc_info):
26
+ self.exc_type = exc_info[0]
27
+ self.exc_msg = "".join(traceback.format_exception(*exc_info))
28
+
29
+
30
+ _use_shared_memory = False
31
+ """Whether to use shared memory in default_collate"""
32
+
33
+
34
+ def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
35
+ global _use_shared_memory
36
+ _use_shared_memory = True
37
+
38
+ # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
39
+ # module's handlers are executed after Python returns from C low-level
40
+ # handlers, likely when the same fatal signal happened again already.
41
+ # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
42
+ _set_worker_signal_handlers()
43
+
44
+ torch.set_num_threads(1)
45
+ torch.manual_seed(seed)
46
+ np.random.seed(seed)
47
+
48
+ if init_fn is not None:
49
+ init_fn(worker_id)
50
+
51
+ while True:
52
+ r = index_queue.get()
53
+ if r is None:
54
+ break
55
+ idx, batch_indices = r
56
+ try:
57
+ samples = collate_fn([dataset[i] for i in batch_indices])
58
+ except Exception:
59
+ data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
60
+ else:
61
+ data_queue.put((idx, samples))
62
+
63
+
64
+ def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
65
+ if pin_memory:
66
+ torch.cuda.set_device(device_id)
67
+
68
+ while True:
69
+ try:
70
+ r = in_queue.get()
71
+ except Exception:
72
+ if done_event.is_set():
73
+ return
74
+ raise
75
+ if r is None:
76
+ break
77
+ if isinstance(r[1], ExceptionWrapper):
78
+ out_queue.put(r)
79
+ continue
80
+ idx, batch = r
81
+ try:
82
+ if pin_memory:
83
+ batch = pin_memory_batch(batch)
84
+ except Exception:
85
+ out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
86
+ else:
87
+ out_queue.put((idx, batch))
88
+
89
+ numpy_type_map = {
90
+ 'float64': torch.DoubleTensor,
91
+ 'float32': torch.FloatTensor,
92
+ 'float16': torch.HalfTensor,
93
+ 'int64': torch.LongTensor,
94
+ 'int32': torch.IntTensor,
95
+ 'int16': torch.ShortTensor,
96
+ 'int8': torch.CharTensor,
97
+ 'uint8': torch.ByteTensor,
98
+ }
99
+
100
+
101
+ def default_collate(batch):
102
+ "Puts each data field into a tensor with outer dimension batch size"
103
+
104
+ error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
105
+ elem_type = type(batch[0])
106
+ if torch.is_tensor(batch[0]):
107
+ out = None
108
+ if _use_shared_memory:
109
+ # If we're in a background process, concatenate directly into a
110
+ # shared memory tensor to avoid an extra copy
111
+ numel = sum([x.numel() for x in batch])
112
+ storage = batch[0].storage()._new_shared(numel)
113
+ out = batch[0].new(storage)
114
+ return torch.stack(batch, 0, out=out)
115
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
116
+ and elem_type.__name__ != 'string_':
117
+ elem = batch[0]
118
+ if elem_type.__name__ == 'ndarray':
119
+ # array of string classes and object
120
+ if re.search('[SaUO]', elem.dtype.str) is not None:
121
+ raise TypeError(error_msg.format(elem.dtype))
122
+
123
+ return torch.stack([torch.from_numpy(b) for b in batch], 0)
124
+ if elem.shape == (): # scalars
125
+ py_type = float if elem.dtype.name.startswith('float') else int
126
+ return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
127
+ elif isinstance(batch[0], int_classes):
128
+ return torch.LongTensor(batch)
129
+ elif isinstance(batch[0], float):
130
+ return torch.DoubleTensor(batch)
131
+ elif isinstance(batch[0], string_classes):
132
+ return batch
133
+ elif isinstance(batch[0], collections.Mapping):
134
+ return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
135
+ elif isinstance(batch[0], collections.Sequence):
136
+ transposed = zip(*batch)
137
+ return [default_collate(samples) for samples in transposed]
138
+
139
+ raise TypeError((error_msg.format(type(batch[0]))))
140
+
141
+
142
+ def pin_memory_batch(batch):
143
+ if torch.is_tensor(batch):
144
+ return batch.pin_memory()
145
+ elif isinstance(batch, string_classes):
146
+ return batch
147
+ elif isinstance(batch, collections.Mapping):
148
+ return {k: pin_memory_batch(sample) for k, sample in batch.items()}
149
+ elif isinstance(batch, collections.Sequence):
150
+ return [pin_memory_batch(sample) for sample in batch]
151
+ else:
152
+ return batch
153
+
154
+
155
+ _SIGCHLD_handler_set = False
156
+ """Whether SIGCHLD handler is set for DataLoader worker failures. Only one
157
+ handler needs to be set for all DataLoaders in a process."""
158
+
159
+
160
+ def _set_SIGCHLD_handler():
161
+ # Windows doesn't support SIGCHLD handler
162
+ if sys.platform == 'win32':
163
+ return
164
+ # can't set signal in child threads
165
+ if not isinstance(threading.current_thread(), threading._MainThread):
166
+ return
167
+ global _SIGCHLD_handler_set
168
+ if _SIGCHLD_handler_set:
169
+ return
170
+ previous_handler = signal.getsignal(signal.SIGCHLD)
171
+ if not callable(previous_handler):
172
+ previous_handler = None
173
+
174
+ def handler(signum, frame):
175
+ # This following call uses `waitid` with WNOHANG from C side. Therefore,
176
+ # Python can still get and update the process status successfully.
177
+ _error_if_any_worker_fails()
178
+ if previous_handler is not None:
179
+ previous_handler(signum, frame)
180
+
181
+ signal.signal(signal.SIGCHLD, handler)
182
+ _SIGCHLD_handler_set = True
183
+
184
+
185
+ class DataLoaderIter(object):
186
+ "Iterates once over the DataLoader's dataset, as specified by the sampler"
187
+
188
+ def __init__(self, loader):
189
+ self.dataset = loader.dataset
190
+ self.collate_fn = loader.collate_fn
191
+ self.batch_sampler = loader.batch_sampler
192
+ self.num_workers = loader.num_workers
193
+ self.pin_memory = loader.pin_memory and torch.cuda.is_available()
194
+ self.timeout = loader.timeout
195
+ self.done_event = threading.Event()
196
+
197
+ self.sample_iter = iter(self.batch_sampler)
198
+
199
+ if self.num_workers > 0:
200
+ self.worker_init_fn = loader.worker_init_fn
201
+ self.index_queue = multiprocessing.SimpleQueue()
202
+ self.worker_result_queue = multiprocessing.SimpleQueue()
203
+ self.batches_outstanding = 0
204
+ self.worker_pids_set = False
205
+ self.shutdown = False
206
+ self.send_idx = 0
207
+ self.rcvd_idx = 0
208
+ self.reorder_dict = {}
209
+
210
+ base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0]
211
+ self.workers = [
212
+ multiprocessing.Process(
213
+ target=_worker_loop,
214
+ args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
215
+ base_seed + i, self.worker_init_fn, i))
216
+ for i in range(self.num_workers)]
217
+
218
+ if self.pin_memory or self.timeout > 0:
219
+ self.data_queue = queue.Queue()
220
+ if self.pin_memory:
221
+ maybe_device_id = torch.cuda.current_device()
222
+ else:
223
+ # do not initialize cuda context if not necessary
224
+ maybe_device_id = None
225
+ self.worker_manager_thread = threading.Thread(
226
+ target=_worker_manager_loop,
227
+ args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
228
+ maybe_device_id))
229
+ self.worker_manager_thread.daemon = True
230
+ self.worker_manager_thread.start()
231
+ else:
232
+ self.data_queue = self.worker_result_queue
233
+
234
+ for w in self.workers:
235
+ w.daemon = True # ensure that the worker exits on process exit
236
+ w.start()
237
+
238
+ _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
239
+ _set_SIGCHLD_handler()
240
+ self.worker_pids_set = True
241
+
242
+ # prime the prefetch loop
243
+ for _ in range(2 * self.num_workers):
244
+ self._put_indices()
245
+
246
+ def __len__(self):
247
+ return len(self.batch_sampler)
248
+
249
+ def _get_batch(self):
250
+ if self.timeout > 0:
251
+ try:
252
+ return self.data_queue.get(timeout=self.timeout)
253
+ except queue.Empty:
254
+ raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
255
+ else:
256
+ return self.data_queue.get()
257
+
258
+ def __next__(self):
259
+ if self.num_workers == 0: # same-process loading
260
+ indices = next(self.sample_iter) # may raise StopIteration
261
+ batch = self.collate_fn([self.dataset[i] for i in indices])
262
+ if self.pin_memory:
263
+ batch = pin_memory_batch(batch)
264
+ return batch
265
+
266
+ # check if the next sample has already been generated
267
+ if self.rcvd_idx in self.reorder_dict:
268
+ batch = self.reorder_dict.pop(self.rcvd_idx)
269
+ return self._process_next_batch(batch)
270
+
271
+ if self.batches_outstanding == 0:
272
+ self._shutdown_workers()
273
+ raise StopIteration
274
+
275
+ while True:
276
+ assert (not self.shutdown and self.batches_outstanding > 0)
277
+ idx, batch = self._get_batch()
278
+ self.batches_outstanding -= 1
279
+ if idx != self.rcvd_idx:
280
+ # store out-of-order samples
281
+ self.reorder_dict[idx] = batch
282
+ continue
283
+ return self._process_next_batch(batch)
284
+
285
+ next = __next__ # Python 2 compatibility
286
+
287
+ def __iter__(self):
288
+ return self
289
+
290
+ def _put_indices(self):
291
+ assert self.batches_outstanding < 2 * self.num_workers
292
+ indices = next(self.sample_iter, None)
293
+ if indices is None:
294
+ return
295
+ self.index_queue.put((self.send_idx, indices))
296
+ self.batches_outstanding += 1
297
+ self.send_idx += 1
298
+
299
+ def _process_next_batch(self, batch):
300
+ self.rcvd_idx += 1
301
+ self._put_indices()
302
+ if isinstance(batch, ExceptionWrapper):
303
+ raise batch.exc_type(batch.exc_msg)
304
+ return batch
305
+
306
+ def __getstate__(self):
307
+ # TODO: add limited pickling support for sharing an iterator
308
+ # across multiple threads for HOGWILD.
309
+ # Probably the best way to do this is by moving the sample pushing
310
+ # to a separate thread and then just sharing the data queue
311
+ # but signalling the end is tricky without a non-blocking API
312
+ raise NotImplementedError("DataLoaderIterator cannot be pickled")
313
+
314
+ def _shutdown_workers(self):
315
+ try:
316
+ if not self.shutdown:
317
+ self.shutdown = True
318
+ self.done_event.set()
319
+ # if worker_manager_thread is waiting to put
320
+ while not self.data_queue.empty():
321
+ self.data_queue.get()
322
+ for _ in self.workers:
323
+ self.index_queue.put(None)
324
+ # done_event should be sufficient to exit worker_manager_thread,
325
+ # but be safe here and put another None
326
+ self.worker_result_queue.put(None)
327
+ finally:
328
+ # removes pids no matter what
329
+ if self.worker_pids_set:
330
+ _remove_worker_pids(id(self))
331
+ self.worker_pids_set = False
332
+
333
+ def __del__(self):
334
+ if self.num_workers > 0:
335
+ self._shutdown_workers()
336
+
337
+
338
+ class DataLoader(object):
339
+ """
340
+ Data loader. Combines a dataset and a sampler, and provides
341
+ single- or multi-process iterators over the dataset.
342
+
343
+ Arguments:
344
+ dataset (Dataset): dataset from which to load the data.
345
+ batch_size (int, optional): how many samples per batch to load
346
+ (default: 1).
347
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
348
+ at every epoch (default: False).
349
+ sampler (Sampler, optional): defines the strategy to draw samples from
350
+ the dataset. If specified, ``shuffle`` must be False.
351
+ batch_sampler (Sampler, optional): like sampler, but returns a batch of
352
+ indices at a time. Mutually exclusive with batch_size, shuffle,
353
+ sampler, and drop_last.
354
+ num_workers (int, optional): how many subprocesses to use for data
355
+ loading. 0 means that the data will be loaded in the main process.
356
+ (default: 0)
357
+ collate_fn (callable, optional): merges a list of samples to form a mini-batch.
358
+ pin_memory (bool, optional): If ``True``, the data loader will copy tensors
359
+ into CUDA pinned memory before returning them.
360
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
361
+ if the dataset size is not divisible by the batch size. If ``False`` and
362
+ the size of dataset is not divisible by the batch size, then the last batch
363
+ will be smaller. (default: False)
364
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
365
+ from workers. Should always be non-negative. (default: 0)
366
+ worker_init_fn (callable, optional): If not None, this will be called on each
367
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
368
+ input, after seeding and before data loading. (default: None)
369
+
370
+ .. note:: By default, each worker will have its PyTorch seed set to
371
+ ``base_seed + worker_id``, where ``base_seed`` is a long generated
372
+ by main process using its RNG. You may use ``torch.initial_seed()`` to access
373
+ this value in :attr:`worker_init_fn`, which can be used to set other seeds
374
+ (e.g. NumPy) before data loading.
375
+
376
+ .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
377
+ unpicklable object, e.g., a lambda function.
378
+ """
379
+
380
+ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
381
+ num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
382
+ timeout=0, worker_init_fn=None):
383
+ self.dataset = dataset
384
+ self.batch_size = batch_size
385
+ self.num_workers = num_workers
386
+ self.collate_fn = collate_fn
387
+ self.pin_memory = pin_memory
388
+ self.drop_last = drop_last
389
+ self.timeout = timeout
390
+ self.worker_init_fn = worker_init_fn
391
+
392
+ if timeout < 0:
393
+ raise ValueError('timeout option should be non-negative')
394
+
395
+ if batch_sampler is not None:
396
+ if batch_size > 1 or shuffle or sampler is not None or drop_last:
397
+ raise ValueError('batch_sampler is mutually exclusive with '
398
+ 'batch_size, shuffle, sampler, and drop_last')
399
+
400
+ if sampler is not None and shuffle:
401
+ raise ValueError('sampler is mutually exclusive with shuffle')
402
+
403
+ if self.num_workers < 0:
404
+ raise ValueError('num_workers cannot be negative; '
405
+ 'use num_workers=0 to disable multiprocessing.')
406
+
407
+ if batch_sampler is None:
408
+ if sampler is None:
409
+ if shuffle:
410
+ sampler = RandomSampler(dataset)
411
+ else:
412
+ sampler = SequentialSampler(dataset)
413
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
414
+
415
+ self.sampler = sampler
416
+ self.batch_sampler = batch_sampler
417
+
418
+ def __iter__(self):
419
+ return DataLoaderIter(self)
420
+
421
+ def __len__(self):
422
+ return len(self.batch_sampler)
lib/utils/data/dataset.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import warnings
3
+
4
+ from torch._utils import _accumulate
5
+ from torch import randperm
6
+
7
+
8
+ class Dataset(object):
9
+ """An abstract class representing a Dataset.
10
+
11
+ All other datasets should subclass it. All subclasses should override
12
+ ``__len__``, that provides the size of the dataset, and ``__getitem__``,
13
+ supporting integer indexing in range from 0 to len(self) exclusive.
14
+ """
15
+
16
+ def __getitem__(self, index):
17
+ raise NotImplementedError
18
+
19
+ def __len__(self):
20
+ raise NotImplementedError
21
+
22
+ def __add__(self, other):
23
+ return ConcatDataset([self, other])
24
+
25
+
26
+ class TensorDataset(Dataset):
27
+ """Dataset wrapping data and target tensors.
28
+
29
+ Each sample will be retrieved by indexing both tensors along the first
30
+ dimension.
31
+
32
+ Arguments:
33
+ data_tensor (Tensor): contains sample data.
34
+ target_tensor (Tensor): contains sample targets (labels).
35
+ """
36
+
37
+ def __init__(self, data_tensor, target_tensor):
38
+ assert data_tensor.size(0) == target_tensor.size(0)
39
+ self.data_tensor = data_tensor
40
+ self.target_tensor = target_tensor
41
+
42
+ def __getitem__(self, index):
43
+ return self.data_tensor[index], self.target_tensor[index]
44
+
45
+ def __len__(self):
46
+ return self.data_tensor.size(0)
47
+
48
+
49
+ class ConcatDataset(Dataset):
50
+ """
51
+ Dataset to concatenate multiple datasets.
52
+ Purpose: useful to assemble different existing datasets, possibly
53
+ large-scale datasets as the concatenation operation is done in an
54
+ on-the-fly manner.
55
+
56
+ Arguments:
57
+ datasets (iterable): List of datasets to be concatenated
58
+ """
59
+
60
+ @staticmethod
61
+ def cumsum(sequence):
62
+ r, s = [], 0
63
+ for e in sequence:
64
+ l = len(e)
65
+ r.append(l + s)
66
+ s += l
67
+ return r
68
+
69
+ def __init__(self, datasets):
70
+ super(ConcatDataset, self).__init__()
71
+ assert len(datasets) > 0, 'datasets should not be an empty iterable'
72
+ self.datasets = list(datasets)
73
+ self.cumulative_sizes = self.cumsum(self.datasets)
74
+
75
+ def __len__(self):
76
+ return self.cumulative_sizes[-1]
77
+
78
+ def __getitem__(self, idx):
79
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
80
+ if dataset_idx == 0:
81
+ sample_idx = idx
82
+ else:
83
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
84
+ return self.datasets[dataset_idx][sample_idx]
85
+
86
+ @property
87
+ def cummulative_sizes(self):
88
+ warnings.warn("cummulative_sizes attribute is renamed to "
89
+ "cumulative_sizes", DeprecationWarning, stacklevel=2)
90
+ return self.cumulative_sizes
91
+
92
+
93
+ class Subset(Dataset):
94
+ def __init__(self, dataset, indices):
95
+ self.dataset = dataset
96
+ self.indices = indices
97
+
98
+ def __getitem__(self, idx):
99
+ return self.dataset[self.indices[idx]]
100
+
101
+ def __len__(self):
102
+ return len(self.indices)
103
+
104
+
105
+ def random_split(dataset, lengths):
106
+ """
107
+ Randomly split a dataset into non-overlapping new datasets of given lengths
108
+ ds
109
+
110
+ Arguments:
111
+ dataset (Dataset): Dataset to be split
112
+ lengths (iterable): lengths of splits to be produced
113
+ """
114
+ if sum(lengths) != len(dataset):
115
+ raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
116
+
117
+ indices = randperm(sum(lengths))
118
+ return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
lib/utils/data/distributed.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from .sampler import Sampler
4
+ from torch.distributed import get_world_size, get_rank
5
+
6
+
7
+ class DistributedSampler(Sampler):
8
+ """Sampler that restricts data loading to a subset of the dataset.
9
+
10
+ It is especially useful in conjunction with
11
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12
+ process can pass a DistributedSampler instance as a DataLoader sampler,
13
+ and load a subset of the original dataset that is exclusive to it.
14
+
15
+ .. note::
16
+ Dataset is assumed to be of constant size.
17
+
18
+ Arguments:
19
+ dataset: Dataset used for sampling.
20
+ num_replicas (optional): Number of processes participating in
21
+ distributed training.
22
+ rank (optional): Rank of the current process within num_replicas.
23
+ """
24
+
25
+ def __init__(self, dataset, num_replicas=None, rank=None):
26
+ if num_replicas is None:
27
+ num_replicas = get_world_size()
28
+ if rank is None:
29
+ rank = get_rank()
30
+ self.dataset = dataset
31
+ self.num_replicas = num_replicas
32
+ self.rank = rank
33
+ self.epoch = 0
34
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35
+ self.total_size = self.num_samples * self.num_replicas
36
+
37
+ def __iter__(self):
38
+ # deterministically shuffle based on epoch
39
+ g = torch.Generator()
40
+ g.manual_seed(self.epoch)
41
+ indices = list(torch.randperm(len(self.dataset), generator=g))
42
+
43
+ # add extra samples to make it evenly divisible
44
+ indices += indices[:(self.total_size - len(indices))]
45
+ assert len(indices) == self.total_size
46
+
47
+ # subsample
48
+ offset = self.num_samples * self.rank
49
+ indices = indices[offset:offset + self.num_samples]
50
+ assert len(indices) == self.num_samples
51
+
52
+ return iter(indices)
53
+
54
+ def __len__(self):
55
+ return self.num_samples
56
+
57
+ def set_epoch(self, epoch):
58
+ self.epoch = epoch
lib/utils/data/sampler.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class Sampler(object):
5
+ """Base class for all Samplers.
6
+
7
+ Every Sampler subclass has to provide an __iter__ method, providing a way
8
+ to iterate over indices of dataset elements, and a __len__ method that
9
+ returns the length of the returned iterators.
10
+ """
11
+
12
+ def __init__(self, data_source):
13
+ pass
14
+
15
+ def __iter__(self):
16
+ raise NotImplementedError
17
+
18
+ def __len__(self):
19
+ raise NotImplementedError
20
+
21
+
22
+ class SequentialSampler(Sampler):
23
+ """Samples elements sequentially, always in the same order.
24
+
25
+ Arguments:
26
+ data_source (Dataset): dataset to sample from
27
+ """
28
+
29
+ def __init__(self, data_source):
30
+ self.data_source = data_source
31
+
32
+ def __iter__(self):
33
+ return iter(range(len(self.data_source)))
34
+
35
+ def __len__(self):
36
+ return len(self.data_source)
37
+
38
+
39
+ class RandomSampler(Sampler):
40
+ """Samples elements randomly, without replacement.
41
+
42
+ Arguments:
43
+ data_source (Dataset): dataset to sample from
44
+ """
45
+
46
+ def __init__(self, data_source):
47
+ self.data_source = data_source
48
+
49
+ def __iter__(self):
50
+ return iter(torch.randperm(len(self.data_source)).long())
51
+
52
+ def __len__(self):
53
+ return len(self.data_source)
54
+
55
+
56
+ class SubsetRandomSampler(Sampler):
57
+ """Samples elements randomly from a given list of indices, without replacement.
58
+
59
+ Arguments:
60
+ indices (list): a list of indices
61
+ """
62
+
63
+ def __init__(self, indices):
64
+ self.indices = indices
65
+
66
+ def __iter__(self):
67
+ return (self.indices[i] for i in torch.randperm(len(self.indices)))
68
+
69
+ def __len__(self):
70
+ return len(self.indices)
71
+
72
+
73
+ class WeightedRandomSampler(Sampler):
74
+ """Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
75
+
76
+ Arguments:
77
+ weights (list) : a list of weights, not necessary summing up to one
78
+ num_samples (int): number of samples to draw
79
+ replacement (bool): if ``True``, samples are drawn with replacement.
80
+ If not, they are drawn without replacement, which means that when a
81
+ sample index is drawn for a row, it cannot be drawn again for that row.
82
+ """
83
+
84
+ def __init__(self, weights, num_samples, replacement=True):
85
+ self.weights = torch.DoubleTensor(weights)
86
+ self.num_samples = num_samples
87
+ self.replacement = replacement
88
+
89
+ def __iter__(self):
90
+ return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
91
+
92
+ def __len__(self):
93
+ return self.num_samples
94
+
95
+
96
+ class BatchSampler(object):
97
+ """Wraps another sampler to yield a mini-batch of indices.
98
+
99
+ Args:
100
+ sampler (Sampler): Base sampler.
101
+ batch_size (int): Size of mini-batch.
102
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
103
+ its size would be less than ``batch_size``
104
+
105
+ Example:
106
+ >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
107
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
108
+ >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
109
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
110
+ """
111
+
112
+ def __init__(self, sampler, batch_size, drop_last):
113
+ self.sampler = sampler
114
+ self.batch_size = batch_size
115
+ self.drop_last = drop_last
116
+
117
+ def __iter__(self):
118
+ batch = []
119
+ for idx in self.sampler:
120
+ batch.append(idx)
121
+ if len(batch) == self.batch_size:
122
+ yield batch
123
+ batch = []
124
+ if len(batch) > 0 and not self.drop_last:
125
+ yield batch
126
+
127
+ def __len__(self):
128
+ if self.drop_last:
129
+ return len(self.sampler) // self.batch_size
130
+ else:
131
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size