|
import torch.utils.data |
|
from data.base_data_loader import BaseDataLoader |
|
|
|
|
|
def CreateDataset(opt): |
|
dataset = None |
|
if opt.dataset_mode == 'aligned': |
|
from data.aligned_dataset import AlignedDataset |
|
dataset = AlignedDataset() |
|
elif opt.dataset_mode == 'unaligned': |
|
from data.unaligned_dataset import UnalignedDataset |
|
dataset = UnalignedDataset() |
|
elif opt.dataset_mode == 'unaligned_random_crop': |
|
from data.unaligned_random_crop import UnalignedDataset |
|
dataset = UnalignedDataset() |
|
elif opt.dataset_mode == 'pair': |
|
from data.pair_dataset import PairDataset |
|
dataset = PairDataset() |
|
elif opt.dataset_mode == 'syn': |
|
from data.syn_dataset import PairDataset |
|
dataset = PairDataset() |
|
elif opt.dataset_mode == 'single': |
|
from data.single_dataset import SingleDataset |
|
dataset = SingleDataset() |
|
else: |
|
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) |
|
|
|
print("dataset [%s] was created" % (dataset.name())) |
|
dataset.initialize(opt) |
|
return dataset |
|
|
|
|
|
class CustomDatasetDataLoader(BaseDataLoader): |
|
def name(self): |
|
return 'CustomDatasetDataLoader' |
|
|
|
def initialize(self, opt): |
|
BaseDataLoader.initialize(self, opt) |
|
self.dataset = CreateDataset(opt) |
|
self.dataloader = torch.utils.data.DataLoader( |
|
self.dataset, |
|
batch_size=opt.batchSize, |
|
shuffle=not opt.serial_batches, |
|
num_workers=int(opt.nThreads)) |
|
|
|
def load_data(self): |
|
return self.dataloader |
|
|
|
def __len__(self): |
|
return min(len(self.dataset), self.opt.max_dataset_size) |
|
|