Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Union | |
import torch | |
from mmcv.cnn import build_norm_layer | |
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention | |
from mmengine import ConfigDict | |
from mmengine.model import BaseModule, ModuleList | |
from torch import Tensor | |
from mmdet.utils import ConfigType, OptConfigType | |
class DetrTransformerEncoder(BaseModule): | |
"""Encoder of DETR. | |
Args: | |
num_layers (int): Number of encoder layers. | |
layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder | |
layer. All the layers will share the same config. | |
init_cfg (:obj:`ConfigDict` or dict, optional): the config to control | |
the initialization. Defaults to None. | |
""" | |
def __init__(self, | |
num_layers: int, | |
layer_cfg: ConfigType, | |
init_cfg: OptConfigType = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.num_layers = num_layers | |
self.layer_cfg = layer_cfg | |
self._init_layers() | |
def _init_layers(self) -> None: | |
"""Initialize encoder layers.""" | |
self.layers = ModuleList([ | |
DetrTransformerEncoderLayer(**self.layer_cfg) | |
for _ in range(self.num_layers) | |
]) | |
self.embed_dims = self.layers[0].embed_dims | |
def forward(self, query: Tensor, query_pos: Tensor, | |
key_padding_mask: Tensor, **kwargs) -> Tensor: | |
"""Forward function of encoder. | |
Args: | |
query (Tensor): Input queries of encoder, has shape | |
(bs, num_queries, dim). | |
query_pos (Tensor): The positional embeddings of the queries, has | |
shape (bs, num_queries, dim). | |
key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` | |
input. ByteTensor, has shape (bs, num_queries). | |
Returns: | |
Tensor: Has shape (bs, num_queries, dim) if `batch_first` is | |
`True`, otherwise (num_queries, bs, dim). | |
""" | |
for layer in self.layers: | |
query = layer(query, query_pos, key_padding_mask, **kwargs) | |
return query | |
class DetrTransformerDecoder(BaseModule): | |
"""Decoder of DETR. | |
Args: | |
num_layers (int): Number of decoder layers. | |
layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder | |
layer. All the layers will share the same config. | |
post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the | |
post normalization layer. Defaults to `LN`. | |
return_intermediate (bool, optional): Whether to return outputs of | |
intermediate layers. Defaults to `True`, | |
init_cfg (:obj:`ConfigDict` or dict, optional): the config to control | |
the initialization. Defaults to None. | |
""" | |
def __init__(self, | |
num_layers: int, | |
layer_cfg: ConfigType, | |
post_norm_cfg: OptConfigType = dict(type='LN'), | |
return_intermediate: bool = True, | |
init_cfg: Union[dict, ConfigDict] = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.layer_cfg = layer_cfg | |
self.num_layers = num_layers | |
self.post_norm_cfg = post_norm_cfg | |
self.return_intermediate = return_intermediate | |
self._init_layers() | |
def _init_layers(self) -> None: | |
"""Initialize decoder layers.""" | |
self.layers = ModuleList([ | |
DetrTransformerDecoderLayer(**self.layer_cfg) | |
for _ in range(self.num_layers) | |
]) | |
self.embed_dims = self.layers[0].embed_dims | |
self.post_norm = build_norm_layer(self.post_norm_cfg, | |
self.embed_dims)[1] | |
def forward(self, query: Tensor, key: Tensor, value: Tensor, | |
query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor, | |
**kwargs) -> Tensor: | |
"""Forward function of decoder | |
Args: | |
query (Tensor): The input query, has shape (bs, num_queries, dim). | |
key (Tensor): The input key, has shape (bs, num_keys, dim). | |
value (Tensor): The input value with the same shape as `key`. | |
query_pos (Tensor): The positional encoding for `query`, with the | |
same shape as `query`. | |
key_pos (Tensor): The positional encoding for `key`, with the | |
same shape as `key`. | |
key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` | |
input. ByteTensor, has shape (bs, num_value). | |
Returns: | |
Tensor: The forwarded results will have shape | |
(num_decoder_layers, bs, num_queries, dim) if | |
`return_intermediate` is `True` else (1, bs, num_queries, dim). | |
""" | |
intermediate = [] | |
for layer in self.layers: | |
query = layer( | |
query, | |
key=key, | |
value=value, | |
query_pos=query_pos, | |
key_pos=key_pos, | |
key_padding_mask=key_padding_mask, | |
**kwargs) | |
if self.return_intermediate: | |
intermediate.append(self.post_norm(query)) | |
query = self.post_norm(query) | |
if self.return_intermediate: | |
return torch.stack(intermediate) | |
return query.unsqueeze(0) | |
class DetrTransformerEncoderLayer(BaseModule): | |
"""Implements encoder layer in DETR transformer. | |
Args: | |
self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self | |
attention. | |
ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. | |
norm_cfg (:obj:`ConfigDict` or dict, optional): Config for | |
normalization layers. All the layers will share the same | |
config. Defaults to `LN`. | |
init_cfg (:obj:`ConfigDict` or dict, optional): Config to control | |
the initialization. Defaults to None. | |
""" | |
def __init__(self, | |
self_attn_cfg: OptConfigType = dict( | |
embed_dims=256, num_heads=8, dropout=0.0), | |
ffn_cfg: OptConfigType = dict( | |
embed_dims=256, | |
feedforward_channels=1024, | |
num_fcs=2, | |
ffn_drop=0., | |
act_cfg=dict(type='ReLU', inplace=True)), | |
norm_cfg: OptConfigType = dict(type='LN'), | |
init_cfg: OptConfigType = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.self_attn_cfg = self_attn_cfg | |
if 'batch_first' not in self.self_attn_cfg: | |
self.self_attn_cfg['batch_first'] = True | |
else: | |
assert self.self_attn_cfg['batch_first'] is True, 'First \ | |
dimension of all DETRs in mmdet is `batch`, \ | |
please set `batch_first` flag.' | |
self.ffn_cfg = ffn_cfg | |
self.norm_cfg = norm_cfg | |
self._init_layers() | |
def _init_layers(self) -> None: | |
"""Initialize self-attention, FFN, and normalization.""" | |
self.self_attn = MultiheadAttention(**self.self_attn_cfg) | |
self.embed_dims = self.self_attn.embed_dims | |
self.ffn = FFN(**self.ffn_cfg) | |
norms_list = [ | |
build_norm_layer(self.norm_cfg, self.embed_dims)[1] | |
for _ in range(2) | |
] | |
self.norms = ModuleList(norms_list) | |
def forward(self, query: Tensor, query_pos: Tensor, | |
key_padding_mask: Tensor, **kwargs) -> Tensor: | |
"""Forward function of an encoder layer. | |
Args: | |
query (Tensor): The input query, has shape (bs, num_queries, dim). | |
query_pos (Tensor): The positional encoding for query, with | |
the same shape as `query`. | |
key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` | |
input. ByteTensor. has shape (bs, num_queries). | |
Returns: | |
Tensor: forwarded results, has shape (bs, num_queries, dim). | |
""" | |
query = self.self_attn( | |
query=query, | |
key=query, | |
value=query, | |
query_pos=query_pos, | |
key_pos=query_pos, | |
key_padding_mask=key_padding_mask, | |
**kwargs) | |
query = self.norms[0](query) | |
query = self.ffn(query) | |
query = self.norms[1](query) | |
return query | |
class DetrTransformerDecoderLayer(BaseModule): | |
"""Implements decoder layer in DETR transformer. | |
Args: | |
self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self | |
attention. | |
cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross | |
attention. | |
ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. | |
norm_cfg (:obj:`ConfigDict` or dict, optional): Config for | |
normalization layers. All the layers will share the same | |
config. Defaults to `LN`. | |
init_cfg (:obj:`ConfigDict` or dict, optional): Config to control | |
the initialization. Defaults to None. | |
""" | |
def __init__(self, | |
self_attn_cfg: OptConfigType = dict( | |
embed_dims=256, | |
num_heads=8, | |
dropout=0.0, | |
batch_first=True), | |
cross_attn_cfg: OptConfigType = dict( | |
embed_dims=256, | |
num_heads=8, | |
dropout=0.0, | |
batch_first=True), | |
ffn_cfg: OptConfigType = dict( | |
embed_dims=256, | |
feedforward_channels=1024, | |
num_fcs=2, | |
ffn_drop=0., | |
act_cfg=dict(type='ReLU', inplace=True), | |
), | |
norm_cfg: OptConfigType = dict(type='LN'), | |
init_cfg: OptConfigType = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.self_attn_cfg = self_attn_cfg | |
self.cross_attn_cfg = cross_attn_cfg | |
if 'batch_first' not in self.self_attn_cfg: | |
self.self_attn_cfg['batch_first'] = True | |
else: | |
assert self.self_attn_cfg['batch_first'] is True, 'First \ | |
dimension of all DETRs in mmdet is `batch`, \ | |
please set `batch_first` flag.' | |
if 'batch_first' not in self.cross_attn_cfg: | |
self.cross_attn_cfg['batch_first'] = True | |
else: | |
assert self.cross_attn_cfg['batch_first'] is True, 'First \ | |
dimension of all DETRs in mmdet is `batch`, \ | |
please set `batch_first` flag.' | |
self.ffn_cfg = ffn_cfg | |
self.norm_cfg = norm_cfg | |
self._init_layers() | |
def _init_layers(self) -> None: | |
"""Initialize self-attention, FFN, and normalization.""" | |
self.self_attn = MultiheadAttention(**self.self_attn_cfg) | |
self.cross_attn = MultiheadAttention(**self.cross_attn_cfg) | |
self.embed_dims = self.self_attn.embed_dims | |
self.ffn = FFN(**self.ffn_cfg) | |
norms_list = [ | |
build_norm_layer(self.norm_cfg, self.embed_dims)[1] | |
for _ in range(3) | |
] | |
self.norms = ModuleList(norms_list) | |
def forward(self, | |
query: Tensor, | |
key: Tensor = None, | |
value: Tensor = None, | |
query_pos: Tensor = None, | |
key_pos: Tensor = None, | |
self_attn_mask: Tensor = None, | |
cross_attn_mask: Tensor = None, | |
key_padding_mask: Tensor = None, | |
**kwargs) -> Tensor: | |
""" | |
Args: | |
query (Tensor): The input query, has shape (bs, num_queries, dim). | |
key (Tensor, optional): The input key, has shape (bs, num_keys, | |
dim). If `None`, the `query` will be used. Defaults to `None`. | |
value (Tensor, optional): The input value, has the same shape as | |
`key`, as in `nn.MultiheadAttention.forward`. If `None`, the | |
`key` will be used. Defaults to `None`. | |
query_pos (Tensor, optional): The positional encoding for `query`, | |
has the same shape as `query`. If not `None`, it will be added | |
to `query` before forward function. Defaults to `None`. | |
key_pos (Tensor, optional): The positional encoding for `key`, has | |
the same shape as `key`. If not `None`, it will be added to | |
`key` before forward function. If None, and `query_pos` has the | |
same shape as `key`, then `query_pos` will be used for | |
`key_pos`. Defaults to None. | |
self_attn_mask (Tensor, optional): ByteTensor mask, has shape | |
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`. | |
Defaults to None. | |
cross_attn_mask (Tensor, optional): ByteTensor mask, has shape | |
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`. | |
Defaults to None. | |
key_padding_mask (Tensor, optional): The `key_padding_mask` of | |
`self_attn` input. ByteTensor, has shape (bs, num_value). | |
Defaults to None. | |
Returns: | |
Tensor: forwarded results, has shape (bs, num_queries, dim). | |
""" | |
query = self.self_attn( | |
query=query, | |
key=query, | |
value=query, | |
query_pos=query_pos, | |
key_pos=query_pos, | |
attn_mask=self_attn_mask, | |
**kwargs) | |
query = self.norms[0](query) | |
query = self.cross_attn( | |
query=query, | |
key=key, | |
value=value, | |
query_pos=query_pos, | |
key_pos=key_pos, | |
attn_mask=cross_attn_mask, | |
key_padding_mask=key_padding_mask, | |
**kwargs) | |
query = self.norms[1](query) | |
query = self.ffn(query) | |
query = self.norms[2](query) | |
return query | |