Spaces:
Runtime error
Runtime error
File size: 11,561 Bytes
abca9bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import contextlib
import numpy as np
import torch
from torch import nn
from enum import Enum, auto
from .model import Unlimiformer, ModelType, UnlimiformerBART, UnlimiformerT5, UnlimiformerLED
from transformers import BartModel, BartForConditionalGeneration, \
T5Model, T5ForConditionalGeneration, \
LEDModel, LEDForConditionalGeneration, \
AutoModelForSeq2SeqLM
class RandomTrainingUnlimiformer(Unlimiformer[ModelType]):
def __init__(self, model: ModelType, *args, **kwargs):
super().__init__(model, *args, **kwargs)
self.training_hooks_injected = False
self.train_step = 0
@classmethod
def convert_model(cls, model, *args, **kwargs):
# model_clone = AutoModelForSeq2SeqLM.from_config(model.config)
# model_clone.load_state_dict(model.state_dict()).to(args.device)
type_to_class = {
BartModel: RandomUnlimiformerBART,
BartForConditionalGeneration: RandomUnlimiformerBART,
T5Model: RandomUnlimiformerT5,
T5ForConditionalGeneration: RandomUnlimiformerT5,
LEDModel: RandomUnlimiformerLED,
LEDForConditionalGeneration: RandomUnlimiformerLED,
}
type_to_class[type(model)](model, *args, **kwargs)
return model
def pre_eval_hook(self):
self.remove_training_hooks(self.model)
self.inject_hooks(self.model)
self.original_model_eval_func()
def pre_train_hook(self, mode=True):
# mode=True means model.train() is called
# mode=False means model.eval() is called
torch.cuda.empty_cache()
if mode is True:
self.break_out(self.model)
self.remove_training_hooks(self.model)
if self.unlimiformer_training and self.train_step % 2 == 0:
super().inject_training_hooks(self.model)
else:
self.inject_training_hooks(self.model)
self.train_step += 1
self.original_model_train_func(mode)
def inject_training_hooks(self, model):
if self.training_hooks_injected:
return
# self.original_forward_func = model.forward
model.forward = self.random_inputs_forward_hook
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
self.original_decoder_layer_self_attn_forward_funcs = []
for decoder_layer in decoder_layers_to_run:
attention = self.self_attention(decoder_layer)
self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward)
attention.forward = self.create_self_attn_random_pre_forward_hook(attention.forward)
self.original_decoder_layer_forward_funcs = []
for decoder_layer in decoder_layers_to_run:
self.original_decoder_layer_forward_funcs.append(decoder_layer.forward)
decoder_layer.forward = self.create_decoder_layer_random_func(decoder_layer.forward, decoder_layer)
self.original_decoder_layer_cross_attn_forward_funcs = []
for i, decoder_layer in enumerate(decoder_layers_to_run):
attention = self.cross_attention(decoder_layer)
self.original_decoder_layer_cross_attn_forward_funcs.append(attention.forward)
self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run)
self.training_hooks_injected = True
def create_self_attn_random_pre_forward_hook(self, original_self_attn_forward_func):
def self_attention_pre_forward_hook(*args, **kwargs):
kwargs['past_key_value'] = None
return original_self_attn_forward_func(*args, **kwargs)
return self_attention_pre_forward_hook
def create_decoder_layer_random_func(self, decoder_layer_original_forward_func, decoder_layer):
def checkpointed_decoder_layer(
hidden_states: torch.Tensor,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
output_attentions=False,
position_bias=None,
encoder_decoder_position_bias=None,
use_cache=True):
def sample_and_forward(hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, past_key_value,
output_attentions, use_cache, long_inputs, long_inputs_mask, rand_indices,
position_bias, encoder_decoder_position_bias):
sampled_input, _ = self.sample_long_input(long_inputs, long_inputs_mask, rand_indices)
key, value = self.create_key_value(sampled_input, decoder_layer)
decoder_layer_args = self.create_decoder_layer_args(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
use_cache=use_cache,
key=key,value=value
)
return decoder_layer_original_forward_func(**decoder_layer_args)
with torch.no_grad():
# This sampling must be done outside of the checkpoint, to ensure that the same sampling happens
# both in "forward" and "backward" passes
rand_indices = self.sample_random_indices()
return torch.utils.checkpoint.checkpoint(
sample_and_forward, hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, None,
output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, rand_indices,
position_bias, encoder_decoder_position_bias)
return checkpointed_decoder_layer
def sample_random_indices(self):
rand_indices_list = []
seq_lens = self.long_inputs_mask.sum(-1).tolist()
for seq_len in seq_lens:
if seq_len < self.actual_model_window_size:
rand_indices = torch.arange(self.actual_model_window_size).to(self.device)
rand_indices_list.append(rand_indices)
continue
rand_indices = torch.torch.randperm(seq_len)[:self.actual_model_window_size].to(self.device)
if seq_len < self.actual_model_window_size:
padding = max(self.actual_model_window_size - seq_len, 0)
rand_indices = torch.cat([rand_indices, torch.arange(padding).to(self.device) + seq_len], axis=-1).to(self.device)
rand_indices_list.append(rand_indices)
rand_indices = torch.stack(rand_indices_list, dim=0)
return rand_indices
def random_inputs_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
self.model.base_model.decoder.gradient_checkpointing = False
self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask)
# TODO: should the inputs be sampled or the truncated beginning?
# if self.random_knn_initial_inputs:
# encoded_inputs, encoded_inputs_mask = self.sample_long_input(self.long_inputs_encoded, self.long_inputs_mask)
# else:
encoded_inputs = self.long_inputs_encoded[:, :self.actual_model_window_size]
encoded_inputs_mask = self.long_inputs_mask[:, :self.actual_model_window_size]
return self.original_forward_func(encoder_outputs=(encoded_inputs, ), labels=labels, attention_mask=encoded_inputs_mask, **kwargs)
def sample_long_input(self, long_inputs_encoded, long_inputs_mask, random_indices=None):
if long_inputs_mask.shape[-1] < self.actual_model_window_size:
return long_inputs_encoded, long_inputs_mask
batch_size = long_inputs_encoded.shape[0]
if random_indices is None:
random_indices = self.sample_random_indices()
random_mask = torch.zeros_like(long_inputs_mask).to(self.device) \
.scatter_(dim=-1, index=random_indices, src=torch.ones_like(random_indices)).bool().to(self.device)
sampled_input = long_inputs_encoded[random_mask].reshape(batch_size, self.actual_model_window_size, -1).to(self.device)
sampled_mask = long_inputs_mask[random_mask].reshape(batch_size, self.actual_model_window_size).to(self.device)
return sampled_input, sampled_mask
def chunked_encode_input(self, input_ids, attention_mask):
long_inputs_encoded = []
long_inputs_mask = []
window_indices = self.window_indices(input_ids.shape[-1])
self.is_input_encoding_pass = True
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
chunk = input_ids[:, context_start_ind:context_end_ind]
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind]
output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True)
encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim)
# list of (batch, head, chunked_time, dim)
encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim)
chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time)
long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim)
long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len)
long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim)
long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len)
self.is_input_encoding_pass = False
if self.verbose:
print(f'Input: '
f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| '
f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}')
print()
return long_inputs_encoded, long_inputs_mask
class RandomUnlimiformerBART(RandomTrainingUnlimiformer[BartModel], UnlimiformerBART):
def __init__(self, model: BartModel, *args, **kwargs):
super().__init__(model, *args, **kwargs)
class RandomUnlimiformerT5(RandomTrainingUnlimiformer[T5Model], UnlimiformerT5):
def __init__(self, model: T5Model, *args, **kwargs):
super().__init__(model, *args, **kwargs)
class RandomUnlimiformerLED(RandomTrainingUnlimiformer[LEDModel], UnlimiformerLED):
def __init__(self, model: LEDModel, *args, **kwargs):
super().__init__(model, *args, **kwargs) |