|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Callable, Optional, Tuple |
|
|
|
import torch |
|
from came_pytorch import CAME |
|
from mmcv import Config |
|
from mmcv.runner import OPTIMIZER_BUILDERS, OPTIMIZERS, DefaultOptimizerConstructor |
|
from mmcv.runner import build_optimizer as mm_build_optimizer |
|
from mmcv.utils import _BatchNorm, _InstanceNorm |
|
from torch.nn import GroupNorm, LayerNorm |
|
from torch.optim.optimizer import Optimizer |
|
|
|
from .logger import get_root_logger |
|
|
|
|
|
def auto_scale_lr(effective_bs, optimizer_cfg, rule="linear", base_batch_size=256): |
|
assert rule in ["linear", "sqrt"] |
|
logger = get_root_logger() |
|
|
|
if rule == "sqrt": |
|
scale_ratio = math.sqrt(effective_bs / base_batch_size) |
|
elif rule == "linear": |
|
scale_ratio = effective_bs / base_batch_size |
|
optimizer_cfg["lr"] *= scale_ratio |
|
logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).') |
|
return scale_ratio |
|
|
|
|
|
@OPTIMIZER_BUILDERS.register_module() |
|
class MyOptimizerConstructor(DefaultOptimizerConstructor): |
|
def add_params(self, params, module, prefix="", is_dcn_module=None): |
|
"""Add all parameters of module to the params list. |
|
|
|
The parameters of the given module will be added to the list of param |
|
groups, with specific rules defined by paramwise_cfg. |
|
|
|
Args: |
|
params (list[dict]): A list of param groups, it will be modified |
|
in place. |
|
module (nn.Module): The module to be added. |
|
prefix (str): The prefix of the module |
|
|
|
""" |
|
|
|
custom_keys = self.paramwise_cfg.get("custom_keys", {}) |
|
|
|
|
|
|
|
bias_lr_mult = self.paramwise_cfg.get("bias_lr_mult", 1.0) |
|
bias_decay_mult = self.paramwise_cfg.get("bias_decay_mult", 1.0) |
|
norm_decay_mult = self.paramwise_cfg.get("norm_decay_mult", 1.0) |
|
bypass_duplicate = self.paramwise_cfg.get("bypass_duplicate", False) |
|
|
|
|
|
is_norm = isinstance(module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) |
|
|
|
for name, param in module.named_parameters(recurse=False): |
|
base_lr = self.base_lr |
|
if name == "bias" and not (is_norm or is_dcn_module): |
|
base_lr *= bias_lr_mult |
|
|
|
|
|
base_wd = self.base_wd |
|
if self.base_wd is not None: |
|
|
|
if is_norm: |
|
base_wd *= norm_decay_mult |
|
|
|
elif name == "bias" and not is_dcn_module: |
|
|
|
base_wd *= bias_decay_mult |
|
|
|
param_group = {"params": [param]} |
|
if not param.requires_grad: |
|
param_group["requires_grad"] = False |
|
params.append(param_group) |
|
continue |
|
if bypass_duplicate and self._is_in(param_group, params): |
|
logger = get_root_logger() |
|
logger.warn(f"{prefix} is duplicate. It is skipped since " f"bypass_duplicate={bypass_duplicate}") |
|
continue |
|
|
|
is_custom = False |
|
for key in custom_keys: |
|
if isinstance(key, tuple): |
|
scope, key_name = key |
|
else: |
|
scope, key_name = None, key |
|
if scope is not None and scope not in f"{prefix}": |
|
continue |
|
if key_name in f"{prefix}.{name}": |
|
is_custom = True |
|
if "lr_mult" in custom_keys[key]: |
|
|
|
|
|
|
|
param_group["lr"] = self.base_lr * custom_keys[key]["lr_mult"] |
|
elif "lr" not in param_group: |
|
param_group["lr"] = base_lr |
|
if self.base_wd is not None: |
|
if "decay_mult" in custom_keys[key]: |
|
param_group["weight_decay"] = self.base_wd * custom_keys[key]["decay_mult"] |
|
elif "weight_decay" not in param_group: |
|
param_group["weight_decay"] = base_wd |
|
|
|
if not is_custom: |
|
|
|
|
|
if base_lr != self.base_lr: |
|
param_group["lr"] = base_lr |
|
if base_wd != self.base_wd: |
|
param_group["weight_decay"] = base_wd |
|
params.append(param_group) |
|
|
|
for child_name, child_mod in module.named_children(): |
|
child_prefix = f"{prefix}.{child_name}" if prefix else child_name |
|
self.add_params(params, child_mod, prefix=child_prefix, is_dcn_module=is_dcn_module) |
|
|
|
|
|
def build_optimizer(model, optimizer_cfg): |
|
|
|
logger = get_root_logger() |
|
|
|
if hasattr(model, "module"): |
|
model = model.module |
|
|
|
optimizer_cfg.setdefault("constructor", "MyOptimizerConstructor") |
|
|
|
custom_keys = dict() |
|
for name, module in model.named_modules(): |
|
if hasattr(module, "zero_weight_decay"): |
|
custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay}) |
|
|
|
paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys))) |
|
given_cfg = optimizer_cfg.get("paramwise_cfg") |
|
if given_cfg: |
|
paramwise_cfg.merge_from_dict(dict(cfg=given_cfg)) |
|
optimizer_cfg["paramwise_cfg"] = paramwise_cfg.cfg |
|
|
|
optimizer = mm_build_optimizer(model, optimizer_cfg) |
|
|
|
weight_decay_groups = dict() |
|
lr_groups = dict() |
|
for group in optimizer.param_groups: |
|
if not group.get("requires_grad", True): |
|
continue |
|
lr_groups.setdefault(group["lr"], []).append(group) |
|
weight_decay_groups.setdefault(group["weight_decay"], []).append(group) |
|
|
|
learnable_count, fix_count = 0, 0 |
|
for p in model.parameters(): |
|
if p.requires_grad: |
|
learnable_count += 1 |
|
else: |
|
fix_count += 1 |
|
fix_info = f"{learnable_count} are learnable, {fix_count} are fix" |
|
lr_info = "Lr group: " + ", ".join([f"{len(group)} params with lr {lr:.5f}" for lr, group in lr_groups.items()]) |
|
wd_info = "Weight decay group: " + ", ".join( |
|
[f"{len(group)} params with weight decay {wd}" for wd, group in weight_decay_groups.items()] |
|
) |
|
opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}." |
|
logger.info(opt_info) |
|
|
|
return optimizer |
|
|
|
|
|
@OPTIMIZERS.register_module() |
|
class Lion(Optimizer): |
|
def __init__( |
|
self, |
|
params, |
|
lr: float = 1e-4, |
|
betas: Tuple[float, float] = (0.9, 0.99), |
|
weight_decay: float = 0.0, |
|
): |
|
assert lr > 0.0 |
|
assert all([0.0 <= beta <= 1.0 for beta in betas]) |
|
|
|
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) |
|
|
|
super().__init__(params, defaults) |
|
|
|
@staticmethod |
|
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): |
|
|
|
p.data.mul_(1 - lr * wd) |
|
|
|
|
|
update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() |
|
p.add_(update, alpha=-lr) |
|
|
|
|
|
exp_avg.lerp_(grad, 1 - beta2) |
|
|
|
@staticmethod |
|
def exists(val): |
|
return val is not None |
|
|
|
@torch.no_grad() |
|
def step(self, closure: Optional[Callable] = None): |
|
|
|
loss = None |
|
if self.exists(closure): |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
for p in filter(lambda p: self.exists(p.grad), group["params"]): |
|
|
|
grad, lr, wd, beta1, beta2, state = ( |
|
p.grad, |
|
group["lr"], |
|
group["weight_decay"], |
|
*group["betas"], |
|
self.state[p], |
|
) |
|
|
|
|
|
if len(state) == 0: |
|
state["exp_avg"] = torch.zeros_like(p) |
|
|
|
exp_avg = state["exp_avg"] |
|
|
|
self.update_fn(p, grad, exp_avg, lr, wd, beta1, beta2) |
|
|
|
return loss |
|
|
|
|
|
@OPTIMIZERS.register_module() |
|
class CAMEWrapper(CAME): |
|
def __init__(self, *args, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
|