ozgurkara's picture
first commit
eb9a9b4
raw
history blame
3.09 kB
import torch
import torch.nn as nn
from annotator.mmpkg.mmcv import is_tuple_of
from annotator.mmpkg.mmcv.cnn import ConvModule
from annotator.mmpkg.mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class LRASPPHead(BaseDecodeHead):
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
This head is the improved implementation of `Searching for MobileNetV3
<https://ieeexplore.ieee.org/document/9008835>`_.
Args:
branch_channels (tuple[int]): The number of output channels in every
each branch. Default: (32, 64).
"""
def __init__(self, branch_channels=(32, 64), **kwargs):
super(LRASPPHead, self).__init__(**kwargs)
if self.input_transform != 'multiple_select':
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
f'must be \'multiple_select\'. But received '
f'\'{self.input_transform}\'')
assert is_tuple_of(branch_channels, int)
assert len(branch_channels) == len(self.in_channels) - 1
self.branch_channels = branch_channels
self.convs = nn.Sequential()
self.conv_ups = nn.Sequential()
for i in range(len(branch_channels)):
self.convs.add_module(
f'conv{i}',
nn.Conv2d(
self.in_channels[i], branch_channels[i], 1, bias=False))
self.conv_ups.add_module(
f'conv_up{i}',
ConvModule(
self.channels + branch_channels[i],
self.channels,
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
bias=False))
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
self.aspp_conv = ConvModule(
self.in_channels[-1],
self.channels,
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
bias=False)
self.image_pool = nn.Sequential(
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
ConvModule(
self.in_channels[2],
self.channels,
1,
act_cfg=dict(type='Sigmoid'),
bias=False))
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
x = inputs[-1]
x = self.aspp_conv(x) * resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
x = self.conv_up_input(x)
for i in range(len(self.branch_channels) - 1, -1, -1):
x = resize(
x,
size=inputs[i].size()[2:],
mode='bilinear',
align_corners=self.align_corners)
x = torch.cat([x, self.convs[i](inputs[i])], 1)
x = self.conv_ups[i](x)
return self.cls_seg(x)