Spaces:
Runtime error
Runtime error
import math | |
from typing import Optional | |
import numpy as np | |
from typing import Dict, Optional, Union | |
from mmengine.registry import HOOKS | |
from .param_scheduler_hook import ParamSchedulerHook | |
DATA_BATCH = Optional[Union[dict, tuple, list]] | |
def linear_fn(lr_factor: float, max_epochs: int): | |
"""Generate linear function.""" | |
return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor | |
def cosine_fn(lr_factor: float, max_epochs: int): | |
"""Generate cosine function.""" | |
return lambda x: ( | |
(1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1 | |
class YOLOv5ParamSchedulerHook(ParamSchedulerHook): | |
"""A hook to update learning rate and momentum in optimizer of YOLOv5.""" | |
priority = 9 | |
scheduler_maps = {'linear': linear_fn, 'cosine': cosine_fn} | |
def __init__(self, | |
scheduler_type: str = 'linear', | |
lr_factor: float = 0.01, | |
max_epochs: int = 300, | |
warmup_epochs: int = 3, | |
warmup_bias_lr: float = 0.1, | |
warmup_momentum: float = 0.8, | |
warmup_mim_iter: int = 500, | |
**kwargs): | |
assert scheduler_type in self.scheduler_maps | |
self.warmup_epochs = warmup_epochs | |
self.warmup_bias_lr = warmup_bias_lr | |
self.warmup_momentum = warmup_momentum | |
self.warmup_mim_iter = warmup_mim_iter | |
kwargs.update({'lr_factor': lr_factor, 'max_epochs': max_epochs}) | |
self.scheduler_fn = self.scheduler_maps[scheduler_type](**kwargs) | |
self._warmup_end = False | |
self._base_lr = None | |
self._base_momentum = None | |
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
optimizer = trainer.optimizers[0] | |
for group in optimizer.param_groups: | |
# If the param is never be scheduled, record the current value | |
# as the initial value. | |
group.setdefault('initial_lr', group['lr']) | |
group.setdefault('initial_momentum', group.get('momentum', -1)) | |
self._base_lr = [ | |
group['initial_lr'] for group in optimizer.param_groups | |
] | |
self._base_momentum = [ | |
group['initial_momentum'] for group in optimizer.param_groups | |
] | |
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss) -> None: | |
cur_iters = trainer.global_step | |
cur_epoch = trainer.current_epoch | |
optimizer = trainer.optimizers[0] | |
# The minimum warmup is self.warmup_mim_iter | |
warmup_total_iters = max( | |
round(self.warmup_epochs * len(trainer.train_dataloader)), | |
self.warmup_mim_iter) | |
if cur_iters <= warmup_total_iters: | |
xp = [0, warmup_total_iters] | |
for group_idx, param in enumerate(optimizer.param_groups): | |
if group_idx == 2: | |
# bias learning rate will be handled specially | |
yp = [ | |
self.warmup_bias_lr, | |
self._base_lr[group_idx] * self.scheduler_fn(cur_epoch) | |
] | |
else: | |
yp = [ | |
0.0, | |
self._base_lr[group_idx] * self.scheduler_fn(cur_epoch) | |
] | |
param['lr'] = np.interp(cur_iters, xp, yp) | |
if 'momentum' in param: | |
param['momentum'] = np.interp( | |
cur_iters, xp, | |
[self.warmup_momentum, self._base_momentum[group_idx]]) | |
else: | |
self._warmup_end = True | |
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
if not self._warmup_end: | |
return | |
cur_epoch = trainer.current_epoch | |
optimizer = trainer.optimizers[0] | |
for group_idx, param in enumerate(optimizer.param_groups): | |
param['lr'] = self._base_lr[group_idx] * self.scheduler_fn( | |
cur_epoch) | |