Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from typing import Any, Callable, Optional | |
import numpy as np | |
from PIL import Image | |
from torchvision.datasets.vision import VisionDataset | |
class SegmentationDataset(VisionDataset): | |
"""A PyTorch dataset for image segmentation task. | |
The dataset is compatible with torchvision transforms. | |
The transforms passed would be applied to both the Images and Masks. | |
""" | |
def __init__(self, | |
root: str, | |
image_folder: str, | |
mask_folder: str, | |
transforms: Optional[Callable] = None, | |
seed: int = None, | |
fraction: float = None, | |
subset: str = None, | |
image_color_mode: str = "rgb", | |
mask_color_mode: str = "grayscale") -> None: | |
""" | |
Args: | |
root (str): Root directory path. | |
image_folder (str): Name of the folder that contains the images in the root directory. | |
mask_folder (str): Name of the folder that contains the masks in the root directory. | |
transforms (Optional[Callable], optional): A function/transform that takes in | |
a sample and returns a transformed version. | |
E.g, ``transforms.ToTensor`` for images. Defaults to None. | |
seed (int, optional): Specify a seed for the train and test split for reproducible results. Defaults to None. | |
fraction (float, optional): A float value from 0 to 1 which specifies the validation split fraction. Defaults to None. | |
subset (str, optional): 'Train' or 'Test' to select the appropriate set. Defaults to None. | |
image_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'rgb'. | |
mask_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'grayscale'. | |
Raises: | |
OSError: If image folder doesn't exist in root. | |
OSError: If mask folder doesn't exist in root. | |
ValueError: If subset is not either 'Train' or 'Test' | |
ValueError: If image_color_mode and mask_color_mode are either 'rgb' or 'grayscale' | |
""" | |
super().__init__(root, transforms) | |
image_folder_path = Path(self.root) / image_folder | |
mask_folder_path = Path(self.root) / mask_folder | |
if not image_folder_path.exists(): | |
raise OSError(f"{image_folder_path} does not exist.") | |
if not mask_folder_path.exists(): | |
raise OSError(f"{mask_folder_path} does not exist.") | |
if image_color_mode not in ["rgb", "grayscale"]: | |
raise ValueError( | |
f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale." | |
) | |
if mask_color_mode not in ["rgb", "grayscale"]: | |
raise ValueError( | |
f"{mask_color_mode} is an invalid choice. Please enter from rgb grayscale." | |
) | |
self.image_color_mode = image_color_mode | |
self.mask_color_mode = mask_color_mode | |
if not fraction: | |
self.image_names = sorted(image_folder_path.glob("*")) | |
self.mask_names = sorted(mask_folder_path.glob("*")) | |
else: | |
if subset not in ["Train", "Test"]: | |
raise (ValueError( | |
f"{subset} is not a valid input. Acceptable values are Train and Test." | |
)) | |
self.fraction = fraction | |
self.image_list = np.array(sorted(image_folder_path.glob("*"))) | |
self.mask_list = np.array(sorted(mask_folder_path.glob("*"))) | |
if seed: | |
np.random.seed(seed) | |
indices = np.arange(len(self.image_list)) | |
np.random.shuffle(indices) | |
self.image_list = self.image_list[indices] | |
self.mask_list = self.mask_list[indices] | |
if subset == "Train": | |
self.image_names = self.image_list[:int( | |
np.ceil(len(self.image_list) * (1 - self.fraction)))] | |
self.mask_names = self.mask_list[:int( | |
np.ceil(len(self.mask_list) * (1 - self.fraction)))] | |
else: | |
self.image_names = self.image_list[ | |
int(np.ceil(len(self.image_list) * (1 - self.fraction))):] | |
self.mask_names = self.mask_list[ | |
int(np.ceil(len(self.mask_list) * (1 - self.fraction))):] | |
def __len__(self) -> int: | |
return len(self.image_names) | |
def __getitem__(self, index: int) -> Any: | |
image_path = self.image_names[index] | |
mask_path = self.mask_names[index] | |
with open(image_path, "rb") as image_file, open(mask_path, | |
"rb") as mask_file: | |
image = Image.open(image_file) | |
if self.image_color_mode == "rgb": | |
image = image.convert("RGB") | |
elif self.image_color_mode == "grayscale": | |
image = image.convert("L") | |
mask = Image.open(mask_file) | |
if self.mask_color_mode == "rgb": | |
mask = mask.convert("RGB") | |
elif self.mask_color_mode == "grayscale": | |
mask = mask.convert("L") | |
sample = {"image": image, "mask": mask} | |
if self.transforms: | |
sample["image"] = self.transforms(sample["image"]) | |
sample["mask"] = self.transforms(sample["mask"]) | |
return sample |