zyznull vatolinalex commited on
Commit
ca1f7d7
·
verified ·
1 Parent(s): 581c2b2

Fixed some minor bugs in eval_mteb.py (#26)

Browse files

- Fixed some minor bugs in eval_mteb.py (ca35073e5b5f91721d7a342a9bd29bc4dba4c9a7)


Co-authored-by: Vatolin Alexey <[email protected]>

Files changed (1) hide show
  1. scripts/eval_mteb.py +321 -216
scripts/eval_mteb.py CHANGED
@@ -1,21 +1,18 @@
 
 
1
  import argparse
2
- from collections import defaultdict
3
- import json
4
  import logging
5
  import math
6
- import os
7
- import sys
8
  import queue
9
  from typing import Dict, List, Optional, Union
10
 
11
- from tqdm.autonotebook import trange
12
- import datasets
13
  import numpy as np
14
  import torch
15
  import torch.multiprocessing as mp
 
16
  from transformers import AutoModel, AutoTokenizer
17
- from transformers import AutoModelForCausalLM
18
- from mteb import MTEB, CrosslingualTask, MultilingualTask
19
 
20
  TASK_LIST_CLASSIFICATION = [
21
  "AmazonCounterfactualClassification",
@@ -112,99 +109,179 @@ MTEB_TASK_LIST = (
112
  )
113
 
114
 
115
- CMTEB_TASK_LIST = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai','AmazonReviewsClassification', 'MassiveIntentClassification', 'MassiveScenarioClassification', 'MultilingualSentiment',
116
- 'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P',
117
- 'Ocnli', 'Cmnli',
118
- 'T2Reranking', 'MmarcoReranking', 'CMedQAv1', 'CMedQAv2',
119
- 'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
120
- 'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC', 'STS22']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  MTEB_PL = [
123
- "CBD","PolEmo2.0-IN","PolEmo2.0-OUT","AllegroReviews","PAC","MassiveIntentClassification","MassiveScenarioClassification",
124
- "SICK-E-PL","PPC","CDSC-E","PSC","8TagsClustering","SICK-R-PL","CDSC-R","STS22",
125
- "ArguAna-PL","DBPedia-PL","FiQA-PL","HotpotQA-PL","MSMARCO-PL","NFCorpus-PL","NQ-PL","Quora-PL","SCIDOCS-PL","SciFact-PL","TRECCOVID-PL"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  ]
127
 
128
  MTEB_FR = [
129
- "AmazonReviewsClassification","MasakhaNEWSClassification","MassiveIntentClassification",
130
- "MassiveScenarioClassification","MTOPDomainClassification","MTOPIntentClassification","OpusparcusPC","PawsX",
131
- "AlloProfClusteringP2P","AlloProfClusteringS2S","HALClusteringS2S","MasakhaNEWSClusteringP2P","MasakhaNEWSClusteringS2S","MLSUMClusteringP2P","MLSUMClusteringS2S",
132
- "SyntecReranking","AlloprofReranking","AlloprofRetrieval","BSARDRetrieval","SyntecRetrieval","XPQARetrieval","MintakaRetrieval",
133
- "SummEvalFr","STSBenchmarkMultilingualSTS","STS22","SICKFr"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  ]
135
 
136
  logging.basicConfig(
137
- level=logging.INFO,
138
- format='%(asctime)s - %(levelname)s - %(name)s : %(message)s'
139
  )
140
 
141
- logger = logging.getLogger('eval_mteb_qwen.py')
 
142
 
143
  def get_detailed_instruct(task_description: str) -> str:
144
  if not task_description:
145
- return ''
146
 
147
- return 'Instruct: {}\nQuery: '.format(task_description)
148
 
149
- def get_task_def_by_task_name_and_type(task_name: str, task_type: str, default_instruct='Given a web search query, retrieve relevant passages that answer the query') -> str:
150
- if task_type in ['STS']:
 
 
 
 
 
151
  return "Retrieve semantically similar text"
152
 
153
- if task_type in ['Summarization']:
154
  return "Given a news summary, retrieve other semantically similar summaries"
155
 
156
- if task_type in ['BitextMining']:
157
  return "Retrieve parallel sentences"
158
 
159
- if task_type in ['Classification']:
160
  task_name_to_instruct: Dict[str, str] = {
161
- 'AmazonCounterfactualClassification': 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual',
162
- 'AmazonPolarityClassification': 'Classify Amazon reviews into positive or negative sentiment',
163
- 'AmazonReviewsClassification': 'Classify the given Amazon review into its appropriate rating category',
164
- 'Banking77Classification': 'Given a online banking query, find the corresponding intents',
165
- 'EmotionClassification': 'Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise',
166
- 'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset',
167
- 'MassiveIntentClassification': 'Given a user utterance as query, find the user intents',
168
- 'MassiveScenarioClassification': 'Given a user utterance as query, find the user scenarios',
169
- 'MTOPDomainClassification': 'Classify the intent domain of the given utterance in task-oriented conversation',
170
- 'MTOPIntentClassification': 'Classify the intent of the given utterance in task-oriented conversation',
171
- 'ToxicConversationsClassification': 'Classify the given comments as either toxic or not toxic',
172
- 'TweetSentimentExtractionClassification': 'Classify the sentiment of a given tweet as either positive, negative, or neutral',
173
  # C-MTEB eval instructions
174
- 'TNews': 'Classify the fine-grained category of the given news title',
175
- 'IFlyTek': 'Given an App description text, find the appropriate fine-grained category',
176
- 'MultilingualSentiment': 'Classify sentiment of the customer review into positive, neutral, or negative',
177
- 'JDReview': 'Classify the customer review for iPhone on e-commerce platform into positive or negative',
178
- 'OnlineShopping': 'Classify the customer review for online shopping into positive or negative',
179
- 'Waimai': 'Classify the customer review from a food takeaway platform into positive or negative',
180
  # MTEB-pl eval instructions
181
- "CBD":"Classify the sentiment of polish tweet reviews",
182
  "PolEmo2.0-IN": "Classify the sentiment of in-domain (medicine and hotels) online reviews",
183
- "PolEmo2.0-OUT":"Classify the sentiment of out-of-domain (products and school) online reviews",
184
  "AllegroReviews": "Classify the sentiment of reviews from e-commerce marketplace Allegro",
185
- "PAC": "Classify the sentence into one of the two types: \"BEZPIECZNE_POSTANOWIENIE_UMOWNE\" and \"KLAUZULA_ABUZYWNA\"",
186
-
187
  }
188
  return task_name_to_instruct[task_name]
189
 
190
- if task_type in ['Clustering']:
191
  task_name_to_instruct: Dict[str, str] = {
192
- 'ArxivClusteringP2P': 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts',
193
- 'ArxivClusteringS2S': 'Identify the main and secondary category of Arxiv papers based on the titles',
194
- 'BiorxivClusteringP2P': 'Identify the main category of Biorxiv papers based on the titles and abstracts',
195
- 'BiorxivClusteringS2S': 'Identify the main category of Biorxiv papers based on the titles',
196
- 'MedrxivClusteringP2P': 'Identify the main category of Medrxiv papers based on the titles and abstracts',
197
- 'MedrxivClusteringS2S': 'Identify the main category of Medrxiv papers based on the titles',
198
- 'RedditClustering': 'Identify the topic or theme of Reddit posts based on the titles',
199
- 'RedditClusteringP2P': 'Identify the topic or theme of Reddit posts based on the titles and posts',
200
- 'StackExchangeClustering': 'Identify the topic or theme of StackExchange posts based on the titles',
201
- 'StackExchangeClusteringP2P': 'Identify the topic or theme of StackExchange posts based on the given paragraphs',
202
- 'TwentyNewsgroupsClustering': 'Identify the topic or theme of the given news articles',
203
  # C-MTEB eval instructions
204
- 'CLSClusteringS2S': 'Identify the main category of scholar papers based on the titles',
205
- 'CLSClusteringP2P': 'Identify the main category of scholar papers based on the titles and abstracts',
206
- 'ThuNewsClusteringS2S': 'Identify the topic or theme of the given news articles based on the titles',
207
- 'ThuNewsClusteringP2P': 'Identify the topic or theme of the given news articles based on the titles and contents',
208
  # MTEB-fr eval instructions
209
  "AlloProfClusteringP2P": "Identify the main category of Allo Prof document based on the titles and descriptions",
210
  "AlloProfClusteringS2S": "Identify the main category of Allo Prof document based on the titles",
@@ -212,32 +289,32 @@ def get_task_def_by_task_name_and_type(task_name: str, task_type: str, default_i
212
  "MasakhaNEWSClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents",
213
  "MasakhaNEWSClusteringS2S": "Identify the topic or theme of the given news articles based on the titles",
214
  "MLSUMClusteringP2P": "Identify the topic or theme of the given articles based on the titles and contents",
215
- "MLSUMClusteringS2S": "Identify the topic or theme of the given articles based on the titles",
216
  # MTEB-pl eval instructions
217
  "8TagsClustering": "Identify of headlines from social media posts in Polish into 8 categories: film, history, food, medicine, motorization, work, sport and technology",
218
  }
219
  return task_name_to_instruct[task_name]
220
 
221
- if task_type in ['Reranking', 'PairClassification']:
222
  task_name_to_instruct: Dict[str, str] = {
223
- 'AskUbuntuDupQuestions': 'Retrieve duplicate questions from AskUbuntu forum',
224
- 'MindSmallReranking': 'Retrieve relevant news articles based on user browsing history',
225
- 'SciDocsRR': 'Given a title of a scientific paper, retrieve the titles of other relevant papers',
226
- 'StackOverflowDupQuestions': 'Retrieve duplicate questions from StackOverflow forum',
227
- 'SprintDuplicateQuestions': 'Retrieve duplicate questions from Sprint forum',
228
- 'TwitterSemEval2015': 'Retrieve tweets that are semantically similar to the given tweet',
229
- 'TwitterURLCorpus': 'Retrieve tweets that are semantically similar to the given tweet',
230
  # C-MTEB eval instructions
231
- 'T2Reranking': 'Given a Chinese search query, retrieve web passages that answer the question',
232
- 'MmarcoReranking': 'Given a Chinese search query, retrieve web passages that answer the question',
233
- 'CMedQAv1': 'Given a Chinese community medical question, retrieve replies that best answer the question',
234
- 'CMedQAv2': 'Given a Chinese community medical question, retrieve replies that best answer the question',
235
- 'Ocnli': 'Retrieve semantically similar text.',
236
- 'Cmnli': 'Retrieve semantically similar text.',
237
  # MTEB-fr eval instructions
238
  "AlloprofReranking": "Given a question, retrieve passages that answer the question",
239
- "OpusparcusPC":"Retrieve semantically similar text",
240
- "PawsX":"Retrieve semantically similar text",
241
  "SyntecReranking": "Given a question, retrieve passages that answer the question",
242
  # MTEB-pl eval instructions
243
  "SICK-E-PL": "Retrieve semantically similar text",
@@ -247,41 +324,41 @@ def get_task_def_by_task_name_and_type(task_name: str, task_type: str, default_i
247
  }
248
  return task_name_to_instruct[task_name]
249
 
250
- if task_type in ['Retrieval']:
251
- if task_name.lower().startswith('cqadupstack'):
252
- return 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question'
253
 
254
  task_name_to_instruct: Dict[str, str] = {
255
- 'ArguAna': 'Given a claim, find documents that refute the claim',
256
- 'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim',
257
- 'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia',
258
- 'FEVER': 'Given a claim, retrieve documents that support or refute the claim',
259
- 'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question',
260
- 'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question',
261
- 'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query',
262
- 'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question',
263
- 'NQ': 'Given a question, retrieve Wikipedia passages that answer the question',
264
- 'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question',
265
- 'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper',
266
- 'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim',
267
- 'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question',
268
- 'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query',
269
  # C-MTEB eval instructions
270
- 'T2Retrieval': 'Given a Chinese search query, retrieve web passages that answer the question',
271
- 'MMarcoRetrieval': 'Given a web search query, retrieve relevant passages that answer the query',
272
- 'DuRetrieval': 'Given a Chinese search query, retrieve web passages that answer the question',
273
- 'CovidRetrieval': 'Given a question on COVID-19, retrieve news articles that answer the question',
274
- 'CmedqaRetrieval': 'Given a Chinese community medical question, retrieve replies that best answer the question',
275
- 'EcomRetrieval': 'Given a user query from an e-commerce website, retrieve description sentences of relevant products',
276
- 'MedicalRetrieval': 'Given a medical question, retrieve user replies that best answer the question',
277
- 'VideoRetrieval': 'Given a video search query, retrieve the titles of relevant videos',
278
  # MTEB-fr eval instructions
279
  "AlloprofRetrieval": "Given a question, retrieve passages that answer the question",
280
  "BSARDRetrieval": "Given a question, retrieve passages that answer the question",
281
  "SyntecRetrieval": "Given a question, retrieve passages that answer the question",
282
  "XPQARetrieval": "Given a question, retrieve passages that answer the question",
283
  "MintakaRetrieval": "Given a question, retrieve passages that answer the question",
284
- # MTEB-pl eval instructions
285
  "ArguAna-PL": "Given a claim, find documents that refute the claim",
286
  "DBPedia-PL": "Given a query, retrieve relevant entity descriptions from DBPedia",
287
  "FiQA-PL": "Given a financial question, retrieve user replies that best answer the question",
@@ -292,45 +369,47 @@ def get_task_def_by_task_name_and_type(task_name: str, task_type: str, default_i
292
  "Quora-PL": "Given a question, retrieve questions that are semantically equivalent to the given question",
293
  "SCIDOCS-PL": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper",
294
  "SciFact-PL": "Given a scientific claim, retrieve documents that support or refute the claim",
295
- "TRECCOVID-PL": "Given a query on COVID-19, retrieve documents that answer the query"
296
  }
297
 
298
  # add lower case keys to match some beir names
299
  task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
300
  # other cases where lower case match still doesn't work
301
- task_name_to_instruct['trec-covid'] = task_name_to_instruct['TRECCOVID']
302
- task_name_to_instruct['climate-fever'] = task_name_to_instruct['ClimateFEVER']
303
- task_name_to_instruct['dbpedia-entity'] = task_name_to_instruct['DBPedia']
304
- task_name_to_instruct['webis-touche2020'] = task_name_to_instruct['Touche2020']
305
- task_name_to_instruct['fiqa'] = task_name_to_instruct['FiQA2018']
306
- task_name_to_instruct['quora'] = task_name_to_instruct['QuoraRetrieval']
307
 
308
  # for miracl evaluation
309
- task_name_to_instruct['miracl'] = 'Given a question, retrieve Wikipedia passages that answer the question'
 
 
310
 
311
  return task_name_to_instruct[task_name]
312
- logging.warning(f"No instruction config for task {task_name} with type {task_type}, use default instruction.")
313
- return default_instruct
 
 
 
314
 
315
  class Encoder(torch.nn.Module):
316
- def __init__(self, name_or_path:str, pooling: str):
317
  super().__init__()
318
  self.model = AutoModel.from_pretrained(name_or_path, trust_remote_code=True)
319
  self.model = self.model.half()
320
- self.model.eval()
321
  self.pooling = pooling
322
 
323
  def forward(self, **features) -> torch.Tensor:
324
  output = self.model(**features, output_hidden_states=True, return_dict=True)
325
- hidden_state = output.hidden_states[-1]
326
  embeddings = self.pooler(hidden_state, **features)
327
  return embeddings
328
 
329
  def pooler(
330
- self,
331
- hidden_state: torch.Tensor,
332
- attention_mask: torch.Tensor,
333
- **kwargs
334
  ) -> torch.Tensor:
335
  if attention_mask.ndim == 2:
336
  mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size())
@@ -341,32 +420,35 @@ class Encoder(torch.nn.Module):
341
 
342
  hidden_state = hidden_state * mask_expanded
343
 
344
- if self.pooling == 'first':
345
  pooled_output = hidden_state[:, 0]
346
 
347
- elif self.pooling == 'last':
348
- left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
349
  if left_padding:
350
  return hidden_state[:, -1]
351
  else:
352
  sequence_lengths = attention_mask.sum(dim=1) - 1
353
  batch_size = hidden_state.shape[0]
354
- return hidden_state[torch.arange(batch_size, device=hidden_state.device), sequence_lengths]
355
- elif self.pooling == 'mean':
 
 
356
  # TODO: weight
357
  lengths = mask_expanded.sum(1).clamp(min=1e-9)
358
  pooled_output = hidden_state.sum(dim=1) / lengths
359
 
360
- elif self.pooling == 'weightedmean':
361
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
362
  # hidden_state shape: bs, seq, hidden_dim
363
  weights = (
364
- torch.arange(start=1, end=hidden_state.shape[1] + 1)
365
- .unsqueeze(0)
366
- .unsqueeze(-1)
367
- .expand(hidden_state.size())
368
- .float().to(hidden_state.device)
369
- )
 
370
  assert weights.shape == hidden_state.shape == input_mask_expanded.shape
371
  input_mask_expanded = input_mask_expanded * weights
372
 
@@ -392,28 +474,29 @@ class Wrapper:
392
  force_default: bool = False,
393
  sep: str = " ",
394
  mp_tensor_to_cuda: bool = False,
395
- instruction: str = None,
396
- attn_type: str = None
397
  ):
398
  self.tokenizer = tokenizer
399
  self.model = encoder
400
  self.batch_size = batch_size
401
  self.max_seq_len = max_seq_len
402
- self.pool: dict = None
403
  self.normalize_embeddings = normalize_embeddings
404
  self.mp_tensor_to_cuda = mp_tensor_to_cuda
405
  self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
406
  self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
407
  self.instruction = instruction
408
- self.default_query = default_query
409
  self.sep = sep
410
  self.force_default = force_default
411
- if self.tokenizer.padding_side != 'right':
412
- logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
413
- self.tokenizer.padding_side = 'right'
 
 
414
  if self.tokenizer.pad_token is None:
415
  logger.warning(f"Set tokenizer.pad_token as eos_token {self.tokenizer.eos_token}")
416
- self.tokenizer.pad_token='<|endoftext|>'
417
 
418
  def start(self, target_devices: Optional[List[str]] = None):
419
  """
@@ -426,14 +509,16 @@ class Wrapper:
426
  """
427
  if target_devices is None:
428
  if torch.cuda.is_available():
429
- target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())]
430
  else:
