mambazjp's picture
Upload 57 files
8b79d57
raw
history blame contribute delete
898 Bytes
import os
from torch.utils.data import Dataset
from PIL import Image
class SuperviselyPersonDataset(Dataset):
def __init__(self, imgdir, segdir, transform=None):
self.img_dir = imgdir
self.img_files = sorted(os.listdir(imgdir))
self.seg_dir = segdir
self.seg_files = sorted(os.listdir(segdir))
assert len(self.img_files) == len(self.seg_files)
self.transform = transform
def __len__(self):
return len(self.img_files)
def __getitem__(self, idx):
with Image.open(os.path.join(self.img_dir, self.img_files[idx])) as img, \
Image.open(os.path.join(self.seg_dir, self.seg_files[idx])) as seg:
img = img.convert('RGB')
seg = seg.convert('L')
if self.transform is not None:
img, seg = self.transform(img, seg)
return img, seg