Spaces:
No application file
No application file
# -*- encoding: utf-8 -*- | |
''' | |
@File : text_feat_extractor.py | |
@Time : 2021/08/26 10:46:15 | |
@Author : Chuhao Jin | |
@Email : [email protected] | |
''' | |
# here put the import lib | |
import os | |
import sys | |
import pickle | |
import argparse | |
base_dir = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(base_dir) | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
import numpy as np | |
from transformers import AutoTokenizer | |
from utils import getLanMask | |
from utils.config import cfg_from_yaml_file, cfg | |
from models.vl_model import * | |
from tqdm import tqdm | |
import pdb | |
import json | |
class TextModel(nn.Module): | |
def __init__(self, model_cfg): | |
super(TextModel, self).__init__() | |
self.model_cfg = model_cfg | |
self.learnable = nn.ModuleDict() | |
self.learnable['textencoder'] = TextLearnableEncoder(model_cfg) | |
def forward(self, texts, maskTexts): | |
textFea = self.learnable['textencoder'](texts, maskTexts) # <bsz, img_dim> | |
textFea = F.normalize(textFea, p=2, dim=-1) | |
return textFea | |
class TextFeatureExtractor: | |
def __init__(self, cfg_file, model_weights, gpu_id=0): | |
self.gpu_id = gpu_id | |
self.cfg_file = cfg_file | |
self.cfg = cfg_from_yaml_file(self.cfg_file, cfg) | |
self.cfg.MODEL.ENCODER = os.path.join(base_dir, self.cfg.MODEL.ENCODER) | |
self.text_model = TextModel(model_cfg=self.cfg.MODEL) | |
self.text_model = self.text_model.cuda(self.gpu_id) | |
model_component = torch.load(model_weights, map_location=torch.device('cuda:{}'.format(self.gpu_id))) | |
text_model_component = {} | |
for key in model_component["learnable"].keys(): | |
if "textencoder." in key: | |
text_model_component[key] = model_component["learnable"][key] | |
self.text_model.learnable.load_state_dict(text_model_component) | |
self.text_model.eval() | |
self.text_transform = AutoTokenizer.from_pretrained('./hfl/chinese-bert-wwm-ext') | |
def extract(self, text_input): | |
if text_input is None: | |
return None | |
else: | |
text_info = self.text_transform(text_input, padding='max_length', truncation=True, | |
max_length=self.cfg.MODEL.MAX_TEXT_LEN, return_tensors='pt') | |
text = text_info.input_ids.reshape(-1) | |
text_len = torch.sum(text_info.attention_mask) | |
with torch.no_grad(): | |
texts = text.unsqueeze(0) | |
text_lens = text_len.unsqueeze(0) | |
textMask = getLanMask(text_lens, cfg.MODEL.MAX_TEXT_LEN) | |
textMask = textMask.cuda(self.gpu_id) | |
texts = texts.cuda(self.gpu_id) | |
text_lens = text_lens.cuda(self.gpu_id) | |
text_fea = self.text_model(texts, textMask) | |
text_fea = text_fea.cpu().numpy() | |
return text_fea | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--txt_path', type=str, default=None) | |
parser.add_argument('--feat_save_dir', type=str, default=None) | |
parser.add_argument('--cfg_file', type=str, default='cfg/test_xyb.yml') | |
parser.add_argument('--brivl_checkpoint', type=str, | |
default='/innovation_cfs/mmatch/infguo/weights/BriVL-1.0-5500w.pth') | |
args = parser.parse_args() | |
cfg_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.cfg_file) | |
model_weights = args.brivl_checkpoint | |
vfe = TextFeatureExtractor(cfg_file, model_weights) | |
save_dir = args.feat_save_dir | |
if not os.path.exists(args.feat_save_dir): | |
os.makedirs(args.feat_save_dir) | |
for i in os.listdir(args.txt_path): | |
clip_data = json.load(open(os.path.join(args.txt_path, i)), encoding='UTF-8') | |
for clip in clip_data["clips"]: | |
clip["multi_factor"] = {"semantics": None} | |
if "original_text" in clip and clip["original_text"] and len(clip["original_text"]) > 0: | |
text = clip["original_text"] | |
fea = vfe.extract(text) | |
fea = fea.squeeze(axis=0).tolist() | |
clip["multi_factor"]["semantics"] = fea | |
with open(os.path.join(args.feat_save_dir, i), "w", encoding="utf-8") as fp: | |
json.dump(clip_data, fp, ensure_ascii=False, indent=4) | |