431
  logger.info("CUDA is not available. Start 4 CPU worker")
432
- target_devices = ['cpu']*4
433
 
434
- logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))
435
- print('multi instruction', self.instruction)
436
- ctx = mp.get_context('spawn')
 
 
437
  input_queue = ctx.Queue()
438
  output_queue = ctx.Queue()
439
  processes = []
@@ -442,26 +527,26 @@ class Wrapper:
442
  p = ctx.Process(
443
  target=self._encode_multi_process_worker,
444
  args=(cuda_id, self, input_queue, output_queue),
445
- daemon=True
446
  )
447
  p.start()
448
  processes.append(p)
449
 
450
- self.pool = {'input': input_queue, 'output': output_queue, 'processes': processes}
451
 
452
  def stop(self):
453
  """
454
  Stops all processes started with start_multi_process_pool
455
  """
456
- for p in self.pool['processes']:
457
  p.terminate()
458
 
459
- for p in self.pool['processes']:
460
  p.join()
461
  p.close()
462
 
463
- self.pool['input'].close()
464
- self.pool['output'].close()
465
 
466
  @staticmethod
467
  def _encode_multi_process_worker(target_device: str, model, input_queue, results_queue):
@@ -477,11 +562,7 @@ class Wrapper:
477
  except queue.Empty:
