File size: 7,330 Bytes
6b2e0e1 5101b13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
from transformers.modeling_outputs import TokenClassifierOutput, SequenceClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os
logger = logging.getLogger(__name__)
def get_info(label_map):
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
return num_token_labels_dict
class ModelForSequenceAndTokenClassification(PreTrainedModel):
def __init__(self, config, num_sequence_labels, num_token_labels, do_classif=False):
super().__init__(config)
self.num_token_labels = num_token_labels
self.num_sequence_labels = num_sequence_labels
self.config = config
self.do_classif = do_classif
self.bert = AutoModel.from_config(config)
classifier_dropout = (
config.classifier_dropout
if config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
# For token classification
self.token_classifier = nn.Linear(config.hidden_size, self.num_token_labels)
if do_classif:
# For the entire sequence classification
self.sequence_classifier = nn.Linear(
config.hidden_size, self.num_sequence_labels
)
# Initialize weights and apply final processing
self.post_init()
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = AutoConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
def do_classif(self):
return self.do_classif
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_labels: Optional[torch.Tensor] = None,
sequence_labels: Optional[torch.Tensor] = None,
offset_mapping: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[
Union[Tuple[torch.Tensor], SequenceClassifierOutput],
Union[Tuple[torch.Tensor], TokenClassifierOutput],
]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# For token classification
token_output = outputs[0]
token_output = self.dropout(token_output)
token_logits = self.token_classifier(token_output)
if self.do_classif:
# For the entire sequence classification
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
sequence_logits = self.sequence_classifier(pooled_output)
# Computing the loss as the average of both losses
loss = None
if token_labels is not None:
loss_fct = CrossEntropyLoss()
# import pdb;pdb.set_trace()
loss_tokens = loss_fct(
token_logits.view(-1, self.num_token_labels), token_labels.view(-1)
)
if self.do_classif:
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_sequence_labels == 1:
loss_sequence = loss_fct(
sequence_logits.squeeze(), sequence_labels.squeeze()
)
else:
loss_sequence = loss_fct(sequence_logits, sequence_labels)
if self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss_sequence = loss_fct(
sequence_logits.view(-1, self.num_sequence_labels),
sequence_labels.view(-1),
)
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss_sequence = loss_fct(sequence_logits, sequence_labels)
loss = loss_tokens + loss_sequence
else:
loss = loss_tokens
if not return_dict:
if self.do_classif:
output = (
sequence_logits,
token_logits,
) + outputs[2:]
return ((loss,) + output) if loss is not None else output
else:
output = (token_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
if self.do_classif:
return SequenceClassifierOutput(
loss=loss,
logits=sequence_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
), TokenClassifierOutput(
loss=loss,
logits=token_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
return TokenClassifierOutput(
loss=loss,
logits=token_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|