Misinfo-BERT-LSTM / model.py
UNCANNY69's picture
Update model.py
63b12c0 verified
from transformers import BertModel, BertConfig
from abc import ABCMeta
from transformers.modeling_outputs import BaseModelOutputWithPooling
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import PreTrainedModel
from transformers import PretrainedConfig
class BertLSTMConfig(PretrainedConfig):
model_type = "bertLSTMForSequenceClassification"
def __init__(self,
num_classes=2,
hidden_size=768, # BERT hidden size
num_layers=12,
hidden_dim_lstm=256, # New parameter for LSTM
hidden_dropout_prob=0.1, # Changed from dropout_rate to hidden_dropout_prob
**kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
self.hidden_size = hidden_size
self.num_layers = num_layers
self.hidden_dim_lstm = hidden_dim_lstm # Assign LSTM hidden dimension
self.hidden_dropout_prob = hidden_dropout_prob # Adjusted to BERT parameter name
self.id2label = {
0: "fake",
1: "true",
}
self.label2id = {
"fake": 0,
"true": 1,
}
class BertLSTMForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
config_class = BertLSTMConfig
def __init__(self, config):
super(BertLSTMForSequenceClassification, self).__init__(config)
self.num_classes = config.num_classes
self.embed_dim = config.hidden_size # BERT hidden size is used as the embedding dimension
self.num_layers = config.num_layers
self.hidden_dim_lstm = config.hidden_dim_lstm
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# Use BertModel instead of AlbertModel
self.bert = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=False)
print("BERT Model Loaded")
# Adjust the input dimension for LSTM
self.lstm = nn.LSTM(self.embed_dim, self.hidden_dim_lstm, batch_first=True, num_layers=self.num_layers)
# Adjust the output dimension for the linear layer
self.fc = nn.Linear(self.hidden_dim_lstm, self.num_classes)
def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
hidden_states = bert_output["hidden_states"]
# Extract the embeddings from the last layer of BERT
last_hidden_states = hidden_states[-1]
# Apply dropout
last_hidden_states = self.dropout(last_hidden_states)
# Pass through LSTM
lstm_output, _ = self.lstm(last_hidden_states, None)
# Take the output from the last time step
lstm_output = lstm_output[:, -1, :]
# Apply dropout
lstm_output = self.dropout(lstm_output)
# Linear layer for classification
logits = self.fc(lstm_output)
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels)
out = SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=bert_output.hidden_states,
attentions=bert_output.attentions,
)
return out