478
  break
479
 
480
- def encode_multi_process(
481
- self,
482
- sentences: List[str],
483
- **kwargs
484
- ):
485
  """
486
  This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
487
  and sent to individual processes, which encode these on the different GPUs. This method is only suitable
@@ -496,9 +577,11 @@ class Wrapper:
496
  part_size = math.ceil(len(sentences) / len(self.pool["processes"]))
497
  chunk_size = part_size if part_size < 3200 else 3200 # for retrieval chunk 50000
498
 
499
- logger.debug(f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}")
 
 
500
 
501
- input_queue = self.pool['input']
502
  last_chunk_id = 0
503
  chunk = []
504
 
@@ -513,8 +596,10 @@ class Wrapper:
513
  input_queue.put([last_chunk_id, chunk, kwargs])
514
  last_chunk_id += 1
515
 
516
- output_queue = self.pool['output']
517
- results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0])
 
 
518
  embeddings = np.concatenate([result[1] for result in results_list])
519
  return embeddings
520
 
@@ -535,33 +620,41 @@ class Wrapper:
535
  (representing several text inputs to the model).
536
  """
537
 
538
- if isinstance(text, dict): #{key: value} case
539
  return len(next(iter(text.values())))
540
- elif not hasattr(text, '__len__'): #Object has no len() method
541
  return 1
