pipeline_paddle / paddleseg /models /mla_transformer.py
sidharthism's picture
Added model *.pdparams
1ab1a09
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.models import layers
from paddleseg.cvlibs import manager
from paddleseg.utils import utils
class MLAHeads(nn.Layer):
def __init__(self, mlahead_channels=128):
super(MLAHeads, self).__init__()
self.head2 = nn.Sequential(
layers.ConvBNReLU(
mlahead_channels * 2,
mlahead_channels,
3,
padding=1,
bias_attr=False),
layers.ConvBNReLU(
mlahead_channels,
mlahead_channels,
3,
padding=1,
bias_attr=False))
self.head3 = nn.Sequential(
layers.ConvBNReLU(
mlahead_channels * 2,
mlahead_channels,
3,
padding=1,
bias_attr=False),
layers.ConvBNReLU(
mlahead_channels,
mlahead_channels,
3,
padding=1,
bias_attr=False))
self.head4 = nn.Sequential(
layers.ConvBNReLU(
mlahead_channels * 2,
mlahead_channels,
3,
padding=1,
bias_attr=False),
layers.ConvBNReLU(
mlahead_channels,
mlahead_channels,
3,
padding=1,
bias_attr=False))
self.head5 = nn.Sequential(
layers.ConvBNReLU(
mlahead_channels * 2,
mlahead_channels,
3,
padding=1,
bias_attr=False),
layers.ConvBNReLU(
mlahead_channels,
mlahead_channels,
3,
padding=1,
bias_attr=False))
def forward(self, mla_p2, mla_p3, mla_p4, mla_p5):
head2 = F.interpolate(
self.head2(mla_p2),
size=(4 * mla_p2.shape[3], 4 * mla_p2.shape[3]),
mode='bilinear',
align_corners=True)
head3 = F.interpolate(
self.head3(mla_p3),
size=(4 * mla_p3.shape[3], 4 * mla_p3.shape[3]),
mode='bilinear',
align_corners=True)
head4 = F.interpolate(
self.head4(mla_p4),
size=(4 * mla_p4.shape[3], 4 * mla_p4.shape[3]),
mode='bilinear',
align_corners=True)
head5 = F.interpolate(
self.head5(mla_p5),
size=(4 * mla_p5.shape[3], 4 * mla_p5.shape[3]),
mode='bilinear',
align_corners=True)
return paddle.concat([head2, head3, head4, head5], axis=1)
@manager.MODELS.add_component
class MLATransformer(nn.Layer):
def __init__(self,
num_classes,
in_channels,
backbone,
mlahead_channels=128,
aux_channels=256,
norm_layer=nn.BatchNorm2D,
pretrained=None,
**kwargs):
super(MLATransformer, self).__init__()
self.BatchNorm = norm_layer
self.mlahead_channels = mlahead_channels
self.num_classes = num_classes
self.in_channels = in_channels
self.backbone = backbone
self.mlahead = MLAHeads(mlahead_channels=self.mlahead_channels)
self.cls = nn.Conv2D(
4 * self.mlahead_channels, self.num_classes, 3, padding=1)
self.conv0 = layers.ConvBNReLU(
self.in_channels[0],
self.in_channels[0] * 2,
3,
padding=1,
bias_attr=False)
self.conv1 = layers.ConvBNReLU(
self.in_channels[1],
self.in_channels[1],
3,
padding=1,
bias_attr=False)
self.conv21 = layers.ConvBNReLU(
self.in_channels[2],
self.in_channels[2],
3,
padding=1,
bias_attr=False)
self.conv22 = layers.ConvBNReLU(
self.in_channels[2],
self.in_channels[2] // 2,
3,
padding=1,
bias_attr=False)
self.conv31 = layers.ConvBNReLU(
self.in_channels[3],
self.in_channels[3],
3,
padding=1,
bias_attr=False)
self.conv32 = layers.ConvBNReLU(
self.in_channels[3],
self.in_channels[3] // 2,
3,
padding=1,
bias_attr=False)
self.conv33 = layers.ConvBNReLU(
self.in_channels[3] // 2,
self.in_channels[3] // 4,
3,
padding=1,
bias_attr=False)
self.aux_head = nn.Sequential(
layers.ConvBN(
in_channels=self.in_channels[2],
out_channels=aux_channels,
kernel_size=3,
padding=1,
bias_attr=False),
nn.Conv2D(
in_channels=aux_channels,
out_channels=self.num_classes,
kernel_size=1, ))
self.pretrained = pretrained
self.init_weight()
def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)
def forward(self, x):
inputs = self.backbone(x)
inputs0 = self.conv0(inputs[0])
inputs1 = F.interpolate(
self.conv1(inputs[1]),
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=True)
inputs2 = F.interpolate(
self.conv21(inputs[2]),
scale_factor=2,
mode='bilinear',
align_corners=True)
inputs2 = F.interpolate(
self.conv22(inputs2),
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=True)
inputs3 = F.interpolate(
self.conv31(inputs[3]),
scale_factor=2,
mode='bilinear',
align_corners=True)
inputs3 = F.interpolate(
self.conv32(inputs3),
scale_factor=2,
mode='bilinear',
align_corners=True)
inputs3 = F.interpolate(
self.conv33(inputs3),
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=True)
inputs2 = inputs2 + inputs3
inputs1 = inputs1 + inputs2
inputs0 = inputs0 + inputs1
feats = self.mlahead(inputs0, inputs1, inputs2, inputs3)
logit = self.cls(feats)
logit_list = [logit]
if self.training:
logit_list.append(self.aux_head(inputs[2]))
logit_list = [
F.interpolate(
logit, paddle.shape(x)[2:], mode='bilinear', align_corners=True)
for logit in logit_list
]
return logit_list