UI_card_mapping / model.py
DinoLiu's picture
frist commit
44dc531
raw
history blame
No virus
1.7 kB
import transformers
import torch
from transformers import BertTokenizer, BertModel, BertConfig
import torch.nn as nn
class AttentionPool(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.attention = nn.Linear(hidden_size, 1)
def forward(self, last_hidden_state):
attention_scores = self.attention(last_hidden_state).squeeze(-1)
attention_weights = F.softmax(attention_scores, dim=1)
pooled_output = torch.bmm(attention_weights.unsqueeze(1), last_hidden_state).squeeze(1)
return pooled_output
class MultiSampleDropout(nn.Module):
def __init__(self, dropout=0.5, num_samples=5):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.num_samples = num_samples
def forward(self, x):
return torch.mean(torch.stack([self.dropout(x) for _ in range(self.num_samples)]), dim=0)
class ImprovedBERTClass(nn.Module):
def __init__(self, num_classes=13):
super().__init__()
self.bert = transformers.BertModel.from_pretrained('bert-base-uncased')
self.attention_pool = AttentionPool(768)
self.dropout = MultiSampleDropout()
self.norm = nn.LayerNorm(768)
self.classifier = nn.Linear(768, num_classes)
def forward(self, input_ids, attention_mask, token_type_ids):
bert_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled_output = self.attention_pool(bert_output.last_hidden_state)
pooled_output = self.dropout(pooled_output)
pooled_output = self.norm(pooled_output)
logits = self.classifier(pooled_output)
return logits