|
from typing import List |
|
import torch |
|
from transformers import BertTokenizer |
|
from foodybert import FoodyBertForSequenceClassification |
|
|
|
class PreTrainedPipeline(): |
|
|
|
def __init__(self, path=""): |
|
|
|
""" |
|
Initialize model |
|
""" |
|
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
self.model = FoodyBertForSequenceClassification.from_pretrained("rttl-ai/foody-bert") |
|
|
|
def __call__(self, inputs: str) -> str: |
|
""" |
|
Args: |
|
inputs (:obj:`str`): |
|
a string to get the features of. |
|
Return: |
|
A :obj:`list` of floats: The features computed by the model. |
|
""" |
|
input_ids = self.bert_tokenizer.encode(inputs, add_special_tokens=True) |
|
X = torch.tensor([input_ids]) |
|
with torch.no_grad(): |
|
predicted_class_id = self.model(X).logits.argmax().item() |
|
labels = ['positive','neutral','negative'] |
|
reps = labels[predicted_class_id] |
|
return reps |