File size: 2,351 Bytes
4016518
 
96355e1
a61e58e
dbdba21
 
df00d4a
4016518
 
 
 
 
 
 
 
 
 
df00d4a
 
b5cf395
 
 
96355e1
b20a254
c1f8c44
e82005a
b5cf395
cb578d2
59199c3
 
4016518
 
 
 
 
 
 
 
 
9ee4b8b
 
dccbfd3
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModel
from optimum.pipelines import pipeline
from optimum.onnxruntime import ORTModelForFeatureExtraction
from pathlib import Path

import os
import torch

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
class EndpointHandler():
    def __init__(self, path=""):
        print("HELLO THIS IS THE CWD:", os.getcwd())
        print("HELLO THIS IS THE PATH ARG:", path)
        files = os.listdir(path)
        for file in files:    
            print(file)
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        task = "feature-extraction"
        self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/msmarco-MiniLM-L-6-v3')
        model_regular = ORTModelForFeatureExtraction.from_pretrained("jpohhhh/msmarco-MiniLM-L-6-v3_onnx", from_transformers=False)
 
        self.onnx_extractor = pipeline(task, model=model_regular, tokenizer=self.tokenizer)
        # self.model.to(self.device)
        # print("model will run on ", self.device)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        sentences = data.pop("inputs",data)
        sentence_embeddings = []
        for sentence in sentences:    
            encoded_input = self.tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')    
            # Compute token embeddings    
            with torch.no_grad():       
                model_output = self.onnx_extractor(**encoded_input)    
            # Perform pooling. In this case, max pooling.    
            embedding = mean_pooling(model_output, encoded_input['attention_mask'])
            sentence_embeddings.append(embedding.tolist())
        return sentence_embeddings