542
- elif len(text) == 0 or isinstance(text[0], int): #Empty string or list of ints
543
  return len(text)
544
  else:
545
- return sum([len(t) for t in text]) #Sum of length of individual strings
546
 
547
  def _tokenize(self, sentences: List[str], is_query: bool):
548
-
549
- batch_dict = self.tokenizer(sentences, max_length=self.max_seq_len - 1, return_attention_mask=False, padding=False, truncation=True)
550
- batch_dict['input_ids'] = [input_ids + [self.tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
551
- batch_dict = self.tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
552
- batch_dict['is_causal'] = False
 
 
 
 
 
 
 
 
 
553
  return batch_dict
554
 
555
-
556
  def _encode(
557
  self,
558
  sentences: List[str],
559
  is_query: bool,
560
  convert_to_numpy: bool = True,
561
  convert_to_tensor: bool = False,
562
- device: str = None,
563
  show_progress_bar: bool = True,
564
- **kwargs
565
  ):
566
  """
567
  Computes sentence embeddings
@@ -584,7 +677,9 @@ class Wrapper:
584
  convert_to_numpy = False
585
 
586
  input_was_string = False
587
- if isinstance(sentences, str) or not hasattr(sentences, '__len__'): #Cast an individual sentence to a list with length 1
 
 
588
  sentences = [sentences]
589
  input_was_string = True
590
 
@@ -597,8 +692,10 @@ class Wrapper:
597
  length_sorted_idx = np.argsort([-self._text_length(s) for s in sentences])
598
  sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
599
 
600
- for start_index in trange(0, len(sentences), self.batch_size, desc="Batches", disable=not show_progress_bar):
601
- sentences_batch = sentences_sorted[start_index:start_index + self.batch_size]
 
 
602
  features = self._tokenize(sentences_batch, is_query)
603
  features = self.batch_to_device(features, device)
604
 
@@ -619,7 +716,7 @@ class Wrapper:
619
  if convert_to_tensor:
620
  all_embeddings = torch.stack(all_embeddings)
621
  elif convert_to_numpy:
622
- #all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
623
  all_embeddings = np.asarray([emb.to(torch.float).numpy() for emb in all_embeddings])
624
  if input_was_string:
625
  all_embeddings = all_embeddings[0]
@@ -631,11 +728,11 @@ class Wrapper:
631
  sentences: List[str],
632
  is_query: Optional[bool] = None,
633
  convert_to_tensor: bool = False,
634
- **kwargs
635
  ):
636
  is_query = self.default_query if is_query is None else is_query
637
  if is_query and self.instruction:
638
- sentences = [self.instruction + sent for sent in sentences]
639
  kwargs.update(is_query=is_query)
640
  if self.pool is not None:
641
  kwargs.update(show_progress_bar=False)
@@ -643,7 +740,7 @@ class Wrapper:
643
  if convert_to_tensor:
644
  embeddings = torch.from_numpy(embeddings)
645
  if self.mp_tensor_to_cuda and torch.cuda.is_available():
646
- embeddings = embeddings.to(torch.device('cuda')) # default 0-th gpu
647
  return embeddings
648
 
649
  return self._encode(sentences, convert_to_tensor=convert_to_tensor, **kwargs)
@@ -663,7 +760,9 @@ class Wrapper:
663
  ]
664
  elif isinstance(corpus[0], dict):
665
  sentences = [
666
- (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
 
 
667
  for doc in corpus
668
  ]
669
  else:
@@ -671,43 +770,46 @@ class Wrapper:
671
  is_query = self.default_query if self.force_default else False
672
  return self.encode(sentences, is_query=is_query, **kwargs)
673
 
 
674
  def main(args):
675
  tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
676
  encoder = Encoder(args.model, args.pooling)
677
- default_query = args.default_type == 'query'
678
  model = Wrapper(
679
- tokenizer, encoder,
 
680
  batch_size=args.batch_size,
681
  max_seq_len=args.max_seq_len,
682
  normalize_embeddings=args.norm,
683
- default_query=default_query
684
  )
685
- sym_retrievals = ['QuoraRetrieval', 'ArguAna', 'CQADupstack']
686
- if args.task == 'mteb':
687
  task_names = MTEB_TASK_LIST
688
- lang = ['en']
689
- elif args.task == 'cmteb':
690
  task_names = CMTEB_TASK_LIST
691
- lang = ['zh','zh-CN']
692
- elif args.task == 'mteb-fr':
693
- tas_names = MTEB_FR
694
- lang = ['fr']
695
- elif args.task == 'mteb-pl':
696
- lang = ['pl']
 
697
  else:
698
  task_names = [args.task]
699
- lang = ['en','zh','zh-CN','pl','fr']
700
  for task in task_names:
701
  evaluation = MTEB(tasks=[task], task_langs=lang)
702
  task_cls = evaluation.tasks[0]
703
- task_name: str = task_cls.description['name']
704
- task_type: str = task_cls.description['type']
705
  instruction = get_task_def_by_task_name_and_type(task_name, task_type)
706
  model.instruction = get_detailed_instruct(instruction)
707
- if task == 'MSMARCO':
708
  eval_splits = ["dev"]
709
  elif task in CMTEB_TASK_LIST:
710
- eval_splits = task_cls.description['eval_splits']
711
  else:
712
  eval_splits = ["test"]
713
  sym = False
@@ -718,28 +820,31 @@ def main(args):
718
  else:
719
  sym = False
720
  if sym:
721
- logger.info(f"Switch to symmetric mode for {task}, all as {'query' if default_query else 'doc'}.")
 
 
722
  model.force_default = True
723
  evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
724
 
725
  if sym:
726
  logger.info(f"Switch back.")
727
  model.force_default = force_default_ori
728
- print('\n')
729
 
730
 
731
  if __name__ == "__main__":
732
  _PARSER = argparse.ArgumentParser()
733
- _PARSER.add_argument(
734
- "-m", "--model", type=str, default=None
735
- )
736
- _PARSER.add_argument("--pooling", type=str, default='last')
737
  _PARSER.add_argument("--output_dir", type=str, default=None)
738
- _PARSER.add_argument("--default_type", type=str, default='query')
739
  _PARSER.add_argument("--max_seq_len", type=int, default=512)
740
  _PARSER.add_argument("-b", "--batch_size", type=int, default=32)
741
  _PARSER.add_argument(
742
- "-t", "--task", type=str, default=None # None for running default tasks
 
 
 
743
  )
744
  _PARSER.add_argument("--norm", action="store_true")
745
  _ARGS = _PARSER.parse_args()
 
1
+ from __future__ import annotations
2
+
3
  import argparse
 
 
4
  import logging
5
  import math
 
 
6
  import queue
7
  from typing import Dict, List, Optional, Union
8
 
 
 
9
  import numpy as np
10
  import torch
11
  import torch.multiprocessing as mp
12
+ from tqdm.autonotebook import trange
13
  from transformers import AutoModel, AutoTokenizer
14
+
15
+ from mteb import MTEB
16
 
17
  TASK_LIST_CLASSIFICATION = [
18
  "AmazonCounterfactualClassification",
 
109
  )
110
 
111
 
112
+ CMTEB_TASK_LIST = [
113
+ "TNews",
114
+ "IFlyTek",
115
+ "MultilingualSentiment",
116
+ "JDReview",
117
+ "OnlineShopping",
118
+ "Waimai",
119
+ "AmazonReviewsClassification",
120
+ "MassiveIntentClassification",
121
+ "MassiveScenarioClassification",
122
+ "MultilingualSentiment",
123
+ "CLSClusteringS2S",
124
+ "CLSClusteringP2P",
125
+ "ThuNewsClusteringS2S",
126
+ "ThuNewsClusteringP2P",
127
+ "Ocnli",
128
+ "Cmnli",
129
+ "T2Reranking",
130
+ "MmarcoReranking",
131
+ "CMedQAv1",
132
+ "CMedQAv2",
133
+ "T2Retrieval",
134
+ "MMarcoRetrieval",
135
+ "DuRetrieval",
136
+ "CovidRetrieval",
137
+ "CmedqaRetrieval",
138
+ "EcomRetrieval",
139
+ "MedicalRetrieval",
140
+ "VideoRetrieval",
141
+ "ATEC",
142
+ "BQ",
143
+ "LCQMC",
144
+ "PAWSX",
145
+ "STSB",
146
+ "AFQMC",
147
+ "QBQTC",
148
+ "STS22",
149
+ ]
150
 
151
  MTEB_PL = [
152
+ "CBD",
153
+ "PolEmo2.0-IN",
154
+ "PolEmo2.0-OUT",
155
+ "AllegroReviews",
156
+ "PAC",
157
+ "MassiveIntentClassification",
158
+ "MassiveScenarioClassification",
159
+ "SICK-E-PL",
160
+ "PPC",
161
+ "CDSC-E",
162
+ "PSC",
163
+ "8TagsClustering",
164
+ "SICK-R-PL",
165
+ "CDSC-R",
166
+ "STS22",
167
+ "ArguAna-PL",
168
+ "DBPedia-PL",
169
+ "FiQA-PL",
170
+ "HotpotQA-PL",
171
+ "MSMARCO-PL",
172
+ "NFCorpus-PL",
173
+ "NQ-PL",
174
+ "Quora-PL",
175
+ "SCIDOCS-PL",
176
+ "SciFact-PL",
177
+ "TRECCOVID-PL",
178
  ]
179
 
180
  MTEB_FR = [
181
+ "AmazonReviewsClassification",
182
+ "MasakhaNEWSClassification",
183
+ "MassiveIntentClassification",
184
+ "MassiveScenarioClassification",
185
+ "MTOPDomainClassification",
186
+ "MTOPIntentClassification",
187
+ "OpusparcusPC",
188
+ "PawsX",
189
+ "AlloProfClusteringP2P",
190
+ "AlloProfClusteringS2S",
191
+ "HALClusteringS2S",
192
+ "MasakhaNEWSClusteringP2P",
193
+ "MasakhaNEWSClusteringS2S",
194
+ "MLSUMClusteringP2P",
195
+ "MLSUMClusteringS2S",
196
+ "SyntecReranking",
197
+ "AlloprofReranking",
198
+ "AlloprofRetrieval",
199
+ "BSARDRetrieval",
200
+ "SyntecRetrieval",
201
+ "XPQARetrieval",
202
+ "MintakaRetrieval",
203
+ "SummEvalFr",
204
+ "STSBenchmarkMultilingualSTS",
205
+ "STS22",
206
+ "SICKFr",
207
  ]
208
 
209
  logging.basicConfig(
210
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s : %(message)s"
 
211
  )
212
 
213
+ logger = logging.getLogger("eval_mteb_qwen.py")
214
+
215
 
216
  def get_detailed_instruct(task_description: str) -> str:
217
  if not task_description:
218
+ return ""
219
 
220
+ return "Instruct: {}\nQuery: ".format(task_description)
221
 
222
+
223
+ def get_task_def_by_task_name_and_type(
224
+ task_name: str,
225
+ task_type: str,
226
+ default_instruct="Given a web search query, retrieve relevant passages that answer the query",
227
+ ) -> str:
228
+ if task_type in ["STS"]:
229
  return "Retrieve semantically similar text"
230
 
231
+ if task_type in ["Summarization"]:
232
  return "Given a news summary, retrieve other semantically similar summaries"
233
 
234
+ if task_type in ["BitextMining"]:
235
  return "Retrieve parallel sentences"
236
 
237
+ if task_type in ["Classification"]:
238
  task_name_to_instruct: Dict[str, str] = {
239
+ "AmazonCounterfactualClassification": "Classify a given Amazon customer review text as either counterfactual or not-counterfactual",
240
+ "AmazonPolarityClassification": "Classify Amazon reviews into positive or negative sentiment",
241
+ "AmazonReviewsClassification": "Classify the given Amazon review into its appropriate rating category",
242
+ "Banking77Classification": "Given a online banking query, find the corresponding intents",
243
+ "EmotionClassification": "Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise",
244
+ "ImdbClassification": "Classify the sentiment expressed in the given movie review text from the IMDB dataset",
245
+ "MassiveIntentClassification": "Given a user utterance as query, find the user intents",
246
+ "MassiveScenarioClassification": "Given a user utterance as query, find the user scenarios",
247
+ "MTOPDomainClassification": "Classify the intent domain of the given utterance in task-oriented conversation",
248
+ "MTOPIntentClassification": "Classify the intent of the given utterance in task-oriented conversation",
249
+ "ToxicConversationsClassification": "Classify the given comments as either toxic or not toxic",
250
+ "TweetSentimentExtractionClassification": "Classify the sentiment of a given tweet as either positive, negative, or neutral",
251
  # C-MTEB eval instructions
252
+ "TNews": "Classify the fine-grained category of the given news title",
253
+ "IFlyTek": "Given an App description text, find the appropriate fine-grained category",
254
+ "MultilingualSentiment": "Classify sentiment of the customer review into positive, neutral, or negative",
255
+ "JDReview": "Classify the customer review for iPhone on e-commerce platform into positive or negative",
256
+ "OnlineShopping": "Classify the customer review for online shopping into positive or negative",
257
+ "Waimai": "Classify the customer review from a food takeaway platform into positive or negative",
258
  # MTEB-pl eval instructions
259
+ "CBD": "Classify the sentiment of polish tweet reviews",
260
  "PolEmo2.0-IN": "Classify the sentiment of in-domain (medicine and hotels) online reviews",
261
+ "PolEmo2.0-OUT": "Classify the sentiment of out-of-domain (products and school) online reviews",
262
  "AllegroReviews": "Classify the sentiment of reviews from e-commerce marketplace Allegro",
263
+ "PAC": 'Classify the sentence into one of the two types: "BEZPIECZNE_POSTANOWIENIE_UMOWNE" and "KLAUZULA_ABUZYWNA"',
 
264
  }
265
  return task_name_to_instruct[task_name]
266
 
267
+ if task_type in ["Clustering"]:
268
  task_name_to_instruct: Dict[str, str] = {
269
+ "ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts",
270
+ "ArxivClusteringS2S": "Identify the main and secondary category of Arxiv papers based on the titles",
271
+ "BiorxivClusteringP2P": "Identify the main category of Biorxiv papers based on the titles and abstracts",
272
+ "BiorxivClusteringS2S": "Identify the main category of Biorxiv papers based on the titles",
273
+ "MedrxivClusteringP2P": "Identify the main category of Medrxiv papers based on the titles and abstracts",
274
+ "MedrxivClusteringS2S": "Identify the main category of Medrxiv papers based on the titles",
275
+ "RedditClustering": "Identify the topic or theme of Reddit posts based on the titles",
276
+ "RedditClusteringP2P": "Identify the topic or theme of Reddit posts based on the titles and posts",
277
+ "StackExchangeClustering": "Identify the topic or theme of StackExchange posts based on the titles",
278
+ "StackExchangeClusteringP2P": "Identify the topic or theme of StackExchange posts based on the given paragraphs",
279
+ "TwentyNewsgroupsClustering": "Identify the topic or theme of the given news articles",
280
  # C-MTEB eval instructions
281
+ "CLSClusteringS2S": "Identify the main category of scholar papers based on the titles",
282
+ "CLSClusteringP2P": "Identify the main category of scholar papers based on the titles and abstracts",
283
+ "ThuNewsClusteringS2S": "Identify the topic or theme of the given news articles based on the titles",
284
+ "ThuNewsClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents",
285
  # MTEB-fr eval instructions
286
  "AlloProfClusteringP2P": "Identify the main category of Allo Prof document based on the titles and descriptions",
287
  "AlloProfClusteringS2S": "Identify the main category of Allo Prof document based on the titles",
 
289
  "MasakhaNEWSClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents",
290
  "MasakhaNEWSClusteringS2S": "Identify the topic or theme of the given news articles based on the titles",
291
  "MLSUMClusteringP2P": "Identify the topic or theme of the given articles based on the titles and contents",
292
+ "MLSUMClusteringS2S": "Identify the topic or theme of the given articles based on the titles",
293
  # MTEB-pl eval instructions
294
  "8TagsClustering": "Identify of headlines from social media posts in Polish into 8 categories: film, history, food, medicine, motorization, work, sport and technology",
295
  }
296
  return task_name_to_instruct[task_name]
297
 
298
+ if task_type in ["Reranking", "PairClassification"]:
299
  task_name_to_instruct: Dict[str, str] = {
300
+ "AskUbuntuDupQuestions": "Retrieve duplicate questions from AskUbuntu forum",
301
+ "MindSmallReranking": "Retrieve relevant news articles based on user browsing history",
302
+ "SciDocsRR": "Given a title of a scientific paper, retrieve the titles of other relevant papers",
303
+ "StackOverflowDupQuestions": "Retrieve duplicate questions from StackOverflow forum",
304
+ "SprintDuplicateQuestions": "Retrieve duplicate questions from Sprint forum",
305
+ "TwitterSemEval2015": "Retrieve tweets that are semantically similar to the given tweet",
306
+ "TwitterURLCorpus": "Retrieve tweets that are semantically similar to the given tweet",
307
  # C-MTEB eval instructions
308
+ "T2Reranking": "Given a Chinese search query, retrieve web passages that answer the question",
309
+ "MmarcoReranking": "Given a Chinese search query, retrieve web passages that answer the question",
310
+ "CMedQAv1": "Given a Chinese community medical question, retrieve replies that best answer the question",
311
+ "CMedQAv2": "Given a Chinese community medical question, retrieve replies that best answer the question",
312
+ "Ocnli": "Retrieve semantically similar text.",
313
+ "Cmnli": "Retrieve semantically similar text.",
314
  # MTEB-fr eval instructions
315
  "AlloprofReranking": "Given a question, retrieve passages that answer the question",
316
+ "OpusparcusPC": "Retrieve semantically similar text",
317
+ "PawsX": "Retrieve semantically similar text",
318
  "SyntecReranking": "Given a question, retrieve passages that answer the question",
319
  # MTEB-pl eval instructions
320
  "SICK-E-PL": "Retrieve semantically similar text",
 
324
  }
325
  return task_name_to_instruct[task_name]
326
 
327
+ if task_type in ["Retrieval"]:
328
+ if task_name.lower().startswith("cqadupstack"):
329
+ return "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question"
330
 
331
  task_name_to_instruct: Dict[str, str] = {
332
+ "ArguAna": "Given a claim, find documents that refute the claim",
333
+ "ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim",
334
+ "DBPedia": "Given a query, retrieve relevant entity descriptions from DBPedia",
335
+ "FEVER": "Given a claim, retrieve documents that support or refute the claim",
336
+ "FiQA2018": "Given a financial question, retrieve user replies that best answer the question",
337
+ "HotpotQA": "Given a multi-hop question, retrieve documents that can help answer the question",
338
+ "MSMARCO": "Given a web search query, retrieve relevant passages that answer the query",
339
+ "NFCorpus": "Given a question, retrieve relevant documents that best answer the question",
340
+ "NQ": "Given a question, retrieve Wikipedia passages that answer the question",
341
+ "QuoraRetrieval": "Given a question, retrieve questions that are semantically equivalent to the given question",
342
+ "SCIDOCS": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper",
343
+ "SciFact": "Given a scientific claim, retrieve documents that support or refute the claim",
344
+ "Touche2020": "Given a question, retrieve detailed and persuasive arguments that answer the question",
345
+ "TRECCOVID": "Given a query on COVID-19, retrieve documents that answer the query",
346
  # C-MTEB eval instructions
347
+ "T2Retrieval": "Given a Chinese search query, retrieve web passages that answer the question",
348
+ "MMarcoRetrieval": "Given a web search query, retrieve relevant passages that answer the query",
349
+ "DuRetrieval": "Given a Chinese search query, retrieve web passages that answer the question",
350
+ "CovidRetrieval": "Given a question on COVID-19, retrieve news articles that answer the question",
351
+ "CmedqaRetrieval": "Given a Chinese community medical question, retrieve replies that best answer the question",
352
+ "EcomRetrieval": "Given a user query from an e-commerce website, retrieve description sentences of relevant products",
353
+ "MedicalRetrieval": "Given a medical question, retrieve user replies that best answer the question",
354
+ "VideoRetrieval": "Given a video search query, retrieve the titles of relevant videos",
355
  # MTEB-fr eval instructions
356
  "AlloprofRetrieval": "Given a question, retrieve passages that answer the question",
357
  "BSARDRetrieval": "Given a question, retrieve passages that answer the question",
358
  "SyntecRetrieval": "Given a question, retrieve passages that answer the question",
359
  "XPQARetrieval": "Given a question, retrieve passages that answer the question",
360
  "MintakaRetrieval": "Given a question, retrieve passages that answer the question",
361
+ # MTEB-pl eval instructions
362
  "ArguAna-PL": "Given a claim, find documents that refute the claim",
363
  "DBPedia-PL": "Given a query, retrieve relevant entity descriptions from DBPedia",
364
  "FiQA-PL": "Given a financial question, retrieve user replies that best answer the question",
 
369
  "Quora-PL": "Given a question, retrieve questions that are semantically equivalent to the given question",
370
  "SCIDOCS-PL": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper",
371
  "SciFact-PL": "Given a scientific claim, retrieve documents that support or refute the claim",
372
+ "TRECCOVID-PL": "Given a query on COVID-19, retrieve documents that answer the query",
373
  }
374
 
375
  # add lower case keys to match some beir names
376
  task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
377
  # other cases where lower case match still doesn't work
378
+ task_name_to_instruct["trec-covid"] = task_name_to_instruct["TRECCOVID"]
379
+ task_name_to_instruct["climate-fever"] = task_name_to_instruct["ClimateFEVER"]
380
+ task_name_to_instruct["dbpedia-entity"] = task_name_to_instruct["DBPedia"]
381
+ task_name_to_instruct["webis-touche2020"] = task_name_to_instruct["Touche2020"]
382
+ task_name_to_instruct["fiqa"] = task_name_to_instruct["FiQA2018"]
383
+ task_name_to_instruct["quora"] = task_name_to_instruct["QuoraRetrieval"]
384
 
385
  # for miracl evaluation
386
+ task_name_to_instruct["miracl"] = (
387
+ "Given a question, retrieve Wikipedia passages that answer the question"
388
+ )
389
 
390
  return task_name_to_instruct[task_name]
391
+ logging.warning(
392
+ f"No instruction config for task {task_name} with type {task_type}, use default instruction."
393
+ )
394
+ return default_instruct
395
+
396
 
397
  class Encoder(torch.nn.Module):
398
+ def __init__(self, name_or_path: str, pooling: str):
399
  super().__init__()
400
  self.model = AutoModel.from_pretrained(name_or_path, trust_remote_code=True)
401
  self.model = self.model.half()
402
+ self.model.eval()
403
  self.pooling = pooling
404
 
405
  def forward(self, **features) -> torch.Tensor:
406
  output = self.model(**features, output_hidden_states=True, return_dict=True)
407
+ hidden_state = output.hidden_states[-1]
408
  embeddings = self.pooler(hidden_state, **features)
409
  return embeddings
410
 
411
  def pooler(
412
+ self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, **kwargs
 
 
 
413
  ) -> torch.Tensor:
414
  if attention_mask.ndim == 2:
415
  mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size())
 
420
 
421
  hidden_state = hidden_state * mask_expanded
422
 
423
+ if self.pooling == "first":
424
  pooled_output = hidden_state[:, 0]
425
 
426
+ elif self.pooling == "last":
427
+ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
428
  if left_padding:
429
  return hidden_state[:, -1]
430
  else:
431
  sequence_lengths = attention_mask.sum(dim=1) - 1
432
  batch_size = hidden_state.shape[0]
433
+ return hidden_state[
434
+ torch.arange(batch_size, device=hidden_state.device), sequence_lengths
435
+ ]
436
+ elif self.pooling == "mean":
437
  # TODO: weight
438
  lengths = mask_expanded.sum(1).clamp(min=1e-9)
439
  pooled_output = hidden_state.sum(dim=1) / lengths
440
 
441
+ elif self.pooling == "weightedmean":
442
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
443
  # hidden_state shape: bs, seq, hidden_dim
444
  weights = (
445
+ torch.arange(start=1, end=hidden_state.shape[1] + 1)
446
+ .unsqueeze(0)
447
+ .unsqueeze(-1)
448
+ .expand(hidden_state.size())
449
+ .float()
450
+ .to(hidden_state.device)
451
+ )
452
  assert weights.shape == hidden_state.shape == input_mask_expanded.shape
453
  input_mask_expanded = input_mask_expanded * weights
454
 
 
474
  force_default: bool = False,
475
  sep: str = " ",
476
  mp_tensor_to_cuda: bool = False,
477
+ instruction: Optional[str] = None,
 
478
  ):
479
  self.tokenizer = tokenizer
480
  self.model = encoder
481
  self.batch_size = batch_size
482
  self.max_seq_len = max_seq_len
483
+ self.pool: Optional[dict] = None
484
  self.normalize_embeddings = normalize_embeddings
485
  self.mp_tensor_to_cuda = mp_tensor_to_cuda
486
  self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
487
  self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
488
  self.instruction = instruction
489
+ self.default_query = default_query
490
  self.sep = sep
491
  self.force_default = force_default
492
+ if self.tokenizer.padding_side != "right":
493
+ logger.warning(
494
+ f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right"
495
+ )
496
+ self.tokenizer.padding_side = "right"
497
  if self.tokenizer.pad_token is None:
498
  logger.warning(f"Set tokenizer.pad_token as eos_token {self.tokenizer.eos_token}")
499
+ self.tokenizer.pad_token = "<|endoftext|>"
500
 
501
  def start(self, target_devices: Optional[List[str]] = None):
502
  """
 
509
  """
510
  if target_devices is None:
511
  if torch.cuda.is_available():
512
+ target_devices = ["cuda:{}".format(i) for i in range(torch.cuda.device_count())]
513
  else:
514
  logger.info("CUDA is not available. Start 4 CPU worker")
515
+ target_devices = ["cpu"] * 4
516
 
517
+ logger.info(
518
+ "Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices)))
519
+ )
520
+ print("multi instruction", self.instruction)
521
+ ctx = mp.get_context("spawn")
522
  input_queue = ctx.Queue()
523
  output_queue = ctx.Queue()
524
  processes = []
 
527
  p = ctx.Process(
528
  target=self._encode_multi_process_worker,
529
  args=(cuda_id, self, input_queue, output_queue),
530
+ daemon=True,
531
  )
532
  p.start()
533
  processes.append(p)
534
 
535
+ self.pool = {"input": input_queue, "output": output_queue, "processes": processes}
536
 
537
  def stop(self):
538
  """
539
  Stops all processes started with start_multi_process_pool
540
  """
541
+ for p in self.pool["processes"]:
542
  p.terminate()
543
 
544
+ for p in self.pool["processes"]:
545
  p.join()
546
  p.close()
547
 
548
+ self.pool["input"].close()
549
+ self.pool["output"].close()
550
 
551
  @staticmethod
552
  def _encode_multi_process_worker(target_device: str, model, input_queue, results_queue):
 
562
  except queue.Empty:
563
  break
564
 
565
+ def encode_multi_process(self, sentences: List[str], **kwargs):
 
 
 
 
566
  """
567
  This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
568
  and sent to individual processes, which encode these on the different GPUs. This method is only suitable
 
577
  part_size = math.ceil(len(sentences) / len(self.pool["processes"]))
578
  chunk_size = part_size if part_size < 3200 else 3200 # for retrieval chunk 50000
579
 
580
+ logger.debug(
581
+ f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}"
582
+ )
583
 
584
+ input_queue = self.pool["input"]
585
  last_chunk_id = 0
586
  chunk = []
587
 
 
596
  input_queue.put([last_chunk_id, chunk, kwargs])
597
  last_chunk_id += 1
598
 
599
+ output_queue = self.pool["output"]
600
+ results_list = sorted(
601
+ [output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0]
602
+ )
603
  embeddings = np.concatenate([result[1] for result in results_list])
604
  return embeddings
605
 
 
620
  (representing several text inputs to the model).
621
  """
622
 
623
+ if isinstance(text, dict): # {key: value} case
624
  return len(next(iter(text.values())))
625
+ elif not hasattr(text, "__len__"): # Object has no len() method
626
  return 1
627
+ elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
628
  return len(text)
629
  else:
630
+ return sum([len(t) for t in text]) # Sum of length of individual strings
631
 
632
  def _tokenize(self, sentences: List[str], is_query: bool):
633
+ batch_dict = self.tokenizer(
634
+ sentences,
635
+ max_length=self.max_seq_len - 1,
636
+ return_attention_mask=False,
637
+ padding=False,
638
+ truncation=True,
639
+ )
640
+ batch_dict["input_ids"] = [
641
+ input_ids + [self.tokenizer.eos_token_id] for input_ids in batch_dict["input_ids"]
642
+ ]
643
+ batch_dict = self.tokenizer.pad(
644
+ batch_dict, padding=True, return_attention_mask=True, return_tensors="pt"
645
+ )
646
+ batch_dict["is_causal"] = False
647
  return batch_dict
648
 
 
649
  def _encode(
650
  self,
651
  sentences: List[str],
652
  is_query: bool,
653
  convert_to_numpy: bool = True,
654
  convert_to_tensor: bool = False,
655
+ device: Optional[str] = None,
656
  show_progress_bar: bool = True,
657
+ **kwargs,
658
  ):
659
  """
660
  Computes sentence embeddings
 
677
  convert_to_numpy = False
678
 
679
  input_was_string = False
680
+ if isinstance(sentences, str) or not hasattr(
681
+ sentences, "__len__"
682
+ ): # Cast an individual sentence to a list with length 1
683
  sentences = [sentences]
684
  input_was_string = True
685
 
 
692
  length_sorted_idx = np.argsort([-self._text_length(s) for s in sentences])
693
  sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
694
 
695
+ for start_index in trange(
696
+ 0, len(sentences), self.batch_size, desc="Batches", disable=not show_progress_bar
697
+ ):
698
+ sentences_batch = sentences_sorted[start_index : start_index + self.batch_size]
699
  features = self._tokenize(sentences_batch, is_query)
700
  features = self.batch_to_device(features, device)
701
 
 
716
  if convert_to_tensor:
717
  all_embeddings = torch.stack(all_embeddings)
718
  elif convert_to_numpy:
719
+ # all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
720
  all_embeddings = np.asarray([emb.to(torch.float).numpy() for emb in all_embeddings])
721
  if input_was_string:
722
  all_embeddings = all_embeddings[0]
 
728
  sentences: List[str],
729
  is_query: Optional[bool] = None,
730
  convert_to_tensor: bool = False,
731
+ **kwargs,
732
  ):
