Spaces:
Running
Running
File size: 4,823 Bytes
0d80816 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import random
from torch.utils.data import ConcatDataset, Dataset
from torch.utils.data.sampler import (
BatchSampler,
RandomSampler,
Sampler,
SequentialSampler,
)
class ScheduledSampler(Sampler):
"""A sampler that samples data from a given concat-dataset.
Args:
concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
batch_size (int): batch size
holistic_shuffle (bool): whether to shuffle the whole dataset or not
logger (logging.Logger): logger to print warning message
Usage:
For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
>>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
[3, 4, 5, 0, 1, 2, 6, 7, 8]
"""
def __init__(
self,
concat_dataset,
batch_size,
holistic_shuffle,
logger=None,
loader_type="train",
):
if not isinstance(concat_dataset, ConcatDataset):
raise ValueError(
"concat_dataset must be an instance of ConcatDataset, but got {}".format(
type(concat_dataset)
)
)
if not isinstance(batch_size, int):
raise ValueError(
"batch_size must be an integer, but got {}".format(type(batch_size))
)
if not isinstance(holistic_shuffle, bool):
raise ValueError(
"holistic_shuffle must be a boolean, but got {}".format(
type(holistic_shuffle)
)
)
self.concat_dataset = concat_dataset
self.batch_size = batch_size
self.holistic_shuffle = holistic_shuffle
affected_dataset_name = []
affected_dataset_len = []
for dataset in concat_dataset.datasets:
dataset_len = len(dataset)
dataset_name = dataset.get_dataset_name()
if dataset_len < batch_size:
affected_dataset_name.append(dataset_name)
affected_dataset_len.append(dataset_len)
self.type = loader_type
for dataset_name, dataset_len in zip(
affected_dataset_name, affected_dataset_len
):
if not loader_type == "valid":
logger.warning(
"The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
loader_type, dataset_name, dataset_len, batch_size
)
)
def __len__(self):
# the number of batches with drop last
num_of_batches = sum(
[
math.floor(len(dataset) / self.batch_size)
for dataset in self.concat_dataset.datasets
]
)
# if samples are not enough for one batch, we don't drop last
if self.type == "valid" and num_of_batches < 1:
return len(self.concat_dataset)
return num_of_batches * self.batch_size
def __iter__(self):
iters = []
for dataset in self.concat_dataset.datasets:
iters.append(
SequentialSampler(dataset).__iter__()
if not self.holistic_shuffle
else RandomSampler(dataset).__iter__()
)
# e.g. [0, 200, 400]
init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
output_batches = []
for dataset_idx in range(len(self.concat_dataset.datasets)):
cur_batch = []
for idx in iters[dataset_idx]:
cur_batch.append(idx + init_indices[dataset_idx])
if len(cur_batch) == self.batch_size:
output_batches.append(cur_batch)
cur_batch = []
# if loader_type is valid, we don't need to drop last
if self.type == "valid" and len(cur_batch) > 0:
output_batches.append(cur_batch)
# force drop last in training
random.shuffle(output_batches)
output_indices = [item for sublist in output_batches for item in sublist]
return iter(output_indices)
def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
sampler = ScheduledSampler(
concat_dataset,
cfg.train.batch_size,
cfg.train.sampler.holistic_shuffle,
logger,
loader_type,
)
batch_sampler = BatchSampler(
sampler,
cfg.train.batch_size,
cfg.train.sampler.drop_last if not loader_type == "valid" else False,
)
return sampler, batch_sampler
|