import contextlib import numpy as np import torch from torch import nn from enum import Enum, auto from .model import Unlimiformer, ModelType, UnlimiformerBART, UnlimiformerT5, UnlimiformerLED from transformers import BartModel, BartForConditionalGeneration, \ T5Model, T5ForConditionalGeneration, \ LEDModel, LEDForConditionalGeneration, \ AutoModelForSeq2SeqLM class RandomTrainingUnlimiformer(Unlimiformer[ModelType]): def __init__(self, model: ModelType, *args, **kwargs): super().__init__(model, *args, **kwargs) self.training_hooks_injected = False self.train_step = 0 @classmethod def convert_model(cls, model, *args, **kwargs): # model_clone = AutoModelForSeq2SeqLM.from_config(model.config) # model_clone.load_state_dict(model.state_dict()).to(args.device) type_to_class = { BartModel: RandomUnlimiformerBART, BartForConditionalGeneration: RandomUnlimiformerBART, T5Model: RandomUnlimiformerT5, T5ForConditionalGeneration: RandomUnlimiformerT5, LEDModel: RandomUnlimiformerLED, LEDForConditionalGeneration: RandomUnlimiformerLED, } type_to_class[type(model)](model, *args, **kwargs) return model def pre_eval_hook(self): self.remove_training_hooks(self.model) self.inject_hooks(self.model) self.original_model_eval_func() def pre_train_hook(self, mode=True): # mode=True means model.train() is called # mode=False means model.eval() is called torch.cuda.empty_cache() if mode is True: self.break_out(self.model) self.remove_training_hooks(self.model) if self.unlimiformer_training and self.train_step % 2 == 0: super().inject_training_hooks(self.model) else: self.inject_training_hooks(self.model) self.train_step += 1 self.original_model_train_func(mode) def inject_training_hooks(self, model): if self.training_hooks_injected: return # self.original_forward_func = model.forward model.forward = self.random_inputs_forward_hook decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) self.original_decoder_layer_self_attn_forward_funcs = [] for decoder_layer in decoder_layers_to_run: attention = self.self_attention(decoder_layer) self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward) attention.forward = self.create_self_attn_random_pre_forward_hook(attention.forward) self.original_decoder_layer_forward_funcs = [] for decoder_layer in decoder_layers_to_run: self.original_decoder_layer_forward_funcs.append(decoder_layer.forward) decoder_layer.forward = self.create_decoder_layer_random_func(decoder_layer.forward, decoder_layer) self.original_decoder_layer_cross_attn_forward_funcs = [] for i, decoder_layer in enumerate(decoder_layers_to_run): attention = self.cross_attention(decoder_layer) self.original_decoder_layer_cross_attn_forward_funcs.append(attention.forward) self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run) self.training_hooks_injected = True def create_self_attn_random_pre_forward_hook(self, original_self_attn_forward_func): def self_attention_pre_forward_hook(*args, **kwargs): kwargs['past_key_value'] = None return original_self_attn_forward_func(*args, **kwargs) return self_attention_pre_forward_hook def create_decoder_layer_random_func(self, decoder_layer_original_forward_func, decoder_layer): def checkpointed_decoder_layer( hidden_states: torch.Tensor, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, layer_head_mask=None, cross_attn_layer_head_mask=None, past_key_value=None, output_attentions=False, position_bias=None, encoder_decoder_position_bias=None, use_cache=True): def sample_and_forward(hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, past_key_value, output_attentions, use_cache, long_inputs, long_inputs_mask, rand_indices, position_bias, encoder_decoder_position_bias): sampled_input, _ = self.sample_long_input(long_inputs, long_inputs_mask, rand_indices) key, value = self.create_key_value(sampled_input, decoder_layer) decoder_layer_args = self.create_decoder_layer_args( hidden_states=hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, use_cache=use_cache, key=key,value=value ) return decoder_layer_original_forward_func(**decoder_layer_args) with torch.no_grad(): # This sampling must be done outside of the checkpoint, to ensure that the same sampling happens # both in "forward" and "backward" passes rand_indices = self.sample_random_indices() return torch.utils.checkpoint.checkpoint( sample_and_forward, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, None, output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, rand_indices, position_bias, encoder_decoder_position_bias) return checkpointed_decoder_layer def sample_random_indices(self): rand_indices_list = [] seq_lens = self.long_inputs_mask.sum(-1).tolist() for seq_len in seq_lens: if seq_len < self.actual_model_window_size: rand_indices = torch.arange(self.actual_model_window_size).to(self.device) rand_indices_list.append(rand_indices) continue rand_indices = torch.torch.randperm(seq_len)[:self.actual_model_window_size].to(self.device) if seq_len < self.actual_model_window_size: padding = max(self.actual_model_window_size - seq_len, 0) rand_indices = torch.cat([rand_indices, torch.arange(padding).to(self.device) + seq_len], axis=-1).to(self.device) rand_indices_list.append(rand_indices) rand_indices = torch.stack(rand_indices_list, dim=0) return rand_indices def random_inputs_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs): self.model.base_model.decoder.gradient_checkpointing = False self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask) # TODO: should the inputs be sampled or the truncated beginning? # if self.random_knn_initial_inputs: # encoded_inputs, encoded_inputs_mask = self.sample_long_input(self.long_inputs_encoded, self.long_inputs_mask) # else: encoded_inputs = self.long_inputs_encoded[:, :self.actual_model_window_size] encoded_inputs_mask = self.long_inputs_mask[:, :self.actual_model_window_size] return self.original_forward_func(encoder_outputs=(encoded_inputs, ), labels=labels, attention_mask=encoded_inputs_mask, **kwargs) def sample_long_input(self, long_inputs_encoded, long_inputs_mask, random_indices=None): if long_inputs_mask.shape[-1] < self.actual_model_window_size: return long_inputs_encoded, long_inputs_mask batch_size = long_inputs_encoded.shape[0] if random_indices is None: random_indices = self.sample_random_indices() random_mask = torch.zeros_like(long_inputs_mask).to(self.device) \ .scatter_(dim=-1, index=random_indices, src=torch.ones_like(random_indices)).bool().to(self.device) sampled_input = long_inputs_encoded[random_mask].reshape(batch_size, self.actual_model_window_size, -1).to(self.device) sampled_mask = long_inputs_mask[random_mask].reshape(batch_size, self.actual_model_window_size).to(self.device) return sampled_input, sampled_mask def chunked_encode_input(self, input_ids, attention_mask): long_inputs_encoded = [] long_inputs_mask = [] window_indices = self.window_indices(input_ids.shape[-1]) self.is_input_encoding_pass = True for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices: chunk = input_ids[:, context_start_ind:context_end_ind] chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind] output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True) encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim) # list of (batch, head, chunked_time, dim) encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim) chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time) long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim) long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len) long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim) long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len) self.is_input_encoding_pass = False if self.verbose: print(f'Input: ' f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| ' f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}') print() return long_inputs_encoded, long_inputs_mask class RandomUnlimiformerBART(RandomTrainingUnlimiformer[BartModel], UnlimiformerBART): def __init__(self, model: BartModel, *args, **kwargs): super().__init__(model, *args, **kwargs) class RandomUnlimiformerT5(RandomTrainingUnlimiformer[T5Model], UnlimiformerT5): def __init__(self, model: T5Model, *args, **kwargs): super().__init__(model, *args, **kwargs) class RandomUnlimiformerLED(RandomTrainingUnlimiformer[LEDModel], UnlimiformerLED): def __init__(self, model: LEDModel, *args, **kwargs): super().__init__(model, *args, **kwargs)