Maksym-Lysyi's picture
initial commit
e3641b1
raw
history blame contribute delete
567 Bytes
import torch.optim as optim
class LayerDecayOptimizer:
def __init__(self, optimizer, layerwise_decay_rate):
self.optimizer = optimizer
self.layerwise_decay_rate = layerwise_decay_rate
self.param_groups = optimizer.param_groups
def step(self, *args, **kwargs):
for i, group in enumerate(self.optimizer.param_groups):
group['lr'] *= self.layerwise_decay_rate[i]
self.optimizer.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
self.optimizer.zero_grad(*args, **kwargs)