import torch import torch.nn as nn from .fakeTransformer import FakeTransformer from .bert import Bert from ..utils import pairLoss, alignmentLoss, attAlignmentLoss, AlignTripLoss, SimpTripLoss, NCELoss import torch.nn.functional as F import timm import numpy as np import sys class ImgLearnableEncoder(nn.Module): def __init__(self, model_cfg): super(ImgLearnableEncoder, self).__init__() self.backbone = timm.create_model(model_cfg.CNN, pretrained=True) self.model_cfg = model_cfg self.learnable = nn.ModuleDict() self.learnable['imgFC'] = FakeTransformer(model_cfg.IMG_FEATURE_DIM, model_cfg.IMG_FEATURE_DIM, model_cfg.IMG_FEATURE_DIM) img_encoder_layer = nn.TransformerEncoderLayer(d_model=model_cfg.IMG_FEATURE_DIM, nhead=model_cfg.IMG_TRANSFORMER_HEAD) self.learnable['imgAtt'] = nn.TransformerEncoder(img_encoder_layer, num_layers=model_cfg.IMG_TRANSFORMER_LAYER) self.learnable['max_pool'] = nn.Sequential( nn.Conv2d(model_cfg.IMG_FEATURE_DIM, model_cfg.IMG_FEATURE_DIM, kernel_size=1), nn.AvgPool2d(model_cfg.GRID_SIZE, stride=1) ) self.init_param() def init_param(self): for name, param in self.backbone.named_parameters(): # print('@@@@@@@@@@@@@@@@@@@@@@@') condition = 'blocks.6' not in name and 'blocks.5' not in name and 'blocks.4' not in name and 'blocks.3' not in name if condition: param.requires_grad = False else: # print(name + ' need grads') param.requires_grad = True sys.stdout.flush() def roi_grid_pool(self, spatial_features_2d, rois): """ Args: rois: (B, num_rois, 4) spatial_features_2d: (B, C, H, W) Returns: pooled_features : (B, num_rois, C) """ batch_size = spatial_features_2d.size(0) rois = rois.detach() height, width = spatial_features_2d.size(2), spatial_features_2d.size(3) # 特征图的长宽 #print(spatial_features_2d.size()) down_sample_ratio = self.model_cfg.IMG_SIZE / height pooled_features_list = [] torch.backends.cudnn.enabled = False for b_id in range(batch_size): # todo 这里有一个坐标系的转换需要做 # Map global boxes coordinates to feature map coordinates x1 = rois[b_id, :, 0] / down_sample_ratio y1 = rois[b_id, :, 1] / down_sample_ratio x2 = rois[b_id, :, 2] / down_sample_ratio y2 = rois[b_id, :, 3] / down_sample_ratio #print(x1, y1, x2, y2) angle = torch.zeros((1), device=spatial_features_2d.device) ########## cosa = torch.cos(angle) sina = torch.sin(angle) theta = torch.stack(( (x2 - x1) / (width - 1) * cosa, (x2 - x1) / (width - 1) * (-sina), (x1 + x2 - width + 1) / (width - 1), (y2 - y1) / (height - 1) * sina, (y2 - y1) / (height - 1) * cosa, (y1 + y2 - height + 1) / (height - 1) ), dim=1).view(-1, 2, 3).float() grid_size = self.model_cfg.GRID_SIZE grid = nn.functional.affine_grid( theta, torch.Size((rois.size(1), spatial_features_2d.size(1), grid_size, grid_size)) ) pooled_features = nn.functional.grid_sample( spatial_features_2d[b_id].unsqueeze(0).expand(rois.size(1), spatial_features_2d.size(1), height, width), grid ) pooled_features = self.learnable['max_pool'](pooled_features) pooled_features_list.append(pooled_features.squeeze()) torch.backends.cudnn.enabled = True pooled_features = torch.stack(pooled_features_list, dim=0) return pooled_features def forward(self, imgFea, maskImages, image_boxs): feature_map = self.backbone.forward_features(imgFea) imgFea = self.roi_grid_pool(feature_map, image_boxs) imgFea = F.normalize(imgFea, p=2, dim=-1) imgFea = self.learnable['imgAtt'](imgFea.transpose(0, 1), src_key_padding_mask=(maskImages == 0)).transpose(0,1) tmpMask = torch.where(maskImages == 1, torch.tensor([1.0], device=maskImages.device), torch.tensor([0.0], device=maskImages.device)) imgFea = (imgFea * tmpMask.unsqueeze(-1)).sum(dim=1) / tmpMask.sum(dim=1).unsqueeze(-1) # (bs, dim) imgFea = self.learnable['imgFC'](imgFea) return imgFea class TextLearnableEncoder(nn.Module): def __init__(self, model_cfg): super(TextLearnableEncoder, self).__init__() self.backbone = Bert(model_cfg) self.model_cfg = model_cfg self.learnable = nn.ModuleDict() self.learnable['textFC'] = FakeTransformer(model_cfg.TEXT_FEATURE_DIM, model_cfg.IMG_FEATURE_DIM, model_cfg.IMG_FEATURE_DIM) text_encoder_layer = nn.TransformerEncoderLayer(d_model=model_cfg.TEXT_FEATURE_DIM, nhead=model_cfg.TEXT_TRANSFORMER_HEAD) self.learnable['textAtt'] = nn.TransformerEncoder(text_encoder_layer, num_layers=model_cfg.TEXT_TRANSFORMER_LAYER) self.init_param() def init_param(self): #print('!!!!!!!!!!!!!!!!') for name, param in self.backbone.named_parameters(): #print(name) if 'large' not in self.model_cfg.ENCODER: if 'layer.11' not in name and 'layer.10' not in name and 'layer.9' not in name and 'layer.8' not in name: param.requires_grad = False else: #print('????????') # print(name + ' need grads') param.requires_grad = True else: if 'layer.21' not in name and 'layer.22' not in name and 'layer.23' not in name and 'layer.20' not in name: # and 'layer.9' not in name param.requires_grad = False else: #print('????????') # print(name + ' need grads') param.requires_grad = True sys.stdout.flush() def forward(self, textFea, maskTexts): textFea = self.backbone(textFea) textFea = F.normalize(textFea, p=2, dim=-1) # print(textFea.shape) # torch.Size([75, 80, 1024]) # print(maskTexts.shape) # print(1) textFea = self.learnable['textAtt'](textFea.transpose(0, 1), src_key_padding_mask=(maskTexts == 0)).transpose(0,1) # print(textFea.shape) # torch.Size([75, 80, 1024]) # print(2) tmpMask = torch.where(maskTexts == 1, torch.tensor([1.0], device=maskTexts.device), torch.tensor([0.0], device=maskTexts.device)) textFea = (textFea * tmpMask.unsqueeze(-1)).sum(dim=1) / tmpMask.sum(dim=1).unsqueeze(-1) # (bs, dim) # print(textFea.shape) # torch.Size([75, 80, 1024]) # print(3) textFea = self.learnable['textFC'](textFea) # print(textFea.shape) # torch.Size([75, 80, 1024]) # print(4) return textFea class VL_model(nn.Module): def __init__(self, model_cfg): super(VL_model, self).__init__() self.model_cfg = model_cfg self.learnable = nn.ModuleDict() self.learnable['imgencoder'] = ImgLearnableEncoder(model_cfg) self.learnable['imgencoder_mom'] = ImgLearnableEncoder(model_cfg) self.learnable['textencoder'] = TextLearnableEncoder(model_cfg) self.learnable['textencoder_mom'] = TextLearnableEncoder(model_cfg) #self.generator = Generator(model_cfg) ############ add new params in .yml config file self.K = model_cfg.QUEUE_SIZE # 6400 self.m = model_cfg.MOMENTUM # 0.9 self.T = model_cfg.TEMPERATURE # 0.07 self.topk = model_cfg.TOPK # 5 self.multi_label = False ############ add new params in .yml config file # init the parameter of two models self.init_param() # create the img queue self.register_buffer("img_queue", torch.randn(model_cfg.IMG_FEATURE_DIM, self.K)) self.img_queue = nn.functional.normalize(self.img_queue, dim=0) self.register_buffer("img_queue_ptr", torch.zeros(1, dtype=torch.long)) # image queue points # create the text queue self.register_buffer("text_queue", torch.randn(model_cfg.IMG_FEATURE_DIM, self.K)) self.text_queue = nn.functional.normalize(self.text_queue, dim=0) self.register_buffer("text_queue_ptr", torch.zeros(1, dtype=torch.long)) # text queue points def init_param(self): for param_q, param_k in zip(self.learnable['imgencoder'].parameters(), self.learnable['imgencoder_mom'].parameters()): param_k.data.copy_(param_q.data) # initialize param_k.requires_grad = False # not update by gradient for param_q, param_k in zip(self.learnable['textencoder'].parameters(), self.learnable['textencoder_mom'].parameters()): param_k.data.copy_(param_q.data) # initialize param_k.requires_grad = False # not update by gradient @torch.no_grad() def _momentum_update_key_encoder(self): """ Momentum update of the key encoder for image modal """ for param_q, param_k in zip(self.learnable['imgencoder'].parameters(), self.learnable['imgencoder_mom'].parameters()): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) for param_q, param_k in zip(self.learnable['textencoder'].parameters(), self.learnable['textencoder_mom'].parameters()): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) @torch.no_grad() def _dequeue_and_enqueue(self, keys, option='img'): # option in # gather keys before updating queue keys = concat_all_gather(keys) batch_size = keys.shape[0] if option == 'img': ptr = int(self.img_queue_ptr) assert self.K % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) self.img_queue[:, ptr:ptr + batch_size] = keys.T ptr = (ptr + batch_size) % self.K # move pointer self.img_queue_ptr[0] = ptr else: ptr = int(self.text_queue_ptr) assert self.K % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) self.text_queue[:, ptr:ptr + batch_size] = keys.T ptr = (ptr + batch_size) % self.K # move pointer self.text_queue_ptr[0] = ptr @torch.no_grad() def _batch_shuffle_ddp(self, x, x_mask): """ Batch shuffle, for making use of BatchNorm. *** Only support DistributedDataParallel (DDP) model. *** """ # gather from all gpus batch_size_this = x.shape[0] x_gather = concat_all_gather(x) x_mask_gather = concat_all_gather(x_mask) batch_size_all = x_gather.shape[0] num_gpus = batch_size_all // batch_size_this # random shuffle index idx_shuffle = torch.randperm(batch_size_all).cuda() # broadcast to all gpus torch.distributed.broadcast(idx_shuffle, src=0) # index for restoring idx_unshuffle = torch.argsort(idx_shuffle) # shuffled index for this gpu gpu_idx = torch.distributed.get_rank() idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] return x_gather[idx_this], x_mask_gather[idx_this], idx_unshuffle @torch.no_grad() def _batch_unshuffle_ddp(self, x, x_mask, idx_unshuffle): """ Undo batch shuffle. *** Only support DistributedDataParallel (DDP) model. *** """ # gather from all gpus batch_size_this = x.shape[0] x_gather = concat_all_gather(x) x_mask_gather = concat_all_gather(x_mask) batch_size_all = x_gather.shape[0] num_gpus = batch_size_all // batch_size_this # restored index for this gpu gpu_idx = torch.distributed.get_rank() idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] return x_gather[idx_this], x_mask_gather[idx_this] def forward(self, imgFea, texts, maskImages, maskTexts, text_lens, image_boxs, is_training=True): if self.model_cfg.IS_EXTRACT: return self.extract(imgFea, texts, maskImages, maskTexts, image_boxs) batch_size = imgFea.size(0) imgFea_q = self.learnable['imgencoder'](imgFea, maskImages, image_boxs) # imgFea_q = F.normalize(imgFea_q, p=2, dim=-1) textFea_q = self.learnable['textencoder'](texts, maskTexts) # textFea_q = F.normalize(textFea_q, p=2, dim=-1) # compute key features with torch.no_grad(): # no gradient to keys self._momentum_update_key_encoder() # update the key encoder # shuffle for making use of BN imgFea, image_boxs, idx_unshuffle = self._batch_shuffle_ddp(imgFea, image_boxs) imgFea_k = self.learnable['imgencoder_mom'](imgFea, maskImages, image_boxs) # imgFea_k = F.normalize(imgFea_k, p=2, dim=-1) # undo shuffle imgFea_k, image_boxs = self._batch_unshuffle_ddp(imgFea_k, image_boxs, idx_unshuffle) # shuffle for making use of BN texts, maskTexts, idx_unshuffle = self._batch_shuffle_ddp(texts, maskTexts) textFea_k = self.learnable['textencoder_mom'](texts, maskTexts) # textFea_k = F.normalize(textFea_k, p=2, dim=-1) # undo shuffle textFea_k, maskTexts = self._batch_unshuffle_ddp(textFea_k, maskTexts, idx_unshuffle) # compute logits for image -> text # positive logits: Nx1 i2t_l_pos = torch.einsum('nc,nc->n', [imgFea_q, textFea_k]).unsqueeze(-1) # negative logits: NxK i2t_l_neg = torch.einsum('nc,ck->nk', [imgFea_q, self.text_queue.clone().detach()]) # logits: Nx(1+K) i2t_logits = torch.cat([i2t_l_pos, i2t_l_neg], dim=-1) i2t_logits /= self.T # compute logits for text -> image # positive logits: Nx1 t2i_l_pos = torch.einsum('nc,nc->n', [textFea_q, imgFea_k]).unsqueeze(-1) # negative logits: NxK t2i_l_neg = torch.einsum('nc,ck->nk', [textFea_q, self.img_queue.clone().detach()]) # logits: Nx(1+K) t2i_logits = torch.cat([t2i_l_pos, t2i_l_neg], dim=-1) t2i_logits /= self.T ### multi-label mask = torch.zeros((batch_size, self.K)).bool().cuda() # if self.multi_label: mask_sim_txt = textFea_k.matmul(self.text_queue.clone().detach()) # -> mask_sim_img = imgFea_k.matmul(self.img_queue.clone().detach()) _, topkidx_txt = torch.topk(mask_sim_txt, self.topk, dim=1) # _, topkidx_img = torch.topk(mask_sim_img, self.topk, dim=1) # topk_onehot_txt = torch.zeros_like(mask_sim_txt) # topk_onehot_txt.scatter_(1, topkidx_txt, 1) # one hot vector topk_onehot_img = torch.zeros_like(mask_sim_img) # topk_onehot_img.scatter_(1, topkidx_img, 1) # one hot vector mask[topk_onehot_txt.bool() & topk_onehot_img.bool()] = True # mask = torch.cat([torch.ones((batch_size, 1), dtype=torch.long, device=mask.device).bool(), mask], dim=1) # ### multi-label t2i_loss = -1 * F.log_softmax(t2i_logits, dim=1) # t2i_loss = torch.masked_select(t2i_loss, mask).sum() / batch_size # masked_select return 1-d tensor i2t_loss = -1 * F.log_softmax(i2t_logits, dim=1) i2t_loss = torch.masked_select(i2t_loss, mask).sum() / batch_size # masked_select return 1-d tensor loss = t2i_loss + i2t_loss ## enqueue and dequeue self._dequeue_and_enqueue(imgFea_k, option='img') self._dequeue_and_enqueue(textFea_k, option='text') # ----------caption------------- # TODO: update ''' if is_training: caption = None caption_loss = self.generator(imgFea_q, texts, text_lens, maskTexts, is_training) else: caption_loss, caption = self.generator(imgFea_q, texts, text_lens, maskTexts, is_training) ''' return loss#, caption_loss, caption def extract(self, imgFea, texts, maskImages, maskTexts, image_boxs): imgFea = self.learnable['imgencoder'](imgFea, maskImages, image_boxs) # textFea = self.learnable['textencoder'](texts, maskTexts) # imgFea = F.normalize(imgFea, p=2, dim=-1) textFea = F.normalize(textFea, p=2, dim=-1) retrieval_feat_group = {} retrieval_feat_group['img_text'] = (imgFea, textFea) return retrieval_feat_group # utils @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output