Spaces:
Runtime error
Runtime error
import logging | |
import numpy as np | |
import torch | |
from torch import nn | |
from enum import Enum, auto | |
from transformers import BartModel, BartForConditionalGeneration, \ | |
T5Model, T5ForConditionalGeneration, \ | |
LEDModel, LEDForConditionalGeneration, \ | |
AutoModelForCausalLM, AutoModelForSeq2SeqLM, \ | |
MODEL_WITH_LM_HEAD_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING | |
from typing import TypeVar, Generic | |
from .index_building import Datastore, DatastoreBatch | |
logger = logging.getLogger('Unlimiformer') | |
logger.setLevel(20) | |
ModelType = TypeVar('ModelType') | |
class Unlimiformer(Generic[ModelType]): | |
def __init__(self, model: ModelType, | |
layer_begin=-1, layer_end=None, | |
unlimiformer_head_num=None, | |
exclude_attention=False, | |
model_encoder_max_len=None, | |
chunk_overlap=0, | |
verbose=False, save_heatmap=False, | |
tokenizer=None, unlimiformer_training=False, | |
use_datastore=False, | |
flat_index=False, | |
test_datastore=False, reconstruct_embeddings=False, | |
gpu_datastore=False, gpu_index=False, | |
index_devices=(0,), datastore_device=0, | |
): | |
super().__init__() | |
self.model = model | |
model.unlimiformer = self | |
self.layer_begin = layer_begin | |
self.layer_end = layer_end | |
self.specific_head = unlimiformer_head_num | |
self.exclude_attention = exclude_attention | |
self.actual_model_window_size = None | |
self.model_encoder_max_len = model_encoder_max_len | |
self.chunk_overlap = chunk_overlap | |
self.verbose = verbose | |
self.save_heatmap = save_heatmap | |
self.tokenizer = tokenizer | |
self.unlimiformer_training = unlimiformer_training | |
self.use_datastore = use_datastore | |
self.flat_index = flat_index | |
self.reconstruct_embeddings = reconstruct_embeddings | |
self.gpu_datastore = gpu_datastore | |
self.gpu_index = gpu_index | |
# if torch.cuda.is_available() and gpu_index: | |
# self.index_devices = [torch.device(f'cuda:{i}') for i in index_devices] | |
# else: | |
self.index_devices = [torch.device('cpu')] | |
self.datastore_device = torch.device('cpu') | |
self.test_datastore = test_datastore # flag for debugging | |
self.device = torch.device('cpu') | |
self.activation_capturer = None | |
self.is_encoder_decoder = model.config.is_encoder_decoder | |
self.hook_handles = [] | |
self.is_input_encoding_pass = False | |
self.is_first_test_decoding_step = False | |
self.prev_tokens = None | |
self.last_beam_idx = None | |
self.heatmap = None | |
self.cur_decoder_layer_index = None | |
self.datastore = None | |
self.break_into(model) | |
def break_into(self, model): | |
self.actual_model_window_size = self.window_size() | |
if self.model_encoder_max_len is None: | |
self.model_encoder_max_len = self.actual_model_window_size | |
self.window_margin = int(self.model_encoder_max_len * self.chunk_overlap / 2) | |
self.num_heads = model.config.num_attention_heads | |
if self.specific_head is None: | |
self.head_nums = Ellipsis # torch.arange(0, self.num_heads, device=self.device) | |
else: | |
self.head_nums = self.specific_head | |
self.hooks_injected = False | |
self.training_hooks_injected = False | |
self.original_forward_func = model.forward | |
# Activate Unlimiformer when calling model.eval(), deactivate for model.train() | |
self.original_model_eval_func = model.eval | |
model.eval = self.pre_eval_hook | |
self.original_model_train_func = model.train | |
model.train = self.pre_train_hook | |
def pre_eval_hook(self): | |
self.remove_training_hooks(self.model) | |
self.inject_hooks(self.model) | |
self.original_model_eval_func() | |
def pre_train_hook(self, mode=True): | |
# mode=True means model.train() is called | |
# mode=False means model.eval() is called | |
torch.cuda.empty_cache() | |
if mode is True: | |
self.break_out(self.model) | |
if self.unlimiformer_training: | |
self.inject_training_hooks(self.model) | |
self.original_model_train_func(mode) | |
def inject_hooks(self, model): | |
if self.hooks_injected: | |
return | |
# Inject our activation_capturer to capture the activations at every forward pass | |
attention_layers_to_capture = self.activation_to_capture(self.layer_begin, self.layer_end) | |
self.activation_capturer = [] | |
for layer in attention_layers_to_capture: | |
if type(layer) is list: | |
layer_capturers = [] | |
for k_or_v in layer: | |
capturer = ActivationCapturer(k_or_v, capture_input=False) | |
layer_capturers.append(capturer) | |
self.register_hook(k_or_v, capturer) | |
self.activation_capturer.append(layer_capturers) | |
else: | |
capturer = ActivationCapturer(layer, capture_input=False) | |
self.register_hook(layer, capturer) | |
self.activation_capturer.append(capturer) | |
# Inject our main function after the main attention function | |
attention_layers_to_run = self.attention_op_to_run(self.layer_begin, self.layer_end) | |
for layer in attention_layers_to_run: | |
self.register_hook(layer, self.attention_forward_hook) | |
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) | |
self.original_decoder_layer_cross_attn_forward_funcs = [] | |
for i, decoder_layer in enumerate(decoder_layers_to_run): | |
decoder_layer_cross_attention = self.cross_attention(decoder_layer) | |
self.original_decoder_layer_cross_attn_forward_funcs.append(decoder_layer_cross_attention.forward) | |
decoder_layer_cross_attention.forward = self.create_cross_attn_pre_forward_hook(decoder_layer_cross_attention.forward, i) | |
# Inject our hook function in the beginning of generation. | |
# When the "model.generate()" will be called, it will first call our "reset_generation()" function, | |
# and only then call "model.generate()" | |
self.original_generate_func = model.generate | |
model.generate = self.pre_generate_hook | |
model.forward = self.pre_forward_hook | |
self.original_reorder_cache_func = model._reorder_cache | |
model._reorder_cache = self.reorder_cache_hook | |
self.hooks_injected = True | |
def inject_training_hooks(self, model): | |
if self.training_hooks_injected: | |
return | |
# self.original_forward_func = model.forward | |
model.forward = self.pre_forward_hook | |
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) | |
self.original_decoder_layer_self_attn_forward_funcs = [] | |
for decoder_layer in decoder_layers_to_run: | |
attention = self.self_attention(decoder_layer) | |
self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward) | |
attention.forward = self.create_self_attn_pre_forward_hook(attention.forward) | |
self.original_decoder_layer_cross_attn_forward_funcs = [] | |
for i, decoder_layer in enumerate(decoder_layers_to_run): | |
decoder_layer_cross_attention = self.cross_attention(decoder_layer) | |
self.original_decoder_layer_cross_attn_forward_funcs.append(decoder_layer_cross_attention.forward) | |
decoder_layer_cross_attention.forward = self.create_cross_attn_pre_forward_hook(decoder_layer_cross_attention.forward, i) | |
self.original_decoder_layer_forward_funcs = [] | |
for decoder_layer in decoder_layers_to_run: | |
self.original_decoder_layer_forward_funcs.append(decoder_layer.forward) | |
decoder_layer.forward = self.create_decoder_layer_func(decoder_layer.forward, decoder_layer) | |
self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run) | |
attention_layers_to_run = self.attention_op_to_run(self.layer_begin, self.layer_end) | |
for layer in attention_layers_to_run: | |
self.register_hook(layer, self.train_attention_forward_hook) | |
self.training_hooks_injected = True | |
def inject_hooks_for_unaffected_layers(self, model, decoder_layers_to_run): | |
self.original_non_injected_decoder_layer_forward_funcs = [] | |
non_injected_decoder_layers = [l for l in self.attention_layer_to_run(0, None) | |
if l not in decoder_layers_to_run] | |
for decoder_layer in non_injected_decoder_layers: | |
self.original_non_injected_decoder_layer_forward_funcs.append(decoder_layer.forward) | |
decoder_layer.forward = self.create_noninjected_decoder_layer_func(decoder_layer.forward, decoder_layer) | |
def create_self_attn_pre_forward_hook(self, original_self_attn_forward_func): | |
def self_attention_pre_forward_hook(*args, **kwargs): | |
kwargs['past_key_value'] = None | |
return original_self_attn_forward_func(*args, **kwargs) | |
return self_attention_pre_forward_hook | |
def create_decoder_layer_func(self, decoder_layer_original_forward_func, decoder_layer): | |
def checkpointed_decoder_layer( | |
hidden_states: torch.Tensor, | |
attention_mask=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
layer_head_mask=None, | |
cross_attn_layer_head_mask=None, | |
past_key_value=None, | |
output_attentions=False, | |
position_bias=None, | |
encoder_decoder_position_bias=None, | |
use_cache=True): | |
def forward_with_all_keys(hidden_states, attention_mask, | |
encoder_hidden_states, encoder_attention_mask, layer_head_mask, | |
cross_attn_layer_head_mask, past_key_value, | |
output_attentions, use_cache, long_inputs, long_inputs_mask, | |
position_bias, encoder_decoder_position_bias): | |
key, value = self.create_key_value(long_inputs, decoder_layer) | |
decoder_layer_args = self.create_decoder_layer_args( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
layer_head_mask=layer_head_mask, | |
cross_attn_layer_head_mask=cross_attn_layer_head_mask, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
position_bias=position_bias, | |
encoder_decoder_position_bias=encoder_decoder_position_bias, | |
use_cache=use_cache, | |
key=key,value=value) | |
return decoder_layer_original_forward_func(**decoder_layer_args) | |
return torch.utils.checkpoint.checkpoint( | |
forward_with_all_keys, hidden_states, attention_mask, | |
encoder_hidden_states, encoder_attention_mask, layer_head_mask, | |
cross_attn_layer_head_mask, None, | |
output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, | |
position_bias, encoder_decoder_position_bias) | |
return checkpointed_decoder_layer | |
def create_noninjected_decoder_layer_func(self, decoder_layer_original_forward_func, decoder_layer): | |
def checkpointed_decoder_layer( | |
hidden_states: torch.Tensor, | |
attention_mask=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
layer_head_mask=None, | |
cross_attn_layer_head_mask=None, | |
past_key_value=None, | |
output_attentions=False, | |
position_bias=None, | |
encoder_decoder_position_bias=None, | |
use_cache=True): | |
def forward_with_all_keys(hidden_states, attention_mask, | |
encoder_hidden_states, encoder_attention_mask, layer_head_mask, | |
cross_attn_layer_head_mask, past_key_value, | |
output_attentions, use_cache, long_inputs, long_inputs_mask, | |
position_bias, encoder_decoder_position_bias): | |
decoder_layer_args = self.create_decoder_layer_args( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
layer_head_mask=layer_head_mask, | |
cross_attn_layer_head_mask=cross_attn_layer_head_mask, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
position_bias=position_bias, | |
encoder_decoder_position_bias=encoder_decoder_position_bias, | |
use_cache=use_cache, key=None, value=None) | |
return decoder_layer_original_forward_func(**decoder_layer_args) | |
return torch.utils.checkpoint.checkpoint( | |
forward_with_all_keys, hidden_states, attention_mask, | |
encoder_hidden_states, encoder_attention_mask, layer_head_mask, | |
cross_attn_layer_head_mask, None, | |
output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, | |
position_bias, encoder_decoder_position_bias) | |
return checkpointed_decoder_layer | |
def register_hook(self, layer, func, pre=False): | |
handle = layer.register_forward_pre_hook(func) if pre else layer.register_forward_hook(func) | |
self.hook_handles.append(handle) | |
def break_out(self, model): | |
self.prompt_keys = [] | |
self.prompt_values = [] | |
self.prompt_attention_mask = [] | |
self.generated_input_ids = [] | |
torch.cuda.empty_cache() | |
if not self.hooks_injected: | |
return | |
for h in self.hook_handles: | |
h.remove() | |
model.generate = self.original_generate_func | |
model.forward = self.original_forward_func | |
model._reorder_cache = self.original_reorder_cache_func | |
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) | |
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_cross_attn_forward_funcs): | |
self.cross_attention(decoder_layer).forward = original_func | |
self.hooks_injected = False | |
def remove_training_hooks(self, model): | |
self.long_inputs_encoded, self.long_inputs_mask = None, None | |
if not self.training_hooks_injected: | |
return | |
for h in self.hook_handles: | |
h.remove() | |
model.forward = self.original_forward_func | |
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) | |
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_self_attn_forward_funcs): | |
self.self_attention(decoder_layer).forward = original_func | |
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_cross_attn_forward_funcs): | |
self.cross_attention(decoder_layer).forward = original_func | |
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_forward_funcs): | |
decoder_layer.forward = original_func | |
non_injected_decoder_layers = [l for l in self.attention_layer_to_run(0, None) | |
if l not in decoder_layers_to_run] | |
for decoder_layer, original_func in zip(non_injected_decoder_layers, self.original_non_injected_decoder_layer_forward_funcs): | |
decoder_layer.forward = original_func | |
self.training_hooks_injected = False | |
def reset_memory(self, input_ids, attention_mask): | |
if self.use_datastore: | |
if self.is_encoder_decoder: | |
self.datastore = [DatastoreBatch(dim=self.model.config.hidden_size, batch_size=input_ids.shape[0], flat_index=self.flat_index, | |
gpu_index=self.gpu_index, index_device=self.index_devices[0])] | |
self.hidden_states = [[]] | |
else: | |
self.datastore = [DatastoreBatch(dim=self.model.config.hidden_size, batch_size=input_ids.shape[0], flat_index=self.flat_index, | |
gpu_index=self.gpu_index, index_device=self.index_devices[i % len(self.index_devices)]) | |
for i in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] | |
self.hidden_states = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] | |
torch.cuda.empty_cache() | |
self.prompt_input_ids = input_ids | |
self.input_ids_size = input_ids.shape[-1] | |
self.prompt_keys, self.prompt_values = None, None | |
self.prev_tokens = [None for _ in range(len(self.original_decoder_layer_cross_attn_forward_funcs))] | |
self.last_beam_idx = None | |
self.cur_layer_key_value_placeholder = None | |
self.is_input_encoding_pass = True | |
if self.is_encoder_decoder: | |
dummy_labels = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) | |
else: | |
dummy_labels = None | |
if self.save_heatmap: | |
if self.heatmap is not None: | |
print(f'Generated: {self.tokenizer.decode(self.generated_input_ids[0])}') | |
self.plot_heatmap(self.heatmap[0].detach().cpu().numpy()) | |
self.heatmap = torch.tensor([], dtype=torch.float, device=input_ids.device) | |
self.generated_input_ids = torch.tensor([], dtype=torch.long, device=input_ids.device) | |
self.prompt_keys = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] | |
self.prompt_values = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] | |
self.prompt_attention_mask = [] | |
window_indices = self.window_indices(input_ids.shape[-1]) | |
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices: | |
logger.info(f'Encoding {context_start_ind} to {context_end_ind} out of {input_ids.shape[-1]}') | |
chunk = input_ids[:, context_start_ind:context_end_ind].to(self.device) | |
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind].to(self.device) | |
with torch.inference_mode(): | |
_ = self.model(chunk, attention_mask=chunk_attention_mask, labels=dummy_labels) # , return_dict=True, output_hidden_states=True) | |
if self.use_datastore: | |
# TODO: verify with BART as well | |
# hidden_states_to_index = [hidden_states.encoder_last_hidden_state] # list of length 1 of (batch, chunked_source_len, dim) | |
hidden_states_to_index = [ | |
layer_capturer.captured for layer_capturer in self.activation_capturer | |
] | |
# hidden_states_to_index = list(hidden_states.hidden_states)[:-1][self.layer_begin:self.layer_end] | |
to_add = [state[:, update_start_ind:update_end_ind].detach() for state in hidden_states_to_index] | |
to_apply_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] | |
# to_apply_mask = to_apply_mask.log().to(to_add[0].dtype) | |
to_apply_mask = to_apply_mask.to(to_add[0].dtype) | |
if not self.reconstruct_embeddings: | |
to_add_embeddings = to_add | |
if not self.gpu_datastore: | |
to_add_embeddings = [states.cpu() for states in to_add_embeddings] | |
to_apply_mask = to_apply_mask.cpu() | |
for i, layer_states in enumerate(to_add_embeddings): | |
layer_states = layer_states * to_apply_mask.unsqueeze(-1) | |
self.hidden_states[i].append(layer_states.to(self.datastore_device)) | |
# list of len layers, inside it there is a list of len batch, each item is (masked_time, dim) | |
# for i, to_add_layer in enumerate(to_add): | |
# keys = [key[mask.bool()] for key, mask in zip(to_add_layer, to_apply_mask)] | |
# self.datastore[i].add_keys(keys) | |
if (not self.use_datastore) or self.test_datastore: | |
layers_kv = [ | |
self.process_key_value(layer_capturer) # (batch, head, time, dim) | |
for layer_capturer in self.activation_capturer | |
] # list of pairs of (batch, head, time, dim) | |
# list of (batch, head, chunked_time, dim) | |
key = [layer[0][:, :, update_start_ind:update_end_ind] for layer in layers_kv] | |
value = [layer[1][:, :, update_start_ind:update_end_ind] for layer in layers_kv] | |
chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time) | |
# key = torch.stack(key, dim=0) # (num_layers, batch, head, time, dim) | |
# value = torch.stack(value, dim=0) # (num_layers, batch, head, time, dim) | |
for i, (layer_key, layer_value) in enumerate(zip(key, value)): | |
self.prompt_keys[i].append(layer_key) # (num_layers, batch, head, chunked_source_len, dim) | |
self.prompt_values[i].append(layer_value) # (num_layers, batch, head, chunked_source_len, dim) | |
self.prompt_attention_mask.append(chunk_attention_mask) # (batch, chunked_source_len) | |
if self.use_datastore: | |
# keys are all in datastore already! | |
if not self.reconstruct_embeddings: | |
# self.hidden_states = [torch.cat(layer_hidden_states, axis=1) for layer_hidden_states in self.hidden_states] | |
concat_hidden_states = [] | |
for i in range(len(self.hidden_states)): | |
concat_hidden_states.append(torch.cat(self.hidden_states[i], axis=1)) | |
self.hidden_states[i] = None | |
self.hidden_states = concat_hidden_states | |
for datastore, layer_hidden_states in zip(self.datastore, self.hidden_states): | |
datastore.train_index(layer_hidden_states) | |
if (not self.use_datastore) or self.test_datastore: | |
for i, (layer_keys, layer_values) in enumerate(zip(self.prompt_keys, self.prompt_values)): | |
self.prompt_keys[i] = torch.cat(layer_keys, dim=-2) | |
self.prompt_values[i] = torch.cat(layer_values, dim=-2) | |
# self.prompt_keys = torch.cat(self.prompt_keys, dim=-2) # (num_layers, batch, head, source_len, dim) | |
# self.prompt_values = torch.cat(self.prompt_values, dim=-2) # (num_layers, batch, head, source_len, dim) | |
self.prompt_attention_mask = torch.cat(self.prompt_attention_mask, dim=-1) # (batch, source_len) | |
self.is_input_encoding_pass = False | |
if self.verbose: | |
print(f'Input: ' | |
f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| ' | |
f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}') | |
print() | |
def chunked_encode_input(self, input_ids, attention_mask): | |
long_inputs_encoded = [] | |
long_inputs_mask = [] | |
window_indices = self.window_indices(input_ids.shape[-1]) | |
self.is_input_encoding_pass = True | |
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices: | |
chunk = input_ids[:, context_start_ind:context_end_ind] | |
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind] | |
output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True) | |
encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim) | |
# list of (batch, head, chunked_time, dim) | |
encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim) | |
chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time) | |
long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim) | |
long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len) | |
long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim) | |
long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len) | |
self.is_input_encoding_pass = False | |
if self.verbose: | |
print(f'Input: ' | |
f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| ' | |
f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}') | |
print() | |
return long_inputs_encoded, long_inputs_mask | |
def window_indices(self, total_seq_len): | |
# Copied from SLED (Ivgy et al., 2022) | |
# https://github.com/Mivg/SLED/blob/main/sled/modeling_sled.py#L467 | |
if total_seq_len <= self.model_encoder_max_len: | |
return [(0, total_seq_len, 0, total_seq_len)] | |
else: | |
results = [] | |
# if self.chunk_overlap == 0: | |
# stride = self.model_encoder_max_len | |
stride = self.model_encoder_max_len - 2 * self.window_margin | |
context_start = update_start_ind = 0 | |
context_end = self.model_encoder_max_len | |
if self.is_encoder_decoder: | |
update_end_ind = context_end - self.window_margin | |
else: | |
update_end_ind = context_end | |
# first window always should update from the beginning | |
results.append((context_start, context_end, update_start_ind, update_end_ind)) | |
while context_end < total_seq_len: | |
context_end = min(total_seq_len, context_end + stride) | |
context_start = ( | |
context_start + stride if context_end < total_seq_len else total_seq_len - self.model_encoder_max_len | |
) | |
update_start_ind = max(update_start_ind + stride, update_end_ind) | |
# last window always should update until the end | |
update_end_ind = ( | |
min(total_seq_len, update_end_ind + stride) if context_end < total_seq_len else total_seq_len | |
) | |
cs, ce, us, ue = context_start, context_end, update_start_ind - context_start, \ | |
update_end_ind - context_start | |
results.append((cs, ce, us, ue)) | |
return results | |
def pre_generate_hook(self, input_ids, **kwargs): | |
if 'attention_mask' not in kwargs: | |
kwargs['attention_mask'] = torch.ones_like(input_ids) | |
self.reset_memory(input_ids, kwargs['attention_mask']) | |
new_kwargs = kwargs | |
if 'attention_mask' in kwargs: | |
new_kwargs = {k: v for k, v in kwargs.items() if k != 'attention_mask'} | |
new_kwargs['attention_mask'] = kwargs['attention_mask'][:, :self.actual_model_window_size].to(self.device) | |
new_kwargs['use_cache'] = True | |
if self.is_encoder_decoder: | |
input_ids_prefix = input_ids[:, :self.actual_model_window_size] | |
else: | |
input_ids_prefix = input_ids[:, -self.actual_model_window_size:] | |
input_ids_prefix = input_ids_prefix.to(self.device) | |
return self.original_generate_func(input_ids_prefix, **new_kwargs) | |
def pre_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs): | |
self.set_gradient_checkpointing(False) | |
if not self.is_input_encoding_pass: | |
if self.model.training: | |
# self.reset_memory(input_ids, attention_mask) | |
self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask) | |
input_ids = input_ids[:, :self.actual_model_window_size] | |
attention_mask = attention_mask[:, :self.actual_model_window_size] if attention_mask is not None else None | |
# input_ids = input_ids[:, :self.model_encoder_max_len] | |
# labels = labels[:, :self.model_encoder_max_len] if labels is not None else None | |
else: | |
if kwargs.get('past_key_values') is None: | |
self.is_first_test_decoding_step = True | |
if input_ids is not None: | |
# self.input_ids_size += input_ids.shape[-1] | |
self.input_ids_size += 1 | |
if kwargs.get('decoder_input_ids') is not None: | |
self.generated_input_ids = torch.cat([self.generated_input_ids, kwargs['decoder_input_ids']], axis=-1) | |
result = self.original_forward_func(input_ids=input_ids, labels=labels, attention_mask=attention_mask, **kwargs) | |
self.is_first_test_decoding_step = False | |
return result | |
def create_cross_attn_pre_forward_hook(self, original_cross_attn_forward_func, cur_layer_num): | |
def attention_pre_forward_hook(hidden_states, attention_mask=None, *args, **kwargs): | |
self.cur_decoder_layer_index = cur_layer_num | |
if kwargs.get('past_key_value') is not None: | |
# it's a tuple, and we convert it to a list to be able to perform assignment | |
# and modify its items from our attention_forward_hook | |
self.cur_layer_key_value_placeholder = \ | |
kwargs['past_key_value'] = list(kwargs['past_key_value']) # (batch, head, time, attn_dim) | |
batch_size, tgt_len, dim = hidden_states.shape | |
if self.model.training: | |
# from: (batch, tgt_len, dim) to: (batch * tgt_len, 1, dim) | |
hidden_states = hidden_states.reshape(-1, 1, hidden_states.shape[-1]) | |
# from: (batch, 1, tgt_len, dim) to: (batch * tgt_len, 1, 1, dim) | |
attention_mask = attention_mask.reshape(-1, 1, 1, attention_mask.shape[-1]) | |
attn_output, attn_weights_reshaped, past_key_value = original_cross_attn_forward_func(hidden_states=hidden_states, attention_mask=attention_mask, *args, **kwargs) | |
attn_output = attn_output.reshape(batch_size, tgt_len, dim) | |
result = (attn_output, attn_weights_reshaped, past_key_value) | |
else: | |
result = original_cross_attn_forward_func(hidden_states=hidden_states, attention_mask=attention_mask, *args, **kwargs) | |
# Uri: this part adds the generated tokens to the prompt. | |
# However it was commented out because currently we always keep the generated tokens in the attention window | |
# if not self.is_encoder_decoder and not self.is_input_encoding_pass and \ | |
# past_key_value[0].shape[2] > self.prompt_keys[self.cur_decoder_layer_index].shape[2]: | |
# self.prompt_keys[self.cur_decoder_layer_index] = torch.cat([self.prompt_keys[self.cur_decoder_layer_index], past_key_value[0][:,:,-1:]], dim=-2) | |
# self.prompt_values[self.cur_decoder_layer_index] = torch.cat([self.prompt_values[self.cur_decoder_layer_index], past_key_value[1][:,:,-1:]], dim=-2) | |
# if self.cur_decoder_layer_index == self.model.config.num_hidden_layers - 1: | |
# self.prompt_attention_mask = torch.cat([ | |
# self.prompt_attention_mask, | |
# torch.ones([self.prompt_attention_mask.shape[0], 1], dtype=self.prompt_attention_mask.dtype).to(self.device)], dim=-1) | |
return result | |
return attention_pre_forward_hook | |
def attention_forward_hook(self, module, input, output): | |
# output: (batch, time, 3 * heads * attention_dim) | |
if self.is_input_encoding_pass or self.is_first_test_decoding_step: | |
return | |
with torch.no_grad(): | |
prompt_size = self.prompt_input_ids.shape[1] | |
generated_size = self.input_ids_size - prompt_size | |
window_size = self.cur_layer_key_value_placeholder[0].shape[-2] | |
# topk = min(self.actual_model_window_size, attn_weights.shape[-1]) | |
topk = min(prompt_size, window_size) | |
if not self.is_encoder_decoder: | |
topk = min(topk, window_size - generated_size + 1) | |
if self.gpu_index: | |
topk = min(topk, 2048) | |
query = self.process_query(output)[:,-1] # (batch * beam, head, dim) | |
query = query[:, self.head_nums] # (batch * beam, head, dim) | |
if self.use_datastore: | |
# query: (batch, beam, head, dim) | |
# need to multiply by key vector | |
# query.view(query.shape[0], query.shape[1] * query.shape[2]) | |
# k_proj in attention? | |
datastore_index = 0 if self.is_encoder_decoder else self.cur_decoder_layer_index | |
attention_layer_list = self.get_kv_projections(self.layer_begin, self.layer_end) | |
k_proj_layer = [layers[0] for layers in attention_layer_list][self.cur_decoder_layer_index] | |
v_proj_layer = [layers[1] for layers in attention_layer_list][self.cur_decoder_layer_index] | |
# modify query by k_projs | |
k_proj = k_proj_layer.weight | |
datastore_query = self.preprocess_query(query, k_proj) # (batch * beam, num_heads, embed_dim) | |
batch_size = self.datastore[datastore_index].batch_size | |
datastore_query = datastore_query.view((batch_size, -1, datastore_query.shape[2])) # (batch, beam * num_heads, embed_dim) | |
# then search | |
if self.reconstruct_embeddings: | |
# embeddings: (batch, beam * head, actual_model_window_size, dim) | |
_, top_search_key_indices, embeddings = self.datastore[datastore_index].search_and_reconstruct(datastore_query, k=topk) | |
else: | |
_, top_search_key_indices = self.datastore[datastore_index].search(datastore_query, k=topk) | |
# self.embeddings: (batch, src_len, dim) | |
# indices: (batch, beam * head, actual_model_window_size) | |
# embeddings: (batch, beam * head, actual_model_window_size, dim) | |
embeddings = torch.take_along_dim(input=self.hidden_states[datastore_index].unsqueeze(1), | |
indices=top_search_key_indices.unsqueeze(-1).to(self.hidden_states[datastore_index].device), dim=-2) | |
embeddings = embeddings.to(self.device) | |
# (batch, beam, head, actual_model_window_size) | |
# top_search_key_scores = top_search_key_scores.reshape(batch_size, -1, *top_search_key_scores.shape[1:]) | |
top_search_key_indices = top_search_key_indices.reshape(batch_size, -1, *top_search_key_indices.shape[1:]) | |
# embeddings: (batch, beam, head, actual_model_window_size, dim) | |
embeddings = embeddings.reshape(batch_size, -1, self.num_heads, *embeddings.shape[2:]) | |
# raw_values are actually token indices; need to look them up | |
if (not self.use_datastore) or self.test_datastore: | |
this_layer_prompt_keys = self.prompt_keys[self.cur_decoder_layer_index] | |
this_layer_prompt_values = self.prompt_values[self.cur_decoder_layer_index] | |
# query: (batch * beam, head, dim) | |
batch_size = self.prompt_input_ids.shape[0] | |
beam_size = query.shape[0] // batch_size | |
# query: (batch, beam, head, dim) | |
query = query.reshape(batch_size, beam_size, *query.shape[1:]) | |
# this_layer_prompt_keys: (batch, head, source_len, dim) | |
# this_layer_prompt_keys.unsqueeze(1): (batch, 1, head, source_len, dim) | |
# query.unsqueeze(-1): (batch, beam, head, dim, 1) | |
# attn_weights: (batch, beam, head, source_len) | |
attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums], query.unsqueeze(-1)).squeeze(-1) | |
# attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums]).squeeze(-2) | |
prompt_attention_mask_to_add = (1 - self.prompt_attention_mask) * -1e9 # (batch, source_len) | |
prompt_attention_mask_to_add = prompt_attention_mask_to_add.unsqueeze(1).unsqueeze(1) | |
attn_weights += prompt_attention_mask_to_add # (batch, beam, head, source_len) | |
if self.exclude_attention and attn_weights.shape[-1] > self.actual_model_window_size: | |
attn_weights[..., :self.actual_model_window_size] -= 1e9 | |
# target_keys, target_values, topk = self.get_target_slices(output) | |
top_key_scores, top_key_indices = torch.topk(attn_weights, k=topk, dim=-1, sorted=True) # (batch, beam, head, trunc_source) | |
if self.save_heatmap: | |
# heatrow: (beam, heads, source_len) | |
heatrow = torch.zeros([top_key_indices.shape[1], top_key_indices.shape[2], this_layer_prompt_keys.shape[-2]], dtype=torch.float) | |
heatrow = heatrow.scatter(index=top_key_indices[0], src=torch.ones_like(top_key_scores[0]), dim=-1) | |
# heatrow = torch.nn.functional.softmax(heatrow, dim=-1) | |
# self.heatmap: (beam, heads, targets, source_len) | |
self.heatmap = torch.cat([self.heatmap, heatrow.unsqueeze(-2)], axis=-2) | |
if self.test_datastore: | |
assert top_key_indices.shape == top_search_key_indices.shape | |
assert torch.mean((top_key_indices == top_search_key_indices).float()) > 0.99 | |
if self.verbose: | |
if self.is_encoder_decoder: | |
for i, beam in enumerate(self.generated_input_ids): | |
print(f'({i}) Generated: {self.tokenizer.decode(beam)}') | |
# else: | |
# print(f'Generated: {self.tokenizer.decode(self.input_ids)}') | |
print() | |
if self.use_datastore: | |
# k_proj_layer.weight, v_proj_layer.weight: (embed_dim, embed_dim) | |
# embeddings: (batch, beam, head, encoder_len, embed_dim) | |
retrieved_keys, retrieved_values = self.post_process_retrieved(embeddings, k_proj_layer, v_proj_layer, top_search_key_indices) | |
else: | |
# this_layer_prompt_keys: (batch, head, source_len, dim) | |
# top_key_indices: (batch, beam, head, trunc_source) | |
retrieved_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
dim=-2) # (batch, head, trunc_source, attn_dim) | |
retrieved_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
dim=-2) # (batch, head, trunc_source, attn_dim) | |
if self.test_datastore: | |
correct_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
dim=-2) # (batch, head, trunc_source, attn_dim) | |
correct_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
dim=-2) # (batch, head, trunc_source, attn_dim) | |
assert correct_keys.shape == retrieved_keys.shape | |
assert correct_values.shape == retrieved_values.shape | |
assert torch.mean(torch.isclose(correct_keys, retrieved_keys, rtol=1e-3, atol=1e-3).float()) > 0.99 | |
assert torch.mean(torch.isclose(correct_values, retrieved_values, rtol=1e-3, atol=1e-3).float()) > 0.99 | |
# retrieved_keys, retrieved_values: (batch * beam, head, encoder_len, attn_dim) | |
retrieved_keys = retrieved_keys.flatten(0, 1)[:,:,:topk] | |
retrieved_values = retrieved_values.flatten(0, 1)[:,:,:topk] | |
self.cur_layer_key_value_placeholder[0] = torch.cat([retrieved_keys, self.cur_layer_key_value_placeholder[0][:,:,topk:]], dim=-2) | |
self.cur_layer_key_value_placeholder[1] = torch.cat([retrieved_values, self.cur_layer_key_value_placeholder[1][:,:,topk:]], dim=-2) | |
return | |
def train_attention_forward_hook(self, module, input, output): | |
# output: (batch, time, 3 * heads * attention_dim) | |
if self.is_input_encoding_pass or self.is_first_test_decoding_step: | |
return | |
this_layer_prompt_keys = self.cur_layer_key_value_placeholder[0] | |
this_layer_prompt_values = self.cur_layer_key_value_placeholder[1] | |
with torch.no_grad(): | |
query = self.process_query(output) # (batch * beam, tgt_len, head, dim) | |
# query = query[:, :, self.head_nums] # (batch * beam, head, dim) | |
# query: (batch * beam, tgt_len, head, dim) | |
batch_size = this_layer_prompt_keys.shape[0] | |
tgt_len = query.shape[0] // batch_size | |
# query: (batch, tgt, head, dim) | |
query = query.reshape(batch_size, tgt_len, *query.shape[2:]) | |
# this_layer_prompt_keys: (batch, head, source_len, dim) | |
# this_layer_prompt_keys.unsqueeze(1): (batch, 1, head, source_len, dim) | |
# attn_weights: (batch, tgt_len, head, 1, source_len) | |
# attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1).permute(0,1,2,4,3)) | |
attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1), query.unsqueeze(-1)) \ | |
.reshape(batch_size, tgt_len, query.shape[-2], 1, this_layer_prompt_keys.shape[-2]) | |
# attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums]).squeeze(-2) | |
prompt_attention_mask_to_add = (1 - self.long_inputs_mask) * -1e9 # (batch, source_len) | |
prompt_attention_mask_to_add = prompt_attention_mask_to_add.unsqueeze(1).unsqueeze(1).unsqueeze(1) | |
attn_weights += prompt_attention_mask_to_add # (batch, beam, head, source_len) | |
# target_keys, target_values, topk = self.get_target_slices(output) | |
topk = min(self.actual_model_window_size, attn_weights.shape[-1]) | |
top_key_scores, top_key_indices = torch.topk(attn_weights, k=min(topk, attn_weights.shape[-1]), dim=-1, sorted=True) # (batch, beam, head, tgt, trunc_source) | |
# this_layer_prompt_keys: (batch, head, source_len, dim) | |
# top_key_indices: (batch, tgt_len, head, 1, trunc_source) | |
new_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(2).unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
dim=-2) # (batch, tgt_len, head, 1, trunc_source, attn_dim) | |
new_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(2).unsqueeze(1), indices=top_key_indices.unsqueeze(-1), | |
dim=-2) # (batch, tgt_len, head, 1, trunc_source, attn_dim) | |
# (batch * beam, head, tgt_len, trunc_source, attn_dim) | |
self.cur_layer_key_value_placeholder[0] = new_keys.flatten(0, 1).squeeze(2) | |
self.cur_layer_key_value_placeholder[1] = new_values.flatten(0, 1).squeeze(2) | |
return | |
def preprocess_query(self, query, k_proj_weight): | |
k_proj = k_proj_weight.view(1, self.num_heads, query.shape[-1], k_proj_weight.shape[0]) # (1, num_heads, attn_dim, embed_dim) | |
datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim) | |
datastore_query = torch.matmul(datastore_query, k_proj) # (batch * beam, num_heads, 1, embed_dim) | |
datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim) | |
return datastore_query | |
def post_process_retrieved(self, embeddings, k_proj_layer, v_proj_layer, top_search_key_indices): | |
embed_dim = embeddings.shape[-1] | |
k_weight = k_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, 1, heads, embed_dim, attn_dim) | |
k_bias = 0 | |
if k_proj_layer.bias is not None: | |
k_bias = k_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) | |
v_weight = v_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, heads, embed_dim, attn_dim) | |
v_bias = 0 | |
if v_proj_layer.bias is not None: | |
v_bias = v_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) | |
# new_keys, new_values: (batch, beam, head, encoder_len, attn_dim) | |
retrieved_keys = torch.matmul(embeddings, k_weight) + k_bias # (beam, head, encoder_len, embed_dim) | |
retrieved_values = torch.matmul(embeddings, v_weight) + v_bias # (beam, head, encoder_len, embed_dim) | |
return retrieved_keys, retrieved_values | |
def set_gradient_checkpointing(self, value): | |
self.model.base_model.decoder.gradient_checkpointing = value | |
def reorder_cache_hook(self, past, beam_idx): | |
self.last_beam_idx = beam_idx | |
self.generated_input_ids = self.generated_input_ids[beam_idx] | |
for i, layer_prev_tokens in enumerate(self.prev_tokens): | |
if layer_prev_tokens is not None: | |
self.prev_tokens[i] = layer_prev_tokens.flatten(0, 1)[beam_idx].reshape(layer_prev_tokens.shape) | |
if self.save_heatmap and self.heatmap.numel() > 0: | |
self.heatmap = self.heatmap[beam_idx] | |
return self.original_reorder_cache_func(past, beam_idx) | |
def convert_model(cls, model, *args, **kwargs): | |
# if type(model.config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: | |
# elif type(model.config) in MODEL_WITH_LM_HEAD_MAPPING: | |
# else: | |
# raise ValueError(f'Unsupported model type: {type(model.config)}') | |
# if model.config.is_encoder_decoder: | |
# model_clone = AutoModelForSeq2SeqLM.from_config(model.config) | |
# else: | |
# model_clone = AutoModelForCausalLM.from_config(model.config) | |
# model_clone.load_state_dict(model.state_dict()).to(args.device) | |
type_to_class = { | |
BartModel: UnlimiformerBART, | |
BartForConditionalGeneration: UnlimiformerBART, | |
T5Model: UnlimiformerT5, | |
T5ForConditionalGeneration: UnlimiformerT5, | |
LEDModel: UnlimiformerLED, | |
LEDForConditionalGeneration: UnlimiformerLED, | |
# LlamaModel: UnlimiformerLLaMa, | |
# LlamaForCausalLM: UnlimiformerLLaMa, | |
} | |
type_to_class[type(model)](model, *args, **kwargs) | |
return model | |
def plot_heatmap(self, data, xticklabels='auto', yticklabels='auto'): | |
# data: (heads, targets, source_len) | |
import seaborn as sb | |
import matplotlib.pyplot as plt | |
# print('gat = np.array([') | |
# for row in data[0]: | |
# rowstr = ', '.join([f'{x:.2f}' for x in row]) | |
# print(f' [{rowstr}],') | |
# print(']') | |
# sb.set(font_scale=1.5, rc={'text.usetex': True}) | |
for i in range(data.shape[0]): | |
fig, axes = plt.subplots(1, 1, figsize=(40, 100)) | |
cur_ax = axes | |
axes.set_title(f'Head #{i}, length: {data.shape[2]}, target length: {data.shape[1]}') | |
cur_ax = axes | |
# annot = [[x for x in row] for row in data] | |
ax = sb.heatmap(data[i], annot=False, fmt='.2f', | |
xticklabels=512, yticklabels=yticklabels, ax=cur_ax) | |
ax.xaxis.tick_top() | |
plt.savefig(f'knns_head{i}.pdf') | |
# plt.savefig('gat_s10_contrast.pdf') | |
plt.show() | |
class UnlimiformerBART(Unlimiformer[BartModel]): | |
def __init__(self, model: BartModel, *args, **kwargs): | |
super().__init__(model, *args, **kwargs) | |
def create_key_value(self, encoder_hidden_states, decoder_layer): | |
# (batch, time, hidden_dim) | |
attention = decoder_layer.encoder_attn | |
# key, value: (batch, heads, time, attn_dim) | |
key = attention.k_proj(encoder_hidden_states) | |
key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
value = attention.v_proj(encoder_hidden_states) | |
value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
# key, value: (batch, heads, time, attn_dim) | |
return key, value | |
def process_key_value(self, capturers): | |
key_capturer, value_capturer = capturers | |
key, value = key_capturer.captured, value_capturer.captured | |
# (batch, time, heads, attn_dim) | |
attention = self.model.base_model.decoder.layers[-1].encoder_attn | |
# query, key, value: (batch, heads, time, attn_dim) | |
# query = query.view(query.shape[0], query.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
return key, value | |
def process_query(self, output): | |
# (batch, time, heads, attn_dim) | |
attention = self.model.base_model.decoder.layers[-1].encoder_attn | |
# query: (batch, heads, time, attn_dim) | |
# query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).contiguous() | |
return query | |
def get_kv_projections(self, layer_begin, layer_end): | |
return [ | |
[layer.encoder_attn.k_proj, layer.encoder_attn.v_proj] | |
for layer in self.model.base_model.decoder.layers[layer_begin:layer_end] | |
] | |
def activation_to_capture(self, layer_begin, layer_end): | |
if self.use_datastore: | |
return [self.model.base_model.encoder.layers[-1]] | |
else: | |
return self.get_kv_projections(layer_begin, layer_end) | |
def attention_op_to_run(self, layer_begin, layer_end): | |
return [ | |
layer.encoder_attn.q_proj | |
for layer in self.model.base_model.decoder.layers[layer_begin:layer_end] | |
] | |
def attention_layer_to_run(self, layer_begin, layer_end): | |
return self.model.base_model.decoder.layers[layer_begin:layer_end] | |
def self_attention(self, decoder_layer): | |
return decoder_layer.self_attn | |
def cross_attention(self, decoder_layer): | |
return decoder_layer.encoder_attn | |
def window_size(self): | |
return self.model.config.max_position_embeddings | |
def create_decoder_layer_args(self, hidden_states, attention_mask, encoder_hidden_states, | |
encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, | |
past_key_value, output_attentions, position_bias, | |
encoder_decoder_position_bias, use_cache, key, value): | |
args = {'hidden_states': hidden_states, | |
'attention_mask': attention_mask, | |
'encoder_hidden_states': encoder_hidden_states, | |
'encoder_attention_mask': encoder_attention_mask, | |
'layer_head_mask': layer_head_mask, | |
'cross_attn_layer_head_mask': cross_attn_layer_head_mask, | |
'past_key_value': (None, None, key, value), | |
'output_attentions': output_attentions, | |
'use_cache': use_cache,} | |
if key is None and value is None: | |
args['past_key_value'] = None | |
return args | |
class UnlimiformerT5(Unlimiformer[T5Model]): | |
def __init__(self, model: T5Model, *args, **kwargs): | |
super().__init__(model, *args, **kwargs) | |
def create_key_value(self, encoder_hidden_states, decoder_layer): | |
# (batch, time, hidden_dim) | |
attention = decoder_layer.layer[1].EncDecAttention | |
# key, value: (batch, heads, time, attn_dim) | |
key = attention.k(encoder_hidden_states) | |
key = key.view(key.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() | |
value = attention.v(encoder_hidden_states) | |
value = value.view(value.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() | |
return key, value | |
def process_key_value(self, capturers): | |
key_capturer, value_capturer = capturers | |
key, value = key_capturer.captured, value_capturer.captured | |
# (batch, time, heads, attn_dim) | |
attention = self.model.base_model.decoder.block[-1].layer[1].EncDecAttention | |
# query, key, value: (batch, heads, time, attn_dim) | |
# query = query.view(query.shape[0], query.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
key = key.view(key.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() | |
value = value.view(value.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() | |
return key, value | |
def process_query(self, output): | |
# (batch, time, heads, attn_dim) | |
attention = self.model.base_model.decoder.block[-1].layer[1].EncDecAttention | |
# query: (batch, heads, time, attn_dim) | |
query = output.view(output.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).contiguous() | |
return query | |
def get_kv_projections(self, layer_begin, layer_end): | |
return [ | |
[layer.layer[1].EncDecAttention.k, layer.layer[1].EncDecAttention.v] | |
for layer in self.model.base_model.decoder.block[layer_begin:layer_end] | |
] | |
def activation_to_capture(self, layer_begin, layer_end): | |
if self.use_datastore: | |
return [self.model.base_model.encoder.layers[-1]] | |
else: | |
return self.get_kv_projections(layer_begin, layer_end) | |
def attention_op_to_run(self, layer_begin, layer_end): | |
return [ | |
layer.layer[1].EncDecAttention.q | |
for layer in self.model.base_model.decoder.block[layer_begin:layer_end] | |
] | |
def attention_layer_to_run(self, layer_begin, layer_end): | |
return self.model.base_model.decoder.block[layer_begin:layer_end] | |
def self_attention(self, decoder_layer): | |
return decoder_layer.layer[0] | |
def cross_attention(self, decoder_layer): | |
return decoder_layer.layer[1] | |
def window_size(self): | |
try: | |
size = self.model.config.n_positions | |
except AttributeError: | |
size = 1024 | |
return size | |
def create_decoder_layer_args(self, hidden_states, attention_mask, encoder_hidden_states, | |
encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, | |
past_key_value, output_attentions, position_bias, | |
encoder_decoder_position_bias, use_cache, key, value): | |
args = {'hidden_states': hidden_states, | |
'attention_mask': attention_mask, | |
'position_bias': position_bias, | |
'encoder_hidden_states': encoder_hidden_states, | |
'encoder_attention_mask': encoder_attention_mask, | |
'encoder_decoder_position_bias': encoder_decoder_position_bias, | |
'layer_head_mask': layer_head_mask, | |
'cross_attn_layer_head_mask': cross_attn_layer_head_mask, | |
'past_key_value': (None, None, key, value), | |
'use_cache': use_cache, | |
'output_attentions': output_attentions} | |
if key is None and value is None: | |
args['past_key_value'] = None | |
return args | |
class UnlimiformerLED(UnlimiformerBART): | |
def __init__(self, model: LEDModel, *args, **kwargs): | |
super().__init__(model, *args, **kwargs) | |
def window_size(self): | |
return self.model.config.max_encoder_position_embeddings | |
# class UnlimiformerLLaMa(Unlimiformer[LlamaModel]): | |
# def __init__(self, model: LlamaModel, *args, **kwargs): | |
# super().__init__(model, *args, **kwargs) | |
# def get_kv_projections(self, layer_begin, layer_end): | |
# return [ | |
# [layer.self_attn.k_proj, layer.self_attn.v_proj] | |
# for layer in self.model.base_model.layers[layer_begin:layer_end] | |
# ] | |
# def activation_to_capture(self, layer_begin, layer_end): | |
# if self.use_datastore: | |
# return [ | |
# layer.input_layernorm | |
# for layer in self.model.base_model.layers[layer_begin:layer_end] | |
# ] | |
# else: | |
# return self.get_kv_projections(layer_begin, layer_end) | |
# def attention_op_to_run(self, layer_begin, layer_end): | |
# return [ | |
# layer.self_attn.q_proj | |
# for layer in self.model.base_model.layers[layer_begin:layer_end] | |
# ] | |
# def attention_layer_to_run(self, layer_begin, layer_end): | |
# return self.model.base_model.layers[layer_begin:layer_end] | |
# def self_attention(self, decoder_layer): | |
# return decoder_layer.self_attn | |
# def cross_attention(self, decoder_layer): | |
# return decoder_layer.self_attn | |
# def window_size(self): | |
# return self.model.config.max_position_embeddings | |
# def set_gradient_checkpointing(self, value): | |
# self.model.base_model.gradient_checkpointing = value | |
# def process_key_value(self, capturers): | |
# key_capturer, value_capturer = capturers | |
# # (batch, time, heads * attn_dim) | |
# key, value = key_capturer.captured, value_capturer.captured | |
# attention = self.model.base_model.layers[-1].self_attn | |
# # (batch, heads, time, attn_dim) | |
# key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
# value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() | |
# return key, value | |
# def process_query(self, output): | |
# # output: (batch, time, heads * attn_dim) | |
# attention = self.model.base_model.layers[-1].self_attn | |
# # query: (batch, time, heads, attn_dim) | |
# query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).contiguous() | |
# return query | |
# def rotate_half(self, x): | |
# """Rotates half the hidden dims of the input.""" | |
# x1 = x[..., : x.shape[-1] // 2] | |
# x2 = x[..., x.shape[-1] // 2 :] | |
# return torch.cat((-x2, x1), dim=-1) | |
# def preprocess_query(self, query, k_proj_weight): | |
# # query: (batch * time, head, dim) | |
# attention = self.model.base_model.layers[-1].self_attn | |
# num_generated = min(self.input_ids_size - self.prompt_input_ids.shape[1], self.actual_model_window_size) | |
# cos, sin = attention.rotary_emb(query, seq_len=num_generated) | |
# cos = cos[:,:,-1] # [1, 1, dim] | |
# sin = sin[:,:,-1] # [1, 1, dim] | |
# # cos = cos[-1].unsqueeze(0).unsqueeze(0) # [bs, 1, seq_len, dim] | |
# # sin = sin[-1].unsqueeze(0) # [bs, 1, seq_len, dim] | |
# query = (query * cos) + (self.rotate_half(query) * sin) | |
# k_proj = k_proj_weight.view(1, self.num_heads, query.shape[-1], k_proj_weight.shape[0]) # (1, num_heads, attn_dim, embed_dim) | |
# k_proj_l = k_proj[..., :k_proj.shape[-2] // 2, :] | |
# k_proj_r = k_proj[..., k_proj.shape[-2] // 2:, :] | |
# k_proj_rotated = torch.cat([-k_proj_l, k_proj_r], dim=-2) | |
# datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim) | |
# datastore_query = torch.matmul(datastore_query, k_proj + k_proj_rotated) # (batch * beam, num_heads, 1, embed_dim) | |
# datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim) | |
# return datastore_query | |
# def post_process_retrieved(self, embeddings, k_proj_layer, v_proj_layer, top_search_key_indices): | |
# embed_dim = embeddings.shape[-1] | |
# k_weight = k_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, 1, heads, embed_dim, attn_dim) | |
# k_bias = 0 | |
# if k_proj_layer.bias is not None: | |
# k_bias = k_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) | |
# v_weight = v_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, heads, embed_dim, attn_dim) | |
# v_bias = 0 | |
# if v_proj_layer.bias is not None: | |
# v_bias = v_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) | |
# # new_keys, new_values: (batch, beam, head, encoder_len, attn_dim) | |
# retrieved_keys = torch.matmul(embeddings, k_weight) + k_bias # (beam, head, encoder_len, embed_dim) | |
# retrieved_values = torch.matmul(embeddings, v_weight) + v_bias # (beam, head, encoder_len, embed_dim) | |
# attention = self.model.base_model.layers[-1].self_attn | |
# cos, sin = attention.rotary_emb(retrieved_values, seq_len=self.hidden_states[0].shape[1]) | |
# cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] | |
# sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] | |
# if self.prompt_input_ids.shape[1] > self.actual_model_window_size: | |
# # scale the top key indices to the actual model window size, such that the model will not see | |
# # positional embeddings that did not appear at training time | |
# scaled_key_indices = ((top_search_key_indices / self.prompt_input_ids.shape[1]) * self.actual_model_window_size).int() | |
# else: | |
# scaled_key_indices = top_search_key_indices | |
# # top_search_key_indices = top_search_key_indices.to(cos.device) | |
# scaled_key_indices = scaled_key_indices.to(cos.device) | |
# cos = cos[scaled_key_indices] # [bs, 1, seq_len, dim] | |
# sin = sin[scaled_key_indices] # [bs, 1, seq_len, dim] | |
# retrieved_keys = (retrieved_keys * cos) + (self.rotate_half(retrieved_keys) * sin) | |
# return retrieved_keys, retrieved_values | |
class ActivationCapturer(nn.Module): | |
def __init__(self, layer, capture_input=False): | |
super().__init__() | |
self.layer = layer | |
self.capture_input = capture_input | |
self.captured = None | |
def unwrap_tuple(self, t): | |
if isinstance(t, tuple) and len(t) == 1: | |
t = t[0] | |
return t | |
def forward(self, module, layer_input, layer_output): | |
if self.capture_input: | |
self.captured = self.unwrap_tuple(layer_input) | |
else: | |
self.captured = self.unwrap_tuple(layer_output) | |