File size: 764 Bytes
03b684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torch.optim.lr_scheduler import _LRScheduler
import math


class CosineLR(_LRScheduler):
    def __init__(self, optimizer, init_lr, total_epochs, last_epoch=-1):
        super(CosineLR, self).__init__(optimizer, last_epoch=-1)
        self.optimizer = optimizer
        self.init_lr = init_lr
        self.total_epochs = total_epochs
        self.last_epoch = last_epoch
        print(f'CosineLR start from epoch(step) {last_epoch} with init_lr {init_lr} ')

    def get_lr(self):
        if self.last_epoch == 0:
            return [group['lr'] for group in self.optimizer.param_groups]

        return [0.5 * (1 + math.cos(self.last_epoch * math.pi / self.total_epochs)) * self.init_lr for group in
                self.optimizer.param_groups]