Spaces:
Running
Running
File size: 666 Bytes
a39c40f 79ba16e a39c40f 3c8fdbd e01d8e6 3c8fdbd dbe1544 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
import numpy as np
def model_predict(model, tokenizer, sentences):
"""
Predict the labels of the sentences using the model and tokenizer
Args:
model: Model (transformers)
tokenizer: Tokenizer (transformers tokenizer)
sentences: Sentences to predict (ndarray)
Returns:
predictions: Predicted labels
"""
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt", max_length=512).to("cpu")
# Classify sentences
with torch.no_grad():
outputs = model(**inputs) # get the logits
label = np.argmax(outputs.logits.to("cpu"))
return int(label) |