Spaces:
No application file
No application file
from typing import List, Tuple, Dict | |
import os | |
import time | |
from tqdm import tqdm | |
import torch | |
import numpy as np | |
from numpy import ndarray | |
from PIL import Image | |
from transformers import BertForSequenceClassification, BertTokenizer, CLIPProcessor, CLIPModel | |
class TextFeatureExtractor(object): | |
def __init__(self, language_model_path: str, local_file: bool=True, device: str='cpu'): | |
if device: | |
self.device = device | |
else: | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
language_model_path = "Taiyi-CLIP-Roberta-large-326M-Chinese" if local_file else "IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese" | |
self.text_tokenizer = BertTokenizer.from_pretrained(language_model_path, local_files_only=local_file) | |
self.text_encoder = BertForSequenceClassification.from_pretrained(language_model_path, local_files_only=local_file).eval().to(self.device) | |
def text(self, query_texts: List[str]) -> ndarray: | |
text = self.text_tokenizer(query_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.text_encoder.config.max_length)['input_ids'] | |
text = text.to(self.device) | |
with torch.no_grad(): | |
text_features = self.text_encoder(text).logits | |
text_features = text_features / text_features.norm(dim=1, keepdim=True) | |
text_features = text_features.squeeze | |
return text_features.detach().cpu().numpy() | |
class TaiyiFeatureExtractor(TextFeatureExtractor): | |
def __init__(self, language_model_path: str="Taiyi-CLIP-Roberta-large-326M-Chinese", local_file: bool = True, device: str = 'cpu'): | |
"""_summary_ | |
Args: | |
language_model_path (str, optional): Taiyi-CLIP-Roberta-large-326M-Chinese or IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese. Defaults to "Taiyi-CLIP-Roberta-large-326M-Chinese". | |
local_file (bool, optional): _description_. Defaults to True. | |
device (str, optional): _description_. Defaults to 'cpu'. | |
""" | |
super().__init__(language_model_path, local_file, device) | |