RSPrompter / mmpl /engine /hooks /yolov5_param_scheduler_hook.py
KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
4.08 kB
import math
from typing import Optional
import numpy as np
from typing import Dict, Optional, Union
from mmengine.registry import HOOKS
from .param_scheduler_hook import ParamSchedulerHook
DATA_BATCH = Optional[Union[dict, tuple, list]]
def linear_fn(lr_factor: float, max_epochs: int):
"""Generate linear function."""
return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor
def cosine_fn(lr_factor: float, max_epochs: int):
"""Generate cosine function."""
return lambda x: (
(1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1
@HOOKS.register_module()
class YOLOv5ParamSchedulerHook(ParamSchedulerHook):
"""A hook to update learning rate and momentum in optimizer of YOLOv5."""
priority = 9
scheduler_maps = {'linear': linear_fn, 'cosine': cosine_fn}
def __init__(self,
scheduler_type: str = 'linear',
lr_factor: float = 0.01,
max_epochs: int = 300,
warmup_epochs: int = 3,
warmup_bias_lr: float = 0.1,
warmup_momentum: float = 0.8,
warmup_mim_iter: int = 500,
**kwargs):
assert scheduler_type in self.scheduler_maps
self.warmup_epochs = warmup_epochs
self.warmup_bias_lr = warmup_bias_lr
self.warmup_momentum = warmup_momentum
self.warmup_mim_iter = warmup_mim_iter
kwargs.update({'lr_factor': lr_factor, 'max_epochs': max_epochs})
self.scheduler_fn = self.scheduler_maps[scheduler_type](**kwargs)
self._warmup_end = False
self._base_lr = None
self._base_momentum = None
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
optimizer = trainer.optimizers[0]
for group in optimizer.param_groups:
# If the param is never be scheduled, record the current value
# as the initial value.
group.setdefault('initial_lr', group['lr'])
group.setdefault('initial_momentum', group.get('momentum', -1))
self._base_lr = [
group['initial_lr'] for group in optimizer.param_groups
]
self._base_momentum = [
group['initial_momentum'] for group in optimizer.param_groups
]
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss) -> None:
cur_iters = trainer.global_step
cur_epoch = trainer.current_epoch
optimizer = trainer.optimizers[0]
# The minimum warmup is self.warmup_mim_iter
warmup_total_iters = max(
round(self.warmup_epochs * len(trainer.train_dataloader)),
self.warmup_mim_iter)
if cur_iters <= warmup_total_iters:
xp = [0, warmup_total_iters]
for group_idx, param in enumerate(optimizer.param_groups):
if group_idx == 2:
# bias learning rate will be handled specially
yp = [
self.warmup_bias_lr,
self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
]
else:
yp = [
0.0,
self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
]
param['lr'] = np.interp(cur_iters, xp, yp)
if 'momentum' in param:
param['momentum'] = np.interp(
cur_iters, xp,
[self.warmup_momentum, self._base_momentum[group_idx]])
else:
self._warmup_end = True
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self._warmup_end:
return
cur_epoch = trainer.current_epoch
optimizer = trainer.optimizers[0]
for group_idx, param in enumerate(optimizer.param_groups):
param['lr'] = self._base_lr[group_idx] * self.scheduler_fn(
cur_epoch)