|
from __future__ import annotations |
|
|
|
import argparse |
|
import logging |
|
import math |
|
import queue |
|
from typing import Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.multiprocessing as mp |
|
from tqdm.autonotebook import trange |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
from mteb import MTEB |
|
|
|
TASK_LIST_CLASSIFICATION = [ |
|
"AmazonCounterfactualClassification", |
|
"AmazonPolarityClassification", |
|
"AmazonReviewsClassification", |
|
"Banking77Classification", |
|
"EmotionClassification", |
|
"ImdbClassification", |
|
"MassiveIntentClassification", |
|
"MassiveScenarioClassification", |
|
"MTOPDomainClassification", |
|
"MTOPIntentClassification", |
|
"ToxicConversationsClassification", |
|
"TweetSentimentExtractionClassification", |
|
] |
|
|
|
TASK_LIST_CLUSTERING = [ |
|
"ArxivClusteringP2P", |
|
"ArxivClusteringS2S", |
|
"BiorxivClusteringP2P", |
|
"BiorxivClusteringS2S", |
|
"MedrxivClusteringP2P", |
|
"MedrxivClusteringS2S", |
|
"RedditClustering", |
|
"RedditClusteringP2P", |
|
"StackExchangeClustering", |
|
"StackExchangeClusteringP2P", |
|
"TwentyNewsgroupsClustering", |
|
] |
|
|
|
TASK_LIST_PAIR_CLASSIFICATION = [ |
|
"SprintDuplicateQuestions", |
|
"TwitterSemEval2015", |
|
"TwitterURLCorpus", |
|
] |
|
|
|
TASK_LIST_RERANKING = [ |
|
"AskUbuntuDupQuestions", |
|
"MindSmallReranking", |
|
"SciDocsRR", |
|
"StackOverflowDupQuestions", |
|
] |
|
|
|
TASK_LIST_RETRIEVAL = [ |
|
"ArguAna", |
|
"ClimateFEVER", |
|
"CQADupstackAndroidRetrieval", |
|
"CQADupstackEnglishRetrieval", |
|
"CQADupstackGamingRetrieval", |
|
"CQADupstackGisRetrieval", |
|
"CQADupstackMathematicaRetrieval", |
|
"CQADupstackPhysicsRetrieval", |
|
"CQADupstackProgrammersRetrieval", |
|
"CQADupstackStatsRetrieval", |
|
"CQADupstackTexRetrieval", |
|
"CQADupstackUnixRetrieval", |
|
"CQADupstackWebmastersRetrieval", |
|
"CQADupstackWordpressRetrieval", |
|
"DBPedia", |
|
"FEVER", |
|
"FiQA2018", |
|
"HotpotQA", |
|
"MSMARCO", |
|
"NFCorpus", |
|
"NQ", |
|
"QuoraRetrieval", |
|
"SCIDOCS", |
|
"SciFact", |
|
"Touche2020", |
|
"TRECCOVID", |
|
] |
|
|
|
TASK_LIST_STS = [ |
|
"BIOSSES", |
|
"SICK-R", |
|
"STS12", |
|
"STS13", |
|
"STS14", |
|
"STS15", |
|
"STS16", |
|
"STS17", |
|
"STS22", |
|
"STSBenchmark", |
|
"SummEval", |
|
] |
|
|
|
MTEB_TASK_LIST = ( |
|
TASK_LIST_CLASSIFICATION |
|
+ TASK_LIST_CLUSTERING |
|
+ TASK_LIST_PAIR_CLASSIFICATION |
|
+ TASK_LIST_RERANKING |
|
+ TASK_LIST_RETRIEVAL |
|
+ TASK_LIST_STS |
|
) |
|
|
|
|
|
CMTEB_TASK_LIST = [ |
|
"TNews", |
|
"IFlyTek", |
|
"MultilingualSentiment", |
|
"JDReview", |
|
"OnlineShopping", |
|
"Waimai", |
|
"AmazonReviewsClassification", |
|
"MassiveIntentClassification", |
|
"MassiveScenarioClassification", |
|
"MultilingualSentiment", |
|
"CLSClusteringS2S", |
|
"CLSClusteringP2P", |
|
"ThuNewsClusteringS2S", |
|
"ThuNewsClusteringP2P", |
|
"Ocnli", |
|
"Cmnli", |
|
"T2Reranking", |
|
"MmarcoReranking", |
|
"CMedQAv1", |
|
"CMedQAv2", |
|
"T2Retrieval", |
|
"MMarcoRetrieval", |
|
"DuRetrieval", |
|
"CovidRetrieval", |
|
"CmedqaRetrieval", |
|
"EcomRetrieval", |
|
"MedicalRetrieval", |
|
"VideoRetrieval", |
|
"ATEC", |
|
"BQ", |
|
"LCQMC", |
|
"PAWSX", |
|
"STSB", |
|
"AFQMC", |
|
"QBQTC", |
|
"STS22", |
|
] |
|
|
|
MTEB_PL = [ |
|
"CBD", |
|
"PolEmo2.0-IN", |
|
"PolEmo2.0-OUT", |
|
"AllegroReviews", |
|
"PAC", |
|
"MassiveIntentClassification", |
|
"MassiveScenarioClassification", |
|
"SICK-E-PL", |
|
"PPC", |
|
"CDSC-E", |
|
"PSC", |
|
"8TagsClustering", |
|
"SICK-R-PL", |
|
"CDSC-R", |
|
"STS22", |
|
"ArguAna-PL", |
|
"DBPedia-PL", |
|
"FiQA-PL", |
|
"HotpotQA-PL", |
|
"MSMARCO-PL", |
|
"NFCorpus-PL", |
|
"NQ-PL", |
|
"Quora-PL", |
|
"SCIDOCS-PL", |
|
"SciFact-PL", |
|
"TRECCOVID-PL", |
|
] |
|
|
|
MTEB_FR = [ |
|
"AmazonReviewsClassification", |
|
"MasakhaNEWSClassification", |
|
"MassiveIntentClassification", |
|
"MassiveScenarioClassification", |
|
"MTOPDomainClassification", |
|
"MTOPIntentClassification", |
|
"OpusparcusPC", |
|
"PawsX", |
|
"AlloProfClusteringP2P", |
|
"AlloProfClusteringS2S", |
|
"HALClusteringS2S", |
|
"MasakhaNEWSClusteringP2P", |
|
"MasakhaNEWSClusteringS2S", |
|
"MLSUMClusteringP2P", |
|
"MLSUMClusteringS2S", |
|
"SyntecReranking", |
|
"AlloprofReranking", |
|
"AlloprofRetrieval", |
|
"BSARDRetrieval", |
|
"SyntecRetrieval", |
|
"XPQARetrieval", |
|
"MintakaRetrieval", |
|
"SummEvalFr", |
|
"STSBenchmarkMultilingualSTS", |
|
"STS22", |
|
"SICKFr", |
|
] |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s : %(message)s" |
|
) |
|
|
|
logger = logging.getLogger("eval_mteb_qwen.py") |
|
|
|
|
|
def get_detailed_instruct(task_description: str) -> str: |
|
if not task_description: |
|
return "" |
|
|
|
return "Instruct: {}\nQuery: ".format(task_description) |
|
|
|
|
|
def get_task_def_by_task_name_and_type( |
|
task_name: str, |
|
task_type: str, |
|
default_instruct="Given a web search query, retrieve relevant passages that answer the query", |
|
) -> str: |
|
if task_type in ["STS"]: |
|
return "Retrieve semantically similar text" |
|
|
|
if task_type in ["Summarization"]: |
|
return "Given a news summary, retrieve other semantically similar summaries" |
|
|
|
if task_type in ["BitextMining"]: |
|
return "Retrieve parallel sentences" |
|
|
|
if task_type in ["Classification"]: |
|
task_name_to_instruct: Dict[str, str] = { |
|
"AmazonCounterfactualClassification": "Classify a given Amazon customer review text as either counterfactual or not-counterfactual", |
|
"AmazonPolarityClassification": "Classify Amazon reviews into positive or negative sentiment", |
|
"AmazonReviewsClassification": "Classify the given Amazon review into its appropriate rating category", |
|
"Banking77Classification": "Given a online banking query, find the corresponding intents", |
|
"EmotionClassification": "Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise", |
|
"ImdbClassification": "Classify the sentiment expressed in the given movie review text from the IMDB dataset", |
|
"MassiveIntentClassification": "Given a user utterance as query, find the user intents", |
|
"MassiveScenarioClassification": "Given a user utterance as query, find the user scenarios", |
|
"MTOPDomainClassification": "Classify the intent domain of the given utterance in task-oriented conversation", |
|
"MTOPIntentClassification": "Classify the intent of the given utterance in task-oriented conversation", |
|
"ToxicConversationsClassification": "Classify the given comments as either toxic or not toxic", |
|
"TweetSentimentExtractionClassification": "Classify the sentiment of a given tweet as either positive, negative, or neutral", |
|
|
|
"TNews": "Classify the fine-grained category of the given news title", |
|
"IFlyTek": "Given an App description text, find the appropriate fine-grained category", |
|
"MultilingualSentiment": "Classify sentiment of the customer review into positive, neutral, or negative", |
|
"JDReview": "Classify the customer review for iPhone on e-commerce platform into positive or negative", |
|
"OnlineShopping": "Classify the customer review for online shopping into positive or negative", |
|
"Waimai": "Classify the customer review from a food takeaway platform into positive or negative", |
|
|
|
"CBD": "Classify the sentiment of polish tweet reviews", |
|
"PolEmo2.0-IN": "Classify the sentiment of in-domain (medicine and hotels) online reviews", |
|
"PolEmo2.0-OUT": "Classify the sentiment of out-of-domain (products and school) online reviews", |
|
"AllegroReviews": "Classify the sentiment of reviews from e-commerce marketplace Allegro", |
|
"PAC": 'Classify the sentence into one of the two types: "BEZPIECZNE_POSTANOWIENIE_UMOWNE" and "KLAUZULA_ABUZYWNA"', |
|
} |
|
return task_name_to_instruct[task_name] |
|
|
|
if task_type in ["Clustering"]: |
|
task_name_to_instruct: Dict[str, str] = { |
|
"ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts", |
|
"ArxivClusteringS2S": "Identify the main and secondary category of Arxiv papers based on the titles", |
|
"BiorxivClusteringP2P": "Identify the main category of Biorxiv papers based on the titles and abstracts", |
|
"BiorxivClusteringS2S": "Identify the main category of Biorxiv papers based on the titles", |
|
"MedrxivClusteringP2P": "Identify the main category of Medrxiv papers based on the titles and abstracts", |
|
"MedrxivClusteringS2S": "Identify the main category of Medrxiv papers based on the titles", |
|
"RedditClustering": "Identify the topic or theme of Reddit posts based on the titles", |
|
"RedditClusteringP2P": "Identify the topic or theme of Reddit posts based on the titles and posts", |
|
"StackExchangeClustering": "Identify the topic or theme of StackExchange posts based on the titles", |
|
"StackExchangeClusteringP2P": "Identify the topic or theme of StackExchange posts based on the given paragraphs", |
|
"TwentyNewsgroupsClustering": "Identify the topic or theme of the given news articles", |
|
|
|
"CLSClusteringS2S": "Identify the main category of scholar papers based on the titles", |
|
"CLSClusteringP2P": "Identify the main category of scholar papers based on the titles and abstracts", |
|
"ThuNewsClusteringS2S": "Identify the topic or theme of the given news articles based on the titles", |
|
"ThuNewsClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents", |
|
|
|
"AlloProfClusteringP2P": "Identify the main category of Allo Prof document based on the titles and descriptions", |
|
"AlloProfClusteringS2S": "Identify the main category of Allo Prof document based on the titles", |
|
"HALClusteringS2S": "Identify the main category of academic passage based on the titles and contents", |
|
"MasakhaNEWSClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents", |
|
"MasakhaNEWSClusteringS2S": "Identify the topic or theme of the given news articles based on the titles", |
|
"MLSUMClusteringP2P": "Identify the topic or theme of the given articles based on the titles and contents", |
|
"MLSUMClusteringS2S": "Identify the topic or theme of the given articles based on the titles", |
|
|
|
"8TagsClustering": "Identify of headlines from social media posts in Polish into 8 categories: film, history, food, medicine, motorization, work, sport and technology", |
|
} |
|
return task_name_to_instruct[task_name] |
|
|
|
if task_type in ["Reranking", "PairClassification"]: |
|
task_name_to_instruct: Dict[str, str] = { |
|
"AskUbuntuDupQuestions": "Retrieve duplicate questions from AskUbuntu forum", |
|
"MindSmallReranking": "Retrieve relevant news articles based on user browsing history", |
|
"SciDocsRR": "Given a title of a scientific paper, retrieve the titles of other relevant papers", |
|
"StackOverflowDupQuestions": "Retrieve duplicate questions from StackOverflow forum", |
|
"SprintDuplicateQuestions": "Retrieve duplicate questions from Sprint forum", |
|
"TwitterSemEval2015": "Retrieve tweets that are semantically similar to the given tweet", |
|
"TwitterURLCorpus": "Retrieve tweets that are semantically similar to the given tweet", |
|
|
|
"T2Reranking": "Given a Chinese search query, retrieve web passages that answer the question", |
|
"MmarcoReranking": "Given a Chinese search query, retrieve web passages that answer the question", |
|
"CMedQAv1": "Given a Chinese community medical question, retrieve replies that best answer the question", |
|
"CMedQAv2": "Given a Chinese community medical question, retrieve replies that best answer the question", |
|
"Ocnli": "Retrieve semantically similar text.", |
|
"Cmnli": "Retrieve semantically similar text.", |
|
|
|
"AlloprofReranking": "Given a question, retrieve passages that answer the question", |
|
"OpusparcusPC": "Retrieve semantically similar text", |
|
"PawsX": "Retrieve semantically similar text", |
|
"SyntecReranking": "Given a question, retrieve passages that answer the question", |
|
|
|
"SICK-E-PL": "Retrieve semantically similar text", |
|
"PPC": "Retrieve semantically similar text", |
|
"CDSC-E": "Retrieve semantically similar text", |
|
"PSC": "Retrieve semantically similar text", |
|
} |
|
return task_name_to_instruct[task_name] |
|
|
|
if task_type in ["Retrieval"]: |
|
if task_name.lower().startswith("cqadupstack"): |
|
return "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question" |
|
|
|
task_name_to_instruct: Dict[str, str] = { |
|
"ArguAna": "Given a claim, find documents that refute the claim", |
|
"ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim", |
|
"DBPedia": "Given a query, retrieve relevant entity descriptions from DBPedia", |
|
"FEVER": "Given a claim, retrieve documents that support or refute the claim", |
|
"FiQA2018": "Given a financial question, retrieve user replies that best answer the question", |
|
"HotpotQA": "Given a multi-hop question, retrieve documents that can help answer the question", |
|
"MSMARCO": "Given a web search query, retrieve relevant passages that answer the query", |
|
"NFCorpus": "Given a question, retrieve relevant documents that best answer the question", |
|
"NQ": "Given a question, retrieve Wikipedia passages that answer the question", |
|
"QuoraRetrieval": "Given a question, retrieve questions that are semantically equivalent to the given question", |
|
"SCIDOCS": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper", |
|
"SciFact": "Given a scientific claim, retrieve documents that support or refute the claim", |
|
"Touche2020": "Given a question, retrieve detailed and persuasive arguments that answer the question", |
|
"TRECCOVID": "Given a query on COVID-19, retrieve documents that answer the query", |
|
|
|
"T2Retrieval": "Given a Chinese search query, retrieve web passages that answer the question", |
|
"MMarcoRetrieval": "Given a web search query, retrieve relevant passages that answer the query", |
|
"DuRetrieval": "Given a Chinese search query, retrieve web passages that answer the question", |
|
"CovidRetrieval": "Given a question on COVID-19, retrieve news articles that answer the question", |
|
"CmedqaRetrieval": "Given a Chinese community medical question, retrieve replies that best answer the question", |
|
"EcomRetrieval": "Given a user query from an e-commerce website, retrieve description sentences of relevant products", |
|
"MedicalRetrieval": "Given a medical question, retrieve user replies that best answer the question", |
|
"VideoRetrieval": "Given a video search query, retrieve the titles of relevant videos", |
|
|
|
"AlloprofRetrieval": "Given a question, retrieve passages that answer the question", |
|
"BSARDRetrieval": "Given a question, retrieve passages that answer the question", |
|
"SyntecRetrieval": "Given a question, retrieve passages that answer the question", |
|
"XPQARetrieval": "Given a question, retrieve passages that answer the question", |
|
"MintakaRetrieval": "Given a question, retrieve passages that answer the question", |
|
|
|
"ArguAna-PL": "Given a claim, find documents that refute the claim", |
|
"DBPedia-PL": "Given a query, retrieve relevant entity descriptions from DBPedia", |
|
"FiQA-PL": "Given a financial question, retrieve user replies that best answer the question", |
|
"HotpotQA-PL": "Given a multi-hop question, retrieve documents that can help answer the question", |
|
"MSMARCO-PL": "Given a web search query, retrieve relevant passages that answer the query", |
|
"NFCorpus-PL": "Given a question, retrieve relevant documents that best answer the question", |
|
"NQ-PL": "Given a question, retrieve Wikipedia passages that answer the question", |
|
"Quora-PL": "Given a question, retrieve questions that are semantically equivalent to the given question", |
|
"SCIDOCS-PL": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper", |
|
"SciFact-PL": "Given a scientific claim, retrieve documents that support or refute the claim", |
|
"TRECCOVID-PL": "Given a query on COVID-19, retrieve documents that answer the query", |
|
} |
|
|
|
|
|
task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()}) |
|
|
|
task_name_to_instruct["trec-covid"] = task_name_to_instruct["TRECCOVID"] |
|
task_name_to_instruct["climate-fever"] = task_name_to_instruct["ClimateFEVER"] |
|
task_name_to_instruct["dbpedia-entity"] = task_name_to_instruct["DBPedia"] |
|
task_name_to_instruct["webis-touche2020"] = task_name_to_instruct["Touche2020"] |
|
task_name_to_instruct["fiqa"] = task_name_to_instruct["FiQA2018"] |
|
task_name_to_instruct["quora"] = task_name_to_instruct["QuoraRetrieval"] |
|
|
|
|
|
task_name_to_instruct["miracl"] = ( |
|
"Given a question, retrieve Wikipedia passages that answer the question" |
|
) |
|
|
|
return task_name_to_instruct[task_name] |
|
logging.warning( |
|
f"No instruction config for task {task_name} with type {task_type}, use default instruction." |
|
) |
|
return default_instruct |
|
|
|
|
|
class Encoder(torch.nn.Module): |
|
def __init__(self, name_or_path: str, pooling: str): |
|
super().__init__() |
|
self.model = AutoModel.from_pretrained(name_or_path, trust_remote_code=True) |
|
self.model = self.model.half() |
|
self.model.eval() |
|
self.pooling = pooling |
|
|
|
def forward(self, **features) -> torch.Tensor: |
|
output = self.model(**features, output_hidden_states=True, return_dict=True) |
|
hidden_state = output.hidden_states[-1] |
|
embeddings = self.pooler(hidden_state, **features) |
|
return embeddings |
|
|
|
def pooler( |
|
self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, **kwargs |
|
) -> torch.Tensor: |
|
if attention_mask.ndim == 2: |
|
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()) |
|
elif attention_mask.ndim == 3: |
|
mask_expanded = attention_mask |
|
else: |
|
raise RuntimeError(f"Unexpected {attention_mask.ndim=}") |
|
|
|
hidden_state = hidden_state * mask_expanded |
|
|
|
if self.pooling == "first": |
|
pooled_output = hidden_state[:, 0] |
|
|
|
elif self.pooling == "last": |
|
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] |
|
if left_padding: |
|
return hidden_state[:, -1] |
|
else: |
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
batch_size = hidden_state.shape[0] |
|
return hidden_state[ |
|
torch.arange(batch_size, device=hidden_state.device), sequence_lengths |
|
] |
|
elif self.pooling == "mean": |
|
|
|
lengths = mask_expanded.sum(1).clamp(min=1e-9) |
|
pooled_output = hidden_state.sum(dim=1) / lengths |
|
|
|
elif self.pooling == "weightedmean": |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float() |
|
|
|
weights = ( |
|
torch.arange(start=1, end=hidden_state.shape[1] + 1) |
|
.unsqueeze(0) |
|
.unsqueeze(-1) |
|
.expand(hidden_state.size()) |
|
.float() |
|
.to(hidden_state.device) |
|
) |
|
assert weights.shape == hidden_state.shape == input_mask_expanded.shape |
|
input_mask_expanded = input_mask_expanded * weights |
|
|
|
sum_embeddings = torch.sum(hidden_state * input_mask_expanded, 1) |
|
sum_mask = input_mask_expanded.sum(1) |
|
sum_mask = torch.clamp(sum_mask, min=1e-9) |
|
pooled_output = sum_embeddings / sum_mask |
|
|
|
else: |
|
raise ValueError(f"Wrong pooler mode : {self.pooling}") |
|
return pooled_output |
|
|
|
|
|
class Wrapper: |
|
def __init__( |
|
self, |
|
tokenizer, |
|
encoder: Encoder, |
|
batch_size: int, |
|
max_seq_len: int = 512, |
|
normalize_embeddings: bool = False, |
|
default_query: bool = False, |
|
force_default: bool = False, |
|
sep: str = " ", |
|
mp_tensor_to_cuda: bool = False, |
|
instruction: Optional[str] = None, |
|
): |
|
self.tokenizer = tokenizer |
|
self.model = encoder |
|
self.batch_size = batch_size |
|
self.max_seq_len = max_seq_len |
|
self.pool: Optional[dict] = None |
|
self.normalize_embeddings = normalize_embeddings |
|
self.mp_tensor_to_cuda = mp_tensor_to_cuda |
|
self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>") |
|
self.instruction = instruction |
|
self.default_query = default_query |
|
self.sep = sep |
|
self.force_default = force_default |
|
if self.tokenizer.padding_side != "right": |
|
logger.warning( |
|
f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right" |
|
) |
|
self.tokenizer.padding_side = "right" |
|
if self.tokenizer.pad_token is None: |
|
logger.warning(f"Set tokenizer.pad_token as eos_token {self.tokenizer.eos_token}") |
|
self.tokenizer.pad_token = "<|endoftext|>" |
|
|
|
def start(self, target_devices: Optional[List[str]] = None): |
|
""" |
|
Starts multi process to process the encoding with several, independent processes. |
|
This method is recommended if you want to encode on multiple GPUs. It is advised |
|
to start only one process per GPU. This method works together with encode_multi_process |
|
|
|
:param target_devices: PyTorch target devices, e.g. cuda:0, cuda:1... If None, all available CUDA devices will be used |
|
:return: Returns a dict with the target processes, an input queue and and output queue. |
|
""" |
|
if target_devices is None: |
|
if torch.cuda.is_available(): |
|
target_devices = ["cuda:{}".format(i) for i in range(torch.cuda.device_count())] |
|
else: |
|
logger.info("CUDA is not available. Start 4 CPU worker") |
|
target_devices = ["cpu"] * 4 |
|
|
|
logger.info( |
|
"Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices))) |
|
) |
|
print("multi instruction", self.instruction) |
|
ctx = mp.get_context("spawn") |
|
input_queue = ctx.Queue() |
|
output_queue = ctx.Queue() |
|
processes = [] |
|
|
|
for cuda_id in target_devices: |
|
p = ctx.Process( |
|
target=self._encode_multi_process_worker, |
|
args=(cuda_id, self, input_queue, output_queue), |
|
daemon=True, |
|
) |
|
p.start() |
|
processes.append(p) |
|
|
|
self.pool = {"input": input_queue, "output": output_queue, "processes": processes} |
|
|
|
def stop(self): |
|
""" |
|
Stops all processes started with start_multi_process_pool |
|
""" |
|
for p in self.pool["processes"]: |
|
p.terminate() |
|
|
|
for p in self.pool["processes"]: |
|
p.join() |
|
p.close() |
|
|
|
self.pool["input"].close() |
|
self.pool["output"].close() |
|
|
|
@staticmethod |
|
def _encode_multi_process_worker(target_device: str, model, input_queue, results_queue): |
|
""" |
|
Internal working process to encode sentences in multi-process setup |
|
""" |
|
while True: |
|
try: |
|
id, sentences, kwargs = input_queue.get() |
|
kwargs.update(device=target_device, show_progress_bar=False, convert_to_numpy=True) |
|
embeddings = model._encode(sentences, **kwargs) |
|
results_queue.put([id, embeddings]) |
|
except queue.Empty: |
|
break |
|
|
|
def encode_multi_process(self, sentences: List[str], **kwargs): |
|
""" |
|
This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages |
|
and sent to individual processes, which encode these on the different GPUs. This method is only suitable |
|
for encoding large sets of sentences |
|
|
|
:param sentences: List of sentences |
|
:param pool: A pool of workers started with SentenceTransformer.start_multi_process_pool |
|
:param chunk_size: Sentences are chunked and sent to the individual processes. If none, it determine a sensible size. |
|
:param kwargs: other keyword arguments for model.encode() such as batch_size |
|
:return: Numpy matrix with all embeddings |
|
""" |
|
part_size = math.ceil(len(sentences) / len(self.pool["processes"])) |
|
chunk_size = part_size if part_size < 3200 else 3200 |
|
|
|
logger.debug( |
|
f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}" |
|
) |
|
|
|
input_queue = self.pool["input"] |
|
last_chunk_id = 0 |
|
chunk = [] |
|
|
|
for sentence in sentences: |
|
chunk.append(sentence) |
|
if len(chunk) >= chunk_size: |
|
input_queue.put([last_chunk_id, chunk, kwargs]) |
|
last_chunk_id += 1 |
|
chunk = [] |
|
|
|
if len(chunk) > 0: |
|
input_queue.put([last_chunk_id, chunk, kwargs]) |
|
last_chunk_id += 1 |
|
|
|
output_queue = self.pool["output"] |
|
results_list = sorted( |
|
[output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0] |
|
) |
|
embeddings = np.concatenate([result[1] for result in results_list]) |
|
return embeddings |
|
|
|
@staticmethod |
|
def batch_to_device(batch, target_device): |
|
""" |
|
send a pytorch batch to a device (CPU/GPU) |
|
""" |
|
for key in batch: |
|
if isinstance(batch[key], torch.Tensor): |
|
batch[key] = batch[key].to(target_device) |
|
return batch |
|
|
|
def _text_length(self, text: Union[List[int], List[List[int]]]): |
|
""" |
|
Help function to get the length for the input text. Text can be either |
|
a list of ints (which means a single text as input), or a tuple of list of ints |
|
(representing several text inputs to the model). |
|
""" |
|
|
|
if isinstance(text, dict): |
|
return len(next(iter(text.values()))) |
|
elif not hasattr(text, "__len__"): |
|
return 1 |
|
elif len(text) == 0 or isinstance(text[0], int): |
|
return len(text) |
|
else: |
|
return sum([len(t) for t in text]) |
|
|
|
def _tokenize(self, sentences: List[str], is_query: bool): |
|
batch_dict = self.tokenizer( |
|
sentences, |
|
max_length=self.max_seq_len - 1, |
|
return_attention_mask=False, |
|
padding=False, |
|
truncation=True, |
|
) |
|
batch_dict["input_ids"] = [ |
|
input_ids + [self.tokenizer.eos_token_id] for input_ids in batch_dict["input_ids"] |
|
] |
|
batch_dict = self.tokenizer.pad( |
|
batch_dict, padding=True, return_attention_mask=True, return_tensors="pt" |
|
) |
|
batch_dict["is_causal"] = False |
|
return batch_dict |
|
|
|
def _encode( |
|
self, |
|
sentences: List[str], |
|
is_query: bool, |
|
convert_to_numpy: bool = True, |
|
convert_to_tensor: bool = False, |
|
device: Optional[str] = None, |
|
show_progress_bar: bool = True, |
|
**kwargs, |
|
): |
|
""" |
|
Computes sentence embeddings |
|
|
|
:param sentences: the sentences to embed |
|
:param batch_size: the batch size used for the computation |
|
:param show_progress_bar: Output a progress bar when encode sentences |
|
:param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values |
|
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. |
|
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy |
|
:param device: Which torch.device to use for the computation |
|
:param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. |
|
|
|
:return: |
|
By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned. |
|
""" |
|
self.model.eval() |
|
|
|
if convert_to_tensor: |
|
convert_to_numpy = False |
|
|
|
input_was_string = False |
|
if isinstance(sentences, str) or not hasattr( |
|
sentences, "__len__" |
|
): |
|
sentences = [sentences] |
|
input_was_string = True |
|
|
|
if device is None: |
|
device = self._target_device |
|
|
|
self.model.to(device) |
|
|
|
all_embeddings = [] |
|
length_sorted_idx = np.argsort([-self._text_length(s) for s in sentences]) |
|
sentences_sorted = [sentences[idx] for idx in length_sorted_idx] |
|
|
|
for start_index in trange( |
|
0, len(sentences), self.batch_size, desc="Batches", disable=not show_progress_bar |
|
): |
|
sentences_batch = sentences_sorted[start_index : start_index + self.batch_size] |
|
features = self._tokenize(sentences_batch, is_query) |
|
features = self.batch_to_device(features, device) |
|
|
|
with torch.no_grad(): |
|
embeddings = self.model(**features) |
|
|
|
if self.normalize_embeddings: |
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
|
|
|
|
if convert_to_numpy: |
|
embeddings = embeddings.cpu() |
|
|
|
all_embeddings.extend(embeddings) |
|
|
|
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] |
|
|
|
if convert_to_tensor: |
|
all_embeddings = torch.stack(all_embeddings) |
|
elif convert_to_numpy: |
|
|
|
all_embeddings = np.asarray([emb.to(torch.float).numpy() for emb in all_embeddings]) |
|
if input_was_string: |
|
all_embeddings = all_embeddings[0] |
|
|
|
return all_embeddings |
|
|
|
def encode( |
|
self, |
|
sentences: List[str], |
|
is_query: Optional[bool] = None, |
|
convert_to_tensor: bool = False, |
|
**kwargs, |
|
): |
|
is_query = self.default_query if is_query is None else is_query |
|
if is_query and self.instruction: |
|
sentences = [self.instruction + sent for sent in sentences] |
|
kwargs.update(is_query=is_query) |
|
if self.pool is not None: |
|
kwargs.update(show_progress_bar=False) |
|
embeddings = self.encode_multi_process(sentences, **kwargs) |
|
if convert_to_tensor: |
|
embeddings = torch.from_numpy(embeddings) |
|
if self.mp_tensor_to_cuda and torch.cuda.is_available(): |
|
embeddings = embeddings.to(torch.device("cuda")) |
|
return embeddings |
|
|
|
return self._encode(sentences, convert_to_tensor=convert_to_tensor, **kwargs) |
|
|
|
def encode_queries(self, queries: List[str], **kwargs): |
|
is_query = self.default_query if self.force_default else True |
|
return self.encode(queries, is_query=is_query, **kwargs) |
|
|
|
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs): |
|
|
|
if type(corpus) is dict: |
|
sentences = [ |
|
(corpus["title"][i] + self.sep + corpus["text"][i]).strip() |
|
if "title" in corpus |
|
else corpus["text"][i].strip() |
|
for i in range(len(corpus["text"])) |
|
] |
|
elif isinstance(corpus[0], dict): |
|
sentences = [ |
|
(doc["title"] + self.sep + doc["text"]).strip() |
|
if "title" in doc |
|
else doc["text"].strip() |
|
for doc in corpus |
|
] |
|
else: |
|
sentences = corpus |
|
is_query = self.default_query if self.force_default else False |
|
return self.encode(sentences, is_query=is_query, **kwargs) |
|
|
|
|
|
def main(args): |
|
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) |
|
encoder = Encoder(args.model, args.pooling) |
|
default_query = args.default_type == "query" |
|
model = Wrapper( |
|
tokenizer, |
|
encoder, |
|
batch_size=args.batch_size, |
|
max_seq_len=args.max_seq_len, |
|
normalize_embeddings=args.norm, |
|
default_query=default_query, |
|
) |
|
sym_retrievals = ["QuoraRetrieval", "ArguAna", "CQADupstack"] |
|
if args.task == "mteb": |
|
task_names = MTEB_TASK_LIST |
|
lang = ["en"] |
|
elif args.task == "cmteb": |
|
task_names = CMTEB_TASK_LIST |
|
lang = ["zh", "zh-CN"] |
|
elif args.task == "mteb-fr": |
|
task_names = MTEB_FR |
|
lang = ["fr"] |
|
elif args.task == "mteb-pl": |
|
task_names = MTEB_PL |
|
lang = ["pl"] |
|
else: |
|
task_names = [args.task] |
|
lang = ["en", "zh", "zh-CN", "pl", "fr"] |
|
for task in task_names: |
|
evaluation = MTEB(tasks=[task], task_langs=lang) |
|
task_cls = evaluation.tasks[0] |
|
task_name: str = task_cls.metadata_dict["name"] |
|
task_type: str = task_cls.metadata_dict["type"] |
|
instruction = get_task_def_by_task_name_and_type(task_name, task_type) |
|
model.instruction = get_detailed_instruct(instruction) |
|
if task == "MSMARCO": |
|
eval_splits = ["dev"] |
|
elif task in CMTEB_TASK_LIST: |
|
eval_splits = task_cls.metadata_dict["eval_splits"] |
|
else: |
|
eval_splits = ["test"] |
|
sym = False |
|
for name in sym_retrievals: |
|
if task.startswith(name): |
|
sym = True |
|
break |
|
else: |
|
sym = False |
|
if sym: |
|
logger.info( |
|
f"Switch to symmetric mode for {task}, all as {'query' if default_query else 'doc'}." |
|
) |
|
model.force_default = True |
|
evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits) |
|
|
|
if sym: |
|
logger.info(f"Switch back.") |
|
model.force_default = force_default_ori |
|
print("\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
_PARSER = argparse.ArgumentParser() |
|
_PARSER.add_argument("-m", "--model", type=str, default=None) |
|
_PARSER.add_argument("--pooling", type=str, default="last") |
|
_PARSER.add_argument("--output_dir", type=str, default=None) |
|
_PARSER.add_argument("--default_type", type=str, default="query") |
|
_PARSER.add_argument("--max_seq_len", type=int, default=512) |
|
_PARSER.add_argument("-b", "--batch_size", type=int, default=32) |
|
_PARSER.add_argument( |
|
"-t", |
|
"--task", |
|
type=str, |
|
default=None, |
|
) |
|
_PARSER.add_argument("--norm", action="store_true") |
|
_ARGS = _PARSER.parse_args() |
|
main(_ARGS) |
|
|