733
  is_query = self.default_query if is_query is None else is_query
734
  if is_query and self.instruction:
735
+ sentences = [self.instruction + sent for sent in sentences]
736
  kwargs.update(is_query=is_query)
737
  if self.pool is not None:
738
  kwargs.update(show_progress_bar=False)
 
740
  if convert_to_tensor:
741
  embeddings = torch.from_numpy(embeddings)
742
  if self.mp_tensor_to_cuda and torch.cuda.is_available():
743
+ embeddings = embeddings.to(torch.device("cuda")) # default 0-th gpu
744
  return embeddings
745
 
746
  return self._encode(sentences, convert_to_tensor=convert_to_tensor, **kwargs)
 
760
  ]
761
  elif isinstance(corpus[0], dict):
762
  sentences = [
763
+ (doc["title"] + self.sep + doc["text"]).strip()
764
+ if "title" in doc
765
+ else doc["text"].strip()
766
  for doc in corpus
767
  ]
768
  else:
 
770
  is_query = self.default_query if self.force_default else False
771
  return self.encode(sentences, is_query=is_query, **kwargs)
772
 
773
+
774
  def main(args):
775
  tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
776
  encoder = Encoder(args.model, args.pooling)
777
+ default_query = args.default_type == "query"
778
  model = Wrapper(
779
+ tokenizer,
780
+ encoder,
781
  batch_size=args.batch_size,
782
  max_seq_len=args.max_seq_len,
783
  normalize_embeddings=args.norm,
784
+ default_query=default_query,
785
  )
