Spaces:
Running
on
A10G
Running
on
A10G
import torch | |
from annotator.mmpkg.mmcv.cnn import NonLocal2d | |
from ..builder import HEADS | |
from .fcn_head import FCNHead | |
class NLHead(FCNHead): | |
"""Non-local Neural Networks. | |
This head is the implementation of `NLNet | |
<https://arxiv.org/abs/1711.07971>`_. | |
Args: | |
reduction (int): Reduction factor of projection transform. Default: 2. | |
use_scale (bool): Whether to scale pairwise_weight by | |
sqrt(1/inter_channels). Default: True. | |
mode (str): The nonlocal mode. Options are 'embedded_gaussian', | |
'dot_product'. Default: 'embedded_gaussian.'. | |
""" | |
def __init__(self, | |
reduction=2, | |
use_scale=True, | |
mode='embedded_gaussian', | |
**kwargs): | |
super(NLHead, self).__init__(num_convs=2, **kwargs) | |
self.reduction = reduction | |
self.use_scale = use_scale | |
self.mode = mode | |
self.nl_block = NonLocal2d( | |
in_channels=self.channels, | |
reduction=self.reduction, | |
use_scale=self.use_scale, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
mode=self.mode) | |
def forward(self, inputs): | |
"""Forward function.""" | |
x = self._transform_inputs(inputs) | |
output = self.convs[0](x) | |
output = self.nl_block(output) | |
output = self.convs[1](output) | |
if self.concat_input: | |
output = self.conv_cat(torch.cat([x, output], dim=1)) | |
output = self.cls_seg(output) | |
return output | |