frist commit
Browse files- README.md +1 -0
- card_selection_modified.pt +3 -0
- config.json +5 -0
- model.py +43 -0
- requirements.txt +2 -0
README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
This a fine-tuned bert model for card mapping in genUI.
|
card_selection_modified.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0aa88bb5edd01ab1336d836512991e6ece84cd38f38f09c55e626063b1286b0a
|
3 |
+
size 438061544
|
config.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "improved_bert",
|
3 |
+
"num_classes": 13,
|
4 |
+
"bert_model": "bert-base-uncased"
|
5 |
+
}
|
model.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import transformers
|
3 |
+
import torch
|
4 |
+
from transformers import BertTokenizer, BertModel, BertConfig
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
class AttentionPool(nn.Module):
|
8 |
+
def __init__(self, hidden_size):
|
9 |
+
super().__init__()
|
10 |
+
self.attention = nn.Linear(hidden_size, 1)
|
11 |
+
|
12 |
+
def forward(self, last_hidden_state):
|
13 |
+
attention_scores = self.attention(last_hidden_state).squeeze(-1)
|
14 |
+
attention_weights = F.softmax(attention_scores, dim=1)
|
15 |
+
pooled_output = torch.bmm(attention_weights.unsqueeze(1), last_hidden_state).squeeze(1)
|
16 |
+
return pooled_output
|
17 |
+
|
18 |
+
class MultiSampleDropout(nn.Module):
|
19 |
+
def __init__(self, dropout=0.5, num_samples=5):
|
20 |
+
super().__init__()
|
21 |
+
self.dropout = nn.Dropout(dropout)
|
22 |
+
self.num_samples = num_samples
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return torch.mean(torch.stack([self.dropout(x) for _ in range(self.num_samples)]), dim=0)
|
26 |
+
|
27 |
+
|
28 |
+
class ImprovedBERTClass(nn.Module):
|
29 |
+
def __init__(self, num_classes=13):
|
30 |
+
super().__init__()
|
31 |
+
self.bert = transformers.BertModel.from_pretrained('bert-base-uncased')
|
32 |
+
self.attention_pool = AttentionPool(768)
|
33 |
+
self.dropout = MultiSampleDropout()
|
34 |
+
self.norm = nn.LayerNorm(768)
|
35 |
+
self.classifier = nn.Linear(768, num_classes)
|
36 |
+
|
37 |
+
def forward(self, input_ids, attention_mask, token_type_ids):
|
38 |
+
bert_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
39 |
+
pooled_output = self.attention_pool(bert_output.last_hidden_state)
|
40 |
+
pooled_output = self.dropout(pooled_output)
|
41 |
+
pooled_output = self.norm(pooled_output)
|
42 |
+
logits = self.classifier(pooled_output)
|
43 |
+
return logits
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch==1.9.0
|
2 |
+
transformers==4.11.3
|