File size: 2,109 Bytes
6755a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from typing import List, Tuple, Dict
import os
import time

from tqdm import tqdm
import torch
import numpy as np
from numpy import ndarray
from PIL import Image
from transformers import BertForSequenceClassification, BertTokenizer, CLIPProcessor, CLIPModel



class TextFeatureExtractor(object):
    def __init__(self, language_model_path: str, local_file: bool=True, device: str='cpu'):
        if device:
            self.device = device
        else:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        language_model_path = "Taiyi-CLIP-Roberta-large-326M-Chinese" if local_file else "IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese"    
        self.text_tokenizer = BertTokenizer.from_pretrained(language_model_path, local_files_only=local_file)
        self.text_encoder = BertForSequenceClassification.from_pretrained(language_model_path, local_files_only=local_file).eval().to(self.device)
   
    def text(self, query_texts: List[str]) -> ndarray:
        text = self.text_tokenizer(query_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.text_encoder.config.max_length)['input_ids']
        text = text.to(self.device)
        with torch.no_grad():
            text_features = self.text_encoder(text).logits
            text_features = text_features / text_features.norm(dim=1, keepdim=True)
            text_features = text_features.squeeze
        return text_features.detach().cpu().numpy()


class TaiyiFeatureExtractor(TextFeatureExtractor):
    def __init__(self, language_model_path: str="Taiyi-CLIP-Roberta-large-326M-Chinese", local_file: bool = True, device: str = 'cpu'):
        """_summary_

        Args:
            language_model_path (str, optional): Taiyi-CLIP-Roberta-large-326M-Chinese or IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese. Defaults to "Taiyi-CLIP-Roberta-large-326M-Chinese".
            local_file (bool, optional): _description_. Defaults to True.
            device (str, optional): _description_. Defaults to 'cpu'.
        """
        super().__init__(language_model_path, local_file, device)