oguzakif's picture
init repo
d4b77ac
raw
history blame
2.54 kB
import logging
import torch
import torch.utils.data
from importlib import import_module
def create_dataloader(phase, dataset, dataset_opt, opt=None, sampler=None):
logger = logging.getLogger('base')
if phase == 'train':
num_workers = dataset_opt['n_workers'] * opt['world_size']
batch_size = dataset_opt['batch_size']
if sampler is not None:
logger.info('N_workers: {}, batch_size: {} DDP train dataloader has been established'.format(num_workers,
batch_size))
return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
num_workers=num_workers, sampler=sampler,
pin_memory=True)
else:
logger.info('N_workers: {}, batch_size: {} train dataloader has been established'.format(num_workers,
batch_size))
return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
num_workers=num_workers, shuffle=True,
pin_memory=True)
else:
logger.info(
'N_workers: {}, batch_size: {} validate/test dataloader has been established'.format(
dataset_opt['n_workers'],
dataset_opt['batch_size']))
return torch.utils.data.DataLoader(dataset, batch_size=dataset_opt['batch_size'], shuffle=False,
num_workers=dataset_opt['n_workers'],
pin_memory=False)
def create_dataset(dataset_opt, dataInfo, phase, dataset_name):
if phase == 'train':
dataset_package = import_module('data.{}'.format(dataset_name))
dataset = dataset_package.VideoBasedDataset(dataset_opt, dataInfo)
mode = dataset_opt['mode']
logger = logging.getLogger('base')
logger.info(
'{} train dataset [{:s} - {:s} - {:s}] is created.'.format(dataset_opt['type'].upper(),
dataset.__class__.__name__,
dataset_opt['name'], mode))
else: # validate and test dataset
return ValueError('No dataset initialized for valdataset')
return dataset