import os import torch.utils.data as data import torchvision.transforms as T from PIL import Image class LowLightDatasetTest(data.Dataset): def __init__(self, root, reside=False): self.root = root self.items = [] subsets = os.listdir(root) for subset in subsets: img_root = os.path.join(root, subset) img_names = list(sorted(os.listdir(img_root))) for img_name in img_names: self.items.append(( os.path.join(img_root, img_name), subset, img_name )) self.preproc = T.Compose( [T.ToTensor()] ) self.preproc_raw = T.Compose( [T.ToTensor()] ) def __getitem__(self, idx): img_path, subset, img_name = self.items[idx] img = Image.open(img_path).convert("RGB") img = img.resize((img.width // 8 * 8, img.height // 8 * 8), Image.ANTIALIAS) img_raw = self.preproc_raw(img) return img_raw, subset, img_name def __len__(self): return len(self.items)