|
import torch |
|
from transformers import BertTokenizer |
|
from foody_bert import FoodyBertForSequenceClassification |
|
|
|
class PreTrainedPipeline(): |
|
|
|
def __init__(self, path=""): |
|
|
|
""" |
|
Initialize model |
|
""" |
|
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
self.model = FoodyBertForSequenceClassification.from_pretrained(".") |
|
def __call__(self, inputs: str) -> List[float]: |
|
|
|
""" |
|
Args: |
|
inputs (:obj:`str`): |
|
a string to get the features of. |
|
Return: |
|
A :obj:`list` of floats: The features computed by the model. |
|
""" |
|
input_ids = self.bert_tokenizer.encode(inputs, add_special_tokens=True) |
|
X = torch.tensor([input_ids]) |
|
with torch.no_grad(): |
|
predicted_class_id = self.model(X).logits.argmax().item() |
|
labels = ['positive','neutral','negative'] |
|
reps = labels[predicted_class_id] |
|
return self.model.get_sentence_vector(inputs).tolist() |
|
|