Spaces:
Build error
Build error
File size: 2,102 Bytes
04daa95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch.nn as nn
from networks.layers.transformer import DualBranchGPM
from networks.models.aot import AOT
from networks.decoders import build_decoder
class DeAOT(AOT):
def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'):
super().__init__(cfg, encoder, decoder)
self.LSTT = DualBranchGPM(
cfg.MODEL_LSTT_NUM,
cfg.MODEL_ENCODER_EMBEDDING_DIM,
cfg.MODEL_SELF_HEADS,
cfg.MODEL_ATT_HEADS,
emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT,
droppath=cfg.TRAIN_LSTT_DROPPATH,
lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT,
st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT,
droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST,
droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING,
intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT,
return_intermediate=True)
decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \
(cfg.MODEL_LSTT_NUM * 2 +
1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM * 2
self.decoder = build_decoder(
decoder,
in_dim=decoder_indim,
out_dim=cfg.MODEL_MAX_OBJ_NUM + 1,
decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT,
hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM,
shortcut_dims=cfg.MODEL_ENCODER_DIM,
align_corners=cfg.MODEL_ALIGN_CORNERS)
self.id_norm = nn.LayerNorm(cfg.MODEL_ENCODER_EMBEDDING_DIM)
self._init_weight()
def decode_id_logits(self, lstt_emb, shortcuts):
n, c, h, w = shortcuts[-1].size()
decoder_inputs = [shortcuts[-1]]
for emb in lstt_emb:
decoder_inputs.append(emb.view(h, w, n, -1).permute(2, 3, 0, 1))
pred_logit = self.decoder(decoder_inputs, shortcuts)
return pred_logit
def get_id_emb(self, x):
id_emb = self.patch_wise_id_bank(x)
id_emb = self.id_norm(id_emb.permute(2, 3, 0, 1)).permute(2, 3, 0, 1)
id_emb = self.id_dropout(id_emb)
return id_emb
|