gsri-18 commited on
Commit
b7a1123
·
verified ·
1 Parent(s): a71a296

Update custom_model.py

Browse files
Files changed (1) hide show
  1. custom_model.py +10 -8
custom_model.py CHANGED
@@ -1,19 +1,21 @@
1
- # custom_model.py
2
- from transformers import BertTokenizerFast, BertForTokenClassification
3
  from peft import PeftModel
4
  import torch
5
 
6
  class CustomBertForTokenClassification(BertForTokenClassification):
7
- def __init__(self, *args, **kwargs):
8
- super().__init__(*args, **kwargs)
9
-
10
  def predict_entities(self, sentence, label_mapping):
11
  # Tokenize the input sentence
12
- tokenizer = BertTokenizerFast.from_pretrained(self.config._name_or_path)
13
  inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
14
- inputs = {key: value.to(self.device) for key, value in inputs.items()}
15
 
16
- # Get predictions
 
 
 
 
17
  with torch.no_grad():
18
  outputs = self(**inputs)
19
 
 
1
+ from transformers import BertForTokenClassification, BertTokenizerFast
 
2
  from peft import PeftModel
3
  import torch
4
 
5
  class CustomBertForTokenClassification(BertForTokenClassification):
6
+ def __init__(self, config):
7
+ super().__init__(config)
8
+
9
  def predict_entities(self, sentence, label_mapping):
10
  # Tokenize the input sentence
11
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
12
  inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
 
13
 
14
+ # Move inputs to device
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ inputs = {key: value.to(device) for key, value in inputs.items()}
17
+
18
+ # Get model predictions
19
  with torch.no_grad():
20
  outputs = self(**inputs)
21