Text Classification
Transformers
PyTorch
bert
Inference Endpoints
rttl commited on
Commit
97a1414
1 Parent(s): 320e6ef

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +2 -2
pipeline.py CHANGED
@@ -11,7 +11,7 @@ class PreTrainedPipeline():
11
  Initialize model
12
  """
13
  self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
14
- self.model = FoodyBertForSequenceClassification.from_pretrained(".")
15
  def __call__(self, inputs: str) -> List[float]:
16
 
17
  """
@@ -27,4 +27,4 @@ class PreTrainedPipeline():
27
  predicted_class_id = self.model(X).logits.argmax().item()
28
  labels = ['positive','neutral','negative']
29
  reps = labels[predicted_class_id]
30
- return self.model.get_sentence_vector(inputs).tolist()
 
11
  Initialize model
12
  """
13
  self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
14
+ self.model = FoodyBertForSequenceClassification.from_pretrained("rttl-ai/foody-bert")
15
  def __call__(self, inputs: str) -> List[float]:
16
 
17
  """
 
27
  predicted_class_id = self.model(X).logits.argmax().item()
28
  labels = ['positive','neutral','negative']
29
  reps = labels[predicted_class_id]
30
+ return predicted_class_id