UNCANNY69 commited on
Commit
f3ed620
·
verified ·
1 Parent(s): 7364b3f

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +66 -0
model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfigfrom abc import ABCMeta
2
+ import torch
3
+ from transformers.pytorch_utils import nn
4
+ from transformers import BertModel, BertConfig
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+ from transformers import PretrainedConfig
10
+
11
+ class BertLSTMConfig(PretrainedConfig):
12
+ model_type = "bertLSTMForSequenceClassification"
13
+
14
+ def __init__(self,
15
+ num_classes=2,
16
+ embed_dim=768,
17
+ num_layers=12,
18
+ hidden_dim_lstm=256, # New parameter for LSTM
19
+ dropout_rate=0.1,
20
+ **kwargs):
21
+ super().__init__(**kwargs)
22
+ self.num_classes = num_classes
23
+ self.embed_dim = embed_dim
24
+ self.num_layers = num_layers
25
+ self.hidden_dim_lstm = hidden_dim_lstm # Assign LSTM hidden dimension
26
+ self.dropout_rate = dropout_rate
27
+ self.id2label = {
28
+ 0: "fake",
29
+ 1: "true",
30
+ }
31
+ self.label2id = {
32
+ "fake": 0,
33
+ "true": 1,
34
+ }
35
+
36
+
37
+ class BertLSTMForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
38
+ config_class = BertLSTMConfig
39
+ def __init__(self, config):
40
+ super(BertLSTMForSequenceClassification, self).__init__(config)
41
+ self.num_classes = config.num_classes
42
+ self.embed_dim = config.embed_dim
43
+ self.num_layers = config.num_layers
44
+ self.hidden_dim_lstm = config.hidden_dim_lstm
45
+ self.dropout = nn.Dropout(config.dropout_rate)
46
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
47
+ print("BERT Model Loaded")
48
+ self.lstm = nn.LSTM(self.embed_dim, self.hidden_dim_lstm, batch_first=True, num_layers=3)
49
+ self.fc = nn.Linear(self.hidden_dim_lstm, self.num_classes)
50
+
51
+ def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
52
+ bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
53
+ pooled_output = bert_output.pooler_output # Use the pooled output for classification
54
+ out, _ = self.lstm(pooled_output.unsqueeze(1))
55
+ out = self.dropout(out[:, -1, :])
56
+ logits = self.fc(out)
57
+ loss = None
58
+ if labels is not None:
59
+ loss = F.cross_entropy(logits, labels)
60
+ out = SequenceClassifierOutput(
61
+ loss=loss,
62
+ logits=logits,
63
+ hidden_states=bert_output.hidden_states,
64
+ attentions=bert_output.attentions,
65
+ )
66
+ return out