Weiyun1025's picture
Upload folder using huggingface_hub
2abfccb verified
from torch.utils.data.dataloader import (
torch, python_multiprocessing, multiprocessing, ExceptionWrapper, threading,
itertools, queue, string_classes, _utils, _DatasetKind, _InfiniteConstantSampler,
IterableDataset, SequentialSampler, RandomSampler, BatchSampler)
from ._utils.pin_memory import _pin_memory_loop
from ._utils.worker import _worker_loop, _ResumeIteration
from petrel_client.utils.profile import profileit
class DataLoader(object):
__initialized = False
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None,
*, prefetch_factor=2, persistent_workers=False):
torch._C._log_api_usage_once("python.data_loader")
if num_workers < 0:
raise ValueError('num_workers option should be non-negative; '
'use num_workers=0 to disable multiprocessing.')
if timeout < 0:
raise ValueError('timeout option should be non-negative')
if num_workers == 0 and prefetch_factor != 2:
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
'let num_workers > 0 to enable multiprocessing.')
assert prefetch_factor > 0
if persistent_workers and num_workers == 0:
raise ValueError('persistent_workers option needs num_workers > 0')
self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
# Arg-check dataset related before checking samplers because we want to
# tell users that iterable-style datasets are incompatible with custom
# samplers first, so that they don't learn that this combo doesn't work
# after spending time fixing the custom sampler errors.
if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
# NOTE [ Custom Samplers and `IterableDataset` ]
#
# `IterableDataset` does not support custom `batch_sampler` or
# `sampler` since the key is irrelevant (unless we support
# generator-style dataset one day...).
#
# For `sampler`, we always create a dummy sampler. This is an
# infinite sampler even when the dataset may have an implemented
# finite `__len__` because in multi-process data loading, naive
# settings will return duplicated data (which may be desired), and
# thus using a sampler with length matching that of dataset will
# cause data lost (you may have duplicates of the first couple
# batches, but never see anything afterwards). Therefore,
# `Iterabledataset` always uses an infinite sampler, an instance of
# `_InfiniteConstantSampler` defined above.
#
# A custom `batch_sampler` essentially only controls the batch size.
# However, it is unclear how useful it would be since an iterable-style
# dataset can handle that within itself. Moreover, it is pointless
# in multi-process data loading as the assignment order of batches
# to workers is an implementation detail so users can not control
# how to batchify each worker's iterable. Thus, we disable this
# option. If this turns out to be useful in future, we can re-enable
# this, and support custom samplers that specify the assignments to
# specific workers.
if shuffle is not False:
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"shuffle option, but got shuffle={}".format(shuffle))
elif sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"sampler option, but got sampler={}".format(sampler))
elif batch_sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"batch_sampler option, but got batch_sampler={}".format(batch_sampler))
else:
self._dataset_kind = _DatasetKind.Map
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with '
'shuffle')
if batch_sampler is not None:
# auto_collation with custom batch_sampler
if batch_size != 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
batch_size = None
drop_last = False
elif batch_size is None:
# no auto_collation
if shuffle or drop_last:
raise ValueError('batch_size=None option disables auto-batching '
'and is mutually exclusive with '
'shuffle, and drop_last')
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert
self.collate_fn = collate_fn
self.persistent_workers = persistent_workers
self.__initialized = True
self._iterator = None
def _get_iterator(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
@property
def multiprocessing_context(self):
return self.__multiprocessing_context
@multiprocessing_context.setter
def multiprocessing_context(self, multiprocessing_context):
if multiprocessing_context is not None:
if self.num_workers > 0:
if not multiprocessing._supports_context:
raise ValueError('multiprocessing_context relies on Python >= 3.4, with '
'support for different start methods')
if isinstance(multiprocessing_context, string_classes):
valid_start_methods = multiprocessing.get_all_start_methods()
if multiprocessing_context not in valid_start_methods:
raise ValueError(
('multiprocessing_context option '
'should specify a valid start method in {}, but got '
'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context))
multiprocessing_context = multiprocessing.get_context(
multiprocessing_context)
if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
raise ValueError(('multiprocessing_context option should be a valid context '
'object or a string specifying the start method, but got '
'multiprocessing_context={}').format(multiprocessing_context))
else:
raise ValueError(('multiprocessing_context can only be used with '
'multi-process loading (num_workers > 0), but got '
'num_workers={}').format(self.num_workers))
self.__multiprocessing_context = multiprocessing_context
def __setattr__(self, attr, val):
if self.__initialized and attr in (
'batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset', 'persistent_workers'):
raise ValueError('{} attribute should not be set after {} is '
'initialized'.format(attr, self.__class__.__name__))
super(DataLoader, self).__setattr__(attr, val)
def __iter__(self):
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
@property
def _auto_collation(self):
return self.batch_sampler is not None
@property
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
def __len__(self):
# with iterable-style dataset, this will error
return len(self._index_sampler)
class _BaseDataLoaderIter(object):
def __init__(self, loader):
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._prefetch_factor = loader.prefetch_factor
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
self._persistent_workers = loader.persistent_workers
def __iter__(self):
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def __next__(self):
raise NotImplementedError
def __len__(self):
return len(self._index_sampler)
def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError(
"{} cannot be pickled", self.__class__.__name__)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
next = __next__ # Python 2 compatibility
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
assert self._num_workers > 0
assert self._prefetch_factor > 0
if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context
self._worker_init_fn = loader.worker_init_fn
self._worker_queue_idx_cycle = itertools.cycle(
range(self._num_workers))
self._worker_result_queue = multiprocessing_context.Queue()
self._worker_pids_set = False
self._shutdown = False
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = []
self._workers = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue()
index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed + i, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
w.daemon = True
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
# Therefore, we only add a worker to self._workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
self._data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue
_utils.signal_handling._set_worker_pids(
id(self), tuple(w.pid for w in self._workers))
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
self._reset(loader, first_iter=True)
def _reset(self, loader, first_iter=False):
super()._reset(loader, first_iter)
self._send_idx = 0 # idx of the next task to be sent to workers
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
self._task_info = {}
# always equal to count(v for v in task_info.values() if len(v) == 1)
self._tasks_outstanding = 0
# A list of booleans representing whether each worker still has work to
# do, i.e., not having exhausted its iterable dataset object. It always
# contains all `True`s if not using an iterable-style dataset
# (i.e., if kind != Iterable).
# Not that this indicates that a worker still has work to do *for this epoch*.
# It does not mean that a worker is dead. In case of `_persistent_workers`,
# the worker will be reset to available in the next epoch.
self._workers_status = [True for i in range(self._num_workers)]
# We resume the prefetching in case it was enabled
if not first_iter:
for idx in range(self._num_workers):
self._index_queues[idx].put(_ResumeIteration())
resume_iteration_cnt = self._num_workers
while resume_iteration_cnt > 0:
data = self._get_data()
if isinstance(data, _ResumeIteration):
resume_iteration_cnt -= 1
# prime the prefetch loop
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
# Tries to fetch data from `self._data_queue` once for a given timeout.
# This can also be used as inner loop of fetching without timeout, with
# the sender status as the loop condition.
#
# This raises a `RuntimeError` if any worker died expectedly. This error
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
# (only for non-Windows platforms), or the manual check below on errors
# and timeouts.
#
# Returns a 2-tuple:
# (bool: whether successfully get data, any: data if successful else None)
try:
data = self._data_queue.get(timeout=timeout)
return (True, data)
except Exception as e:
# At timeout and error, we manually check whether any worker has
# failed. Note that this is the only mechanism for Windows to detect
# worker failures.
failed_workers = []
for worker_id, w in enumerate(self._workers):
if self._workers_status[worker_id] and not w.is_alive():
failed_workers.append(w)
self._mark_worker_as_unavailable(worker_id)
if len(failed_workers) > 0:
pids_str = ', '.join(str(w.pid) for w in failed_workers)
raise RuntimeError(
'DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
if isinstance(e, queue.Empty):
return (False, None)
raise
def _get_data(self):
# Fetches data from `self._data_queue`.
#
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
# in a loop. This is the only mechanism to detect worker failures for
# Windows. For other platforms, a SIGCHLD handler is also used for
# worker failure detection.
#
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
# died at timeouts.
if self._timeout > 0:
success, data = self._try_get_data(self._timeout)
if success:
return data
else:
raise RuntimeError(
'DataLoader timed out after {} seconds'.format(self._timeout))
elif self._pin_memory:
while self._pin_memory_thread.is_alive():
success, data = self._try_get_data()
if success:
return data
else:
# while condition is false, i.e., pin_memory_thread died.
raise RuntimeError('Pin memory thread exited unexpectedly')
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
# need to call `.task_done()` because we don't use `.join()`.
else:
while True:
success, data = self._try_get_data()
if success:
return data
@profileit
def __next__(self):
while True:
# If the worker responsible for `self._rcvd_idx` has already ended
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
# we try to advance `self._rcvd_idx` to find the next valid index.
#
# This part needs to run in the loop because both the `self._get_data()`
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead.
while self._rcvd_idx < self._send_idx:
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
# has data or is still active
if len(info) == 2 or self._workers_status[worker_id]:
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
if not self._persistent_workers:
self._shutdown_workers()
raise StopIteration
# Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data)
assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data()
self._tasks_outstanding -= 1
if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
if self._persistent_workers:
self._workers_status[data.worker_id] = False
else:
self._mark_worker_as_unavailable(data.worker_id)
self._try_put_index()
continue
if idx != self._rcvd_idx:
# store out-of-order samples
self._task_info[idx] += (data,)
else:
del self._task_info[idx]
return self._process_data(data)
next = __next__ # Python 2 compatibility
def _try_put_index(self):
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
try:
index = self._next_index()
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
else:
# not found (i.e., didn't break)
return
self._index_queues[worker_queue_idx].put((self._send_idx, index))
self._task_info[self._send_idx] = (worker_queue_idx,)
self._tasks_outstanding += 1
self._send_idx += 1
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index()
if isinstance(data, ExceptionWrapper):
data.reraise()
return data
def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
# Mark a worker as having finished its work e.g., due to
# exhausting an `IterableDataset`. This should be used only when this
# `_MultiProcessingDataLoaderIter` is going to continue running.
assert self._workers_status[worker_id] or (
self._persistent_workers and shutdown)
# Signal termination to that specific worker.
q = self._index_queues[worker_id]
# Indicate that no more data will be put on this queue by the current
# process.
q.put(None)
# Note that we don't actually join the worker here, nor do we remove the
# worker's pid from C side struct because (1) joining may be slow, and
# (2) since we don't join, the worker may still raise error, and we
# prefer capturing those, rather than ignoring them, even though they
# are raised after the worker has finished its job.
# Joinning is deferred to `_shutdown_workers`, which it is called when
# all workers finish their jobs (e.g., `IterableDataset` replicas) or
# when this iterator is garbage collected.
self._workers_status[worker_id] = False
assert self._workers_done_event.is_set() == shutdown
def _shutdown_workers(self):
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
# the logic of this function.
python_exit_status = _utils.python_exit_status
if python_exit_status is True or python_exit_status is None:
# See (2) of the note. If Python is shutting down, do no-op.
return
# Normal exit when last reference is gone / iterator is depleted.
# See (1) and the second half of the note.
if not self._shutdown:
self._shutdown = True
try:
# Exit `pin_memory_thread` first because exiting workers may leave
# corrupted data in `worker_result_queue` which `pin_memory_thread`
# reads from.
if hasattr(self, '_pin_memory_thread'):
# Use hasattr in case error happens before we set the attribute.
self._pin_memory_thread_done_event.set()
# Send something to pin_memory_thread in case it is waiting
# so that it can wake up and check `pin_memory_thread_done_event`
self._worker_result_queue.put((None, None))
self._pin_memory_thread.join()
self._worker_result_queue.cancel_join_thread()
self._worker_result_queue.close()
# Exit workers now.
self._workers_done_event.set()
for worker_id in range(len(self._workers)):
# Get number of workers from `len(self._workers)` instead of
# `self._num_workers` in case we error before starting all
# workers.
# If we are using workers_status with persistent_workers
# we have to shut it down because the worker is paused
if self._persistent_workers or self._workers_status[worker_id]:
self._mark_worker_as_unavailable(
worker_id, shutdown=True)
for w in self._workers:
w.join()
for q in self._index_queues:
q.cancel_join_thread()
q.close()
finally:
# Even though all this function does is putting into queues that
# we have called `cancel_join_thread` on, weird things can
# happen when a worker is killed by a signal, e.g., hanging in
# `Event.set()`. So we need to guard this with SIGCHLD handler,
# and remove pids from the C side data structure only at the
# end.
#
# FIXME: Unfortunately, for Windows, we are missing a worker
# error detection mechanism here in this function, as it
# doesn't provide a SIGCHLD handler.
if self._worker_pids_set:
_utils.signal_handling._remove_worker_pids(id(self))
self._worker_pids_set = False
for w in self._workers:
if w.is_alive():
# Existing mechanisms try to make the workers exit
# peacefully, but in case that we unfortunately reach
# here, which we shouldn't, (e.g., pytorch/pytorch#39570),
# we kill the worker.
w.terminate()
def __del__(self):
self._shutdown_workers()