omarmomen commited on
Commit
340c8dd
1 Parent(s): b6be2be

added notebook for generation

Browse files
generation.ipynb ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "c7544851-fb94-44e0-86eb-65d74fad45aa",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from long_models.longformer_mbart import MLongformerEncoderDecoderConfig, MLongformerEncoderDecoderForConditionalGeneration\n",
11
+ "from transformers import MBartTokenizer"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "id": "c1446321-3b74-4fbb-9871-403a82ceb0de",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "tokenizer = MBartTokenizer.from_pretrained(\"./\")\n",
22
+ "config = MLongformerEncoderDecoderConfig.from_pretrained('./')\n",
23
+ "model = MLongformerEncoderDecoderForConditionalGeneration.from_pretrained('./', config=config)\n",
24
+ "tokenizer.src_lang = 'de_DE'"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "65105f0c-eb1c-4e2f-8f21-23e3fbec81c1",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "input_txt = \"Ein Gebet von Mose, dem Mann Gottes: Mein Herr, eine sichere Wohnung bist du für uns gewesen von Generation zu Generation. Bevor die Berge geboren waren und du die Erde und die irdische Welt hervorgebracht hattest, und von Ewigkeit zu Ewigkeit bist du Gott. Du lässt den Menschen zum Staub zurückkehren und sprichst: „Kehrt zurück, ihr Kinder des Menschen! “ Denn tausend Jahre sind in deinen Augen wie der gestrige Tag, wenn er vergangen ist, oder eine Wache in der Nacht. Du schwemmst sie weg, sie sind wie Schlaf, am Morgen wie Gras, das aufsprosst. Am Morgen blüht es und sprosst auf, zum Abend verwelkt und verdorrt es. Denn wir vergehen durch deinen Zorn, und durch deine Zorneshitze werden wir verstört. Du stellst unsere Fehler vor dich, unsere Geheimnisse vor das Licht deiner Gegenwart. Ja, alle unsere Tage fahren dahin durch deinen Zorn, wir vollenden unsere Jahre wie einen Seufzer. Die Tage unserer Jahre, in ihnen sind siebzig Jahre, und mit Kraft achtzig Jahre. und Ihr Stolz ist Mühe und Beschwerde, denn er ist schnell vergangen und wir fliegen davon. Wer erkennt die Stärke deines Zorns? Wie die Furcht vor dir ist dein Grimm. Darum lehre uns, unsere Tage zu zählen, damit wir ein Herz der Weisheit bekommen. Kehre doch zurück, JHWH! Wie lange? Habe Mitleid mit deinen Knechten! Sättige uns am Morgen mit deiner Güte, dann werden wir jubeln und uns freuen an allen unseren Tagen! Erfreue uns so viele Tage, wie du uns bedrückt hast, so viele Jahre, wie wir Unglück gesehen haben! Zeige deinen Knechten dein Handeln und deine Herrlichkeit ihren Kindern! Die Freundlichkeit des Herrn, unseres Gottes sei über uns! Das Werk unserer Hände festige über uns, und das Werk unserer Hände, festige es!\""
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "id": "d7448395-6af9-44f4-88a0-33ed5fb80fd8",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "print(input_txt)"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "f9483233-7c6b-4f23-b69a-b19b4de053be",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "inputs = tokenizer(\n",
55
+ " input_txt, \n",
56
+ " padding='max_length',\n",
57
+ " return_tensors='pt')"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "id": "4aea1744-dfeb-498e-9fc7-d3f77ff012cb",
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "outputs = model.generate(**inputs, num_beams=6, decoder_start_token_id=tokenizer.convert_tokens_to_ids('de_SI'))"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "id": "7dcc7f7d-edff-4647-acab-fd610e27d0a7",
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "tokenizer.batch_decode(outputs)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "5dd353d8-4de6-470e-b0e6-86c7c6640207",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": []
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "id": "ba203c11-b222-484a-8d53-44a4d9da8ef5",
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": []
95
+ }
96
+ ],
97
+ "metadata": {
98
+ "kernelspec": {
99
+ "display_name": "Python 3 (ipykernel)",
100
+ "language": "python",
101
+ "name": "python3"
102
+ },
103
+ "language_info": {
104
+ "codemirror_mode": {
105
+ "name": "ipython",
106
+ "version": 3
107
+ },
108
+ "file_extension": ".py",
109
+ "mimetype": "text/x-python",
110
+ "name": "python",
111
+ "nbconvert_exporter": "python",
112
+ "pygments_lexer": "ipython3",
113
+ "version": "3.10.6"
114
+ }
115
+ },
116
+ "nbformat": 4,
117
+ "nbformat_minor": 5
118
+ }
long_models/diagonaled_mm_tvm.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+
6
+ This code is from AllenAI's Longformer:
7
+ https://github.com/allenai/longformer/
8
+
9
+ """
10
+ from typing import Union
11
+ from functools import lru_cache
12
+
13
+ import torch
14
+ import os.path
15
+
16
+
17
+ class DiagonaledMM(torch.autograd.Function):
18
+ '''Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling
19
+ this function from PyTorch
20
+ '''
21
+
22
+ function_dict = {} # save a list of functions, each has a different set of parameters
23
+
24
+ @staticmethod
25
+ def _compile_function(dtype: str, device: str, b0: int = 4, b1: int = 4, b2: int = 16):
26
+ '''Compiles a tvm function that computes diagonal_mm
27
+ args:
28
+ dtype: str in ['float64', 'float32', 'float16']
29
+ device: str in ['cpu' or 'cuda']
30
+ b0, b1, b2: size of tensor tiles. Very important for good performance
31
+
32
+ '''
33
+ import tvm # import the full tvm library here for compilation. Don't import at the top of the file in case we don't need to compile
34
+ from tvm.contrib import nvcc
35
+ @tvm.register_func
36
+ def tvm_callback_cuda_compile(code):
37
+ """Use nvcc compiler for better perf."""
38
+ ptx = nvcc.compile_cuda(code, target="ptx", arch='sm_52') # use old arch for this to work on old GPUs
39
+ return ptx
40
+
41
+ assert dtype in ['float16', 'float32', 'float64']
42
+ assert device in ['cpu', 'cuda']
43
+ device = None if device == 'cpu' else device
44
+ tgt_host="llvm"
45
+
46
+ b = tvm.var('b') # batch size
47
+ n = tvm.var('n') # sequence length
48
+ h = tvm.var('h') # number of heads
49
+ m = tvm.var('m') # hidden dimension
50
+ w = tvm.var('w') # window size
51
+ w_upper = tvm.var('w_upper') # window size to the right of the word. Should be `0` or `w`
52
+ padding = tvm.var('padding') # padding
53
+ transpose_t1 = tvm.var('transpose_t1') # t1 should be transposed
54
+ t1d3 = tvm.var('t1d3') # last dimension of t1
55
+ t3d3 = tvm.var('t3d3') # last dimension of t3 (the result tensor)
56
+ X = tvm.placeholder((b, n, h, t1d3), name='X', dtype=dtype) # first tensor
57
+ Y = tvm.placeholder((b, n, h, m), name='Y', dtype=dtype) # second tensor
58
+ k = tvm.reduce_axis((0, t1d3), name='k') # dimension to sum over
59
+ D = tvm.placeholder((h), name='D', dtype='int') # dilation per head
60
+ output_shape = (b, n, h, t3d3) # shape of the result tensor
61
+ algorithm = lambda l, i, q, j: tvm.sum(
62
+ tvm.if_then_else(
63
+ t3d3 == m, # if output dimension == m, then t1 is diagonaled (FIXME: This breaks if t3d3 == m == t1d3)
64
+ tvm.if_then_else(
65
+ transpose_t1 == 0,
66
+ tvm.if_then_else(
67
+ tvm.all(
68
+ i + D[q] * (k - w) >= 0,
69
+ i + D[q] * (k - w) < n,
70
+ ),
71
+ X[l, i, q, k] * Y[l, i + D[q] * (k - w), q, j], # t1 is diagonaled
72
+ padding
73
+ ),
74
+ tvm.if_then_else(
75
+ tvm.all(
76
+ i + D[q] * (k - w_upper) >= 0, # `w_upper` to handle the case `autoregressive=True`
77
+ i + D[q] * (k - w_upper) < n,
78
+ ),
79
+ X[l, i + D[q] * (k - w_upper), q, (w_upper + w) - k] * Y[l, i + D[q] * (k - w_upper), q, j], # # t1 is diagonaled and should be transposed
80
+ padding
81
+ ),
82
+ ),
83
+ tvm.if_then_else(
84
+ tvm.all(
85
+ i + D[q] * (j - w) >= 0,
86
+ i + D[q] * (j - w) < n,
87
+ ),
88
+ X[l, i, q, k] * Y[l, i + D[q] * (j - w), q, k], # t1 is not diagonaled, but the output tensor is going to be
89
+ padding
90
+ )
91
+ ), axis=k)
92
+
93
+ Z = tvm.compute(output_shape, algorithm, name='Z') # automatically generate cuda code
94
+ s = tvm.create_schedule(Z.op)
95
+
96
+ print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True)))
97
+
98
+ # split long axis into smaller chunks and assing each one to a separate GPU thread/block
99
+ ko, ki = s[Z].split(Z.op.reduce_axis[0], factor=b0)
100
+ ZF = s.rfactor(Z, ki)
101
+
102
+ j_outer, j_inner = s[Z].split(s[Z].op.axis[-1], factor=b1)
103
+ i_outer, i_inner = s[Z].split(s[Z].op.axis[1], factor=b2)
104
+
105
+ s[Z].bind(j_outer, tvm.thread_axis("blockIdx.x"))
106
+ s[Z].bind(j_inner, tvm.thread_axis("threadIdx.y"))
107
+
108
+ s[Z].bind(i_outer, tvm.thread_axis("blockIdx.y"))
109
+ s[Z].bind(i_inner, tvm.thread_axis("threadIdx.z"))
110
+
111
+ tx = tvm.thread_axis("threadIdx.x")
112
+ s[Z].bind(s[Z].op.reduce_axis[0], tx)
113
+ s[ZF].compute_at(s[Z], s[Z].op.reduce_axis[0])
114
+ s[Z].set_store_predicate(tx.var.equal(0))
115
+
116
+ print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True)))
117
+
118
+ # compiling the automatically generated cuda code
119
+ diagonaled_mm = tvm.build(s, [X, Y, Z, D, w, w_upper, padding, transpose_t1, t3d3], target=device, target_host=tgt_host, name='diagonaled_mm')
120
+ return diagonaled_mm
121
+
122
+ @staticmethod
123
+ def _get_lib_filename(dtype: str, device: str):
124
+ base_filename = 'longformer/lib/lib_diagonaled_mm'
125
+ return '{}_{}_{}.so'.format(base_filename, dtype, device)
126
+
127
+ @staticmethod
128
+ def _save_compiled_function(f, dtype: str, device: str):
129
+ if not os.path.exists('longformer/lib/'):
130
+ os.makedirs('longformer/lib/')
131
+ f.export_library(DiagonaledMM._get_lib_filename(dtype, device))
132
+
133
+ @staticmethod
134
+ def _load_compiled_function(dtype: str, device: str):
135
+ from tvm.module import load # this can be the small runtime python library, and doesn't need to be the whole thing
136
+ filename = DiagonaledMM._get_lib_filename(dtype, device)
137
+ current_dir = os.path.dirname(os.path.abspath(__file__))
138
+ potential_dirs = ['../../', '../', './', f'{current_dir}/', f'{current_dir}/../']
139
+ for potential_dir in potential_dirs:
140
+ filepath = '{}{}'.format(potential_dir, filename)
141
+ if os.path.isfile(filepath):
142
+ print('Loading tvm binary from: {}'.format(filepath))
143
+ return load(filepath)
144
+ return None
145
+
146
+ @staticmethod
147
+ def _get_function(dtype: str, device: str):
148
+ '''Loads the function from the disk or compile it'''
149
+ # A list of arguments that define the function
150
+ args = (dtype, device)
151
+ if args not in DiagonaledMM.function_dict:
152
+ diagonaled_mm = DiagonaledMM._load_compiled_function(dtype, device) # try to load from disk
153
+ if not diagonaled_mm:
154
+ print('Tvm binary not found. Compiling ...')
155
+ diagonaled_mm = DiagonaledMM._compile_function(dtype, device) # compile
156
+ DiagonaledMM._save_compiled_function(diagonaled_mm, dtype, device) # save to disk
157
+ # convert the tvm function into a pytorch function
158
+ from tvm.contrib import dlpack
159
+ diagonaled_mm_pytorch = dlpack.to_pytorch_func(diagonaled_mm) # wrap it as a pytorch function
160
+ # save the function into a dictionary to be reused
161
+ DiagonaledMM.function_dict[args] = diagonaled_mm_pytorch # save it in a dictionary for next time
162
+ return DiagonaledMM.function_dict[args]
163
+
164
+ @staticmethod
165
+ def _diagonaled_mm(t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int],
166
+ is_t1_diagonaled: bool = False, transpose_t1: bool = False, padding: int = 0,
167
+ autoregressive: bool = False):
168
+ '''Calls the compiled function after checking the input format. This function is called in three different modes.
169
+ t1 x t2 = r ==> t1 and t2 are not diagonaled, but r is. Useful for query x key = attention_scores
170
+ t1 x t2 = r ==> t1 is diagonaled, but t2 and r are not. Useful to compuate attantion_scores x value = context
171
+ t1 x t2 = r ==> t1 is diagonaled and it should be transposed, but t2 and r are not diagonaled. Useful in some of
172
+ the calculations in the backward pass.
173
+ '''
174
+ dtype = str(t1.dtype).split('.')[1]
175
+ device = t1.device.type
176
+ assert len(t1.shape) == 4
177
+ assert len(t1.shape) == len(t2.shape)
178
+ assert t1.shape[:3] == t2.shape[:3]
179
+ if isinstance(d, int): # if d is an integer, replace it with a tensor of the same length
180
+ # as number of heads, and it is filled with the same dilation value
181
+ d = t1.new_full(size=(t1.shape[2],), fill_value=d, dtype=torch.int, requires_grad=False)
182
+
183
+ assert len(d.shape) == 1
184
+ assert d.shape[0] == t1.shape[2] # number of dilation scores should match number of heads
185
+ b = t1.shape[0] # batch size
186
+ n = t1.shape[1] # sequence length
187
+ h = t1.shape[2] # number of heads
188
+ m = t2.shape[3] # hidden dimension
189
+ w_upper = 0 if autoregressive else w
190
+ c = w_upper + w + 1 # number of diagonals
191
+ if is_t1_diagonaled:
192
+ assert t1.shape[3] == c
193
+ r = t1.new_empty(b, n, h, m) # allocate spase for the result tensor
194
+ else:
195
+ assert not transpose_t1
196
+ assert t1.shape[3] == m
197
+ r = t1.new_empty(b, n, h, c) # allocate spase for the result tensor
198
+
199
+ # gets function from memory, from disk or compiles it from scratch
200
+ _diagonaled_mm_function = DiagonaledMM._get_function(dtype=dtype, device=device)
201
+
202
+ # The last argument to this function is a little hacky. It is the size of the last dimension of the result tensor
203
+ # We use it as a proxy to tell if t1_is_diagonaled or not (if t1 is diagonaled, result is not, and vice versa).
204
+ # The second reason is that the lambda expression in `_compile_function` is easier to express when the shape
205
+ # of the output is known
206
+ # This functions computes diagonal_mm then saves the result in `r`
207
+ if m == c:
208
+ # FIXME
209
+ print('Error: the hidden dimension {m} shouldn\'t match number of diagonals {c}')
210
+ assert False
211
+ _diagonaled_mm_function(t1, t2, r, d, w, w_upper, padding, transpose_t1, m if is_t1_diagonaled else c)
212
+ return r
213
+
214
+ @staticmethod
215
+ def _prepare_tensors(t):
216
+ '''Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch.
217
+ For a tensor t, if t.size(0) == 1, then the value of t.stride()[0] doesn't matter.
218
+ TVM expects this value to be the `product(t.size()[1:])` but PyTorch some times sets it to `t.stride()[1]`.
219
+ Here's an example to reporduce this issue:
220
+ import torch
221
+ print(torch.randn(1, 10).stride())
222
+ > (10, 1)
223
+ print(torch.randn(10, 1).t().contiguous().stride())
224
+ > (1, 1) # expected it to be (10, 1) as above
225
+ print(torch.randn(10, 2).t().contiguous().stride())
226
+ > (10, 1) # but gets the expected stride if the first dimension is > 1
227
+ '''
228
+ assert t.is_contiguous()
229
+ t_stride = list(t.stride())
230
+ t_size = list(t.size())
231
+ # Fix wrong stride information for the first dimension. This occures when batch_size=1
232
+ if t_size[0] == 1 and t_stride[0] == t_stride[1]:
233
+ # In this case, the stride of the first dimension should be the product
234
+ # of the sizes of all other dimensions
235
+ t_stride[0] = t_size[1] * t_size[2] * t_size[3]
236
+ t = t.as_strided(size=t_size, stride=t_stride)
237
+ return t
238
+
239
+ min_seq_len = 16 # unexpected output if seq_len < 16
240
+
241
+ @staticmethod
242
+ def forward(ctx, t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int], is_t1_diagonaled: bool = False, padding: int = 0, autoregressive: bool = False) -> torch.Tensor:
243
+ '''Compuates diagonal_mm of t1 and t2.
244
+ args:
245
+ t1: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals).
246
+ t1 can be a regular tensor (e.g. `query_layer`) or a diagonaled one (e.g. `attention_scores`)
247
+ t2: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size). This is always a non-diagonaled
248
+ tensor, e.g. `key_layer` or `value_layer`
249
+ w: int = window size; number of attentions on each side of the word
250
+ d: torch.Tensor or int = dilation of attentions per attention head. If int, the same dilation value will be used for all
251
+ heads. If torch.Tensor, it should be 1D of lenth=number of attention heads
252
+ is_t1_diagonaled: is t1 a diagonaled or a regular tensor
253
+ padding: the padding value to use when accessing invalid locations. This is mainly useful when the padding
254
+ needs to be a very large negative value (to compute softmax of attentions). For other usecases,
255
+ please use zero padding.
256
+ autoregressive: if true, return only the lower triangle
257
+ returns: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals)
258
+ if t1 is diagonaed, result is non-diagonaled, and vice versa
259
+ '''
260
+ batch_size, seq_len, num_attention_heads, hidden_size = t1.size()
261
+ assert seq_len >= DiagonaledMM.min_seq_len, 'avoid splitting errors by using seq_len >= {}'.format(DiagonaledMM.min_seq_len) # FIXME
262
+ ctx.save_for_backward(t1, t2)
263
+ ctx.w = w
264
+ ctx.d = d
265
+ ctx.is_t1_diagonaled = is_t1_diagonaled
266
+ ctx.autoregressive = autoregressive
267
+ t1 = DiagonaledMM._prepare_tensors(t1)
268
+ t2 = DiagonaledMM._prepare_tensors(t2)
269
+ # output = t1.mm(t2) # what would have been called if this was a regular matmul
270
+ output = DiagonaledMM._diagonaled_mm(t1, t2, w, d, is_t1_diagonaled=is_t1_diagonaled, padding=padding, autoregressive=autoregressive)
271
+ return output
272
+
273
+ @staticmethod
274
+ def backward(ctx, grad_output):
275
+ t1, t2 = ctx.saved_tensors
276
+ w = ctx.w
277
+ d = ctx.d
278
+ is_t1_diagonaled = ctx.is_t1_diagonaled
279
+ autoregressive = ctx.autoregressive
280
+ if not grad_output.is_contiguous():
281
+ grad_output = grad_output.contiguous() # tvm requires all input tensors to be contiguous
282
+ grad_output = DiagonaledMM._prepare_tensors(grad_output)
283
+ t1 = DiagonaledMM._prepare_tensors(t1)
284
+ t2 = DiagonaledMM._prepare_tensors(t2)
285
+ # http://cs231n.github.io/optimization-2/
286
+ # https://pytorch.org/docs/master/notes/extending.html
287
+ # grad_t1 = grad_output.mm(t2) # what would have been called if this was a regular matmul
288
+ grad_t1 = DiagonaledMM._diagonaled_mm(grad_output, t2, w, d, is_t1_diagonaled=not is_t1_diagonaled, autoregressive=autoregressive)
289
+ # grad_t2 = grad_output.t().mm(t1) # or `grad_t2 = t1.t().mm(grad_output).t()` because `(AB)^T = B^TA^T`
290
+ if is_t1_diagonaled:
291
+ grad_t2 = DiagonaledMM._diagonaled_mm(t1, grad_output, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
292
+ else:
293
+ grad_t2 = DiagonaledMM._diagonaled_mm(grad_output, t1, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
294
+ return grad_t1, grad_t2, None, None, None, None, None
295
+
296
+
297
+ def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int):
298
+ diagonals_list = []
299
+ for j in range(-d * w, d, d):
300
+ diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8)
301
+ diagonal_mask[:-j] = 1
302
+ diagonals_list.append(diagonal_mask)
303
+ return torch.stack(diagonals_list, dim=-1)
304
+
305
+ @lru_cache()
306
+ def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):
307
+ if isinstance(d, int):
308
+ affected_seq_len = w * d
309
+ mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
310
+ mask = mask[None, :, None, :]
311
+ else:
312
+ affected_seq_len = w * d.max()
313
+ head_masks = []
314
+ d_list = d.cpu().numpy().tolist()
315
+ for d in d_list:
316
+ one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
317
+ head_masks.append(one_head_mask)
318
+ mask = torch.stack(head_masks, dim=-2)
319
+ mask = mask[None, :, :, :]
320
+
321
+ ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device)
322
+ return affected_seq_len, mask.bool().to(device), ending_mask
323
+
324
+ def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor:
325
+ affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device)
326
+ seq_len = input_tensor.size(1)
327
+ beginning_input = input_tensor[:, :affected_seq_len, :, :w+1]
328
+ beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size())
329
+ beginning_input.masked_fill_(beginning_mask, -float('inf'))
330
+ if not autoregressive:
331
+ ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):]
332
+ ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size())
333
+ ending_input.masked_fill_(ending_mask, -float('inf'))
334
+
335
+
336
+ diagonaled_mm = DiagonaledMM.apply
337
+
338
+ # The non-tvm implementation is the default, we don't need to load the kernel at loading time.
339
+ # DiagonaledMM._get_function('float32', 'cuda')
long_models/longformer.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+
6
+ This code is adapted from AllenAI's Longformer:
7
+ https://github.com/allenai/longformer/
8
+
9
+ """
10
+ from typing import List
11
+ import math
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+ from .diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations
16
+ from .sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv
17
+ from .sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv
18
+ from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM
19
+
20
+
21
+ class Longformer(RobertaModel):
22
+ def __init__(self, config):
23
+ super(Longformer, self).__init__(config)
24
+ if config.attention_mode == 'n2':
25
+ pass # do nothing, use BertSelfAttention instead
26
+ else:
27
+ for i, layer in enumerate(self.encoder.layer):
28
+ layer.attention.self = LongformerSelfAttention(config, layer_id=i)
29
+
30
+
31
+ class LongformerForMaskedLM(RobertaForMaskedLM):
32
+ def __init__(self, config):
33
+ super(LongformerForMaskedLM, self).__init__(config)
34
+ if config.attention_mode == 'n2':
35
+ pass # do nothing, use BertSelfAttention instead
36
+ else:
37
+ for i, layer in enumerate(self.roberta.encoder.layer):
38
+ layer.attention.self = LongformerSelfAttention(config, layer_id=i)
39
+
40
+
41
+ class LongformerConfig(RobertaConfig):
42
+ def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
43
+ autoregressive: bool = False, attention_mode: str = 'sliding_chunks', **kwargs):
44
+ """
45
+ Args:
46
+ attention_window: list of attention window sizes of length = number of layers.
47
+ window size = number of attention locations on each side.
48
+ For an affective window size of 512, use `attention_window=[256]*num_layers`
49
+ which is 256 on each side.
50
+ attention_dilation: list of attention dilation of length = number of layers.
51
+ attention dilation of `1` means no dilation.
52
+ autoregressive: do autoregressive attention or have attention of both sides
53
+ attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
54
+ selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
55
+ """
56
+ super().__init__(**kwargs)
57
+ self.attention_window = attention_window
58
+ self.attention_dilation = attention_dilation
59
+ self.autoregressive = autoregressive
60
+ self.attention_mode = attention_mode
61
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2', 'sliding_chunks_no_overlap']
62
+
63
+
64
+ class LongformerSelfAttention(nn.Module):
65
+ def __init__(self, config, layer_id):
66
+ super(LongformerSelfAttention, self).__init__()
67
+ if config.hidden_size % config.num_attention_heads != 0:
68
+ raise ValueError(
69
+ "The hidden size (%d) is not a multiple of the number of attention "
70
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
71
+ self.num_heads = config.num_attention_heads
72
+ self.head_dim = int(config.hidden_size / config.num_attention_heads)
73
+ self.embed_dim = config.hidden_size
74
+
75
+ self.query = nn.Linear(config.hidden_size, self.embed_dim)
76
+ self.key = nn.Linear(config.hidden_size, self.embed_dim)
77
+ self.value = nn.Linear(config.hidden_size, self.embed_dim)
78
+
79
+ self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
80
+ self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
81
+ self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
82
+
83
+ self.dropout = config.attention_probs_dropout_prob
84
+
85
+ self.layer_id = layer_id
86
+ self.attention_window = config.attention_window[self.layer_id]
87
+ self.attention_dilation = config.attention_dilation[self.layer_id]
88
+ self.attention_mode = config.attention_mode
89
+ self.autoregressive = config.autoregressive
90
+ assert self.attention_window > 0
91
+ assert self.attention_dilation > 0
92
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap']
93
+ if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']:
94
+ assert not self.autoregressive # not supported
95
+ assert self.attention_dilation == 1 # dilation is not supported
96
+
97
+ def forward(
98
+ self,
99
+ hidden_states,
100
+ attention_mask=None,
101
+ head_mask=None,
102
+ encoder_hidden_states=None,
103
+ encoder_attention_mask=None,
104
+ output_attentions=False,
105
+ ):
106
+ '''
107
+ The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
108
+ -ve: no attention
109
+ 0: local attention
110
+ +ve: global attention
111
+ '''
112
+ assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None"
113
+ assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and should be None"
114
+
115
+ if attention_mask is not None:
116
+ key_padding_mask = attention_mask < 0
117
+ extra_attention_mask = attention_mask > 0
118
+ remove_from_windowed_attention_mask = attention_mask != 0
119
+
120
+ num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1)
121
+ max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
122
+ if max_num_extra_indices_per_batch <= 0:
123
+ extra_attention_mask = None
124
+ else:
125
+ # To support the case of variable number of global attention in the rows of a batch,
126
+ # we use the following three selection masks to select global attention embeddings
127
+ # in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
128
+ # 1) selecting embeddings that correspond to global attention
129
+ extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True)
130
+ zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch,
131
+ device=num_extra_indices_per_batch.device)
132
+ # mask indicating which values are actually going to be padding
133
+ selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1)
134
+ # 2) location of the non-padding values in the selected global attention
135
+ selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True)
136
+ # 3) location of the padding values in the selected global attention
137
+ selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True)
138
+ else:
139
+ remove_from_windowed_attention_mask = None
140
+ extra_attention_mask = None
141
+ key_padding_mask = None
142
+
143
+ hidden_states = hidden_states.transpose(0, 1)
144
+ seq_len, bsz, embed_dim = hidden_states.size()
145
+ assert embed_dim == self.embed_dim
146
+ q = self.query(hidden_states)
147
+ k = self.key(hidden_states)
148
+ v = self.value(hidden_states)
149
+ q /= math.sqrt(self.head_dim)
150
+
151
+ q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
152
+ k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
153
+ # attn_weights = (bsz, seq_len, num_heads, window*2+1)
154
+ if self.attention_mode == 'tvm':
155
+ q = q.float().contiguous()
156
+ k = k.float().contiguous()
157
+ attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False)
158
+ elif self.attention_mode == "sliding_chunks":
159
+ attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0)
160
+ elif self.attention_mode == "sliding_chunks_no_overlap":
161
+ attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0)
162
+ else:
163
+ raise False
164
+ mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False)
165
+ if remove_from_windowed_attention_mask is not None:
166
+ # This implementation is fast and takes very little memory because num_heads x hidden_size = 1
167
+ # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size)
168
+ remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1)
169
+ # cast to float/half then replace 1's with -inf
170
+ float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0)
171
+ repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation)
172
+ float_mask = float_mask.repeat(1, 1, repeat_size, 1)
173
+ ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
174
+ # diagonal mask with zeros everywhere and -inf inplace of padding
175
+ if self.attention_mode == 'tvm':
176
+ d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False)
177
+ elif self.attention_mode == "sliding_chunks":
178
+ d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
179
+ elif self.attention_mode == "sliding_chunks_no_overlap":
180
+ d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
181
+
182
+ attn_weights += d_mask
183
+ assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads]
184
+ assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3]
185
+
186
+ # the extra attention
187
+ if extra_attention_mask is not None:
188
+ selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
189
+ selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
190
+ # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch)
191
+ selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k))
192
+ selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
193
+ # concat to attn_weights
194
+ # (bsz, seq_len, num_heads, extra attention count + 2*window+1)
195
+ attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
196
+ attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
197
+ if key_padding_mask is not None:
198
+ # softmax sometimes inserts NaN if all positions are masked, replace them with 0
199
+ attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
200
+ attn_weights = attn_weights_float.type_as(attn_weights)
201
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
202
+ v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
203
+ attn = 0
204
+ if extra_attention_mask is not None:
205
+ selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch)
206
+ selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
207
+ selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
208
+ # use `matmul` because `einsum` crashes sometimes with fp16
209
+ # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
210
+ attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2)
211
+ attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous()
212
+
213
+ if self.attention_mode == 'tvm':
214
+ v = v.float().contiguous()
215
+ attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False)
216
+ elif self.attention_mode == "sliding_chunks":
217
+ attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window)
218
+ elif self.attention_mode == "sliding_chunks_no_overlap":
219
+ attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window)
220
+ else:
221
+ raise False
222
+
223
+ attn = attn.type_as(hidden_states)
224
+ assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
225
+ attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous()
226
+
227
+ # For this case, we'll just recompute the attention for these indices
228
+ # and overwrite the attn tensor. TODO: remove the redundant computation
229
+ if extra_attention_mask is not None:
230
+ selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim)
231
+ selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]]
232
+
233
+ q = self.query_global(selected_hidden_states)
234
+ k = self.key_global(hidden_states)
235
+ v = self.value_global(hidden_states)
236
+ q /= math.sqrt(self.head_dim)
237
+
238
+ q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim)
239
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
240
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
241
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
242
+ assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len]
243
+
244
+ attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
245
+ attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0
246
+ if key_padding_mask is not None:
247
+ attn_weights = attn_weights.masked_fill(
248
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
249
+ -10000.0,
250
+ )
251
+ attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len)
252
+ attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
253
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
254
+ selected_attn = torch.bmm(attn_probs, v)
255
+ assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim]
256
+
257
+ selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim)
258
+ nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]]
259
+ attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states)
260
+
261
+ context_layer = attn.transpose(0, 1) # attn shape: (seq_len, bsz, embed_dim), context_layer shape: (bsz, seq_len, embed_dim)
262
+ if output_attentions:
263
+ if extra_attention_mask is not None:
264
+ # With global attention, return global attention probabilities only
265
+ # batch_size x num_heads x max_num_global_attention_tokens x sequence_length
266
+ # which is the attention weights from tokens with global attention to all tokens
267
+ # It doesn't not return local attention
268
+ # In case of variable number of global attantion in the rows of a batch,
269
+ # attn_weights are padded with -10000.0 attention scores
270
+ attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
271
+ else:
272
+ # without global attention, return local attention probabilities
273
+ # batch_size x num_heads x sequence_length x window_size
274
+ # which is the attention weights of every token attending to its neighbours
275
+ attn_weights = attn_weights.permute(0, 2, 1, 3)
276
+ outputs = (context_layer, attn_weights) if output_attentions else (context_layer,)
277
+ return outputs
long_models/longformer_bart.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+
6
+ This code is in part adapted from AllenAI's Longformer:
7
+ https://github.com/allenai/longformer/
8
+ and in part adapted from:
9
+ https://github.com/huggingface/transformers
10
+
11
+ Author: Annette Rios ([email protected])
12
+
13
+ """
14
+ from typing import List, Optional, Tuple, Dict, Union
15
+ from torch import nn, Tensor, zeros
16
+ import torch
17
+ import math
18
+ import random
19
+ from .longformer import LongformerSelfAttention
20
+ from transformers.models.bart.modeling_bart import BartConfig, BartForConditionalGeneration, BartEncoder, BartLearnedPositionalEmbedding, BartEncoderLayer, BartDecoder, BartModel, _expand_mask
21
+ from transformers.modeling_outputs import BaseModelOutput
22
+
23
+ class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration):
24
+ def __init__(self, config):
25
+ super(BartForConditionalGeneration, self).__init__(config)
26
+
27
+ self.model = LongBartModel(config)
28
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
29
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
30
+ #print(self)
31
+
32
+ if config.attention_mode == 'n2':
33
+ pass # do nothing, use BartSelfAttention instead
34
+ else:
35
+ for i, layer in enumerate(self.model.encoder.layers):
36
+ layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)
37
+ # Initialize weights and apply final processing
38
+ self.post_init()
39
+
40
+
41
+ class LongformerEncoderDecoderConfig(BartConfig):
42
+ def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
43
+ autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
44
+ gradient_checkpointing: bool = False, **kwargs):
45
+ """
46
+ Args:
47
+ attention_window: list of attention window sizes of length = number of layers.
48
+ window size = number of attention locations on each side.
49
+ For an affective window size of 512, use `attention_window=[256]*num_layers`
50
+ which is 256 on each side.
51
+ attention_dilation: list of attention dilation of length = number of layers.
52
+ attention dilation of `1` means no dilation.
53
+ autoregressive: do autoregressive attention or have attention of both sides
54
+ attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
55
+ selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
56
+ """
57
+ super().__init__(**kwargs)
58
+ self.attention_window = attention_window
59
+ self.attention_dilation = attention_dilation
60
+ self.autoregressive = autoregressive
61
+ self.attention_mode = attention_mode
62
+ self.gradient_checkpointing = gradient_checkpointing
63
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
64
+
65
+ class LongformerSelfAttentionForBart(nn.Module):
66
+ def __init__(self, config, layer_id):
67
+ super().__init__()
68
+ self.embed_dim = config.d_model
69
+ self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
70
+ self.output = nn.Linear(self.embed_dim, self.embed_dim)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: Tensor, # shape (batch_size, q_len, model_size)
75
+ key_value_states: Optional[Tensor] = None, # cross-attention in transformers.models.bart.modeling_bart
76
+ past_key_value: Optional[Tuple[Tensor]] = None, # only for decoder
77
+ attention_mask: Optional[Tensor] = None, # shape (batch_size, k_len) -> changed in transformers.models.modeling_bart.BartEncoder and BartEncoderLayer (new mask uses bool -> global attention positions are lost, need to use the inverted orignal mask
78
+ layer_head_mask: Optional[Tensor] = None, # head dropout?
79
+ output_attentions: bool = False
80
+ ) -> Tuple[Tensor, Optional[Tensor]]:
81
+
82
+ bsz, tgt_len, embed_dim = hidden_states.size()
83
+ assert embed_dim == self.embed_dim
84
+ assert list(hidden_states.size()) == [bsz, tgt_len, embed_dim]
85
+
86
+ outputs = self.longformer_self_attn(
87
+ hidden_states,
88
+ attention_mask=attention_mask * -1, # shape (batch_size, 1, 1, key_len)
89
+ head_mask=None,
90
+ encoder_hidden_states=None,
91
+ encoder_attention_mask=None,
92
+ output_attentions=output_attentions,
93
+ )
94
+
95
+ ## new: Bart encoder expects shape (seq_len, bsz, embed_dim), no transpose needed
96
+ attn_output = self.output(outputs[0])
97
+ # new return in BartAttention has attn_output, attn_weights_reshaped, past_key_value (only for decoder), need to return 3 values (None for past_key_value)
98
+ return (attn_output, outputs[1:] ,None) if len(outputs) == 2 else (attn_output, None, None)
99
+
100
+
101
+ class LongBartEncoder(BartEncoder):
102
+ """
103
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
104
+ [`BartEncoderLayer`].
105
+
106
+ Args:
107
+ config: BartConfig
108
+ embed_tokens (nn.Embedding): output embedding
109
+ """
110
+
111
+ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
112
+ super().__init__(config)
113
+
114
+ self.dropout = config.dropout
115
+ self.layerdrop = config.encoder_layerdrop
116
+
117
+ embed_dim = config.d_model
118
+ self.padding_idx = config.pad_token_id
119
+ self.max_source_positions = config.max_encoder_position_embeddings
120
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
121
+
122
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
123
+
124
+ if embed_tokens is not None:
125
+ self.embed_tokens.weight = embed_tokens.weight
126
+
127
+ self.embed_positions = BartLearnedPositionalEmbedding(
128
+ self.max_source_positions,
129
+ embed_dim,
130
+ )
131
+ self.layers = nn.ModuleList([LongBartEncoderLayer(config) for _ in range(config.encoder_layers)])
132
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
133
+
134
+ self.gradient_checkpointing = False
135
+ # Initialize weights and apply final processing
136
+ self.post_init()
137
+
138
+ def forward(
139
+ self,
140
+ input_ids: torch.LongTensor = None,
141
+ attention_mask: Optional[torch.Tensor] = None,
142
+ head_mask: Optional[torch.Tensor] = None,
143
+ inputs_embeds: Optional[torch.FloatTensor] = None,
144
+ output_attentions: Optional[bool] = None,
145
+ output_hidden_states: Optional[bool] = None,
146
+ return_dict: Optional[bool] = None,
147
+ ) -> Union[Tuple, BaseModelOutput]:
148
+ r"""
149
+ Args:
150
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
151
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
152
+ provide it.
153
+
154
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
155
+ [`PreTrainedTokenizer.__call__`] for details.
156
+
157
+ [What are input IDs?](../glossary#input-ids)
158
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
159
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
160
+
161
+ - 1 for tokens that are **not masked**,
162
+ - 0 for tokens that are **masked**.
163
+
164
+ [What are attention masks?](../glossary#attention-mask)
165
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
166
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
167
+
168
+ - 1 indicates the head is **not masked**,
169
+ - 0 indicates the head is **masked**.
170
+
171
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
172
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
173
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
174
+ than the model's internal embedding lookup matrix.
175
+ output_attentions (`bool`, *optional*):
176
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
177
+ returned tensors for more detail.
178
+ output_hidden_states (`bool`, *optional*):
179
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
180
+ for more detail.
181
+ return_dict (`bool`, *optional*):
182
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
183
+ """
184
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
185
+ output_hidden_states = (
186
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
187
+ )
188
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
189
+
190
+ # retrieve input_ids and inputs_embeds
191
+ if input_ids is not None and inputs_embeds is not None:
192
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
193
+ elif input_ids is not None:
194
+ input = input_ids
195
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
196
+ elif inputs_embeds is not None:
197
+ input = inputs_embeds[:, :, -1]
198
+ else:
199
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
200
+
201
+ if inputs_embeds is None:
202
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
203
+
204
+ embed_pos = self.embed_positions(input)
205
+ embed_pos = embed_pos.to(inputs_embeds.device)
206
+
207
+ hidden_states = inputs_embeds + embed_pos
208
+ hidden_states = self.layernorm_embedding(hidden_states)
209
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
210
+
211
+ # expand attention_mask
212
+ longformer_attention_mask = None
213
+ if attention_mask is not None:
214
+ # need to return original, inverted mask for longformer attention, else value for global attention (=2 in given mask, will be -1) is lost
215
+ longformer_attention_mask = 1 - attention_mask
216
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
217
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
218
+
219
+
220
+ encoder_states = () if output_hidden_states else None
221
+ all_attentions = () if output_attentions else None
222
+
223
+ # check if head_mask has a correct number of layers specified if desired
224
+ if head_mask is not None:
225
+ if head_mask.size()[0] != len(self.layers):
226
+ raise ValueError(
227
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
228
+ f" {head_mask.size()[0]}."
229
+ )
230
+
231
+ for idx, encoder_layer in enumerate(self.layers):
232
+ if output_hidden_states:
233
+ encoder_states = encoder_states + (hidden_states,)
234
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
235
+ dropout_probability = random.uniform(0, 1)
236
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
237
+ layer_outputs = (None, None)
238
+ else:
239
+ if self.gradient_checkpointing and self.training:
240
+
241
+ def create_custom_forward(module):
242
+ def custom_forward(*inputs):
243
+ return module(*inputs, output_attentions)
244
+
245
+ return custom_forward
246
+
247
+ layer_outputs = torch.utils.checkpoint.checkpoint(
248
+ create_custom_forward(encoder_layer),
249
+ hidden_states,
250
+ attention_mask,
251
+ longformer_attention_mask,
252
+ (head_mask[idx] if head_mask is not None else None),
253
+ )
254
+ else:
255
+ layer_outputs = encoder_layer(
256
+ hidden_states,
257
+ attention_mask,
258
+ longformer_attention_mask,
259
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
260
+ output_attentions=output_attentions,
261
+ )
262
+
263
+ hidden_states = layer_outputs[0]
264
+
265
+ if output_attentions:
266
+ all_attentions = all_attentions + (layer_outputs[1],)
267
+
268
+ if output_hidden_states:
269
+ encoder_states = encoder_states + (hidden_states,)
270
+
271
+ if not return_dict:
272
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
273
+ return BaseModelOutput(
274
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
275
+ )
276
+
277
+
278
+ class LongBartModel(BartModel):
279
+ def __init__(self, config: BartConfig):
280
+ super().__init__(config)
281
+
282
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
283
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
284
+
285
+ self.encoder = LongBartEncoder(config, self.shared)
286
+ self.decoder = BartDecoder(config, self.shared)
287
+
288
+ # Initialize weights and apply final processing
289
+ self.post_init()
290
+
291
+
292
+ class LongBartEncoderLayer(BartEncoderLayer):
293
+ def __init__(self, config: BartConfig):
294
+ super().__init__(config)
295
+
296
+ def forward(
297
+ self,
298
+ hidden_states: torch.FloatTensor,
299
+ attention_mask: torch.FloatTensor,
300
+ longformer_attention_mask: torch.Tensor,
301
+ layer_head_mask: torch.FloatTensor,
302
+ output_attentions: bool = False,
303
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
304
+ """
305
+ Args:
306
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
307
+ attention_mask (`torch.FloatTensor`): attention mask of size
308
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
309
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
310
+ `(encoder_attention_heads,)`.
311
+ output_attentions (`bool`, *optional*):
312
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
313
+ returned tensors for more detail.
314
+ """
315
+ # if longformer attention instead of bart self attention: use special mask
316
+ if isinstance(self.self_attn, LongformerSelfAttentionForBart):
317
+ attention_mask = longformer_attention_mask
318
+ residual = hidden_states
319
+ hidden_states, attn_weights, _ = self.self_attn(
320
+ hidden_states=hidden_states,
321
+ attention_mask=attention_mask,
322
+ layer_head_mask=layer_head_mask,
323
+ output_attentions=output_attentions,
324
+ )
325
+
326
+ hidden_states = self.self_attn_layer_norm(hidden_states)
327
+ hidden_states, attn_weights, _ = self.self_attn(
328
+ hidden_states=hidden_states,
329
+ attention_mask=attention_mask,
330
+ layer_head_mask=layer_head_mask,
331
+ output_attentions=output_attentions,
332
+ )
333
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
334
+ hidden_states = residual + hidden_states
335
+ hidden_states = self.self_attn_layer_norm(hidden_states)
336
+
337
+ residual = hidden_states
338
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
339
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
340
+ hidden_states = self.fc2(hidden_states)
341
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
342
+ hidden_states = residual + hidden_states
343
+ hidden_states = self.final_layer_norm(hidden_states)
344
+
345
+ if hidden_states.dtype == torch.float16 and (
346
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
347
+ ):
348
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
349
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
350
+
351
+ outputs = (hidden_states,)
352
+
353
+ if output_attentions:
354
+ outputs += (attn_weights,)
355
+
356
+ return outputs
long_models/longformer_mbart.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+
6
+ This code is in part adapted from AllenAI's Longformer:
7
+ https://github.com/allenai/longformer/
8
+ and in part adapted from:
9
+ https://github.com/huggingface/transformers
10
+
11
+ Author: Annette Rios ([email protected])
12
+
13
+ """
14
+ from typing import List, Optional, Tuple, Dict, Union
15
+ from torch import nn, Tensor, zeros
16
+ import torch
17
+ import math
18
+ import random
19
+ from .longformer import LongformerSelfAttention
20
+ from transformers.models.mbart.modeling_mbart import MBartConfig, MBartForConditionalGeneration, MBartEncoder, MBartLearnedPositionalEmbedding, MBartEncoderLayer, MBartDecoder, MBartModel, _expand_mask
21
+ from transformers.modeling_outputs import BaseModelOutput
22
+
23
+ class MLongformerEncoderDecoderForConditionalGeneration(MBartForConditionalGeneration):
24
+ def __init__(self, config):
25
+ super(MBartForConditionalGeneration, self).__init__(config)
26
+
27
+ self.model = LongMBartModel(config)
28
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
29
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
30
+ #print(self)
31
+
32
+ if config.attention_mode == 'n2':
33
+ pass # do nothing, use MBartSelfAttention instead
34
+ else:
35
+ for i, layer in enumerate(self.model.encoder.layers):
36
+ layer.self_attn = LongformerSelfAttentionForMBart(config, layer_id=i)
37
+ # Initialize weights and apply final processing
38
+ self.post_init()
39
+
40
+
41
+ class MLongformerEncoderDecoderConfig(MBartConfig):
42
+ def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
43
+ autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
44
+ gradient_checkpointing: bool = False, **kwargs):
45
+ """
46
+ Args:
47
+ attention_window: list of attention window sizes of length = number of layers.
48
+ window size = number of attention locations on each side.
49
+ For an affective window size of 512, use `attention_window=[256]*num_layers`
50
+ which is 256 on each side.
51
+ attention_dilation: list of attention dilation of length = number of layers.
52
+ attention dilation of `1` means no dilation.
53
+ autoregressive: do autoregressive attention or have attention of both sides
54
+ attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
55
+ selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
56
+ """
57
+ super().__init__(**kwargs)
58
+ self.attention_window = attention_window
59
+ self.attention_dilation = attention_dilation
60
+ self.autoregressive = autoregressive
61
+ self.attention_mode = attention_mode
62
+ self.gradient_checkpointing = gradient_checkpointing
63
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
64
+
65
+ class LongformerSelfAttentionForMBart(nn.Module):
66
+ def __init__(self, config, layer_id):
67
+ super().__init__()
68
+ self.embed_dim = config.d_model
69
+ self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
70
+ self.output = nn.Linear(self.embed_dim, self.embed_dim)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: Tensor, # shape (batch_size, q_len, model_size)
75
+ key_value_states: Optional[Tensor] = None, # cross-attention in transformers.models.mbart.modeling_mbart
76
+ past_key_value: Optional[Tuple[Tensor]] = None, # only for decoder
77
+ attention_mask: Optional[Tensor] = None, # shape (batch_size, k_len) -> changed in transformers.models.modeling_mbart.MBartEncoder and MBartEncoderLayer (new mask uses bool -> global attention positions are lost, need to use the inverted orignal mask
78
+ layer_head_mask: Optional[Tensor] = None, # head dropout?
79
+ output_attentions: bool = False
80
+ ) -> Tuple[Tensor, Optional[Tensor]]:
81
+
82
+ bsz, tgt_len, embed_dim = hidden_states.size()
83
+ assert embed_dim == self.embed_dim
84
+ assert list(hidden_states.size()) == [bsz, tgt_len, embed_dim]
85
+
86
+ outputs = self.longformer_self_attn(
87
+ hidden_states,
88
+ attention_mask=attention_mask * -1, # shape (batch_size, 1, 1, key_len)
89
+ head_mask=None,
90
+ encoder_hidden_states=None,
91
+ encoder_attention_mask=None,
92
+ output_attentions=output_attentions,
93
+ )
94
+
95
+ ## new: MBart encoder expects shape (seq_len, bsz, embed_dim), no transpose needed
96
+ attn_output = self.output(outputs[0])
97
+ # new return in MBartAttention has attn_output, attn_weights_reshaped, past_key_value (only for decoder), need to return 3 values (None for past_key_value)
98
+ return (attn_output, outputs[1:] ,None) if len(outputs) == 2 else (attn_output, None, None)
99
+
100
+
101
+ class LongMBartEncoder(MBartEncoder):
102
+ """
103
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
104
+ [`MBartEncoderLayer`].
105
+
106
+ Args:
107
+ config: MBartConfig
108
+ embed_tokens (nn.Embedding): output embedding
109
+ """
110
+
111
+ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
112
+ super().__init__(config)
113
+
114
+ self.dropout = config.dropout
115
+ self.layerdrop = config.encoder_layerdrop
116
+
117
+ embed_dim = config.d_model
118
+ self.padding_idx = config.pad_token_id
119
+ self.max_source_positions = config.max_encoder_position_embeddings
120
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
121
+
122
+ if embed_tokens is not None:
123
+ self.embed_tokens = embed_tokens
124
+ else:
125
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
126
+
127
+ self.embed_positions = MBartLearnedPositionalEmbedding(
128
+ self.max_source_positions,
129
+ embed_dim,
130
+ )
131
+ self.layers = nn.ModuleList([LongMBartEncoderLayer(config) for _ in range(config.encoder_layers)])
132
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
133
+ self.layer_norm = nn.LayerNorm(config.d_model)
134
+
135
+ self.gradient_checkpointing = False
136
+ # Initialize weights and apply final processing
137
+ self.post_init()
138
+
139
+ def forward(
140
+ self,
141
+ input_ids: torch.LongTensor = None,
142
+ attention_mask: Optional[torch.Tensor] = None,
143
+ head_mask: Optional[torch.Tensor] = None,
144
+ inputs_embeds: Optional[torch.FloatTensor] = None,
145
+ output_attentions: Optional[bool] = None,
146
+ output_hidden_states: Optional[bool] = None,
147
+ return_dict: Optional[bool] = None,
148
+ ) -> Union[Tuple, BaseModelOutput]:
149
+ r"""
150
+ Args:
151
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
152
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
153
+ provide it.
154
+
155
+ Indices can be obtained using [`MBartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
156
+ [`PreTrainedTokenizer.__call__`] for details.
157
+
158
+ [What are input IDs?](../glossary#input-ids)
159
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
160
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
161
+
162
+ - 1 for tokens that are **not masked**,
163
+ - 0 for tokens that are **masked**.
164
+
165
+ [What are attention masks?](../glossary#attention-mask)
166
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
167
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
168
+
169
+ - 1 indicates the head is **not masked**,
170
+ - 0 indicates the head is **masked**.
171
+
172
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
173
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
174
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
175
+ than the model's internal embedding lookup matrix.
176
+ output_attentions (`bool`, *optional*):
177
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
178
+ returned tensors for more detail.
179
+ output_hidden_states (`bool`, *optional*):
180
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
181
+ for more detail.
182
+ return_dict (`bool`, *optional*):
183
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
184
+ """
185
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
186
+ output_hidden_states = (
187
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
188
+ )
189
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
190
+
191
+ # retrieve input_ids and inputs_embeds
192
+ if input_ids is not None and inputs_embeds is not None:
193
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
194
+ elif input_ids is not None:
195
+ input = input_ids
196
+ input_shape = input.shape
197
+ input_ids = input_ids.view(-1, input_shape[-1])
198
+ elif inputs_embeds is not None:
199
+ input = inputs_embeds[:, :, -1]
200
+ else:
201
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
202
+
203
+ if inputs_embeds is None:
204
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
205
+
206
+ embed_pos = self.embed_positions(input)
207
+
208
+ hidden_states = inputs_embeds + embed_pos
209
+ hidden_states = self.layernorm_embedding(hidden_states)
210
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
211
+
212
+ # expand attention_mask
213
+ longformer_attention_mask = None
214
+ if attention_mask is not None:
215
+ # need to return original, inverted mask for longformer attention, else value for global attention (=2 in given mask, will be -1) is lost
216
+ longformer_attention_mask = 1 - attention_mask
217
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
218
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
219
+
220
+
221
+ encoder_states = () if output_hidden_states else None
222
+ all_attentions = () if output_attentions else None
223
+
224
+ # check if head_mask has a correct number of layers specified if desired
225
+ if head_mask is not None:
226
+ if head_mask.size()[0] != len(self.layers):
227
+ raise ValueError(
228
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
229
+ f" {head_mask.size()[0]}."
230
+ )
231
+ for idx, encoder_layer in enumerate(self.layers):
232
+ if output_hidden_states:
233
+ encoder_states = encoder_states + (hidden_states,)
234
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
235
+ dropout_probability = random.uniform(0, 1)
236
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
237
+ layer_outputs = (None, None)
238
+ else:
239
+ if self.gradient_checkpointing and self.training:
240
+
241
+ def create_custom_forward(module):
242
+ def custom_forward(*inputs):
243
+ return module(*inputs, output_attentions)
244
+
245
+ return custom_forward
246
+
247
+ layer_outputs = torch.utils.checkpoint.checkpoint(
248
+ create_custom_forward(encoder_layer),
249
+ hidden_states,
250
+ attention_mask,
251
+ longformer_attention_mask,
252
+ (head_mask[idx] if head_mask is not None else None),
253
+ )
254
+ else:
255
+ layer_outputs = encoder_layer(
256
+ hidden_states,
257
+ attention_mask,
258
+ longformer_attention_mask,
259
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
260
+ output_attentions=output_attentions,
261
+ )
262
+
263
+ hidden_states = layer_outputs[0]
264
+
265
+ if output_attentions:
266
+ all_attentions = all_attentions + (layer_outputs[1],)
267
+
268
+ hidden_states = self.layer_norm(hidden_states)
269
+
270
+ if output_hidden_states:
271
+ encoder_states = encoder_states + (hidden_states,)
272
+
273
+ if not return_dict:
274
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
275
+ return BaseModelOutput(
276
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
277
+ )
278
+
279
+
280
+ class LongMBartModel(MBartModel):
281
+ def __init__(self, config: MBartConfig):
282
+ super().__init__(config)
283
+
284
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
285
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
286
+
287
+ self.encoder = LongMBartEncoder(config, self.shared)
288
+ self.decoder = MBartDecoder(config, self.shared)
289
+
290
+ # Initialize weights and apply final processing
291
+ self.post_init()
292
+
293
+
294
+ class LongMBartEncoderLayer(MBartEncoderLayer):
295
+ def __init__(self, config: MBartConfig):
296
+ super().__init__(config)
297
+
298
+ def forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ attention_mask: torch.Tensor,
302
+ longformer_attention_mask: torch.Tensor,
303
+ layer_head_mask: torch.Tensor,
304
+ output_attentions: bool = False,
305
+ ) -> torch.Tensor:
306
+ """
307
+ Args:
308
+ hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
309
+ attention_mask (`torch.FloatTensor`): attention mask of size
310
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
311
+ longformer_attention_mask (:obj:`torch.FloatTensor`): attention mask of size
312
+ `(batch, src_len)` where 0=local, -1=global, 1=padding.
313
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
314
+ *(encoder_attention_heads,)*.
315
+ output_attentions (`bool`, *optional*):
316
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
317
+ returned tensors for more detail.
318
+ """
319
+ # if longformer attention instead of mbart self attention: use special mask
320
+ if isinstance(self.self_attn, LongformerSelfAttentionForMBart):
321
+ attention_mask = longformer_attention_mask
322
+ residual = hidden_states
323
+ hidden_states = self.self_attn_layer_norm(hidden_states)
324
+ hidden_states, attn_weights, _ = self.self_attn(
325
+ hidden_states=hidden_states,
326
+ attention_mask=attention_mask,
327
+ layer_head_mask=layer_head_mask,
328
+ output_attentions=output_attentions,
329
+ )
330
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
331
+ hidden_states = residual + hidden_states
332
+
333
+ residual = hidden_states
334
+ hidden_states = self.final_layer_norm(hidden_states)
335
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
336
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
337
+ hidden_states = self.fc2(hidden_states)
338
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
339
+ hidden_states = residual + hidden_states
340
+
341
+ if hidden_states.dtype == torch.float16 and (
342
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
343
+ ):
344
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
345
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
346
+
347
+ outputs = (hidden_states,)
348
+
349
+ if output_attentions:
350
+ outputs += (attn_weights,)
351
+
352
+ return outputs
long_models/sliding_chunks.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+
6
+ This code is from AllenAI's Longformer:
7
+ https://github.com/allenai/longformer/
8
+
9
+ """
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from .diagonaled_mm_tvm import mask_invalid_locations
13
+
14
+
15
+ def _skew(x, direction, padding_value):
16
+ '''Convert diagonals into columns (or columns into diagonals depending on `direction`'''
17
+ x_padded = F.pad(x, direction, value=padding_value)
18
+ x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
19
+ return x_padded
20
+
21
+
22
+ def _skew2(x, padding_value):
23
+ '''shift every row 1 step to right converting columns into diagonals'''
24
+ # X = B x C x M x L
25
+ B, C, M, L = x.size()
26
+ x = F.pad(x, (0, M + 1), value=padding_value) # B x C x M x (L+M+1)
27
+ x = x.view(B, C, -1) # B x C x ML+MM+M
28
+ x = x[:, :, :-M] # B x C x ML+MM
29
+ x = x.view(B, C, M, M + L) # B x C, M x L+M
30
+ x = x[:, :, :, :-1]
31
+ return x
32
+
33
+
34
+ def _chunk(x, w):
35
+ '''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''
36
+
37
+ # non-overlapping chunks of size = 2w
38
+ x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))
39
+
40
+ # use `as_strided` to make the chunks overlap with an overlap size = w
41
+ chunk_size = list(x.size())
42
+ chunk_size[1] = chunk_size[1] * 2 - 1
43
+
44
+ chunk_stride = list(x.stride())
45
+ chunk_stride[1] = chunk_stride[1] // 2
46
+ return x.as_strided(size=chunk_size, stride=chunk_stride)
47
+
48
+
49
+ def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
50
+ '''Matrix multiplicatio of query x key tensors using with a sliding window attention pattern.
51
+ This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
52
+ with an overlap of size w'''
53
+ bsz, seqlen, num_heads, head_dim = q.size()
54
+ assert seqlen % (w * 2) == 0
55
+ assert q.size() == k.size()
56
+
57
+ chunks_count = seqlen // w - 1
58
+
59
+ # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
60
+ q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
61
+ k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
62
+
63
+ chunk_q = _chunk(q, w)
64
+ chunk_k = _chunk(k, w)
65
+
66
+ # matrix multipication
67
+ # bcxd: bsz*num_heads x chunks x 2w x head_dim
68
+ # bcyd: bsz*num_heads x chunks x 2w x head_dim
69
+ # bcxy: bsz*num_heads x chunks x 2w x 2w
70
+ chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k)) # multiply
71
+
72
+ # convert diagonals into columns
73
+ diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value)
74
+
75
+ # allocate space for the overall attention matrix where the chunks are compined. The last dimension
76
+ # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to
77
+ # w previous words). The following column is attention score from each word to itself, then
78
+ # followed by w columns for the upper triangle.
79
+
80
+ diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1))
81
+
82
+ # copy parts from diagonal_chunk_attn into the compined matrix of attentions
83
+ # - copying the main diagonal and the upper triangle
84
+ diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1]
85
+ diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1]
86
+ # - copying the lower triangle
87
+ diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:]
88
+ diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:]
89
+
90
+ # separate bsz and num_heads dimensions again
91
+ diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1).transpose(2, 1)
92
+
93
+ mask_invalid_locations(diagonal_attn, w, 1, False)
94
+ return diagonal_attn
95
+
96
+
97
+ def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
98
+ '''Same as sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output
99
+ format from sliding_chunks_matmul_qk'''
100
+ bsz, seqlen, num_heads, head_dim = v.size()
101
+ assert seqlen % (w * 2) == 0
102
+ assert prob.size()[:3] == v.size()[:3]
103
+ assert prob.size(3) == 2 * w + 1
104
+ chunks_count = seqlen // w - 1
105
+ # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
106
+ chunk_prob = prob.transpose(1, 2).reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1)
107
+
108
+ # group bsz and num_heads dimensions into one
109
+ v = v.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
110
+
111
+ # pad seqlen with w at the beginning of the sequence and another w at the end
112
+ padded_v = F.pad(v, (0, 0, w, w), value=-1)
113
+
114
+ # chunk padded_v into chunks of size 3w and an overlap of size w
115
+ chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim)
116
+ chunk_v_stride = padded_v.stride()
117
+ chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
118
+ chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)
119
+
120
+ skewed_prob = _skew2(chunk_prob, padding_value=0)
121
+
122
+ context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v))
123
+ return context.view(bsz, num_heads, seqlen, head_dim).transpose(1, 2)
124
+
125
+
126
+ def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor,
127
+ one_sided_window_size: int, pad_token_id: int):
128
+ '''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer selfattention.
129
+ Input:
130
+ input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces
131
+ attention_mask = torch.Tensor(bsz x seqlen): attention mask
132
+ one_sided_window_size = int: window size on one side of each token
133
+ pad_token_id = int: tokenizer.pad_token_id
134
+ Returns
135
+ (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size
136
+ '''
137
+ w = int(2 * one_sided_window_size)
138
+ seqlen = input_ids.size(1)
139
+ padding_len = (w - seqlen % w) % w
140
+ input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
141
+ attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
142
+ return input_ids, attention_mask
143
+
144
+
145
+ # ========= "sliding_chunks_no_overlap": alternative implemenation of the sliding window attention =========
146
+ # This implementation uses non-overlapping chunks (or blocks) of size `w` with number of local attention = 3xw
147
+ # To make this implemenation comparable to "sliding_chunks" set w such that
148
+ # w_of_sliding_chunks_no_overlap = w_of_sliding_chunks * 2 / 3
149
+ # For example,
150
+ # w_of_sliding_chunks = 256 (this is one sided. Total attention size = 512)
151
+ # w_of_sliding_chunks_no_overlap = 170 (Total attention size = 510)
152
+ # Performance:
153
+ # - Speed: 30% faster than "sliding_chunks"
154
+ # - Memory: 95% of the memory usage of "sliding_chunks"
155
+ # The windows are asymmetric where number of attention on each side of a token ranges between w to 2w
156
+ # while "sliding_chunks" has a symmetric window around each token.
157
+
158
+
159
+ def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
160
+ bsz, seqlen, num_heads, head_dim = q.size()
161
+ assert seqlen % w == 0
162
+ assert q.size() == k.size()
163
+ # chunk seqlen into non-overlapping chunks of size w
164
+ chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim)
165
+ chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim)
166
+ chunk_k_expanded = torch.stack((
167
+ F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
168
+ chunk_k,
169
+ F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
170
+ ), dim=-1)
171
+ diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) # multiply
172
+ return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w)
173
+
174
+
175
+ def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
176
+ bsz, seqlen, num_heads, head_dim = v.size()
177
+ chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w)
178
+ chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim)
179
+ chunk_v_extended = torch.stack((
180
+ F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
181
+ chunk_v,
182
+ F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
183
+ ), dim=-1)
184
+ context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended))
185
+ return context.reshape(bsz, seqlen, num_heads, head_dim)