import torch.optim as optim import warnings class StepLRScheduler(optim.lr_scheduler._LRScheduler): def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False): self.step_size = step_size self.gamma = gamma super(StepLRScheduler, self).__init__(optimizer, last_epoch, verbose) def get_lr(self): if not self._get_lr_called_within_step: warnings.warn("To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning) if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): return [group['lr'] for group in self.optimizer.param_groups] return [group['lr'] * self.gamma for group in self.optimizer.param_groups] def _get_closed_form_lr(self): return [base_lr * self.gamma ** (self.last_epoch // self.step_size) for base_lr in self.base_lrs]