# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle import paddle.nn as nn import paddle.nn.functional as F from paddleseg import utils from paddleseg.models import layers from paddleseg.cvlibs import manager @manager.MODELS.add_component class MobileSeg(nn.Layer): """ The semantic segmentation models for mobile devices. Args: num_classes (int): The number of target classes. backbone(nn.Layer): Backbone network, such as stdc1net and resnet18. The backbone must has feat_channels, of which the length is 5. backbone_indices (List(int), optional): The values indicate the indices of output of backbone. Default: [2, 3, 4]. cm_bin_sizes (List(int), optional): The bin size of context module. Default: [1,2,4]. cm_out_ch (int, optional): The output channel of the last context module. Default: 128. arm_type (str, optional): The type of attention refinement module. Default: ARM_Add_SpAttenAdd3. arm_out_chs (List(int), optional): The out channels of each arm module. Default: [64, 96, 128]. seg_head_inter_chs (List(int), optional): The intermediate channels of segmentation head. Default: [64, 64, 64]. resize_mode (str, optional): The resize mode for the upsampling operation in decoder. Default: bilinear. use_last_fuse (bool, optional): Whether use fusion in the last. Default: False. pretrained (str, optional): The path or url of pretrained model. Default: None. """ def __init__(self, num_classes, backbone, backbone_indices=[1, 2, 3], cm_bin_sizes=[1, 2], cm_out_ch=64, arm_type='UAFMMobile', arm_out_chs=[32, 48, 64], seg_head_inter_chs=[32, 32, 32], resize_mode='bilinear', use_last_fuse=False, pretrained=None): super().__init__() # backbone assert hasattr(backbone, 'feat_channels'), \ "The backbone should has feat_channels." assert len(backbone.feat_channels) >= len(backbone_indices), \ f"The length of input backbone_indices ({len(backbone_indices)}) should not be" \ f"greater than the length of feat_channels ({len(backbone.feat_channels)})." assert len(backbone.feat_channels) > max(backbone_indices), \ f"The max value ({max(backbone_indices)}) of backbone_indices should be " \ f"less than the length of feat_channels ({len(backbone.feat_channels)})." self.backbone = backbone assert len(backbone_indices) >= 1, "The lenght of backbone_indices " \ "should not be lesser than 1" self.backbone_indices = backbone_indices # [..., x16_id, x32_id] backbone_out_chs = [backbone.feat_channels[i] for i in backbone_indices] # head if len(arm_out_chs) == 1: arm_out_chs = arm_out_chs * len(backbone_indices) assert len(arm_out_chs) == len(backbone_indices), "The length of " \ "arm_out_chs and backbone_indices should be equal" self.ppseg_head = MobileSegHead(backbone_out_chs, arm_out_chs, cm_bin_sizes, cm_out_ch, arm_type, resize_mode, use_last_fuse) if len(seg_head_inter_chs) == 1: seg_head_inter_chs = seg_head_inter_chs * len(backbone_indices) assert len(seg_head_inter_chs) == len(backbone_indices), "The length of " \ "seg_head_inter_chs and backbone_indices should be equal" self.seg_heads = nn.LayerList() # [..., head_16, head32] for in_ch, mid_ch in zip(arm_out_chs, seg_head_inter_chs): self.seg_heads.append(SegHead(in_ch, mid_ch, num_classes)) # pretrained self.pretrained = pretrained self.init_weight() def forward(self, x): x_hw = paddle.shape(x)[2:] feats_backbone = self.backbone(x) # [x4, x8, x16, x32] assert len(feats_backbone) >= len(self.backbone_indices), \ f"The nums of backbone feats ({len(feats_backbone)}) should be greater or " \ f"equal than the nums of backbone_indices ({len(self.backbone_indices)})" feats_selected = [feats_backbone[i] for i in self.backbone_indices] feats_head = self.ppseg_head(feats_selected) # [..., x8, x16, x32] if self.training: logit_list = [] for x, seg_head in zip(feats_head, self.seg_heads): x = seg_head(x) logit_list.append(x) logit_list = [ F.interpolate( x, x_hw, mode='bilinear', align_corners=False) for x in logit_list ] else: x = self.seg_heads[0](feats_head[0]) x = F.interpolate(x, x_hw, mode='bilinear', align_corners=False) logit_list = [x] return logit_list def init_weight(self): if self.pretrained is not None: utils.load_entire_model(self, self.pretrained) class MobileSegHead(nn.Layer): """ The head of MobileSeg. Args: backbone_out_chs (List(Tensor)): The channels of output tensors in the backbone. arm_out_chs (List(int)): The out channels of each arm module. cm_bin_sizes (List(int)): The bin size of context module. cm_out_ch (int): The output channel of the last context module. arm_type (str): The type of attention refinement module. resize_mode (str): The resize mode for the upsampling operation in decoder. """ def __init__(self, backbone_out_chs, arm_out_chs, cm_bin_sizes, cm_out_ch, arm_type, resize_mode, use_last_fuse): super().__init__() self.cm = MobileContextModule(backbone_out_chs[-1], cm_out_ch, cm_out_ch, cm_bin_sizes) assert hasattr(layers,arm_type), \ "Not support arm_type ({})".format(arm_type) arm_class = eval("layers." + arm_type) self.arm_list = nn.LayerList() # [..., arm8, arm16, arm32] for i in range(len(backbone_out_chs)): low_chs = backbone_out_chs[i] high_ch = cm_out_ch if i == len( backbone_out_chs) - 1 else arm_out_chs[i + 1] out_ch = arm_out_chs[i] arm = arm_class( low_chs, high_ch, out_ch, ksize=3, resize_mode=resize_mode) self.arm_list.append(arm) self.use_last_fuse = use_last_fuse if self.use_last_fuse: self.fuse_convs = nn.LayerList() for i in range(1, len(arm_out_chs)): conv = layers.SeparableConvBNReLU( arm_out_chs[i], arm_out_chs[0], kernel_size=3, bias_attr=False) self.fuse_convs.append(conv) self.last_conv = layers.SeparableConvBNReLU( len(arm_out_chs) * arm_out_chs[0], arm_out_chs[0], kernel_size=3, bias_attr=False) def forward(self, in_feat_list): """ Args: in_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32]. x2, x4 and x8 are optional. Returns: out_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32]. x2, x4 and x8 are optional. The length of in_feat_list and out_feat_list are the same. """ high_feat = self.cm(in_feat_list[-1]) out_feat_list = [] for i in reversed(range(len(in_feat_list))): low_feat = in_feat_list[i] arm = self.arm_list[i] high_feat = arm(low_feat, high_feat) out_feat_list.insert(0, high_feat) if self.use_last_fuse: x_list = [out_feat_list[0]] size = paddle.shape(out_feat_list[0])[2:] for i, (x, conv ) in enumerate(zip(out_feat_list[1:], self.fuse_convs)): x = conv(x) x = F.interpolate( x, size=size, mode='bilinear', align_corners=False) x_list.append(x) x = paddle.concat(x_list, axis=1) x = self.last_conv(x) out_feat_list[0] = x return out_feat_list class MobileContextModule(nn.Layer): """ Context Module for Mobile Model. Args: in_channels (int): The number of input channels to pyramid pooling module. inter_channels (int): The number of inter channels to pyramid pooling module. out_channels (int): The number of output channels after pyramid pooling module. bin_sizes (tuple, optional): The out size of pooled feature maps. Default: (1, 3). align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. """ def __init__(self, in_channels, inter_channels, out_channels, bin_sizes, align_corners=False): super().__init__() self.stages = nn.LayerList([ self._make_stage(in_channels, inter_channels, size) for size in bin_sizes ]) self.conv_out = layers.SeparableConvBNReLU( in_channels=inter_channels, out_channels=out_channels, kernel_size=3, bias_attr=False) self.align_corners = align_corners def _make_stage(self, in_channels, out_channels, size): prior = nn.AdaptiveAvgPool2D(output_size=size) conv = layers.ConvBNReLU( in_channels=in_channels, out_channels=out_channels, kernel_size=1) return nn.Sequential(prior, conv) def forward(self, input): out = None input_shape = paddle.shape(input)[2:] for stage in self.stages: x = stage(input) x = F.interpolate( x, input_shape, mode='bilinear', align_corners=self.align_corners) if out is None: out = x else: out += x out = self.conv_out(out) return out class SegHead(nn.Layer): def __init__(self, in_chan, mid_chan, n_classes): super().__init__() self.conv = layers.SeparableConvBNReLU( in_chan, mid_chan, kernel_size=3, bias_attr=False) self.conv_out = nn.Conv2D( mid_chan, n_classes, kernel_size=1, bias_attr=False) def forward(self, x): x = self.conv(x) x = self.conv_out(x) return x