File size: 1,050 Bytes
320e6ef 362cab8 320e6ef 362cab8 b827343 4d90ca7 362cab8 4d90ca7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from typing import List
import torch
from transformers import BertTokenizer
from foodybert import FoodyBertForSequenceClassification
class PreTrainedPipeline():
def __init__(self, path=""):
"""
Initialize model
"""
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.model = FoodyBertForSequenceClassification.from_pretrained(".")
#def __call__(self, inputs: str) -> List[float]:
def __call__(self, inputs: str) -> str:
"""
Args:
inputs (:obj:`str`):
a string to get the features of.
Return:
A :obj:`list` of floats: The features computed by the model.
"""
input_ids = self.bert_tokenizer.encode(inputs, add_special_tokens=True)
X = torch.tensor([input_ids])
with torch.no_grad():
predicted_class_id = self.model(X).logits.argmax().item()
labels = ['positive','neutral','negative']
reps = labels[predicted_class_id]
return reps |