jpohhhh commited on
Commit
8af4da5
·
1 Parent(s): 87fd374

Try GPT4 suggestions

Browse files
Files changed (1) hide show
  1. handler.py +20 -7
handler.py CHANGED
@@ -3,11 +3,22 @@ from transformers import AutoTokenizer, AutoModel
3
  from optimum.pipelines import pipeline
4
  from optimum.onnxruntime import ORTModelForFeatureExtraction
5
  from pathlib import Path
 
6
  import time
7
 
8
  import os
9
  import torch
10
 
 
 
 
 
 
 
 
 
 
 
11
  def mean_pooling(model_output):
12
  # Get dimensions
13
  Z, Y = len(model_output[0]), len(model_output[0][0])
@@ -34,6 +45,12 @@ class EndpointHandler():
34
  self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/msmarco-MiniLM-L-6-v3')
35
  model_regular = ORTModelForFeatureExtraction.from_pretrained("jpohhhh/msmarco-MiniLM-L-6-v3_onnx", from_transformers=False)
36
  self.onnx_extractor = pipeline(task, model=model_regular, tokenizer=self.tokenizer)
 
 
 
 
 
 
37
 
38
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
39
  """
@@ -44,10 +61,6 @@ class EndpointHandler():
44
  A :obj:`list` | `dict`: will be serialized and returned
45
  """
46
  sentences = data.pop("inputs",data)
47
- sentence_embeddings = []
48
- for sentence in sentences:
49
- # Compute token embeddings
50
- with torch.no_grad():
51
- model_output = self.onnx_extractor(sentence)
52
- sentence_embeddings.append(mean_pooling(model_output))
53
- return sentence_embeddings
 
3
  from optimum.pipelines import pipeline
4
  from optimum.onnxruntime import ORTModelForFeatureExtraction
5
  from pathlib import Path
6
+ from multiprocessing import Pool
7
  import time
8
 
9
  import os
10
  import torch
11
 
12
+ def mean_pooling2(model_output):
13
+ """Perform mean pooling on tensor T
14
+ Args:
15
+ model_output: tensor T (elements are 2 dimentional float arrays).
16
+ Returns:
17
+ array of mean values.
18
+ """
19
+ return torch.mean(model_output[0], dim=1)
20
+
21
+
22
  def mean_pooling(model_output):
23
  # Get dimensions
24
  Z, Y = len(model_output[0]), len(model_output[0][0])
 
45
  self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/msmarco-MiniLM-L-6-v3')
46
  model_regular = ORTModelForFeatureExtraction.from_pretrained("jpohhhh/msmarco-MiniLM-L-6-v3_onnx", from_transformers=False)
47
  self.onnx_extractor = pipeline(task, model=model_regular, tokenizer=self.tokenizer)
48
+ self.pool = Pool(4)
49
+
50
+ def process_sentence(self, sentence): # Factored out for parallelization
51
+ with torch.no_grad():
52
+ model_output = self.onnx_extractor(sentence)
53
+ return mean_pooling2(model_output)
54
 
55
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
56
  """
 
61
  A :obj:`list` | `dict`: will be serialized and returned
62
  """
63
  sentences = data.pop("inputs",data)
64
+ # Compute embeddings in parallel
65
+ sentence_embeddings = self.pool.map(self.process_sentence, sentences)
66
+ return sentence_embeddings