added notebook for generation
Browse files- generation.ipynb +118 -0
- long_models/diagonaled_mm_tvm.py +339 -0
- long_models/longformer.py +277 -0
- long_models/longformer_bart.py +356 -0
- long_models/longformer_mbart.py +352 -0
- long_models/sliding_chunks.py +185 -0
generation.ipynb
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "c7544851-fb94-44e0-86eb-65d74fad45aa",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from long_models.longformer_mbart import MLongformerEncoderDecoderConfig, MLongformerEncoderDecoderForConditionalGeneration\n",
|
11 |
+
"from transformers import MBartTokenizer"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"id": "c1446321-3b74-4fbb-9871-403a82ceb0de",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"tokenizer = MBartTokenizer.from_pretrained(\"./\")\n",
|
22 |
+
"config = MLongformerEncoderDecoderConfig.from_pretrained('./')\n",
|
23 |
+
"model = MLongformerEncoderDecoderForConditionalGeneration.from_pretrained('./', config=config)\n",
|
24 |
+
"tokenizer.src_lang = 'de_DE'"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": null,
|
30 |
+
"id": "65105f0c-eb1c-4e2f-8f21-23e3fbec81c1",
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"input_txt = \"Ein Gebet von Mose, dem Mann Gottes: Mein Herr, eine sichere Wohnung bist du für uns gewesen von Generation zu Generation. Bevor die Berge geboren waren und du die Erde und die irdische Welt hervorgebracht hattest, und von Ewigkeit zu Ewigkeit bist du Gott. Du lässt den Menschen zum Staub zurückkehren und sprichst: „Kehrt zurück, ihr Kinder des Menschen! “ Denn tausend Jahre sind in deinen Augen wie der gestrige Tag, wenn er vergangen ist, oder eine Wache in der Nacht. Du schwemmst sie weg, sie sind wie Schlaf, am Morgen wie Gras, das aufsprosst. Am Morgen blüht es und sprosst auf, zum Abend verwelkt und verdorrt es. Denn wir vergehen durch deinen Zorn, und durch deine Zorneshitze werden wir verstört. Du stellst unsere Fehler vor dich, unsere Geheimnisse vor das Licht deiner Gegenwart. Ja, alle unsere Tage fahren dahin durch deinen Zorn, wir vollenden unsere Jahre wie einen Seufzer. Die Tage unserer Jahre, in ihnen sind siebzig Jahre, und mit Kraft achtzig Jahre. und Ihr Stolz ist Mühe und Beschwerde, denn er ist schnell vergangen und wir fliegen davon. Wer erkennt die Stärke deines Zorns? Wie die Furcht vor dir ist dein Grimm. Darum lehre uns, unsere Tage zu zählen, damit wir ein Herz der Weisheit bekommen. Kehre doch zurück, JHWH! Wie lange? Habe Mitleid mit deinen Knechten! Sättige uns am Morgen mit deiner Güte, dann werden wir jubeln und uns freuen an allen unseren Tagen! Erfreue uns so viele Tage, wie du uns bedrückt hast, so viele Jahre, wie wir Unglück gesehen haben! Zeige deinen Knechten dein Handeln und deine Herrlichkeit ihren Kindern! Die Freundlichkeit des Herrn, unseres Gottes sei über uns! Das Werk unserer Hände festige über uns, und das Werk unserer Hände, festige es!\""
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": null,
|
40 |
+
"id": "d7448395-6af9-44f4-88a0-33ed5fb80fd8",
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"print(input_txt)"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": null,
|
50 |
+
"id": "f9483233-7c6b-4f23-b69a-b19b4de053be",
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [],
|
53 |
+
"source": [
|
54 |
+
"inputs = tokenizer(\n",
|
55 |
+
" input_txt, \n",
|
56 |
+
" padding='max_length',\n",
|
57 |
+
" return_tensors='pt')"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": null,
|
63 |
+
"id": "4aea1744-dfeb-498e-9fc7-d3f77ff012cb",
|
64 |
+
"metadata": {},
|
65 |
+
"outputs": [],
|
66 |
+
"source": [
|
67 |
+
"outputs = model.generate(**inputs, num_beams=6, decoder_start_token_id=tokenizer.convert_tokens_to_ids('de_SI'))"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"cell_type": "code",
|
72 |
+
"execution_count": null,
|
73 |
+
"id": "7dcc7f7d-edff-4647-acab-fd610e27d0a7",
|
74 |
+
"metadata": {},
|
75 |
+
"outputs": [],
|
76 |
+
"source": [
|
77 |
+
"tokenizer.batch_decode(outputs)"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": null,
|
83 |
+
"id": "5dd353d8-4de6-470e-b0e6-86c7c6640207",
|
84 |
+
"metadata": {},
|
85 |
+
"outputs": [],
|
86 |
+
"source": []
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": null,
|
91 |
+
"id": "ba203c11-b222-484a-8d53-44a4d9da8ef5",
|
92 |
+
"metadata": {},
|
93 |
+
"outputs": [],
|
94 |
+
"source": []
|
95 |
+
}
|
96 |
+
],
|
97 |
+
"metadata": {
|
98 |
+
"kernelspec": {
|
99 |
+
"display_name": "Python 3 (ipykernel)",
|
100 |
+
"language": "python",
|
101 |
+
"name": "python3"
|
102 |
+
},
|
103 |
+
"language_info": {
|
104 |
+
"codemirror_mode": {
|
105 |
+
"name": "ipython",
|
106 |
+
"version": 3
|
107 |
+
},
|
108 |
+
"file_extension": ".py",
|
109 |
+
"mimetype": "text/x-python",
|
110 |
+
"name": "python",
|
111 |
+
"nbconvert_exporter": "python",
|
112 |
+
"pygments_lexer": "ipython3",
|
113 |
+
"version": "3.10.6"
|
114 |
+
}
|
115 |
+
},
|
116 |
+
"nbformat": 4,
|
117 |
+
"nbformat_minor": 5
|
118 |
+
}
|
long_models/diagonaled_mm_tvm.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""
|
5 |
+
|
6 |
+
This code is from AllenAI's Longformer:
|
7 |
+
https://github.com/allenai/longformer/
|
8 |
+
|
9 |
+
"""
|
10 |
+
from typing import Union
|
11 |
+
from functools import lru_cache
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import os.path
|
15 |
+
|
16 |
+
|
17 |
+
class DiagonaledMM(torch.autograd.Function):
|
18 |
+
'''Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling
|
19 |
+
this function from PyTorch
|
20 |
+
'''
|
21 |
+
|
22 |
+
function_dict = {} # save a list of functions, each has a different set of parameters
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def _compile_function(dtype: str, device: str, b0: int = 4, b1: int = 4, b2: int = 16):
|
26 |
+
'''Compiles a tvm function that computes diagonal_mm
|
27 |
+
args:
|
28 |
+
dtype: str in ['float64', 'float32', 'float16']
|
29 |
+
device: str in ['cpu' or 'cuda']
|
30 |
+
b0, b1, b2: size of tensor tiles. Very important for good performance
|
31 |
+
|
32 |
+
'''
|
33 |
+
import tvm # import the full tvm library here for compilation. Don't import at the top of the file in case we don't need to compile
|
34 |
+
from tvm.contrib import nvcc
|
35 |
+
@tvm.register_func
|
36 |
+
def tvm_callback_cuda_compile(code):
|
37 |
+
"""Use nvcc compiler for better perf."""
|
38 |
+
ptx = nvcc.compile_cuda(code, target="ptx", arch='sm_52') # use old arch for this to work on old GPUs
|
39 |
+
return ptx
|
40 |
+
|
41 |
+
assert dtype in ['float16', 'float32', 'float64']
|
42 |
+
assert device in ['cpu', 'cuda']
|
43 |
+
device = None if device == 'cpu' else device
|
44 |
+
tgt_host="llvm"
|
45 |
+
|
46 |
+
b = tvm.var('b') # batch size
|
47 |
+
n = tvm.var('n') # sequence length
|
48 |
+
h = tvm.var('h') # number of heads
|
49 |
+
m = tvm.var('m') # hidden dimension
|
50 |
+
w = tvm.var('w') # window size
|
51 |
+
w_upper = tvm.var('w_upper') # window size to the right of the word. Should be `0` or `w`
|
52 |
+
padding = tvm.var('padding') # padding
|
53 |
+
transpose_t1 = tvm.var('transpose_t1') # t1 should be transposed
|
54 |
+
t1d3 = tvm.var('t1d3') # last dimension of t1
|
55 |
+
t3d3 = tvm.var('t3d3') # last dimension of t3 (the result tensor)
|
56 |
+
X = tvm.placeholder((b, n, h, t1d3), name='X', dtype=dtype) # first tensor
|
57 |
+
Y = tvm.placeholder((b, n, h, m), name='Y', dtype=dtype) # second tensor
|
58 |
+
k = tvm.reduce_axis((0, t1d3), name='k') # dimension to sum over
|
59 |
+
D = tvm.placeholder((h), name='D', dtype='int') # dilation per head
|
60 |
+
output_shape = (b, n, h, t3d3) # shape of the result tensor
|
61 |
+
algorithm = lambda l, i, q, j: tvm.sum(
|
62 |
+
tvm.if_then_else(
|
63 |
+
t3d3 == m, # if output dimension == m, then t1 is diagonaled (FIXME: This breaks if t3d3 == m == t1d3)
|
64 |
+
tvm.if_then_else(
|
65 |
+
transpose_t1 == 0,
|
66 |
+
tvm.if_then_else(
|
67 |
+
tvm.all(
|
68 |
+
i + D[q] * (k - w) >= 0,
|
69 |
+
i + D[q] * (k - w) < n,
|
70 |
+
),
|
71 |
+
X[l, i, q, k] * Y[l, i + D[q] * (k - w), q, j], # t1 is diagonaled
|
72 |
+
padding
|
73 |
+
),
|
74 |
+
tvm.if_then_else(
|
75 |
+
tvm.all(
|
76 |
+
i + D[q] * (k - w_upper) >= 0, # `w_upper` to handle the case `autoregressive=True`
|
77 |
+
i + D[q] * (k - w_upper) < n,
|
78 |
+
),
|
79 |
+
X[l, i + D[q] * (k - w_upper), q, (w_upper + w) - k] * Y[l, i + D[q] * (k - w_upper), q, j], # # t1 is diagonaled and should be transposed
|
80 |
+
padding
|
81 |
+
),
|
82 |
+
),
|
83 |
+
tvm.if_then_else(
|
84 |
+
tvm.all(
|
85 |
+
i + D[q] * (j - w) >= 0,
|
86 |
+
i + D[q] * (j - w) < n,
|
87 |
+
),
|
88 |
+
X[l, i, q, k] * Y[l, i + D[q] * (j - w), q, k], # t1 is not diagonaled, but the output tensor is going to be
|
89 |
+
padding
|
90 |
+
)
|
91 |
+
), axis=k)
|
92 |
+
|
93 |
+
Z = tvm.compute(output_shape, algorithm, name='Z') # automatically generate cuda code
|
94 |
+
s = tvm.create_schedule(Z.op)
|
95 |
+
|
96 |
+
print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True)))
|
97 |
+
|
98 |
+
# split long axis into smaller chunks and assing each one to a separate GPU thread/block
|
99 |
+
ko, ki = s[Z].split(Z.op.reduce_axis[0], factor=b0)
|
100 |
+
ZF = s.rfactor(Z, ki)
|
101 |
+
|
102 |
+
j_outer, j_inner = s[Z].split(s[Z].op.axis[-1], factor=b1)
|
103 |
+
i_outer, i_inner = s[Z].split(s[Z].op.axis[1], factor=b2)
|
104 |
+
|
105 |
+
s[Z].bind(j_outer, tvm.thread_axis("blockIdx.x"))
|
106 |
+
s[Z].bind(j_inner, tvm.thread_axis("threadIdx.y"))
|
107 |
+
|
108 |
+
s[Z].bind(i_outer, tvm.thread_axis("blockIdx.y"))
|
109 |
+
s[Z].bind(i_inner, tvm.thread_axis("threadIdx.z"))
|
110 |
+
|
111 |
+
tx = tvm.thread_axis("threadIdx.x")
|
112 |
+
s[Z].bind(s[Z].op.reduce_axis[0], tx)
|
113 |
+
s[ZF].compute_at(s[Z], s[Z].op.reduce_axis[0])
|
114 |
+
s[Z].set_store_predicate(tx.var.equal(0))
|
115 |
+
|
116 |
+
print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True)))
|
117 |
+
|
118 |
+
# compiling the automatically generated cuda code
|
119 |
+
diagonaled_mm = tvm.build(s, [X, Y, Z, D, w, w_upper, padding, transpose_t1, t3d3], target=device, target_host=tgt_host, name='diagonaled_mm')
|
120 |
+
return diagonaled_mm
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def _get_lib_filename(dtype: str, device: str):
|
124 |
+
base_filename = 'longformer/lib/lib_diagonaled_mm'
|
125 |
+
return '{}_{}_{}.so'.format(base_filename, dtype, device)
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def _save_compiled_function(f, dtype: str, device: str):
|
129 |
+
if not os.path.exists('longformer/lib/'):
|
130 |
+
os.makedirs('longformer/lib/')
|
131 |
+
f.export_library(DiagonaledMM._get_lib_filename(dtype, device))
|
132 |
+
|
133 |
+
@staticmethod
|
134 |
+
def _load_compiled_function(dtype: str, device: str):
|
135 |
+
from tvm.module import load # this can be the small runtime python library, and doesn't need to be the whole thing
|
136 |
+
filename = DiagonaledMM._get_lib_filename(dtype, device)
|
137 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
138 |
+
potential_dirs = ['../../', '../', './', f'{current_dir}/', f'{current_dir}/../']
|
139 |
+
for potential_dir in potential_dirs:
|
140 |
+
filepath = '{}{}'.format(potential_dir, filename)
|
141 |
+
if os.path.isfile(filepath):
|
142 |
+
print('Loading tvm binary from: {}'.format(filepath))
|
143 |
+
return load(filepath)
|
144 |
+
return None
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def _get_function(dtype: str, device: str):
|
148 |
+
'''Loads the function from the disk or compile it'''
|
149 |
+
# A list of arguments that define the function
|
150 |
+
args = (dtype, device)
|
151 |
+
if args not in DiagonaledMM.function_dict:
|
152 |
+
diagonaled_mm = DiagonaledMM._load_compiled_function(dtype, device) # try to load from disk
|
153 |
+
if not diagonaled_mm:
|
154 |
+
print('Tvm binary not found. Compiling ...')
|
155 |
+
diagonaled_mm = DiagonaledMM._compile_function(dtype, device) # compile
|
156 |
+
DiagonaledMM._save_compiled_function(diagonaled_mm, dtype, device) # save to disk
|
157 |
+
# convert the tvm function into a pytorch function
|
158 |
+
from tvm.contrib import dlpack
|
159 |
+
diagonaled_mm_pytorch = dlpack.to_pytorch_func(diagonaled_mm) # wrap it as a pytorch function
|
160 |
+
# save the function into a dictionary to be reused
|
161 |
+
DiagonaledMM.function_dict[args] = diagonaled_mm_pytorch # save it in a dictionary for next time
|
162 |
+
return DiagonaledMM.function_dict[args]
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def _diagonaled_mm(t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int],
|
166 |
+
is_t1_diagonaled: bool = False, transpose_t1: bool = False, padding: int = 0,
|
167 |
+
autoregressive: bool = False):
|
168 |
+
'''Calls the compiled function after checking the input format. This function is called in three different modes.
|
169 |
+
t1 x t2 = r ==> t1 and t2 are not diagonaled, but r is. Useful for query x key = attention_scores
|
170 |
+
t1 x t2 = r ==> t1 is diagonaled, but t2 and r are not. Useful to compuate attantion_scores x value = context
|
171 |
+
t1 x t2 = r ==> t1 is diagonaled and it should be transposed, but t2 and r are not diagonaled. Useful in some of
|
172 |
+
the calculations in the backward pass.
|
173 |
+
'''
|
174 |
+
dtype = str(t1.dtype).split('.')[1]
|
175 |
+
device = t1.device.type
|
176 |
+
assert len(t1.shape) == 4
|
177 |
+
assert len(t1.shape) == len(t2.shape)
|
178 |
+
assert t1.shape[:3] == t2.shape[:3]
|
179 |
+
if isinstance(d, int): # if d is an integer, replace it with a tensor of the same length
|
180 |
+
# as number of heads, and it is filled with the same dilation value
|
181 |
+
d = t1.new_full(size=(t1.shape[2],), fill_value=d, dtype=torch.int, requires_grad=False)
|
182 |
+
|
183 |
+
assert len(d.shape) == 1
|
184 |
+
assert d.shape[0] == t1.shape[2] # number of dilation scores should match number of heads
|
185 |
+
b = t1.shape[0] # batch size
|
186 |
+
n = t1.shape[1] # sequence length
|
187 |
+
h = t1.shape[2] # number of heads
|
188 |
+
m = t2.shape[3] # hidden dimension
|
189 |
+
w_upper = 0 if autoregressive else w
|
190 |
+
c = w_upper + w + 1 # number of diagonals
|
191 |
+
if is_t1_diagonaled:
|
192 |
+
assert t1.shape[3] == c
|
193 |
+
r = t1.new_empty(b, n, h, m) # allocate spase for the result tensor
|
194 |
+
else:
|
195 |
+
assert not transpose_t1
|
196 |
+
assert t1.shape[3] == m
|
197 |
+
r = t1.new_empty(b, n, h, c) # allocate spase for the result tensor
|
198 |
+
|
199 |
+
# gets function from memory, from disk or compiles it from scratch
|
200 |
+
_diagonaled_mm_function = DiagonaledMM._get_function(dtype=dtype, device=device)
|
201 |
+
|
202 |
+
# The last argument to this function is a little hacky. It is the size of the last dimension of the result tensor
|
203 |
+
# We use it as a proxy to tell if t1_is_diagonaled or not (if t1 is diagonaled, result is not, and vice versa).
|
204 |
+
# The second reason is that the lambda expression in `_compile_function` is easier to express when the shape
|
205 |
+
# of the output is known
|
206 |
+
# This functions computes diagonal_mm then saves the result in `r`
|
207 |
+
if m == c:
|
208 |
+
# FIXME
|
209 |
+
print('Error: the hidden dimension {m} shouldn\'t match number of diagonals {c}')
|
210 |
+
assert False
|
211 |
+
_diagonaled_mm_function(t1, t2, r, d, w, w_upper, padding, transpose_t1, m if is_t1_diagonaled else c)
|
212 |
+
return r
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def _prepare_tensors(t):
|
216 |
+
'''Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch.
|
217 |
+
For a tensor t, if t.size(0) == 1, then the value of t.stride()[0] doesn't matter.
|
218 |
+
TVM expects this value to be the `product(t.size()[1:])` but PyTorch some times sets it to `t.stride()[1]`.
|
219 |
+
Here's an example to reporduce this issue:
|
220 |
+
import torch
|
221 |
+
print(torch.randn(1, 10).stride())
|
222 |
+
> (10, 1)
|
223 |
+
print(torch.randn(10, 1).t().contiguous().stride())
|
224 |
+
> (1, 1) # expected it to be (10, 1) as above
|
225 |
+
print(torch.randn(10, 2).t().contiguous().stride())
|
226 |
+
> (10, 1) # but gets the expected stride if the first dimension is > 1
|
227 |
+
'''
|
228 |
+
assert t.is_contiguous()
|
229 |
+
t_stride = list(t.stride())
|
230 |
+
t_size = list(t.size())
|
231 |
+
# Fix wrong stride information for the first dimension. This occures when batch_size=1
|
232 |
+
if t_size[0] == 1 and t_stride[0] == t_stride[1]:
|
233 |
+
# In this case, the stride of the first dimension should be the product
|
234 |
+
# of the sizes of all other dimensions
|
235 |
+
t_stride[0] = t_size[1] * t_size[2] * t_size[3]
|
236 |
+
t = t.as_strided(size=t_size, stride=t_stride)
|
237 |
+
return t
|
238 |
+
|
239 |
+
min_seq_len = 16 # unexpected output if seq_len < 16
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
def forward(ctx, t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int], is_t1_diagonaled: bool = False, padding: int = 0, autoregressive: bool = False) -> torch.Tensor:
|
243 |
+
'''Compuates diagonal_mm of t1 and t2.
|
244 |
+
args:
|
245 |
+
t1: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals).
|
246 |
+
t1 can be a regular tensor (e.g. `query_layer`) or a diagonaled one (e.g. `attention_scores`)
|
247 |
+
t2: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size). This is always a non-diagonaled
|
248 |
+
tensor, e.g. `key_layer` or `value_layer`
|
249 |
+
w: int = window size; number of attentions on each side of the word
|
250 |
+
d: torch.Tensor or int = dilation of attentions per attention head. If int, the same dilation value will be used for all
|
251 |
+
heads. If torch.Tensor, it should be 1D of lenth=number of attention heads
|
252 |
+
is_t1_diagonaled: is t1 a diagonaled or a regular tensor
|
253 |
+
padding: the padding value to use when accessing invalid locations. This is mainly useful when the padding
|
254 |
+
needs to be a very large negative value (to compute softmax of attentions). For other usecases,
|
255 |
+
please use zero padding.
|
256 |
+
autoregressive: if true, return only the lower triangle
|
257 |
+
returns: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals)
|
258 |
+
if t1 is diagonaed, result is non-diagonaled, and vice versa
|
259 |
+
'''
|
260 |
+
batch_size, seq_len, num_attention_heads, hidden_size = t1.size()
|
261 |
+
assert seq_len >= DiagonaledMM.min_seq_len, 'avoid splitting errors by using seq_len >= {}'.format(DiagonaledMM.min_seq_len) # FIXME
|
262 |
+
ctx.save_for_backward(t1, t2)
|
263 |
+
ctx.w = w
|
264 |
+
ctx.d = d
|
265 |
+
ctx.is_t1_diagonaled = is_t1_diagonaled
|
266 |
+
ctx.autoregressive = autoregressive
|
267 |
+
t1 = DiagonaledMM._prepare_tensors(t1)
|
268 |
+
t2 = DiagonaledMM._prepare_tensors(t2)
|
269 |
+
# output = t1.mm(t2) # what would have been called if this was a regular matmul
|
270 |
+
output = DiagonaledMM._diagonaled_mm(t1, t2, w, d, is_t1_diagonaled=is_t1_diagonaled, padding=padding, autoregressive=autoregressive)
|
271 |
+
return output
|
272 |
+
|
273 |
+
@staticmethod
|
274 |
+
def backward(ctx, grad_output):
|
275 |
+
t1, t2 = ctx.saved_tensors
|
276 |
+
w = ctx.w
|
277 |
+
d = ctx.d
|
278 |
+
is_t1_diagonaled = ctx.is_t1_diagonaled
|
279 |
+
autoregressive = ctx.autoregressive
|
280 |
+
if not grad_output.is_contiguous():
|
281 |
+
grad_output = grad_output.contiguous() # tvm requires all input tensors to be contiguous
|
282 |
+
grad_output = DiagonaledMM._prepare_tensors(grad_output)
|
283 |
+
t1 = DiagonaledMM._prepare_tensors(t1)
|
284 |
+
t2 = DiagonaledMM._prepare_tensors(t2)
|
285 |
+
# http://cs231n.github.io/optimization-2/
|
286 |
+
# https://pytorch.org/docs/master/notes/extending.html
|
287 |
+
# grad_t1 = grad_output.mm(t2) # what would have been called if this was a regular matmul
|
288 |
+
grad_t1 = DiagonaledMM._diagonaled_mm(grad_output, t2, w, d, is_t1_diagonaled=not is_t1_diagonaled, autoregressive=autoregressive)
|
289 |
+
# grad_t2 = grad_output.t().mm(t1) # or `grad_t2 = t1.t().mm(grad_output).t()` because `(AB)^T = B^TA^T`
|
290 |
+
if is_t1_diagonaled:
|
291 |
+
grad_t2 = DiagonaledMM._diagonaled_mm(t1, grad_output, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
|
292 |
+
else:
|
293 |
+
grad_t2 = DiagonaledMM._diagonaled_mm(grad_output, t1, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
|
294 |
+
return grad_t1, grad_t2, None, None, None, None, None
|
295 |
+
|
296 |
+
|
297 |
+
def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int):
|
298 |
+
diagonals_list = []
|
299 |
+
for j in range(-d * w, d, d):
|
300 |
+
diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8)
|
301 |
+
diagonal_mask[:-j] = 1
|
302 |
+
diagonals_list.append(diagonal_mask)
|
303 |
+
return torch.stack(diagonals_list, dim=-1)
|
304 |
+
|
305 |
+
@lru_cache()
|
306 |
+
def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):
|
307 |
+
if isinstance(d, int):
|
308 |
+
affected_seq_len = w * d
|
309 |
+
mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
|
310 |
+
mask = mask[None, :, None, :]
|
311 |
+
else:
|
312 |
+
affected_seq_len = w * d.max()
|
313 |
+
head_masks = []
|
314 |
+
d_list = d.cpu().numpy().tolist()
|
315 |
+
for d in d_list:
|
316 |
+
one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
|
317 |
+
head_masks.append(one_head_mask)
|
318 |
+
mask = torch.stack(head_masks, dim=-2)
|
319 |
+
mask = mask[None, :, :, :]
|
320 |
+
|
321 |
+
ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device)
|
322 |
+
return affected_seq_len, mask.bool().to(device), ending_mask
|
323 |
+
|
324 |
+
def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor:
|
325 |
+
affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device)
|
326 |
+
seq_len = input_tensor.size(1)
|
327 |
+
beginning_input = input_tensor[:, :affected_seq_len, :, :w+1]
|
328 |
+
beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size())
|
329 |
+
beginning_input.masked_fill_(beginning_mask, -float('inf'))
|
330 |
+
if not autoregressive:
|
331 |
+
ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):]
|
332 |
+
ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size())
|
333 |
+
ending_input.masked_fill_(ending_mask, -float('inf'))
|
334 |
+
|
335 |
+
|
336 |
+
diagonaled_mm = DiagonaledMM.apply
|
337 |
+
|
338 |
+
# The non-tvm implementation is the default, we don't need to load the kernel at loading time.
|
339 |
+
# DiagonaledMM._get_function('float32', 'cuda')
|
long_models/longformer.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""
|
5 |
+
|
6 |
+
This code is adapted from AllenAI's Longformer:
|
7 |
+
https://github.com/allenai/longformer/
|
8 |
+
|
9 |
+
"""
|
10 |
+
from typing import List
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from .diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations
|
16 |
+
from .sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv
|
17 |
+
from .sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv
|
18 |
+
from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM
|
19 |
+
|
20 |
+
|
21 |
+
class Longformer(RobertaModel):
|
22 |
+
def __init__(self, config):
|
23 |
+
super(Longformer, self).__init__(config)
|
24 |
+
if config.attention_mode == 'n2':
|
25 |
+
pass # do nothing, use BertSelfAttention instead
|
26 |
+
else:
|
27 |
+
for i, layer in enumerate(self.encoder.layer):
|
28 |
+
layer.attention.self = LongformerSelfAttention(config, layer_id=i)
|
29 |
+
|
30 |
+
|
31 |
+
class LongformerForMaskedLM(RobertaForMaskedLM):
|
32 |
+
def __init__(self, config):
|
33 |
+
super(LongformerForMaskedLM, self).__init__(config)
|
34 |
+
if config.attention_mode == 'n2':
|
35 |
+
pass # do nothing, use BertSelfAttention instead
|
36 |
+
else:
|
37 |
+
for i, layer in enumerate(self.roberta.encoder.layer):
|
38 |
+
layer.attention.self = LongformerSelfAttention(config, layer_id=i)
|
39 |
+
|
40 |
+
|
41 |
+
class LongformerConfig(RobertaConfig):
|
42 |
+
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
|
43 |
+
autoregressive: bool = False, attention_mode: str = 'sliding_chunks', **kwargs):
|
44 |
+
"""
|
45 |
+
Args:
|
46 |
+
attention_window: list of attention window sizes of length = number of layers.
|
47 |
+
window size = number of attention locations on each side.
|
48 |
+
For an affective window size of 512, use `attention_window=[256]*num_layers`
|
49 |
+
which is 256 on each side.
|
50 |
+
attention_dilation: list of attention dilation of length = number of layers.
|
51 |
+
attention dilation of `1` means no dilation.
|
52 |
+
autoregressive: do autoregressive attention or have attention of both sides
|
53 |
+
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
|
54 |
+
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
|
55 |
+
"""
|
56 |
+
super().__init__(**kwargs)
|
57 |
+
self.attention_window = attention_window
|
58 |
+
self.attention_dilation = attention_dilation
|
59 |
+
self.autoregressive = autoregressive
|
60 |
+
self.attention_mode = attention_mode
|
61 |
+
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2', 'sliding_chunks_no_overlap']
|
62 |
+
|
63 |
+
|
64 |
+
class LongformerSelfAttention(nn.Module):
|
65 |
+
def __init__(self, config, layer_id):
|
66 |
+
super(LongformerSelfAttention, self).__init__()
|
67 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
68 |
+
raise ValueError(
|
69 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
70 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
71 |
+
self.num_heads = config.num_attention_heads
|
72 |
+
self.head_dim = int(config.hidden_size / config.num_attention_heads)
|
73 |
+
self.embed_dim = config.hidden_size
|
74 |
+
|
75 |
+
self.query = nn.Linear(config.hidden_size, self.embed_dim)
|
76 |
+
self.key = nn.Linear(config.hidden_size, self.embed_dim)
|
77 |
+
self.value = nn.Linear(config.hidden_size, self.embed_dim)
|
78 |
+
|
79 |
+
self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
|
80 |
+
self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
|
81 |
+
self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
|
82 |
+
|
83 |
+
self.dropout = config.attention_probs_dropout_prob
|
84 |
+
|
85 |
+
self.layer_id = layer_id
|
86 |
+
self.attention_window = config.attention_window[self.layer_id]
|
87 |
+
self.attention_dilation = config.attention_dilation[self.layer_id]
|
88 |
+
self.attention_mode = config.attention_mode
|
89 |
+
self.autoregressive = config.autoregressive
|
90 |
+
assert self.attention_window > 0
|
91 |
+
assert self.attention_dilation > 0
|
92 |
+
assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap']
|
93 |
+
if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']:
|
94 |
+
assert not self.autoregressive # not supported
|
95 |
+
assert self.attention_dilation == 1 # dilation is not supported
|
96 |
+
|
97 |
+
def forward(
|
98 |
+
self,
|
99 |
+
hidden_states,
|
100 |
+
attention_mask=None,
|
101 |
+
head_mask=None,
|
102 |
+
encoder_hidden_states=None,
|
103 |
+
encoder_attention_mask=None,
|
104 |
+
output_attentions=False,
|
105 |
+
):
|
106 |
+
'''
|
107 |
+
The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
|
108 |
+
-ve: no attention
|
109 |
+
0: local attention
|
110 |
+
+ve: global attention
|
111 |
+
'''
|
112 |
+
assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None"
|
113 |
+
assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and should be None"
|
114 |
+
|
115 |
+
if attention_mask is not None:
|
116 |
+
key_padding_mask = attention_mask < 0
|
117 |
+
extra_attention_mask = attention_mask > 0
|
118 |
+
remove_from_windowed_attention_mask = attention_mask != 0
|
119 |
+
|
120 |
+
num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1)
|
121 |
+
max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
|
122 |
+
if max_num_extra_indices_per_batch <= 0:
|
123 |
+
extra_attention_mask = None
|
124 |
+
else:
|
125 |
+
# To support the case of variable number of global attention in the rows of a batch,
|
126 |
+
# we use the following three selection masks to select global attention embeddings
|
127 |
+
# in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
|
128 |
+
# 1) selecting embeddings that correspond to global attention
|
129 |
+
extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True)
|
130 |
+
zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch,
|
131 |
+
device=num_extra_indices_per_batch.device)
|
132 |
+
# mask indicating which values are actually going to be padding
|
133 |
+
selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1)
|
134 |
+
# 2) location of the non-padding values in the selected global attention
|
135 |
+
selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True)
|
136 |
+
# 3) location of the padding values in the selected global attention
|
137 |
+
selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True)
|
138 |
+
else:
|
139 |
+
remove_from_windowed_attention_mask = None
|
140 |
+
extra_attention_mask = None
|
141 |
+
key_padding_mask = None
|
142 |
+
|
143 |
+
hidden_states = hidden_states.transpose(0, 1)
|
144 |
+
seq_len, bsz, embed_dim = hidden_states.size()
|
145 |
+
assert embed_dim == self.embed_dim
|
146 |
+
q = self.query(hidden_states)
|
147 |
+
k = self.key(hidden_states)
|
148 |
+
v = self.value(hidden_states)
|
149 |
+
q /= math.sqrt(self.head_dim)
|
150 |
+
|
151 |
+
q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
|
152 |
+
k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
|
153 |
+
# attn_weights = (bsz, seq_len, num_heads, window*2+1)
|
154 |
+
if self.attention_mode == 'tvm':
|
155 |
+
q = q.float().contiguous()
|
156 |
+
k = k.float().contiguous()
|
157 |
+
attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False)
|
158 |
+
elif self.attention_mode == "sliding_chunks":
|
159 |
+
attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0)
|
160 |
+
elif self.attention_mode == "sliding_chunks_no_overlap":
|
161 |
+
attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0)
|
162 |
+
else:
|
163 |
+
raise False
|
164 |
+
mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False)
|
165 |
+
if remove_from_windowed_attention_mask is not None:
|
166 |
+
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
|
167 |
+
# from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size)
|
168 |
+
remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1)
|
169 |
+
# cast to float/half then replace 1's with -inf
|
170 |
+
float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0)
|
171 |
+
repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation)
|
172 |
+
float_mask = float_mask.repeat(1, 1, repeat_size, 1)
|
173 |
+
ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
|
174 |
+
# diagonal mask with zeros everywhere and -inf inplace of padding
|
175 |
+
if self.attention_mode == 'tvm':
|
176 |
+
d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False)
|
177 |
+
elif self.attention_mode == "sliding_chunks":
|
178 |
+
d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
|
179 |
+
elif self.attention_mode == "sliding_chunks_no_overlap":
|
180 |
+
d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
|
181 |
+
|
182 |
+
attn_weights += d_mask
|
183 |
+
assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads]
|
184 |
+
assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3]
|
185 |
+
|
186 |
+
# the extra attention
|
187 |
+
if extra_attention_mask is not None:
|
188 |
+
selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
|
189 |
+
selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
|
190 |
+
# (bsz, seq_len, num_heads, max_num_extra_indices_per_batch)
|
191 |
+
selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k))
|
192 |
+
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
|
193 |
+
# concat to attn_weights
|
194 |
+
# (bsz, seq_len, num_heads, extra attention count + 2*window+1)
|
195 |
+
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
|
196 |
+
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
|
197 |
+
if key_padding_mask is not None:
|
198 |
+
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
199 |
+
attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
|
200 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
201 |
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
202 |
+
v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
|
203 |
+
attn = 0
|
204 |
+
if extra_attention_mask is not None:
|
205 |
+
selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch)
|
206 |
+
selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
|
207 |
+
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
|
208 |
+
# use `matmul` because `einsum` crashes sometimes with fp16
|
209 |
+
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
|
210 |
+
attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2)
|
211 |
+
attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous()
|
212 |
+
|
213 |
+
if self.attention_mode == 'tvm':
|
214 |
+
v = v.float().contiguous()
|
215 |
+
attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False)
|
216 |
+
elif self.attention_mode == "sliding_chunks":
|
217 |
+
attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window)
|
218 |
+
elif self.attention_mode == "sliding_chunks_no_overlap":
|
219 |
+
attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window)
|
220 |
+
else:
|
221 |
+
raise False
|
222 |
+
|
223 |
+
attn = attn.type_as(hidden_states)
|
224 |
+
assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
|
225 |
+
attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous()
|
226 |
+
|
227 |
+
# For this case, we'll just recompute the attention for these indices
|
228 |
+
# and overwrite the attn tensor. TODO: remove the redundant computation
|
229 |
+
if extra_attention_mask is not None:
|
230 |
+
selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim)
|
231 |
+
selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]]
|
232 |
+
|
233 |
+
q = self.query_global(selected_hidden_states)
|
234 |
+
k = self.key_global(hidden_states)
|
235 |
+
v = self.value_global(hidden_states)
|
236 |
+
q /= math.sqrt(self.head_dim)
|
237 |
+
|
238 |
+
q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim)
|
239 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
|
240 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
|
241 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
242 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len]
|
243 |
+
|
244 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
|
245 |
+
attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0
|
246 |
+
if key_padding_mask is not None:
|
247 |
+
attn_weights = attn_weights.masked_fill(
|
248 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
249 |
+
-10000.0,
|
250 |
+
)
|
251 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len)
|
252 |
+
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
|
253 |
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
254 |
+
selected_attn = torch.bmm(attn_probs, v)
|
255 |
+
assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim]
|
256 |
+
|
257 |
+
selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim)
|
258 |
+
nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]]
|
259 |
+
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states)
|
260 |
+
|
261 |
+
context_layer = attn.transpose(0, 1) # attn shape: (seq_len, bsz, embed_dim), context_layer shape: (bsz, seq_len, embed_dim)
|
262 |
+
if output_attentions:
|
263 |
+
if extra_attention_mask is not None:
|
264 |
+
# With global attention, return global attention probabilities only
|
265 |
+
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
|
266 |
+
# which is the attention weights from tokens with global attention to all tokens
|
267 |
+
# It doesn't not return local attention
|
268 |
+
# In case of variable number of global attantion in the rows of a batch,
|
269 |
+
# attn_weights are padded with -10000.0 attention scores
|
270 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
|
271 |
+
else:
|
272 |
+
# without global attention, return local attention probabilities
|
273 |
+
# batch_size x num_heads x sequence_length x window_size
|
274 |
+
# which is the attention weights of every token attending to its neighbours
|
275 |
+
attn_weights = attn_weights.permute(0, 2, 1, 3)
|
276 |
+
outputs = (context_layer, attn_weights) if output_attentions else (context_layer,)
|
277 |
+
return outputs
|
long_models/longformer_bart.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""
|
5 |
+
|
6 |
+
This code is in part adapted from AllenAI's Longformer:
|
7 |
+
https://github.com/allenai/longformer/
|
8 |
+
and in part adapted from:
|
9 |
+
https://github.com/huggingface/transformers
|
10 |
+
|
11 |
+
Author: Annette Rios ([email protected])
|
12 |
+
|
13 |
+
"""
|
14 |
+
from typing import List, Optional, Tuple, Dict, Union
|
15 |
+
from torch import nn, Tensor, zeros
|
16 |
+
import torch
|
17 |
+
import math
|
18 |
+
import random
|
19 |
+
from .longformer import LongformerSelfAttention
|
20 |
+
from transformers.models.bart.modeling_bart import BartConfig, BartForConditionalGeneration, BartEncoder, BartLearnedPositionalEmbedding, BartEncoderLayer, BartDecoder, BartModel, _expand_mask
|
21 |
+
from transformers.modeling_outputs import BaseModelOutput
|
22 |
+
|
23 |
+
class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration):
|
24 |
+
def __init__(self, config):
|
25 |
+
super(BartForConditionalGeneration, self).__init__(config)
|
26 |
+
|
27 |
+
self.model = LongBartModel(config)
|
28 |
+
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
29 |
+
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
30 |
+
#print(self)
|
31 |
+
|
32 |
+
if config.attention_mode == 'n2':
|
33 |
+
pass # do nothing, use BartSelfAttention instead
|
34 |
+
else:
|
35 |
+
for i, layer in enumerate(self.model.encoder.layers):
|
36 |
+
layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)
|
37 |
+
# Initialize weights and apply final processing
|
38 |
+
self.post_init()
|
39 |
+
|
40 |
+
|
41 |
+
class LongformerEncoderDecoderConfig(BartConfig):
|
42 |
+
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
|
43 |
+
autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
|
44 |
+
gradient_checkpointing: bool = False, **kwargs):
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
attention_window: list of attention window sizes of length = number of layers.
|
48 |
+
window size = number of attention locations on each side.
|
49 |
+
For an affective window size of 512, use `attention_window=[256]*num_layers`
|
50 |
+
which is 256 on each side.
|
51 |
+
attention_dilation: list of attention dilation of length = number of layers.
|
52 |
+
attention dilation of `1` means no dilation.
|
53 |
+
autoregressive: do autoregressive attention or have attention of both sides
|
54 |
+
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
|
55 |
+
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
|
56 |
+
"""
|
57 |
+
super().__init__(**kwargs)
|
58 |
+
self.attention_window = attention_window
|
59 |
+
self.attention_dilation = attention_dilation
|
60 |
+
self.autoregressive = autoregressive
|
61 |
+
self.attention_mode = attention_mode
|
62 |
+
self.gradient_checkpointing = gradient_checkpointing
|
63 |
+
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
|
64 |
+
|
65 |
+
class LongformerSelfAttentionForBart(nn.Module):
|
66 |
+
def __init__(self, config, layer_id):
|
67 |
+
super().__init__()
|
68 |
+
self.embed_dim = config.d_model
|
69 |
+
self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
|
70 |
+
self.output = nn.Linear(self.embed_dim, self.embed_dim)
|
71 |
+
|
72 |
+
def forward(
|
73 |
+
self,
|
74 |
+
hidden_states: Tensor, # shape (batch_size, q_len, model_size)
|
75 |
+
key_value_states: Optional[Tensor] = None, # cross-attention in transformers.models.bart.modeling_bart
|
76 |
+
past_key_value: Optional[Tuple[Tensor]] = None, # only for decoder
|
77 |
+
attention_mask: Optional[Tensor] = None, # shape (batch_size, k_len) -> changed in transformers.models.modeling_bart.BartEncoder and BartEncoderLayer (new mask uses bool -> global attention positions are lost, need to use the inverted orignal mask
|
78 |
+
layer_head_mask: Optional[Tensor] = None, # head dropout?
|
79 |
+
output_attentions: bool = False
|
80 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
81 |
+
|
82 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
83 |
+
assert embed_dim == self.embed_dim
|
84 |
+
assert list(hidden_states.size()) == [bsz, tgt_len, embed_dim]
|
85 |
+
|
86 |
+
outputs = self.longformer_self_attn(
|
87 |
+
hidden_states,
|
88 |
+
attention_mask=attention_mask * -1, # shape (batch_size, 1, 1, key_len)
|
89 |
+
head_mask=None,
|
90 |
+
encoder_hidden_states=None,
|
91 |
+
encoder_attention_mask=None,
|
92 |
+
output_attentions=output_attentions,
|
93 |
+
)
|
94 |
+
|
95 |
+
## new: Bart encoder expects shape (seq_len, bsz, embed_dim), no transpose needed
|
96 |
+
attn_output = self.output(outputs[0])
|
97 |
+
# new return in BartAttention has attn_output, attn_weights_reshaped, past_key_value (only for decoder), need to return 3 values (None for past_key_value)
|
98 |
+
return (attn_output, outputs[1:] ,None) if len(outputs) == 2 else (attn_output, None, None)
|
99 |
+
|
100 |
+
|
101 |
+
class LongBartEncoder(BartEncoder):
|
102 |
+
"""
|
103 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
104 |
+
[`BartEncoderLayer`].
|
105 |
+
|
106 |
+
Args:
|
107 |
+
config: BartConfig
|
108 |
+
embed_tokens (nn.Embedding): output embedding
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
|
112 |
+
super().__init__(config)
|
113 |
+
|
114 |
+
self.dropout = config.dropout
|
115 |
+
self.layerdrop = config.encoder_layerdrop
|
116 |
+
|
117 |
+
embed_dim = config.d_model
|
118 |
+
self.padding_idx = config.pad_token_id
|
119 |
+
self.max_source_positions = config.max_encoder_position_embeddings
|
120 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
121 |
+
|
122 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
123 |
+
|
124 |
+
if embed_tokens is not None:
|
125 |
+
self.embed_tokens.weight = embed_tokens.weight
|
126 |
+
|
127 |
+
self.embed_positions = BartLearnedPositionalEmbedding(
|
128 |
+
self.max_source_positions,
|
129 |
+
embed_dim,
|
130 |
+
)
|
131 |
+
self.layers = nn.ModuleList([LongBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
132 |
+
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
133 |
+
|
134 |
+
self.gradient_checkpointing = False
|
135 |
+
# Initialize weights and apply final processing
|
136 |
+
self.post_init()
|
137 |
+
|
138 |
+
def forward(
|
139 |
+
self,
|
140 |
+
input_ids: torch.LongTensor = None,
|
141 |
+
attention_mask: Optional[torch.Tensor] = None,
|
142 |
+
head_mask: Optional[torch.Tensor] = None,
|
143 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
144 |
+
output_attentions: Optional[bool] = None,
|
145 |
+
output_hidden_states: Optional[bool] = None,
|
146 |
+
return_dict: Optional[bool] = None,
|
147 |
+
) -> Union[Tuple, BaseModelOutput]:
|
148 |
+
r"""
|
149 |
+
Args:
|
150 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
151 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
152 |
+
provide it.
|
153 |
+
|
154 |
+
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
155 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
156 |
+
|
157 |
+
[What are input IDs?](../glossary#input-ids)
|
158 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
159 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
160 |
+
|
161 |
+
- 1 for tokens that are **not masked**,
|
162 |
+
- 0 for tokens that are **masked**.
|
163 |
+
|
164 |
+
[What are attention masks?](../glossary#attention-mask)
|
165 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
166 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
167 |
+
|
168 |
+
- 1 indicates the head is **not masked**,
|
169 |
+
- 0 indicates the head is **masked**.
|
170 |
+
|
171 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
172 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
173 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
174 |
+
than the model's internal embedding lookup matrix.
|
175 |
+
output_attentions (`bool`, *optional*):
|
176 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
177 |
+
returned tensors for more detail.
|
178 |
+
output_hidden_states (`bool`, *optional*):
|
179 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
180 |
+
for more detail.
|
181 |
+
return_dict (`bool`, *optional*):
|
182 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
183 |
+
"""
|
184 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
185 |
+
output_hidden_states = (
|
186 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
187 |
+
)
|
188 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
189 |
+
|
190 |
+
# retrieve input_ids and inputs_embeds
|
191 |
+
if input_ids is not None and inputs_embeds is not None:
|
192 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
193 |
+
elif input_ids is not None:
|
194 |
+
input = input_ids
|
195 |
+
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
196 |
+
elif inputs_embeds is not None:
|
197 |
+
input = inputs_embeds[:, :, -1]
|
198 |
+
else:
|
199 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
200 |
+
|
201 |
+
if inputs_embeds is None:
|
202 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
203 |
+
|
204 |
+
embed_pos = self.embed_positions(input)
|
205 |
+
embed_pos = embed_pos.to(inputs_embeds.device)
|
206 |
+
|
207 |
+
hidden_states = inputs_embeds + embed_pos
|
208 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
209 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
210 |
+
|
211 |
+
# expand attention_mask
|
212 |
+
longformer_attention_mask = None
|
213 |
+
if attention_mask is not None:
|
214 |
+
# need to return original, inverted mask for longformer attention, else value for global attention (=2 in given mask, will be -1) is lost
|
215 |
+
longformer_attention_mask = 1 - attention_mask
|
216 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
217 |
+
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
|
218 |
+
|
219 |
+
|
220 |
+
encoder_states = () if output_hidden_states else None
|
221 |
+
all_attentions = () if output_attentions else None
|
222 |
+
|
223 |
+
# check if head_mask has a correct number of layers specified if desired
|
224 |
+
if head_mask is not None:
|
225 |
+
if head_mask.size()[0] != len(self.layers):
|
226 |
+
raise ValueError(
|
227 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
228 |
+
f" {head_mask.size()[0]}."
|
229 |
+
)
|
230 |
+
|
231 |
+
for idx, encoder_layer in enumerate(self.layers):
|
232 |
+
if output_hidden_states:
|
233 |
+
encoder_states = encoder_states + (hidden_states,)
|
234 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
235 |
+
dropout_probability = random.uniform(0, 1)
|
236 |
+
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
237 |
+
layer_outputs = (None, None)
|
238 |
+
else:
|
239 |
+
if self.gradient_checkpointing and self.training:
|
240 |
+
|
241 |
+
def create_custom_forward(module):
|
242 |
+
def custom_forward(*inputs):
|
243 |
+
return module(*inputs, output_attentions)
|
244 |
+
|
245 |
+
return custom_forward
|
246 |
+
|
247 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
248 |
+
create_custom_forward(encoder_layer),
|
249 |
+
hidden_states,
|
250 |
+
attention_mask,
|
251 |
+
longformer_attention_mask,
|
252 |
+
(head_mask[idx] if head_mask is not None else None),
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
layer_outputs = encoder_layer(
|
256 |
+
hidden_states,
|
257 |
+
attention_mask,
|
258 |
+
longformer_attention_mask,
|
259 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
260 |
+
output_attentions=output_attentions,
|
261 |
+
)
|
262 |
+
|
263 |
+
hidden_states = layer_outputs[0]
|
264 |
+
|
265 |
+
if output_attentions:
|
266 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
267 |
+
|
268 |
+
if output_hidden_states:
|
269 |
+
encoder_states = encoder_states + (hidden_states,)
|
270 |
+
|
271 |
+
if not return_dict:
|
272 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
273 |
+
return BaseModelOutput(
|
274 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
275 |
+
)
|
276 |
+
|
277 |
+
|
278 |
+
class LongBartModel(BartModel):
|
279 |
+
def __init__(self, config: BartConfig):
|
280 |
+
super().__init__(config)
|
281 |
+
|
282 |
+
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
283 |
+
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
284 |
+
|
285 |
+
self.encoder = LongBartEncoder(config, self.shared)
|
286 |
+
self.decoder = BartDecoder(config, self.shared)
|
287 |
+
|
288 |
+
# Initialize weights and apply final processing
|
289 |
+
self.post_init()
|
290 |
+
|
291 |
+
|
292 |
+
class LongBartEncoderLayer(BartEncoderLayer):
|
293 |
+
def __init__(self, config: BartConfig):
|
294 |
+
super().__init__(config)
|
295 |
+
|
296 |
+
def forward(
|
297 |
+
self,
|
298 |
+
hidden_states: torch.FloatTensor,
|
299 |
+
attention_mask: torch.FloatTensor,
|
300 |
+
longformer_attention_mask: torch.Tensor,
|
301 |
+
layer_head_mask: torch.FloatTensor,
|
302 |
+
output_attentions: bool = False,
|
303 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
304 |
+
"""
|
305 |
+
Args:
|
306 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
307 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
308 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
309 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
310 |
+
`(encoder_attention_heads,)`.
|
311 |
+
output_attentions (`bool`, *optional*):
|
312 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
313 |
+
returned tensors for more detail.
|
314 |
+
"""
|
315 |
+
# if longformer attention instead of bart self attention: use special mask
|
316 |
+
if isinstance(self.self_attn, LongformerSelfAttentionForBart):
|
317 |
+
attention_mask = longformer_attention_mask
|
318 |
+
residual = hidden_states
|
319 |
+
hidden_states, attn_weights, _ = self.self_attn(
|
320 |
+
hidden_states=hidden_states,
|
321 |
+
attention_mask=attention_mask,
|
322 |
+
layer_head_mask=layer_head_mask,
|
323 |
+
output_attentions=output_attentions,
|
324 |
+
)
|
325 |
+
|
326 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
327 |
+
hidden_states, attn_weights, _ = self.self_attn(
|
328 |
+
hidden_states=hidden_states,
|
329 |
+
attention_mask=attention_mask,
|
330 |
+
layer_head_mask=layer_head_mask,
|
331 |
+
output_attentions=output_attentions,
|
332 |
+
)
|
333 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
334 |
+
hidden_states = residual + hidden_states
|
335 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
336 |
+
|
337 |
+
residual = hidden_states
|
338 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
339 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
340 |
+
hidden_states = self.fc2(hidden_states)
|
341 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
342 |
+
hidden_states = residual + hidden_states
|
343 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
344 |
+
|
345 |
+
if hidden_states.dtype == torch.float16 and (
|
346 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
347 |
+
):
|
348 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
349 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
350 |
+
|
351 |
+
outputs = (hidden_states,)
|
352 |
+
|
353 |
+
if output_attentions:
|
354 |
+
outputs += (attn_weights,)
|
355 |
+
|
356 |
+
return outputs
|
long_models/longformer_mbart.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""
|
5 |
+
|
6 |
+
This code is in part adapted from AllenAI's Longformer:
|
7 |
+
https://github.com/allenai/longformer/
|
8 |
+
and in part adapted from:
|
9 |
+
https://github.com/huggingface/transformers
|
10 |
+
|
11 |
+
Author: Annette Rios ([email protected])
|
12 |
+
|
13 |
+
"""
|
14 |
+
from typing import List, Optional, Tuple, Dict, Union
|
15 |
+
from torch import nn, Tensor, zeros
|
16 |
+
import torch
|
17 |
+
import math
|
18 |
+
import random
|
19 |
+
from .longformer import LongformerSelfAttention
|
20 |
+
from transformers.models.mbart.modeling_mbart import MBartConfig, MBartForConditionalGeneration, MBartEncoder, MBartLearnedPositionalEmbedding, MBartEncoderLayer, MBartDecoder, MBartModel, _expand_mask
|
21 |
+
from transformers.modeling_outputs import BaseModelOutput
|
22 |
+
|
23 |
+
class MLongformerEncoderDecoderForConditionalGeneration(MBartForConditionalGeneration):
|
24 |
+
def __init__(self, config):
|
25 |
+
super(MBartForConditionalGeneration, self).__init__(config)
|
26 |
+
|
27 |
+
self.model = LongMBartModel(config)
|
28 |
+
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
29 |
+
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
30 |
+
#print(self)
|
31 |
+
|
32 |
+
if config.attention_mode == 'n2':
|
33 |
+
pass # do nothing, use MBartSelfAttention instead
|
34 |
+
else:
|
35 |
+
for i, layer in enumerate(self.model.encoder.layers):
|
36 |
+
layer.self_attn = LongformerSelfAttentionForMBart(config, layer_id=i)
|
37 |
+
# Initialize weights and apply final processing
|
38 |
+
self.post_init()
|
39 |
+
|
40 |
+
|
41 |
+
class MLongformerEncoderDecoderConfig(MBartConfig):
|
42 |
+
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
|
43 |
+
autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
|
44 |
+
gradient_checkpointing: bool = False, **kwargs):
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
attention_window: list of attention window sizes of length = number of layers.
|
48 |
+
window size = number of attention locations on each side.
|
49 |
+
For an affective window size of 512, use `attention_window=[256]*num_layers`
|
50 |
+
which is 256 on each side.
|
51 |
+
attention_dilation: list of attention dilation of length = number of layers.
|
52 |
+
attention dilation of `1` means no dilation.
|
53 |
+
autoregressive: do autoregressive attention or have attention of both sides
|
54 |
+
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
|
55 |
+
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
|
56 |
+
"""
|
57 |
+
super().__init__(**kwargs)
|
58 |
+
self.attention_window = attention_window
|
59 |
+
self.attention_dilation = attention_dilation
|
60 |
+
self.autoregressive = autoregressive
|
61 |
+
self.attention_mode = attention_mode
|
62 |
+
self.gradient_checkpointing = gradient_checkpointing
|
63 |
+
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
|
64 |
+
|
65 |
+
class LongformerSelfAttentionForMBart(nn.Module):
|
66 |
+
def __init__(self, config, layer_id):
|
67 |
+
super().__init__()
|
68 |
+
self.embed_dim = config.d_model
|
69 |
+
self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
|
70 |
+
self.output = nn.Linear(self.embed_dim, self.embed_dim)
|
71 |
+
|
72 |
+
def forward(
|
73 |
+
self,
|
74 |
+
hidden_states: Tensor, # shape (batch_size, q_len, model_size)
|
75 |
+
key_value_states: Optional[Tensor] = None, # cross-attention in transformers.models.mbart.modeling_mbart
|
76 |
+
past_key_value: Optional[Tuple[Tensor]] = None, # only for decoder
|
77 |
+
attention_mask: Optional[Tensor] = None, # shape (batch_size, k_len) -> changed in transformers.models.modeling_mbart.MBartEncoder and MBartEncoderLayer (new mask uses bool -> global attention positions are lost, need to use the inverted orignal mask
|
78 |
+
layer_head_mask: Optional[Tensor] = None, # head dropout?
|
79 |
+
output_attentions: bool = False
|
80 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
81 |
+
|
82 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
83 |
+
assert embed_dim == self.embed_dim
|
84 |
+
assert list(hidden_states.size()) == [bsz, tgt_len, embed_dim]
|
85 |
+
|
86 |
+
outputs = self.longformer_self_attn(
|
87 |
+
hidden_states,
|
88 |
+
attention_mask=attention_mask * -1, # shape (batch_size, 1, 1, key_len)
|
89 |
+
head_mask=None,
|
90 |
+
encoder_hidden_states=None,
|
91 |
+
encoder_attention_mask=None,
|
92 |
+
output_attentions=output_attentions,
|
93 |
+
)
|
94 |
+
|
95 |
+
## new: MBart encoder expects shape (seq_len, bsz, embed_dim), no transpose needed
|
96 |
+
attn_output = self.output(outputs[0])
|
97 |
+
# new return in MBartAttention has attn_output, attn_weights_reshaped, past_key_value (only for decoder), need to return 3 values (None for past_key_value)
|
98 |
+
return (attn_output, outputs[1:] ,None) if len(outputs) == 2 else (attn_output, None, None)
|
99 |
+
|
100 |
+
|
101 |
+
class LongMBartEncoder(MBartEncoder):
|
102 |
+
"""
|
103 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
104 |
+
[`MBartEncoderLayer`].
|
105 |
+
|
106 |
+
Args:
|
107 |
+
config: MBartConfig
|
108 |
+
embed_tokens (nn.Embedding): output embedding
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
|
112 |
+
super().__init__(config)
|
113 |
+
|
114 |
+
self.dropout = config.dropout
|
115 |
+
self.layerdrop = config.encoder_layerdrop
|
116 |
+
|
117 |
+
embed_dim = config.d_model
|
118 |
+
self.padding_idx = config.pad_token_id
|
119 |
+
self.max_source_positions = config.max_encoder_position_embeddings
|
120 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
121 |
+
|
122 |
+
if embed_tokens is not None:
|
123 |
+
self.embed_tokens = embed_tokens
|
124 |
+
else:
|
125 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
126 |
+
|
127 |
+
self.embed_positions = MBartLearnedPositionalEmbedding(
|
128 |
+
self.max_source_positions,
|
129 |
+
embed_dim,
|
130 |
+
)
|
131 |
+
self.layers = nn.ModuleList([LongMBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
132 |
+
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
133 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
134 |
+
|
135 |
+
self.gradient_checkpointing = False
|
136 |
+
# Initialize weights and apply final processing
|
137 |
+
self.post_init()
|
138 |
+
|
139 |
+
def forward(
|
140 |
+
self,
|
141 |
+
input_ids: torch.LongTensor = None,
|
142 |
+
attention_mask: Optional[torch.Tensor] = None,
|
143 |
+
head_mask: Optional[torch.Tensor] = None,
|
144 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
145 |
+
output_attentions: Optional[bool] = None,
|
146 |
+
output_hidden_states: Optional[bool] = None,
|
147 |
+
return_dict: Optional[bool] = None,
|
148 |
+
) -> Union[Tuple, BaseModelOutput]:
|
149 |
+
r"""
|
150 |
+
Args:
|
151 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
152 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
153 |
+
provide it.
|
154 |
+
|
155 |
+
Indices can be obtained using [`MBartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
156 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
157 |
+
|
158 |
+
[What are input IDs?](../glossary#input-ids)
|
159 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
160 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
161 |
+
|
162 |
+
- 1 for tokens that are **not masked**,
|
163 |
+
- 0 for tokens that are **masked**.
|
164 |
+
|
165 |
+
[What are attention masks?](../glossary#attention-mask)
|
166 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
167 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
168 |
+
|
169 |
+
- 1 indicates the head is **not masked**,
|
170 |
+
- 0 indicates the head is **masked**.
|
171 |
+
|
172 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
173 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
174 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
175 |
+
than the model's internal embedding lookup matrix.
|
176 |
+
output_attentions (`bool`, *optional*):
|
177 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
178 |
+
returned tensors for more detail.
|
179 |
+
output_hidden_states (`bool`, *optional*):
|
180 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
181 |
+
for more detail.
|
182 |
+
return_dict (`bool`, *optional*):
|
183 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
184 |
+
"""
|
185 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
186 |
+
output_hidden_states = (
|
187 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
188 |
+
)
|
189 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
190 |
+
|
191 |
+
# retrieve input_ids and inputs_embeds
|
192 |
+
if input_ids is not None and inputs_embeds is not None:
|
193 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
194 |
+
elif input_ids is not None:
|
195 |
+
input = input_ids
|
196 |
+
input_shape = input.shape
|
197 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
198 |
+
elif inputs_embeds is not None:
|
199 |
+
input = inputs_embeds[:, :, -1]
|
200 |
+
else:
|
201 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
202 |
+
|
203 |
+
if inputs_embeds is None:
|
204 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
205 |
+
|
206 |
+
embed_pos = self.embed_positions(input)
|
207 |
+
|
208 |
+
hidden_states = inputs_embeds + embed_pos
|
209 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
210 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
211 |
+
|
212 |
+
# expand attention_mask
|
213 |
+
longformer_attention_mask = None
|
214 |
+
if attention_mask is not None:
|
215 |
+
# need to return original, inverted mask for longformer attention, else value for global attention (=2 in given mask, will be -1) is lost
|
216 |
+
longformer_attention_mask = 1 - attention_mask
|
217 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
218 |
+
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
|
219 |
+
|
220 |
+
|
221 |
+
encoder_states = () if output_hidden_states else None
|
222 |
+
all_attentions = () if output_attentions else None
|
223 |
+
|
224 |
+
# check if head_mask has a correct number of layers specified if desired
|
225 |
+
if head_mask is not None:
|
226 |
+
if head_mask.size()[0] != len(self.layers):
|
227 |
+
raise ValueError(
|
228 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
229 |
+
f" {head_mask.size()[0]}."
|
230 |
+
)
|
231 |
+
for idx, encoder_layer in enumerate(self.layers):
|
232 |
+
if output_hidden_states:
|
233 |
+
encoder_states = encoder_states + (hidden_states,)
|
234 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
235 |
+
dropout_probability = random.uniform(0, 1)
|
236 |
+
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
237 |
+
layer_outputs = (None, None)
|
238 |
+
else:
|
239 |
+
if self.gradient_checkpointing and self.training:
|
240 |
+
|
241 |
+
def create_custom_forward(module):
|
242 |
+
def custom_forward(*inputs):
|
243 |
+
return module(*inputs, output_attentions)
|
244 |
+
|
245 |
+
return custom_forward
|
246 |
+
|
247 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
248 |
+
create_custom_forward(encoder_layer),
|
249 |
+
hidden_states,
|
250 |
+
attention_mask,
|
251 |
+
longformer_attention_mask,
|
252 |
+
(head_mask[idx] if head_mask is not None else None),
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
layer_outputs = encoder_layer(
|
256 |
+
hidden_states,
|
257 |
+
attention_mask,
|
258 |
+
longformer_attention_mask,
|
259 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
260 |
+
output_attentions=output_attentions,
|
261 |
+
)
|
262 |
+
|
263 |
+
hidden_states = layer_outputs[0]
|
264 |
+
|
265 |
+
if output_attentions:
|
266 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
267 |
+
|
268 |
+
hidden_states = self.layer_norm(hidden_states)
|
269 |
+
|
270 |
+
if output_hidden_states:
|
271 |
+
encoder_states = encoder_states + (hidden_states,)
|
272 |
+
|
273 |
+
if not return_dict:
|
274 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
275 |
+
return BaseModelOutput(
|
276 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
277 |
+
)
|
278 |
+
|
279 |
+
|
280 |
+
class LongMBartModel(MBartModel):
|
281 |
+
def __init__(self, config: MBartConfig):
|
282 |
+
super().__init__(config)
|
283 |
+
|
284 |
+
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
285 |
+
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
286 |
+
|
287 |
+
self.encoder = LongMBartEncoder(config, self.shared)
|
288 |
+
self.decoder = MBartDecoder(config, self.shared)
|
289 |
+
|
290 |
+
# Initialize weights and apply final processing
|
291 |
+
self.post_init()
|
292 |
+
|
293 |
+
|
294 |
+
class LongMBartEncoderLayer(MBartEncoderLayer):
|
295 |
+
def __init__(self, config: MBartConfig):
|
296 |
+
super().__init__(config)
|
297 |
+
|
298 |
+
def forward(
|
299 |
+
self,
|
300 |
+
hidden_states: torch.Tensor,
|
301 |
+
attention_mask: torch.Tensor,
|
302 |
+
longformer_attention_mask: torch.Tensor,
|
303 |
+
layer_head_mask: torch.Tensor,
|
304 |
+
output_attentions: bool = False,
|
305 |
+
) -> torch.Tensor:
|
306 |
+
"""
|
307 |
+
Args:
|
308 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
309 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
310 |
+
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
311 |
+
longformer_attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
312 |
+
`(batch, src_len)` where 0=local, -1=global, 1=padding.
|
313 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
314 |
+
*(encoder_attention_heads,)*.
|
315 |
+
output_attentions (`bool`, *optional*):
|
316 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
317 |
+
returned tensors for more detail.
|
318 |
+
"""
|
319 |
+
# if longformer attention instead of mbart self attention: use special mask
|
320 |
+
if isinstance(self.self_attn, LongformerSelfAttentionForMBart):
|
321 |
+
attention_mask = longformer_attention_mask
|
322 |
+
residual = hidden_states
|
323 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
324 |
+
hidden_states, attn_weights, _ = self.self_attn(
|
325 |
+
hidden_states=hidden_states,
|
326 |
+
attention_mask=attention_mask,
|
327 |
+
layer_head_mask=layer_head_mask,
|
328 |
+
output_attentions=output_attentions,
|
329 |
+
)
|
330 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
331 |
+
hidden_states = residual + hidden_states
|
332 |
+
|
333 |
+
residual = hidden_states
|
334 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
335 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
336 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
337 |
+
hidden_states = self.fc2(hidden_states)
|
338 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
339 |
+
hidden_states = residual + hidden_states
|
340 |
+
|
341 |
+
if hidden_states.dtype == torch.float16 and (
|
342 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
343 |
+
):
|
344 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
345 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
346 |
+
|
347 |
+
outputs = (hidden_states,)
|
348 |
+
|
349 |
+
if output_attentions:
|
350 |
+
outputs += (attn_weights,)
|
351 |
+
|
352 |
+
return outputs
|
long_models/sliding_chunks.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""
|
5 |
+
|
6 |
+
This code is from AllenAI's Longformer:
|
7 |
+
https://github.com/allenai/longformer/
|
8 |
+
|
9 |
+
"""
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from .diagonaled_mm_tvm import mask_invalid_locations
|
13 |
+
|
14 |
+
|
15 |
+
def _skew(x, direction, padding_value):
|
16 |
+
'''Convert diagonals into columns (or columns into diagonals depending on `direction`'''
|
17 |
+
x_padded = F.pad(x, direction, value=padding_value)
|
18 |
+
x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
|
19 |
+
return x_padded
|
20 |
+
|
21 |
+
|
22 |
+
def _skew2(x, padding_value):
|
23 |
+
'''shift every row 1 step to right converting columns into diagonals'''
|
24 |
+
# X = B x C x M x L
|
25 |
+
B, C, M, L = x.size()
|
26 |
+
x = F.pad(x, (0, M + 1), value=padding_value) # B x C x M x (L+M+1)
|
27 |
+
x = x.view(B, C, -1) # B x C x ML+MM+M
|
28 |
+
x = x[:, :, :-M] # B x C x ML+MM
|
29 |
+
x = x.view(B, C, M, M + L) # B x C, M x L+M
|
30 |
+
x = x[:, :, :, :-1]
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
def _chunk(x, w):
|
35 |
+
'''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''
|
36 |
+
|
37 |
+
# non-overlapping chunks of size = 2w
|
38 |
+
x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))
|
39 |
+
|
40 |
+
# use `as_strided` to make the chunks overlap with an overlap size = w
|
41 |
+
chunk_size = list(x.size())
|
42 |
+
chunk_size[1] = chunk_size[1] * 2 - 1
|
43 |
+
|
44 |
+
chunk_stride = list(x.stride())
|
45 |
+
chunk_stride[1] = chunk_stride[1] // 2
|
46 |
+
return x.as_strided(size=chunk_size, stride=chunk_stride)
|
47 |
+
|
48 |
+
|
49 |
+
def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
|
50 |
+
'''Matrix multiplicatio of query x key tensors using with a sliding window attention pattern.
|
51 |
+
This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
|
52 |
+
with an overlap of size w'''
|
53 |
+
bsz, seqlen, num_heads, head_dim = q.size()
|
54 |
+
assert seqlen % (w * 2) == 0
|
55 |
+
assert q.size() == k.size()
|
56 |
+
|
57 |
+
chunks_count = seqlen // w - 1
|
58 |
+
|
59 |
+
# group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
|
60 |
+
q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
|
61 |
+
k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
|
62 |
+
|
63 |
+
chunk_q = _chunk(q, w)
|
64 |
+
chunk_k = _chunk(k, w)
|
65 |
+
|
66 |
+
# matrix multipication
|
67 |
+
# bcxd: bsz*num_heads x chunks x 2w x head_dim
|
68 |
+
# bcyd: bsz*num_heads x chunks x 2w x head_dim
|
69 |
+
# bcxy: bsz*num_heads x chunks x 2w x 2w
|
70 |
+
chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k)) # multiply
|
71 |
+
|
72 |
+
# convert diagonals into columns
|
73 |
+
diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value)
|
74 |
+
|
75 |
+
# allocate space for the overall attention matrix where the chunks are compined. The last dimension
|
76 |
+
# has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to
|
77 |
+
# w previous words). The following column is attention score from each word to itself, then
|
78 |
+
# followed by w columns for the upper triangle.
|
79 |
+
|
80 |
+
diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1))
|
81 |
+
|
82 |
+
# copy parts from diagonal_chunk_attn into the compined matrix of attentions
|
83 |
+
# - copying the main diagonal and the upper triangle
|
84 |
+
diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1]
|
85 |
+
diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1]
|
86 |
+
# - copying the lower triangle
|
87 |
+
diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:]
|
88 |
+
diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:]
|
89 |
+
|
90 |
+
# separate bsz and num_heads dimensions again
|
91 |
+
diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1).transpose(2, 1)
|
92 |
+
|
93 |
+
mask_invalid_locations(diagonal_attn, w, 1, False)
|
94 |
+
return diagonal_attn
|
95 |
+
|
96 |
+
|
97 |
+
def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
|
98 |
+
'''Same as sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output
|
99 |
+
format from sliding_chunks_matmul_qk'''
|
100 |
+
bsz, seqlen, num_heads, head_dim = v.size()
|
101 |
+
assert seqlen % (w * 2) == 0
|
102 |
+
assert prob.size()[:3] == v.size()[:3]
|
103 |
+
assert prob.size(3) == 2 * w + 1
|
104 |
+
chunks_count = seqlen // w - 1
|
105 |
+
# group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
|
106 |
+
chunk_prob = prob.transpose(1, 2).reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1)
|
107 |
+
|
108 |
+
# group bsz and num_heads dimensions into one
|
109 |
+
v = v.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
|
110 |
+
|
111 |
+
# pad seqlen with w at the beginning of the sequence and another w at the end
|
112 |
+
padded_v = F.pad(v, (0, 0, w, w), value=-1)
|
113 |
+
|
114 |
+
# chunk padded_v into chunks of size 3w and an overlap of size w
|
115 |
+
chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim)
|
116 |
+
chunk_v_stride = padded_v.stride()
|
117 |
+
chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
|
118 |
+
chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)
|
119 |
+
|
120 |
+
skewed_prob = _skew2(chunk_prob, padding_value=0)
|
121 |
+
|
122 |
+
context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v))
|
123 |
+
return context.view(bsz, num_heads, seqlen, head_dim).transpose(1, 2)
|
124 |
+
|
125 |
+
|
126 |
+
def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor,
|
127 |
+
one_sided_window_size: int, pad_token_id: int):
|
128 |
+
'''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer selfattention.
|
129 |
+
Input:
|
130 |
+
input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces
|
131 |
+
attention_mask = torch.Tensor(bsz x seqlen): attention mask
|
132 |
+
one_sided_window_size = int: window size on one side of each token
|
133 |
+
pad_token_id = int: tokenizer.pad_token_id
|
134 |
+
Returns
|
135 |
+
(input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size
|
136 |
+
'''
|
137 |
+
w = int(2 * one_sided_window_size)
|
138 |
+
seqlen = input_ids.size(1)
|
139 |
+
padding_len = (w - seqlen % w) % w
|
140 |
+
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
|
141 |
+
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
|
142 |
+
return input_ids, attention_mask
|
143 |
+
|
144 |
+
|
145 |
+
# ========= "sliding_chunks_no_overlap": alternative implemenation of the sliding window attention =========
|
146 |
+
# This implementation uses non-overlapping chunks (or blocks) of size `w` with number of local attention = 3xw
|
147 |
+
# To make this implemenation comparable to "sliding_chunks" set w such that
|
148 |
+
# w_of_sliding_chunks_no_overlap = w_of_sliding_chunks * 2 / 3
|
149 |
+
# For example,
|
150 |
+
# w_of_sliding_chunks = 256 (this is one sided. Total attention size = 512)
|
151 |
+
# w_of_sliding_chunks_no_overlap = 170 (Total attention size = 510)
|
152 |
+
# Performance:
|
153 |
+
# - Speed: 30% faster than "sliding_chunks"
|
154 |
+
# - Memory: 95% of the memory usage of "sliding_chunks"
|
155 |
+
# The windows are asymmetric where number of attention on each side of a token ranges between w to 2w
|
156 |
+
# while "sliding_chunks" has a symmetric window around each token.
|
157 |
+
|
158 |
+
|
159 |
+
def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
|
160 |
+
bsz, seqlen, num_heads, head_dim = q.size()
|
161 |
+
assert seqlen % w == 0
|
162 |
+
assert q.size() == k.size()
|
163 |
+
# chunk seqlen into non-overlapping chunks of size w
|
164 |
+
chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim)
|
165 |
+
chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim)
|
166 |
+
chunk_k_expanded = torch.stack((
|
167 |
+
F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
|
168 |
+
chunk_k,
|
169 |
+
F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
|
170 |
+
), dim=-1)
|
171 |
+
diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) # multiply
|
172 |
+
return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w)
|
173 |
+
|
174 |
+
|
175 |
+
def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
|
176 |
+
bsz, seqlen, num_heads, head_dim = v.size()
|
177 |
+
chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w)
|
178 |
+
chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim)
|
179 |
+
chunk_v_extended = torch.stack((
|
180 |
+
F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
|
181 |
+
chunk_v,
|
182 |
+
F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
|
183 |
+
), dim=-1)
|
184 |
+
context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended))
|
185 |
+
return context.reshape(bsz, seqlen, num_heads, head_dim)
|