Spaces:
Build error
Build error
import torch.nn as nn | |
from networks.layers.transformer import DualBranchGPM | |
from networks.models.aot import AOT | |
from networks.decoders import build_decoder | |
class DeAOT(AOT): | |
def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): | |
super().__init__(cfg, encoder, decoder) | |
self.LSTT = DualBranchGPM( | |
cfg.MODEL_LSTT_NUM, | |
cfg.MODEL_ENCODER_EMBEDDING_DIM, | |
cfg.MODEL_SELF_HEADS, | |
cfg.MODEL_ATT_HEADS, | |
emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, | |
droppath=cfg.TRAIN_LSTT_DROPPATH, | |
lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, | |
st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, | |
droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, | |
droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, | |
intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, | |
return_intermediate=True) | |
decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ | |
(cfg.MODEL_LSTT_NUM * 2 + | |
1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM * 2 | |
self.decoder = build_decoder( | |
decoder, | |
in_dim=decoder_indim, | |
out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, | |
decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, | |
hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, | |
shortcut_dims=cfg.MODEL_ENCODER_DIM, | |
align_corners=cfg.MODEL_ALIGN_CORNERS) | |
self.id_norm = nn.LayerNorm(cfg.MODEL_ENCODER_EMBEDDING_DIM) | |
self._init_weight() | |
def decode_id_logits(self, lstt_emb, shortcuts): | |
n, c, h, w = shortcuts[-1].size() | |
decoder_inputs = [shortcuts[-1]] | |
for emb in lstt_emb: | |
decoder_inputs.append(emb.view(h, w, n, -1).permute(2, 3, 0, 1)) | |
pred_logit = self.decoder(decoder_inputs, shortcuts) | |
return pred_logit | |
def get_id_emb(self, x): | |
id_emb = self.patch_wise_id_bank(x) | |
id_emb = self.id_norm(id_emb.permute(2, 3, 0, 1)).permute(2, 3, 0, 1) | |
id_emb = self.id_dropout(id_emb) | |
return id_emb | |