Spaces:
Runtime error
Runtime error
import contextlib | |
import numpy as np | |
import torch | |
from torch import nn | |
from enum import Enum, auto | |
from .model import Unlimiformer, ModelType, UnlimiformerBART, UnlimiformerT5, UnlimiformerLED | |
from transformers import BartModel, BartForConditionalGeneration, \ | |
T5Model, T5ForConditionalGeneration, \ | |
LEDModel, LEDForConditionalGeneration, \ | |
AutoModelForSeq2SeqLM | |
class RandomTrainingUnlimiformer(Unlimiformer[ModelType]): | |
def __init__(self, model: ModelType, *args, **kwargs): | |
super().__init__(model, *args, **kwargs) | |
self.training_hooks_injected = False | |
self.train_step = 0 | |
def convert_model(cls, model, *args, **kwargs): | |
# model_clone = AutoModelForSeq2SeqLM.from_config(model.config) | |
# model_clone.load_state_dict(model.state_dict()).to(args.device) | |
type_to_class = { | |
BartModel: RandomUnlimiformerBART, | |
BartForConditionalGeneration: RandomUnlimiformerBART, | |
T5Model: RandomUnlimiformerT5, | |
T5ForConditionalGeneration: RandomUnlimiformerT5, | |
LEDModel: RandomUnlimiformerLED, | |
LEDForConditionalGeneration: RandomUnlimiformerLED, | |
} | |
type_to_class[type(model)](model, *args, **kwargs) | |
return model | |
def pre_eval_hook(self): | |
self.remove_training_hooks(self.model) | |
self.inject_hooks(self.model) | |
self.original_model_eval_func() | |
def pre_train_hook(self, mode=True): | |
# mode=True means model.train() is called | |
# mode=False means model.eval() is called | |
torch.cuda.empty_cache() | |
if mode is True: | |
self.break_out(self.model) | |
self.remove_training_hooks(self.model) | |
if self.unlimiformer_training and self.train_step % 2 == 0: | |
super().inject_training_hooks(self.model) | |
else: | |
self.inject_training_hooks(self.model) | |
self.train_step += 1 | |
self.original_model_train_func(mode) | |
def inject_training_hooks(self, model): | |
if self.training_hooks_injected: | |
return | |
# self.original_forward_func = model.forward | |
model.forward = self.random_inputs_forward_hook | |
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) | |
self.original_decoder_layer_self_attn_forward_funcs = [] | |
for decoder_layer in decoder_layers_to_run: | |
attention = self.self_attention(decoder_layer) | |
self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward) | |
attention.forward = self.create_self_attn_random_pre_forward_hook(attention.forward) | |
self.original_decoder_layer_forward_funcs = [] | |
for decoder_layer in decoder_layers_to_run: | |
self.original_decoder_layer_forward_funcs.append(decoder_layer.forward) | |
decoder_layer.forward = self.create_decoder_layer_random_func(decoder_layer.forward, decoder_layer) | |
self.original_decoder_layer_cross_attn_forward_funcs = [] | |
for i, decoder_layer in enumerate(decoder_layers_to_run): | |
attention = self.cross_attention(decoder_layer) | |
self.original_decoder_layer_cross_attn_forward_funcs.append(attention.forward) | |
self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run) | |
self.training_hooks_injected = True | |
def create_self_attn_random_pre_forward_hook(self, original_self_attn_forward_func): | |
def self_attention_pre_forward_hook(*args, **kwargs): | |
kwargs['past_key_value'] = None | |
return original_self_attn_forward_func(*args, **kwargs) | |
return self_attention_pre_forward_hook | |
def create_decoder_layer_random_func(self, decoder_layer_original_forward_func, decoder_layer): | |
def checkpointed_decoder_layer( | |
hidden_states: torch.Tensor, | |
attention_mask=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
layer_head_mask=None, | |
cross_attn_layer_head_mask=None, | |
past_key_value=None, | |
output_attentions=False, | |
position_bias=None, | |
encoder_decoder_position_bias=None, | |
use_cache=True): | |
def sample_and_forward(hidden_states, attention_mask, | |
encoder_hidden_states, encoder_attention_mask, layer_head_mask, | |
cross_attn_layer_head_mask, past_key_value, | |
output_attentions, use_cache, long_inputs, long_inputs_mask, rand_indices, | |
position_bias, encoder_decoder_position_bias): | |
sampled_input, _ = self.sample_long_input(long_inputs, long_inputs_mask, rand_indices) | |
key, value = self.create_key_value(sampled_input, decoder_layer) | |
decoder_layer_args = self.create_decoder_layer_args( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
layer_head_mask=layer_head_mask, | |
cross_attn_layer_head_mask=cross_attn_layer_head_mask, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
position_bias=position_bias, | |
encoder_decoder_position_bias=encoder_decoder_position_bias, | |
use_cache=use_cache, | |
key=key,value=value | |
) | |
return decoder_layer_original_forward_func(**decoder_layer_args) | |
with torch.no_grad(): | |
# This sampling must be done outside of the checkpoint, to ensure that the same sampling happens | |
# both in "forward" and "backward" passes | |
rand_indices = self.sample_random_indices() | |
return torch.utils.checkpoint.checkpoint( | |
sample_and_forward, hidden_states, attention_mask, | |
encoder_hidden_states, encoder_attention_mask, layer_head_mask, | |
cross_attn_layer_head_mask, None, | |
output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, rand_indices, | |
position_bias, encoder_decoder_position_bias) | |
return checkpointed_decoder_layer | |
def sample_random_indices(self): | |
rand_indices_list = [] | |
seq_lens = self.long_inputs_mask.sum(-1).tolist() | |
for seq_len in seq_lens: | |
if seq_len < self.actual_model_window_size: | |
rand_indices = torch.arange(self.actual_model_window_size).to(self.device) | |
rand_indices_list.append(rand_indices) | |
continue | |
rand_indices = torch.torch.randperm(seq_len)[:self.actual_model_window_size].to(self.device) | |
if seq_len < self.actual_model_window_size: | |
padding = max(self.actual_model_window_size - seq_len, 0) | |
rand_indices = torch.cat([rand_indices, torch.arange(padding).to(self.device) + seq_len], axis=-1).to(self.device) | |
rand_indices_list.append(rand_indices) | |
rand_indices = torch.stack(rand_indices_list, dim=0) | |
return rand_indices | |
def random_inputs_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs): | |
self.model.base_model.decoder.gradient_checkpointing = False | |
self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask) | |
# TODO: should the inputs be sampled or the truncated beginning? | |
# if self.random_knn_initial_inputs: | |
# encoded_inputs, encoded_inputs_mask = self.sample_long_input(self.long_inputs_encoded, self.long_inputs_mask) | |
# else: | |
encoded_inputs = self.long_inputs_encoded[:, :self.actual_model_window_size] | |
encoded_inputs_mask = self.long_inputs_mask[:, :self.actual_model_window_size] | |
return self.original_forward_func(encoder_outputs=(encoded_inputs, ), labels=labels, attention_mask=encoded_inputs_mask, **kwargs) | |
def sample_long_input(self, long_inputs_encoded, long_inputs_mask, random_indices=None): | |
if long_inputs_mask.shape[-1] < self.actual_model_window_size: | |
return long_inputs_encoded, long_inputs_mask | |
batch_size = long_inputs_encoded.shape[0] | |
if random_indices is None: | |
random_indices = self.sample_random_indices() | |
random_mask = torch.zeros_like(long_inputs_mask).to(self.device) \ | |
.scatter_(dim=-1, index=random_indices, src=torch.ones_like(random_indices)).bool().to(self.device) | |
sampled_input = long_inputs_encoded[random_mask].reshape(batch_size, self.actual_model_window_size, -1).to(self.device) | |
sampled_mask = long_inputs_mask[random_mask].reshape(batch_size, self.actual_model_window_size).to(self.device) | |
return sampled_input, sampled_mask | |
def chunked_encode_input(self, input_ids, attention_mask): | |
long_inputs_encoded = [] | |
long_inputs_mask = [] | |
window_indices = self.window_indices(input_ids.shape[-1]) | |
self.is_input_encoding_pass = True | |
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices: | |
chunk = input_ids[:, context_start_ind:context_end_ind] | |
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind] | |
output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True) | |
encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim) | |
# list of (batch, head, chunked_time, dim) | |
encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim) | |
chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time) | |
long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim) | |
long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len) | |
long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim) | |
long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len) | |
self.is_input_encoding_pass = False | |
if self.verbose: | |
print(f'Input: ' | |
f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| ' | |
f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}') | |
print() | |
return long_inputs_encoded, long_inputs_mask | |
class RandomUnlimiformerBART(RandomTrainingUnlimiformer[BartModel], UnlimiformerBART): | |
def __init__(self, model: BartModel, *args, **kwargs): | |
super().__init__(model, *args, **kwargs) | |
class RandomUnlimiformerT5(RandomTrainingUnlimiformer[T5Model], UnlimiformerT5): | |
def __init__(self, model: T5Model, *args, **kwargs): | |
super().__init__(model, *args, **kwargs) | |
class RandomUnlimiformerLED(RandomTrainingUnlimiformer[LEDModel], UnlimiformerLED): | |
def __init__(self, model: LEDModel, *args, **kwargs): | |
super().__init__(model, *args, **kwargs) |