GinnM commited on
Commit
81d10ac
·
verified ·
1 Parent(s): be39b27

Upload DeprotForMaskedLM

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. model.safetensors +3 -0
  3. modeling_deprot.py +1413 -0
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "oracle_checkpoint/ss_0_2051___True/15_proteingym_NoStability_0.390",
3
  "architectures": [
4
  "DeprotForMaskedLM"
5
  ],
 
1
  {
2
+ "_name_or_path": "oracle_checkpoint/ss_0_2051___True/10_proteingym_0.436_proteingym_Stability_0.484_proteingym_NoStability_0.414",
3
  "architectures": [
4
  "DeprotForMaskedLM"
5
  ],
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f563f354509953afa2db37f82c3e7631a050e1fa3e671d9f723d39716b999b6
3
+ size 355285344
modeling_deprot.py ADDED
@@ -0,0 +1,1413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+ from typing import Optional, Tuple, Union
3
+ from transformers import PretrainedConfig
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
8
+ from transformers.activations import ACT2FN
9
+ from transformers.modeling_outputs import (
10
+ BaseModelOutput,
11
+ MaskedLMOutput,
12
+ SequenceClassifierOutput,
13
+ TokenClassifierOutput,
14
+ )
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from .configuration_deprot import DeprotConfig
17
+ from transformers.models.roformer.modeling_roformer import RoFormerModel
18
+ from transformers.models.esm.modeling_esm import EsmModel
19
+ import torch.nn.functional as F
20
+
21
+
22
+ def build_relative_position(query_size, key_size, device):
23
+ """
24
+ Build relative position according to the query and key
25
+
26
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
27
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
28
+ P_k\\)
29
+
30
+ Args:
31
+ query_size (int): the length of query
32
+ key_size (int): the length of key
33
+
34
+ Return:
35
+ `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
36
+
37
+ """
38
+
39
+ q_ids = torch.arange(query_size, dtype=torch.long, device=device)
40
+ k_ids = torch.arange(key_size, dtype=torch.long, device=device)
41
+ rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
42
+ rel_pos_ids = rel_pos_ids[:query_size, :]
43
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
44
+ return rel_pos_ids
45
+
46
+
47
+ @torch.jit.script
48
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
49
+ return c2p_pos.expand(
50
+ [
51
+ query_layer.size(0),
52
+ query_layer.size(1),
53
+ query_layer.size(2),
54
+ relative_pos.size(-1),
55
+ ]
56
+ )
57
+
58
+
59
+ @torch.jit.script
60
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
61
+ return c2p_pos.expand(
62
+ [
63
+ query_layer.size(0),
64
+ query_layer.size(1),
65
+ key_layer.size(-2),
66
+ key_layer.size(-2),
67
+ ]
68
+ )
69
+
70
+
71
+ @torch.jit.script
72
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
73
+ return pos_index.expand(
74
+ p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))
75
+ )
76
+
77
+
78
+ def rotate_half(x):
79
+ x1, x2 = x.chunk(2, dim=-1)
80
+ return torch.cat((-x2, x1), dim=-1)
81
+
82
+
83
+ def apply_rotary_pos_emb(x, cos, sin):
84
+ cos = cos[:, :, : x.shape[-2], :]
85
+ sin = sin[:, :, : x.shape[-2], :]
86
+
87
+ return (x * cos) + (rotate_half(x) * sin)
88
+
89
+
90
+ class RotaryEmbedding(torch.nn.Module):
91
+ """
92
+ Rotary position embeddings based on those in
93
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
94
+ matrices which depend on their relative positions.
95
+ """
96
+
97
+ def __init__(self, dim: int):
98
+ super().__init__()
99
+ # Generate and save the inverse frequency buffer (non trainable)
100
+ inv_freq = 1.0 / (
101
+ 10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
102
+ )
103
+ inv_freq = inv_freq
104
+ self.register_buffer("inv_freq", inv_freq)
105
+
106
+ self._seq_len_cached = None
107
+ self._cos_cached = None
108
+ self._sin_cached = None
109
+
110
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
111
+ seq_len = x.shape[seq_dimension]
112
+
113
+ # Reset the tables if the sequence length has changed,
114
+ # or if we're on a new device (possibly due to tracing for instance)
115
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
116
+ self._seq_len_cached = seq_len
117
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(
118
+ self.inv_freq
119
+ )
120
+ freqs = torch.outer(t, self.inv_freq)
121
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
122
+
123
+ self._cos_cached = emb.cos()[None, None, :, :]
124
+ self._sin_cached = emb.sin()[None, None, :, :]
125
+
126
+ return self._cos_cached, self._sin_cached
127
+
128
+ def forward(
129
+ self, q: torch.Tensor, k: torch.Tensor
130
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
131
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
132
+ k, seq_dimension=-2
133
+ )
134
+
135
+ return (
136
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
137
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
138
+ )
139
+
140
+
141
+ class MaskedConv1d(nn.Conv1d):
142
+ """A masked 1-dimensional convolution layer.
143
+
144
+ Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically.
145
+
146
+ Shape:
147
+ Input: (N, L, in_channels)
148
+ input_mask: (N, L, 1), optional
149
+ Output: (N, L, out_channels)
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ in_channels: int,
155
+ out_channels: int,
156
+ kernel_size: int,
157
+ stride: int = 1,
158
+ dilation: int = 1,
159
+ groups: int = 1,
160
+ bias: bool = True,
161
+ ):
162
+ """
163
+ :param in_channels: input channels
164
+ :param out_channels: output channels
165
+ :param kernel_size: the kernel width
166
+ :param stride: filter shift
167
+ :param dilation: dilation factor
168
+ :param groups: perform depth-wise convolutions
169
+ :param bias: adds learnable bias to output
170
+ """
171
+ padding = dilation * (kernel_size - 1) // 2
172
+ super().__init__(
173
+ in_channels,
174
+ out_channels,
175
+ kernel_size,
176
+ stride=stride,
177
+ dilation=dilation,
178
+ groups=groups,
179
+ bias=bias,
180
+ padding=padding,
181
+ )
182
+
183
+ def forward(self, x, input_mask=None):
184
+ if input_mask is not None:
185
+ x = x * input_mask
186
+ return super().forward(x.transpose(1, 2)).transpose(1, 2)
187
+
188
+
189
+ class Attention1dPooling(nn.Module):
190
+ def __init__(self, config):
191
+ super().__init__()
192
+ self.layer = MaskedConv1d(config.hidden_size, 1, 1)
193
+
194
+ def forward(self, x, input_mask=None):
195
+ batch_szie = x.shape[0]
196
+ attn = self.layer(x)
197
+ attn = attn.view(batch_szie, -1)
198
+ if input_mask is not None:
199
+ attn = attn.masked_fill_(
200
+ ~input_mask.view(batch_szie, -1).bool(), float("-inf")
201
+ )
202
+ attn = F.softmax(attn, dim=-1).view(batch_szie, -1, 1)
203
+ out = (attn * x).sum(dim=1)
204
+ return out
205
+
206
+
207
+ class MeanPooling(nn.Module):
208
+ """Mean Pooling for sentence-level classification tasks."""
209
+
210
+ def __init__(self):
211
+ super().__init__()
212
+
213
+ def forward(self, features, input_mask=None):
214
+ if input_mask is not None:
215
+ # Applying input_mask to zero out masked values
216
+ masked_features = features * input_mask.unsqueeze(2)
217
+ sum_features = torch.sum(masked_features, dim=1)
218
+ mean_pooled_features = sum_features / input_mask.sum(dim=1, keepdim=True)
219
+ else:
220
+ mean_pooled_features = torch.mean(features, dim=1)
221
+ return mean_pooled_features
222
+
223
+
224
+ class ContextPooler(nn.Module):
225
+ def __init__(self, config):
226
+ super().__init__()
227
+ scale_hidden = getattr(config, "scale_hidden", 1)
228
+ if config.pooling_head == "mean":
229
+ self.mean_pooling = MeanPooling()
230
+ elif config.pooling_head == "attention":
231
+ self.mean_pooling = Attention1dPooling(config)
232
+ self.dense = nn.Linear(
233
+ config.pooler_hidden_size, scale_hidden * config.pooler_hidden_size
234
+ )
235
+ self.dropout = nn.Dropout(config.pooler_dropout)
236
+ self.config = config
237
+
238
+ def forward(self, hidden_states, input_mask=None):
239
+ # We "pool" the model by simply taking the hidden state corresponding
240
+ # to the first token.
241
+
242
+ context_token = self.mean_pooling(hidden_states, input_mask)
243
+ context_token = self.dropout(context_token)
244
+ pooled_output = self.dense(context_token)
245
+ pooled_output = torch.tanh(pooled_output)
246
+ return pooled_output
247
+
248
+ @property
249
+ def output_dim(self):
250
+ return self.config.hidden_size
251
+
252
+
253
+ class DeprotLayerNorm(nn.Module):
254
+ """LayerNorm module in the TF style (epsilon inside the square root)."""
255
+
256
+ def __init__(self, size, eps=1e-12):
257
+ super().__init__()
258
+ self.weight = nn.Parameter(torch.ones(size))
259
+ self.bias = nn.Parameter(torch.zeros(size))
260
+ self.variance_epsilon = eps
261
+
262
+ def forward(self, hidden_states):
263
+ input_type = hidden_states.dtype
264
+ hidden_states = hidden_states.float()
265
+ mean = hidden_states.mean(-1, keepdim=True)
266
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
267
+ hidden_states = (hidden_states - mean) / torch.sqrt(
268
+ variance + self.variance_epsilon
269
+ )
270
+ hidden_states = hidden_states.to(input_type)
271
+ y = self.weight * hidden_states + self.bias
272
+ return y
273
+
274
+
275
+ class DisentangledSelfAttention(nn.Module):
276
+
277
+ def __init__(self, config: DeprotConfig):
278
+ super().__init__()
279
+ if config.hidden_size % config.num_attention_heads != 0:
280
+ raise ValueError(
281
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
282
+ f"heads ({config.num_attention_heads})"
283
+ )
284
+ self.num_attention_heads = config.num_attention_heads
285
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
286
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
287
+
288
+ # Q, K, V projection layers
289
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
290
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
291
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
292
+
293
+ # AA->SS, AA->POS, SS->AA, POS->AA and AA->AA attention layers
294
+ self.pos_att_type = (
295
+ config.pos_att_type if config.pos_att_type is not None else []
296
+ )
297
+
298
+ self.relative_attention = getattr(config, "relative_attention", False)
299
+ self.position_embedding_type = getattr(
300
+ config, "position_embedding_type", "relative"
301
+ )
302
+ if self.position_embedding_type == "rotary":
303
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
304
+ if self.relative_attention:
305
+
306
+ if "aa2ss" in self.pos_att_type:
307
+ self.ss_proj = nn.Linear(
308
+ config.hidden_size, self.all_head_size, bias=False
309
+ )
310
+
311
+ if "ss2aa" in self.pos_att_type:
312
+ self.ss_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
313
+
314
+ elif self.position_embedding_type == "relative":
315
+ if self.relative_attention:
316
+ self.max_relative_positions = getattr(
317
+ config, "max_relative_positions", -1
318
+ )
319
+ if self.max_relative_positions < 1:
320
+ self.max_relative_positions = config.max_position_embeddings
321
+ self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
322
+
323
+ # amino acid to position
324
+ if "aa2pos" in self.pos_att_type:
325
+ self.pos_proj = nn.Linear(
326
+ config.hidden_size, self.all_head_size, bias=False
327
+ ) # Key
328
+
329
+ if "pos2aa" in self.pos_att_type:
330
+ self.pos_q_proj = nn.Linear(
331
+ config.hidden_size, self.all_head_size
332
+ ) # Query
333
+
334
+ if "aa2ss" in self.pos_att_type:
335
+ self.ss_proj = nn.Linear(
336
+ config.hidden_size, self.all_head_size, bias=False
337
+ )
338
+
339
+ if "ss2aa" in self.pos_att_type:
340
+ self.ss_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
341
+
342
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
343
+
344
+ def transpose_for_scores(self, x):
345
+ # x [batch_size, seq_len, all_head_size]
346
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
347
+ # x [batch_size, seq_len, num_attention_heads, attention_head_size]
348
+ x = x.view(new_x_shape)
349
+ # x [batch_size, num_attention_heads, seq_len, attention_head_size]
350
+ return x.permute(0, 2, 1, 3)
351
+
352
+ def forward(
353
+ self,
354
+ hidden_states,
355
+ attention_mask,
356
+ output_attentions=False,
357
+ query_states=None,
358
+ relative_pos=None,
359
+ rel_embeddings=None,
360
+ ss_hidden_states=None,
361
+ ):
362
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
363
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
364
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
365
+
366
+ if self.position_embedding_type == "rotary":
367
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
368
+
369
+ rel_att = None
370
+ scale_factor = 1 + len(self.pos_att_type)
371
+ scale = torch.sqrt(
372
+ torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor
373
+ )
374
+ query_layer = query_layer / scale.to(dtype=query_layer.dtype)
375
+
376
+ # [batch_size, num_attention_heads, seq_len, seq_len]
377
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
378
+
379
+ if self.relative_attention:
380
+ if self.position_embedding_type == "relative":
381
+ rel_embeddings = self.pos_dropout(rel_embeddings)
382
+ rel_att = self.disentangled_att_bias(
383
+ query_layer,
384
+ key_layer,
385
+ relative_pos,
386
+ rel_embeddings,
387
+ scale_factor,
388
+ ss_hidden_states,
389
+ )
390
+
391
+ if rel_att is not None:
392
+ attention_scores = attention_scores + rel_att
393
+
394
+ rmask = ~(attention_mask.to(torch.bool))
395
+ attention_probs = attention_scores.masked_fill(rmask, float("-inf"))
396
+ attention_probs = torch.softmax(attention_probs, -1)
397
+ attention_probs = attention_probs.masked_fill(rmask, 0.0)
398
+ # attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
399
+ attention_probs = self.dropout(attention_probs)
400
+
401
+ context_layer = torch.matmul(attention_probs, value_layer)
402
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
403
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
404
+ context_layer = context_layer.view(new_context_layer_shape)
405
+ if output_attentions:
406
+ return (context_layer, attention_probs)
407
+ else:
408
+ return context_layer
409
+
410
+ def disentangled_att_bias(
411
+ self,
412
+ query_layer,
413
+ key_layer,
414
+ relative_pos,
415
+ rel_embeddings,
416
+ scale_factor,
417
+ ss_hidden_states,
418
+ ):
419
+ if self.position_embedding_type == "relative":
420
+ if relative_pos is None:
421
+ q = query_layer.size(-2)
422
+ relative_pos = build_relative_position(
423
+ q, key_layer.size(-2), query_layer.device
424
+ )
425
+ if relative_pos.dim() == 2:
426
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
427
+ elif relative_pos.dim() == 3:
428
+ relative_pos = relative_pos.unsqueeze(1)
429
+ # bxhxqxk
430
+ elif relative_pos.dim() != 4:
431
+ raise ValueError(
432
+ f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}"
433
+ )
434
+
435
+ att_span = min(
436
+ max(query_layer.size(-2), key_layer.size(-2)),
437
+ self.max_relative_positions,
438
+ )
439
+ relative_pos = relative_pos.long().to(query_layer.device)
440
+ rel_embeddings = rel_embeddings[
441
+ self.max_relative_positions
442
+ - att_span : self.max_relative_positions
443
+ + att_span,
444
+ :,
445
+ ].unsqueeze(0)
446
+
447
+ score = 0
448
+
449
+ if "aa2pos" in self.pos_att_type:
450
+ pos_key_layer = self.pos_proj(rel_embeddings)
451
+ pos_key_layer = self.transpose_for_scores(pos_key_layer)
452
+ aa2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
453
+ aa2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
454
+ aa2p_att = torch.gather(
455
+ aa2p_att,
456
+ dim=-1,
457
+ index=c2p_dynamic_expand(aa2p_pos, query_layer, relative_pos),
458
+ )
459
+ score += aa2p_att
460
+
461
+ if "pos2aa" in self.pos_att_type:
462
+ pos_query_layer = self.pos_q_proj(rel_embeddings)
463
+ pos_query_layer = self.transpose_for_scores(pos_query_layer)
464
+ pos_query_layer /= torch.sqrt(
465
+ torch.tensor(pos_query_layer.size(-1), dtype=torch.float)
466
+ * scale_factor
467
+ )
468
+ if query_layer.size(-2) != key_layer.size(-2):
469
+ r_pos = build_relative_position(
470
+ key_layer.size(-2), key_layer.size(-2), query_layer.device
471
+ )
472
+ else:
473
+ r_pos = relative_pos
474
+ p2aa_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
475
+ p2aa_att = torch.matmul(
476
+ key_layer,
477
+ pos_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype),
478
+ )
479
+ p2aa_att = torch.gather(
480
+ p2aa_att,
481
+ dim=-1,
482
+ index=p2c_dynamic_expand(p2aa_pos, query_layer, key_layer),
483
+ ).transpose(-1, -2)
484
+
485
+ if query_layer.size(-2) != key_layer.size(-2):
486
+ pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
487
+ p2aa_att = torch.gather(
488
+ p2aa_att,
489
+ dim=-2,
490
+ index=pos_dynamic_expand(pos_index, p2aa_att, key_layer),
491
+ )
492
+ score += p2aa_att
493
+
494
+ # content -> structure
495
+ if "aa2ss" in self.pos_att_type:
496
+ assert ss_hidden_states is not None
497
+ ss_key_layer = self.ss_proj(ss_hidden_states)
498
+ ss_key_layer = self.transpose_for_scores(ss_key_layer)
499
+ # [batch_size, num_attention_heads, seq_len, seq_len]
500
+ aa2ss_att = torch.matmul(query_layer, ss_key_layer.transpose(-1, -2))
501
+ score += aa2ss_att
502
+
503
+ if "ss2aa" in self.pos_att_type:
504
+ assert ss_hidden_states is not None
505
+ ss_query_layer = self.ss_q_proj(ss_hidden_states)
506
+ ss_query_layer = self.transpose_for_scores(ss_query_layer)
507
+ ss_query_layer /= torch.sqrt(
508
+ torch.tensor(ss_query_layer.size(-1), dtype=torch.float)
509
+ * scale_factor
510
+ )
511
+ ss2aa_att = torch.matmul(
512
+ key_layer, query_layer.transpose(-1, -2).to(dtype=key_layer.dtype)
513
+ )
514
+ score += ss2aa_att
515
+ return score
516
+ elif self.position_embedding_type == "rotary":
517
+ score = 0
518
+ if "aa2ss" in self.pos_att_type:
519
+ assert ss_hidden_states is not None
520
+ ss_key_layer = self.ss_proj(ss_hidden_states)
521
+ ss_key_layer = self.transpose_for_scores(ss_key_layer)
522
+ aa2ss_att = torch.matmul(query_layer, ss_key_layer.transpose(-1, -2))
523
+ score += aa2ss_att
524
+
525
+ if "ss2aa" in self.pos_att_type:
526
+ assert ss_hidden_states is not None
527
+ ss_query_layer = self.ss_q_proj(ss_hidden_states)
528
+ ss_query_layer = self.transpose_for_scores(ss_query_layer)
529
+ ss_query_layer /= torch.sqrt(
530
+ torch.tensor(ss_query_layer.size(-1), dtype=torch.float)
531
+ * scale_factor
532
+ )
533
+ ss2aa_att = torch.matmul(
534
+ key_layer, query_layer.transpose(-1, -2).to(dtype=key_layer.dtype)
535
+ )
536
+ score += ss2aa_att
537
+ return score
538
+
539
+
540
+ class DeprotSelfOutput(nn.Module):
541
+ def __init__(self, config):
542
+ super().__init__()
543
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
544
+ self.LayerNorm = DeprotLayerNorm(config.hidden_size, config.layer_norm_eps)
545
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
546
+
547
+ def forward(self, hidden_states, input_tensor):
548
+ hidden_states = self.dense(hidden_states)
549
+ hidden_states = self.dropout(hidden_states)
550
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
551
+ return hidden_states
552
+
553
+
554
+ class DeprotAttention(nn.Module):
555
+ def __init__(self, config):
556
+ super().__init__()
557
+ self.self = DisentangledSelfAttention(config)
558
+ self.output = DeprotSelfOutput(config)
559
+ self.config = config
560
+
561
+ def forward(
562
+ self,
563
+ hidden_states,
564
+ attention_mask,
565
+ output_attentions=False,
566
+ query_states=None,
567
+ relative_pos=None,
568
+ rel_embeddings=None,
569
+ ss_hidden_states=None,
570
+ ):
571
+ self_output = self.self(
572
+ hidden_states,
573
+ attention_mask,
574
+ output_attentions,
575
+ query_states=query_states,
576
+ relative_pos=relative_pos,
577
+ rel_embeddings=rel_embeddings,
578
+ ss_hidden_states=ss_hidden_states,
579
+ )
580
+ if output_attentions:
581
+ self_output, att_matrix = self_output
582
+ if query_states is None:
583
+ query_states = hidden_states
584
+ attention_output = self.output(self_output, query_states)
585
+
586
+ if output_attentions:
587
+ return (attention_output, att_matrix)
588
+ else:
589
+ return attention_output
590
+
591
+
592
+ class DeprotIntermediate(nn.Module):
593
+ def __init__(self, config):
594
+ super().__init__()
595
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
596
+ if isinstance(config.hidden_act, str):
597
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
598
+ else:
599
+ self.intermediate_act_fn = config.hidden_act
600
+
601
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
602
+ hidden_states = self.dense(hidden_states)
603
+ hidden_states = self.intermediate_act_fn(hidden_states)
604
+ return hidden_states
605
+
606
+
607
+ class DeprotOutput(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
611
+ self.LayerNorm = DeprotLayerNorm(config.hidden_size, config.layer_norm_eps)
612
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
613
+ self.config = config
614
+
615
+ def forward(self, hidden_states, input_tensor):
616
+ hidden_states = self.dense(hidden_states)
617
+ hidden_states = self.dropout(hidden_states)
618
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
619
+ return hidden_states
620
+
621
+
622
+ class DeprotLayer(nn.Module):
623
+ def __init__(self, config):
624
+ super().__init__()
625
+ self.attention = DeprotAttention(config)
626
+ self.intermediate = DeprotIntermediate(config)
627
+ self.output = DeprotOutput(config)
628
+
629
+ def forward(
630
+ self,
631
+ hidden_states,
632
+ attention_mask,
633
+ query_states=None,
634
+ relative_pos=None,
635
+ rel_embeddings=None,
636
+ output_attentions=False,
637
+ ss_hidden_states=None,
638
+ ):
639
+ attention_output = self.attention(
640
+ hidden_states,
641
+ attention_mask,
642
+ output_attentions=output_attentions,
643
+ query_states=query_states,
644
+ relative_pos=relative_pos,
645
+ rel_embeddings=rel_embeddings,
646
+ ss_hidden_states=ss_hidden_states,
647
+ )
648
+ if output_attentions:
649
+ attention_output, att_matrix = attention_output
650
+ intermediate_output = self.intermediate(attention_output)
651
+ layer_output = self.output(intermediate_output, attention_output)
652
+ if output_attentions:
653
+ return (layer_output, att_matrix)
654
+ else:
655
+ return layer_output
656
+
657
+
658
+ class DeprotEncoder(nn.Module):
659
+ """Modified BertEncoder with relative position bias support"""
660
+
661
+ def __init__(self, config):
662
+ super().__init__()
663
+ self.layer = nn.ModuleList(
664
+ [DeprotLayer(config) for _ in range(config.num_hidden_layers)]
665
+ )
666
+ self.relative_attention = getattr(config, "relative_attention", False)
667
+ if self.relative_attention:
668
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
669
+ if self.max_relative_positions < 1:
670
+ self.max_relative_positions = config.max_position_embeddings
671
+ self.rel_embeddings = nn.Embedding(
672
+ self.max_relative_positions * 2, config.hidden_size
673
+ )
674
+ self.gradient_checkpointing = False
675
+
676
+ def get_rel_embedding(self):
677
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
678
+ return rel_embeddings
679
+
680
+ def get_attention_mask(self, attention_mask):
681
+ if attention_mask.dim() <= 2:
682
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
683
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(
684
+ -2
685
+ ).unsqueeze(-1)
686
+ elif attention_mask.dim() == 3:
687
+ attention_mask = attention_mask.unsqueeze(1)
688
+
689
+ return attention_mask
690
+
691
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
692
+ if self.relative_attention and relative_pos is None:
693
+ q = (
694
+ query_states.size(-2)
695
+ if query_states is not None
696
+ else hidden_states.size(-2)
697
+ )
698
+ relative_pos = build_relative_position(
699
+ q, hidden_states.size(-2), hidden_states.device
700
+ )
701
+ return relative_pos
702
+
703
+ def forward(
704
+ self,
705
+ hidden_states,
706
+ attention_mask,
707
+ output_hidden_states=True,
708
+ output_attentions=False,
709
+ query_states=None,
710
+ relative_pos=None,
711
+ ss_hidden_states=None,
712
+ return_dict=True,
713
+ ):
714
+ attention_mask = self.get_attention_mask(attention_mask)
715
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
716
+
717
+ all_hidden_states = () if output_hidden_states else None
718
+ all_attentions = () if output_attentions else None
719
+
720
+ if isinstance(hidden_states, Sequence):
721
+ next_kv = hidden_states[0]
722
+ else:
723
+ next_kv = hidden_states
724
+ rel_embeddings = self.get_rel_embedding()
725
+ for i, layer_module in enumerate(self.layer):
726
+ if output_hidden_states:
727
+ all_hidden_states = all_hidden_states + (hidden_states,)
728
+
729
+ if self.gradient_checkpointing and self.training:
730
+
731
+ def create_custom_forward(module):
732
+ def custom_forward(*inputs):
733
+ return module(*inputs, output_attentions)
734
+
735
+ return custom_forward
736
+
737
+ hidden_states = torch.utils.checkpoint.checkpoint(
738
+ create_custom_forward(layer_module),
739
+ next_kv,
740
+ attention_mask,
741
+ query_states,
742
+ relative_pos,
743
+ rel_embeddings,
744
+ ss_hidden_states,
745
+ )
746
+ else:
747
+ hidden_states = layer_module(
748
+ next_kv,
749
+ attention_mask,
750
+ query_states=query_states,
751
+ relative_pos=relative_pos,
752
+ rel_embeddings=rel_embeddings,
753
+ output_attentions=output_attentions,
754
+ ss_hidden_states=ss_hidden_states,
755
+ )
756
+
757
+ if output_attentions:
758
+ hidden_states, att_m = hidden_states
759
+
760
+ if query_states is not None:
761
+ query_states = hidden_states
762
+ if isinstance(hidden_states, Sequence):
763
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
764
+ else:
765
+ next_kv = hidden_states
766
+
767
+ if output_attentions:
768
+ all_attentions = all_attentions + (att_m,)
769
+
770
+ if output_hidden_states:
771
+ all_hidden_states = all_hidden_states + (hidden_states,)
772
+
773
+ if not return_dict:
774
+ return tuple(
775
+ v
776
+ for v in [hidden_states, all_hidden_states, all_attentions]
777
+ if v is not None
778
+ )
779
+ return BaseModelOutput(
780
+ last_hidden_state=hidden_states,
781
+ hidden_states=all_hidden_states,
782
+ attentions=all_attentions,
783
+ )
784
+
785
+
786
+ class DeprotEmbeddings(nn.Module):
787
+ """Construct the embeddings from word, position and token_type embeddings."""
788
+
789
+ def __init__(self, config):
790
+ super().__init__()
791
+ pad_token_id = getattr(config, "pad_token_id", 0)
792
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
793
+ self.word_embeddings = nn.Embedding(
794
+ config.vocab_size, self.embedding_size, padding_idx=pad_token_id
795
+ )
796
+
797
+ self.position_biased_input = getattr(config, "position_biased_input", False)
798
+ if not self.position_biased_input:
799
+ self.position_embeddings = None
800
+ else:
801
+ # assert getattr(config, "position_embedding_type", "relative") == "absolute"
802
+ self.position_embeddings = nn.Embedding(
803
+ config.max_position_embeddings, self.embedding_size
804
+ )
805
+
806
+ if config.type_vocab_size > 0:
807
+ self.token_type_embeddings = nn.Embedding(
808
+ config.type_vocab_size, self.embedding_size
809
+ )
810
+
811
+ if config.ss_vocab_size > 0:
812
+ self.ss_embeddings = nn.Embedding(config.ss_vocab_size, self.embedding_size)
813
+ self.ss_layer_norm = DeprotLayerNorm(
814
+ config.hidden_size, config.layer_norm_eps
815
+ )
816
+
817
+ if self.embedding_size != config.hidden_size:
818
+ self.embed_proj = nn.Linear(
819
+ self.embedding_size, config.hidden_size, bias=False
820
+ )
821
+ self.LayerNorm = DeprotLayerNorm(config.hidden_size, config.layer_norm_eps)
822
+
823
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
824
+ self.config = config
825
+
826
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
827
+ if self.position_biased_input:
828
+ self.register_buffer(
829
+ "position_ids",
830
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
831
+ persistent=False,
832
+ )
833
+
834
+ def forward(
835
+ self,
836
+ input_ids=None,
837
+ ss_input_ids=None,
838
+ token_type_ids=None,
839
+ position_ids=None,
840
+ mask=None,
841
+ inputs_embeds=None,
842
+ ):
843
+ if input_ids is not None:
844
+ input_shape = input_ids.size()
845
+ else:
846
+ input_shape = inputs_embeds.size()[:-1]
847
+
848
+ seq_length = input_shape[1]
849
+
850
+ if position_ids is None and self.position_biased_input:
851
+ position_ids = self.position_ids[:, :seq_length]
852
+ if seq_length > position_ids.size(1):
853
+ zero_padding = (
854
+ torch.zeros(
855
+ (input_shape[0], seq_length - position_ids.size(1)),
856
+ dtype=torch.long,
857
+ device=position_ids.device,
858
+ )
859
+ + 2047
860
+ )
861
+ position_ids = torch.cat([position_ids, zero_padding], dim=1)
862
+
863
+ if token_type_ids is None:
864
+ token_type_ids = torch.zeros(
865
+ input_shape, dtype=torch.long, device=self.position_ids.device
866
+ )
867
+
868
+ if inputs_embeds is None:
869
+ if self.config.token_dropout:
870
+ inputs_embeds = self.word_embeddings(input_ids)
871
+ inputs_embeds.masked_fill_(
872
+ (input_ids == self.config.mask_token_id).unsqueeze(-1), 0.0
873
+ )
874
+ mask_ratio_train = self.config.mlm_probability * 0.8
875
+ src_lengths = mask.sum(dim=-1)
876
+ mask_ratio_observed = (input_ids == self.config.mask_token_id).sum(
877
+ -1
878
+ ).to(inputs_embeds.dtype) / src_lengths
879
+ inputs_embeds = (
880
+ inputs_embeds
881
+ * (1 - mask_ratio_train)
882
+ / (1 - mask_ratio_observed)[:, None, None]
883
+ )
884
+ else:
885
+ inputs_embeds = self.word_embeddings(input_ids)
886
+
887
+ if self.position_embeddings is not None and self.position_biased_input:
888
+ position_embeddings = self.position_embeddings(position_ids.long())
889
+ else:
890
+ position_embeddings = torch.zeros_like(inputs_embeds)
891
+
892
+ embeddings = inputs_embeds
893
+ if self.position_biased_input:
894
+ embeddings += position_embeddings
895
+ if self.config.type_vocab_size > 0:
896
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
897
+ embeddings += token_type_embeddings
898
+
899
+ if self.embedding_size != self.config.hidden_size:
900
+ embeddings = self.embed_proj(embeddings)
901
+
902
+ embeddings = self.LayerNorm(embeddings)
903
+
904
+ if mask is not None:
905
+ if mask.dim() != embeddings.dim():
906
+ if mask.dim() == 4:
907
+ mask = mask.squeeze(1).squeeze(1)
908
+ mask = mask.unsqueeze(2)
909
+ mask = mask.to(embeddings.dtype)
910
+ embeddings = embeddings * mask
911
+
912
+ embeddings = self.dropout(embeddings)
913
+
914
+ if self.config.ss_vocab_size > 0:
915
+ ss_embeddings = self.ss_embeddings(ss_input_ids)
916
+ ss_embeddings = self.ss_layer_norm(ss_embeddings)
917
+ if mask is not None:
918
+ if mask.dim() != ss_embeddings.dim():
919
+ if mask.dim() == 4:
920
+ mask = mask.squeeze(1).squeeze(1)
921
+ mask = mask.unsqueeze(2)
922
+ mask = mask.to(ss_embeddings.dtype)
923
+ ss_embeddings = ss_embeddings * mask
924
+ ss_embeddings = self.dropout(ss_embeddings)
925
+ return embeddings, ss_embeddings
926
+
927
+ return embeddings, None
928
+
929
+
930
+ class DeprotPreTrainedModel(PreTrainedModel):
931
+ """
932
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
933
+ models.
934
+ """
935
+
936
+ config_class = DeprotConfig
937
+ base_model_prefix = "deprot"
938
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
939
+ supports_gradient_checkpointing = True
940
+
941
+ def _init_weights(self, module):
942
+ """Initialize the weights."""
943
+ if isinstance(module, nn.Linear):
944
+ # Slightly different from the TF version which uses truncated_normal for initialization
945
+ # cf https://github.com/pytorch/pytorch/pull/5617
946
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
947
+ if module.bias is not None:
948
+ module.bias.data.zero_()
949
+ elif isinstance(module, nn.Embedding):
950
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
951
+ if module.padding_idx is not None:
952
+ module.weight.data[module.padding_idx].zero_()
953
+
954
+ def _set_gradient_checkpointing(self, module, value=False):
955
+ if isinstance(module, DeprotEncoder):
956
+ module.gradient_checkpointing = value
957
+
958
+
959
+ class DeprotModel(DeprotPreTrainedModel):
960
+ def __init__(self, config):
961
+ super().__init__(config)
962
+
963
+ self.embeddings = DeprotEmbeddings(config)
964
+ self.encoder = DeprotEncoder(config)
965
+ self.config = config
966
+ # Initialize weights and apply final processing
967
+ self.post_init()
968
+
969
+ def get_input_embeddings(self):
970
+ return self.embeddings.word_embeddings
971
+
972
+ def set_input_embeddings(self, new_embeddings):
973
+ self.embeddings.word_embeddings = new_embeddings
974
+
975
+ def _prune_heads(self, heads_to_prune):
976
+ """
977
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
978
+ class PreTrainedModel
979
+ """
980
+ raise NotImplementedError(
981
+ "The prune function is not implemented in DeBERTa model."
982
+ )
983
+
984
+ def forward(
985
+ self,
986
+ input_ids: Optional[torch.Tensor] = None,
987
+ ss_input_ids: Optional[torch.Tensor] = None,
988
+ attention_mask: Optional[torch.Tensor] = None,
989
+ token_type_ids: Optional[torch.Tensor] = None,
990
+ position_ids: Optional[torch.Tensor] = None,
991
+ inputs_embeds: Optional[torch.Tensor] = None,
992
+ output_attentions: Optional[bool] = None,
993
+ output_hidden_states: Optional[bool] = None,
994
+ return_dict: Optional[bool] = None,
995
+ ) -> Union[Tuple, BaseModelOutput]:
996
+ output_attentions = (
997
+ output_attentions
998
+ if output_attentions is not None
999
+ else self.config.output_attentions
1000
+ )
1001
+ output_hidden_states = (
1002
+ output_hidden_states
1003
+ if output_hidden_states is not None
1004
+ else self.config.output_hidden_states
1005
+ )
1006
+ return_dict = (
1007
+ return_dict if return_dict is not None else self.config.use_return_dict
1008
+ )
1009
+
1010
+ if input_ids is not None and inputs_embeds is not None:
1011
+ raise ValueError(
1012
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1013
+ )
1014
+ elif input_ids is not None:
1015
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1016
+ input_shape = input_ids.size()
1017
+ elif inputs_embeds is not None:
1018
+ input_shape = inputs_embeds.size()[:-1]
1019
+ else:
1020
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1021
+
1022
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1023
+
1024
+ if attention_mask is None:
1025
+ attention_mask = torch.ones(input_shape, device=device)
1026
+ if token_type_ids is None:
1027
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1028
+
1029
+ embedding_output, ss_embeddings = self.embeddings(
1030
+ input_ids=input_ids,
1031
+ ss_input_ids=ss_input_ids,
1032
+ token_type_ids=token_type_ids,
1033
+ position_ids=position_ids,
1034
+ mask=attention_mask,
1035
+ inputs_embeds=inputs_embeds,
1036
+ )
1037
+
1038
+ encoder_outputs = self.encoder(
1039
+ embedding_output,
1040
+ attention_mask,
1041
+ output_hidden_states=True,
1042
+ output_attentions=output_attentions,
1043
+ return_dict=return_dict,
1044
+ ss_hidden_states=ss_embeddings,
1045
+ )
1046
+ encoded_layers = encoder_outputs[1]
1047
+
1048
+ sequence_output = encoded_layers[-1]
1049
+
1050
+ if not return_dict:
1051
+ return (sequence_output,) + encoder_outputs[
1052
+ (1 if output_hidden_states else 2) :
1053
+ ]
1054
+
1055
+ return BaseModelOutput(
1056
+ last_hidden_state=sequence_output,
1057
+ hidden_states=(
1058
+ encoder_outputs.hidden_states if output_hidden_states else None
1059
+ ),
1060
+ attentions=encoder_outputs.attentions,
1061
+ )
1062
+
1063
+
1064
+ class DeprotPredictionHeadTransform(nn.Module):
1065
+ def __init__(self, config):
1066
+ super().__init__()
1067
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
1068
+
1069
+ self.dense = nn.Linear(config.hidden_size, self.embedding_size)
1070
+ if isinstance(config.hidden_act, str):
1071
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1072
+ else:
1073
+ self.transform_act_fn = config.hidden_act
1074
+ self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
1075
+
1076
+ def forward(self, hidden_states):
1077
+ hidden_states = self.dense(hidden_states)
1078
+ hidden_states = self.transform_act_fn(hidden_states)
1079
+ hidden_states = self.LayerNorm(hidden_states)
1080
+ return hidden_states
1081
+
1082
+
1083
+ class DeprotLMPredictionHead(nn.Module):
1084
+ def __init__(self, config):
1085
+ super().__init__()
1086
+ self.transform = DeprotPredictionHeadTransform(config)
1087
+
1088
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
1089
+ # The output weights are the same as the input embeddings, but there is
1090
+ # an output-only bias for each token.
1091
+ self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False)
1092
+
1093
+ # self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1094
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
1095
+ # self.decoder.bias = self.bias
1096
+
1097
+ def forward(self, hidden_states):
1098
+ hidden_states = self.transform(hidden_states)
1099
+ hidden_states = self.decoder(hidden_states)
1100
+ return hidden_states
1101
+
1102
+
1103
+ class DeprotOnlyMLMHead(nn.Module):
1104
+ def __init__(self, config):
1105
+ super().__init__()
1106
+ self.predictions = DeprotLMPredictionHead(config)
1107
+
1108
+ def forward(self, sequence_output):
1109
+ prediction_scores = self.predictions(sequence_output)
1110
+ return prediction_scores
1111
+
1112
+
1113
+ class DeprotPreTrainedModel(PreTrainedModel):
1114
+ """
1115
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1116
+ models.
1117
+ """
1118
+
1119
+ config_class = DeprotConfig
1120
+ base_model_prefix = "deprot"
1121
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
1122
+ supports_gradient_checkpointing = True
1123
+
1124
+ def _init_weights(self, module):
1125
+ """Initialize the weights."""
1126
+ if isinstance(module, nn.Linear):
1127
+ # Slightly different from the TF version which uses truncated_normal for initialization
1128
+ # cf https://github.com/pytorch/pytorch/pull/5617
1129
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1130
+ if module.bias is not None:
1131
+ module.bias.data.zero_()
1132
+ elif isinstance(module, nn.Embedding):
1133
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1134
+ if module.padding_idx is not None:
1135
+ module.weight.data[module.padding_idx].zero_()
1136
+
1137
+ def _set_gradient_checkpointing(self, module, value=False):
1138
+ if isinstance(module, DeprotEncoder):
1139
+ module.gradient_checkpointing = value
1140
+
1141
+
1142
+ class DeprotForMaskedLM(DeprotPreTrainedModel):
1143
+ _tied_weights_keys = [
1144
+ "cls.predictions.decoder.weight",
1145
+ "cls.predictions.decoder.bias",
1146
+ ]
1147
+
1148
+ def __init__(self, config):
1149
+ super().__init__(config)
1150
+
1151
+ self.deprot = DeprotModel(config)
1152
+ self.cls = DeprotOnlyMLMHead(config)
1153
+
1154
+ # Initialize weights and apply final processing
1155
+ self.post_init()
1156
+
1157
+ def get_output_embeddings(self):
1158
+ return self.cls.predictions.decoder
1159
+
1160
+ def set_output_embeddings(self, new_embeddings):
1161
+ self.cls.predictions.decoder = new_embeddings
1162
+
1163
+ def forward(
1164
+ self,
1165
+ input_ids: Optional[torch.Tensor] = None,
1166
+ ss_input_ids: Optional[torch.Tensor] = None,
1167
+ attention_mask: Optional[torch.Tensor] = None,
1168
+ token_type_ids: Optional[torch.Tensor] = None,
1169
+ position_ids: Optional[torch.Tensor] = None,
1170
+ inputs_embeds: Optional[torch.Tensor] = None,
1171
+ labels: Optional[torch.Tensor] = None,
1172
+ output_attentions: Optional[bool] = None,
1173
+ output_hidden_states: Optional[bool] = None,
1174
+ return_dict: Optional[bool] = None,
1175
+ ) -> Union[Tuple, MaskedLMOutput]:
1176
+ r"""
1177
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1178
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1179
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1180
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1181
+ """
1182
+
1183
+ return_dict = (
1184
+ return_dict if return_dict is not None else self.config.use_return_dict
1185
+ )
1186
+
1187
+ outputs = self.deprot(
1188
+ input_ids,
1189
+ ss_input_ids=ss_input_ids,
1190
+ attention_mask=attention_mask,
1191
+ token_type_ids=token_type_ids,
1192
+ position_ids=position_ids,
1193
+ inputs_embeds=inputs_embeds,
1194
+ output_attentions=output_attentions,
1195
+ output_hidden_states=output_hidden_states,
1196
+ return_dict=return_dict,
1197
+ )
1198
+
1199
+ sequence_output = outputs[0]
1200
+ prediction_scores = self.cls(sequence_output)
1201
+
1202
+ masked_lm_loss = None
1203
+ if labels is not None:
1204
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1205
+ masked_lm_loss = loss_fct(
1206
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1207
+ )
1208
+
1209
+ if not return_dict:
1210
+ output = (prediction_scores,) + outputs[1:]
1211
+ return (
1212
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1213
+ )
1214
+
1215
+ return MaskedLMOutput(
1216
+ loss=masked_lm_loss,
1217
+ logits=prediction_scores,
1218
+ hidden_states=outputs.hidden_states,
1219
+ attentions=outputs.attentions,
1220
+ )
1221
+
1222
+
1223
+ class DeprotForSequenceClassification(DeprotPreTrainedModel):
1224
+ def __init__(self, config):
1225
+ super().__init__(config)
1226
+
1227
+ num_labels = getattr(config, "num_labels", 2)
1228
+ self.num_labels = num_labels
1229
+ self.scale_hidden = getattr(config, "scale_hidden", 1)
1230
+ self.deprot = DeprotModel(config)
1231
+ self.pooler = ContextPooler(config)
1232
+ output_dim = self.pooler.output_dim * self.scale_hidden
1233
+
1234
+ self.classifier = nn.Linear(output_dim, num_labels)
1235
+ drop_out = getattr(config, "cls_dropout", None)
1236
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1237
+ self.dropout = nn.Dropout(drop_out)
1238
+
1239
+ # Initialize weights and apply final processing
1240
+ self.post_init()
1241
+
1242
+ def get_input_embeddings(self):
1243
+ return self.deprot.get_input_embeddings()
1244
+
1245
+ def set_input_embeddings(self, new_embeddings):
1246
+ self.deprot.set_input_embeddings(new_embeddings)
1247
+
1248
+ def forward(
1249
+ self,
1250
+ input_ids: Optional[torch.Tensor] = None,
1251
+ ss_input_ids: Optional[torch.Tensor] = None,
1252
+ attention_mask: Optional[torch.Tensor] = None,
1253
+ token_type_ids: Optional[torch.Tensor] = None,
1254
+ position_ids: Optional[torch.Tensor] = None,
1255
+ inputs_embeds: Optional[torch.Tensor] = None,
1256
+ labels: Optional[torch.Tensor] = None,
1257
+ output_attentions: Optional[bool] = None,
1258
+ output_hidden_states: Optional[bool] = None,
1259
+ return_dict: Optional[bool] = None,
1260
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1261
+ r"""
1262
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1263
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1264
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1265
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1266
+ """
1267
+ return_dict = (
1268
+ return_dict if return_dict is not None else self.config.use_return_dict
1269
+ )
1270
+
1271
+ outputs = self.deprot(
1272
+ input_ids,
1273
+ ss_input_ids=ss_input_ids,
1274
+ token_type_ids=token_type_ids,
1275
+ attention_mask=attention_mask,
1276
+ position_ids=position_ids,
1277
+ inputs_embeds=inputs_embeds,
1278
+ output_attentions=output_attentions,
1279
+ output_hidden_states=output_hidden_states,
1280
+ return_dict=return_dict,
1281
+ )
1282
+
1283
+ encoder_layer = outputs[0]
1284
+ pooled_output = self.pooler(encoder_layer, attention_mask)
1285
+ pooled_output = self.dropout(pooled_output)
1286
+ logits = self.classifier(pooled_output)
1287
+
1288
+ loss = None
1289
+ if labels is not None:
1290
+ if self.config.problem_type is None:
1291
+ if self.num_labels == 1:
1292
+ # regression task
1293
+ loss_fn = nn.MSELoss()
1294
+ logits = logits.view(-1).to(labels.dtype)
1295
+ loss = loss_fn(logits, labels.view(-1))
1296
+ elif labels.dim() == 1 or labels.size(-1) == 1:
1297
+ label_index = (labels >= 0).nonzero()
1298
+ labels = labels.long()
1299
+ if label_index.size(0) > 0:
1300
+ labeled_logits = torch.gather(
1301
+ logits,
1302
+ 0,
1303
+ label_index.expand(label_index.size(0), logits.size(1)),
1304
+ )
1305
+ labels = torch.gather(labels, 0, label_index.view(-1))
1306
+ loss_fct = CrossEntropyLoss()
1307
+ loss = loss_fct(
1308
+ labeled_logits.view(-1, self.num_labels).float(),
1309
+ labels.view(-1),
1310
+ )
1311
+ else:
1312
+ loss = torch.tensor(0).to(logits)
1313
+ else:
1314
+ log_softmax = nn.LogSoftmax(-1)
1315
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1316
+ elif self.config.problem_type == "regression":
1317
+ loss_fct = MSELoss()
1318
+ if self.num_labels == 1:
1319
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1320
+ else:
1321
+ loss = loss_fct(logits, labels)
1322
+ elif self.config.problem_type == "binary_classification":
1323
+ loss_fct = BCEWithLogitsLoss()
1324
+ loss = loss_fct(logits.squeeze(), labels.squeeze().to(logits.dtype))
1325
+ elif self.config.problem_type == "single_label_classification":
1326
+ loss_fct = CrossEntropyLoss()
1327
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1328
+ elif self.config.problem_type == "multi_label_classification":
1329
+ loss_fct = BCEWithLogitsLoss()
1330
+ loss = loss_fct(logits, labels.to(logits.dtype))
1331
+ if not return_dict:
1332
+ output = (logits,) + outputs[1:]
1333
+ return ((loss,) + output) if loss is not None else output
1334
+
1335
+ return SequenceClassifierOutput(
1336
+ loss=loss,
1337
+ logits=logits,
1338
+ hidden_states=outputs.hidden_states,
1339
+ attentions=outputs.attentions,
1340
+ )
1341
+
1342
+
1343
+ class DeprotForTokenClassification(DeprotPreTrainedModel):
1344
+ def __init__(self, config):
1345
+ super().__init__(config)
1346
+ self.num_labels = config.num_labels
1347
+
1348
+ self.deprot = DeprotModel(config)
1349
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1350
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1351
+
1352
+ # Initialize weights and apply final processing
1353
+ self.post_init()
1354
+
1355
+ def forward(
1356
+ self,
1357
+ input_ids: Optional[torch.Tensor] = None,
1358
+ attention_mask: Optional[torch.Tensor] = None,
1359
+ token_type_ids: Optional[torch.Tensor] = None,
1360
+ position_ids: Optional[torch.Tensor] = None,
1361
+ inputs_embeds: Optional[torch.Tensor] = None,
1362
+ labels: Optional[torch.Tensor] = None,
1363
+ output_attentions: Optional[bool] = None,
1364
+ output_hidden_states: Optional[bool] = None,
1365
+ return_dict: Optional[bool] = None,
1366
+ ) -> Union[Tuple, TokenClassifierOutput]:
1367
+ r"""
1368
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1369
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1370
+ """
1371
+ return_dict = (
1372
+ return_dict if return_dict is not None else self.config.use_return_dict
1373
+ )
1374
+
1375
+ outputs = self.deprot(
1376
+ input_ids,
1377
+ attention_mask=attention_mask,
1378
+ token_type_ids=token_type_ids,
1379
+ position_ids=position_ids,
1380
+ inputs_embeds=inputs_embeds,
1381
+ output_attentions=output_attentions,
1382
+ output_hidden_states=output_hidden_states,
1383
+ return_dict=return_dict,
1384
+ )
1385
+
1386
+ sequence_output = outputs[0]
1387
+
1388
+ sequence_output = self.dropout(sequence_output)
1389
+ logits = self.classifier(sequence_output)
1390
+
1391
+ loss = None
1392
+ if labels is not None:
1393
+ loss_fct = CrossEntropyLoss()
1394
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1395
+
1396
+ if not return_dict:
1397
+ output = (logits,) + outputs[1:]
1398
+ return ((loss,) + output) if loss is not None else output
1399
+
1400
+ return TokenClassifierOutput(
1401
+ loss=loss,
1402
+ logits=logits,
1403
+ hidden_states=outputs.hidden_states,
1404
+ attentions=outputs.attentions,
1405
+ )
1406
+
1407
+
1408
+ DeprotModel.register_for_auto_class("AutoModel")
1409
+ DeprotForMaskedLM.register_for_auto_class("AutoModelForMaskedLM")
1410
+ DeprotForSequenceClassification.register_for_auto_class(
1411
+ "AutoModelForSequenceClassification"
1412
+ )
1413
+ DeprotForTokenClassification.register_for_auto_class("AutoModelForTokenClassification")