File size: 11,561 Bytes
abca9bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import contextlib
import numpy as np
import torch
from torch import nn
from enum import Enum, auto
from .model import Unlimiformer, ModelType, UnlimiformerBART, UnlimiformerT5, UnlimiformerLED
from transformers import BartModel, BartForConditionalGeneration, \
    T5Model, T5ForConditionalGeneration, \
    LEDModel, LEDForConditionalGeneration, \
    AutoModelForSeq2SeqLM

class RandomTrainingUnlimiformer(Unlimiformer[ModelType]):
    def __init__(self, model: ModelType, *args, **kwargs):
        super().__init__(model, *args, **kwargs)
        self.training_hooks_injected = False
        self.train_step = 0

    @classmethod
    def convert_model(cls, model, *args, **kwargs):
        # model_clone = AutoModelForSeq2SeqLM.from_config(model.config)
        # model_clone.load_state_dict(model.state_dict()).to(args.device)
        type_to_class = {
            BartModel: RandomUnlimiformerBART,
            BartForConditionalGeneration: RandomUnlimiformerBART,
            T5Model: RandomUnlimiformerT5,
            T5ForConditionalGeneration: RandomUnlimiformerT5,
            LEDModel: RandomUnlimiformerLED,
            LEDForConditionalGeneration: RandomUnlimiformerLED,
        }
        type_to_class[type(model)](model, *args, **kwargs)
        return model

    def pre_eval_hook(self):
        self.remove_training_hooks(self.model)
        self.inject_hooks(self.model)   
        self.original_model_eval_func()

    def pre_train_hook(self, mode=True):
        # mode=True means model.train() is called
        # mode=False means model.eval() is called
        torch.cuda.empty_cache()
        if mode is True:
            self.break_out(self.model)
            self.remove_training_hooks(self.model)
            if self.unlimiformer_training and self.train_step % 2 == 0:
                super().inject_training_hooks(self.model)
            else:
                self.inject_training_hooks(self.model)
            self.train_step += 1
        self.original_model_train_func(mode)
    
    def inject_training_hooks(self, model):
        if self.training_hooks_injected:
            return
        # self.original_forward_func = model.forward
        model.forward = self.random_inputs_forward_hook

        decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
        
        self.original_decoder_layer_self_attn_forward_funcs = []
        for decoder_layer in decoder_layers_to_run:
            attention = self.self_attention(decoder_layer)
            self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward)
            attention.forward = self.create_self_attn_random_pre_forward_hook(attention.forward)

        self.original_decoder_layer_forward_funcs = []
        for decoder_layer in decoder_layers_to_run:
            self.original_decoder_layer_forward_funcs.append(decoder_layer.forward)
            decoder_layer.forward = self.create_decoder_layer_random_func(decoder_layer.forward, decoder_layer)

        self.original_decoder_layer_cross_attn_forward_funcs = []
        for i, decoder_layer in enumerate(decoder_layers_to_run):
            attention = self.cross_attention(decoder_layer)
            self.original_decoder_layer_cross_attn_forward_funcs.append(attention.forward)

        self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run)

        self.training_hooks_injected = True

    def create_self_attn_random_pre_forward_hook(self, original_self_attn_forward_func):
        def self_attention_pre_forward_hook(*args, **kwargs):
            kwargs['past_key_value'] = None
            return original_self_attn_forward_func(*args, **kwargs)
        
        return self_attention_pre_forward_hook

    def create_decoder_layer_random_func(self, decoder_layer_original_forward_func, decoder_layer):
        def checkpointed_decoder_layer(
                hidden_states: torch.Tensor,
                attention_mask=None,
                encoder_hidden_states=None,
                encoder_attention_mask=None,
                layer_head_mask=None,
                cross_attn_layer_head_mask=None,
                past_key_value=None,
                output_attentions=False,
                position_bias=None,
                encoder_decoder_position_bias=None,
                use_cache=True):

            
            
            def sample_and_forward(hidden_states, attention_mask, 
                    encoder_hidden_states, encoder_attention_mask, layer_head_mask, 
                    cross_attn_layer_head_mask, past_key_value, 
                    output_attentions, use_cache, long_inputs, long_inputs_mask, rand_indices,
                    position_bias, encoder_decoder_position_bias):
                
                sampled_input, _ = self.sample_long_input(long_inputs, long_inputs_mask, rand_indices)
                key, value = self.create_key_value(sampled_input, decoder_layer)
                decoder_layer_args = self.create_decoder_layer_args(
                    hidden_states=hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    layer_head_mask=layer_head_mask,
                    cross_attn_layer_head_mask=cross_attn_layer_head_mask,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    position_bias=position_bias,
                    encoder_decoder_position_bias=encoder_decoder_position_bias,
                    use_cache=use_cache,
                    key=key,value=value
                )
                return decoder_layer_original_forward_func(**decoder_layer_args)
                    

            with torch.no_grad():
                # This sampling must be done outside of the checkpoint, to ensure that the same sampling happens
                # both in "forward" and "backward" passes
                rand_indices = self.sample_random_indices()

            return torch.utils.checkpoint.checkpoint(
                sample_and_forward, hidden_states, attention_mask, 
                encoder_hidden_states, encoder_attention_mask, layer_head_mask, 
                cross_attn_layer_head_mask, None, 
                output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, rand_indices,
                position_bias, encoder_decoder_position_bias)

        return checkpointed_decoder_layer

    def sample_random_indices(self):
        rand_indices_list = []
        seq_lens = self.long_inputs_mask.sum(-1).tolist()
        for seq_len in seq_lens:
            if seq_len < self.actual_model_window_size:
                rand_indices = torch.arange(self.actual_model_window_size).to(self.device)
                rand_indices_list.append(rand_indices)
                continue
            
            rand_indices = torch.torch.randperm(seq_len)[:self.actual_model_window_size].to(self.device)
            if seq_len < self.actual_model_window_size:
                padding = max(self.actual_model_window_size - seq_len, 0)
                rand_indices = torch.cat([rand_indices, torch.arange(padding).to(self.device) + seq_len], axis=-1).to(self.device)
            rand_indices_list.append(rand_indices)
        rand_indices = torch.stack(rand_indices_list, dim=0)
        return rand_indices

    def random_inputs_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        self.model.base_model.decoder.gradient_checkpointing = False
        self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask)

        #  TODO: should the inputs be sampled or the truncated beginning?
        # if self.random_knn_initial_inputs:
        #     encoded_inputs, encoded_inputs_mask = self.sample_long_input(self.long_inputs_encoded, self.long_inputs_mask)
        # else:
        encoded_inputs = self.long_inputs_encoded[:, :self.actual_model_window_size]
        encoded_inputs_mask = self.long_inputs_mask[:, :self.actual_model_window_size]
        return self.original_forward_func(encoder_outputs=(encoded_inputs, ), labels=labels, attention_mask=encoded_inputs_mask, **kwargs)

    def sample_long_input(self, long_inputs_encoded, long_inputs_mask, random_indices=None):
        if long_inputs_mask.shape[-1] < self.actual_model_window_size:
            return long_inputs_encoded, long_inputs_mask
        batch_size = long_inputs_encoded.shape[0]
        
        if random_indices is None:
            random_indices = self.sample_random_indices()
        random_mask = torch.zeros_like(long_inputs_mask).to(self.device) \
            .scatter_(dim=-1, index=random_indices, src=torch.ones_like(random_indices)).bool().to(self.device)
        sampled_input = long_inputs_encoded[random_mask].reshape(batch_size, self.actual_model_window_size, -1).to(self.device)
        sampled_mask = long_inputs_mask[random_mask].reshape(batch_size, self.actual_model_window_size).to(self.device)
        return sampled_input, sampled_mask

    def chunked_encode_input(self, input_ids, attention_mask):
        long_inputs_encoded = []
        long_inputs_mask = []
        window_indices = self.window_indices(input_ids.shape[-1])

        self.is_input_encoding_pass = True
        for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
            chunk = input_ids[:, context_start_ind:context_end_ind]
            chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind]
            output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True)
            encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim)
            
            # list of (batch, head, chunked_time, dim)
            encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim)
            chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time)

            long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim)
            long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len)
        
        long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim)
        long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len)

        self.is_input_encoding_pass = False
        if self.verbose:
            print(f'Input: '
                f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| '
                f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}')
            print()
        return long_inputs_encoded, long_inputs_mask

class RandomUnlimiformerBART(RandomTrainingUnlimiformer[BartModel], UnlimiformerBART):
    def __init__(self, model: BartModel, *args, **kwargs):
        super().__init__(model, *args, **kwargs)

class RandomUnlimiformerT5(RandomTrainingUnlimiformer[T5Model], UnlimiformerT5):
    def __init__(self, model: T5Model, *args, **kwargs):
        super().__init__(model, *args, **kwargs)

class RandomUnlimiformerLED(RandomTrainingUnlimiformer[LEDModel], UnlimiformerLED):
    def __init__(self, model: LEDModel, *args, **kwargs):
        super().__init__(model, *args, **kwargs)