Spaces:
Running
on
A10G
Running
on
A10G
import torch | |
from annotator.mmpkg.mmcv.cnn import ConvModule, constant_init | |
from torch import nn as nn | |
from torch.nn import functional as F | |
class SelfAttentionBlock(nn.Module): | |
"""General self-attention block/non-local block. | |
Please refer to https://arxiv.org/abs/1706.03762 for details about key, | |
query and value. | |
Args: | |
key_in_channels (int): Input channels of key feature. | |
query_in_channels (int): Input channels of query feature. | |
channels (int): Output channels of key/query transform. | |
out_channels (int): Output channels. | |
share_key_query (bool): Whether share projection weight between key | |
and query projection. | |
query_downsample (nn.Module): Query downsample module. | |
key_downsample (nn.Module): Key downsample module. | |
key_query_num_convs (int): Number of convs for key/query projection. | |
value_num_convs (int): Number of convs for value projection. | |
matmul_norm (bool): Whether normalize attention map with sqrt of | |
channels | |
with_out (bool): Whether use out projection. | |
conv_cfg (dict|None): Config of conv layers. | |
norm_cfg (dict|None): Config of norm layers. | |
act_cfg (dict|None): Config of activation layers. | |
""" | |
def __init__(self, key_in_channels, query_in_channels, channels, | |
out_channels, share_key_query, query_downsample, | |
key_downsample, key_query_num_convs, value_out_num_convs, | |
key_query_norm, value_out_norm, matmul_norm, with_out, | |
conv_cfg, norm_cfg, act_cfg): | |
super(SelfAttentionBlock, self).__init__() | |
if share_key_query: | |
assert key_in_channels == query_in_channels | |
self.key_in_channels = key_in_channels | |
self.query_in_channels = query_in_channels | |
self.out_channels = out_channels | |
self.channels = channels | |
self.share_key_query = share_key_query | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.key_project = self.build_project( | |
key_in_channels, | |
channels, | |
num_convs=key_query_num_convs, | |
use_conv_module=key_query_norm, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
if share_key_query: | |
self.query_project = self.key_project | |
else: | |
self.query_project = self.build_project( | |
query_in_channels, | |
channels, | |
num_convs=key_query_num_convs, | |
use_conv_module=key_query_norm, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.value_project = self.build_project( | |
key_in_channels, | |
channels if with_out else out_channels, | |
num_convs=value_out_num_convs, | |
use_conv_module=value_out_norm, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
if with_out: | |
self.out_project = self.build_project( | |
channels, | |
out_channels, | |
num_convs=value_out_num_convs, | |
use_conv_module=value_out_norm, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
else: | |
self.out_project = None | |
self.query_downsample = query_downsample | |
self.key_downsample = key_downsample | |
self.matmul_norm = matmul_norm | |
self.init_weights() | |
def init_weights(self): | |
"""Initialize weight of later layer.""" | |
if self.out_project is not None: | |
if not isinstance(self.out_project, ConvModule): | |
constant_init(self.out_project, 0) | |
def build_project(self, in_channels, channels, num_convs, use_conv_module, | |
conv_cfg, norm_cfg, act_cfg): | |
"""Build projection layer for key/query/value/out.""" | |
if use_conv_module: | |
convs = [ | |
ConvModule( | |
in_channels, | |
channels, | |
1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
] | |
for _ in range(num_convs - 1): | |
convs.append( | |
ConvModule( | |
channels, | |
channels, | |
1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
else: | |
convs = [nn.Conv2d(in_channels, channels, 1)] | |
for _ in range(num_convs - 1): | |
convs.append(nn.Conv2d(channels, channels, 1)) | |
if len(convs) > 1: | |
convs = nn.Sequential(*convs) | |
else: | |
convs = convs[0] | |
return convs | |
def forward(self, query_feats, key_feats): | |
"""Forward function.""" | |
batch_size = query_feats.size(0) | |
query = self.query_project(query_feats) | |
if self.query_downsample is not None: | |
query = self.query_downsample(query) | |
query = query.reshape(*query.shape[:2], -1) | |
query = query.permute(0, 2, 1).contiguous() | |
key = self.key_project(key_feats) | |
value = self.value_project(key_feats) | |
if self.key_downsample is not None: | |
key = self.key_downsample(key) | |
value = self.key_downsample(value) | |
key = key.reshape(*key.shape[:2], -1) | |
value = value.reshape(*value.shape[:2], -1) | |
value = value.permute(0, 2, 1).contiguous() | |
sim_map = torch.matmul(query, key) | |
if self.matmul_norm: | |
sim_map = (self.channels**-.5) * sim_map | |
sim_map = F.softmax(sim_map, dim=-1) | |
context = torch.matmul(sim_map, value) | |
context = context.permute(0, 2, 1).contiguous() | |
context = context.reshape(batch_size, -1, *query_feats.shape[2:]) | |
if self.out_project is not None: | |
context = self.out_project(context) | |
return context | |