Spaces:
Running
on
T4
Running
on
T4
# -*- coding: utf-8 -*- | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from torchvision.models import vgg19 | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision.utils import save_image, make_grid | |
from torchvision.transforms import ToTensor | |
import numpy as np | |
import cv2 | |
import glob | |
import random | |
from PIL import Image | |
from tqdm import tqdm | |
# from degradation.degradation_main import degredate_process, preparation | |
from opt import opt | |
class ImageDataset(Dataset): | |
def __init__(self, train_lr_paths, degrade_hr_paths, train_hr_paths): | |
# print("low_res path sample is ", train_lr_paths[0]) | |
# print(train_hr_paths[0]) | |
# hr_height, hr_width = hr_shape | |
self.transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
] | |
) | |
self.files_lr = train_lr_paths | |
self.files_degrade_hr = degrade_hr_paths | |
self.files_hr = train_hr_paths | |
assert(len(self.files_lr) == len(self.files_hr)) | |
assert(len(self.files_lr) == len(self.files_degrade_hr)) | |
def augment(self, imgs, hflip=True, rotation=True): | |
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). | |
All the images in the list use the same augmentation. | |
Args: | |
imgs (list[ndarray] | ndarray): Images to be augmented. If the input | |
is an ndarray, it will be transformed to a list. | |
hflip (bool): Horizontal flip. Default: True. | |
rotation (bool): Rotation. Default: True. | |
Returns: | |
imgs (list[ndarray] | ndarray): Augmented images and flows. If returned | |
results only have one element, just return ndarray. | |
""" | |
hflip = hflip and random.random() < 0.5 | |
vflip = rotation and random.random() < 0.5 | |
rot90 = rotation and random.random() < 0.5 | |
def _augment(img): | |
if hflip: # horizontal | |
cv2.flip(img, 1, img) | |
if vflip: # vertical | |
cv2.flip(img, 0, img) | |
if rot90: | |
img = img.transpose(1, 0, 2) | |
return img | |
if not isinstance(imgs, list): | |
imgs = [imgs] | |
imgs = [_augment(img) for img in imgs] | |
if len(imgs) == 1: | |
imgs = imgs[0] | |
return imgs | |
def __getitem__(self, index): | |
# Read File | |
img_lr = cv2.imread(self.files_lr[index % len(self.files_lr)]) # Should be BGR | |
img_degrade_hr = cv2.imread(self.files_degrade_hr[index % len(self.files_degrade_hr)]) | |
img_hr = cv2.imread(self.files_hr[index % len(self.files_hr)]) | |
# Augmentation | |
if random.random() < opt["augment_prob"]: | |
img_lr, img_degrade_hr, img_hr = self.augment([img_lr, img_degrade_hr, img_hr]) | |
# Transform to Tensor | |
img_lr = self.transform(img_lr) | |
img_degrade_hr = self.transform(img_degrade_hr) | |
img_hr = self.transform(img_hr) # ToTensor() is already in the range [0, 1] | |
return {"lr": img_lr, "degrade_hr": img_degrade_hr, "hr": img_hr} | |
def __len__(self): | |
assert(len(self.files_hr) == len(self.files_lr)) | |
return len(self.files_hr) | |