File size: 260 Bytes
b3f324b
 
 
 
 
 
 
1
2
3
4
5
6
7
import torch

def cosine_scheduler(step, max_steps, value_base=1, value_end=0):
    step = torch.tensor(step)
    cosine_value = 0.5 * (1 + torch.cos(torch.pi * step / max_steps))
    value = value_end + (value_base - value_end) * cosine_value
    return value