Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from mmcv.cnn import build_norm_layer | |
from mmengine.model import ModuleList | |
from torch import Tensor | |
from .deformable_detr_layers import DeformableDetrTransformerEncoder | |
from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer | |
class Mask2FormerTransformerEncoder(DeformableDetrTransformerEncoder): | |
"""Encoder in PixelDecoder of Mask2Former.""" | |
def forward(self, query: Tensor, query_pos: Tensor, | |
key_padding_mask: Tensor, spatial_shapes: Tensor, | |
level_start_index: Tensor, valid_ratios: Tensor, | |
reference_points: Tensor, **kwargs) -> Tensor: | |
"""Forward function of Transformer encoder. | |
Args: | |
query (Tensor): The input query, has shape (bs, num_queries, dim). | |
query_pos (Tensor): The positional encoding for query, has shape | |
(bs, num_queries, dim). If not None, it will be added to the | |
`query` before forward function. Defaults to None. | |
key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` | |
input. ByteTensor, has shape (bs, num_queries). | |
spatial_shapes (Tensor): Spatial shapes of features in all levels, | |
has shape (num_levels, 2), last dimension represents (h, w). | |
level_start_index (Tensor): The start index of each level. | |
A tensor has shape (num_levels, ) and can be represented | |
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. | |
valid_ratios (Tensor): The ratios of the valid width and the valid | |
height relative to the width and the height of features in all | |
levels, has shape (bs, num_levels, 2). | |
reference_points (Tensor): The initial reference, has shape | |
(bs, num_queries, 2) with the last dimension arranged | |
as (cx, cy). | |
Returns: | |
Tensor: Output queries of Transformer encoder, which is also | |
called 'encoder output embeddings' or 'memory', has shape | |
(bs, num_queries, dim) | |
""" | |
for layer in self.layers: | |
query = layer( | |
query=query, | |
query_pos=query_pos, | |
key_padding_mask=key_padding_mask, | |
spatial_shapes=spatial_shapes, | |
level_start_index=level_start_index, | |
valid_ratios=valid_ratios, | |
reference_points=reference_points, | |
**kwargs) | |
return query | |
class Mask2FormerTransformerDecoder(DetrTransformerDecoder): | |
"""Decoder of Mask2Former.""" | |
def _init_layers(self) -> None: | |
"""Initialize decoder layers.""" | |
self.layers = ModuleList([ | |
Mask2FormerTransformerDecoderLayer(**self.layer_cfg) | |
for _ in range(self.num_layers) | |
]) | |
self.embed_dims = self.layers[0].embed_dims | |
self.post_norm = build_norm_layer(self.post_norm_cfg, | |
self.embed_dims)[1] | |
class Mask2FormerTransformerDecoderLayer(DetrTransformerDecoderLayer): | |
"""Implements decoder layer in Mask2Former transformer.""" | |
def forward(self, | |
query: Tensor, | |
key: Tensor = None, | |
value: Tensor = None, | |
query_pos: Tensor = None, | |
key_pos: Tensor = None, | |
self_attn_mask: Tensor = None, | |
cross_attn_mask: Tensor = None, | |
key_padding_mask: Tensor = None, | |
**kwargs) -> Tensor: | |
""" | |
Args: | |
query (Tensor): The input query, has shape (bs, num_queries, dim). | |
key (Tensor, optional): The input key, has shape (bs, num_keys, | |
dim). If `None`, the `query` will be used. Defaults to `None`. | |
value (Tensor, optional): The input value, has the same shape as | |
`key`, as in `nn.MultiheadAttention.forward`. If `None`, the | |
`key` will be used. Defaults to `None`. | |
query_pos (Tensor, optional): The positional encoding for `query`, | |
has the same shape as `query`. If not `None`, it will be added | |
to `query` before forward function. Defaults to `None`. | |
key_pos (Tensor, optional): The positional encoding for `key`, has | |
the same shape as `key`. If not `None`, it will be added to | |
`key` before forward function. If None, and `query_pos` has the | |
same shape as `key`, then `query_pos` will be used for | |
`key_pos`. Defaults to None. | |
self_attn_mask (Tensor, optional): ByteTensor mask, has shape | |
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`. | |
Defaults to None. | |
cross_attn_mask (Tensor, optional): ByteTensor mask, has shape | |
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`. | |
Defaults to None. | |
key_padding_mask (Tensor, optional): The `key_padding_mask` of | |
`self_attn` input. ByteTensor, has shape (bs, num_value). | |
Defaults to None. | |
Returns: | |
Tensor: forwarded results, has shape (bs, num_queries, dim). | |
""" | |
query = self.cross_attn( | |
query=query, | |
key=key, | |
value=value, | |
query_pos=query_pos, | |
key_pos=key_pos, | |
attn_mask=cross_attn_mask, | |
key_padding_mask=key_padding_mask, | |
**kwargs) | |
query = self.norms[0](query) | |
query = self.self_attn( | |
query=query, | |
key=query, | |
value=query, | |
query_pos=query_pos, | |
key_pos=query_pos, | |
attn_mask=self_attn_mask, | |
**kwargs) | |
query = self.norms[1](query) | |
query = self.ffn(query) | |
query = self.norms[2](query) | |
return query | |