RSPrompter / mmpl /engine /hooks /pipeline_switch_hook.py
KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
1.68 kB
from mmcv.transforms import Compose
from mmpl.registry import HOOKS
from lightning.pytorch.callbacks import Callback
@HOOKS.register_module()
class PipelineSwitchHook(Callback):
"""Switch data pipeline at switch_epoch.
Args:
switch_epoch (int): switch pipeline at this epoch.
switch_pipeline (list[dict]): the pipeline to switch to.
"""
def __init__(self, switch_epoch, switch_pipeline):
self.switch_epoch = switch_epoch
self.switch_pipeline = switch_pipeline
self._restart_dataloader = False
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""switch pipeline."""
epoch = trainer.current_epoch
train_loader = trainer.train_dataloader
if epoch == self.switch_epoch:
if trainer.local_rank == 0:
print('Switch pipeline now!')
# The dataset pipeline cannot be updated when persistent_workers
# is True, so we need to force the dataloader's multi-process
# restart. This is a very hacky approach.
train_loader.dataset.pipeline = Compose(self.switch_pipeline)
if hasattr(train_loader, 'persistent_workers'
) and train_loader.persistent_workers is True:
train_loader._DataLoader__initialized = False
train_loader._iterator = None
self._restart_dataloader = True
else:
# Once the restart is complete, we need to restore
# the initialization flag.
if self._restart_dataloader:
train_loader._DataLoader__initialized = True