CelebChat / unlimiformer /random_training_unlimiformer.py
lhzstar
new commits
abca9bf
raw
history blame
11.6 kB
import contextlib
import numpy as np
import torch
from torch import nn
from enum import Enum, auto
from .model import Unlimiformer, ModelType, UnlimiformerBART, UnlimiformerT5, UnlimiformerLED
from transformers import BartModel, BartForConditionalGeneration, \
T5Model, T5ForConditionalGeneration, \
LEDModel, LEDForConditionalGeneration, \
AutoModelForSeq2SeqLM
class RandomTrainingUnlimiformer(Unlimiformer[ModelType]):
def __init__(self, model: ModelType, *args, **kwargs):
super().__init__(model, *args, **kwargs)
self.training_hooks_injected = False
self.train_step = 0
@classmethod
def convert_model(cls, model, *args, **kwargs):
# model_clone = AutoModelForSeq2SeqLM.from_config(model.config)
# model_clone.load_state_dict(model.state_dict()).to(args.device)
type_to_class = {
BartModel: RandomUnlimiformerBART,
BartForConditionalGeneration: RandomUnlimiformerBART,
T5Model: RandomUnlimiformerT5,
T5ForConditionalGeneration: RandomUnlimiformerT5,
LEDModel: RandomUnlimiformerLED,
LEDForConditionalGeneration: RandomUnlimiformerLED,
}
type_to_class[type(model)](model, *args, **kwargs)
return model
def pre_eval_hook(self):
self.remove_training_hooks(self.model)
self.inject_hooks(self.model)
self.original_model_eval_func()
def pre_train_hook(self, mode=True):
# mode=True means model.train() is called
# mode=False means model.eval() is called
torch.cuda.empty_cache()
if mode is True:
self.break_out(self.model)
self.remove_training_hooks(self.model)
if self.unlimiformer_training and self.train_step % 2 == 0:
super().inject_training_hooks(self.model)
else:
self.inject_training_hooks(self.model)
self.train_step += 1
self.original_model_train_func(mode)
def inject_training_hooks(self, model):
if self.training_hooks_injected:
return
# self.original_forward_func = model.forward
model.forward = self.random_inputs_forward_hook
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
self.original_decoder_layer_self_attn_forward_funcs = []
for decoder_layer in decoder_layers_to_run:
attention = self.self_attention(decoder_layer)
self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward)
attention.forward = self.create_self_attn_random_pre_forward_hook(attention.forward)
self.original_decoder_layer_forward_funcs = []
for decoder_layer in decoder_layers_to_run:
self.original_decoder_layer_forward_funcs.append(decoder_layer.forward)
decoder_layer.forward = self.create_decoder_layer_random_func(decoder_layer.forward, decoder_layer)
self.original_decoder_layer_cross_attn_forward_funcs = []
for i, decoder_layer in enumerate(decoder_layers_to_run):
attention = self.cross_attention(decoder_layer)
self.original_decoder_layer_cross_attn_forward_funcs.append(attention.forward)
self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run)
self.training_hooks_injected = True
def create_self_attn_random_pre_forward_hook(self, original_self_attn_forward_func):
def self_attention_pre_forward_hook(*args, **kwargs):
kwargs['past_key_value'] = None
return original_self_attn_forward_func(*args, **kwargs)
return self_attention_pre_forward_hook
def create_decoder_layer_random_func(self, decoder_layer_original_forward_func, decoder_layer):
def checkpointed_decoder_layer(
hidden_states: torch.Tensor,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
output_attentions=False,
position_bias=None,
encoder_decoder_position_bias=None,
use_cache=True):
def sample_and_forward(hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, past_key_value,
output_attentions, use_cache, long_inputs, long_inputs_mask, rand_indices,
position_bias, encoder_decoder_position_bias):
sampled_input, _ = self.sample_long_input(long_inputs, long_inputs_mask, rand_indices)
key, value = self.create_key_value(sampled_input, decoder_layer)
decoder_layer_args = self.create_decoder_layer_args(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
use_cache=use_cache,
key=key,value=value
)
return decoder_layer_original_forward_func(**decoder_layer_args)
with torch.no_grad():
# This sampling must be done outside of the checkpoint, to ensure that the same sampling happens
# both in "forward" and "backward" passes
rand_indices = self.sample_random_indices()
return torch.utils.checkpoint.checkpoint(
sample_and_forward, hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, None,
output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, rand_indices,
position_bias, encoder_decoder_position_bias)
return checkpointed_decoder_layer
def sample_random_indices(self):
rand_indices_list = []
seq_lens = self.long_inputs_mask.sum(-1).tolist()
for seq_len in seq_lens:
if seq_len < self.actual_model_window_size:
rand_indices = torch.arange(self.actual_model_window_size).to(self.device)
rand_indices_list.append(rand_indices)
continue
rand_indices = torch.torch.randperm(seq_len)[:self.actual_model_window_size].to(self.device)
if seq_len < self.actual_model_window_size:
padding = max(self.actual_model_window_size - seq_len, 0)
rand_indices = torch.cat([rand_indices, torch.arange(padding).to(self.device) + seq_len], axis=-1).to(self.device)
rand_indices_list.append(rand_indices)
rand_indices = torch.stack(rand_indices_list, dim=0)
return rand_indices
def random_inputs_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
self.model.base_model.decoder.gradient_checkpointing = False
self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask)
# TODO: should the inputs be sampled or the truncated beginning?
# if self.random_knn_initial_inputs:
# encoded_inputs, encoded_inputs_mask = self.sample_long_input(self.long_inputs_encoded, self.long_inputs_mask)
# else:
encoded_inputs = self.long_inputs_encoded[:, :self.actual_model_window_size]
encoded_inputs_mask = self.long_inputs_mask[:, :self.actual_model_window_size]
return self.original_forward_func(encoder_outputs=(encoded_inputs, ), labels=labels, attention_mask=encoded_inputs_mask, **kwargs)
def sample_long_input(self, long_inputs_encoded, long_inputs_mask, random_indices=None):
if long_inputs_mask.shape[-1] < self.actual_model_window_size:
return long_inputs_encoded, long_inputs_mask
batch_size = long_inputs_encoded.shape[0]
if random_indices is None:
random_indices = self.sample_random_indices()
random_mask = torch.zeros_like(long_inputs_mask).to(self.device) \
.scatter_(dim=-1, index=random_indices, src=torch.ones_like(random_indices)).bool().to(self.device)
sampled_input = long_inputs_encoded[random_mask].reshape(batch_size, self.actual_model_window_size, -1).to(self.device)
sampled_mask = long_inputs_mask[random_mask].reshape(batch_size, self.actual_model_window_size).to(self.device)
return sampled_input, sampled_mask
def chunked_encode_input(self, input_ids, attention_mask):
long_inputs_encoded = []
long_inputs_mask = []
window_indices = self.window_indices(input_ids.shape[-1])
self.is_input_encoding_pass = True
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
chunk = input_ids[:, context_start_ind:context_end_ind]
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind]
output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True)
encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim)
# list of (batch, head, chunked_time, dim)
encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim)
chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time)
long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim)
long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len)
long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim)
long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len)
self.is_input_encoding_pass = False
if self.verbose:
print(f'Input: '
f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| '
f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}')
print()
return long_inputs_encoded, long_inputs_mask
class RandomUnlimiformerBART(RandomTrainingUnlimiformer[BartModel], UnlimiformerBART):
def __init__(self, model: BartModel, *args, **kwargs):
super().__init__(model, *args, **kwargs)
class RandomUnlimiformerT5(RandomTrainingUnlimiformer[T5Model], UnlimiformerT5):
def __init__(self, model: T5Model, *args, **kwargs):
super().__init__(model, *args, **kwargs)
class RandomUnlimiformerLED(RandomTrainingUnlimiformer[LEDModel], UnlimiformerLED):
def __init__(self, model: LEDModel, *args, **kwargs):
super().__init__(model, *args, **kwargs)