Spaces:
Runtime error
Runtime error
File size: 63,832 Bytes
abca9bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 |
import logging
import numpy as np
import torch
from torch import nn
from enum import Enum, auto
from transformers import BartModel, BartForConditionalGeneration, \
T5Model, T5ForConditionalGeneration, \
LEDModel, LEDForConditionalGeneration, \
AutoModelForCausalLM, AutoModelForSeq2SeqLM, \
MODEL_WITH_LM_HEAD_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
from typing import TypeVar, Generic
from .index_building import Datastore, DatastoreBatch
logger = logging.getLogger('Unlimiformer')
logger.setLevel(20)
ModelType = TypeVar('ModelType')
class Unlimiformer(Generic[ModelType]):
def __init__(self, model: ModelType,
layer_begin=-1, layer_end=None,
unlimiformer_head_num=None,
exclude_attention=False,
model_encoder_max_len=None,
chunk_overlap=0,
verbose=False, save_heatmap=False,
tokenizer=None, unlimiformer_training=False,
use_datastore=False,
flat_index=False,
test_datastore=False, reconstruct_embeddings=False,
gpu_datastore=False, gpu_index=False,
index_devices=(0,), datastore_device=0,
):
super().__init__()
self.model = model
model.unlimiformer = self
self.layer_begin = layer_begin
self.layer_end = layer_end
self.specific_head = unlimiformer_head_num
self.exclude_attention = exclude_attention
self.actual_model_window_size = None
self.model_encoder_max_len = model_encoder_max_len
self.chunk_overlap = chunk_overlap
self.verbose = verbose
self.save_heatmap = save_heatmap
self.tokenizer = tokenizer
self.unlimiformer_training = unlimiformer_training
self.use_datastore = use_datastore
self.flat_index = flat_index
self.reconstruct_embeddings = reconstruct_embeddings
self.gpu_datastore = gpu_datastore
self.gpu_index = gpu_index
# if torch.cuda.is_available() and gpu_index:
# self.index_devices = [torch.device(f'cuda:{i}') for i in index_devices]
# else:
self.index_devices = [torch.device('cpu')]
self.datastore_device = torch.device('cpu')
self.test_datastore = test_datastore # flag for debugging
self.device = torch.device('cpu')
self.activation_capturer = None
self.is_encoder_decoder = model.config.is_encoder_decoder
self.hook_handles = []
self.is_input_encoding_pass = False
self.is_first_test_decoding_step = False
self.prev_tokens = None
self.last_beam_idx = None
self.heatmap = None
self.cur_decoder_layer_index = None
self.datastore = None
self.break_into(model)
def break_into(self, model):
self.actual_model_window_size = self.window_size()
if self.model_encoder_max_len is None:
self.model_encoder_max_len = self.actual_model_window_size
self.window_margin = int(self.model_encoder_max_len * self.chunk_overlap / 2)
self.num_heads = model.config.num_attention_heads
if self.specific_head is None:
self.head_nums = Ellipsis # torch.arange(0, self.num_heads, device=self.device)
else:
self.head_nums = self.specific_head
self.hooks_injected = False
self.training_hooks_injected = False
self.original_forward_func = model.forward
# Activate Unlimiformer when calling model.eval(), deactivate for model.train()
self.original_model_eval_func = model.eval
model.eval = self.pre_eval_hook
self.original_model_train_func = model.train
model.train = self.pre_train_hook
def pre_eval_hook(self):
self.remove_training_hooks(self.model)
self.inject_hooks(self.model)
self.original_model_eval_func()
def pre_train_hook(self, mode=True):
# mode=True means model.train() is called
# mode=False means model.eval() is called
torch.cuda.empty_cache()
if mode is True:
self.break_out(self.model)
if self.unlimiformer_training:
self.inject_training_hooks(self.model)
self.original_model_train_func(mode)
def inject_hooks(self, model):
if self.hooks_injected:
return
# Inject our activation_capturer to capture the activations at every forward pass
attention_layers_to_capture = self.activation_to_capture(self.layer_begin, self.layer_end)
self.activation_capturer = []
for layer in attention_layers_to_capture:
if type(layer) is list:
layer_capturers = []
for k_or_v in layer:
capturer = ActivationCapturer(k_or_v, capture_input=False)
layer_capturers.append(capturer)
self.register_hook(k_or_v, capturer)
self.activation_capturer.append(layer_capturers)
else:
capturer = ActivationCapturer(layer, capture_input=False)
self.register_hook(layer, capturer)
self.activation_capturer.append(capturer)
# Inject our main function after the main attention function
attention_layers_to_run = self.attention_op_to_run(self.layer_begin, self.layer_end)
for layer in attention_layers_to_run:
self.register_hook(layer, self.attention_forward_hook)
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
self.original_decoder_layer_cross_attn_forward_funcs = []
for i, decoder_layer in enumerate(decoder_layers_to_run):
decoder_layer_cross_attention = self.cross_attention(decoder_layer)
self.original_decoder_layer_cross_attn_forward_funcs.append(decoder_layer_cross_attention.forward)
decoder_layer_cross_attention.forward = self.create_cross_attn_pre_forward_hook(decoder_layer_cross_attention.forward, i)
# Inject our hook function in the beginning of generation.
# When the "model.generate()" will be called, it will first call our "reset_generation()" function,
# and only then call "model.generate()"
self.original_generate_func = model.generate
model.generate = self.pre_generate_hook
model.forward = self.pre_forward_hook
self.original_reorder_cache_func = model._reorder_cache
model._reorder_cache = self.reorder_cache_hook
self.hooks_injected = True
def inject_training_hooks(self, model):
if self.training_hooks_injected:
return
# self.original_forward_func = model.forward
model.forward = self.pre_forward_hook
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
self.original_decoder_layer_self_attn_forward_funcs = []
for decoder_layer in decoder_layers_to_run:
attention = self.self_attention(decoder_layer)
self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward)
attention.forward = self.create_self_attn_pre_forward_hook(attention.forward)
self.original_decoder_layer_cross_attn_forward_funcs = []
for i, decoder_layer in enumerate(decoder_layers_to_run):
decoder_layer_cross_attention = self.cross_attention(decoder_layer)
self.original_decoder_layer_cross_attn_forward_funcs.append(decoder_layer_cross_attention.forward)
decoder_layer_cross_attention.forward = self.create_cross_attn_pre_forward_hook(decoder_layer_cross_attention.forward, i)
self.original_decoder_layer_forward_funcs = []
for decoder_layer in decoder_layers_to_run:
self.original_decoder_layer_forward_funcs.append(decoder_layer.forward)
decoder_layer.forward = self.create_decoder_layer_func(decoder_layer.forward, decoder_layer)
self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run)
attention_layers_to_run = self.attention_op_to_run(self.layer_begin, self.layer_end)
for layer in attention_layers_to_run:
self.register_hook(layer, self.train_attention_forward_hook)
self.training_hooks_injected = True
def inject_hooks_for_unaffected_layers(self, model, decoder_layers_to_run):
self.original_non_injected_decoder_layer_forward_funcs = []
non_injected_decoder_layers = [l for l in self.attention_layer_to_run(0, None)
if l not in decoder_layers_to_run]
for decoder_layer in non_injected_decoder_layers:
self.original_non_injected_decoder_layer_forward_funcs.append(decoder_layer.forward)
decoder_layer.forward = self.create_noninjected_decoder_layer_func(decoder_layer.forward, decoder_layer)
def create_self_attn_pre_forward_hook(self, original_self_attn_forward_func):
def self_attention_pre_forward_hook(*args, **kwargs):
kwargs['past_key_value'] = None
return original_self_attn_forward_func(*args, **kwargs)
return self_attention_pre_forward_hook
def create_decoder_layer_func(self, decoder_layer_original_forward_func, decoder_layer):
def checkpointed_decoder_layer(
hidden_states: torch.Tensor,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
output_attentions=False,
position_bias=None,
encoder_decoder_position_bias=None,
use_cache=True):
def forward_with_all_keys(hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, past_key_value,
output_attentions, use_cache, long_inputs, long_inputs_mask,
position_bias, encoder_decoder_position_bias):
key, value = self.create_key_value(long_inputs, decoder_layer)
decoder_layer_args = self.create_decoder_layer_args(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
use_cache=use_cache,
key=key,value=value)
return decoder_layer_original_forward_func(**decoder_layer_args)
return torch.utils.checkpoint.checkpoint(
forward_with_all_keys, hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, None,
output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask,
position_bias, encoder_decoder_position_bias)
return checkpointed_decoder_layer
def create_noninjected_decoder_layer_func(self, decoder_layer_original_forward_func, decoder_layer):
def checkpointed_decoder_layer(
hidden_states: torch.Tensor,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
output_attentions=False,
position_bias=None,
encoder_decoder_position_bias=None,
use_cache=True):
def forward_with_all_keys(hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, past_key_value,
output_attentions, use_cache, long_inputs, long_inputs_mask,
position_bias, encoder_decoder_position_bias):
decoder_layer_args = self.create_decoder_layer_args(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
use_cache=use_cache, key=None, value=None)
return decoder_layer_original_forward_func(**decoder_layer_args)
return torch.utils.checkpoint.checkpoint(
forward_with_all_keys, hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask, layer_head_mask,
cross_attn_layer_head_mask, None,
output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask,
position_bias, encoder_decoder_position_bias)
return checkpointed_decoder_layer
def register_hook(self, layer, func, pre=False):
handle = layer.register_forward_pre_hook(func) if pre else layer.register_forward_hook(func)
self.hook_handles.append(handle)
def break_out(self, model):
self.prompt_keys = []
self.prompt_values = []
self.prompt_attention_mask = []
self.generated_input_ids = []
torch.cuda.empty_cache()
if not self.hooks_injected:
return
for h in self.hook_handles:
h.remove()
model.generate = self.original_generate_func
model.forward = self.original_forward_func
model._reorder_cache = self.original_reorder_cache_func
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_cross_attn_forward_funcs):
self.cross_attention(decoder_layer).forward = original_func
self.hooks_injected = False
def remove_training_hooks(self, model):
self.long_inputs_encoded, self.long_inputs_mask = None, None
if not self.training_hooks_injected:
return
for h in self.hook_handles:
h.remove()
model.forward = self.original_forward_func
decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_self_attn_forward_funcs):
self.self_attention(decoder_layer).forward = original_func
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_cross_attn_forward_funcs):
self.cross_attention(decoder_layer).forward = original_func
for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_forward_funcs):
decoder_layer.forward = original_func
non_injected_decoder_layers = [l for l in self.attention_layer_to_run(0, None)
if l not in decoder_layers_to_run]
for decoder_layer, original_func in zip(non_injected_decoder_layers, self.original_non_injected_decoder_layer_forward_funcs):
decoder_layer.forward = original_func
self.training_hooks_injected = False
def reset_memory(self, input_ids, attention_mask):
if self.use_datastore:
if self.is_encoder_decoder:
self.datastore = [DatastoreBatch(dim=self.model.config.hidden_size, batch_size=input_ids.shape[0], flat_index=self.flat_index,
gpu_index=self.gpu_index, index_device=self.index_devices[0])]
self.hidden_states = [[]]
else:
self.datastore = [DatastoreBatch(dim=self.model.config.hidden_size, batch_size=input_ids.shape[0], flat_index=self.flat_index,
gpu_index=self.gpu_index, index_device=self.index_devices[i % len(self.index_devices)])
for i in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]]
self.hidden_states = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]]
torch.cuda.empty_cache()
self.prompt_input_ids = input_ids
self.input_ids_size = input_ids.shape[-1]
self.prompt_keys, self.prompt_values = None, None
self.prev_tokens = [None for _ in range(len(self.original_decoder_layer_cross_attn_forward_funcs))]
self.last_beam_idx = None
self.cur_layer_key_value_placeholder = None
self.is_input_encoding_pass = True
if self.is_encoder_decoder:
dummy_labels = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
else:
dummy_labels = None
if self.save_heatmap:
if self.heatmap is not None:
print(f'Generated: {self.tokenizer.decode(self.generated_input_ids[0])}')
self.plot_heatmap(self.heatmap[0].detach().cpu().numpy())
self.heatmap = torch.tensor([], dtype=torch.float, device=input_ids.device)
self.generated_input_ids = torch.tensor([], dtype=torch.long, device=input_ids.device)
self.prompt_keys = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]]
self.prompt_values = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]]
self.prompt_attention_mask = []
window_indices = self.window_indices(input_ids.shape[-1])
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
logger.info(f'Encoding {context_start_ind} to {context_end_ind} out of {input_ids.shape[-1]}')
chunk = input_ids[:, context_start_ind:context_end_ind].to(self.device)
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind].to(self.device)
with torch.inference_mode():
_ = self.model(chunk, attention_mask=chunk_attention_mask, labels=dummy_labels) # , return_dict=True, output_hidden_states=True)
if self.use_datastore:
# TODO: verify with BART as well
# hidden_states_to_index = [hidden_states.encoder_last_hidden_state] # list of length 1 of (batch, chunked_source_len, dim)
hidden_states_to_index = [
layer_capturer.captured for layer_capturer in self.activation_capturer
]
# hidden_states_to_index = list(hidden_states.hidden_states)[:-1][self.layer_begin:self.layer_end]
to_add = [state[:, update_start_ind:update_end_ind].detach() for state in hidden_states_to_index]
to_apply_mask = chunk_attention_mask[:, update_start_ind:update_end_ind]
# to_apply_mask = to_apply_mask.log().to(to_add[0].dtype)
to_apply_mask = to_apply_mask.to(to_add[0].dtype)
if not self.reconstruct_embeddings:
to_add_embeddings = to_add
if not self.gpu_datastore:
to_add_embeddings = [states.cpu() for states in to_add_embeddings]
to_apply_mask = to_apply_mask.cpu()
for i, layer_states in enumerate(to_add_embeddings):
layer_states = layer_states * to_apply_mask.unsqueeze(-1)
self.hidden_states[i].append(layer_states.to(self.datastore_device))
# list of len layers, inside it there is a list of len batch, each item is (masked_time, dim)
# for i, to_add_layer in enumerate(to_add):
# keys = [key[mask.bool()] for key, mask in zip(to_add_layer, to_apply_mask)]
# self.datastore[i].add_keys(keys)
if (not self.use_datastore) or self.test_datastore:
layers_kv = [
self.process_key_value(layer_capturer) # (batch, head, time, dim)
for layer_capturer in self.activation_capturer
] # list of pairs of (batch, head, time, dim)
# list of (batch, head, chunked_time, dim)
key = [layer[0][:, :, update_start_ind:update_end_ind] for layer in layers_kv]
value = [layer[1][:, :, update_start_ind:update_end_ind] for layer in layers_kv]
chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time)
# key = torch.stack(key, dim=0) # (num_layers, batch, head, time, dim)
# value = torch.stack(value, dim=0) # (num_layers, batch, head, time, dim)
for i, (layer_key, layer_value) in enumerate(zip(key, value)):
self.prompt_keys[i].append(layer_key) # (num_layers, batch, head, chunked_source_len, dim)
self.prompt_values[i].append(layer_value) # (num_layers, batch, head, chunked_source_len, dim)
self.prompt_attention_mask.append(chunk_attention_mask) # (batch, chunked_source_len)
if self.use_datastore:
# keys are all in datastore already!
if not self.reconstruct_embeddings:
# self.hidden_states = [torch.cat(layer_hidden_states, axis=1) for layer_hidden_states in self.hidden_states]
concat_hidden_states = []
for i in range(len(self.hidden_states)):
concat_hidden_states.append(torch.cat(self.hidden_states[i], axis=1))
self.hidden_states[i] = None
self.hidden_states = concat_hidden_states
for datastore, layer_hidden_states in zip(self.datastore, self.hidden_states):
datastore.train_index(layer_hidden_states)
if (not self.use_datastore) or self.test_datastore:
for i, (layer_keys, layer_values) in enumerate(zip(self.prompt_keys, self.prompt_values)):
self.prompt_keys[i] = torch.cat(layer_keys, dim=-2)
self.prompt_values[i] = torch.cat(layer_values, dim=-2)
# self.prompt_keys = torch.cat(self.prompt_keys, dim=-2) # (num_layers, batch, head, source_len, dim)
# self.prompt_values = torch.cat(self.prompt_values, dim=-2) # (num_layers, batch, head, source_len, dim)
self.prompt_attention_mask = torch.cat(self.prompt_attention_mask, dim=-1) # (batch, source_len)
self.is_input_encoding_pass = False
if self.verbose:
print(f'Input: '
f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| '
f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}')
print()
def chunked_encode_input(self, input_ids, attention_mask):
long_inputs_encoded = []
long_inputs_mask = []
window_indices = self.window_indices(input_ids.shape[-1])
self.is_input_encoding_pass = True
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
chunk = input_ids[:, context_start_ind:context_end_ind]
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind]
output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True)
encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim)
# list of (batch, head, chunked_time, dim)
encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim)
chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time)
long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim)
long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len)
long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim)
long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len)
self.is_input_encoding_pass = False
if self.verbose:
print(f'Input: '
f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| '
f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}')
print()
return long_inputs_encoded, long_inputs_mask
def window_indices(self, total_seq_len):
# Copied from SLED (Ivgy et al., 2022)
# https://github.com/Mivg/SLED/blob/main/sled/modeling_sled.py#L467
if total_seq_len <= self.model_encoder_max_len:
return [(0, total_seq_len, 0, total_seq_len)]
else:
results = []
# if self.chunk_overlap == 0:
# stride = self.model_encoder_max_len
stride = self.model_encoder_max_len - 2 * self.window_margin
context_start = update_start_ind = 0
context_end = self.model_encoder_max_len
if self.is_encoder_decoder:
update_end_ind = context_end - self.window_margin
else:
update_end_ind = context_end
# first window always should update from the beginning
results.append((context_start, context_end, update_start_ind, update_end_ind))
while context_end < total_seq_len:
context_end = min(total_seq_len, context_end + stride)
context_start = (
context_start + stride if context_end < total_seq_len else total_seq_len - self.model_encoder_max_len
)
update_start_ind = max(update_start_ind + stride, update_end_ind)
# last window always should update until the end
update_end_ind = (
min(total_seq_len, update_end_ind + stride) if context_end < total_seq_len else total_seq_len
)
cs, ce, us, ue = context_start, context_end, update_start_ind - context_start, \
update_end_ind - context_start
results.append((cs, ce, us, ue))
return results
def pre_generate_hook(self, input_ids, **kwargs):
if 'attention_mask' not in kwargs:
kwargs['attention_mask'] = torch.ones_like(input_ids)
self.reset_memory(input_ids, kwargs['attention_mask'])
new_kwargs = kwargs
if 'attention_mask' in kwargs:
new_kwargs = {k: v for k, v in kwargs.items() if k != 'attention_mask'}
new_kwargs['attention_mask'] = kwargs['attention_mask'][:, :self.actual_model_window_size].to(self.device)
new_kwargs['use_cache'] = True
if self.is_encoder_decoder:
input_ids_prefix = input_ids[:, :self.actual_model_window_size]
else:
input_ids_prefix = input_ids[:, -self.actual_model_window_size:]
input_ids_prefix = input_ids_prefix.to(self.device)
return self.original_generate_func(input_ids_prefix, **new_kwargs)
def pre_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
self.set_gradient_checkpointing(False)
if not self.is_input_encoding_pass:
if self.model.training:
# self.reset_memory(input_ids, attention_mask)
self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask)
input_ids = input_ids[:, :self.actual_model_window_size]
attention_mask = attention_mask[:, :self.actual_model_window_size] if attention_mask is not None else None
# input_ids = input_ids[:, :self.model_encoder_max_len]
# labels = labels[:, :self.model_encoder_max_len] if labels is not None else None
else:
if kwargs.get('past_key_values') is None:
self.is_first_test_decoding_step = True
if input_ids is not None:
# self.input_ids_size += input_ids.shape[-1]
self.input_ids_size += 1
if kwargs.get('decoder_input_ids') is not None:
self.generated_input_ids = torch.cat([self.generated_input_ids, kwargs['decoder_input_ids']], axis=-1)
result = self.original_forward_func(input_ids=input_ids, labels=labels, attention_mask=attention_mask, **kwargs)
self.is_first_test_decoding_step = False
return result
def create_cross_attn_pre_forward_hook(self, original_cross_attn_forward_func, cur_layer_num):
def attention_pre_forward_hook(hidden_states, attention_mask=None, *args, **kwargs):
self.cur_decoder_layer_index = cur_layer_num
if kwargs.get('past_key_value') is not None:
# it's a tuple, and we convert it to a list to be able to perform assignment
# and modify its items from our attention_forward_hook
self.cur_layer_key_value_placeholder = \
kwargs['past_key_value'] = list(kwargs['past_key_value']) # (batch, head, time, attn_dim)
batch_size, tgt_len, dim = hidden_states.shape
if self.model.training:
# from: (batch, tgt_len, dim) to: (batch * tgt_len, 1, dim)
hidden_states = hidden_states.reshape(-1, 1, hidden_states.shape[-1])
# from: (batch, 1, tgt_len, dim) to: (batch * tgt_len, 1, 1, dim)
attention_mask = attention_mask.reshape(-1, 1, 1, attention_mask.shape[-1])
attn_output, attn_weights_reshaped, past_key_value = original_cross_attn_forward_func(hidden_states=hidden_states, attention_mask=attention_mask, *args, **kwargs)
attn_output = attn_output.reshape(batch_size, tgt_len, dim)
result = (attn_output, attn_weights_reshaped, past_key_value)
else:
result = original_cross_attn_forward_func(hidden_states=hidden_states, attention_mask=attention_mask, *args, **kwargs)
# Uri: this part adds the generated tokens to the prompt.
# However it was commented out because currently we always keep the generated tokens in the attention window
# if not self.is_encoder_decoder and not self.is_input_encoding_pass and \
# past_key_value[0].shape[2] > self.prompt_keys[self.cur_decoder_layer_index].shape[2]:
# self.prompt_keys[self.cur_decoder_layer_index] = torch.cat([self.prompt_keys[self.cur_decoder_layer_index], past_key_value[0][:,:,-1:]], dim=-2)
# self.prompt_values[self.cur_decoder_layer_index] = torch.cat([self.prompt_values[self.cur_decoder_layer_index], past_key_value[1][:,:,-1:]], dim=-2)
# if self.cur_decoder_layer_index == self.model.config.num_hidden_layers - 1:
# self.prompt_attention_mask = torch.cat([
# self.prompt_attention_mask,
# torch.ones([self.prompt_attention_mask.shape[0], 1], dtype=self.prompt_attention_mask.dtype).to(self.device)], dim=-1)
return result
return attention_pre_forward_hook
def attention_forward_hook(self, module, input, output):
# output: (batch, time, 3 * heads * attention_dim)
if self.is_input_encoding_pass or self.is_first_test_decoding_step:
return
with torch.no_grad():
prompt_size = self.prompt_input_ids.shape[1]
generated_size = self.input_ids_size - prompt_size
window_size = self.cur_layer_key_value_placeholder[0].shape[-2]
# topk = min(self.actual_model_window_size, attn_weights.shape[-1])
topk = min(prompt_size, window_size)
if not self.is_encoder_decoder:
topk = min(topk, window_size - generated_size + 1)
if self.gpu_index:
topk = min(topk, 2048)
query = self.process_query(output)[:,-1] # (batch * beam, head, dim)
query = query[:, self.head_nums] # (batch * beam, head, dim)
if self.use_datastore:
# query: (batch, beam, head, dim)
# need to multiply by key vector
# query.view(query.shape[0], query.shape[1] * query.shape[2])
# k_proj in attention?
datastore_index = 0 if self.is_encoder_decoder else self.cur_decoder_layer_index
attention_layer_list = self.get_kv_projections(self.layer_begin, self.layer_end)
k_proj_layer = [layers[0] for layers in attention_layer_list][self.cur_decoder_layer_index]
v_proj_layer = [layers[1] for layers in attention_layer_list][self.cur_decoder_layer_index]
# modify query by k_projs
k_proj = k_proj_layer.weight
datastore_query = self.preprocess_query(query, k_proj) # (batch * beam, num_heads, embed_dim)
batch_size = self.datastore[datastore_index].batch_size
datastore_query = datastore_query.view((batch_size, -1, datastore_query.shape[2])) # (batch, beam * num_heads, embed_dim)
# then search
if self.reconstruct_embeddings:
# embeddings: (batch, beam * head, actual_model_window_size, dim)
_, top_search_key_indices, embeddings = self.datastore[datastore_index].search_and_reconstruct(datastore_query, k=topk)
else:
_, top_search_key_indices = self.datastore[datastore_index].search(datastore_query, k=topk)
# self.embeddings: (batch, src_len, dim)
# indices: (batch, beam * head, actual_model_window_size)
# embeddings: (batch, beam * head, actual_model_window_size, dim)
embeddings = torch.take_along_dim(input=self.hidden_states[datastore_index].unsqueeze(1),
indices=top_search_key_indices.unsqueeze(-1).to(self.hidden_states[datastore_index].device), dim=-2)
embeddings = embeddings.to(self.device)
# (batch, beam, head, actual_model_window_size)
# top_search_key_scores = top_search_key_scores.reshape(batch_size, -1, *top_search_key_scores.shape[1:])
top_search_key_indices = top_search_key_indices.reshape(batch_size, -1, *top_search_key_indices.shape[1:])
# embeddings: (batch, beam, head, actual_model_window_size, dim)
embeddings = embeddings.reshape(batch_size, -1, self.num_heads, *embeddings.shape[2:])
# raw_values are actually token indices; need to look them up
if (not self.use_datastore) or self.test_datastore:
this_layer_prompt_keys = self.prompt_keys[self.cur_decoder_layer_index]
this_layer_prompt_values = self.prompt_values[self.cur_decoder_layer_index]
# query: (batch * beam, head, dim)
batch_size = self.prompt_input_ids.shape[0]
beam_size = query.shape[0] // batch_size
# query: (batch, beam, head, dim)
query = query.reshape(batch_size, beam_size, *query.shape[1:])
# this_layer_prompt_keys: (batch, head, source_len, dim)
# this_layer_prompt_keys.unsqueeze(1): (batch, 1, head, source_len, dim)
# query.unsqueeze(-1): (batch, beam, head, dim, 1)
# attn_weights: (batch, beam, head, source_len)
attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums], query.unsqueeze(-1)).squeeze(-1)
# attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums]).squeeze(-2)
prompt_attention_mask_to_add = (1 - self.prompt_attention_mask) * -1e9 # (batch, source_len)
prompt_attention_mask_to_add = prompt_attention_mask_to_add.unsqueeze(1).unsqueeze(1)
attn_weights += prompt_attention_mask_to_add # (batch, beam, head, source_len)
if self.exclude_attention and attn_weights.shape[-1] > self.actual_model_window_size:
attn_weights[..., :self.actual_model_window_size] -= 1e9
# target_keys, target_values, topk = self.get_target_slices(output)
top_key_scores, top_key_indices = torch.topk(attn_weights, k=topk, dim=-1, sorted=True) # (batch, beam, head, trunc_source)
if self.save_heatmap:
# heatrow: (beam, heads, source_len)
heatrow = torch.zeros([top_key_indices.shape[1], top_key_indices.shape[2], this_layer_prompt_keys.shape[-2]], dtype=torch.float)
heatrow = heatrow.scatter(index=top_key_indices[0], src=torch.ones_like(top_key_scores[0]), dim=-1)
# heatrow = torch.nn.functional.softmax(heatrow, dim=-1)
# self.heatmap: (beam, heads, targets, source_len)
self.heatmap = torch.cat([self.heatmap, heatrow.unsqueeze(-2)], axis=-2)
if self.test_datastore:
assert top_key_indices.shape == top_search_key_indices.shape
assert torch.mean((top_key_indices == top_search_key_indices).float()) > 0.99
if self.verbose:
if self.is_encoder_decoder:
for i, beam in enumerate(self.generated_input_ids):
print(f'({i}) Generated: {self.tokenizer.decode(beam)}')
# else:
# print(f'Generated: {self.tokenizer.decode(self.input_ids)}')
print()
if self.use_datastore:
# k_proj_layer.weight, v_proj_layer.weight: (embed_dim, embed_dim)
# embeddings: (batch, beam, head, encoder_len, embed_dim)
retrieved_keys, retrieved_values = self.post_process_retrieved(embeddings, k_proj_layer, v_proj_layer, top_search_key_indices)
else:
# this_layer_prompt_keys: (batch, head, source_len, dim)
# top_key_indices: (batch, beam, head, trunc_source)
retrieved_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, head, trunc_source, attn_dim)
retrieved_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, head, trunc_source, attn_dim)
if self.test_datastore:
correct_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, head, trunc_source, attn_dim)
correct_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, head, trunc_source, attn_dim)
assert correct_keys.shape == retrieved_keys.shape
assert correct_values.shape == retrieved_values.shape
assert torch.mean(torch.isclose(correct_keys, retrieved_keys, rtol=1e-3, atol=1e-3).float()) > 0.99
assert torch.mean(torch.isclose(correct_values, retrieved_values, rtol=1e-3, atol=1e-3).float()) > 0.99
# retrieved_keys, retrieved_values: (batch * beam, head, encoder_len, attn_dim)
retrieved_keys = retrieved_keys.flatten(0, 1)[:,:,:topk]
retrieved_values = retrieved_values.flatten(0, 1)[:,:,:topk]
self.cur_layer_key_value_placeholder[0] = torch.cat([retrieved_keys, self.cur_layer_key_value_placeholder[0][:,:,topk:]], dim=-2)
self.cur_layer_key_value_placeholder[1] = torch.cat([retrieved_values, self.cur_layer_key_value_placeholder[1][:,:,topk:]], dim=-2)
return
def train_attention_forward_hook(self, module, input, output):
# output: (batch, time, 3 * heads * attention_dim)
if self.is_input_encoding_pass or self.is_first_test_decoding_step:
return
this_layer_prompt_keys = self.cur_layer_key_value_placeholder[0]
this_layer_prompt_values = self.cur_layer_key_value_placeholder[1]
with torch.no_grad():
query = self.process_query(output) # (batch * beam, tgt_len, head, dim)
# query = query[:, :, self.head_nums] # (batch * beam, head, dim)
# query: (batch * beam, tgt_len, head, dim)
batch_size = this_layer_prompt_keys.shape[0]
tgt_len = query.shape[0] // batch_size
# query: (batch, tgt, head, dim)
query = query.reshape(batch_size, tgt_len, *query.shape[2:])
# this_layer_prompt_keys: (batch, head, source_len, dim)
# this_layer_prompt_keys.unsqueeze(1): (batch, 1, head, source_len, dim)
# attn_weights: (batch, tgt_len, head, 1, source_len)
# attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1).permute(0,1,2,4,3))
attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1), query.unsqueeze(-1)) \
.reshape(batch_size, tgt_len, query.shape[-2], 1, this_layer_prompt_keys.shape[-2])
# attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums]).squeeze(-2)
prompt_attention_mask_to_add = (1 - self.long_inputs_mask) * -1e9 # (batch, source_len)
prompt_attention_mask_to_add = prompt_attention_mask_to_add.unsqueeze(1).unsqueeze(1).unsqueeze(1)
attn_weights += prompt_attention_mask_to_add # (batch, beam, head, source_len)
# target_keys, target_values, topk = self.get_target_slices(output)
topk = min(self.actual_model_window_size, attn_weights.shape[-1])
top_key_scores, top_key_indices = torch.topk(attn_weights, k=min(topk, attn_weights.shape[-1]), dim=-1, sorted=True) # (batch, beam, head, tgt, trunc_source)
# this_layer_prompt_keys: (batch, head, source_len, dim)
# top_key_indices: (batch, tgt_len, head, 1, trunc_source)
new_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(2).unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, tgt_len, head, 1, trunc_source, attn_dim)
new_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(2).unsqueeze(1), indices=top_key_indices.unsqueeze(-1),
dim=-2) # (batch, tgt_len, head, 1, trunc_source, attn_dim)
# (batch * beam, head, tgt_len, trunc_source, attn_dim)
self.cur_layer_key_value_placeholder[0] = new_keys.flatten(0, 1).squeeze(2)
self.cur_layer_key_value_placeholder[1] = new_values.flatten(0, 1).squeeze(2)
return
def preprocess_query(self, query, k_proj_weight):
k_proj = k_proj_weight.view(1, self.num_heads, query.shape[-1], k_proj_weight.shape[0]) # (1, num_heads, attn_dim, embed_dim)
datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim)
datastore_query = torch.matmul(datastore_query, k_proj) # (batch * beam, num_heads, 1, embed_dim)
datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim)
return datastore_query
def post_process_retrieved(self, embeddings, k_proj_layer, v_proj_layer, top_search_key_indices):
embed_dim = embeddings.shape[-1]
k_weight = k_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, 1, heads, embed_dim, attn_dim)
k_bias = 0
if k_proj_layer.bias is not None:
k_bias = k_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0)
v_weight = v_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, heads, embed_dim, attn_dim)
v_bias = 0
if v_proj_layer.bias is not None:
v_bias = v_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0)
# new_keys, new_values: (batch, beam, head, encoder_len, attn_dim)
retrieved_keys = torch.matmul(embeddings, k_weight) + k_bias # (beam, head, encoder_len, embed_dim)
retrieved_values = torch.matmul(embeddings, v_weight) + v_bias # (beam, head, encoder_len, embed_dim)
return retrieved_keys, retrieved_values
def set_gradient_checkpointing(self, value):
self.model.base_model.decoder.gradient_checkpointing = value
def reorder_cache_hook(self, past, beam_idx):
self.last_beam_idx = beam_idx
self.generated_input_ids = self.generated_input_ids[beam_idx]
for i, layer_prev_tokens in enumerate(self.prev_tokens):
if layer_prev_tokens is not None:
self.prev_tokens[i] = layer_prev_tokens.flatten(0, 1)[beam_idx].reshape(layer_prev_tokens.shape)
if self.save_heatmap and self.heatmap.numel() > 0:
self.heatmap = self.heatmap[beam_idx]
return self.original_reorder_cache_func(past, beam_idx)
@classmethod
def convert_model(cls, model, *args, **kwargs):
# if type(model.config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
# elif type(model.config) in MODEL_WITH_LM_HEAD_MAPPING:
# else:
# raise ValueError(f'Unsupported model type: {type(model.config)}')
# if model.config.is_encoder_decoder:
# model_clone = AutoModelForSeq2SeqLM.from_config(model.config)
# else:
# model_clone = AutoModelForCausalLM.from_config(model.config)
# model_clone.load_state_dict(model.state_dict()).to(args.device)
type_to_class = {
BartModel: UnlimiformerBART,
BartForConditionalGeneration: UnlimiformerBART,
T5Model: UnlimiformerT5,
T5ForConditionalGeneration: UnlimiformerT5,
LEDModel: UnlimiformerLED,
LEDForConditionalGeneration: UnlimiformerLED,
# LlamaModel: UnlimiformerLLaMa,
# LlamaForCausalLM: UnlimiformerLLaMa,
}
type_to_class[type(model)](model, *args, **kwargs)
return model
def plot_heatmap(self, data, xticklabels='auto', yticklabels='auto'):
# data: (heads, targets, source_len)
import seaborn as sb
import matplotlib.pyplot as plt
# print('gat = np.array([')
# for row in data[0]:
# rowstr = ', '.join([f'{x:.2f}' for x in row])
# print(f' [{rowstr}],')
# print(']')
# sb.set(font_scale=1.5, rc={'text.usetex': True})
for i in range(data.shape[0]):
fig, axes = plt.subplots(1, 1, figsize=(40, 100))
cur_ax = axes
axes.set_title(f'Head #{i}, length: {data.shape[2]}, target length: {data.shape[1]}')
cur_ax = axes
# annot = [[x for x in row] for row in data]
ax = sb.heatmap(data[i], annot=False, fmt='.2f',
xticklabels=512, yticklabels=yticklabels, ax=cur_ax)
ax.xaxis.tick_top()
plt.savefig(f'knns_head{i}.pdf')
# plt.savefig('gat_s10_contrast.pdf')
plt.show()
class UnlimiformerBART(Unlimiformer[BartModel]):
def __init__(self, model: BartModel, *args, **kwargs):
super().__init__(model, *args, **kwargs)
def create_key_value(self, encoder_hidden_states, decoder_layer):
# (batch, time, hidden_dim)
attention = decoder_layer.encoder_attn
# key, value: (batch, heads, time, attn_dim)
key = attention.k_proj(encoder_hidden_states)
key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
value = attention.v_proj(encoder_hidden_states)
value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
# key, value: (batch, heads, time, attn_dim)
return key, value
def process_key_value(self, capturers):
key_capturer, value_capturer = capturers
key, value = key_capturer.captured, value_capturer.captured
# (batch, time, heads, attn_dim)
attention = self.model.base_model.decoder.layers[-1].encoder_attn
# query, key, value: (batch, heads, time, attn_dim)
# query = query.view(query.shape[0], query.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
return key, value
def process_query(self, output):
# (batch, time, heads, attn_dim)
attention = self.model.base_model.decoder.layers[-1].encoder_attn
# query: (batch, heads, time, attn_dim)
# query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).contiguous()
return query
def get_kv_projections(self, layer_begin, layer_end):
return [
[layer.encoder_attn.k_proj, layer.encoder_attn.v_proj]
for layer in self.model.base_model.decoder.layers[layer_begin:layer_end]
]
def activation_to_capture(self, layer_begin, layer_end):
if self.use_datastore:
return [self.model.base_model.encoder.layers[-1]]
else:
return self.get_kv_projections(layer_begin, layer_end)
def attention_op_to_run(self, layer_begin, layer_end):
return [
layer.encoder_attn.q_proj
for layer in self.model.base_model.decoder.layers[layer_begin:layer_end]
]
def attention_layer_to_run(self, layer_begin, layer_end):
return self.model.base_model.decoder.layers[layer_begin:layer_end]
def self_attention(self, decoder_layer):
return decoder_layer.self_attn
def cross_attention(self, decoder_layer):
return decoder_layer.encoder_attn
def window_size(self):
return self.model.config.max_position_embeddings
def create_decoder_layer_args(self, hidden_states, attention_mask, encoder_hidden_states,
encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask,
past_key_value, output_attentions, position_bias,
encoder_decoder_position_bias, use_cache, key, value):
args = {'hidden_states': hidden_states,
'attention_mask': attention_mask,
'encoder_hidden_states': encoder_hidden_states,
'encoder_attention_mask': encoder_attention_mask,
'layer_head_mask': layer_head_mask,
'cross_attn_layer_head_mask': cross_attn_layer_head_mask,
'past_key_value': (None, None, key, value),
'output_attentions': output_attentions,
'use_cache': use_cache,}
if key is None and value is None:
args['past_key_value'] = None
return args
class UnlimiformerT5(Unlimiformer[T5Model]):
def __init__(self, model: T5Model, *args, **kwargs):
super().__init__(model, *args, **kwargs)
def create_key_value(self, encoder_hidden_states, decoder_layer):
# (batch, time, hidden_dim)
attention = decoder_layer.layer[1].EncDecAttention
# key, value: (batch, heads, time, attn_dim)
key = attention.k(encoder_hidden_states)
key = key.view(key.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous()
value = attention.v(encoder_hidden_states)
value = value.view(value.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous()
return key, value
def process_key_value(self, capturers):
key_capturer, value_capturer = capturers
key, value = key_capturer.captured, value_capturer.captured
# (batch, time, heads, attn_dim)
attention = self.model.base_model.decoder.block[-1].layer[1].EncDecAttention
# query, key, value: (batch, heads, time, attn_dim)
# query = query.view(query.shape[0], query.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
key = key.view(key.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous()
value = value.view(value.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous()
return key, value
def process_query(self, output):
# (batch, time, heads, attn_dim)
attention = self.model.base_model.decoder.block[-1].layer[1].EncDecAttention
# query: (batch, heads, time, attn_dim)
query = output.view(output.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).contiguous()
return query
def get_kv_projections(self, layer_begin, layer_end):
return [
[layer.layer[1].EncDecAttention.k, layer.layer[1].EncDecAttention.v]
for layer in self.model.base_model.decoder.block[layer_begin:layer_end]
]
def activation_to_capture(self, layer_begin, layer_end):
if self.use_datastore:
return [self.model.base_model.encoder.layers[-1]]
else:
return self.get_kv_projections(layer_begin, layer_end)
def attention_op_to_run(self, layer_begin, layer_end):
return [
layer.layer[1].EncDecAttention.q
for layer in self.model.base_model.decoder.block[layer_begin:layer_end]
]
def attention_layer_to_run(self, layer_begin, layer_end):
return self.model.base_model.decoder.block[layer_begin:layer_end]
def self_attention(self, decoder_layer):
return decoder_layer.layer[0]
def cross_attention(self, decoder_layer):
return decoder_layer.layer[1]
def window_size(self):
try:
size = self.model.config.n_positions
except AttributeError:
size = 1024
return size
def create_decoder_layer_args(self, hidden_states, attention_mask, encoder_hidden_states,
encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask,
past_key_value, output_attentions, position_bias,
encoder_decoder_position_bias, use_cache, key, value):
args = {'hidden_states': hidden_states,
'attention_mask': attention_mask,
'position_bias': position_bias,
'encoder_hidden_states': encoder_hidden_states,
'encoder_attention_mask': encoder_attention_mask,
'encoder_decoder_position_bias': encoder_decoder_position_bias,
'layer_head_mask': layer_head_mask,
'cross_attn_layer_head_mask': cross_attn_layer_head_mask,
'past_key_value': (None, None, key, value),
'use_cache': use_cache,
'output_attentions': output_attentions}
if key is None and value is None:
args['past_key_value'] = None
return args
class UnlimiformerLED(UnlimiformerBART):
def __init__(self, model: LEDModel, *args, **kwargs):
super().__init__(model, *args, **kwargs)
def window_size(self):
return self.model.config.max_encoder_position_embeddings
# class UnlimiformerLLaMa(Unlimiformer[LlamaModel]):
# def __init__(self, model: LlamaModel, *args, **kwargs):
# super().__init__(model, *args, **kwargs)
# def get_kv_projections(self, layer_begin, layer_end):
# return [
# [layer.self_attn.k_proj, layer.self_attn.v_proj]
# for layer in self.model.base_model.layers[layer_begin:layer_end]
# ]
# def activation_to_capture(self, layer_begin, layer_end):
# if self.use_datastore:
# return [
# layer.input_layernorm
# for layer in self.model.base_model.layers[layer_begin:layer_end]
# ]
# else:
# return self.get_kv_projections(layer_begin, layer_end)
# def attention_op_to_run(self, layer_begin, layer_end):
# return [
# layer.self_attn.q_proj
# for layer in self.model.base_model.layers[layer_begin:layer_end]
# ]
# def attention_layer_to_run(self, layer_begin, layer_end):
# return self.model.base_model.layers[layer_begin:layer_end]
# def self_attention(self, decoder_layer):
# return decoder_layer.self_attn
# def cross_attention(self, decoder_layer):
# return decoder_layer.self_attn
# def window_size(self):
# return self.model.config.max_position_embeddings
# def set_gradient_checkpointing(self, value):
# self.model.base_model.gradient_checkpointing = value
# def process_key_value(self, capturers):
# key_capturer, value_capturer = capturers
# # (batch, time, heads * attn_dim)
# key, value = key_capturer.captured, value_capturer.captured
# attention = self.model.base_model.layers[-1].self_attn
# # (batch, heads, time, attn_dim)
# key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
# value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous()
# return key, value
# def process_query(self, output):
# # output: (batch, time, heads * attn_dim)
# attention = self.model.base_model.layers[-1].self_attn
# # query: (batch, time, heads, attn_dim)
# query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).contiguous()
# return query
# def rotate_half(self, x):
# """Rotates half the hidden dims of the input."""
# x1 = x[..., : x.shape[-1] // 2]
# x2 = x[..., x.shape[-1] // 2 :]
# return torch.cat((-x2, x1), dim=-1)
# def preprocess_query(self, query, k_proj_weight):
# # query: (batch * time, head, dim)
# attention = self.model.base_model.layers[-1].self_attn
# num_generated = min(self.input_ids_size - self.prompt_input_ids.shape[1], self.actual_model_window_size)
# cos, sin = attention.rotary_emb(query, seq_len=num_generated)
# cos = cos[:,:,-1] # [1, 1, dim]
# sin = sin[:,:,-1] # [1, 1, dim]
# # cos = cos[-1].unsqueeze(0).unsqueeze(0) # [bs, 1, seq_len, dim]
# # sin = sin[-1].unsqueeze(0) # [bs, 1, seq_len, dim]
# query = (query * cos) + (self.rotate_half(query) * sin)
# k_proj = k_proj_weight.view(1, self.num_heads, query.shape[-1], k_proj_weight.shape[0]) # (1, num_heads, attn_dim, embed_dim)
# k_proj_l = k_proj[..., :k_proj.shape[-2] // 2, :]
# k_proj_r = k_proj[..., k_proj.shape[-2] // 2:, :]
# k_proj_rotated = torch.cat([-k_proj_l, k_proj_r], dim=-2)
# datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim)
# datastore_query = torch.matmul(datastore_query, k_proj + k_proj_rotated) # (batch * beam, num_heads, 1, embed_dim)
# datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim)
# return datastore_query
# def post_process_retrieved(self, embeddings, k_proj_layer, v_proj_layer, top_search_key_indices):
# embed_dim = embeddings.shape[-1]
# k_weight = k_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, 1, heads, embed_dim, attn_dim)
# k_bias = 0
# if k_proj_layer.bias is not None:
# k_bias = k_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0)
# v_weight = v_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, heads, embed_dim, attn_dim)
# v_bias = 0
# if v_proj_layer.bias is not None:
# v_bias = v_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0)
# # new_keys, new_values: (batch, beam, head, encoder_len, attn_dim)
# retrieved_keys = torch.matmul(embeddings, k_weight) + k_bias # (beam, head, encoder_len, embed_dim)
# retrieved_values = torch.matmul(embeddings, v_weight) + v_bias # (beam, head, encoder_len, embed_dim)
# attention = self.model.base_model.layers[-1].self_attn
# cos, sin = attention.rotary_emb(retrieved_values, seq_len=self.hidden_states[0].shape[1])
# cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
# sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
# if self.prompt_input_ids.shape[1] > self.actual_model_window_size:
# # scale the top key indices to the actual model window size, such that the model will not see
# # positional embeddings that did not appear at training time
# scaled_key_indices = ((top_search_key_indices / self.prompt_input_ids.shape[1]) * self.actual_model_window_size).int()
# else:
# scaled_key_indices = top_search_key_indices
# # top_search_key_indices = top_search_key_indices.to(cos.device)
# scaled_key_indices = scaled_key_indices.to(cos.device)
# cos = cos[scaled_key_indices] # [bs, 1, seq_len, dim]
# sin = sin[scaled_key_indices] # [bs, 1, seq_len, dim]
# retrieved_keys = (retrieved_keys * cos) + (self.rotate_half(retrieved_keys) * sin)
# return retrieved_keys, retrieved_values
class ActivationCapturer(nn.Module):
def __init__(self, layer, capture_input=False):
super().__init__()
self.layer = layer
self.capture_input = capture_input
self.captured = None
def unwrap_tuple(self, t):
if isinstance(t, tuple) and len(t) == 1:
t = t[0]
return t
def forward(self, module, layer_input, layer_output):
if self.capture_input:
self.captured = self.unwrap_tuple(layer_input)
else:
self.captured = self.unwrap_tuple(layer_output)
|