786
+ sym_retrievals = ["QuoraRetrieval", "ArguAna", "CQADupstack"]
787
+ if args.task == "mteb":
788
  task_names = MTEB_TASK_LIST
789
+ lang = ["en"]
790
+ elif args.task == "cmteb":
791
  task_names = CMTEB_TASK_LIST
792
+ lang = ["zh", "zh-CN"]
793
+ elif args.task == "mteb-fr":
794
+ task_names = MTEB_FR
795
+ lang = ["fr"]
796
+ elif args.task == "mteb-pl":
797
+ task_names = MTEB_PL
798
+ lang = ["pl"]
799
  else:
800
  task_names = [args.task]
801
+ lang = ["en", "zh", "zh-CN", "pl", "fr"]
802
  for task in task_names:
803
  evaluation = MTEB(tasks=[task], task_langs=lang)
804
  task_cls = evaluation.tasks[0]
805
+ task_name: str = task_cls.metadata_dict["name"]
806
+ task_type: str = task_cls.metadata_dict["type"]
807
  instruction = get_task_def_by_task_name_and_type(task_name, task_type)
808
  model.instruction = get_detailed_instruct(instruction)
809
+ if task == "MSMARCO":
810
  eval_splits = ["dev"]
811
  elif task in CMTEB_TASK_LIST:
