Bread / exposure_augment.py
huqiming513's picture
Upload 14 files
e538b68
import math
import os
import PIL.Image as Image
import numpy as np
import torch
import torchvision.transforms as vtrans
import tqdm
def main(fip, fod):
max_overex_rate = 0.25
steps = 20
num_gen = 4
im = Image.open(fip)
im = vtrans.ToTensor()(im)
im_max = torch.flatten(torch.max(im, dim=0, keepdim=True).values)
mag = 1. / torch.topk(im_max, math.floor(len(im_max) * max_overex_rate + 1)).values
mag = mag[range(0, len(mag), int(len(mag) * (1. / steps)))]
mag_diff = torch.diff(mag, 1)
mag = mag[:-1]
top_mag_diff = torch.topk(mag_diff, num_gen).values
min_gain = top_mag_diff[top_mag_diff > 0][-1]
min_mag = mag[0]
max_mag = mag[mag_diff > min_gain][-1]
fn, ext = os.path.basename(fip).split('.')
bar.set_description(f'{fn}: {min_gain}')
ma = np.arange(1, min_mag - min_gain, min_gain * 2)
if len(ma) > num_gen:
mags = np.append(np.linspace(1, min_mag - min_gain, num_gen),
np.linspace(min_mag, max_mag, num_gen))
elif len(ma) == num_gen:
mags = np.append(ma, np.linspace(min_mag, max_mag, num_gen))
else:
mags = np.linspace(1, max_mag, num_gen * 2)
im = Image.open(fip)
im_raw = vtrans.ToTensor()(im)
for i, mag in enumerate(mags):
im = im_raw * mag
im.clamp_max_(1.)
fop = os.path.join(fod, f'{fn}_{i}.{ext}')
if not os.path.exists(fop):
vtrans.ToPILImage()(im).save(fop)
if __name__ == '__main__':
# one needs to download it online
fid = './data/LOL/train/images'
fod = './data/LOL/train/images_aug'
os.makedirs(fod, exist_ok=True)
bar = tqdm.tqdm(os.listdir(fid))
for fn in bar:
fip = os.path.join(fid, fn)
main(fip, fod)