import transformers import torch from transformers import BertTokenizer, BertModel, BertConfig import torch.nn as nn class AttentionPool(nn.Module): def __init__(self, hidden_size): super().__init__() self.attention = nn.Linear(hidden_size, 1) def forward(self, last_hidden_state): attention_scores = self.attention(last_hidden_state).squeeze(-1) attention_weights = F.softmax(attention_scores, dim=1) pooled_output = torch.bmm(attention_weights.unsqueeze(1), last_hidden_state).squeeze(1) return pooled_output class MultiSampleDropout(nn.Module): def __init__(self, dropout=0.5, num_samples=5): super().__init__() self.dropout = nn.Dropout(dropout) self.num_samples = num_samples def forward(self, x): return torch.mean(torch.stack([self.dropout(x) for _ in range(self.num_samples)]), dim=0) class ImprovedBERTClass(nn.Module): def __init__(self, num_classes=13): super().__init__() self.bert = transformers.BertModel.from_pretrained('bert-base-uncased') self.attention_pool = AttentionPool(768) self.dropout = MultiSampleDropout() self.norm = nn.LayerNorm(768) self.classifier = nn.Linear(768, num_classes) def forward(self, input_ids, attention_mask, token_type_ids): bert_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) pooled_output = self.attention_pool(bert_output.last_hidden_state) pooled_output = self.dropout(pooled_output) pooled_output = self.norm(pooled_output) logits = self.classifier(pooled_output) return logits