from transformers.file_utils import cached_path, hf_bucket_url from importlib.machinery import SourceFileLoader import os from transformers import EncoderDecoderModel, AutoConfig, AutoModel, EncoderDecoderConfig, RobertaForCausalLM, \ RobertaModel from transformers.modeling_utils import PreTrainedModel, logging import torch from torch.nn import CrossEntropyLoss, Parameter from transformers.modeling_outputs import Seq2SeqLMOutput, CausalLMOutputWithCrossAttentions, \ ModelOutput from attentions import ScaledDotProductAttention, MultiHeadAttention from collections import namedtuple from typing import Dict, Any, Optional, Tuple from dataclasses import dataclass import random from model_config_handling import EncoderDecoderSpokenNormConfig, DecoderSpokenNormConfig, PretrainedConfig cache_dir = './cache' model_name = 'nguyenvulebinh/envibert' if not os.path.exists(cache_dir): os.makedirs(cache_dir) logger = logging.get_logger(__name__) @dataclass class SpokenNormOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None logits_spoken_tagging: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None def collect_spoken_phrases_features(encoder_hidden_states, word_src_lengths, spoken_label): list_features = [] list_features_mask = [] max_length = word_src_lengths.max() feature_pad = torch.zeros_like(encoder_hidden_states[0, :1, :]) for hidden_state, word_length, list_idx in zip(encoder_hidden_states, word_src_lengths, spoken_label): for idx in list_idx: if idx > 0: start = sum(word_length[:idx]) end = start + word_length[idx] remain_length = max_length - word_length[idx] list_features_mask.append(torch.cat([torch.ones_like(spoken_label[0, 0]).expand(word_length[idx]), torch.zeros_like( spoken_label[0, 0].expand(remain_length))]).unsqueeze(0)) spoken_phrases_feature = hidden_state[start: end] list_features.append(torch.cat([spoken_phrases_feature, feature_pad.expand(remain_length, feature_pad.size(-1))]).unsqueeze(0)) return torch.cat(list_features), torch.cat(list_features_mask) def collect_spoken_phrases_labels(decoder_input_ids, labels, labels_bias, word_tgt_lengths, spoken_idx): list_decoder_input_ids = [] list_labels = [] list_labels_bias = [] max_length = word_tgt_lengths.max() init_decoder_ids = torch.tensor([0], device=labels.device, dtype=labels.dtype) pad_decoder_ids = torch.tensor([1], device=labels.device, dtype=labels.dtype) eos_decoder_ids = torch.tensor([2], device=labels.device, dtype=labels.dtype) none_labels_bias = torch.tensor([0], device=labels.device, dtype=labels.dtype) ignore_labels_bias = torch.tensor([-100], device=labels.device, dtype=labels.dtype) for decoder_inputs, decoder_label, decoder_label_bias, word_length, list_idx in zip(decoder_input_ids, labels, labels_bias, word_tgt_lengths, spoken_idx): for idx in list_idx: if idx > 0: start = sum(word_length[:idx - 1]) end = start + word_length[idx - 1] remain_length = max_length - word_length[idx - 1] remain_decoder_input_ids = max_length - len(decoder_inputs[start + 1:end + 1]) list_decoder_input_ids.append(torch.cat([init_decoder_ids, decoder_inputs[start + 1:end + 1], pad_decoder_ids.expand(remain_decoder_input_ids)]).unsqueeze(0)) list_labels.append(torch.cat([decoder_label[start:end], eos_decoder_ids, ignore_labels_bias.expand(remain_length)]).unsqueeze(0)) list_labels_bias.append(torch.cat([decoder_label_bias[start:end], none_labels_bias, ignore_labels_bias.expand(remain_length)]).unsqueeze(0)) decoder_input_ids = torch.cat(list_decoder_input_ids) labels = torch.cat(list_labels) labels_bias = torch.cat(list_labels_bias) return decoder_input_ids, labels, labels_bias class EncoderDecoderSpokenNorm(EncoderDecoderModel): config_class = EncoderDecoderSpokenNormConfig def __init__( self, config: Optional[PretrainedConfig] = None, encoder: Optional[PreTrainedModel] = None, decoder: Optional[PreTrainedModel] = None, ): if config is None and (encoder is None or decoder is None): raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") if config is None: config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) else: if not isinstance(config, self.config_class): raise ValueError(f"Config: {config} has to be of type {self.config_class}") if config.decoder.cross_attention_hidden_size is not None: if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: raise ValueError( "If `cross_attention_hidden_size` is specified in the decoder's configuration, " "it has to be equal to the encoder's `hidden_size`. " f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` " f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." ) # initialize with config super().__init__(config) if encoder is None: from transformers.models.auto.modeling_auto import AutoModel encoder = AutoModel.from_config(config.encoder) if decoder is None: # from transformers.models.auto.modeling_auto import AutoModelForCausalLM decoder = DecoderSpokenNorm._from_config(config.decoder) self.encoder = encoder self.decoder = decoder if self.encoder.config.to_dict() != self.config.encoder.to_dict(): logger.warning( f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}" ) if self.decoder.config.to_dict() != self.config.decoder.to_dict(): logger.warning( f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}" ) # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder # encoder outputs might need to be projected to different dimension for decoder if ( self.encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None ): self.enc_to_dec_proj = torch.nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) if self.encoder.get_output_embeddings() is not None: raise ValueError( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) # spoken tagging self.dropout = torch.nn.Dropout(0.3) # 0: "O", 1: "B", 2: "I" self.spoken_tagging_classifier = torch.nn.Linear(config.encoder.hidden_size, 3) # tie encoder, decoder weights if config set accordingly self.tie_weights() @classmethod def from_encoder_decoder_pretrained( cls, encoder_pretrained_model_name_or_path: str = None, decoder_pretrained_model_name_or_path: str = None, *model_args, **kwargs ) -> PreTrainedModel: kwargs_encoder = { argument[len("encoder_"):]: value for argument, value in kwargs.items() if argument.startswith("encoder_") } kwargs_decoder = { argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } # remove encoder, decoder kwargs from kwargs for key in kwargs_encoder.keys(): del kwargs["encoder_" + key] for key in kwargs_decoder.keys(): del kwargs["decoder_" + key] # Load and initialize the encoder and decoder # The distinction between encoder and decoder at the model level is made # by the value of the flag `is_decoder` that we need to set correctly. encoder = kwargs_encoder.pop("model", None) if encoder is None: if encoder_pretrained_model_name_or_path is None: raise ValueError( "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " "to be defined." ) if "config" not in kwargs_encoder: encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " "from a decoder model. Cross-attention and casual mask are disabled." ) encoder_config.is_decoder = False encoder_config.add_cross_attention = False kwargs_encoder["config"] = encoder_config encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) decoder = kwargs_decoder.pop("model", None) if decoder is None: if decoder_pretrained_model_name_or_path is None: raise ValueError( "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " "to be defined." ) if "config" not in kwargs_decoder: decoder_config = DecoderSpokenNormConfig.from_pretrained(decoder_pretrained_model_name_or_path) if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} " f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for " "cross attention layers." ) decoder_config.is_decoder = True decoder_config.add_cross_attention = True kwargs_decoder["config"] = decoder_config if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: logger.warning( f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) decoder = DecoderSpokenNorm.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) # instantiate config with corresponding kwargs config = EncoderDecoderSpokenNormConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) return cls(encoder=encoder, decoder=decoder, config=config) def get_encoder(self): def forward(input_ids=None, attention_mask=None, bias_input_ids=None, bias_attention_mask=None, return_dict=True, output_attentions=False, output_hidden_states=False, word_src_lengths=None, spoken_idx=None, **kwargs_encoder): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=None, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_encoder, ) encoder_outputs.word_src_lengths = word_src_lengths encoder_outputs.spoken_tagging_output = self.spoken_tagging_classifier(self.dropout(encoder_outputs[0])) if spoken_idx is not None: encoder_outputs.spoken_idx = spoken_idx else: pass encoder_bias_outputs = self.forward_bias(bias_input_ids, bias_attention_mask, output_attentions=output_attentions, return_dict=return_dict, output_hidden_states=output_hidden_states, **kwargs_encoder) # d = { # "encoder_bias_outputs": None, # "bias_attention_mask": None, # "last_hidden_state": None, # "pooler_output": None # # } # encoder_bias_outputs = namedtuple('Struct', d.keys())(*d.values()) # if bias_input_ids is not None: # encoder_bias_outputs = self.encoder( # input_ids=bias_input_ids, # attention_mask=bias_attention_mask, # inputs_embeds=None, # output_attentions=output_attentions, # output_hidden_states=output_hidden_states, # return_dict=return_dict, # **kwargs_encoder, # ) # encoder_bias_outputs.bias_attention_mask = bias_attention_mask return encoder_outputs, encoder_bias_outputs return forward def forward_bias(self, bias_input_ids, bias_attention_mask, output_attentions=False, return_dict=True, output_hidden_states=False, **kwargs_encoder): d = { "encoder_bias_outputs": None, "bias_attention_mask": None, "last_hidden_state": None, "pooler_output": None } encoder_bias_outputs = namedtuple('Struct', d.keys())(*d.values()) if bias_input_ids is not None: encoder_bias_outputs = self.encoder( input_ids=bias_input_ids, attention_mask=bias_attention_mask, inputs_embeds=None, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_encoder, ) encoder_bias_outputs.bias_attention_mask = bias_attention_mask return encoder_bias_outputs def _prepare_encoder_decoder_kwargs_for_generation( self, input_ids: torch.LongTensor, model_kwargs, model_input_name ) -> Dict[str, Any]: if "encoder_outputs" not in model_kwargs: # retrieve encoder hidden states encoder = self.get_encoder() encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not (argument.startswith("decoder_") or argument.startswith("cross_attn")) } encoder_outputs, encoder_bias_outputs = encoder(input_ids, return_dict=True, **encoder_kwargs) model_kwargs["encoder_outputs"]: ModelOutput = encoder_outputs model_kwargs["encoder_bias_outputs"]: ModelOutput = encoder_bias_outputs return model_kwargs def _prepare_decoder_input_ids_for_generation( self, batch_size: int, decoder_start_token_id: int = None, bos_token_id: int = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.LongTensor: if model_kwargs is not None and "decoder_input_ids" in model_kwargs: return model_kwargs.pop("decoder_input_ids") else: decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) num_spoken_phrases = (model_kwargs['encoder_outputs'].spoken_idx >= 0).view(-1).sum() return torch.ones((num_spoken_phrases, 1), dtype=torch.long, device=self.device) * decoder_start_token_id def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None input_dict = { "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "decoder_input_ids": decoder_inputs["input_ids"], "encoder_outputs": encoder_outputs, "encoder_bias_outputs": kwargs["encoder_bias_outputs"], "past_key_values": decoder_inputs["past_key_values"], "use_cache": use_cache, } return input_dict def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, bias_input_ids=None, bias_attention_mask=None, labels_bias=None, decoder_attention_mask=None, encoder_outputs=None, encoder_bias_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, spoken_label=None, word_src_lengths=None, word_tgt_lengths=None, spoken_idx=None, output_attentions=None, output_hidden_states=None, return_dict=None, inputs_length=None, outputs=None, outputs_length=None, text=None, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} kwargs_decoder = { argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } spoken_tagging_output = None if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_encoder, ) spoken_tagging_output = self.spoken_tagging_classifier(self.dropout(encoder_outputs[0])) # else: # word_src_lengths = encoder_outputs.word_src_lengths # spoken_tagging_output = encoder_outputs.spoken_tagging_output if encoder_bias_outputs is None: encoder_bias_outputs = self.encoder( input_ids=bias_input_ids, attention_mask=bias_attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_encoder, ) encoder_bias_outputs.bias_attention_mask = bias_attention_mask encoder_hidden_states = encoder_outputs[0] # if spoken_idx is None: # # extract spoken_idx from spoken_tagging_output # spoken_idx = None # encoder_hidden_states, attention_mask = collect_spoken_phrases_features(encoder_hidden_states, # word_src_lengths, # spoken_idx) # if labels is not None: # decoder_input_ids, labels, labels_bias = collect_spoken_phrases_labels(decoder_input_ids, # labels, labels_bias, # word_tgt_lengths, # spoken_idx) if spoken_idx is not None: encoder_hidden_states, attention_mask = collect_spoken_phrases_features(encoder_hidden_states, word_src_lengths, spoken_idx) decoder_input_ids, labels, labels_bias = collect_spoken_phrases_labels(decoder_input_ids, labels, labels_bias, word_tgt_lengths, spoken_idx) # optionally project encoder_hidden_states if ( self.encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None ): encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_bias_pooling=encoder_bias_outputs.pooler_output, # encoder_bias_hidden_states=encoder_bias_outputs[0], encoder_bias_hidden_states=encoder_bias_outputs.last_hidden_state, bias_attention_mask=encoder_bias_outputs.bias_attention_mask, encoder_attention_mask=attention_mask, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, labels_bias=labels_bias, **kwargs_decoder, ) # Compute loss independent from decoder (as some shift the logits inside them) loss = None if labels is not None: logits = decoder_outputs.logits if return_dict else decoder_outputs[1] loss_fct = CrossEntropyLoss() loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) loss = loss + decoder_outputs.loss if spoken_label is not None: loss_fct = CrossEntropyLoss() spoken_tagging_loss = loss_fct(spoken_tagging_output.reshape(-1, 3), spoken_label.view(-1)) loss = loss + spoken_tagging_loss if not return_dict: if loss is not None: return (loss,) + decoder_outputs + encoder_outputs else: return decoder_outputs + encoder_outputs return SpokenNormOutput( loss=loss, logits=decoder_outputs.logits, logits_spoken_tagging=spoken_tagging_output, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) class DecoderSpokenNorm(RobertaForCausalLM): config_class = DecoderSpokenNormConfig # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta def __init__(self, config): super().__init__(config) self.dense_query_copy = torch.nn.Linear(config.hidden_size, config.hidden_size) self.mem_no_entry = Parameter(torch.randn(config.hidden_size).unsqueeze(0)) self.bias_attention_layer = MultiHeadAttention(config.hidden_size) self.copy_attention_layer = MultiHeadAttention(config.hidden_size) def forward_bias_attention(self, query, values, values_mask): """ :param query: batch * output_steps * hidden_state :param values: batch * output_steps * max_bias_steps * hidden_state :param values_mask: batch * output_steps * max_bias_steps :return: batch * output_steps * hidden_state """ batch, output_steps, hidden_state = query.size() _, _, max_bias_steps, _ = values.size() query = query.view(batch * output_steps, 1, hidden_state) values = values.view(-1, max_bias_steps, hidden_state) values_mask = 1 - values_mask.view(-1, max_bias_steps) result_attention, attention_score = self.bias_attention_layer(query=query, key=values, value=values, mask=values_mask.bool()) result_attention = result_attention.squeeze(1).view(batch, output_steps, hidden_state) return result_attention def forward_copy_attention(self, query, values, values_mask): """ :param query: batch * output_steps * hidden_state :param values: batch * max_encoder_steps * hidden_state :param values_mask: batch * output_steps * max_encoder_steps :return: batch * output_steps * hidden_state """ dot_attn_score = torch.bmm(query, values.transpose(2, 1)) attn_mask = (1 - values_mask.clone().unsqueeze(1)).bool() dot_attn_score.masked_fill_(attn_mask, -float('inf')) dot_attn_score = torch.softmax(dot_attn_score, dim=-1) result_attention = torch.bmm(dot_attn_score, values) return result_attention def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, encoder_bias_pooling=None, encoder_bias_hidden_states=None, bias_attention_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, labels_bias=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: use_cache = False # attention with input encoded outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # Query for bias sequence_output = outputs[0] bias_indicate_output = None # output copy attention query_copy = torch.relu(self.dense_query_copy(sequence_output)) sequence_atten_copy_output = self.forward_copy_attention(query_copy, encoder_hidden_states, encoder_attention_mask) if encoder_bias_pooling is not None: # Make bias features encoder_bias_pooling = torch.cat([self.mem_no_entry, encoder_bias_pooling], dim=0) mem_no_entry_feature = torch.zeros_like(encoder_bias_hidden_states[0]).unsqueeze(0) mem_no_entry_mask = torch.ones_like(bias_attention_mask[0]).unsqueeze(0) encoder_bias_hidden_states = torch.cat([mem_no_entry_feature, encoder_bias_hidden_states], dim=0) bias_attention_mask = torch.cat([mem_no_entry_mask, bias_attention_mask], dim=0) # Compute ranking score b, s, h = sequence_output.size() bias_ranking_score = sequence_output.view(b * s, h).mm(encoder_bias_pooling.T) bias_ranking_score = bias_ranking_score.view(b, s, encoder_bias_pooling.size(0)) # teacher force with bias label if not self.training: bias_indicate_output = torch.argmax(bias_ranking_score, dim=-1) else: if random.random() < 0.5: bias_indicate_output = labels_bias.clone() bias_indicate_output[torch.where(bias_indicate_output < 0)] = 0 else: bias_indicate_output = torch.argmax(bias_ranking_score, dim=-1) # Bias encoder hidden state _, max_len, _ = encoder_bias_hidden_states.size() bias_encoder_hidden_states = torch.index_select(input=encoder_bias_hidden_states, dim=0, index=bias_indicate_output.view(b * s)).view(b, s, max_len, h) bias_encoder_attention_mask = torch.index_select(input=bias_attention_mask, dim=0, index=bias_indicate_output.view(b * s)).view(b, s, max_len) sequence_atten_bias_output = self.forward_bias_attention(sequence_output, bias_encoder_hidden_states, bias_encoder_attention_mask) # Find output words prediction_scores = self.lm_head(sequence_output + sequence_atten_bias_output + sequence_atten_copy_output) else: prediction_scores = self.lm_head(sequence_output + sequence_atten_copy_output) # run attention with bias bias_ranking_loss = None if labels_bias is not None: loss_fct = CrossEntropyLoss() bias_ranking_loss = loss_fct(bias_ranking_score.view(-1, encoder_bias_pooling.size(0)), labels_bias.view(-1)) if not return_dict: output = (prediction_scores,) + outputs[2:] return ((bias_ranking_loss,) + output) if bias_ranking_loss is not None else output result = CausalLMOutputWithCrossAttentions( loss=bias_ranking_loss, logits=prediction_scores, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) result.bias_indicate_output = bias_indicate_output return result def download_tokenizer_files(): resources = ['envibert_tokenizer.py', 'dict.txt', 'sentencepiece.bpe.model'] for item in resources: if not os.path.exists(os.path.join(cache_dir, item)): tmp_file = hf_bucket_url(model_name, filename=item) tmp_file = cached_path(tmp_file, cache_dir=cache_dir) os.rename(tmp_file, os.path.join(cache_dir, item)) def init_tokenizer(): download_tokenizer_files() tokenizer = SourceFileLoader("envibert.tokenizer", os.path.join(cache_dir, 'envibert_tokenizer.py')).load_module().RobertaTokenizer(cache_dir) tokenizer.model_input_names = ["input_ids", "attention_mask", "bias_input_ids", "bias_attention_mask", "labels" "labels_bias"] return tokenizer def init_model(): download_tokenizer_files() tokenizer = SourceFileLoader("envibert.tokenizer", os.path.join(cache_dir, 'envibert_tokenizer.py')).load_module().RobertaTokenizer(cache_dir) tokenizer.model_input_names = ["input_ids", "attention_mask", "bias_input_ids", "bias_attention_mask", "labels" "labels_bias"] # set encoder decoder tying to True roberta_shared = EncoderDecoderSpokenNorm.from_encoder_decoder_pretrained(model_name, model_name, tie_encoder_decoder=False) # set special tokens roberta_shared.config.decoder_start_token_id = tokenizer.bos_token_id roberta_shared.config.eos_token_id = tokenizer.eos_token_id roberta_shared.config.pad_token_id = tokenizer.pad_token_id # sensible parameters for beam search # set decoding params roberta_shared.config.max_length = 50 roberta_shared.config.early_stopping = True roberta_shared.config.no_repeat_ngram_size = 3 roberta_shared.config.length_penalty = 2.0 roberta_shared.config.num_beams = 1 roberta_shared.config.vocab_size = roberta_shared.config.encoder.vocab_size return roberta_shared, tokenizer