Spaces:
Sleeping
Sleeping
import torch.optim as optim | |
class LayerDecayOptimizer: | |
def __init__(self, optimizer, layerwise_decay_rate): | |
self.optimizer = optimizer | |
self.layerwise_decay_rate = layerwise_decay_rate | |
self.param_groups = optimizer.param_groups | |
def step(self, *args, **kwargs): | |
for i, group in enumerate(self.optimizer.param_groups): | |
group['lr'] *= self.layerwise_decay_rate[i] | |
self.optimizer.step(*args, **kwargs) | |
def zero_grad(self, *args, **kwargs): | |
self.optimizer.zero_grad(*args, **kwargs) |