import math from torch.optim.lr_scheduler import LambdaLR from functools import partial def get_scheduler( optimizer, start_lr, max_lr, min_lr, warmup_epochs, sustain_epochs, total_epochs, decay, mode="cosine", ): def lr_lambda(epoch): if epoch < warmup_epochs: return (max_lr - start_lr) / warmup_epochs * epoch + start_lr elif epoch < warmup_epochs + sustain_epochs: return max_lr elif mode == "exponential": return (max_lr - min_lr) * decay ** ( epoch - warmup_epochs - sustain_epochs ) + min_lr elif mode == "step": return max_lr * decay ** ((epoch - warmup_epochs - sustain_epochs) // 2) elif mode == "cosine": decay_total_epochs = total_epochs - warmup_epochs - sustain_epochs + 3 decay_epoch_index = epoch - warmup_epochs - sustain_epochs phase = math.pi * decay_epoch_index / decay_total_epochs cosine_decay = 0.5 * (1 + math.cos(phase)) return (max_lr - min_lr) * cosine_decay + min_lr else: raise ValueError( f"Unsupported mode '{mode}'. Supported modes are 'exp', 'step', 'cosine'." ) return LambdaLR(optimizer, lr_lambda) def _get_cosine_schedule_with_warmup_lr_lambda( current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, ): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float( max(1, num_training_steps - num_warmup_steps) ) return max( 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) ) def get_cosine_schedule_with_warmup( optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_cycles (`float`, *optional*, defaults to 0.5): The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine). last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_lambda = partial( _get_cosine_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, ) return LambdaLR(optimizer, lr_lambda, last_epoch)