Spaces:
Runtime error
Runtime error
from typing import Dict, Optional, Union, Any | |
from lightning.pytorch.utilities.types import STEP_OUTPUT | |
from mmengine.optim import _ParamScheduler | |
from mmpl.registry import HOOKS | |
from mmengine.utils import is_list_of | |
from lightning import Callback | |
DATA_BATCH = Optional[Union[dict, tuple, list]] | |
class ParamSchedulerHook(Callback): | |
"""A hook to update some hyper-parameters in optimizer, e.g., learning rate | |
and momentum.""" | |
priority = 'LOW' | |
def on_train_batch_end( | |
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int | |
) -> None: | |
"""Call step function for each scheduler after each training iteration. | |
Args: | |
runner (Runner): The runner of the training process. | |
batch_idx (int): The index of the current batch in the train loop. | |
data_batch (dict or tuple or list, optional): Data from dataloader. | |
In order to keep this interface consistent with other hooks, | |
we keep ``data_batch`` here. | |
outputs (dict, optional): Outputs from model. | |
In order to keep this interface consistent with other hooks, we | |
keep ``data_batch`` here. | |
""" | |
param_schedulers = pl_module.lr_schedulers() | |
if param_schedulers is None: | |
return | |
def step(param_schedulers): | |
assert isinstance(param_schedulers, list) | |
for scheduler in param_schedulers: | |
if not scheduler.by_epoch: | |
scheduler.step() | |
if isinstance(param_schedulers, _ParamScheduler): | |
param_schedulers = [param_schedulers] | |
if isinstance(param_schedulers, list): | |
step(param_schedulers) | |
elif isinstance(param_schedulers, dict): | |
for param_schedulers in param_schedulers.values(): | |
step(param_schedulers) | |
else: | |
raise TypeError( | |
'runner.param_schedulers should be list of ParamScheduler or ' | |
'a dict containing list of ParamScheduler, ' | |
f'but got {param_schedulers}') | |
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
"""Call step function for each scheduler after each training epoch. | |
Args: | |
runner (Runner): The runner of the training process. | |
""" | |
param_schedulers = pl_module.lr_schedulers() | |
if param_schedulers is None: | |
return | |
def step(param_schedulers): | |
assert isinstance(param_schedulers, list) | |
for scheduler in param_schedulers: | |
if scheduler.by_epoch: | |
scheduler.step() | |
if isinstance(param_schedulers, _ParamScheduler): | |
param_schedulers = [param_schedulers] | |
if isinstance(param_schedulers, list): | |
step(param_schedulers) | |
elif isinstance(param_schedulers, dict): | |
for param_schedulers in param_schedulers.values(): | |
step(param_schedulers) | |
else: | |
raise TypeError( | |
'runner.param_schedulers should be list of ParamScheduler or ' | |
'a dict containing list of ParamScheduler, ' | |
f'but got {param_schedulers}') | |
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
"""Call step function for each scheduler which has attribute | |
``need_val_args`` after each validation epoch. | |
Args: | |
runner (Runner): The runner of the validation process. | |
metrics (Dict[str, float], optional): Evaluation results of all | |
metrics on validation dataset. The keys are the names of the | |
metrics, and the values are corresponding results. | |
Note: | |
if ``runner.param_schedulers`` is not built before, | |
the hook ``after_val_epoch`` will be skipped. | |
""" | |
param_schedulers = pl_module.lr_schedulers() | |
if param_schedulers is None: | |
return | |
# avoid counting scheduler._global_step | |
# it has counted in after_train_* hook | |
metrics = trainer.callback_metrics | |
if metrics is None: | |
return | |
def step(param_schedulers): | |
# check param_schedulers is list and built | |
if not is_list_of(param_schedulers, _ParamScheduler): | |
return | |
for scheduler in param_schedulers: | |
if (scheduler.by_epoch | |
and getattr(scheduler, 'need_val_args', False)): | |
scheduler.step(metrics) | |
if isinstance(param_schedulers, _ParamScheduler): | |
param_schedulers = [param_schedulers] | |
if isinstance(param_schedulers, list): | |
step(param_schedulers) | |
elif isinstance(param_schedulers, dict): | |
for param_schedulers in param_schedulers.values(): | |
step(param_schedulers) | |
else: | |
raise TypeError( | |
'runner.param_schedulers should be list of ParamScheduler or ' | |
'a dict containing list of ParamScheduler, ' | |
f'but got {param_schedulers}') | |