Spaces:
Running
on
A10G
Running
on
A10G
import torch | |
import torch.nn as nn | |
from annotator.mmpkg.mmcv.cnn import ConvModule, build_upsample_layer | |
class UpConvBlock(nn.Module): | |
"""Upsample convolution block in decoder for UNet. | |
This upsample convolution block consists of one upsample module | |
followed by one convolution block. The upsample module expands the | |
high-level low-resolution feature map and the convolution block fuses | |
the upsampled high-level low-resolution feature map and the low-level | |
high-resolution feature map from encoder. | |
Args: | |
conv_block (nn.Sequential): Sequential of convolutional layers. | |
in_channels (int): Number of input channels of the high-level | |
skip_channels (int): Number of input channels of the low-level | |
high-resolution feature map from encoder. | |
out_channels (int): Number of output channels. | |
num_convs (int): Number of convolutional layers in the conv_block. | |
Default: 2. | |
stride (int): Stride of convolutional layer in conv_block. Default: 1. | |
dilation (int): Dilation rate of convolutional layer in conv_block. | |
Default: 1. | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Default: False. | |
conv_cfg (dict | None): Config dict for convolution layer. | |
Default: None. | |
norm_cfg (dict | None): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
act_cfg (dict | None): Config dict for activation layer in ConvModule. | |
Default: dict(type='ReLU'). | |
upsample_cfg (dict): The upsample config of the upsample module in | |
decoder. Default: dict(type='InterpConv'). If the size of | |
high-level feature map is the same as that of skip feature map | |
(low-level feature map from encoder), it does not need upsample the | |
high-level feature map and the upsample_cfg is None. | |
dcn (bool): Use deformable convolution in convolutional layer or not. | |
Default: None. | |
plugins (dict): plugins for convolutional layers. Default: None. | |
""" | |
def __init__(self, | |
conv_block, | |
in_channels, | |
skip_channels, | |
out_channels, | |
num_convs=2, | |
stride=1, | |
dilation=1, | |
with_cp=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
upsample_cfg=dict(type='InterpConv'), | |
dcn=None, | |
plugins=None): | |
super(UpConvBlock, self).__init__() | |
assert dcn is None, 'Not implemented yet.' | |
assert plugins is None, 'Not implemented yet.' | |
self.conv_block = conv_block( | |
in_channels=2 * skip_channels, | |
out_channels=out_channels, | |
num_convs=num_convs, | |
stride=stride, | |
dilation=dilation, | |
with_cp=with_cp, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
dcn=None, | |
plugins=None) | |
if upsample_cfg is not None: | |
self.upsample = build_upsample_layer( | |
cfg=upsample_cfg, | |
in_channels=in_channels, | |
out_channels=skip_channels, | |
with_cp=with_cp, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
else: | |
self.upsample = ConvModule( | |
in_channels, | |
skip_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
def forward(self, skip, x): | |
"""Forward function.""" | |
x = self.upsample(x) | |
out = torch.cat([skip, x], dim=1) | |
out = self.conv_block(out) | |
return out | |