Upload pipeline.py
Browse files- pipeline.py +2 -2
pipeline.py
CHANGED
@@ -11,7 +11,7 @@ class PreTrainedPipeline():
|
|
11 |
Initialize model
|
12 |
"""
|
13 |
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
14 |
-
self.model = FoodyBertForSequenceClassification.from_pretrained("
|
15 |
def __call__(self, inputs: str) -> List[float]:
|
16 |
|
17 |
"""
|
@@ -27,4 +27,4 @@ class PreTrainedPipeline():
|
|
27 |
predicted_class_id = self.model(X).logits.argmax().item()
|
28 |
labels = ['positive','neutral','negative']
|
29 |
reps = labels[predicted_class_id]
|
30 |
-
return
|
|
|
11 |
Initialize model
|
12 |
"""
|
13 |
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
14 |
+
self.model = FoodyBertForSequenceClassification.from_pretrained("rttl-ai/foody-bert")
|
15 |
def __call__(self, inputs: str) -> List[float]:
|
16 |
|
17 |
"""
|
|
|
27 |
predicted_class_id = self.model(X).logits.argmax().item()
|
28 |
labels = ['positive','neutral','negative']
|
29 |
reps = labels[predicted_class_id]
|
30 |
+
return predicted_class_id
|