lhzstar
new commits
abca9bf
raw
history blame
63.8 kB
import logging
import numpy as np
import torch
from torch import nn
from enum import Enum, auto
from transformers import BartModel, BartForConditionalGeneration, \
T5Model, T5ForConditionalGeneration, \
LEDModel, LEDForConditionalGeneration, \
AutoModelForCausalLM, AutoModelForSeq2SeqLM, \
MODEL_WITH_LM_HEAD_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
from typing import TypeVar, Generic
from .index_building import Datastore, DatastoreBatch
logger = logging.getLogger('Unlimiformer')
logger.setLevel(20)
ModelType = TypeVar('ModelType')
class Unlimiformer(Generic[ModelType]):
def __init__(self, model: ModelType,
layer_begin=-1, layer_end=None,
unlimiformer_head_num=None,
exclude_attention=False,
model_encoder_max_len=None,
chunk_overlap=0,
verbose=False, save_heatmap=False,
tokenizer=None, unlimiformer_training=False,
use_datastore=False,
flat_index=False,
test_datastore=False, reconstruct_embeddings=False,
gpu_datastore=False, gpu_index=False,
index_devices=(0,), datastore_device=0,
):
super().__init__()
self.model = model
model.unlimiformer = self
self.layer_begin = layer_begin
self.layer_end = layer_end
self.specific_head = unlimiformer_head_num
self.exclude_attention = exclude_attention
self.actual_model_window_size = None
self.model_encoder_max_len = model_encoder_max_len
self.chunk_overlap = chunk_overlap
self.verbose = verbose
self.save_heatmap = save_heatmap
self.tokenizer = tokenizer
self.unlimiformer_training = unlimiformer_training
self.use_datastore = use_datastore
self.flat_index = flat_index
self.reconstruct_embeddings = reconstruct_embeddings
self.gpu_datastore = gpu_datastore
self.gpu_index = gpu_index
# if torch.cuda.is_available() and gpu_index:
# self.index_devices = [torch.device(f'cuda:{i}') for i in index_devices]
# else:
self.index_devices = [torch.device('cpu')]
self.datastore_device = torch.device('cpu')
self.test_datastore = test_datastore # flag for debugging
self.device = torch.device('cpu')
self.activation_capturer = None
self.is_encoder_decoder = model.config.is_encoder_decoder
self.hook_handles = []
self.is_input_encoding_pass = False
self.is_first_test_decoding_step = False
self.prev_tokens = None
self.last_beam_idx = None
self.heatmap = None
self.cur_decoder_layer_index = None
self.datastore = None
self.break_into(model)
def break_into(self, model):
self.actual_model_window_size = self.window_size()
if self.model_encoder_max_len is None:
self.model_encoder_max_len = self.actual_model_window_size
self.window_margin = int(self.model_encoder_max_len * self.chunk_overlap / 2)
self.num_heads = model.config.num_attention_heads
if self.specific_head is None:
self.head_nums = Ellipsis # torch.arange(0, self.num_heads, device=self.device)
else:
self.head_nums = self.specific_head
self.hooks_injected = False
self.training_hooks_injected = False
self.original_forward_func = model.forward
# Activate Unlimiformer when calling model.eval(), deactivate for model.train()
self.original_model_eval_func = model.eval
model.eval = self.pre_eval_hook
self.original_model_train_func = model.train
model.train = self.pre_train_hook
def pre_eval_hook(self):
self.remove_training_hooks(self.model)
self.inject_hooks(self.model)
self.original_model_eval_func()
def pre_train_hook(self, mode=True):
# mode=True means model.train() is called
# mode=False means model.eval() is called
torch.cuda.empty_cache()
if mode is True:
self.break_out(self.model)
if self.unlimiformer_training:
self.inject_training_hooks(self.model)
self.original_model_train_func(mode)
def inject_hooks(self, model):
if self.hooks_injected:
return
# Inject our activation_capturer to capture the activations at every forward pass
attention_layers_to_capture = self.activation_to_capture(self.layer_begin, self.layer_end)
self.activation_capturer = []
for layer in attention_layers_to_capture:
if type(layer) is list:
layer_capturers = []
for k_or_v in layer:
capturer = ActivationCapturer(k_or_v, capture_input=False)
layer_capturers.append(capturer)
self.register_hook(k_or_v, capturer)
self.activation_capturer.append(layer_capturers)
else:
capturer = ActivationCapturer(layer, capture_input=False)
self.register_hook(layer, capturer)
self.activation_capturer.append(capturer)
# Inject our main function after the main attention function
attention_layers_to_run = self.attention_op_to_run(self.layer_begin, self.layer_end)
for layer in attention_layers_to_run:
self.register_hook(layer, self.attention_forward_hook)
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
self.original_decoder_layer_cross_attn_forward_funcs = []
for i, decoder_layer in enumerate(decoder_layers_to_run):
decoder_layer_cross_attention = self.cross_attention(decoder_layer)
self.original_decoder_layer_cross_attn_forward_funcs.append(decoder_layer_cross_attention.forward)
decoder_layer_cross_attention.forward = self.create_cross_attn_pre_forward_hook(decoder_layer_cross_attention.forward, i)
# Inject our hook function in the beginning of generation.
# When the "model.generate()" will be called, it will first call our "reset_generation()" function,
# and only then call "model.generate()"
self.original_generate_func = model.generate
model.generate = self.pre_generate_hook
model.forward = self.pre_forward_hook
self.original_reorder_cache_func = model._reorder_cache
model._reorder_cache = self.reorder_cache_hook
self.hooks_injected = True
def inject_training_hooks(self, model):
if self.training_hooks_injected:
return
# self.original_forward_func = model.forward
model.forward = self.pre_forward_hook
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
self.original_decoder_layer_self_attn_forward_funcs = []
for decoder_layer in decoder_layers_to_run:
attention = self.self_attention(decoder_layer)
self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward)
attention.forward = self.create_self_attn_pre_forward_hook(attention.forward)
self.original_decoder_layer_cross_attn_forward_funcs = []
for i, decoder_layer in enumerate(decoder_layers_to_run):
decoder_layer_cross_attention = self.cross_attention(decoder_layer)
self.original_decoder_layer_cross_attn_forward_funcs.append(decoder_layer_cross_attention.forward)
decoder_layer_cross_attention.forward = self.create_cross_attn_pre_forward_hook(decoder_layer_cross_attention.forward, i)
self.original_decoder_layer_forward_funcs = []
for decoder_layer in decoder_layers_to_run:
self.original_decoder_layer_forward_funcs.append(decoder_layer.forward)
decoder_layer.forward = self.create_decoder_layer_func(decoder_layer.forward, decoder_layer)
self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run)
attention_layers_to_run = self.attention_op_to_run(self.layer_begin, self.layer_end)
for layer in attention_layers_to_run:
self.register_hook(layer, self.train_attention_forward_hook)
self.training_hooks_injected = True
def inject_hooks_for_unaffected_layers(self, model, decoder_layers_to_run):
self.original_non_injected_decoder_layer_forward_funcs = []
non_injected_decoder_layers = [l for l in self.attention_layer_to_run(0, None)
if l not in decoder_layers_to_run]
for decoder_layer in non_injected_decoder_layers:
self.original_non_injected_decoder_layer_forward_funcs.append(decoder_layer.forward)
decoder_layer.forward = self.create_noninjected_decoder_layer_func(decoder_layer.forward, decoder_layer)
def create_self_attn_pre_forward_hook(self, original_self_attn_forward_func):
def self_attention_pre_forward_hook(*args, **kwargs):
kwargs['past_key_value'] = None
return original_self_attn_forward_func(*args, **kwargs)
return self_attention_pre_forward_hook
def create_decoder_layer_func(self, decoder_layer_original_forward_func, decoder_layer):
def checkpointed_decoder_layer(
hidden_states: torch.Tensor,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
output_attentions=False,
position_bias=None,
encoder_decoder_position_bias=None,
use_cache=True):
def forward_with_all_keys(hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, past_key_value,
output_attentions, use_cache, long_inputs, long_inputs_mask,
position_bias, encoder_decoder_position_bias):
key, value = self.create_key_value(long_inputs, decoder_layer)
decoder_layer_args = self.create_decoder_layer_args(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
use_cache=use_cache,
key=key,value=value)
return decoder_layer_original_forward_func(**decoder_layer_args)
return torch.utils.checkpoint.checkpoint(
forward_with_all_keys, hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, None,
output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask,
position_bias, encoder_decoder_position_bias)
return checkpointed_decoder_layer
def create_noninjected_decoder_layer_func(self, decoder_layer_original_forward_func, decoder_layer):
def checkpointed_decoder_layer(
hidden_states: torch.Tensor,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
output_attentions=False,
position_bias=None,
encoder_decoder_position_bias=None,
use_cache=True):
def forward_with_all_keys(hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, past_key_value,
output_attentions, use_cache, long_inputs, long_inputs_mask,
position_bias, encoder_decoder_position_bias):
decoder_layer_args = self.create_decoder_layer_args(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
use_cache=use_cache, key=None, value=None)
return decoder_layer_original_forward_func(**decoder_layer_args)
return torch.utils.checkpoint.checkpoint(
forward_with_all_keys, hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, None,
output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask,
position_bias, encoder_decoder_position_bias)
return checkpointed_decoder_layer
def register_hook(self, layer, func, pre=False):
handle = layer.register_forward_pre_hook(func) if pre else layer.register_forward_hook(func)
self.hook_handles.append(handle)
def break_out(self, model):
self.prompt_keys = []
self.prompt_values = []
self.prompt_attention_mask = []
self.generated_input_ids = []
torch.cuda.empty_cache()
if not self.hooks_injected:
return
for h in self.hook_handles:
h.remove()
model.generate = self.original_generate_func
model.forward = self.original_forward_func
model._reorder_cache = self.original_reorder_cache_func
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_cross_attn_forward_funcs):
self.cross_attention(decoder_layer).forward = original_func
self.hooks_injected = False
def remove_training_hooks(self, model):
self.long_inputs_encoded, self.long_inputs_mask = None, None
if not self.training_hooks_injected:
return
for h in self.hook_handles:
h.remove()
model.forward = self.original_forward_func
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_self_attn_forward_funcs):
self.self_attention(decoder_layer).forward = original_func
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_cross_attn_forward_funcs):
self.cross_attention(decoder_layer).forward = original_func
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_forward_funcs):
decoder_layer.forward = original_func
non_injected_decoder_layers = [l for l in self.attention_layer_to_run(0, None)
if l not in decoder_layers_to_run]
for decoder_layer, original_func in zip(non_injected_decoder_layers, self.original_non_injected_decoder_layer_forward_funcs):
decoder_layer.forward = original_func
self.training_hooks_injected = False
def reset_memory(self, input_ids, attention_mask):
if self.use_datastore:
if self.is_encoder_decoder:
self.datastore = [DatastoreBatch(dim=self.model.config.hidden_size, batch_size=input_ids.shape[0], flat_index=self.flat_index,
gpu_index=self.gpu_index, index_device=self.index_devices[0])]
self.hidden_states = [[]]
else:
self.datastore = [DatastoreBatch(dim=self.model.config.hidden_size, batch_size=input_ids.shape[0], flat_index=self.flat_index,
gpu_index=self.gpu_index, index_device=self.index_devices[i % len(self.index_devices)])
for i in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]]
self.hidden_states = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]]
torch.cuda.empty_cache()
self.prompt_input_ids = input_ids
self.input_ids_size = input_ids.shape[-1]
self.prompt_keys, self.prompt_values = None, None
self.prev_tokens = [None for _ in range(len(self.original_decoder_layer_cross_attn_forward_funcs))]
self.last_beam_idx = None
self.cur_layer_key_value_placeholder = None
self.is_input_encoding_pass = True
if self.is_encoder_decoder:
dummy_labels = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
else:
dummy_labels = None
if self.save_heatmap:
if self.heatmap is not None:
print(f'Generated: {self.tokenizer.decode(self.generated_input_ids[0])}')
self.plot_heatmap(self.heatmap[0].detach().cpu().numpy())
self.heatmap = torch.tensor([], dtype=torch.float, device=input_ids.device)
self.generated_input_ids = torch.tensor([], dtype=torch.long, device=input_ids.device)
self.prompt_keys = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]]
self.prompt_values = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]]
self.prompt_attention_mask = []
window_indices = self.window_indices(input_ids.shape[-1])
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
logger.info(f'Encoding {context_start_ind} to {context_end_ind} out of {input_ids.shape[-1]}')
chunk = input_ids[:, context_start_ind:context_end_ind].to(self.device)
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind].to(self.device)
with torch.inference_mode():
_ = self.model(chunk, attention_mask=chunk_attention_mask, labels=dummy_labels) # , return_dict=True, output_hidden_states=True)
if self.use_datastore:
# TODO: verify with BART as well
# hidden_states_to_index = [hidden_states.encoder_last_hidden_state] # list of length 1 of (batch, chunked_source_len, dim)
hidden_states_to_index = [
layer_capturer.captured for layer_capturer in self.activation_capturer
]
# hidden_states_to_index = list(hidden_states.hidden_states)[:-1][self.layer_begin:self.layer_end]
to_add = [state[:, update_start_ind:update_end_ind].detach() for state in hidden_states_to_index]
to_apply_mask = chunk_attention_mask[:, update_start_ind:update_end_ind]
# to_apply_mask = to_apply_mask.log().to(to_add[0].dtype)
to_apply_mask = to_apply_mask.to(to_add[0].dtype)
if not self.reconstruct_embeddings:
to_add_embeddings = to_add
if not self.gpu_datastore:
to_add_embeddings = [states.cpu() for states in to_add_embeddings]
to_apply_mask = to_apply_mask.cpu()
for i, layer_states in enumerate(to_add_embeddings):
layer_states = layer_states * to_apply_mask.unsqueeze(-1)
self.hidden_states[i].append(layer_states.to(self.datastore_device))
# list of len layers, inside it there is a list of len batch, each item is (masked_time, dim)
# for i, to_add_layer in enumerate(to_add):
# keys = [key[mask.bool()] for key, mask in zip(to_add_layer, to_apply_mask)]
# self.datastore[i].add_keys(keys)
if (not self.use_datastore) or self.test_datastore:
layers_kv = [
self.process_key_value(layer_capturer) # (batch, head, time, dim)
for layer_capturer in self.activation_capturer
] # list of pairs of (batch, head, time, dim)
# list of (batch, head, chunked_time, dim)
key = [layer[0][:, :, update_start_ind:update_end_ind] for layer in layers_kv]
value = [layer[1][:, :, update_start_ind:update_end_ind] for layer in layers_kv]
chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time)
# key = torch.stack(key, dim=0) # (num_layers, batch, head, time, dim)
# value = torch.stack(value, dim=0) # (num_layers, batch, head, time, dim)
for i, (layer_key, layer_value) in enumerate(zip(key, value)):
self.prompt_keys[i].append(layer_key) # (num_layers, batch, head, chunked_source_len, dim)
self.prompt_values[i].append(layer_value) # (num_layers, batch, head, chunked_source_len, dim)
self.prompt_attention_mask.append(chunk_attention_mask) # (batch, chunked_source_len)
if self.use_datastore:
# keys are all in datastore already!
if not self.reconstruct_embeddings:
# self.hidden_states = [torch.cat(layer_hidden_states, axis=1) for layer_hidden_states in self.hidden_states]
concat_hidden_states = []
for i in range(len(self.hidden_states)):
concat_hidden_states.append(torch.cat(self.hidden_states[i], axis=1))
self.hidden_states[i] = None
self.hidden_states = concat_hidden_states
for datastore, layer_hidden_states in zip(self.datastore, self.hidden_states):
datastore.train_index(layer_hidden_states)
if (not self.use_datastore) or self.test_datastore:
for i, (layer_keys, layer_values) in enumerate(zip(self.prompt_keys, self.prompt_values)):
self.prompt_keys[i] = torch.cat(layer_keys, dim=-2)
self.prompt_values[i] = torch.cat(layer_values, dim=-2)
# self.prompt_keys = torch.cat(self.prompt_keys, dim=-2) # (num_layers, batch, head, source_len, dim)
# self.prompt_values = torch.cat(self.prompt_values, dim=-2) # (num_layers, batch, head, source_len, dim)
self.prompt_attention_mask = torch.cat(self.prompt_attention_mask, dim=-1) # (batch, source_len)
self.is_input_encoding_pass = False
if self.verbose:
print(f'Input: '
f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| '
f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}')
print()
def chunked_encode_input(self, input_ids, attention_mask):
long_inputs_encoded = []
long_inputs_mask = []
window_indices = self.window_indices(input_ids.shape[-1])
self.is_input_encoding_pass = True
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
chunk = input_ids[:, context_start_ind:context_end_ind]
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind]
output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True)
encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim)
# list of (batch, head, chunked_time, dim)
encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim)
chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time)
long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim)
long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len)
long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim)
long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len)
self.is_input_encoding_pass = False
if self.verbose:
print(f'Input: '
f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| '
f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}')
print()
return long_inputs_encoded, long_inputs_mask
def window_indices(self, total_seq_len):
# Copied from SLED (Ivgy et al., 2022)
# https://github.com/Mivg/SLED/blob/main/sled/modeling_sled.py#L467
if total_seq_len <= self.model_encoder_max_len:
return [(0, total_seq_len, 0, total_seq_len)]
else:
results = []
# if self.chunk_overlap == 0:
# stride = self.model_encoder_max_len
stride = self.model_encoder_max_len - 2 * self.window_margin
context_start = update_start_ind = 0
context_end = self.model_encoder_max_len
if self.is_encoder_decoder:
update_end_ind = context_end - self.window_margin
else:
update_end_ind = context_end
# first window always should update from the beginning
results.append((context_start, context_end, update_start_ind, update_end_ind))
while context_end < total_seq_len:
context_end = min(total_seq_len, context_end + stride)
context_start = (
context_start + stride if context_end < total_seq_len else total_seq_len - self.model_encoder_max_len
)
update_start_ind = max(update_start_ind + stride, update_end_ind)
# last window always should update until the end
update_end_ind = (
min(total_seq_len, update_end_ind + stride) if context_end < total_seq_len else total_seq_len
)
cs, ce, us, ue = context_start, context_end, update_start_ind - context_start, \
update_end_ind - context_start
results.append((cs, ce, us, ue))
return results
def pre_generate_hook(self, input_ids, **kwargs):
if 'attention_mask' not in kwargs:
kwargs['attention_mask'] = torch.ones_like(input_ids)
self.reset_memory(input_ids, kwargs['attention_mask'])
new_kwargs = kwargs
if 'attention_mask' in kwargs:
new_kwargs = {k: v for k, v in kwargs.items() if k != 'attention_mask'}
new_kwargs['attention_mask'] = kwargs['attention_mask'][:, :self.actual_model_window_size].to(self.device)
new_kwargs['use_cache'] = True
if self.is_encoder_decoder:
input_ids_prefix = input_ids[:, :self.actual_model_window_size]
else:
input_ids_prefix = input_ids[:, -self.actual_model_window_size:]
input_ids_prefix = input_ids_prefix.to(self.device)
return self.original_generate_func(input_ids_prefix, **new_kwargs)
def pre_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
self.set_gradient_checkpointing(False)
if not self.is_input_encoding_pass:
if self.model.training:
# self.reset_memory(input_ids, attention_mask)
self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask)
input_ids = input_ids[:, :self.actual_model_window_size]
attention_mask = attention_mask[:, :self.actual_model_window_size] if attention_mask is not None else None
# input_ids = input_ids[:, :self.model_encoder_max_len]
# labels = labels[:, :self.model_encoder_max_len] if labels is not None else None
else:
if kwargs.get('past_key_values') is None:
self.is_first_test_decoding_step = True
if input_ids is not None:
# self.input_ids_size += input_ids.shape[-1]
self.input_ids_size += 1
if kwargs.get('decoder_input_ids') is not None:
self.generated_input_ids = torch.cat([self.generated_input_ids, kwargs['decoder_input_ids']], axis=-1)
result = self.original_forward_func(input_ids=input_ids, labels=labels, attention_mask=attention_mask, **kwargs)
self.is_first_test_decoding_step = False
return result
def create_cross_attn_pre_forward_hook(self, original_cross_attn_forward_func, cur_layer_num):
def attention_pre_forward_hook(hidden_states, attention_mask=None, *args, **kwargs):
self.cur_decoder_layer_index = cur_layer_num
if kwargs.get('past_key_value') is not None:
# it's a tuple, and we convert it to a list to be able to perform assignment
# and modify its items from our attention_forward_hook
self.cur_layer_key_value_placeholder = \
kwargs['past_key_value'] = list(kwargs['past_key_value']) # (batch, head, time, attn_dim)
batch_size, tgt_len, dim = hidden_states.shape
if self.model.training:
# from: (batch, tgt_len, dim) to: (batch * tgt_len, 1, dim)
hidden_states = hidden_states.reshape(-1, 1, hidden_states.shape[-1])
# from: (batch, 1, tgt_len, dim) to: (batch * tgt_len, 1, 1, dim)
attention_mask = attention_mask.reshape(-1, 1, 1, attention_mask.shape[-1])
attn_output, attn_weights_reshaped, past_key_value = original_cross_attn_forward_func(hidden_states=hidden_states, attention_mask=attention_mask, *args, **kwargs)
attn_output = attn_output.reshape(batch_size, tgt_len, dim)
result = (attn_output, attn_weights_reshaped, past_key_value)
else:
result = original_cross_attn_forward_func(hidden_states=hidden_states, attention_mask=attention_mask, *args, **kwargs)
# Uri: this part adds the generated tokens to the prompt.
# However it was commented out because currently we always keep the generated tokens in the attention window
# if not self.is_encoder_decoder and not self.is_input_encoding_pass and \
# past_key_value[0].shape[2] > self.prompt_keys[self.cur_decoder_layer_index].shape[2]:
# self.prompt_keys[self.cur_decoder_layer_index] = torch.cat([self.prompt_keys[self.cur_decoder_layer_index], past_key_value[0][:,:,-1:]], dim=-2)
# self.prompt_values[self.cur_decoder_layer_index] = torch.cat([self.prompt_values[self.cur_decoder_layer_index], past_key_value[1][:,:,-1:]], dim=-2)
# if self.cur_decoder_layer_index == self.model.config.num_hidden_layers - 1:
# self.prompt_attention_mask = torch.cat([
# self.prompt_attention_mask,
# torch.ones([self.prompt_attention_mask.shape[0], 1], dtype=self.prompt_attention_mask.dtype).to(self.device)], dim=-1)
return result
return attention_pre_forward_hook
def attention_forward_hook(self, module, input, output):
# output: (batch, time, 3 * heads * attention_dim)
if self.is_input_encoding_pass or self.is_first_test_decoding_step:
return
with torch.no_grad():
prompt_size = self.prompt_input_ids.shape[1]
generated_size = self.input_ids_size - prompt_size
window_size = self.cur_layer_key_value_placeholder[0].shape[-2]
# topk = min(self.actual_model_window_size, attn_weights.shape[-1])
topk = min(prompt_size, window_size)
if not self.is_encoder_decoder:
topk = min(topk, window_size - generated_size + 1)
if self.gpu_index:
topk = min(topk, 2048)
query = self.process_query(output)[:,-1] # (batch * beam, head, dim)
query = query[:, self.head_nums] # (batch * beam, head, dim)
if self.use_datastore:
# query: (batch, beam, head, dim)
# need to multiply by key vector
# query.view(query.shape[0], query.shape[1] * query.shape[2])
# k_proj in attention?
datastore_index = 0 if self.is_encoder_decoder else self.cur_decoder_layer_index
attention_layer_list = self.get_kv_projections(self.layer_begin, self.layer_end)
k_proj_layer = [layers[0] for layers in attention_layer_list][self.cur_decoder_layer_index]
v_proj_layer = [layers[1] for layers in attention_layer_list][self.cur_decoder_layer_index]
# modify query by k_projs
k_proj = k_proj_layer.weight
datastore_query = self.preprocess_query(query, k_proj) # (batch * beam, num_heads, embed_dim)
batch_size = self.datastore[datastore_index].batch_size
datastore_query = datastore_query.view((batch_size, -1, datastore_query.shape[2])) # (batch, beam * num_heads, embed_dim)
# then search
if self.reconstruct_embeddings:
# embeddings: (batch, beam * head, actual_model_window_size, dim)
_, top_search_key_indices, embeddings = self.datastore[datastore_index].search_and_reconstruct(datastore_query, k=topk)
else:
_, top_search_key_indices = self.datastore[datastore_index].search(datastore_query, k=topk)
# self.embeddings: (batch, src_len, dim)
# indices: (batch, beam * head, actual_model_window_size)
# embeddings: (batch, beam * head, actual_model_window_size, dim)
embeddings = torch.take_along_dim(input=self.hidden_states[datastore_index].unsqueeze(1),
indices=top_search_key_indices.unsqueeze(-1).to(self.hidden_states[datastore_index].device), dim=-2)
embeddings = embeddings.to(self.device)
# (batch, beam, head, actual_model_window_size)
# top_search_key_scores = top_search_key_scores.reshape(batch_size, -1, *top_search_key_scores.shape[1:])
top_search_key_indices = top_search_key_indices.reshape(batch_size, -1, *top_search_key_indices.shape[1:])
# embeddings: (batch, beam, head, actual_model_window_size, dim)
embeddings = embeddings.reshape(batch_size, -1, self.num_heads, *embeddings.shape[2:])
# raw_values are actually token indices; need to look them up
if (not self.use_datastore) or self.test_datastore:
this_layer_prompt_keys = self.prompt_keys[self.cur_decoder_layer_index]
this_layer_prompt_values = self.prompt_values[self.cur_decoder_layer_index]
# query: (batch * beam, head, dim)
batch_size = self.prompt_input_ids.shape[0]
beam_size = query.shape[0] // batch_size
# query: (batch, beam, head, dim)
query = query.reshape(batch_size, beam_size, *query.shape[1:])
# this_layer_prompt_keys: (batch, head, source_len, dim)
# this_layer_prompt_keys.unsqueeze(1): (batch, 1, head, source_len, dim)
# query.unsqueeze(-1): (batch, beam, head, dim, 1)
# attn_weights: (batch, beam, head, source_len)
attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums], query.unsqueeze(-1)).squeeze(-1)
# attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums]).squeeze(-2)
prompt_attention_mask_to_add = (1 - self.prompt_attention_mask) * -1e9 # (batch, source_len)
prompt_attention_mask_to_add = prompt_attention_mask_to_add.unsqueeze(1).unsqueeze(1)
attn_weights += prompt_attention_mask_to_add # (batch, beam, head, source_len)
if self.exclude_attention and attn_weights.shape[-1] > self.actual_model_window_size:
attn_weights[..., :self.actual_model_window_size] -= 1e9
# target_keys, target_values, topk = self.get_target_slices(output)
top_key_scores, top_key_indices = torch.topk(attn_weights, k=topk, dim=-1, sorted=True) # (batch, beam, head, trunc_source)
if self.save_heatmap:
# heatrow: (beam, heads, source_len)
heatrow = torch.zeros([top_key_indices.shape[1], top_key_indices.shape[2], this_layer_prompt_keys.shape[-2]], dtype=torch.float)
heatrow = heatrow.scatter(index=top_key_indices[0], src=torch.ones_like(top_key_scores[0]), dim=-1)
# heatrow = torch.nn.functional.softmax(heatrow, dim=-1)
# self.heatmap: (beam, heads, targets, source_len)
self.heatmap = torch.cat([self.heatmap, heatrow.unsqueeze(-2)], axis=-2)
if self.test_datastore:
assert top_key_indices.shape == top_search_key_indices.shape
assert torch.mean((top_key_indices == top_search_key_indices).float()) > 0.99
if self.verbose:
if self.is_encoder_decoder:
for i, beam in enumerate(self.generated_input_ids):
print(f'({i}) Generated: {self.tokenizer.decode(beam)}')
# else:
# print(f'Generated: {self.tokenizer.decode(self.input_ids)}')
print()
if self.use_datastore:
# k_proj_layer.weight, v_proj_layer.weight: (embed_dim, embed_dim)
# embeddings: (batch, beam, head, encoder_len, embed_dim)
retrieved_keys, retrieved_values = self.post_process_retrieved(embeddings, k_proj_layer, v_proj_layer, top_search_key_indices)
else:
# this_layer_prompt_keys: (batch, head, source_len, dim)
# top_key_indices: (batch, beam, head, trunc_source)
retrieved_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, head, trunc_source, attn_dim)
retrieved_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, head, trunc_source, attn_dim)
if self.test_datastore:
correct_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, head, trunc_source, attn_dim)
correct_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, head, trunc_source, attn_dim)
assert correct_keys.shape == retrieved_keys.shape
assert correct_values.shape == retrieved_values.shape
assert torch.mean(torch.isclose(correct_keys, retrieved_keys, rtol=1e-3, atol=1e-3).float()) > 0.99
assert torch.mean(torch.isclose(correct_values, retrieved_values, rtol=1e-3, atol=1e-3).float()) > 0.99
# retrieved_keys, retrieved_values: (batch * beam, head, encoder_len, attn_dim)
retrieved_keys = retrieved_keys.flatten(0, 1)[:,:,:topk]
retrieved_values = retrieved_values.flatten(0, 1)[:,:,:topk]
self.cur_layer_key_value_placeholder[0] = torch.cat([retrieved_keys, self.cur_layer_key_value_placeholder[0][:,:,topk:]], dim=-2)
self.cur_layer_key_value_placeholder[1] = torch.cat([retrieved_values, self.cur_layer_key_value_placeholder[1][:,:,topk:]], dim=-2)
return
def train_attention_forward_hook(self, module, input, output):
# output: (batch, time, 3 * heads * attention_dim)
if self.is_input_encoding_pass or self.is_first_test_decoding_step:
return
this_layer_prompt_keys = self.cur_layer_key_value_placeholder[0]
this_layer_prompt_values = self.cur_layer_key_value_placeholder[1]
with torch.no_grad():
query = self.process_query(output) # (batch * beam, tgt_len, head, dim)
# query = query[:, :, self.head_nums] # (batch * beam, head, dim)
# query: (batch * beam, tgt_len, head, dim)
batch_size = this_layer_prompt_keys.shape[0]
tgt_len = query.shape[0] // batch_size
# query: (batch, tgt, head, dim)
query = query.reshape(batch_size, tgt_len, *query.shape[2:])
# this_layer_prompt_keys: (batch, head, source_len, dim)
# this_layer_prompt_keys.unsqueeze(1): (batch, 1, head, source_len, dim)
# attn_weights: (batch, tgt_len, head, 1, source_len)
# attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1).permute(0,1,2,4,3))
attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1), query.unsqueeze(-1)) \
.reshape(batch_size, tgt_len, query.shape[-2], 1, this_layer_prompt_keys.shape[-2])
# attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums]).squeeze(-2)
prompt_attention_mask_to_add = (1 - self.long_inputs_mask) * -1e9 # (batch, source_len)
prompt_attention_mask_to_add = prompt_attention_mask_to_add.unsqueeze(1).unsqueeze(1).unsqueeze(1)
attn_weights += prompt_attention_mask_to_add # (batch, beam, head, source_len)
# target_keys, target_values, topk = self.get_target_slices(output)
topk = min(self.actual_model_window_size, attn_weights.shape[-1])
top_key_scores, top_key_indices = torch.topk(attn_weights, k=min(topk, attn_weights.shape[-1]), dim=-1, sorted=True) # (batch, beam, head, tgt, trunc_source)
# this_layer_prompt_keys: (batch, head, source_len, dim)
# top_key_indices: (batch, tgt_len, head, 1, trunc_source)
new_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(2).unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, tgt_len, head, 1, trunc_source, attn_dim)
new_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(2).unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, tgt_len, head, 1, trunc_source, attn_dim)
# (batch * beam, head, tgt_len, trunc_source, attn_dim)
self.cur_layer_key_value_placeholder[0] = new_keys.flatten(0, 1).squeeze(2)
self.cur_layer_key_value_placeholder[1] = new_values.flatten(0, 1).squeeze(2)
return
def preprocess_query(self, query, k_proj_weight):
k_proj = k_proj_weight.view(1, self.num_heads, query.shape[-1], k_proj_weight.shape[0]) # (1, num_heads, attn_dim, embed_dim)
datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim)
datastore_query = torch.matmul(datastore_query, k_proj) # (batch * beam, num_heads, 1, embed_dim)
datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim)
return datastore_query
def post_process_retrieved(self, embeddings, k_proj_layer, v_proj_layer, top_search_key_indices):
embed_dim = embeddings.shape[-1]
k_weight = k_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, 1, heads, embed_dim, attn_dim)
k_bias = 0
if k_proj_layer.bias is not None:
k_bias = k_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0)
v_weight = v_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, heads, embed_dim, attn_dim)
v_bias = 0
if v_proj_layer.bias is not None:
v_bias = v_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0)
# new_keys, new_values: (batch, beam, head, encoder_len, attn_dim)
retrieved_keys = torch.matmul(embeddings, k_weight) + k_bias # (beam, head, encoder_len, embed_dim)
retrieved_values = torch.matmul(embeddings, v_weight) + v_bias # (beam, head, encoder_len, embed_dim)
return retrieved_keys, retrieved_values
def set_gradient_checkpointing(self, value):
self.model.base_model.decoder.gradient_checkpointing = value
def reorder_cache_hook(self, past, beam_idx):
self.last_beam_idx = beam_idx
self.generated_input_ids = self.generated_input_ids[beam_idx]
for i, layer_prev_tokens in enumerate(self.prev_tokens):
if layer_prev_tokens is not None:
self.prev_tokens[i] = layer_prev_tokens.flatten(0, 1)[beam_idx].reshape(layer_prev_tokens.shape)
if self.save_heatmap and self.heatmap.numel() > 0:
self.heatmap = self.heatmap[beam_idx]
return self.original_reorder_cache_func(past, beam_idx)
@classmethod
def convert_model(cls, model, *args, **kwargs):
# if type(model.config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
# elif type(model.config) in MODEL_WITH_LM_HEAD_MAPPING:
# else:
# raise ValueError(f'Unsupported model type: {type(model.config)}')
# if model.config.is_encoder_decoder:
# model_clone = AutoModelForSeq2SeqLM.from_config(model.config)
# else:
# model_clone = AutoModelForCausalLM.from_config(model.config)
# model_clone.load_state_dict(model.state_dict()).to(args.device)
type_to_class = {
BartModel: UnlimiformerBART,
BartForConditionalGeneration: UnlimiformerBART,
T5Model: UnlimiformerT5,
T5ForConditionalGeneration: UnlimiformerT5,
LEDModel: UnlimiformerLED,
LEDForConditionalGeneration: UnlimiformerLED,
# LlamaModel: UnlimiformerLLaMa,
# LlamaForCausalLM: UnlimiformerLLaMa,
}
type_to_class[type(model)](model, *args, **kwargs)
return model
def plot_heatmap(self, data, xticklabels='auto', yticklabels='auto'):
# data: (heads, targets, source_len)
import seaborn as sb
import matplotlib.pyplot as plt
# print('gat = np.array([')
# for row in data[0]:
# rowstr = ', '.join([f'{x:.2f}' for x in row])
# print(f' [{rowstr}],')
# print(']')
# sb.set(font_scale=1.5, rc={'text.usetex': True})
for i in range(data.shape[0]):
fig, axes = plt.subplots(1, 1, figsize=(40, 100))
cur_ax = axes
axes.set_title(f'Head #{i}, length: {data.shape[2]}, target length: {data.shape[1]}')
cur_ax = axes
# annot = [[x for x in row] for row in data]
ax = sb.heatmap(data[i], annot=False, fmt='.2f',
xticklabels=512, yticklabels=yticklabels, ax=cur_ax)
ax.xaxis.tick_top()
plt.savefig(f'knns_head{i}.pdf')
# plt.savefig('gat_s10_contrast.pdf')
plt.show()
class UnlimiformerBART(Unlimiformer[BartModel]):
def __init__(self, model: BartModel, *args, **kwargs):
super().__init__(model, *args, **kwargs)
def create_key_value(self, encoder_hidden_states, decoder_layer):
# (batch, time, hidden_dim)
attention = decoder_layer.encoder_attn
# key, value: (batch, heads, time, attn_dim)
key = attention.k_proj(encoder_hidden_states)
key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
value = attention.v_proj(encoder_hidden_states)
value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
# key, value: (batch, heads, time, attn_dim)
return key, value
def process_key_value(self, capturers):
key_capturer, value_capturer = capturers
key, value = key_capturer.captured, value_capturer.captured
# (batch, time, heads, attn_dim)
attention = self.model.base_model.decoder.layers[-1].encoder_attn
# query, key, value: (batch, heads, time, attn_dim)
# query = query.view(query.shape[0], query.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
return key, value
def process_query(self, output):
# (batch, time, heads, attn_dim)
attention = self.model.base_model.decoder.layers[-1].encoder_attn
# query: (batch, heads, time, attn_dim)
# query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).contiguous()
return query
def get_kv_projections(self, layer_begin, layer_end):
return [
[layer.encoder_attn.k_proj, layer.encoder_attn.v_proj]
for layer in self.model.base_model.decoder.layers[layer_begin:layer_end]
]
def activation_to_capture(self, layer_begin, layer_end):
if self.use_datastore:
return [self.model.base_model.encoder.layers[-1]]
else:
return self.get_kv_projections(layer_begin, layer_end)
def attention_op_to_run(self, layer_begin, layer_end):
return [
layer.encoder_attn.q_proj
for layer in self.model.base_model.decoder.layers[layer_begin:layer_end]
]
def attention_layer_to_run(self, layer_begin, layer_end):
return self.model.base_model.decoder.layers[layer_begin:layer_end]
def self_attention(self, decoder_layer):
return decoder_layer.self_attn
def cross_attention(self, decoder_layer):
return decoder_layer.encoder_attn
def window_size(self):
return self.model.config.max_position_embeddings
def create_decoder_layer_args(self, hidden_states, attention_mask, encoder_hidden_states,
encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask,
past_key_value, output_attentions, position_bias,
encoder_decoder_position_bias, use_cache, key, value):
args = {'hidden_states': hidden_states,
'attention_mask': attention_mask,
'encoder_hidden_states': encoder_hidden_states,
'encoder_attention_mask': encoder_attention_mask,
'layer_head_mask': layer_head_mask,
'cross_attn_layer_head_mask': cross_attn_layer_head_mask,
'past_key_value': (None, None, key, value),
'output_attentions': output_attentions,
'use_cache': use_cache,}
if key is None and value is None:
args['past_key_value'] = None
return args
class UnlimiformerT5(Unlimiformer[T5Model]):
def __init__(self, model: T5Model, *args, **kwargs):
super().__init__(model, *args, **kwargs)
def create_key_value(self, encoder_hidden_states, decoder_layer):
# (batch, time, hidden_dim)
attention = decoder_layer.layer[1].EncDecAttention
# key, value: (batch, heads, time, attn_dim)
key = attention.k(encoder_hidden_states)
key = key.view(key.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous()
value = attention.v(encoder_hidden_states)
value = value.view(value.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous()
return key, value
def process_key_value(self, capturers):
key_capturer, value_capturer = capturers
key, value = key_capturer.captured, value_capturer.captured
# (batch, time, heads, attn_dim)
attention = self.model.base_model.decoder.block[-1].layer[1].EncDecAttention
# query, key, value: (batch, heads, time, attn_dim)
# query = query.view(query.shape[0], query.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
key = key.view(key.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous()
value = value.view(value.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous()
return key, value
def process_query(self, output):
# (batch, time, heads, attn_dim)
attention = self.model.base_model.decoder.block[-1].layer[1].EncDecAttention
# query: (batch, heads, time, attn_dim)
query = output.view(output.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).contiguous()
return query
def get_kv_projections(self, layer_begin, layer_end):
return [
[layer.layer[1].EncDecAttention.k, layer.layer[1].EncDecAttention.v]
for layer in self.model.base_model.decoder.block[layer_begin:layer_end]
]
def activation_to_capture(self, layer_begin, layer_end):
if self.use_datastore:
return [self.model.base_model.encoder.layers[-1]]
else:
return self.get_kv_projections(layer_begin, layer_end)
def attention_op_to_run(self, layer_begin, layer_end):
return [
layer.layer[1].EncDecAttention.q
for layer in self.model.base_model.decoder.block[layer_begin:layer_end]
]
def attention_layer_to_run(self, layer_begin, layer_end):
return self.model.base_model.decoder.block[layer_begin:layer_end]
def self_attention(self, decoder_layer):
return decoder_layer.layer[0]
def cross_attention(self, decoder_layer):
return decoder_layer.layer[1]
def window_size(self):
try:
size = self.model.config.n_positions
except AttributeError:
size = 1024
return size
def create_decoder_layer_args(self, hidden_states, attention_mask, encoder_hidden_states,
encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask,
past_key_value, output_attentions, position_bias,
encoder_decoder_position_bias, use_cache, key, value):
args = {'hidden_states': hidden_states,
'attention_mask': attention_mask,
'position_bias': position_bias,
'encoder_hidden_states': encoder_hidden_states,
'encoder_attention_mask': encoder_attention_mask,
'encoder_decoder_position_bias': encoder_decoder_position_bias,
'layer_head_mask': layer_head_mask,
'cross_attn_layer_head_mask': cross_attn_layer_head_mask,
'past_key_value': (None, None, key, value),
'use_cache': use_cache,
'output_attentions': output_attentions}
if key is None and value is None:
args['past_key_value'] = None
return args
class UnlimiformerLED(UnlimiformerBART):
def __init__(self, model: LEDModel, *args, **kwargs):
super().__init__(model, *args, **kwargs)
def window_size(self):
return self.model.config.max_encoder_position_embeddings
# class UnlimiformerLLaMa(Unlimiformer[LlamaModel]):
# def __init__(self, model: LlamaModel, *args, **kwargs):
# super().__init__(model, *args, **kwargs)
# def get_kv_projections(self, layer_begin, layer_end):
# return [
# [layer.self_attn.k_proj, layer.self_attn.v_proj]
# for layer in self.model.base_model.layers[layer_begin:layer_end]
# ]
# def activation_to_capture(self, layer_begin, layer_end):
# if self.use_datastore:
# return [
# layer.input_layernorm
# for layer in self.model.base_model.layers[layer_begin:layer_end]
# ]
# else:
# return self.get_kv_projections(layer_begin, layer_end)
# def attention_op_to_run(self, layer_begin, layer_end):
# return [
# layer.self_attn.q_proj
# for layer in self.model.base_model.layers[layer_begin:layer_end]
# ]
# def attention_layer_to_run(self, layer_begin, layer_end):
# return self.model.base_model.layers[layer_begin:layer_end]
# def self_attention(self, decoder_layer):
# return decoder_layer.self_attn
# def cross_attention(self, decoder_layer):
# return decoder_layer.self_attn
# def window_size(self):
# return self.model.config.max_position_embeddings
# def set_gradient_checkpointing(self, value):
# self.model.base_model.gradient_checkpointing = value
# def process_key_value(self, capturers):
# key_capturer, value_capturer = capturers
# # (batch, time, heads * attn_dim)
# key, value = key_capturer.captured, value_capturer.captured
# attention = self.model.base_model.layers[-1].self_attn
# # (batch, heads, time, attn_dim)
# key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
# value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
# return key, value
# def process_query(self, output):
# # output: (batch, time, heads * attn_dim)
# attention = self.model.base_model.layers[-1].self_attn
# # query: (batch, time, heads, attn_dim)
# query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).contiguous()
# return query
# def rotate_half(self, x):
# """Rotates half the hidden dims of the input."""
# x1 = x[..., : x.shape[-1] // 2]
# x2 = x[..., x.shape[-1] // 2 :]
# return torch.cat((-x2, x1), dim=-1)
# def preprocess_query(self, query, k_proj_weight):
# # query: (batch * time, head, dim)
# attention = self.model.base_model.layers[-1].self_attn
# num_generated = min(self.input_ids_size - self.prompt_input_ids.shape[1], self.actual_model_window_size)
# cos, sin = attention.rotary_emb(query, seq_len=num_generated)
# cos = cos[:,:,-1] # [1, 1, dim]
# sin = sin[:,:,-1] # [1, 1, dim]
# # cos = cos[-1].unsqueeze(0).unsqueeze(0) # [bs, 1, seq_len, dim]
# # sin = sin[-1].unsqueeze(0) # [bs, 1, seq_len, dim]
# query = (query * cos) + (self.rotate_half(query) * sin)
# k_proj = k_proj_weight.view(1, self.num_heads, query.shape[-1], k_proj_weight.shape[0]) # (1, num_heads, attn_dim, embed_dim)
# k_proj_l = k_proj[..., :k_proj.shape[-2] // 2, :]
# k_proj_r = k_proj[..., k_proj.shape[-2] // 2:, :]
# k_proj_rotated = torch.cat([-k_proj_l, k_proj_r], dim=-2)
# datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim)
# datastore_query = torch.matmul(datastore_query, k_proj + k_proj_rotated) # (batch * beam, num_heads, 1, embed_dim)
# datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim)
# return datastore_query
# def post_process_retrieved(self, embeddings, k_proj_layer, v_proj_layer, top_search_key_indices):
# embed_dim = embeddings.shape[-1]
# k_weight = k_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, 1, heads, embed_dim, attn_dim)
# k_bias = 0
# if k_proj_layer.bias is not None:
# k_bias = k_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0)
# v_weight = v_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, heads, embed_dim, attn_dim)
# v_bias = 0
# if v_proj_layer.bias is not None:
# v_bias = v_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0)
# # new_keys, new_values: (batch, beam, head, encoder_len, attn_dim)
# retrieved_keys = torch.matmul(embeddings, k_weight) + k_bias # (beam, head, encoder_len, embed_dim)
# retrieved_values = torch.matmul(embeddings, v_weight) + v_bias # (beam, head, encoder_len, embed_dim)
# attention = self.model.base_model.layers[-1].self_attn
# cos, sin = attention.rotary_emb(retrieved_values, seq_len=self.hidden_states[0].shape[1])
# cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
# sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
# if self.prompt_input_ids.shape[1] > self.actual_model_window_size:
# # scale the top key indices to the actual model window size, such that the model will not see
# # positional embeddings that did not appear at training time
# scaled_key_indices = ((top_search_key_indices / self.prompt_input_ids.shape[1]) * self.actual_model_window_size).int()
# else:
# scaled_key_indices = top_search_key_indices
# # top_search_key_indices = top_search_key_indices.to(cos.device)
# scaled_key_indices = scaled_key_indices.to(cos.device)
# cos = cos[scaled_key_indices] # [bs, 1, seq_len, dim]
# sin = sin[scaled_key_indices] # [bs, 1, seq_len, dim]
# retrieved_keys = (retrieved_keys * cos) + (self.rotate_half(retrieved_keys) * sin)
# return retrieved_keys, retrieved_values
class ActivationCapturer(nn.Module):
def __init__(self, layer, capture_input=False):
super().__init__()
self.layer = layer
self.capture_input = capture_input
self.captured = None
def unwrap_tuple(self, t):
if isinstance(t, tuple) and len(t) == 1:
t = t[0]
return t
def forward(self, module, layer_input, layer_output):
if self.capture_input:
self.captured = self.unwrap_tuple(layer_input)
else:
self.captured = self.unwrap_tuple(layer_output)