Update custom_model.py
Browse files- custom_model.py +10 -8
custom_model.py
CHANGED
@@ -1,19 +1,21 @@
|
|
1 |
-
|
2 |
-
from transformers import BertTokenizerFast, BertForTokenClassification
|
3 |
from peft import PeftModel
|
4 |
import torch
|
5 |
|
6 |
class CustomBertForTokenClassification(BertForTokenClassification):
|
7 |
-
def __init__(self,
|
8 |
-
super().__init__(
|
9 |
-
|
10 |
def predict_entities(self, sentence, label_mapping):
|
11 |
# Tokenize the input sentence
|
12 |
-
tokenizer = BertTokenizerFast.from_pretrained(
|
13 |
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
|
14 |
-
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
15 |
|
16 |
-
#
|
|
|
|
|
|
|
|
|
17 |
with torch.no_grad():
|
18 |
outputs = self(**inputs)
|
19 |
|
|
|
1 |
+
from transformers import BertForTokenClassification, BertTokenizerFast
|
|
|
2 |
from peft import PeftModel
|
3 |
import torch
|
4 |
|
5 |
class CustomBertForTokenClassification(BertForTokenClassification):
|
6 |
+
def __init__(self, config):
|
7 |
+
super().__init__(config)
|
8 |
+
|
9 |
def predict_entities(self, sentence, label_mapping):
|
10 |
# Tokenize the input sentence
|
11 |
+
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
12 |
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
|
|
|
13 |
|
14 |
+
# Move inputs to device
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
inputs = {key: value.to(device) for key, value in inputs.items()}
|
17 |
+
|
18 |
+
# Get model predictions
|
19 |
with torch.no_grad():
|
20 |
outputs = self(**inputs)
|
21 |
|