DifFace / datapipe /datasets.py
Zongsheng
first upload
06f26d7
raw
history blame
5.96 kB
import random
import numpy as np
from pathlib import Path
from ResizeRight.resize_right import resize
from einops import rearrange
import torch
import torchvision as thv
from torch.utils.data import Dataset
from utils import util_sisr
from utils import util_image
from utils import util_common
from basicsr.data.realesrgan_dataset import RealESRGANDataset
from .ffhq_degradation_dataset import FFHQDegradationDataset
def get_transforms(transform_type, out_size, sf):
if transform_type == 'default':
transform = thv.transforms.Compose([
util_image.SpatialAug(),
thv.transforms.ToTensor(),
thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
elif transform_type == 'face':
transform = thv.transforms.Compose([
thv.transforms.ToTensor(),
thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
elif transform_type == 'bicubic':
transform = thv.transforms.Compose([
util_sisr.Bicubic(1/sf),
thv.transforms.ToTensor(),
thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
else:
raise ValueError(f'Unexpected transform_variant {transform_variant}')
return transform
def create_dataset(dataset_config):
if dataset_config['type'] == 'gfpgan':
dataset = FFHQDegradationDataset(dataset_config['params'])
elif dataset_config['type'] == 'face':
dataset = BaseDatasetFace(**dataset_config['params'])
elif dataset_config['type'] == 'bicubic':
dataset = DatasetBicubic(**dataset_config['params'])
elif dataset_config['type'] == 'folder':
dataset = BaseDataFolder(**dataset_config['params'])
elif dataset_config['type'] == 'realesrgan':
dataset = RealESRGANDataset(dataset_config['params'])
else:
raise NotImplementedError(dataset_config['type'])
return dataset
class BaseDatasetFace(Dataset):
def __init__(self, celeba_txt=None,
ffhq_txt=None,
out_size=256,
transform_type='face',
sf=None,
length=None):
super().__init__()
self.files_names = util_common.readline_txt(celeba_txt) + util_common.readline_txt(ffhq_txt)
if length is None:
self.length = len(self.files_names)
else:
self.length = length
self.transform = get_transforms(transform_type, out_size, sf)
def __len__(self):
return self.length
def __getitem__(self, index):
im_path = self.files_names[index]
im = util_image.imread(im_path, chn='rgb', dtype='uint8')
im = self.transform(im)
return {'image':im,}
class DatasetBicubic(Dataset):
def __init__(self,
files_txt=None,
val_dir=None,
ext='png',
sf=None,
up_back=False,
need_gt_path=False,
length=None):
super().__init__()
if val_dir is None:
self.files_names = util_common.readline_txt(files_txt)
else:
self.files_names = [str(x) for x in Path(val_dir).glob(f"*.{ext}")]
self.sf = sf
self.up_back = up_back
self.need_gt_path = need_gt_path
if length is None:
self.length = len(self.files_names)
else:
self.length = length
def __len__(self):
return self.length
def __getitem__(self, index):
im_path = self.files_names[index]
im_gt = util_image.imread(im_path, chn='rgb', dtype='float32')
im_lq = resize(im_gt, scale_factors=1/self.sf)
if self.up_back:
im_lq = resize(im_lq, scale_factors=self.sf)
im_lq = rearrange(im_lq, 'h w c -> c h w')
im_lq = torch.from_numpy(im_lq).type(torch.float32)
im_gt = rearrange(im_gt, 'h w c -> c h w')
im_gt = torch.from_numpy(im_gt).type(torch.float32)
if self.need_gt_path:
return {'lq':im_lq, 'gt':im_gt, 'gt_path':im_path}
else:
return {'lq':im_lq, 'gt':im_gt}
class BaseDataFolder(Dataset):
def __init__(
self,
dir_path,
dir_path_gt,
need_gt_path=True,
length=None,
ext=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'],
mean=0.5,
std=0.5,
):
super(BaseDataFolder, self).__init__()
if isinstance(ext, str):
files_path = [str(x) for x in Path(dir_path).glob(f'*.{ext}')]
else:
assert isinstance(ext, list) or isinstance(ext, tuple)
files_path = []
for current_ext in ext:
files_path.extend([str(x) for x in Path(dir_path).glob(f'*.{current_ext}')])
self.files_path = files_path if length is None else files_path[:length]
self.dir_path_gt = dir_path_gt
self.need_gt_path = need_gt_path
self.mean=mean
self.std=std
def __len__(self):
return len(self.files_path)
def __getitem__(self, index):
im_path = self.files_path[index]
im = util_image.imread(im_path, chn='rgb', dtype='float32')
im = util_image.normalize_np(im, mean=self.mean, std=self.std, reverse=False)
im = rearrange(im, 'h w c -> c h w')
out_dict = {'image':im.astype(np.float32), 'lq':im.astype(np.float32)}
if self.need_gt_path:
out_dict['path'] = im_path
if self.dir_path_gt is not None:
gt_path = str(Path(self.dir_path_gt) / Path(im_path).name)
im_gt = util_image.imread(gt_path, chn='rgb', dtype='float32')
im_gt = util_image.normalize_np(im_gt, mean=self.mean, std=self.std, reverse=False)
im_gt = rearrange(im_gt, 'h w c -> c h w')
out_dict['gt'] = im_gt.astype(np.float32)
return out_dict