DinoLiu commited on
Commit
44dc531
1 Parent(s): 829982b

frist commit

Browse files
Files changed (5) hide show
  1. README.md +1 -0
  2. card_selection_modified.pt +3 -0
  3. config.json +5 -0
  4. model.py +43 -0
  5. requirements.txt +2 -0
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This a fine-tuned bert model for card mapping in genUI.
card_selection_modified.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0aa88bb5edd01ab1336d836512991e6ece84cd38f38f09c55e626063b1286b0a
3
+ size 438061544
config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "model_type": "improved_bert",
3
+ "num_classes": 13,
4
+ "bert_model": "bert-base-uncased"
5
+ }
model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import transformers
3
+ import torch
4
+ from transformers import BertTokenizer, BertModel, BertConfig
5
+ import torch.nn as nn
6
+
7
+ class AttentionPool(nn.Module):
8
+ def __init__(self, hidden_size):
9
+ super().__init__()
10
+ self.attention = nn.Linear(hidden_size, 1)
11
+
12
+ def forward(self, last_hidden_state):
13
+ attention_scores = self.attention(last_hidden_state).squeeze(-1)
14
+ attention_weights = F.softmax(attention_scores, dim=1)
15
+ pooled_output = torch.bmm(attention_weights.unsqueeze(1), last_hidden_state).squeeze(1)
16
+ return pooled_output
17
+
18
+ class MultiSampleDropout(nn.Module):
19
+ def __init__(self, dropout=0.5, num_samples=5):
20
+ super().__init__()
21
+ self.dropout = nn.Dropout(dropout)
22
+ self.num_samples = num_samples
23
+
24
+ def forward(self, x):
25
+ return torch.mean(torch.stack([self.dropout(x) for _ in range(self.num_samples)]), dim=0)
26
+
27
+
28
+ class ImprovedBERTClass(nn.Module):
29
+ def __init__(self, num_classes=13):
30
+ super().__init__()
31
+ self.bert = transformers.BertModel.from_pretrained('bert-base-uncased')
32
+ self.attention_pool = AttentionPool(768)
33
+ self.dropout = MultiSampleDropout()
34
+ self.norm = nn.LayerNorm(768)
35
+ self.classifier = nn.Linear(768, num_classes)
36
+
37
+ def forward(self, input_ids, attention_mask, token_type_ids):
38
+ bert_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
39
+ pooled_output = self.attention_pool(bert_output.last_hidden_state)
40
+ pooled_output = self.dropout(pooled_output)
41
+ pooled_output = self.norm(pooled_output)
42
+ logits = self.classifier(pooled_output)
43
+ return logits
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==1.9.0
2
+ transformers==4.11.3