from abc import ABCMeta import torch from transformers.pytorch_utils import nn import torch.nn.functional as F from transformers import BertModel, BertForSequenceClassification, PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput from transformers import BertConfig from transformers import PretrainedConfig class BertAttentionConfig(PretrainedConfig): model_type = "bertAttentionForSequenceClassification" # Update the model type def __init__(self, num_classes=2, hidden_size=768, # Update embed_dim to hidden_size fc_hidden=128, # New parameter for FC layer num_layers=12, dropout_rate=0.1, **kwargs): super().__init__(**kwargs) self.num_classes = num_classes self.hidden_size = hidden_size # Update embed_dim to hidden_size self.fc_hidden = fc_hidden # Assign FC layer hidden units self.num_layers = num_layers self.dropout_rate = dropout_rate self.id2label = { 0: "fake", 1: "true", } self.label2id = { "fake": 0, "true": 1, } class BertAttentionForSequenceClassification(PreTrainedModel, metaclass=ABCMeta): config_class = BertAttentionConfig # Use the appropriate BERT configuration class def __init__(self, config): super(BertAttentionForSequenceClassification, self).__init__(config) self.num_classes = config.num_classes self.embed_dim = config.hidden_size # Hidden size is the BERT embedding dimension self.bert = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True) print("BERT Model Loaded") self.fc = nn.Linear(config.hidden_size, self.num_classes) def forward(self, input_ids, attention_mask, token_type_ids, labels=None): bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) hidden_states = bert_output.last_hidden_state # Use the last hidden state # Apply self-attention (scaled dot-product attention) attention_scores = torch.matmul(hidden_states, hidden_states.transpose(1, 2)) attention_scores = attention_scores / (self.embed_dim ** 0.5) attention_probs = F.softmax(attention_scores, dim=-1) attention_output = torch.matmul(attention_probs, hidden_states) # Pool over the sequence length to get the final representation pooled_output = torch.mean(attention_output, dim=1) logits = self.fc(pooled_output) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) out = SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=bert_output.hidden_states, attentions=bert_output.attentions, ) return out