Text Classification
Transformers
PyTorch
bert
Inference Endpoints
foody-bert / pipeline.py
rttl's picture
Update pipeline.py
b827343
raw
history blame
1.05 kB
from typing import List
import torch
from transformers import BertTokenizer
from foodybert import FoodyBertForSequenceClassification
class PreTrainedPipeline():
def __init__(self, path=""):
"""
Initialize model
"""
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.model = FoodyBertForSequenceClassification.from_pretrained(".")
#def __call__(self, inputs: str) -> List[float]:
def __call__(self, inputs: str) -> str:
"""
Args:
inputs (:obj:`str`):
a string to get the features of.
Return:
A :obj:`list` of floats: The features computed by the model.
"""
input_ids = self.bert_tokenizer.encode(inputs, add_special_tokens=True)
X = torch.tensor([input_ids])
with torch.no_grad():
predicted_class_id = self.model(X).logits.argmax().item()
labels = ['positive','neutral','negative']
reps = labels[predicted_class_id]
return reps