from transformers import BertModel, BertConfig from abc import ABCMeta from transformers.modeling_outputs import BaseModelOutputWithPooling import torch.nn as nn import torch.nn.functional as F from transformers.modeling_outputs import SequenceClassifierOutput from transformers import PreTrainedModel from transformers import PretrainedConfig class BertLSTMConfig(PretrainedConfig): model_type = "bertLSTMForSequenceClassification" def __init__(self, num_classes=2, hidden_size=768, # BERT hidden size num_layers=12, hidden_dim_lstm=256, # New parameter for LSTM hidden_dropout_prob=0.1, # Changed from dropout_rate to hidden_dropout_prob **kwargs): super().__init__(**kwargs) self.num_classes = num_classes self.hidden_size = hidden_size self.num_layers = num_layers self.hidden_dim_lstm = hidden_dim_lstm # Assign LSTM hidden dimension self.hidden_dropout_prob = hidden_dropout_prob # Adjusted to BERT parameter name self.id2label = { 0: "fake", 1: "true", } self.label2id = { "fake": 0, "true": 1, } class BertLSTMForSequenceClassification(PreTrainedModel, metaclass=ABCMeta): config_class = BertLSTMConfig def __init__(self, config): super(BertLSTMForSequenceClassification, self).__init__(config) self.num_classes = config.num_classes self.embed_dim = config.hidden_size # BERT hidden size is used as the embedding dimension self.num_layers = config.num_layers self.hidden_dim_lstm = config.hidden_dim_lstm self.dropout = nn.Dropout(config.hidden_dropout_prob) # Use BertModel instead of AlbertModel self.bert = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=False) print("BERT Model Loaded") # Adjust the input dimension for LSTM self.lstm = nn.LSTM(self.embed_dim, self.hidden_dim_lstm, batch_first=True, num_layers=self.num_layers) # Adjust the output dimension for the linear layer self.fc = nn.Linear(self.hidden_dim_lstm, self.num_classes) def forward(self, input_ids, attention_mask, token_type_ids, labels=None): bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) hidden_states = bert_output["hidden_states"] # Extract the embeddings from the last layer of BERT last_hidden_states = hidden_states[-1] # Apply dropout last_hidden_states = self.dropout(last_hidden_states) # Pass through LSTM lstm_output, _ = self.lstm(last_hidden_states, None) # Take the output from the last time step lstm_output = lstm_output[:, -1, :] # Apply dropout lstm_output = self.dropout(lstm_output) # Linear layer for classification logits = self.fc(lstm_output) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) out = SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=bert_output.hidden_states, attentions=bert_output.attentions, ) return out