|
""" |
|
inference_onnx.py |
|
|
|
This script leverages ONNX runtime to perform inference with a pre-trained model. |
|
""" |
|
import json |
|
import torch |
|
import sys |
|
import numpy as np |
|
import onnxruntime as rt |
|
|
|
from huggingface_hub import hf_hub_download |
|
from transformers import AutoTokenizer |
|
|
|
repo_path = "govtech/jina-embeddings-v2-small-en-off-topic" |
|
config_path = hf_hub_download(repo_id=repo_path, filename="config.json") |
|
config_path = "config.json" |
|
|
|
with open(config_path, 'r') as f: |
|
config = json.load(f) |
|
|
|
def predict(sentence1, sentence2): |
|
""" |
|
Predicts the label for a pair of sentences using a fine-tuned ONNX model. |
|
|
|
This function tokenizes the input sentences, prepares them as inputs for an ONNX model, |
|
and performs inference to predict the label and probabilities for the given sentence pair. |
|
|
|
Args: |
|
- sentence1 (str): The first input sentence. |
|
- sentence2 (str): The second input sentence. |
|
|
|
Returns: |
|
tuple: |
|
- predicted_label (int): The predicted label (e.g., 0 or 1). |
|
- probabilities (numpy.ndarray): The probabilities for each class. |
|
""" |
|
|
|
model_name = config['classifier']['embedding']['model_name'] |
|
max_length = config['classifier']['embedding']['max_length'] |
|
model_fp = config['classifier']['embedding']['model_fp'] |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else "cpu" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length) |
|
inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length) |
|
input_ids1 = inputs1['input_ids'].to(device) |
|
attention_mask1 = inputs1['attention_mask'].to(device) |
|
input_ids2 = inputs2['input_ids'].to(device) |
|
attention_mask2 = inputs2['attention_mask'].to(device) |
|
|
|
|
|
local_model_fp = model_fp |
|
local_model_fp = hf_hub_download(repo_id=repo_path, filename=model_fp) |
|
|
|
|
|
session = rt.InferenceSession(local_model_fp) |
|
onnx_inputs = { |
|
session.get_inputs()[0].name: input_ids1.cpu().numpy(), |
|
session.get_inputs()[1].name: attention_mask1.cpu().numpy(), |
|
session.get_inputs()[2].name: input_ids2.cpu().numpy(), |
|
session.get_inputs()[3].name: attention_mask2.cpu().numpy(), |
|
} |
|
outputs = session.run(None, onnx_inputs) |
|
probabilities = torch.softmax(torch.tensor(outputs[0]), dim=1) |
|
predicted_label = torch.argmax(probabilities, dim=1).item() |
|
|
|
return predicted_label, probabilities.cpu().numpy() |
|
|
|
if __name__ == "__main__": |
|
|
|
input_data = sys.argv[1] |
|
sentence_pairs = json.loads(input_data) |
|
|
|
|
|
if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs): |
|
raise ValueError("Each pair must contain two strings.") |
|
|
|
for idx, (sentence1, sentence2) in enumerate(sentence_pairs): |
|
|
|
|
|
predicted_label, probabilities = predict(sentence1, sentence2) |
|
|
|
|
|
print(f"Pair {idx + 1}:") |
|
print(f" Sentence 1: {sentence1}") |
|
print(f" Sentence 2: {sentence2}") |
|
print(f" Predicted Label: {predicted_label}") |
|
print(f" Probabilities: {probabilities}") |
|
print('-' * 50) |
|
|