omarmomen commited on
Commit
c799e10
1 Parent(s): bfdf249
Files changed (3) hide show
  1. config.json +23 -0
  2. pytorch_model.bin +3 -0
  3. structformer.py +1014 -0
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TransformerModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "structformer.TransformerConfig",
7
+ "AutoModelForMaskedLM": "structformer.TransformerModel"
8
+ },
9
+ "dropatt": 0.1,
10
+ "dropout": 0.1,
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 512,
13
+ "initializer_range": 0.02,
14
+ "model_type": "transformer",
15
+ "nhead": 8,
16
+ "nlayers": 8,
17
+ "ntokens": 16000,
18
+ "pad": 1,
19
+ "pos_emb": true,
20
+ "relative_bias": false,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.18.0"
23
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f295cf6b736aeee12757ea7ecc1c32ead791a33adf69b46f60969f6cd0b87a2
3
+ size 134778679
structformer.py ADDED
@@ -0,0 +1,1014 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """StructFormer and transformer model."""
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn import init
22
+ from transformers import PretrainedConfig, PreTrainedModel
23
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
24
+
25
+ def _get_activation_fn(activation):
26
+ """Get specified activation function."""
27
+ if activation == "relu":
28
+ return nn.ReLU()
29
+ elif activation == "gelu":
30
+ return nn.GELU()
31
+ elif activation == "leakyrelu":
32
+ return nn.LeakyReLU()
33
+
34
+ raise RuntimeError(
35
+ "activation should be relu/gelu, not {}".format(activation))
36
+
37
+
38
+ class Conv1d(nn.Module):
39
+ """1D convolution layer."""
40
+
41
+ def __init__(self, hidden_size, kernel_size, dilation=1):
42
+ """Initialization.
43
+
44
+ Args:
45
+ hidden_size: dimension of input embeddings
46
+ kernel_size: convolution kernel size
47
+ dilation: the spacing between the kernel points
48
+ """
49
+ super(Conv1d, self).__init__()
50
+
51
+ if kernel_size % 2 == 0:
52
+ padding = (kernel_size // 2) * dilation
53
+ self.shift = True
54
+ else:
55
+ padding = ((kernel_size - 1) // 2) * dilation
56
+ self.shift = False
57
+ self.conv = nn.Conv1d(
58
+ hidden_size,
59
+ hidden_size,
60
+ kernel_size,
61
+ padding=padding,
62
+ dilation=dilation)
63
+
64
+ def forward(self, x):
65
+ """Compute convolution.
66
+
67
+ Args:
68
+ x: input embeddings
69
+ Returns:
70
+ conv_output: convolution results
71
+ """
72
+
73
+ if self.shift:
74
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
75
+ else:
76
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
77
+
78
+
79
+ class MultiheadAttention(nn.Module):
80
+ """Multi-head self-attention layer."""
81
+
82
+ def __init__(self,
83
+ embed_dim,
84
+ num_heads,
85
+ dropout=0.,
86
+ bias=True,
87
+ v_proj=True,
88
+ out_proj=True,
89
+ relative_bias=True):
90
+ """Initialization.
91
+
92
+ Args:
93
+ embed_dim: dimension of input embeddings
94
+ num_heads: number of self-attention heads
95
+ dropout: dropout rate
96
+ bias: bool, indicate whether include bias for linear transformations
97
+ v_proj: bool, indicate whether project inputs to new values
98
+ out_proj: bool, indicate whether project outputs to new values
99
+ relative_bias: bool, indicate whether use a relative position based
100
+ attention bias
101
+ """
102
+
103
+ super(MultiheadAttention, self).__init__()
104
+ self.embed_dim = embed_dim
105
+
106
+ self.num_heads = num_heads
107
+ self.drop = nn.Dropout(dropout)
108
+ self.head_dim = embed_dim // num_heads
109
+ assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be "
110
+ "divisible by "
111
+ "num_heads")
112
+
113
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
114
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
115
+ if v_proj:
116
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
117
+ else:
118
+ self.v_proj = nn.Identity()
119
+
120
+ if out_proj:
121
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
122
+ else:
123
+ self.out_proj = nn.Identity()
124
+
125
+ if relative_bias:
126
+ self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512)))
127
+ else:
128
+ self.relative_bias = None
129
+
130
+ self._reset_parameters()
131
+
132
+ def _reset_parameters(self):
133
+ """Initialize attention parameters."""
134
+
135
+ init.xavier_uniform_(self.q_proj.weight)
136
+ init.constant_(self.q_proj.bias, 0.)
137
+
138
+ init.xavier_uniform_(self.k_proj.weight)
139
+ init.constant_(self.k_proj.bias, 0.)
140
+
141
+ if isinstance(self.v_proj, nn.Linear):
142
+ init.xavier_uniform_(self.v_proj.weight)
143
+ init.constant_(self.v_proj.bias, 0.)
144
+
145
+ if isinstance(self.out_proj, nn.Linear):
146
+ init.xavier_uniform_(self.out_proj.weight)
147
+ init.constant_(self.out_proj.bias, 0.)
148
+
149
+ def forward(self, query, key_padding_mask=None, attn_mask=None):
150
+ """Compute multi-head self-attention.
151
+
152
+ Args:
153
+ query: input embeddings
154
+ key_padding_mask: 3D mask that prevents attention to certain positions
155
+ attn_mask: 3D mask that rescale the attention weight at each position
156
+ Returns:
157
+ attn_output: self-attention output
158
+ """
159
+
160
+ length, bsz, embed_dim = query.size()
161
+ assert embed_dim == self.embed_dim
162
+
163
+ head_dim = embed_dim // self.num_heads
164
+ assert head_dim * self.num_heads == embed_dim, ("embed_dim must be "
165
+ "divisible by num_heads")
166
+ scaling = float(head_dim)**-0.5
167
+
168
+ q = self.q_proj(query)
169
+ k = self.k_proj(query)
170
+ v = self.v_proj(query)
171
+
172
+ q = q * scaling
173
+
174
+ if attn_mask is not None:
175
+ assert list(attn_mask.size()) == [bsz * self.num_heads,
176
+ query.size(0), query.size(0)]
177
+
178
+ q = q.contiguous().view(length, bsz * self.num_heads,
179
+ head_dim).transpose(0, 1)
180
+ k = k.contiguous().view(length, bsz * self.num_heads,
181
+ head_dim).transpose(0, 1)
182
+ v = v.contiguous().view(length, bsz * self.num_heads,
183
+ head_dim).transpose(0, 1)
184
+
185
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
186
+ assert list(
187
+ attn_output_weights.size()) == [bsz * self.num_heads, length, length]
188
+
189
+ if self.relative_bias is not None:
190
+ pos = torch.arange(length, device=query.device)
191
+ relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256
192
+ relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1,
193
+ -1)
194
+
195
+ relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0)
196
+ relative_bias = relative_bias[:, None, :].expand(-1, length, -1)
197
+ relative_bias = torch.gather(relative_bias, 2, relative_pos)
198
+ attn_output_weights = attn_output_weights + relative_bias
199
+
200
+ if key_padding_mask is not None:
201
+ attn_output_weights = attn_output_weights + key_padding_mask
202
+
203
+ if attn_mask is None:
204
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
205
+ else:
206
+ attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask
207
+
208
+ attn_output_weights = self.drop(attn_output_weights)
209
+
210
+ attn_output = torch.bmm(attn_output_weights, v)
211
+
212
+ assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim]
213
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
214
+ length, bsz, embed_dim)
215
+ attn_output = self.out_proj(attn_output)
216
+
217
+ return attn_output
218
+
219
+
220
+ class TransformerLayer(nn.Module):
221
+ """TransformerEncoderLayer is made up of self-attn and feedforward network."""
222
+
223
+ def __init__(self,
224
+ d_model,
225
+ nhead,
226
+ dim_feedforward=2048,
227
+ dropout=0.1,
228
+ dropatt=0.1,
229
+ activation="leakyrelu",
230
+ relative_bias=True):
231
+ """Initialization.
232
+
233
+ Args:
234
+ d_model: dimension of inputs
235
+ nhead: number of self-attention heads
236
+ dim_feedforward: dimension of hidden layer in feedforward layer
237
+ dropout: dropout rate
238
+ dropatt: drop attention rate
239
+ activation: activation function
240
+ relative_bias: bool, indicate whether use a relative position based
241
+ attention bias
242
+ """
243
+
244
+ super(TransformerLayer, self).__init__()
245
+ self.self_attn = MultiheadAttention(
246
+ d_model, nhead, dropout=dropatt, relative_bias=relative_bias)
247
+ # Implementation of Feedforward model
248
+ self.feedforward = nn.Sequential(
249
+ nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward),
250
+ _get_activation_fn(activation), nn.Dropout(dropout),
251
+ nn.Linear(dim_feedforward, d_model))
252
+
253
+ self.norm = nn.LayerNorm(d_model)
254
+ self.dropout1 = nn.Dropout(dropout)
255
+ self.dropout2 = nn.Dropout(dropout)
256
+
257
+ self.nhead = nhead
258
+
259
+ def forward(self, src, attn_mask=None, key_padding_mask=None):
260
+ """Pass the input through the encoder layer.
261
+
262
+ Args:
263
+ src: the sequence to the encoder layer (required).
264
+ attn_mask: the mask for the src sequence (optional).
265
+ key_padding_mask: the mask for the src keys per batch (optional).
266
+ Returns:
267
+ src3: the output of transformer layer, share the same shape as src.
268
+ """
269
+ src2 = self.self_attn(
270
+ self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask)
271
+ src2 = src + self.dropout1(src2)
272
+ src3 = self.feedforward(src2)
273
+ src3 = src2 + self.dropout2(src3)
274
+
275
+ return src3
276
+
277
+
278
+ def cumprod(x, reverse=False, exclusive=False):
279
+ """cumulative product."""
280
+ if reverse:
281
+ x = x.flip([-1])
282
+
283
+ if exclusive:
284
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
285
+
286
+ cx = x.cumprod(-1)
287
+
288
+ if reverse:
289
+ cx = cx.flip([-1])
290
+ return cx
291
+
292
+ def cumsum(x, reverse=False, exclusive=False):
293
+ """cumulative sum."""
294
+ bsz, _, length = x.size()
295
+ device = x.device
296
+ if reverse:
297
+ if exclusive:
298
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
299
+ else:
300
+ w = torch.ones([bsz, length, length], device=device).tril(0)
301
+ cx = torch.bmm(x, w)
302
+ else:
303
+ if exclusive:
304
+ w = torch.ones([bsz, length, length], device=device).triu(1)
305
+ else:
306
+ w = torch.ones([bsz, length, length], device=device).triu(0)
307
+ cx = torch.bmm(x, w)
308
+ return cx
309
+
310
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
311
+ """cumulative min."""
312
+ if reverse:
313
+ if exclusive:
314
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
315
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
316
+ else:
317
+ if exclusive:
318
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
319
+ x = x.cummin(-1)[0]
320
+ return x
321
+
322
+ class Transformer(nn.Module):
323
+ """Transformer model."""
324
+
325
+ def __init__(self,
326
+ hidden_size,
327
+ nlayers,
328
+ ntokens,
329
+ nhead=8,
330
+ dropout=0.1,
331
+ dropatt=0.1,
332
+ relative_bias=True,
333
+ pos_emb=False,
334
+ pad=0):
335
+ """Initialization.
336
+
337
+ Args:
338
+ hidden_size: dimension of inputs and hidden states
339
+ nlayers: number of layers
340
+ ntokens: number of output categories
341
+ nhead: number of self-attention heads
342
+ dropout: dropout rate
343
+ dropatt: drop attention rate
344
+ relative_bias: bool, indicate whether use a relative position based
345
+ attention bias
346
+ pos_emb: bool, indicate whether use a learnable positional embedding
347
+ pad: pad token index
348
+ """
349
+
350
+ super(Transformer, self).__init__()
351
+
352
+ self.drop = nn.Dropout(dropout)
353
+
354
+ self.emb = nn.Embedding(ntokens, hidden_size)
355
+ if pos_emb:
356
+ self.pos_emb = nn.Embedding(500, hidden_size)
357
+
358
+ self.layers = nn.ModuleList([
359
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
360
+ dropatt=dropatt, relative_bias=relative_bias)
361
+ for _ in range(nlayers)])
362
+
363
+ self.norm = nn.LayerNorm(hidden_size)
364
+
365
+ self.output_layer = nn.Linear(hidden_size, ntokens)
366
+ self.output_layer.weight = self.emb.weight
367
+
368
+ self.init_weights()
369
+
370
+ self.nlayers = nlayers
371
+ self.nhead = nhead
372
+ self.ntokens = ntokens
373
+ self.hidden_size = hidden_size
374
+ self.pad = pad
375
+
376
+ def init_weights(self):
377
+ """Initialize token embedding and output bias."""
378
+ initrange = 0.1
379
+ self.emb.weight.data.uniform_(-initrange, initrange)
380
+ if hasattr(self, 'pos_emb'):
381
+ self.pos_emb.weight.data.uniform_(-initrange, initrange)
382
+ self.output_layer.bias.data.fill_(0)
383
+
384
+ def visibility(self, x, device):
385
+ """Mask pad tokens."""
386
+ visibility = (x != self.pad).float()
387
+ visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
388
+ visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
389
+ return visibility.log()
390
+
391
+ def encode(self, x, pos):
392
+ """Standard transformer encode process."""
393
+ h = self.emb(x)
394
+ if hasattr(self, 'pos_emb'):
395
+ h = h + self.pos_emb(pos)
396
+ h_list = []
397
+ visibility = self.visibility(x, x.device)
398
+
399
+ for i in range(self.nlayers):
400
+ h_list.append(h)
401
+ h = self.layers[i](
402
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
403
+
404
+ output = h
405
+ h_array = torch.stack(h_list, dim=2)
406
+
407
+ return output, h_array
408
+
409
+ def forward(self, x, pos):
410
+ """Pass the input through the encoder layer.
411
+
412
+ Args:
413
+ x: input tokens (required).
414
+ pos: position for each token (optional).
415
+ Returns:
416
+ output: probability distributions for missing tokens.
417
+ state_dict: parsing results and raw output
418
+ """
419
+
420
+ batch_size, length = x.size()
421
+
422
+ raw_output, _ = self.encode(x, pos)
423
+ raw_output = self.norm(raw_output)
424
+ raw_output = self.drop(raw_output)
425
+
426
+ output = self.output_layer(raw_output)
427
+ return output.view(batch_size * length, -1), {'raw_output': raw_output,}
428
+
429
+ class StructFormer(Transformer):
430
+ """StructFormer model."""
431
+
432
+ def __init__(self,
433
+ hidden_size,
434
+ nlayers,
435
+ ntokens,
436
+ nhead=8,
437
+ dropout=0.1,
438
+ dropatt=0.1,
439
+ relative_bias=False,
440
+ pos_emb=False,
441
+ pad=0,
442
+ n_parser_layers=4,
443
+ conv_size=9,
444
+ relations=('head', 'child'),
445
+ weight_act='softmax'):
446
+ """Initialization.
447
+
448
+ Args:
449
+ hidden_size: dimension of inputs and hidden states
450
+ nlayers: number of layers
451
+ ntokens: number of output categories
452
+ nhead: number of self-attention heads
453
+ dropout: dropout rate
454
+ dropatt: drop attention rate
455
+ relative_bias: bool, indicate whether use a relative position based
456
+ attention bias
457
+ pos_emb: bool, indicate whether use a learnable positional embedding
458
+ pad: pad token index
459
+ n_parser_layers: number of parsing layers
460
+ conv_size: convolution kernel size for parser
461
+ relations: relations that are used to compute self attention
462
+ weight_act: relations distribution activation function
463
+ """
464
+
465
+ super(StructFormer, self).__init__(
466
+ hidden_size,
467
+ nlayers,
468
+ ntokens,
469
+ nhead=nhead,
470
+ dropout=dropout,
471
+ dropatt=dropatt,
472
+ relative_bias=relative_bias,
473
+ pos_emb=pos_emb,
474
+ pad=pad)
475
+
476
+ self.parser_layers = nn.ModuleList([
477
+ nn.Sequential(Conv1d(hidden_size, conv_size),
478
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
479
+ nn.Tanh()) for i in range(n_parser_layers)])
480
+
481
+ self.distance_ff = nn.Sequential(
482
+ Conv1d(hidden_size, 2),
483
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
484
+ nn.Linear(hidden_size, 1))
485
+
486
+ self.height_ff = nn.Sequential(
487
+ nn.Linear(hidden_size, hidden_size),
488
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
489
+ nn.Linear(hidden_size, 1))
490
+
491
+ n_rel = len(relations)
492
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
493
+ self._rel_weight.data.normal_(0, 0.1)
494
+
495
+ self._scaler = nn.Parameter(torch.zeros(2))
496
+
497
+ self.n_parse_layers = n_parser_layers
498
+ self.weight_act = weight_act
499
+ self.relations = relations
500
+
501
+ @property
502
+ def scaler(self):
503
+ return self._scaler.exp()
504
+
505
+ @property
506
+ def rel_weight(self):
507
+ if self.weight_act == 'sigmoid':
508
+ return torch.sigmoid(self._rel_weight)
509
+ elif self.weight_act == 'softmax':
510
+ return torch.softmax(self._rel_weight, dim=-1)
511
+
512
+ def parse(self, x, pos):
513
+ """Parse input sentence.
514
+
515
+ Args:
516
+ x: input tokens (required).
517
+ pos: position for each token (optional).
518
+ Returns:
519
+ distance: syntactic distance
520
+ height: syntactic height
521
+ """
522
+
523
+ mask = (x != self.pad)
524
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
525
+
526
+ h = self.emb(x)
527
+ for i in range(self.n_parse_layers):
528
+ h = h.masked_fill(~mask[:, :, None], 0)
529
+ h = self.parser_layers[i](h)
530
+
531
+ height = self.height_ff(h).squeeze(-1)
532
+ height.masked_fill_(~mask, -1e9)
533
+
534
+ distance = self.distance_ff(h).squeeze(-1)
535
+ distance.masked_fill_(~mask_shifted, 1e9)
536
+
537
+ # Calbrating the distance and height to the same level
538
+ length = distance.size(1)
539
+ height_max = height[:, None, :].expand(-1, length, -1)
540
+ height_max = torch.cummax(
541
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
542
+ dim=-1)[0].triu(0)
543
+
544
+ margin_left = torch.relu(
545
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
546
+ margin_right = torch.relu(distance[:, None, :] - height_max)
547
+ margin = torch.where(margin_left > margin_right, margin_right,
548
+ margin_left).triu(0)
549
+
550
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
551
+ margin.masked_fill_(~margin_mask, 0)
552
+ margin = margin.max()
553
+
554
+ distance = distance - margin
555
+
556
+ return distance, height
557
+
558
+ def compute_block(self, distance, height):
559
+ """Compute constituents from distance and height."""
560
+
561
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
562
+
563
+ gamma = torch.sigmoid(-beta_logits)
564
+ ones = torch.ones_like(gamma)
565
+
566
+ block_mask_left = cummin(
567
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
568
+ block_mask_left = block_mask_left - F.pad(
569
+ block_mask_left[:, :, :-1], (1, 0), value=0)
570
+ block_mask_left.tril_(0)
571
+
572
+ block_mask_right = cummin(
573
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
574
+ block_mask_right = block_mask_right - F.pad(
575
+ block_mask_right[:, :, 1:], (0, 1), value=0)
576
+ block_mask_right.triu_(0)
577
+
578
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
579
+ block = cumsum(block_mask_left).tril(0) + cumsum(
580
+ block_mask_right, reverse=True).triu(1)
581
+
582
+ return block_p, block
583
+
584
+ def compute_head(self, height):
585
+ """Estimate head for each constituent."""
586
+
587
+ _, length = height.size()
588
+ head_logits = height * self.scaler[1]
589
+ index = torch.arange(length, device=height.device)
590
+
591
+ mask = (index[:, None, None] <= index[None, None, :]) * (
592
+ index[None, None, :] <= index[None, :, None])
593
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
594
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
595
+
596
+ head_p = torch.softmax(head_logits, dim=-1)
597
+
598
+ return head_p
599
+
600
+ def generate_mask(self, x, distance, height):
601
+ """Compute head and cibling distribution for each token."""
602
+
603
+ bsz, length = x.size()
604
+
605
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
606
+ eye = eye[None, :, :].expand((bsz, -1, -1))
607
+
608
+ block_p, block = self.compute_block(distance, height)
609
+ head_p = self.compute_head(height)
610
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
611
+ head = head.masked_fill(eye, 0)
612
+ child = head.transpose(1, 2)
613
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
614
+
615
+ rel_list = []
616
+ if 'head' in self.relations:
617
+ rel_list.append(head)
618
+ if 'child' in self.relations:
619
+ rel_list.append(child)
620
+ if 'cibling' in self.relations:
621
+ rel_list.append(cibling)
622
+
623
+ rel = torch.stack(rel_list, dim=1)
624
+
625
+ rel_weight = self.rel_weight
626
+
627
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
628
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
629
+
630
+ return att_mask, cibling, head, block
631
+
632
+ def encode(self, x, pos, att_mask):
633
+ """Structformer encoding process."""
634
+
635
+ visibility = self.visibility(x, x.device)
636
+ h = self.emb(x)
637
+ if hasattr(self, 'pos_emb'):
638
+ assert pos.max() < 500
639
+ h = h + self.pos_emb(pos)
640
+ for i in range(self.nlayers):
641
+ h = self.layers[i](
642
+ h.transpose(0, 1), attn_mask=att_mask[i],
643
+ key_padding_mask=visibility).transpose(0, 1)
644
+ return h
645
+
646
+ def forward(self, x, pos):
647
+ """Pass the input through the encoder layer.
648
+
649
+ Args:
650
+ x: input tokens (required).
651
+ pos: position for each token (optional).
652
+ Returns:
653
+ output: probability distributions for missing tokens.
654
+ state_dict: parsing results and raw output
655
+ """
656
+
657
+ batch_size, length = x.size()
658
+
659
+ distance, height = self.parse(x, pos)
660
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
661
+
662
+ raw_output = self.encode(x, pos, att_mask)
663
+ raw_output = self.norm(raw_output)
664
+ raw_output = self.drop(raw_output)
665
+
666
+ output = self.output_layer(raw_output)
667
+
668
+ return output.view(batch_size * length, -1), \
669
+ {'raw_output': raw_output, 'distance': distance, 'height': height,
670
+ 'cibling': cibling, 'head': head, 'block': block}
671
+
672
+
673
+ ##########################################
674
+ # Clasication Head For BabyLM Evaluation Tasks
675
+ ##########################################
676
+ class ClassificationHead(nn.Module):
677
+ """Head for sentence-level classification tasks."""
678
+ def __init__(self, config):
679
+ super(ClassificationHead, self).__init__()
680
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
681
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
682
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
683
+
684
+ def forward(self, features, **kwargs):
685
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
686
+ x = self.dropout(x)
687
+ x = self.dense(x)
688
+ x = torch.tanh(x)
689
+ x = self.dropout(x)
690
+ x = self.out_proj(x)
691
+ return x
692
+
693
+ ##########################################
694
+ # HuggingFace Config
695
+ ##########################################
696
+ class StructFormerConfig(PretrainedConfig):
697
+ model_type = "structformer"
698
+
699
+ def __init__(
700
+ self,
701
+ hidden_size=512,
702
+ nlayers=8,
703
+ ntokens=10_000,
704
+ nhead=8,
705
+ dropout=0.1,
706
+ dropatt=0.1,
707
+ relative_bias=False,
708
+ pos_emb=False,
709
+ pad=0,
710
+ n_parser_layers=4,
711
+ conv_size=9,
712
+ relations=('head', 'child'),
713
+ weight_act='softmax',
714
+ num_labels=1,
715
+ hidden_dropout_prob=0.1,
716
+ initializer_range=0.02,
717
+ **kwargs,
718
+ ):
719
+ self.hidden_size = hidden_size
720
+ self.nlayers = nlayers
721
+ self.ntokens = ntokens
722
+ self.nhead = nhead
723
+ self.dropout = dropout
724
+ self.dropatt = dropatt
725
+ self.relative_bias = relative_bias
726
+ self.pos_emb = pos_emb
727
+ self.pad = pad
728
+ self.n_parser_layers = n_parser_layers
729
+ self.conv_size = conv_size
730
+ self.relations = relations
731
+ self.weight_act = weight_act
732
+ self.num_labels = num_labels
733
+ self.hidden_dropout_prob = hidden_dropout_prob
734
+ self.initializer_range=initializer_range
735
+ super().__init__(**kwargs)
736
+
737
+ class TransformerConfig(PretrainedConfig):
738
+ model_type = "transformer"
739
+
740
+ def __init__(
741
+ self,
742
+ hidden_size=512,
743
+ nlayers=8,
744
+ ntokens=10_000,
745
+ nhead=8,
746
+ dropout=0.1,
747
+ dropatt=0.1,
748
+ relative_bias=False,
749
+ pos_emb=False,
750
+ pad=0,
751
+ num_labels=1,
752
+ hidden_dropout_prob=0.1,
753
+ initializer_range=0.02,
754
+ **kwargs,
755
+ ):
756
+ self.hidden_size = hidden_size
757
+ self.nlayers = nlayers
758
+ self.ntokens = ntokens
759
+ self.nhead = nhead
760
+ self.dropout = dropout
761
+ self.dropatt = dropatt
762
+ self.relative_bias = relative_bias
763
+ self.pos_emb = pos_emb
764
+ self.pad = pad
765
+ self.num_labels = num_labels
766
+ self.hidden_dropout_prob = hidden_dropout_prob
767
+ self.initializer_range=initializer_range
768
+ super().__init__(**kwargs)
769
+
770
+
771
+
772
+ ##########################################
773
+ # HuggingFace Models
774
+ ##########################################
775
+ class StructFormerModel(PreTrainedModel):
776
+ config_class = StructFormerConfig
777
+
778
+ def __init__(self, config):
779
+ super().__init__(config)
780
+ self.model = StructFormer(
781
+ hidden_size=config.hidden_size,
782
+ nlayers=config.nlayers,
783
+ ntokens=config.ntokens,
784
+ nhead=config.nhead,
785
+ dropout=config.dropout,
786
+ dropatt=config.dropatt,
787
+ relative_bias=config.relative_bias,
788
+ pos_emb=config.pos_emb,
789
+ pad=config.pad,
790
+ n_parser_layers=config.n_parser_layers,
791
+ conv_size=config.conv_size,
792
+ relations=config.relations,
793
+ weight_act=config.weight_act
794
+ )
795
+ self.config = config
796
+
797
+ def parse(self, input_ids, **kwargs):
798
+ x = input_ids
799
+ batch_size, length = x.size()
800
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
801
+
802
+ sf_output = self.model(x, pos)
803
+
804
+ return sf_output[1]
805
+
806
+ def forward(self, input_ids, labels=None, **kwargs):
807
+ x = input_ids
808
+ batch_size, length = x.size()
809
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
810
+
811
+ sf_output = self.model(x, pos)
812
+
813
+ loss = None
814
+ if labels is not None:
815
+ loss_fct = nn.CrossEntropyLoss()
816
+ loss = loss_fct(sf_output[0], labels.reshape(-1))
817
+
818
+ return MaskedLMOutput(
819
+ loss=loss, # shape: 1
820
+ logits=sf_output[0].view(batch_size, length, -1), # shape: (batch_size, length, ntokens)
821
+ hidden_states=None,
822
+ attentions=None
823
+ )
824
+
825
+ class StructFormerModelForSequenceClassification(PreTrainedModel):
826
+ config_class = StructFormerConfig
827
+
828
+ def __init__(self, config):
829
+ super().__init__(config)
830
+ self.model = StructFormer(
831
+ hidden_size=config.hidden_size,
832
+ nlayers=config.nlayers,
833
+ ntokens=config.ntokens,
834
+ nhead=config.nhead,
835
+ dropout=config.dropout,
836
+ dropatt=config.dropatt,
837
+ relative_bias=config.relative_bias,
838
+ pos_emb=config.pos_emb,
839
+ pad=config.pad,
840
+ n_parser_layers=config.n_parser_layers,
841
+ conv_size=config.conv_size,
842
+ relations=config.relations,
843
+ weight_act=config.weight_act
844
+ )
845
+ self.config = config
846
+ self.num_labels = config.num_labels
847
+ self.model.classifier = ClassificationHead(config)
848
+
849
+ def _init_weights(self, module):
850
+ """Initialize the weights"""
851
+ if isinstance(module, nn.Linear):
852
+ # Slightly different from the TF version which uses truncated_normal for initialization
853
+ # cf https://github.com/pytorch/pytorch/pull/5617
854
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
855
+ if module.bias is not None:
856
+ module.bias.data.zero_()
857
+ elif isinstance(module, nn.Embedding):
858
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
859
+ if module.padding_idx is not None:
860
+ module.weight.data[module.padding_idx].zero_()
861
+ elif isinstance(module, nn.LayerNorm):
862
+ if module.bias is not None:
863
+ module.bias.data.zero_()
864
+ module.weight.data.fill_(1.0)
865
+
866
+ def forward(self, input_ids, labels=None, **kwargs):
867
+ x = input_ids
868
+ batch_size, length = x.size()
869
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
870
+
871
+ sf_output = self.model(x, pos)
872
+
873
+ logits = self.model.classifier(sf_output[1]['raw_output'])
874
+ loss = None
875
+ if labels is not None:
876
+ if self.config.problem_type is None:
877
+ if self.num_labels == 1:
878
+ self.config.problem_type = "regression"
879
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
880
+ self.config.problem_type = "single_label_classification"
881
+ else:
882
+ self.config.problem_type = "multi_label_classification"
883
+
884
+ if self.config.problem_type == "regression":
885
+ loss_fct = nn.MSELoss()
886
+ if self.num_labels == 1:
887
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
888
+ else:
889
+ loss = loss_fct(logits, labels)
890
+ elif self.config.problem_type == "single_label_classification":
891
+ loss_fct = nn.CrossEntropyLoss()
892
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
893
+ elif self.config.problem_type == "multi_label_classification":
894
+ loss_fct = nn.BCEWithLogitsLoss()
895
+ loss = loss_fct(logits, labels)
896
+
897
+ return SequenceClassifierOutput(
898
+ loss=loss,
899
+ logits=logits,
900
+ hidden_states=None,
901
+ attentions=None,
902
+ )
903
+
904
+ class TransformerModel(PreTrainedModel):
905
+ config_class = TransformerConfig
906
+
907
+ def __init__(self, config):
908
+ super().__init__(config)
909
+ self.model = Transformer(
910
+ hidden_size=config.hidden_size,
911
+ nlayers=config.nlayers,
912
+ ntokens=config.ntokens,
913
+ nhead=config.nhead,
914
+ dropout=config.dropout,
915
+ dropatt=config.dropatt,
916
+ relative_bias=config.relative_bias,
917
+ pos_emb=config.pos_emb,
918
+ pad=config.pad
919
+ )
920
+ self.config = config
921
+
922
+ def forward(self, input_ids, labels=None, **kwargs):
923
+ x = input_ids
924
+ batch_size, length = x.size()
925
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
926
+
927
+ sf_output = self.model(x, pos)
928
+
929
+ loss = None
930
+ if labels is not None:
931
+ loss_fct = nn.CrossEntropyLoss()
932
+ loss = loss_fct(sf_output[0], labels.reshape(-1))
933
+
934
+ return MaskedLMOutput(
935
+ loss=loss, # shape: 1
936
+ logits=sf_output[0].view(batch_size, length, -1), # shape: (batch_size, length, ntokens)
937
+ hidden_states=None,
938
+ attentions=None
939
+ )
940
+
941
+ class TransformerModelForSequenceClassification(PreTrainedModel):
942
+ config_class = TransformerConfig
943
+
944
+ def __init__(self, config):
945
+ super().__init__(config)
946
+ self.model = StructFormer(
947
+ hidden_size=config.hidden_size,
948
+ nlayers=config.nlayers,
949
+ ntokens=config.ntokens,
950
+ nhead=config.nhead,
951
+ dropout=config.dropout,
952
+ dropatt=config.dropatt,
953
+ relative_bias=config.relative_bias,
954
+ pos_emb=config.pos_emb,
955
+ pad=config.pad
956
+ )
957
+ self.config = config
958
+ self.num_labels = config.num_labels
959
+ self.model.classifier = ClassificationHead(config)
960
+
961
+ def _init_weights(self, module):
962
+ """Initialize the weights"""
963
+ if isinstance(module, nn.Linear):
964
+ # Slightly different from the TF version which uses truncated_normal for initialization
965
+ # cf https://github.com/pytorch/pytorch/pull/5617
966
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
967
+ if module.bias is not None:
968
+ module.bias.data.zero_()
969
+ elif isinstance(module, nn.Embedding):
970
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
971
+ if module.padding_idx is not None:
972
+ module.weight.data[module.padding_idx].zero_()
973
+ elif isinstance(module, nn.LayerNorm):
974
+ if module.bias is not None:
975
+ module.bias.data.zero_()
976
+ module.weight.data.fill_(1.0)
977
+
978
+ def forward(self, input_ids, labels=None, **kwargs):
979
+ x = input_ids
980
+ batch_size, length = x.size()
981
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
982
+
983
+ sf_output = self.model(x, pos)
984
+
985
+ logits = self.model.classifier(sf_output[1]['raw_output'])
986
+ loss = None
987
+ if labels is not None:
988
+ if self.config.problem_type is None:
989
+ if self.num_labels == 1:
990
+ self.config.problem_type = "regression"
991
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
992
+ self.config.problem_type = "single_label_classification"
993
+ else:
994
+ self.config.problem_type = "multi_label_classification"
995
+
996
+ if self.config.problem_type == "regression":
997
+ loss_fct = nn.MSELoss()
998
+ if self.num_labels == 1:
999
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1000
+ else:
1001
+ loss = loss_fct(logits, labels)
1002
+ elif self.config.problem_type == "single_label_classification":
1003
+ loss_fct = nn.CrossEntropyLoss()
1004
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1005
+ elif self.config.problem_type == "multi_label_classification":
1006
+ loss_fct = nn.BCEWithLogitsLoss()
1007
+ loss = loss_fct(logits, labels)
1008
+
1009
+ return SequenceClassifierOutput(
1010
+ loss=loss,
1011
+ logits=logits,
1012
+ hidden_states=None,
1013
+ attentions=None,
1014
+ )