File size: 666 Bytes
a39c40f
79ba16e
a39c40f
3c8fdbd
 
 
 
 
 
 
 
 
 
 
 
e01d8e6
3c8fdbd
 
 
 
 
dbe1544
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import numpy as np

def model_predict(model, tokenizer, sentences):
    """
    Predict the labels of the sentences using the model and tokenizer
    Args:
        model: Model (transformers)
        tokenizer: Tokenizer (transformers tokenizer)
        sentences: Sentences to predict (ndarray)
    Returns:
        predictions: Predicted labels
    """


    inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt", max_length=512).to("cpu")
    # Classify sentences
    with torch.no_grad():
        outputs = model(**inputs) # get the logits
        label = np.argmax(outputs.logits.to("cpu"))
        
    return int(label)