from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from transformers.modeling_utils import PreTrainedModel from transformers.models.luke.modeling_luke import ( EntityPredictionHead, LukeLMHead, LukeModel, ) from transformers.utils import ModelOutput from .configuration_ubke import UbkeConfig @dataclass class UbkeMaskedLMOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None mlm_loss: Optional[torch.FloatTensor] = None mep_loss: Optional[torch.FloatTensor] = None tep_loss: Optional[torch.FloatTensor] = None tcp_loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None entity_logits: Optional[torch.FloatTensor] = None topic_entity_logits: torch.FloatTensor = None topic_category_logits: Optional[torch.FloatTensor] = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None entity_last_hidden_state: torch.FloatTensor = None entity_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None class UbkePreTrainedModel(PreTrainedModel): config_class = UbkeConfig base_model_prefix = "luke" supports_gradient_checkpointing = True _no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"] def _init_weights(self, module: nn.Module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): if module.embedding_dim == 1: # embedding for bias parameters module.weight.data.zero_() else: module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class UbkeForMaskedLM(UbkePreTrainedModel): _tied_weights_keys = [ "lm_head.decoder.weight", "lm_head.decoder.bias", "entity_predictions.decoder.weight", ] def __init__(self, config: UbkeConfig): super().__init__(config) self.luke = LukeModel(config) if self.config.normalize_entity_embeddings: self.luke.entity_embeddings.entity_embeddings = nn.Embedding( config.entity_vocab_size, config.entity_emb_size, padding_idx=0, max_norm=1.0, ) self.lm_head = LukeLMHead(config) self.entity_predictions = EntityPredictionHead(config) self.loss_fn = nn.CrossEntropyLoss() # Initialize weights and apply final processing self.post_init() def tie_weights(self): super().tie_weights() self._tie_or_clone_weights( self.entity_predictions.decoder, self.luke.entity_embeddings.entity_embeddings, ) def get_output_embeddings(self) -> nn.Module: return self.lm_head.decoder def set_output_embeddings(self, new_embeddings: nn.Module): self.lm_head.decoder = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, entity_ids: Optional[torch.LongTensor] = None, entity_attention_mask: Optional[torch.LongTensor] = None, entity_token_type_ids: Optional[torch.LongTensor] = None, entity_position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, entity_labels: Optional[torch.LongTensor] = None, topic_entity_labels: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, UbkeMaskedLMOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.luke( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, entity_ids=entity_ids, entity_attention_mask=entity_attention_mask, entity_token_type_ids=entity_token_type_ids, entity_position_ids=entity_position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) loss = None mlm_loss = None logits = self.lm_head(outputs.last_hidden_state) if labels is not None: labels = labels.to(logits.device) mlm_loss = self.loss_fn( logits.view(-1, self.config.vocab_size), labels.view(-1) ) if loss is None: loss = mlm_loss mep_loss = None entity_logits = None if outputs.entity_last_hidden_state is not None: entity_logits = self.entity_predictions(outputs.entity_last_hidden_state) if entity_labels is not None: mep_loss = self.loss_fn( entity_logits.view(-1, self.config.entity_vocab_size) / self.config.entity_temperature, entity_labels.view(-1), ) if loss is None: loss = mep_loss else: loss = loss + mep_loss topic_entity_logits = self.entity_predictions(outputs.last_hidden_state[:, 0]) topic_category_logits = None if self.config.num_category_entities > 0: topic_category_logits = topic_entity_logits[ :, -self.config.num_category_entities : ] topic_entity_logits = topic_entity_logits[ :, : -self.config.num_category_entities ] topic_category_labels = None if topic_entity_labels is not None and self.config.num_category_entities > 0: topic_category_labels = topic_entity_labels[ :, -self.config.num_category_entities : ] topic_entity_labels = topic_entity_labels[ :, : -self.config.num_category_entities ] tep_loss = None if topic_entity_labels is not None: num_topic_entity_labels = topic_entity_labels.sum(dim=1) if (num_topic_entity_labels > 0).any(): topic_entity_labels = topic_entity_labels.to( topic_entity_logits.dtype ) / num_topic_entity_labels.unsqueeze(-1) tep_loss = self.loss_fn( topic_entity_logits[num_topic_entity_labels > 0] / self.config.entity_temperature, topic_entity_labels[num_topic_entity_labels > 0], ) if loss is None: loss = tep_loss else: loss = loss + tep_loss tcp_loss = None if topic_category_labels is not None: num_topic_category_labels = topic_category_labels.sum(dim=1) if (num_topic_category_labels > 0).any(): topic_category_labels = topic_category_labels.to( topic_category_logits.dtype ) / num_topic_category_labels.unsqueeze(-1) tcp_loss = self.loss_fn( topic_category_logits[num_topic_category_labels > 0] / self.config.entity_temperature, topic_category_labels[num_topic_category_labels > 0], ) if loss is None: loss = tcp_loss else: loss = loss + tcp_loss if not return_dict: return tuple( v for v in [ logits, entity_logits, topic_entity_logits, topic_category_logits, outputs.last_hidden_state, outputs.entity_last_hidden_state, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions, ] if v is not None ) return UbkeMaskedLMOutput( loss=loss, mlm_loss=mlm_loss, mep_loss=mep_loss, tep_loss=tep_loss, tcp_loss=tcp_loss, logits=logits, entity_logits=entity_logits, topic_entity_logits=topic_entity_logits, topic_category_logits=topic_category_logits, last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, entity_last_hidden_state=outputs.entity_last_hidden_state, entity_hidden_states=outputs.entity_hidden_states, attentions=outputs.attentions, )