File size: 1,763 Bytes
e538b68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import math
import os

import PIL.Image as Image
import numpy as np
import torch
import torchvision.transforms as vtrans
import tqdm


def main(fip, fod):
    max_overex_rate = 0.25
    steps = 20
    num_gen = 4

    im = Image.open(fip)
    im = vtrans.ToTensor()(im)
    im_max = torch.flatten(torch.max(im, dim=0, keepdim=True).values)
    mag = 1. / torch.topk(im_max, math.floor(len(im_max) * max_overex_rate + 1)).values
    mag = mag[range(0, len(mag), int(len(mag) * (1. / steps)))]
    mag_diff = torch.diff(mag, 1)
    mag = mag[:-1]

    top_mag_diff = torch.topk(mag_diff, num_gen).values
    min_gain = top_mag_diff[top_mag_diff > 0][-1]
    min_mag = mag[0]
    max_mag = mag[mag_diff > min_gain][-1]
    fn, ext = os.path.basename(fip).split('.')
    bar.set_description(f'{fn}: {min_gain}')
    ma = np.arange(1, min_mag - min_gain, min_gain * 2)
    if len(ma) > num_gen:
        mags = np.append(np.linspace(1, min_mag - min_gain, num_gen),
                         np.linspace(min_mag, max_mag, num_gen))
    elif len(ma) == num_gen:
        mags = np.append(ma, np.linspace(min_mag, max_mag, num_gen))
    else:
        mags = np.linspace(1, max_mag, num_gen * 2)

    im = Image.open(fip)
    im_raw = vtrans.ToTensor()(im)

    for i, mag in enumerate(mags):
        im = im_raw * mag
        im.clamp_max_(1.)
        fop = os.path.join(fod, f'{fn}_{i}.{ext}')

        if not os.path.exists(fop):
            vtrans.ToPILImage()(im).save(fop)


if __name__ == '__main__':
    # one needs to download it online
    fid = './data/LOL/train/images'
    fod = './data/LOL/train/images_aug'
    os.makedirs(fod, exist_ok=True)

    bar = tqdm.tqdm(os.listdir(fid))
    for fn in bar:
        fip = os.path.join(fid, fn)
        main(fip, fod)