sadjava's picture
changed to pipelines
fd52b7f
raw
history blame
No virus
5.43 kB
from pathlib import Path
from typing import Any, Callable, Optional
import numpy as np
from PIL import Image
from torchvision.datasets.vision import VisionDataset
class SegmentationDataset(VisionDataset):
"""A PyTorch dataset for image segmentation task.
The dataset is compatible with torchvision transforms.
The transforms passed would be applied to both the Images and Masks.
"""
def __init__(self,
root: str,
image_folder: str,
mask_folder: str,
transforms: Optional[Callable] = None,
seed: int = None,
fraction: float = None,
subset: str = None,
image_color_mode: str = "rgb",
mask_color_mode: str = "grayscale") -> None:
"""
Args:
root (str): Root directory path.
image_folder (str): Name of the folder that contains the images in the root directory.
mask_folder (str): Name of the folder that contains the masks in the root directory.
transforms (Optional[Callable], optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.ToTensor`` for images. Defaults to None.
seed (int, optional): Specify a seed for the train and test split for reproducible results. Defaults to None.
fraction (float, optional): A float value from 0 to 1 which specifies the validation split fraction. Defaults to None.
subset (str, optional): 'Train' or 'Test' to select the appropriate set. Defaults to None.
image_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'rgb'.
mask_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'grayscale'.
Raises:
OSError: If image folder doesn't exist in root.
OSError: If mask folder doesn't exist in root.
ValueError: If subset is not either 'Train' or 'Test'
ValueError: If image_color_mode and mask_color_mode are either 'rgb' or 'grayscale'
"""
super().__init__(root, transforms)
image_folder_path = Path(self.root) / image_folder
mask_folder_path = Path(self.root) / mask_folder
if not image_folder_path.exists():
raise OSError(f"{image_folder_path} does not exist.")
if not mask_folder_path.exists():
raise OSError(f"{mask_folder_path} does not exist.")
if image_color_mode not in ["rgb", "grayscale"]:
raise ValueError(
f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale."
)
if mask_color_mode not in ["rgb", "grayscale"]:
raise ValueError(
f"{mask_color_mode} is an invalid choice. Please enter from rgb grayscale."
)
self.image_color_mode = image_color_mode
self.mask_color_mode = mask_color_mode
if not fraction:
self.image_names = sorted(image_folder_path.glob("*"))
self.mask_names = sorted(mask_folder_path.glob("*"))
else:
if subset not in ["Train", "Test"]:
raise (ValueError(
f"{subset} is not a valid input. Acceptable values are Train and Test."
))
self.fraction = fraction
self.image_list = np.array(sorted(image_folder_path.glob("*")))
self.mask_list = np.array(sorted(mask_folder_path.glob("*")))
if seed:
np.random.seed(seed)
indices = np.arange(len(self.image_list))
np.random.shuffle(indices)
self.image_list = self.image_list[indices]
self.mask_list = self.mask_list[indices]
if subset == "Train":
self.image_names = self.image_list[:int(
np.ceil(len(self.image_list) * (1 - self.fraction)))]
self.mask_names = self.mask_list[:int(
np.ceil(len(self.mask_list) * (1 - self.fraction)))]
else:
self.image_names = self.image_list[
int(np.ceil(len(self.image_list) * (1 - self.fraction))):]
self.mask_names = self.mask_list[
int(np.ceil(len(self.mask_list) * (1 - self.fraction))):]
def __len__(self) -> int:
return len(self.image_names)
def __getitem__(self, index: int) -> Any:
image_path = self.image_names[index]
mask_path = self.mask_names[index]
with open(image_path, "rb") as image_file, open(mask_path,
"rb") as mask_file:
image = Image.open(image_file)
if self.image_color_mode == "rgb":
image = image.convert("RGB")
elif self.image_color_mode == "grayscale":
image = image.convert("L")
mask = Image.open(mask_file)
if self.mask_color_mode == "rgb":
mask = mask.convert("RGB")
elif self.mask_color_mode == "grayscale":
mask = mask.convert("L")
sample = {"image": image, "mask": mask}
if self.transforms:
sample["image"] = self.transforms(sample["image"])
sample["mask"] = self.transforms(sample["mask"])
return sample