File size: 1,127 Bytes
8683813 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from .registry import PADDING_LAYERS
PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
def build_padding_layer(cfg, *args, **kwargs):
"""Build padding layer.
Args:
cfg (None or dict): The padding layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate a padding layer.
Returns:
nn.Module: Created padding layer.
"""
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
padding_type = cfg_.pop('type')
if padding_type not in PADDING_LAYERS:
raise KeyError(f'Unrecognized padding type {padding_type}.')
else:
padding_layer = PADDING_LAYERS.get(padding_type)
layer = padding_layer(*args, **kwargs, **cfg_)
return layer
|