sadjava's picture
changed to pipelines
fd52b7f
from pathlib import Path
from typing import Any, Callable, Optional
import numpy as np
from PIL import Image
from torchvision.datasets.vision import VisionDataset
class SegmentationDataset(VisionDataset):
"""A PyTorch dataset for image segmentation task.
The dataset is compatible with torchvision transforms.
The transforms passed would be applied to both the Images and Masks.
"""
def __init__(self,
root: str,
image_folder: str,
mask_folder: str,
transforms: Optional[Callable] = None,
seed: int = None,
fraction: float = None,
subset: str = None,
image_color_mode: str = "rgb",
mask_color_mode: str = "grayscale") -> None:
"""
Args:
root (str): Root directory path.
image_folder (str): Name of the folder that contains the images in the root directory.
mask_folder (str): Name of the folder that contains the masks in the root directory.
transforms (Optional[Callable], optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.ToTensor`` for images. Defaults to None.
seed (int, optional): Specify a seed for the train and test split for reproducible results. Defaults to None.
fraction (float, optional): A float value from 0 to 1 which specifies the validation split fraction. Defaults to None.
subset (str, optional): 'Train' or 'Test' to select the appropriate set. Defaults to None.
image_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'rgb'.
mask_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'grayscale'.
Raises:
OSError: If image folder doesn't exist in root.
OSError: If mask folder doesn't exist in root.
ValueError: If subset is not either 'Train' or 'Test'
ValueError: If image_color_mode and mask_color_mode are either 'rgb' or 'grayscale'
"""
super().__init__(root, transforms)
image_folder_path = Path(self.root) / image_folder
mask_folder_path = Path(self.root) / mask_folder
if not image_folder_path.exists():
raise OSError(f"{image_folder_path} does not exist.")
if not mask_folder_path.exists():
raise OSError(f"{mask_folder_path} does not exist.")
if image_color_mode not in ["rgb", "grayscale"]:
raise ValueError(
f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale."
)
if mask_color_mode not in ["rgb", "grayscale"]:
raise ValueError(
f"{mask_color_mode} is an invalid choice. Please enter from rgb grayscale."
)
self.image_color_mode = image_color_mode
self.mask_color_mode = mask_color_mode
if not fraction:
self.image_names = sorted(image_folder_path.glob("*"))
self.mask_names = sorted(mask_folder_path.glob("*"))
else:
if subset not in ["Train", "Test"]:
raise (ValueError(
f"{subset} is not a valid input. Acceptable values are Train and Test."
))
self.fraction = fraction
self.image_list = np.array(sorted(image_folder_path.glob("*")))
self.mask_list = np.array(sorted(mask_folder_path.glob("*")))
if seed:
np.random.seed(seed)
indices = np.arange(len(self.image_list))
np.random.shuffle(indices)
self.image_list = self.image_list[indices]
self.mask_list = self.mask_list[indices]
if subset == "Train":
self.image_names = self.image_list[:int(
np.ceil(len(self.image_list) * (1 - self.fraction)))]
self.mask_names = self.mask_list[:int(
np.ceil(len(self.mask_list) * (1 - self.fraction)))]
else:
self.image_names = self.image_list[
int(np.ceil(len(self.image_list) * (1 - self.fraction))):]
self.mask_names = self.mask_list[
int(np.ceil(len(self.mask_list) * (1 - self.fraction))):]
def __len__(self) -> int:
return len(self.image_names)
def __getitem__(self, index: int) -> Any:
image_path = self.image_names[index]
mask_path = self.mask_names[index]
with open(image_path, "rb") as image_file, open(mask_path,
"rb") as mask_file:
image = Image.open(image_file)
if self.image_color_mode == "rgb":
image = image.convert("RGB")
elif self.image_color_mode == "grayscale":
image = image.convert("L")
mask = Image.open(mask_file)
if self.mask_color_mode == "rgb":
mask = mask.convert("RGB")
elif self.mask_color_mode == "grayscale":
mask = mask.convert("L")
sample = {"image": image, "mask": mask}
if self.transforms:
sample["image"] = self.transforms(sample["image"])
sample["mask"] = self.transforms(sample["mask"])
return sample