KyanChen's picture
Upload 787 files
3e06e1c
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmengine import ConfigDict
from mmengine.model import BaseModule, ModuleList
from torch import Tensor
from mmdet.utils import ConfigType, OptConfigType
class DetrTransformerEncoder(BaseModule):
"""Encoder of DETR.
Args:
num_layers (int): Number of encoder layers.
layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
layer. All the layers will share the same config.
init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
the initialization. Defaults to None.
"""
def __init__(self,
num_layers: int,
layer_cfg: ConfigType,
init_cfg: OptConfigType = None) -> None:
super().__init__(init_cfg=init_cfg)
self.num_layers = num_layers
self.layer_cfg = layer_cfg
self._init_layers()
def _init_layers(self) -> None:
"""Initialize encoder layers."""
self.layers = ModuleList([
DetrTransformerEncoderLayer(**self.layer_cfg)
for _ in range(self.num_layers)
])
self.embed_dims = self.layers[0].embed_dims
def forward(self, query: Tensor, query_pos: Tensor,
key_padding_mask: Tensor, **kwargs) -> Tensor:
"""Forward function of encoder.
Args:
query (Tensor): Input queries of encoder, has shape
(bs, num_queries, dim).
query_pos (Tensor): The positional embeddings of the queries, has
shape (bs, num_queries, dim).
key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
input. ByteTensor, has shape (bs, num_queries).
Returns:
Tensor: Has shape (bs, num_queries, dim) if `batch_first` is
`True`, otherwise (num_queries, bs, dim).
"""
for layer in self.layers:
query = layer(query, query_pos, key_padding_mask, **kwargs)
return query
class DetrTransformerDecoder(BaseModule):
"""Decoder of DETR.
Args:
num_layers (int): Number of decoder layers.
layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
layer. All the layers will share the same config.
post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the
post normalization layer. Defaults to `LN`.
return_intermediate (bool, optional): Whether to return outputs of
intermediate layers. Defaults to `True`,
init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
the initialization. Defaults to None.
"""
def __init__(self,
num_layers: int,
layer_cfg: ConfigType,
post_norm_cfg: OptConfigType = dict(type='LN'),
return_intermediate: bool = True,
init_cfg: Union[dict, ConfigDict] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.layer_cfg = layer_cfg
self.num_layers = num_layers
self.post_norm_cfg = post_norm_cfg
self.return_intermediate = return_intermediate
self._init_layers()
def _init_layers(self) -> None:
"""Initialize decoder layers."""
self.layers = ModuleList([
DetrTransformerDecoderLayer(**self.layer_cfg)
for _ in range(self.num_layers)
])
self.embed_dims = self.layers[0].embed_dims
self.post_norm = build_norm_layer(self.post_norm_cfg,
self.embed_dims)[1]
def forward(self, query: Tensor, key: Tensor, value: Tensor,
query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor,
**kwargs) -> Tensor:
"""Forward function of decoder
Args:
query (Tensor): The input query, has shape (bs, num_queries, dim).
key (Tensor): The input key, has shape (bs, num_keys, dim).
value (Tensor): The input value with the same shape as `key`.
query_pos (Tensor): The positional encoding for `query`, with the
same shape as `query`.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`.
key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
input. ByteTensor, has shape (bs, num_value).
Returns:
Tensor: The forwarded results will have shape
(num_decoder_layers, bs, num_queries, dim) if
`return_intermediate` is `True` else (1, bs, num_queries, dim).
"""
intermediate = []
for layer in self.layers:
query = layer(
query,
key=key,
value=value,
query_pos=query_pos,
key_pos=key_pos,
key_padding_mask=key_padding_mask,
**kwargs)
if self.return_intermediate:
intermediate.append(self.post_norm(query))
query = self.post_norm(query)
if self.return_intermediate:
return torch.stack(intermediate)
return query.unsqueeze(0)
class DetrTransformerEncoderLayer(BaseModule):
"""Implements encoder layer in DETR transformer.
Args:
self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self
attention.
ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN.
norm_cfg (:obj:`ConfigDict` or dict, optional): Config for
normalization layers. All the layers will share the same
config. Defaults to `LN`.
init_cfg (:obj:`ConfigDict` or dict, optional): Config to control
the initialization. Defaults to None.
"""
def __init__(self,
self_attn_cfg: OptConfigType = dict(
embed_dims=256, num_heads=8, dropout=0.0),
ffn_cfg: OptConfigType = dict(
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True)),
norm_cfg: OptConfigType = dict(type='LN'),
init_cfg: OptConfigType = None) -> None:
super().__init__(init_cfg=init_cfg)
self.self_attn_cfg = self_attn_cfg
if 'batch_first' not in self.self_attn_cfg:
self.self_attn_cfg['batch_first'] = True
else:
assert self.self_attn_cfg['batch_first'] is True, 'First \
dimension of all DETRs in mmdet is `batch`, \
please set `batch_first` flag.'
self.ffn_cfg = ffn_cfg
self.norm_cfg = norm_cfg
self._init_layers()
def _init_layers(self) -> None:
"""Initialize self-attention, FFN, and normalization."""
self.self_attn = MultiheadAttention(**self.self_attn_cfg)
self.embed_dims = self.self_attn.embed_dims
self.ffn = FFN(**self.ffn_cfg)
norms_list = [
build_norm_layer(self.norm_cfg, self.embed_dims)[1]
for _ in range(2)
]
self.norms = ModuleList(norms_list)
def forward(self, query: Tensor, query_pos: Tensor,
key_padding_mask: Tensor, **kwargs) -> Tensor:
"""Forward function of an encoder layer.
Args:
query (Tensor): The input query, has shape (bs, num_queries, dim).
query_pos (Tensor): The positional encoding for query, with
the same shape as `query`.
key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
input. ByteTensor. has shape (bs, num_queries).
Returns:
Tensor: forwarded results, has shape (bs, num_queries, dim).
"""
query = self.self_attn(
query=query,
key=query,
value=query,
query_pos=query_pos,
key_pos=query_pos,
key_padding_mask=key_padding_mask,
**kwargs)
query = self.norms[0](query)
query = self.ffn(query)
query = self.norms[1](query)
return query
class DetrTransformerDecoderLayer(BaseModule):
"""Implements decoder layer in DETR transformer.
Args:
self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self
attention.
cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross
attention.
ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN.
norm_cfg (:obj:`ConfigDict` or dict, optional): Config for
normalization layers. All the layers will share the same
config. Defaults to `LN`.
init_cfg (:obj:`ConfigDict` or dict, optional): Config to control
the initialization. Defaults to None.
"""
def __init__(self,
self_attn_cfg: OptConfigType = dict(
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=True),
cross_attn_cfg: OptConfigType = dict(
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=True),
ffn_cfg: OptConfigType = dict(
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
norm_cfg: OptConfigType = dict(type='LN'),
init_cfg: OptConfigType = None) -> None:
super().__init__(init_cfg=init_cfg)
self.self_attn_cfg = self_attn_cfg
self.cross_attn_cfg = cross_attn_cfg
if 'batch_first' not in self.self_attn_cfg:
self.self_attn_cfg['batch_first'] = True
else:
assert self.self_attn_cfg['batch_first'] is True, 'First \
dimension of all DETRs in mmdet is `batch`, \
please set `batch_first` flag.'
if 'batch_first' not in self.cross_attn_cfg:
self.cross_attn_cfg['batch_first'] = True
else:
assert self.cross_attn_cfg['batch_first'] is True, 'First \
dimension of all DETRs in mmdet is `batch`, \
please set `batch_first` flag.'
self.ffn_cfg = ffn_cfg
self.norm_cfg = norm_cfg
self._init_layers()
def _init_layers(self) -> None:
"""Initialize self-attention, FFN, and normalization."""
self.self_attn = MultiheadAttention(**self.self_attn_cfg)
self.cross_attn = MultiheadAttention(**self.cross_attn_cfg)
self.embed_dims = self.self_attn.embed_dims
self.ffn = FFN(**self.ffn_cfg)
norms_list = [
build_norm_layer(self.norm_cfg, self.embed_dims)[1]
for _ in range(3)
]
self.norms = ModuleList(norms_list)
def forward(self,
query: Tensor,
key: Tensor = None,
value: Tensor = None,
query_pos: Tensor = None,
key_pos: Tensor = None,
self_attn_mask: Tensor = None,
cross_attn_mask: Tensor = None,
key_padding_mask: Tensor = None,
**kwargs) -> Tensor:
"""
Args:
query (Tensor): The input query, has shape (bs, num_queries, dim).
key (Tensor, optional): The input key, has shape (bs, num_keys,
dim). If `None`, the `query` will be used. Defaults to `None`.
value (Tensor, optional): The input value, has the same shape as
`key`, as in `nn.MultiheadAttention.forward`. If `None`, the
`key` will be used. Defaults to `None`.
query_pos (Tensor, optional): The positional encoding for `query`,
has the same shape as `query`. If not `None`, it will be added
to `query` before forward function. Defaults to `None`.
key_pos (Tensor, optional): The positional encoding for `key`, has
the same shape as `key`. If not `None`, it will be added to
`key` before forward function. If None, and `query_pos` has the
same shape as `key`, then `query_pos` will be used for
`key_pos`. Defaults to None.
self_attn_mask (Tensor, optional): ByteTensor mask, has shape
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
Defaults to None.
cross_attn_mask (Tensor, optional): ByteTensor mask, has shape
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor, optional): The `key_padding_mask` of
`self_attn` input. ByteTensor, has shape (bs, num_value).
Defaults to None.
Returns:
Tensor: forwarded results, has shape (bs, num_queries, dim).
"""
query = self.self_attn(
query=query,
key=query,
value=query,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=self_attn_mask,
**kwargs)
query = self.norms[0](query)
query = self.cross_attn(
query=query,
key=key,
value=value,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=cross_attn_mask,
key_padding_mask=key_padding_mask,
**kwargs)
query = self.norms[1](query)
query = self.ffn(query)
query = self.norms[2](query)
return query