Spaces:
Runtime error
Runtime error
import glob | |
import h5py | |
from PIL import Image | |
from torchvision.transforms import RandomCrop | |
from torchvision.transforms.functional import to_tensor | |
from tqdm import tqdm | |
from Dataloader import ImageAugment | |
patch_size = 128 | |
shrink_size = 2 | |
noise_level = 1 | |
patches_per_img = 20 | |
images = glob.glob("dataset/train/*") | |
database = h5py.File("train_images.hdf5", 'w') | |
dat_group = database.create_group("shrink_2_noise_level_1_downsample_random_rgb") | |
# del database['shrink_2_noise_level_1_downsample_random'] | |
storage_lr = dat_group.create_dataset("train_lr", shape=(patches_per_img * len(images), 3, | |
patch_size // shrink_size, | |
patch_size // shrink_size), | |
dtype='float32', | |
# compression='lzf', | |
) | |
storage_hr = dat_group.create_dataset("train_hr", shape=(patches_per_img * len(images), 3, | |
patch_size, patch_size), | |
# compression='lzf', | |
dtype='float32') | |
random_cropper = RandomCrop(size=patch_size) | |
img_augmenter = ImageAugment(shrink_size, noise_level, down_sample_method=None) | |
def get_img_patches(img_pil): | |
img_patch = random_cropper(img_pil) | |
lr_hr_patches = img_augmenter.process(img_patch) | |
return lr_hr_patches | |
counter = 0 | |
for img in tqdm(images): | |
img_pil = Image.open(img).convert("RGB") | |
for i in range(patches_per_img): | |
patch = get_img_patches(img_pil) | |
storage_lr[counter] = to_tensor(patch[0].convert("RGB")).numpy() | |
storage_hr[counter] = to_tensor(patch[1].convert("RGB")).numpy() | |
counter += 1 | |
database.close() | |