|
import torch |
|
import torch.nn as nn |
|
from fastai.vision import * |
|
|
|
from .model_vision import BaseVision |
|
from .model_language import BCNLanguage |
|
from .model_alignment import BaseAlignment |
|
|
|
|
|
class ABINetIterModel(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.iter_size = ifnone(config.model_iter_size, 1) |
|
self.max_length = config.dataset_max_length + 1 |
|
self.vision = BaseVision(config) |
|
self.language = BCNLanguage(config) |
|
self.alignment = BaseAlignment(config) |
|
|
|
def forward(self, images, *args): |
|
v_res = self.vision(images) |
|
a_res = v_res |
|
all_l_res, all_a_res = [], [] |
|
for _ in range(self.iter_size): |
|
tokens = torch.softmax(a_res['logits'], dim=-1) |
|
lengths = a_res['pt_lengths'] |
|
lengths.clamp_(2, self.max_length) |
|
l_res = self.language(tokens, lengths) |
|
all_l_res.append(l_res) |
|
a_res = self.alignment(l_res['feature'], v_res['feature']) |
|
all_a_res.append(a_res) |
|
if self.training: |
|
return all_a_res, all_l_res, v_res |
|
else: |
|
return a_res, all_l_res[-1], v_res |
|
|