File size: 2,193 Bytes
2ac90ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------
import json
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
"""
Parameter groups for layer-wise lr decay
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
"""
param_group_names = {}
param_groups = {}
num_layers = len(model.blocks) + 1
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
for n, p in model.named_parameters():
if not p.requires_grad:
continue
# no decay: all 1D parameters and model specific ones
if p.ndim == 1 or n in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.
else:
g_decay = "decay"
this_decay = weight_decay
layer_id = get_layer_id_for_vit(n, num_layers)
group_name = "layer_%d_%s" % (layer_id, g_decay)
if group_name not in param_group_names:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_group_names[group_name]["params"].append(n)
param_groups[group_name]["params"].append(p)
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values())
def get_layer_id_for_vit(name, num_layers):
"""
Assign a parameter with its layer id
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
"""
if name in ['cls_token', 'pos_embed']:
return 0
elif name.startswith('patch_embed'):
return 0
elif name.startswith('blocks'):
return int(name.split('.')[1]) + 1
else:
return num_layers |