jpohhhh's picture
Update handler.py
a17e571
raw
history blame
2.99 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModel
from optimum.pipelines import pipeline
from optimum.onnxruntime import ORTModelForFeatureExtraction
from pathlib import Path
import time
import os
import torch
def max_pooling(model_output):
# Get dimensions
Z, Y = len(model_output[0]), len(model_output[0][0])
# Initialize an empty list with length Y (384 in your case)
output_array = [0] * Y
# Loop over secondary arrays (Z)
for i in range(Z):
# Loop over values in innermost arrays (Y)
for j in range(Y):
# If value is greater than current max, update max
if model_output[0][i][j] > output_array[j]:
output_array[j] = model_output[0][i][j]
return output_array
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class EndpointHandler():
def __init__(self, path=""):
print("HELLO THIS IS THE CWD:", os.getcwd())
print("HELLO THIS IS THE PATH ARG:", path)
files = os.listdir(path)
for file in files:
print(file)
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
task = "feature-extraction"
self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/msmarco-MiniLM-L-6-v3')
model_regular = ORTModelForFeatureExtraction.from_pretrained("jpohhhh/msmarco-MiniLM-L-6-v3_onnx", from_transformers=False)
self.onnx_extractor = pipeline(task, model=model_regular, tokenizer=self.tokenizer)
# self.model.to(self.device)
# print("model will run on ", self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
print("A")
sentences = data.pop("inputs",data)
print("B")
sentence_embeddings = []
print("C")
for sentence in sentences:
print("D")
# Compute token embeddings
with torch.no_grad():
model_output = self.onnx_extractor(sentence)
print("E")
# Perform pooling. In this case, max pooling.
# embedding = mean_pooling(model_output, encoded_input['attention_mask'])
print("F")
sentence_embeddings.append(max_pooling(model_output))
print("G")
return sentence_embeddings