|
|
|
import transformers |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class AttentionPool(nn.Module): |
|
def __init__(self, hidden_size): |
|
super().__init__() |
|
self.attention = nn.Linear(hidden_size, 1) |
|
|
|
def forward(self, last_hidden_state): |
|
attention_scores = self.attention(last_hidden_state).squeeze(-1) |
|
attention_weights = F.softmax(attention_scores, dim=1) |
|
pooled_output = torch.bmm(attention_weights.unsqueeze(1), last_hidden_state).squeeze(1) |
|
return pooled_output |
|
|
|
class MultiSampleDropout(nn.Module): |
|
def __init__(self, dropout=0.5, num_samples=5): |
|
super().__init__() |
|
self.dropout = nn.Dropout(dropout) |
|
self.num_samples = num_samples |
|
|
|
def forward(self, x): |
|
return torch.mean(torch.stack([self.dropout(x) for _ in range(self.num_samples)]), dim=0) |
|
|
|
|
|
class ImprovedBERTClass(nn.Module): |
|
def __init__(self, num_classes=13): |
|
super().__init__() |
|
self.bert = transformers.BertModel.from_pretrained('bert-base-uncased') |
|
self.attention_pool = AttentionPool(768) |
|
self.dropout = MultiSampleDropout() |
|
self.norm = nn.LayerNorm(768) |
|
self.classifier = nn.Linear(768, num_classes) |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
bert_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
pooled_output = self.attention_pool(bert_output.last_hidden_state) |
|
pooled_output = self.dropout(pooled_output) |
|
pooled_output = self.norm(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
return logits |
|
|