Spaces:
Sleeping
Sleeping
import logging | |
import torch | |
import torch.utils.data | |
from importlib import import_module | |
def create_dataloader(phase, dataset, dataset_opt, opt=None, sampler=None): | |
logger = logging.getLogger('base') | |
if phase == 'train': | |
num_workers = dataset_opt['n_workers'] * opt['world_size'] | |
batch_size = dataset_opt['batch_size'] | |
if sampler is not None: | |
logger.info('N_workers: {}, batch_size: {} DDP train dataloader has been established'.format(num_workers, | |
batch_size)) | |
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, | |
num_workers=num_workers, sampler=sampler, | |
pin_memory=True) | |
else: | |
logger.info('N_workers: {}, batch_size: {} train dataloader has been established'.format(num_workers, | |
batch_size)) | |
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, | |
num_workers=num_workers, shuffle=True, | |
pin_memory=True) | |
else: | |
logger.info( | |
'N_workers: {}, batch_size: {} validate/test dataloader has been established'.format( | |
dataset_opt['n_workers'], | |
dataset_opt['batch_size'])) | |
return torch.utils.data.DataLoader(dataset, batch_size=dataset_opt['batch_size'], shuffle=False, | |
num_workers=dataset_opt['n_workers'], | |
pin_memory=False) | |
def create_dataset(dataset_opt, dataInfo, phase, dataset_name): | |
if phase == 'train': | |
dataset_package = import_module('data.{}'.format(dataset_name)) | |
dataset = dataset_package.VideoBasedDataset(dataset_opt, dataInfo) | |
mode = dataset_opt['mode'] | |
logger = logging.getLogger('base') | |
logger.info( | |
'{} train dataset [{:s} - {:s} - {:s}] is created.'.format(dataset_opt['type'].upper(), | |
dataset.__class__.__name__, | |
dataset_opt['name'], mode)) | |
else: # validate and test dataset | |
return ValueError('No dataset initialized for valdataset') | |
return dataset | |