File size: 847 Bytes
8683813 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
# Copyright (c) OpenMMLab. All rights reserved.
from .hook import HOOKS, Hook
@HOOKS.register_module()
class DistSamplerSeedHook(Hook):
"""Data-loading sampler for distributed training.
When distributed training, it is only useful in conjunction with
:obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
purpose with :obj:`IterLoader`.
"""
def before_epoch(self, runner):
if hasattr(runner.data_loader.sampler, 'set_epoch'):
# in case the data loader uses `SequentialSampler` in Pytorch
runner.data_loader.sampler.set_epoch(runner.epoch)
elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
# batch sampler in pytorch warps the sampler as its attributes.
runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
|