gyr66 commited on
Commit
df70323
Β·
1 Parent(s): 0cf85f0

Upload BertCrfForTokenClassification

Browse files
Files changed (2) hide show
  1. config.json +5 -2
  2. model.py +99 -0
config.json CHANGED
@@ -1,9 +1,12 @@
1
  {
2
- "_name_or_path": "./RoBERTa-ext-large-chinese-finetuned-ner",
3
  "architectures": [
4
  "BertCrfForTokenClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
 
 
 
7
  "bos_token_id": 0,
8
  "classifier_dropout": null,
9
  "directionality": "bidi",
@@ -89,7 +92,7 @@
89
  "pooler_type": "first_token_transform",
90
  "position_embedding_type": "absolute",
91
  "torch_dtype": "float32",
92
- "transformers_version": "4.35.2",
93
  "type_vocab_size": 2,
94
  "use_cache": true,
95
  "vocab_size": 21128
 
1
  {
2
+ "_name_or_path": "gyr66/RoBERTa-ext-large-crf-chinese-finetuned-ner",
3
  "architectures": [
4
  "BertCrfForTokenClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoModelForTokenClassification": "model.BertCrfForTokenClassification"
9
+ },
10
  "bos_token_id": 0,
11
  "classifier_dropout": null,
12
  "directionality": "bidi",
 
92
  "pooler_type": "first_token_transform",
93
  "position_embedding_type": "absolute",
94
  "torch_dtype": "float32",
95
+ "transformers_version": "4.36.2",
96
  "type_vocab_size": 2,
97
  "use_cache": true,
98
  "vocab_size": 21128
model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchcrf import CRF
4
+ from transformers import BertPreTrainedModel, BertModel, BertForTokenClassification
5
+ from transformers.modeling_outputs import TokenClassifierOutput
6
+
7
+
8
+ class BertCrfForTokenClassification(BertPreTrainedModel):
9
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
10
+
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+ self.num_labels = config.num_labels
14
+ self.config = config
15
+ self.bert = BertModel(config, add_pooling_layer=False)
16
+ self.dropout = nn.Dropout(
17
+ config.classifier_dropout
18
+ if config.classifier_dropout is not None
19
+ else config.hidden_dropout_prob
20
+ )
21
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
22
+ self.crf = CRF(config.num_labels, batch_first=True)
23
+ self.post_init()
24
+
25
+ def forward(
26
+ self,
27
+ input_ids=None,
28
+ attention_mask=None,
29
+ token_type_ids=None,
30
+ position_ids=None,
31
+ head_mask=None,
32
+ inputs_embeds=None,
33
+ labels=None,
34
+ output_attentions=None,
35
+ output_hidden_states=None,
36
+ return_dict=None,
37
+ ):
38
+ return_dict = (
39
+ return_dict if return_dict is not None else self.config.use_return_dict
40
+ )
41
+
42
+ outputs = self.bert(
43
+ input_ids,
44
+ attention_mask=attention_mask,
45
+ token_type_ids=token_type_ids,
46
+ position_ids=position_ids,
47
+ head_mask=head_mask,
48
+ inputs_embeds=inputs_embeds,
49
+ output_attentions=output_attentions,
50
+ output_hidden_states=output_hidden_states,
51
+ return_dict=return_dict,
52
+ )
53
+
54
+ sequence_output = outputs[0]
55
+
56
+ sequence_output = self.dropout(sequence_output)
57
+ logits = self.classifier(sequence_output)
58
+ dummy_logits = torch.zeros_like(logits).to(logits.device)
59
+
60
+ valid_lens = attention_mask.sum(dim=1) - 2
61
+ logits = logits[:, 1:]
62
+ labels_mask = torch.arange(logits.size(1)).to(
63
+ valid_lens.device
64
+ ) < valid_lens.unsqueeze(1)
65
+
66
+ seq_label_ids = self.crf.decode(logits, mask=labels_mask)
67
+
68
+ loss = None
69
+ if labels is not None:
70
+ labels = labels[:, 1:]
71
+ is_pad = labels == -100
72
+ labels.masked_fill_(is_pad, 0)
73
+ assert torch.eq(~is_pad, labels_mask).all().item(), "mask assertion failed "
74
+ loss = -self.crf(logits, labels, mask=labels_mask, reduction="token_mean")
75
+
76
+ padded_list = torch.nn.utils.rnn.pad_sequence(
77
+ [torch.tensor(lst) for lst in seq_label_ids],
78
+ batch_first=True,
79
+ padding_value=0,
80
+ )
81
+ padded_list = torch.nn.functional.pad(
82
+ padded_list, (0, logits.size(1) - padded_list.shape[1])
83
+ )
84
+ padded_list = torch.nn.functional.one_hot(
85
+ padded_list, num_classes=logits.size(2)
86
+ )
87
+ assert dummy_logits.size(1) == padded_list.size(1) + 1, "size assertion failed"
88
+ dummy_logits[:, 1:] = padded_list
89
+
90
+ if not return_dict:
91
+ output = (dummy_logits,) + outputs[2:]
92
+ return ((loss,) + output) if loss is not None else output
93
+
94
+ return TokenClassifierOutput(
95
+ loss=loss,
96
+ logits=dummy_logits,
97
+ hidden_states=outputs.hidden_states,
98
+ attentions=outputs.attentions,
99
+ )