huqiming513
commited on
Commit
•
2a2ae9a
1
Parent(s):
03b684c
Upload 13 files
Browse files- checkpoints/FuseNet_CA_MEF_251.pth +3 -0
- checkpoints/FuseNet_FD_297.pth +3 -0
- checkpoints/IANet_335.pth +3 -0
- checkpoints/NSNet_422.pth +3 -0
- datasets/__init__.py +4 -0
- datasets/low_light.py +171 -0
- datasets/low_light_test.py +41 -0
- datasets/mef.py +36 -0
- tools/TensorboardWriter.py +16 -0
- tools/__init__.py +4 -0
- tools/model_utils.py +11 -0
- tools/mutils.py +30 -0
- tools/saver.py +55 -0
checkpoints/FuseNet_CA_MEF_251.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c44373240c845c4a79ad49a7aa3436a53917d012165fc5a94843fc70174b25aa
|
3 |
+
size 895263
|
checkpoints/FuseNet_FD_297.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c2b7d29e8b90ac3f6d386d1d6c927a01ff24a95e14f79fa900d8f8fa45647237
|
3 |
+
size 892959
|
checkpoints/IANet_335.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a551ba226137d681e851b91c6717ce51e731b7050fd68143f05a036e5e8e70a
|
3 |
+
size 3409487
|
checkpoints/NSNet_422.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ce8998e5c9b5b3eb5d0d582cfb5874213073b247c2a577ab3416ed42d9d2a72
|
3 |
+
size 3410639
|
datasets/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .low_light import *
|
2 |
+
from .low_light_test import *
|
3 |
+
from .mef import *
|
4 |
+
|
datasets/low_light.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.data as data
|
6 |
+
import torchvision.transforms as T
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
class LowLightFDataset(data.Dataset):
|
11 |
+
def __init__(self, root, image_split='images_aug', targets_split='targets', training=True):
|
12 |
+
self.root = root
|
13 |
+
self.num_instances = 8
|
14 |
+
self.img_root = os.path.join(root, image_split)
|
15 |
+
self.target_root = os.path.join(root, targets_split)
|
16 |
+
self.training = training
|
17 |
+
print('----', image_split, targets_split, '----')
|
18 |
+
self.imgs = list(sorted(os.listdir(self.img_root)))
|
19 |
+
self.gts = list(sorted(os.listdir(self.target_root)))
|
20 |
+
|
21 |
+
names = [img_name.split('_')[0] + '.' + img_name.split('.')[-1] for img_name in self.imgs]
|
22 |
+
self.imgs = list(
|
23 |
+
filter(lambda img_name: img_name.split('_')[0] + '.' + img_name.split('.')[-1] in self.gts, self.imgs))
|
24 |
+
|
25 |
+
self.gts = list(filter(lambda gt: gt in names, self.gts))
|
26 |
+
|
27 |
+
print(len(self.imgs), len(self.gts))
|
28 |
+
self.preproc = T.Compose(
|
29 |
+
[T.ToTensor()]
|
30 |
+
)
|
31 |
+
self.preproc_gt = T.Compose(
|
32 |
+
[T.ToTensor()]
|
33 |
+
)
|
34 |
+
|
35 |
+
def __getitem__(self, idx):
|
36 |
+
fn, ext = self.gts[idx].split('.')
|
37 |
+
imgs = []
|
38 |
+
for i in range(self.num_instances):
|
39 |
+
img_path = os.path.join(self.img_root, f"{fn}_{i}.{ext}")
|
40 |
+
imgs += [self.preproc(Image.open(img_path).convert("RGB"))]
|
41 |
+
|
42 |
+
if self.training:
|
43 |
+
random.shuffle(imgs)
|
44 |
+
gt_path = os.path.join(self.target_root, self.gts[idx])
|
45 |
+
gt = Image.open(gt_path).convert("RGB")
|
46 |
+
gt = self.preproc_gt(gt)
|
47 |
+
|
48 |
+
# print(img_path, gt_path)
|
49 |
+
return torch.stack(imgs, dim=0), gt, fn
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.gts)
|
53 |
+
|
54 |
+
|
55 |
+
class LowLightFDatasetEval(data.Dataset):
|
56 |
+
def __init__(self, root, targets_split='targets', training=True):
|
57 |
+
self.root = root
|
58 |
+
self.num_instances = 1
|
59 |
+
self.img_root = os.path.join(root, 'images')
|
60 |
+
self.target_root = os.path.join(root, targets_split)
|
61 |
+
self.training = training
|
62 |
+
|
63 |
+
self.imgs = list(sorted(os.listdir(self.img_root)))
|
64 |
+
self.gts = list(sorted(os.listdir(self.target_root)))
|
65 |
+
|
66 |
+
self.imgs = list(filter(lambda img_name: img_name in self.gts, self.imgs))
|
67 |
+
self.gts = list(filter(lambda gt: gt in self.imgs, self.gts))
|
68 |
+
|
69 |
+
print(len(self.imgs), len(self.gts))
|
70 |
+
self.preproc = T.Compose(
|
71 |
+
[T.ToTensor()]
|
72 |
+
)
|
73 |
+
self.preproc_gt = T.Compose(
|
74 |
+
[T.ToTensor()]
|
75 |
+
)
|
76 |
+
|
77 |
+
def __getitem__(self, idx):
|
78 |
+
fn, ext = self.gts[idx].split('.')
|
79 |
+
imgs = []
|
80 |
+
for i in range(self.num_instances):
|
81 |
+
img_path = os.path.join(self.img_root, f"{fn}.{ext}")
|
82 |
+
imgs += [self.preproc(Image.open(img_path).convert("RGB"))]
|
83 |
+
|
84 |
+
gt_path = os.path.join(self.target_root, self.gts[idx])
|
85 |
+
gt = Image.open(gt_path).convert("RGB")
|
86 |
+
gt = self.preproc_gt(gt)
|
87 |
+
|
88 |
+
# print(img_path, gt_path)
|
89 |
+
return torch.stack(imgs, dim=0), gt, fn
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
return len(self.gts)
|
93 |
+
|
94 |
+
|
95 |
+
class LowLightDataset(data.Dataset):
|
96 |
+
def __init__(self, root, targets_split='targets', color_tuning=False):
|
97 |
+
self.root = root
|
98 |
+
self.img_root = os.path.join(root, 'images')
|
99 |
+
self.target_root = os.path.join(root, targets_split)
|
100 |
+
self.color_tuning = color_tuning
|
101 |
+
self.imgs = list(sorted(os.listdir(self.img_root)))
|
102 |
+
self.gts = list(sorted(os.listdir(self.target_root)))
|
103 |
+
|
104 |
+
self.imgs = list(filter(lambda img_name: img_name in self.gts, self.imgs))
|
105 |
+
self.gts = list(filter(lambda gt: gt in self.imgs, self.gts))
|
106 |
+
|
107 |
+
print(len(self.imgs), len(self.gts))
|
108 |
+
self.preproc = T.Compose(
|
109 |
+
[T.ToTensor()]
|
110 |
+
)
|
111 |
+
self.preproc_gt = T.Compose(
|
112 |
+
[T.ToTensor()]
|
113 |
+
)
|
114 |
+
|
115 |
+
def __getitem__(self, idx):
|
116 |
+
fn, ext = self.gts[idx].split('.')
|
117 |
+
|
118 |
+
img_path = os.path.join(self.img_root, self.imgs[idx])
|
119 |
+
img = Image.open(img_path).convert("RGB")
|
120 |
+
img = self.preproc(img)
|
121 |
+
|
122 |
+
gt_path = os.path.join(self.target_root, self.gts[idx])
|
123 |
+
gt = Image.open(gt_path).convert("RGB")
|
124 |
+
gt = self.preproc_gt(gt)
|
125 |
+
|
126 |
+
if self.color_tuning:
|
127 |
+
return img, gt, 'a' + self.imgs[idx], 'a' + self.imgs[idx]
|
128 |
+
else:
|
129 |
+
return img, gt, fn
|
130 |
+
|
131 |
+
def __len__(self):
|
132 |
+
return len(self.imgs)
|
133 |
+
|
134 |
+
|
135 |
+
class LowLightDatasetReverse(data.Dataset):
|
136 |
+
def __init__(self, root, targets_split='targets', color_tuning=False):
|
137 |
+
self.root = root
|
138 |
+
self.img_root = os.path.join(root, 'images')
|
139 |
+
self.target_root = os.path.join(root, targets_split)
|
140 |
+
self.color_tuning = color_tuning
|
141 |
+
self.imgs = list(sorted(os.listdir(self.img_root)))
|
142 |
+
self.gts = list(sorted(os.listdir(self.target_root)))
|
143 |
+
|
144 |
+
self.imgs = list(filter(lambda img_name: img_name in self.gts, self.imgs))
|
145 |
+
self.gts = list(filter(lambda gt: gt in self.imgs, self.gts))
|
146 |
+
|
147 |
+
print(len(self.imgs), len(self.gts))
|
148 |
+
self.preproc = T.Compose(
|
149 |
+
[T.ToTensor()]
|
150 |
+
)
|
151 |
+
self.preproc_gt = T.Compose(
|
152 |
+
[T.ToTensor()]
|
153 |
+
)
|
154 |
+
|
155 |
+
def __getitem__(self, idx):
|
156 |
+
img_path = os.path.join(self.img_root, self.imgs[idx])
|
157 |
+
img = Image.open(img_path).convert("RGB")
|
158 |
+
img = self.preproc(img)
|
159 |
+
|
160 |
+
gt_path = os.path.join(self.target_root, self.gts[idx])
|
161 |
+
gt = Image.open(gt_path).convert("RGB")
|
162 |
+
gt = self.preproc_gt(gt)
|
163 |
+
|
164 |
+
if self.color_tuning:
|
165 |
+
return gt, img, 'a' + self.imgs[idx], 'a' + self.imgs[idx]
|
166 |
+
else:
|
167 |
+
fn, ext = os.path.splitext(self.imgs[idx])
|
168 |
+
return gt, img, '%03d' % int(fn) + ext
|
169 |
+
|
170 |
+
def __len__(self):
|
171 |
+
return len(self.imgs)
|
datasets/low_light_test.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch.utils.data as data
|
4 |
+
import torchvision.transforms as T
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class LowLightDatasetTest(data.Dataset):
|
9 |
+
def __init__(self, root, reside=False):
|
10 |
+
self.root = root
|
11 |
+
self.items = []
|
12 |
+
|
13 |
+
subsets = os.listdir(root)
|
14 |
+
for subset in subsets:
|
15 |
+
img_root = os.path.join(root, subset)
|
16 |
+
img_names = list(sorted(os.listdir(img_root)))
|
17 |
+
|
18 |
+
for img_name in img_names:
|
19 |
+
self.items.append((
|
20 |
+
os.path.join(img_root, img_name),
|
21 |
+
subset,
|
22 |
+
img_name
|
23 |
+
))
|
24 |
+
|
25 |
+
self.preproc = T.Compose(
|
26 |
+
[T.ToTensor()]
|
27 |
+
)
|
28 |
+
self.preproc_raw = T.Compose(
|
29 |
+
[T.ToTensor()]
|
30 |
+
)
|
31 |
+
|
32 |
+
def __getitem__(self, idx):
|
33 |
+
img_path, subset, img_name = self.items[idx]
|
34 |
+
img = Image.open(img_path).convert("RGB")
|
35 |
+
img = img.resize((img.width // 8 * 8, img.height // 8 * 8), Image.ANTIALIAS)
|
36 |
+
img_raw = self.preproc_raw(img)
|
37 |
+
|
38 |
+
return img_raw, subset, img_name
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return len(self.items)
|
datasets/mef.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch.utils.data as data
|
5 |
+
import torchvision.transforms as T
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
class MEFDataset(data.Dataset):
|
10 |
+
def __init__(self, root):
|
11 |
+
self.img_root = root
|
12 |
+
|
13 |
+
self.numbers = list(sorted(os.listdir(self.img_root)))
|
14 |
+
print(len(self.numbers))
|
15 |
+
|
16 |
+
self.preproc = T.Compose(
|
17 |
+
[T.ToTensor()]
|
18 |
+
)
|
19 |
+
|
20 |
+
def __getitem__(self, idx):
|
21 |
+
number = self.numbers[idx]
|
22 |
+
im_dir = os.path.join(self.img_root, number)
|
23 |
+
fn1, fn2 = tuple(random.sample(os.listdir(im_dir), k=2))
|
24 |
+
fp1 = os.path.join(im_dir, fn1)
|
25 |
+
fp2 = os.path.join(im_dir, fn2)
|
26 |
+
img1 = Image.open(fp1).convert("RGB")
|
27 |
+
img2 = Image.open(fp2).convert("RGB")
|
28 |
+
img1 = self.preproc(img1)
|
29 |
+
img2 = self.preproc(img2)
|
30 |
+
|
31 |
+
fn1 = f'{number}_{fn1}'
|
32 |
+
fn2 = f'{number}_{fn2}'
|
33 |
+
return img1, img2, fn1, fn2
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return len(self.numbers)
|
tools/TensorboardWriter.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
from tensorboardX import SummaryWriter
|
3 |
+
|
4 |
+
|
5 |
+
class SingleSummaryWriter(SummaryWriter):
|
6 |
+
_instance_lock = threading.Lock()
|
7 |
+
|
8 |
+
def __init__(self, logdir=None, **kwargs):
|
9 |
+
super().__init__(logdir, **kwargs)
|
10 |
+
|
11 |
+
def __new__(cls, *args, **kwargs):
|
12 |
+
if not hasattr(SingleSummaryWriter, "_instance"):
|
13 |
+
with SingleSummaryWriter._instance_lock:
|
14 |
+
if not hasattr(SingleSummaryWriter, "_instance"):
|
15 |
+
SingleSummaryWriter._instance = object.__new__(cls)
|
16 |
+
return SingleSummaryWriter._instance
|
tools/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .TensorboardWriter import *
|
2 |
+
from .model_utils import *
|
3 |
+
from .saver import *
|
4 |
+
from .mutils import *
|
tools/model_utils.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def count_parameters(model):
|
2 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
3 |
+
|
4 |
+
|
5 |
+
def count_conv_layers(model):
|
6 |
+
cnt = 0
|
7 |
+
for mo in model().modules():
|
8 |
+
if type(mo).__name__ == 'Conv2d':
|
9 |
+
cnt += 1
|
10 |
+
|
11 |
+
print(model.__name__, cnt, count_parameters(model()))
|
tools/mutils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import time
|
5 |
+
|
6 |
+
|
7 |
+
def contains(key, lst):
|
8 |
+
flag = False
|
9 |
+
for item in lst:
|
10 |
+
if key == item:
|
11 |
+
flag = True
|
12 |
+
return flag
|
13 |
+
|
14 |
+
|
15 |
+
def make_empty_dir(new_dir):
|
16 |
+
if os.path.exists(new_dir):
|
17 |
+
shutil.rmtree(new_dir)
|
18 |
+
os.makedirs(new_dir, exist_ok=True)
|
19 |
+
|
20 |
+
|
21 |
+
def get_timestamp():
|
22 |
+
return str(time.time()).replace('.', '')
|
23 |
+
|
24 |
+
|
25 |
+
def get_formatted_time():
|
26 |
+
return datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
pass
|
tools/saver.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from tools import mutils
|
6 |
+
|
7 |
+
saved_grad = None
|
8 |
+
saved_name = None
|
9 |
+
|
10 |
+
base_url = './results'
|
11 |
+
os.makedirs(base_url, exist_ok=True)
|
12 |
+
|
13 |
+
|
14 |
+
def normalize_tensor_mm(tensor):
|
15 |
+
return (tensor - tensor.min()) / (tensor.max() - tensor.min())
|
16 |
+
|
17 |
+
|
18 |
+
def normalize_tensor_sigmoid(tensor):
|
19 |
+
return nn.functional.sigmoid(tensor)
|
20 |
+
|
21 |
+
|
22 |
+
def save_image(tensor, name=None, save_path=None, exit_flag=False, timestamp=False, norm=False):
|
23 |
+
import torchvision.utils as vutils
|
24 |
+
os.makedirs(base_url, exist_ok=True)
|
25 |
+
if norm:
|
26 |
+
tensor = normalize_tensor_mm(tensor)
|
27 |
+
grid = vutils.make_grid(tensor.detach().cpu(), nrow=4)
|
28 |
+
|
29 |
+
if save_path:
|
30 |
+
vutils.save_image(grid, save_path)
|
31 |
+
else:
|
32 |
+
if timestamp:
|
33 |
+
vutils.save_image(grid, f'{base_url}/{name}_{mutils.get_timestamp()}.png')
|
34 |
+
else:
|
35 |
+
vutils.save_image(grid, f'{base_url}/{name}.png')
|
36 |
+
if exit_flag:
|
37 |
+
exit(0)
|
38 |
+
|
39 |
+
|
40 |
+
def save_feature(tensor, name, exit_flag=False, timestamp=False):
|
41 |
+
import torchvision.utils as vutils
|
42 |
+
# tensors = [tensor, normalize_tensor_mm(tensor), normalize_tensor_sigmoid(tensor)]
|
43 |
+
tensors = [tensor]
|
44 |
+
titles = ['original', 'min-max', 'sigmoid']
|
45 |
+
os.makedirs(base_url, exist_ok=True)
|
46 |
+
if timestamp:
|
47 |
+
name += '_' + str(time.time()).replace('.', '')
|
48 |
+
|
49 |
+
for index, tensor in enumerate(tensors):
|
50 |
+
_data = tensor.detach().cpu().squeeze(0).unsqueeze(1)
|
51 |
+
num_per_row = 8
|
52 |
+
grid = vutils.make_grid(_data, nrow=num_per_row)
|
53 |
+
vutils.save_image(grid, f'{base_url}/{name}_{titles[index]}.png')
|
54 |
+
if exit_flag:
|
55 |
+
exit(0)
|