ONNX
English
Shing Yee
feat: add files
ba6803f unverified
"""
inference_onnx.py
This script leverages ONNX runtime to perform inference with a pre-trained model.
"""
import json
import torch
import sys
import numpy as np
import onnxruntime as rt
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
config_path = "config.json"
with open(config_path, 'r') as f:
config = json.load(f)
def predict(sentence1, sentence2):
"""
Predicts the label for a pair of sentences using a fine-tuned ONNX model.
This function tokenizes the input sentences, prepares them as inputs for an ONNX model,
and performs inference to predict the label and probabilities for the given sentence pair.
Args:
- sentence1 (str): The first input sentence.
- sentence2 (str): The second input sentence.
Returns:
tuple:
- predicted_label (int): The predicted label (e.g., 0 or 1).
- probabilities (numpy.ndarray): The probabilities for each class.
"""
# Load model configuration
model_name = config['classifier']['embedding']['model_name']
max_length = config['classifier']['embedding']['max_length']
model_fp = config['classifier']['embedding']['model_fp']
# Set device and load tokenizer
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Get inputs
inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
input_ids1 = inputs1['input_ids'].to(device)
attention_mask1 = inputs1['attention_mask'].to(device)
input_ids2 = inputs2['input_ids'].to(device)
attention_mask2 = inputs2['attention_mask'].to(device)
# Download the classifier from HuggingFace hub
local_model_fp = model_fp
local_model_fp = hf_hub_download(repo_id=repo_path, filename=model_fp)
# Run inference
session = rt.InferenceSession(local_model_fp) # Load the ONNX model
onnx_inputs = {
session.get_inputs()[0].name: input_ids1.cpu().numpy(),
session.get_inputs()[1].name: attention_mask1.cpu().numpy(),
session.get_inputs()[2].name: input_ids2.cpu().numpy(),
session.get_inputs()[3].name: attention_mask2.cpu().numpy(),
}
outputs = session.run(None, onnx_inputs)
probabilities = torch.softmax(torch.tensor(outputs[0]), dim=1)
predicted_label = torch.argmax(probabilities, dim=1).item()
return predicted_label, probabilities.cpu().numpy()
if __name__ == "__main__":
# Load data
input_data = sys.argv[1]
sentence_pairs = json.loads(input_data)
# Validate input data format
if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
raise ValueError("Each pair must contain two strings.")
for idx, (sentence1, sentence2) in enumerate(sentence_pairs):
# Generate prediction and scores
predicted_label, probabilities = predict(sentence1, sentence2)
# Print the results
print(f"Pair {idx + 1}:")
print(f" Sentence 1: {sentence1}")
print(f" Sentence 2: {sentence2}")
print(f" Predicted Label: {predicted_label}")
print(f" Probabilities: {probabilities}")
print('-' * 50)