File size: 9,574 Bytes
778b7b6 28698b8 778b7b6 28698b8 778b7b6 28698b8 778b7b6 28698b8 778b7b6 28698b8 778b7b6 28698b8 778b7b6 28698b8 778b7b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.models.luke.modeling_luke import (
EntityPredictionHead,
LukeLMHead,
LukeModel,
)
from transformers.utils import ModelOutput
from .configuration_ubke import UbkeConfig
@dataclass
class UbkeMaskedLMOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
mlm_loss: Optional[torch.FloatTensor] = None
mep_loss: Optional[torch.FloatTensor] = None
tep_loss: Optional[torch.FloatTensor] = None
tcp_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
entity_logits: Optional[torch.FloatTensor] = None
topic_entity_logits: torch.FloatTensor = None
topic_category_logits: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
entity_last_hidden_state: torch.FloatTensor = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class UbkePreTrainedModel(PreTrainedModel):
config_class = UbkeConfig
base_model_prefix = "luke"
supports_gradient_checkpointing = True
_no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"]
def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
if module.embedding_dim == 1: # embedding for bias parameters
module.weight.data.zero_()
else:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class UbkeForMaskedLM(UbkePreTrainedModel):
_tied_weights_keys = [
"lm_head.decoder.weight",
"lm_head.decoder.bias",
"entity_predictions.decoder.weight",
]
def __init__(self, config: UbkeConfig):
super().__init__(config)
self.luke = LukeModel(config)
if self.config.normalize_entity_embeddings:
self.luke.entity_embeddings.entity_embeddings = nn.Embedding(
config.entity_vocab_size,
config.entity_emb_size,
padding_idx=0,
max_norm=1.0,
)
self.lm_head = LukeLMHead(config)
self.entity_predictions = EntityPredictionHead(config)
self.loss_fn = nn.CrossEntropyLoss()
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
super().tie_weights()
self._tie_or_clone_weights(
self.entity_predictions.decoder,
self.luke.entity_embeddings.entity_embeddings,
)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings: nn.Module):
self.lm_head.decoder = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask: Optional[torch.LongTensor] = None,
entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
entity_labels: Optional[torch.LongTensor] = None,
topic_entity_labels: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, UbkeMaskedLMOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
loss = None
mlm_loss = None
logits = self.lm_head(outputs.last_hidden_state)
if labels is not None:
labels = labels.to(logits.device)
mlm_loss = self.loss_fn(
logits.view(-1, self.config.vocab_size), labels.view(-1)
)
if loss is None:
loss = mlm_loss
mep_loss = None
entity_logits = None
if outputs.entity_last_hidden_state is not None:
entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
if entity_labels is not None:
mep_loss = self.loss_fn(
entity_logits.view(-1, self.config.entity_vocab_size)
/ self.config.entity_temperature,
entity_labels.view(-1),
)
if loss is None:
loss = mep_loss
else:
loss = loss + mep_loss
topic_entity_logits = self.entity_predictions(outputs.last_hidden_state[:, 0])
topic_category_logits = None
if self.config.num_category_entities > 0:
topic_category_logits = topic_entity_logits[
:, -self.config.num_category_entities :
]
topic_entity_logits = topic_entity_logits[
:, : -self.config.num_category_entities
]
topic_category_labels = None
if topic_entity_labels is not None and self.config.num_category_entities > 0:
topic_category_labels = topic_entity_labels[
:, -self.config.num_category_entities :
]
topic_entity_labels = topic_entity_labels[
:, : -self.config.num_category_entities
]
tep_loss = None
if topic_entity_labels is not None:
num_topic_entity_labels = topic_entity_labels.sum(dim=1)
if (num_topic_entity_labels > 0).any():
topic_entity_labels = topic_entity_labels.to(
topic_entity_logits.dtype
) / num_topic_entity_labels.unsqueeze(-1)
tep_loss = self.loss_fn(
topic_entity_logits[num_topic_entity_labels > 0]
/ self.config.entity_temperature,
topic_entity_labels[num_topic_entity_labels > 0],
)
if loss is None:
loss = tep_loss
else:
loss = loss + tep_loss
tcp_loss = None
if topic_category_labels is not None:
num_topic_category_labels = topic_category_labels.sum(dim=1)
if (num_topic_category_labels > 0).any():
topic_category_labels = topic_category_labels.to(
topic_category_logits.dtype
) / num_topic_category_labels.unsqueeze(-1)
tcp_loss = self.loss_fn(
topic_category_logits[num_topic_category_labels > 0]
/ self.config.entity_temperature,
topic_category_labels[num_topic_category_labels > 0],
)
if loss is None:
loss = tcp_loss
else:
loss = loss + tcp_loss
if not return_dict:
return tuple(
v
for v in [
logits,
entity_logits,
topic_entity_logits,
topic_category_logits,
outputs.last_hidden_state,
outputs.entity_last_hidden_state,
outputs.hidden_states,
outputs.entity_hidden_states,
outputs.attentions,
]
if v is not None
)
return UbkeMaskedLMOutput(
loss=loss,
mlm_loss=mlm_loss,
mep_loss=mep_loss,
tep_loss=tep_loss,
tcp_loss=tcp_loss,
logits=logits,
entity_logits=entity_logits,
topic_entity_logits=topic_entity_logits,
topic_category_logits=topic_category_logits,
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
entity_last_hidden_state=outputs.entity_last_hidden_state,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
|