812
+ eval_splits = task_cls.metadata_dict["eval_splits"]
813
  else:
814
  eval_splits = ["test"]
815
  sym = False
 
820
  else:
821
  sym = False
822
  if sym:
823
+ logger.info(
824
+ f"Switch to symmetric mode for {task}, all as {'query' if default_query else 'doc'}."
825
+ )
826
  model.force_default = True
827
  evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
828
 
829
  if sym:
830
  logger.info(f"Switch back.")
831
  model.force_default = force_default_ori
832
+ print("\n")
833
 
834
 
835
  if __name__ == "__main__":
836
  _PARSER = argparse.ArgumentParser()
837
+ _PARSER.add_argument("-m", "--model", type=str, default=None)
838
+ _PARSER.add_argument("--pooling", type=str, default="last")
 
 
839
  _PARSER.add_argument("--output_dir", type=str, default=None)
840
+ _PARSER.add_argument("--default_type", type=str, default="query")
841
  _PARSER.add_argument("--max_seq_len", type=int, default=512)
842
  _PARSER.add_argument("-b", "--batch_size", type=int, default=32)
843
  _PARSER.add_argument(
844
+ "-t",
845
+ "--task",
846
+ type=str,
847
+ default=None, # None for running default tasks
848
  )
849
  _PARSER.add_argument("--norm", action="store_true")
850
  _ARGS = _PARSER.parse_args()