Spaces:
Running
on
A10G
Running
on
A10G
from abc import ABCMeta, abstractmethod | |
import torch | |
import torch.nn as nn | |
from annotator.mmpkg.mmcv.cnn import normal_init | |
from annotator.mmpkg.mmcv.runner import auto_fp16, force_fp32 | |
from annotator.mmpkg.mmseg.core import build_pixel_sampler | |
from annotator.mmpkg.mmseg.ops import resize | |
from ..builder import build_loss | |
from ..losses import accuracy | |
class BaseDecodeHead(nn.Module, metaclass=ABCMeta): | |
"""Base class for BaseDecodeHead. | |
Args: | |
in_channels (int|Sequence[int]): Input channels. | |
channels (int): Channels after modules, before conv_seg. | |
num_classes (int): Number of classes. | |
dropout_ratio (float): Ratio of dropout layer. Default: 0.1. | |
conv_cfg (dict|None): Config of conv layers. Default: None. | |
norm_cfg (dict|None): Config of norm layers. Default: None. | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU') | |
in_index (int|Sequence[int]): Input feature index. Default: -1 | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
'resize_concat': Multiple feature maps will be resize to the | |
same size as first one and than concat together. | |
Usually used in FCN head of HRNet. | |
'multiple_select': Multiple feature maps will be bundle into | |
a list and passed into decode head. | |
None: Only one select feature map is allowed. | |
Default: None. | |
loss_decode (dict): Config of decode loss. | |
Default: dict(type='CrossEntropyLoss'). | |
ignore_index (int | None): The label index to be ignored. When using | |
masked BCE loss, ignore_index should be set to None. Default: 255 | |
sampler (dict|None): The config of segmentation map sampler. | |
Default: None. | |
align_corners (bool): align_corners argument of F.interpolate. | |
Default: False. | |
""" | |
def __init__(self, | |
in_channels, | |
channels, | |
*, | |
num_classes, | |
dropout_ratio=0.1, | |
conv_cfg=None, | |
norm_cfg=None, | |
act_cfg=dict(type='ReLU'), | |
in_index=-1, | |
input_transform=None, | |
loss_decode=dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=False, | |
loss_weight=1.0), | |
ignore_index=255, | |
sampler=None, | |
align_corners=False): | |
super(BaseDecodeHead, self).__init__() | |
self._init_inputs(in_channels, in_index, input_transform) | |
self.channels = channels | |
self.num_classes = num_classes | |
self.dropout_ratio = dropout_ratio | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.in_index = in_index | |
self.loss_decode = build_loss(loss_decode) | |
self.ignore_index = ignore_index | |
self.align_corners = align_corners | |
if sampler is not None: | |
self.sampler = build_pixel_sampler(sampler, context=self) | |
else: | |
self.sampler = None | |
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) | |
if dropout_ratio > 0: | |
self.dropout = nn.Dropout2d(dropout_ratio) | |
else: | |
self.dropout = None | |
self.fp16_enabled = False | |
def extra_repr(self): | |
"""Extra repr.""" | |
s = f'input_transform={self.input_transform}, ' \ | |
f'ignore_index={self.ignore_index}, ' \ | |
f'align_corners={self.align_corners}' | |
return s | |
def _init_inputs(self, in_channels, in_index, input_transform): | |
"""Check and initialize input transforms. | |
The in_channels, in_index and input_transform must match. | |
Specifically, when input_transform is None, only single feature map | |
will be selected. So in_channels and in_index must be of type int. | |
When input_transform | |
Args: | |
in_channels (int|Sequence[int]): Input channels. | |
in_index (int|Sequence[int]): Input feature index. | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
'resize_concat': Multiple feature maps will be resize to the | |
same size as first one and than concat together. | |
Usually used in FCN head of HRNet. | |
'multiple_select': Multiple feature maps will be bundle into | |
a list and passed into decode head. | |
None: Only one select feature map is allowed. | |
""" | |
if input_transform is not None: | |
assert input_transform in ['resize_concat', 'multiple_select'] | |
self.input_transform = input_transform | |
self.in_index = in_index | |
if input_transform is not None: | |
assert isinstance(in_channels, (list, tuple)) | |
assert isinstance(in_index, (list, tuple)) | |
assert len(in_channels) == len(in_index) | |
if input_transform == 'resize_concat': | |
self.in_channels = sum(in_channels) | |
else: | |
self.in_channels = in_channels | |
else: | |
assert isinstance(in_channels, int) | |
assert isinstance(in_index, int) | |
self.in_channels = in_channels | |
def init_weights(self): | |
"""Initialize weights of classification layer.""" | |
normal_init(self.conv_seg, mean=0, std=0.01) | |
def _transform_inputs(self, inputs): | |
"""Transform inputs for decoder. | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
Returns: | |
Tensor: The transformed inputs | |
""" | |
if self.input_transform == 'resize_concat': | |
inputs = [inputs[i] for i in self.in_index] | |
upsampled_inputs = [ | |
resize( | |
input=x, | |
size=inputs[0].shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) for x in inputs | |
] | |
inputs = torch.cat(upsampled_inputs, dim=1) | |
elif self.input_transform == 'multiple_select': | |
inputs = [inputs[i] for i in self.in_index] | |
else: | |
inputs = inputs[self.in_index] | |
return inputs | |
def forward(self, inputs): | |
"""Placeholder of forward function.""" | |
pass | |
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): | |
"""Forward function for training. | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
img_metas (list[dict]): List of image info dict where each dict | |
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
For details on the values of these keys see | |
`mmseg/datasets/pipelines/formatting.py:Collect`. | |
gt_semantic_seg (Tensor): Semantic segmentation masks | |
used if the architecture supports semantic segmentation task. | |
train_cfg (dict): The training config. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
seg_logits = self.forward(inputs) | |
losses = self.losses(seg_logits, gt_semantic_seg) | |
return losses | |
def forward_test(self, inputs, img_metas, test_cfg): | |
"""Forward function for testing. | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
img_metas (list[dict]): List of image info dict where each dict | |
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
For details on the values of these keys see | |
`mmseg/datasets/pipelines/formatting.py:Collect`. | |
test_cfg (dict): The testing config. | |
Returns: | |
Tensor: Output segmentation map. | |
""" | |
return self.forward(inputs) | |
def cls_seg(self, feat): | |
"""Classify each pixel.""" | |
if self.dropout is not None: | |
feat = self.dropout(feat) | |
output = self.conv_seg(feat) | |
return output | |
def losses(self, seg_logit, seg_label): | |
"""Compute segmentation loss.""" | |
loss = dict() | |
seg_logit = resize( | |
input=seg_logit, | |
size=seg_label.shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
if self.sampler is not None: | |
seg_weight = self.sampler.sample(seg_logit, seg_label) | |
else: | |
seg_weight = None | |
seg_label = seg_label.squeeze(1) | |
loss['loss_seg'] = self.loss_decode( | |
seg_logit, | |
seg_label, | |
weight=seg_weight, | |
ignore_index=self.ignore_index) | |
loss['acc_seg'] = accuracy(seg_logit, seg_label) | |
return loss | |