File size: 5,146 Bytes
87b62f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import torch.nn as nn
from annotator.mmpkg.mmcv.utils import is_tuple_of
from annotator.mmpkg.mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
from .registry import NORM_LAYERS
NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
def infer_abbr(class_type):
"""Infer abbreviation from the class name.
When we build a norm layer with `build_norm_layer()`, we want to preserve
the norm type in variable names, e.g, self.bn1, This method will
infer the abbreviation to map class types to abbreviations.
Rule 1: If the class has the property "_abbr_", return the property.
Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
"in" respectively.
Rule 3: If the class name contains "batch", "group", "layer" or "instance",
the abbreviation of this layer will be "bn", "gn", "ln" and "in"
Rule 4: Otherwise, the abbreviation falls back to "norm".
class_type (type): The norm layer type.
str: The inferred abbreviation.
if not inspect.isclass(class_type):
raise TypeError(
f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, '_abbr_'):
return class_type._abbr_
if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
return 'in'
elif issubclass(class_type, _BatchNorm):
return 'bn'
elif issubclass(class_type, nn.GroupNorm):
return 'gn'
elif issubclass(class_type, nn.LayerNorm):
return 'ln'
class_name = class_type.__name__.lower()
if 'batch' in class_name:
return 'bn'
elif 'group' in class_name:
return 'gn'
elif 'layer' in class_name:
return 'ln'
elif 'instance' in class_name:
return 'in'
return 'norm_layer'
def build_norm_layer(cfg, num_features, postfix=''):
"""Build normalization layer.
cfg (dict): The norm layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate a norm layer.
- requires_grad (bool, optional): Whether stop gradient updates.
num_features (int): Number of input channels.
postfix (int | str): The postfix to be appended into norm abbreviation
to create named layer.
(str, nn.Module): The first element is the layer name consisting of
abbreviation and postfix, e.g., bn1, gn. The second element is the
created norm layer.
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in NORM_LAYERS:
raise KeyError(f'Unrecognized norm type {layer_type}')
norm_layer = NORM_LAYERS.get(layer_type)
abbr = infer_abbr(norm_layer)
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
requires_grad = cfg_.pop('requires_grad', True)
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_)
for param in layer.parameters():
param.requires_grad = requires_grad
return name, layer
def is_norm(layer, exclude=None):
"""Check if a layer is a normalization layer.
layer (nn.Module): The layer to be checked.
exclude (type | tuple[type]): Types to be excluded.
bool: Whether the layer is a norm layer.
if exclude is not None:
if not isinstance(exclude, tuple):
exclude = (exclude, )
if not is_tuple_of(exclude, type):
raise TypeError(
f'"exclude" must be either None or type or a tuple of types, '
f'but got {type(exclude)}: {exclude}')
if exclude and isinstance(layer, exclude):
return False
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
return isinstance(layer, all_norm_bases)