|
import os |
|
import random |
|
|
|
import torch.utils.data as data |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
|
|
|
|
class MEFDataset(data.Dataset): |
|
def __init__(self, root): |
|
self.img_root = root |
|
|
|
self.numbers = list(sorted(os.listdir(self.img_root))) |
|
print(len(self.numbers)) |
|
|
|
self.preproc = T.Compose( |
|
[T.ToTensor()] |
|
) |
|
|
|
def __getitem__(self, idx): |
|
number = self.numbers[idx] |
|
im_dir = os.path.join(self.img_root, number) |
|
fn1, fn2 = tuple(random.sample(os.listdir(im_dir), k=2)) |
|
fp1 = os.path.join(im_dir, fn1) |
|
fp2 = os.path.join(im_dir, fn2) |
|
img1 = Image.open(fp1).convert("RGB") |
|
img2 = Image.open(fp2).convert("RGB") |
|
img1 = self.preproc(img1) |
|
img2 = self.preproc(img2) |
|
|
|
fn1 = f'{number}_{fn1}' |
|
fn2 = f'{number}_{fn2}' |
|
return img1, img2, fn1, fn2 |
|
|
|
def __len__(self): |
|
return len(self.numbers) |
|
|