|
import torch |
|
import torch.nn as nn |
|
from fastai.vision import * |
|
|
|
from modules.model import _default_tfmer_cfg |
|
from modules.resnet import resnet45 |
|
from modules.transformer import (PositionalEncoding, |
|
TransformerEncoder, |
|
TransformerEncoderLayer) |
|
|
|
|
|
class ResTranformer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.resnet = resnet45() |
|
|
|
self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model']) |
|
nhead = ifnone(config.model_vision_nhead, _default_tfmer_cfg['nhead']) |
|
d_inner = ifnone(config.model_vision_d_inner, _default_tfmer_cfg['d_inner']) |
|
dropout = ifnone(config.model_vision_dropout, _default_tfmer_cfg['dropout']) |
|
activation = ifnone(config.model_vision_activation, _default_tfmer_cfg['activation']) |
|
num_layers = ifnone(config.model_vision_backbone_ln, 2) |
|
|
|
self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32) |
|
encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead, |
|
dim_feedforward=d_inner, dropout=dropout, activation=activation) |
|
self.transformer = TransformerEncoder(encoder_layer, num_layers) |
|
|
|
def forward(self, images): |
|
feature = self.resnet(images) |
|
n, c, h, w = feature.shape |
|
feature = feature.view(n, c, -1).permute(2, 0, 1) |
|
feature = self.pos_encoder(feature) |
|
feature = self.transformer(feature) |
|
feature = feature.permute(1, 2, 0).view(n, c, h, w) |
|
return feature |
|
|