UNCANNY69 commited on
Commit
ee0c352
·
verified ·
1 Parent(s): 1db20fb

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +83 -0
model.py CHANGED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta
2
+ import numpy as np
3
+ import torch
4
+ from transformers.pytorch_utils import nn
5
+ import torch.nn.functional as F
6
+ from transformers import BertModel, BertForSequenceClassification, PreTrainedModel
7
+ from transformers.modeling_outputs import SequenceClassifierOutput
8
+ from transformers import AutoModelForSequenceClassification
9
+ from transformers import PretrainedConfig
10
+
11
+ class BertABSAConfig(PretrainedConfig):
12
+ model_type = "BertCNNForSequenceClassification"
13
+
14
+ def __init__(self,
15
+ num_classes=2,
16
+ embed_dim=768,
17
+ conv_out_channels=256, # New parameter for Conv1d
18
+ conv_kernel_size=3,
19
+ fc_hidden=128, # New parameter for FC layer
20
+ dropout_rate=0.1,
21
+ num_layers=12,
22
+ **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.num_classes = num_classes
25
+ self.embed_dim = embed_dim
26
+ self.conv_out_channels = conv_out_channels # Assign Conv1d output channels
27
+ self.conv_kernel_size = conv_kernel_size # Assign Conv1d kernel size
28
+ self.fc_hidden = fc_hidden # Assign FC layer hidden units
29
+ self.dropout_rate = dropout_rate
30
+ self.num_layers = num_layers
31
+ self.id2label = {
32
+ 0: "fake",
33
+ 1: "true",
34
+ }
35
+ self.label2id = {
36
+ "fake": 0,
37
+ "true": 1,
38
+ }
39
+
40
+
41
+
42
+ class BertCNNForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
43
+ config_class = BertABSAConfig
44
+
45
+ def __init__(self, config):
46
+ super(BertCNNForSequenceClassification, self).__init__(config)
47
+ self.num_classes = config.num_classes
48
+ self.embed_dim = config.embed_dim
49
+ self.num_layers = config.num_layers
50
+ self.conv_out_channels = config.conv_out_channels
51
+ self.conv_kernel_size = config.conv_kernel_size
52
+ self.dropout = nn.Dropout(config.dropout_rate)
53
+ self.bert = BertModel.from_pretrained('bert-base-uncased',
54
+ output_hidden_states=True,
55
+ output_attentions=False)
56
+ print("BERT Model Loaded")
57
+ self.conv1d = nn.Conv1d(in_channels=self.embed_dim, out_channels=self.conv_out_channels, kernel_size=self.conv_kernel_size)
58
+ self.fc = nn.Linear(self.conv_out_channels, self.num_classes)
59
+
60
+ def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
61
+ bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
62
+ hidden_states = bert_output["hidden_states"]
63
+
64
+ hidden_states = torch.stack([hidden_states[layer_i][:, 0].squeeze()
65
+ for layer_i in range(0, self.num_layers)], dim=-1) # noqa
66
+ hidden_states = hidden_states.view(-1, self.num_layers, self.embed_dim)
67
+ hidden_states = hidden_states.permute(0, 2, 1) # Permute to match Conv1d input shape
68
+ conv_output = self.conv1d(hidden_states)
69
+ conv_output = F.relu(conv_output)
70
+ conv_output = F.max_pool1d(conv_output, kernel_size=conv_output.size(2)) # Global Max Pooling
71
+ conv_output = conv_output.squeeze(-1)
72
+ conv_output = self.dropout(conv_output)
73
+ logits = self.fc(conv_output)
74
+ loss = None
75
+ if labels is not None:
76
+ loss = F.cross_entropy(logits, labels)
77
+ out = SequenceClassifierOutput(
78
+ loss=loss,
79
+ logits=logits,
80
+ hidden_states=bert_output.hidden_states,
81
+ attentions=bert_output.attentions,
82
+ )
83
+ return out