|
import os |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
|
|
|
|
class SuperviselyPersonDataset(Dataset): |
|
def __init__(self, imgdir, segdir, transform=None): |
|
self.img_dir = imgdir |
|
self.img_files = sorted(os.listdir(imgdir)) |
|
self.seg_dir = segdir |
|
self.seg_files = sorted(os.listdir(segdir)) |
|
assert len(self.img_files) == len(self.seg_files) |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.img_files) |
|
|
|
def __getitem__(self, idx): |
|
with Image.open(os.path.join(self.img_dir, self.img_files[idx])) as img, \ |
|
Image.open(os.path.join(self.seg_dir, self.seg_files[idx])) as seg: |
|
img = img.convert('RGB') |
|
seg = seg.convert('L') |
|
|
|
if self.transform is not None: |
|
img, seg = self.transform(img, seg) |
|
|
|
return img, seg |
|
|