Spaces:
Running
on
A10G
Running
on
A10G
import numpy as np | |
import torch.nn as nn | |
from annotator.mmpkg.mmcv.cnn import ConvModule | |
from annotator.mmpkg.mmseg.ops import resize | |
from ..builder import HEADS | |
from .decode_head import BaseDecodeHead | |
class FPNHead(BaseDecodeHead): | |
"""Panoptic Feature Pyramid Networks. | |
This head is the implementation of `Semantic FPN | |
<https://arxiv.org/abs/1901.02446>`_. | |
Args: | |
feature_strides (tuple[int]): The strides for input feature maps. | |
stack_lateral. All strides suppose to be power of 2. The first | |
one is of largest resolution. | |
""" | |
def __init__(self, feature_strides, **kwargs): | |
super(FPNHead, self).__init__( | |
input_transform='multiple_select', **kwargs) | |
assert len(feature_strides) == len(self.in_channels) | |
assert min(feature_strides) == feature_strides[0] | |
self.feature_strides = feature_strides | |
self.scale_heads = nn.ModuleList() | |
for i in range(len(feature_strides)): | |
head_length = max( | |
1, | |
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) | |
scale_head = [] | |
for k in range(head_length): | |
scale_head.append( | |
ConvModule( | |
self.in_channels[i] if k == 0 else self.channels, | |
self.channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg)) | |
if feature_strides[i] != feature_strides[0]: | |
scale_head.append( | |
nn.Upsample( | |
scale_factor=2, | |
mode='bilinear', | |
align_corners=self.align_corners)) | |
self.scale_heads.append(nn.Sequential(*scale_head)) | |
def forward(self, inputs): | |
x = self._transform_inputs(inputs) | |
output = self.scale_heads[0](x[0]) | |
for i in range(1, len(self.feature_strides)): | |
# non inplace | |
output = output + resize( | |
self.scale_heads[i](x[i]), | |
size=output.shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
output = self.cls_seg(output) | |
return output | |