rusen commited on
Commit
3c8fdbd
·
1 Parent(s): e43e4a1

Created utils.py

Browse files
Files changed (1) hide show
  1. utils.py +19 -0
utils.py CHANGED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def model_predict(model, tokenizer, sentences):
2
+ """
3
+ Predict the labels of the sentences using the model and tokenizer
4
+ Args:
5
+ model: Model (transformers)
6
+ tokenizer: Tokenizer (transformers tokenizer)
7
+ sentences: Sentences to predict (ndarray)
8
+ Returns:
9
+ predictions: Predicted labels
10
+ """
11
+
12
+
13
+ inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
14
+ # Classify sentences
15
+ with torch.no_grad():
16
+ outputs = model(**inputs) # get the logits
17
+ label = np.argmax(outputs.logits.to("cpu"))
18
+
19
+ return str(labels)