File size: 1,668 Bytes
44dc531
 
 
 
02c198f
44dc531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a94233d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionPool(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attention = nn.Linear(hidden_size, 1)
    
    def forward(self, last_hidden_state):
        attention_scores = self.attention(last_hidden_state).squeeze(-1)
        attention_weights = F.softmax(attention_scores, dim=1)
        pooled_output = torch.bmm(attention_weights.unsqueeze(1), last_hidden_state).squeeze(1)
        return pooled_output

class MultiSampleDropout(nn.Module):
    def __init__(self, dropout=0.5, num_samples=5):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.num_samples = num_samples
    
    def forward(self, x):
        return torch.mean(torch.stack([self.dropout(x) for _ in range(self.num_samples)]), dim=0)


class ImprovedBERTClass(nn.Module):
    def __init__(self, num_classes=13):
        super().__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-uncased')
        self.attention_pool = AttentionPool(768)
        self.dropout = MultiSampleDropout()
        self.norm = nn.LayerNorm(768)
        self.classifier = nn.Linear(768, num_classes)
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        bert_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = self.attention_pool(bert_output.last_hidden_state)
        pooled_output = self.dropout(pooled_output)
        pooled_output = self.norm(pooled_output)
        logits = self.classifier(pooled_output)
        return logits