Jackmin108 commited on
Commit
76fc218
·
1 Parent(s): c1736a8

feat: adapter masking finished

Browse files

Signed-off-by: Meow <[email protected]>

Files changed (5) hide show
  1. block.py +11 -11
  2. embedding.py +20 -39
  3. mha.py +28 -37
  4. mlp.py +25 -19
  5. modeling_xlm_roberta.py +23 -28
block.py CHANGED
@@ -233,17 +233,17 @@ class Block(nn.Module):
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
- task_type = mixer_kwargs.get('task_type')
237
- if task_type:
238
- if isinstance(task_type, tuple):
239
- assert mixer_kwargs['cu_seqlens'].shape[0] % 9 == 1
240
- split_index = int((mixer_kwargs['cu_seqlens'].shape[0] - 1) / 9)
241
- split = mixer_kwargs['cu_seqlens'][split_index]
242
- mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'), split=split)
243
- else:
244
- mlp_out = self.mlp(hidden_states, task_type=task_type)
245
- else:
246
- mlp_out = self.mlp(hidden_states)
247
  if self.return_residual: # mlp out is actually a pair here
248
  mlp_out, hidden_states = mlp_out
249
  if not self.fused_dropout_add_ln:
 
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
+ mlp_out = self.mlp(hidden_states, cu_adapter_mask=mixer_kwargs.get('cu_adapter_mask'))
237
+ # if cu_adapter_mask:
238
+ # if isinstance(task_type, tuple):
239
+ # assert mixer_kwargs['cu_seqlens'].shape[0] % 9 == 1
240
+ # split_index = int((mixer_kwargs['cu_seqlens'].shape[0] - 1) / 9)
241
+ # split = mixer_kwargs['cu_seqlens'][split_index]
242
+ # mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'), split=split)
243
+ # else:
244
+ # mlp_out = self.mlp(hidden_states, task_type=task_type)
245
+ # else:
246
+ # mlp_out = self.mlp(hidden_states)
247
  if self.return_residual: # mlp out is actually a pair here
248
  mlp_out, hidden_states = mlp_out
249
  if not self.fused_dropout_add_ln:
embedding.py CHANGED
@@ -40,40 +40,25 @@ class XLMRobertaEmbeddings(nn.Module):
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
- def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None, adapter_mask=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
- if isinstance(task_type, tuple):
51
- assert input_ids.shape[0] % 9 == 0
52
- split = int(input_ids.shape[0] / 9)
53
- tensor1 = input_ids[:split, :]
54
- tensor2 = input_ids[split:, :]
55
- emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
56
- emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
57
- embeddings = torch.cat((emb1, emb2), dim=0)
58
-
59
  unique_tasks = torch.unique(adapter_mask).tolist()
60
- torch_dtype = next(self.word_embeddings.parameters()).dtype
61
- embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim, dtype=torch_dtype).to(input_ids.device)
62
- for task in unique_tasks:
63
- indices = (adapter_mask == task).nonzero(as_tuple=True)[0]
64
- inp = input_ids[indices]
65
- lora_kwargs = {'task_type': task} if task is not None else {}
66
- emb = self.word_embeddings(inp, **lora_kwargs)
67
- embeddings[indices] = emb
68
-
69
- exit(0)
70
  else:
71
- unique_task = torch.unique(adapter_mask)[0]
72
- task1_indices = (adapter_mask == unique_task).nonzero(as_tuple=True)[0]
73
- input1 = input_ids[task1_indices]
74
- lora_kwargs = {'task_type': unique_task} if unique_task is not None else {}
75
- embeddings = self.word_embeddings(input1, **lora_kwargs)
76
-
77
 
78
  if self.max_position_embeddings > 0:
79
  if position_ids is None:
@@ -84,19 +69,15 @@ class XLMRobertaEmbeddings(nn.Module):
84
  if self.type_vocab_size > 0:
85
  if token_type_ids is None:
86
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
87
- if isinstance(task_type, tuple):
88
- assert embeddings.shape[0] % 9 == 0
89
- split = int(embeddings.shape[0] / 9)
90
- emb1 = embeddings[:split, :, :]
91
- emb2 = embeddings[split:, :, :]
92
- token_type_embs1 = self.token_type_embeddings(token_type_ids, task_type=task_type[0])
93
- token_type_embs2 = self.token_type_embeddings(token_type_ids, task_type=task_type[1])
94
- emb1 = emb1 + token_type_embs1
95
- emb2 = emb2 + token_type_embs2
96
- embeddings = torch.cat((emb1, emb2), dim=0)
97
  else:
98
- unique_task = torch.unique(adapter_mask)[0]
99
- lora_kwargs = {'task_type': unique_task} if unique_task is not None else {}
100
- token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
101
  embeddings = embeddings + token_type_embeddings
 
102
  return embeddings
 
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ if adapter_mask is not None:
 
 
 
 
 
 
 
 
51
  unique_tasks = torch.unique(adapter_mask).tolist()
52
+ embedding_dtype = next(self.word_embeddings.parameters()).dtype
53
+ embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
54
+ dtype=embedding_dtype).to(input_ids.device)
55
+ for task_id in unique_tasks:
56
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
57
+ task_input_ids = input_ids[task_indices]
58
+ task_embeddings = self.word_embeddings(task_input_ids, task_type=task_id)
59
+ embeddings[task_indices] = task_embeddings
 
 
60
  else:
61
+ embeddings = self.word_embeddings(input_ids)
 
 
 
 
 
62
 
63
  if self.max_position_embeddings > 0:
64
  if position_ids is None:
 
69
  if self.type_vocab_size > 0:
70
  if token_type_ids is None:
71
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
72
+
73
+ if adapter_mask is not None:
74
+ unique_tasks = torch.unique(adapter_mask).tolist()
75
+ for task_id in unique_tasks:
76
+ task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_type=task_id)
77
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
78
+ embeddings[task_indices] = embeddings[task_indices] + task_token_type_embeddings
 
 
 
79
  else:
80
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
 
 
81
  embeddings = embeddings + token_type_embeddings
82
+
83
  return embeddings
mha.py CHANGED
@@ -590,7 +590,7 @@ class MHA(nn.Module):
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
- task_type=None,
594
  **kwargs,
595
  ):
596
  """
@@ -643,39 +643,27 @@ class MHA(nn.Module):
643
  inference_params.max_sequence_len if inference_params is not None else max_seqlen
644
  )
645
  batch, seqlen = x.shape[:2]
646
- lora_kwargs = {}
647
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
648
  assert x_kv is None and mixer_subset is None
649
 
650
- split = None
651
- if isinstance(task_type, tuple):
652
- assert cu_seqlens.shape[0] % 9 == 1
653
- split_index = int((cu_seqlens.shape[0] - 1) / 9)
654
- split = cu_seqlens[split_index]
655
-
656
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
657
-
658
- if not self.return_residual:
659
- if isinstance(task_type, tuple):
660
- tensor1 = x[:split, :]
661
- tensor2 = x[split:, :]
662
- qkv1 = self.Wqkv(tensor1, task_type=task_type[0])
663
- qkv2 = self.Wqkv(tensor2, task_type=task_type[1])
664
- qkv = torch.cat((qkv1, qkv2), dim=0)
665
- else:
666
- qkv = self.Wqkv(x, **lora_kwargs)
667
  else:
668
- if isinstance(task_type, tuple):
669
- tensor1 = x[:split, :]
670
- tensor2 = x[split:, :]
671
- qkv1, tensor1 = self.Wqkv(tensor1, task_type=task_type[0], residual=True)
672
- qkv2, tensor2 = self.Wqkv(tensor2, task_type=task_type[1], residual=True)
673
- qkv = torch.cat((qkv1, qkv2), dim=0)
674
- x = torch.cat((tensor1, tensor2), dim=0)
675
  else:
676
- if lora_kwargs:
677
- lora_kwargs['residual'] = True
678
- qkv, x = self.Wqkv(x, **lora_kwargs)
679
 
680
  if self.dwconv:
681
  qkv = rearrange(
@@ -762,14 +750,17 @@ class MHA(nn.Module):
762
  else:
763
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
764
 
765
- lora_kwargs.pop('residual', None)
766
  inp = rearrange(context, "... h d -> ... (h d)")
767
- if isinstance(task_type, tuple):
768
- tensor1 = inp[:split, :]
769
- tensor2 = inp[split:, :]
770
- out1 = self.out_proj(tensor1, task_type=task_type[0])
771
- out2 = self.out_proj(tensor2, task_type=task_type[1])
772
- out = torch.cat((out1, out2), dim=0)
 
 
 
 
773
  else:
774
- out = self.out_proj(inp, **lora_kwargs)
775
  return out if not self.return_residual else (out, x)
 
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
+ cu_adapter_mask=None,
594
  **kwargs,
595
  ):
596
  """
 
643
  inference_params.max_sequence_len if inference_params is not None else max_seqlen
644
  )
645
  batch, seqlen = x.shape[:2]
 
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
648
 
649
+ if cu_adapter_mask is not None:
650
+ unique_tasks = torch.unique(cu_adapter_mask).tolist()
651
+ qkv_dtype = next(self.Wqkv.parameters()).dtype
652
+ qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
653
+ dtype=qkv_dtype).to(x.device)
654
+ for task_id in unique_tasks:
655
+ task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
656
+ task_tensor = x[task_indices]
657
+ if not self.return_residual:
658
+ task_qkv = self.Wqkv(task_tensor, task_type=task_id)
659
+ else:
660
+ task_qkv, _ = self.Wqkv(task_tensor, task_type=task_id, residual=True)
661
+ qkv[task_indices] = task_qkv
 
 
 
 
662
  else:
663
+ if not self.return_residual:
664
+ qkv = self.Wqkv(x)
 
 
 
 
 
665
  else:
666
+ qkv, x = self.Wqkv(x)
 
 
667
 
668
  if self.dwconv:
669
  qkv = rearrange(
 
750
  else:
751
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
752
 
 
753
  inp = rearrange(context, "... h d -> ... (h d)")
754
+ if cu_adapter_mask is not None:
755
+ unique_tasks = torch.unique(cu_adapter_mask).tolist()
756
+ out_dtype = next(self.out_proj.parameters()).dtype
757
+ out = torch.empty(inp.shape[0], self.out_proj.out_features,
758
+ dtype=out_dtype).to(inp.device)
759
+ for task_id in unique_tasks:
760
+ task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
761
+ task_tensor = inp[task_indices]
762
+ task_out = self.out_proj(task_tensor, task_type=task_id)
763
+ out[task_indices] = task_out
764
  else:
765
+ out = self.out_proj(inp)
766
  return out if not self.return_residual else (out, x)
mlp.py CHANGED
@@ -47,30 +47,36 @@ class Mlp(nn.Module):
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
- def forward(self, x, task_type=None, split=None):
51
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
52
- if split:
53
- assert isinstance(task_type, tuple)
54
- tensor1 = x[:split, :]
55
- tensor2 = x[split:, :]
56
- y1 = self.fc1(tensor1, task_type=task_type[0])
57
- y2 = self.fc1(tensor2, task_type=task_type[1])
58
- y = torch.cat((y1, y2), dim=0)
 
 
59
  else:
60
- y = self.fc1(x, **lora_kwargs)
61
 
62
  y = self.activation(y)
63
 
64
- if split:
65
- assert isinstance(task_type, tuple)
66
- tensor1 = y[:split, :]
67
- tensor2 = y[split:, :]
68
- y1 = self.fc2(tensor1, task_type=task_type[0])
69
- y2 = self.fc2(tensor2, task_type=task_type[1])
70
- y = torch.cat((y1, y2), dim=0)
 
 
 
71
  else:
72
- y = self.fc2(y, **lora_kwargs)
73
- return y if not self.return_residual else (y, x)
 
74
 
75
 
76
  class ParallelMLP(nn.Module):
 
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
+ def forward(self, x, cu_adapter_mask=None):
51
+ if cu_adapter_mask is not None:
52
+ unique_tasks = torch.unique(cu_adapter_mask).tolist()
53
+ fc1_dtype = next(self.fc1.parameters()).dtype
54
+ y = torch.empty(x.shape[0], self.fc1.out_features,
55
+ dtype=fc1_dtype).to(x.device)
56
+ for task_id in unique_tasks:
57
+ task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
58
+ task_tensor = x[task_indices]
59
+ task_y = self.fc1(task_tensor, task_type=task_id)
60
+ y[task_indices] = task_y
61
  else:
62
+ y = self.fc1(x)
63
 
64
  y = self.activation(y)
65
 
66
+ if cu_adapter_mask is not None:
67
+ unique_tasks = torch.unique(cu_adapter_mask).tolist()
68
+ fc2_dtype = next(self.fc2.parameters()).dtype
69
+ out = torch.empty(y.shape[0], self.fc2.out_features,
70
+ dtype=fc2_dtype).to(y.device)
71
+ for task_id in unique_tasks:
72
+ task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
73
+ task_tensor = y[task_indices]
74
+ task_out = self.fc2(task_tensor, task_type=task_id)
75
+ out[task_indices] = task_out
76
  else:
77
+ out = self.fc1(y)
78
+
79
+ return out if not self.return_residual else (out, x)
80
 
81
 
82
  class ParallelMLP(nn.Module):
modeling_xlm_roberta.py CHANGED
@@ -204,18 +204,16 @@ class XLMRobertaEncoder(nn.Module):
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
- def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None, adapter_mask=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
211
  """
212
  if key_padding_mask is None or not self.use_flash_attn:
213
- mixer_kwargs = (
214
- {"key_padding_mask": key_padding_mask.bool()}
215
- if key_padding_mask is not None
216
- else None
217
- )
218
- mixer_kwargs['task_type'] = task_type
219
  for layer in self.layers:
220
  if self._grad_checkpointing:
221
  hidden_states = torch.utils.checkpoint.checkpoint(
@@ -233,7 +231,8 @@ class XLMRobertaEncoder(nn.Module):
233
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
234
  hidden_states, key_padding_mask, adapter_mask
235
  )
236
- mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type, "cu_adapter_mask": cu_adapter_mask}
 
237
  if subset_mask is None:
238
  for layer in self.layers:
239
  if self._grad_checkpointing:
@@ -310,24 +309,22 @@ class XLMRobertaPooler(nn.Module):
310
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
311
  self.activation = nn.Tanh()
312
 
313
- def forward(self, hidden_states, pool=True, task_type=None):
314
  # We "pool" the model by simply taking the hidden state corresponding
315
  # to the first token.
316
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
317
-
318
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
319
-
320
- if isinstance(task_type, tuple):
321
- assert first_token_tensor.shape[0] % 9 == 0
322
- split = int(first_token_tensor.shape[0] / 9)
323
- tensor1 = first_token_tensor[:split, :]
324
- tensor2 = first_token_tensor[split:, :]
325
- pooled_out1 = self.dense(tensor1, task_type=task_type[0])
326
- pooled_out2 = self.dense(tensor2, task_type=task_type[0])
327
- pooled_output = torch.cat((pooled_out1, pooled_out2), dim=0)
 
328
  else:
329
- pooled_output = self.dense(first_token_tensor, **lora_kwargs)
330
-
331
  pooled_output = self.activation(pooled_output)
332
  return pooled_output
333
 
@@ -440,7 +437,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
440
  "gelu_fast",
441
  "gelu_pytorch_tanh",
442
  ]
443
-
444
  self.embeddings = XLMRobertaEmbeddings(
445
  config.hidden_size,
446
  config.vocab_size,
@@ -648,7 +644,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
648
  layer output for these tokens.
649
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
650
  """
651
- task_type = kwargs.pop('task_type', None)
652
  adapter_mask = kwargs.pop('adapter_mask', None)
653
  if kwargs:
654
  for key, value in kwargs.items():
@@ -663,7 +658,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
663
  )
664
 
665
  hidden_states = self.embeddings(
666
- input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type, adapter_mask=adapter_mask
667
  )
668
  # TD [2022-12:18]: Don't need to force residual in fp32
669
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -687,12 +682,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
687
  subset_mask = None
688
 
689
  sequence_output = self.encoder(
690
- hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type, adapter_mask=adapter_mask
691
  )
692
 
693
  if masked_tokens_mask is None:
694
  pooled_output = (
695
- self.pooler(sequence_output, task_type=task_type) if self.pooler is not None else None
696
  )
697
  else:
698
  # TD [2022-03-01]: the indexing here is very tricky.
@@ -706,7 +701,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
706
  pool_input = sequence_output[first_col_mask[subset_mask]]
707
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
708
  pooled_output = (
709
- self.pooler(pool_input, pool=False, task_type=task_type) if self.pooler is not None else None
710
  )
711
 
712
  if not return_dict:
 
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, adapter_mask=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
211
  """
212
  if key_padding_mask is None or not self.use_flash_attn:
213
+ mixer_kwargs = {'adapter_mask': adapter_mask}
214
+ if key_padding_mask is not None:
215
+ mixer_kwargs['key_padding_mask'] = key_padding_mask.bool()
216
+
 
 
217
  for layer in self.layers:
218
  if self._grad_checkpointing:
219
  hidden_states = torch.utils.checkpoint.checkpoint(
 
231
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
232
  hidden_states, key_padding_mask, adapter_mask
233
  )
234
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "cu_adapter_mask": cu_adapter_mask}
235
+
236
  if subset_mask is None:
237
  for layer in self.layers:
238
  if self._grad_checkpointing:
 
309
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
310
  self.activation = nn.Tanh()
311
 
312
+ def forward(self, hidden_states, pool=True, adapter_mask=None):
313
  # We "pool" the model by simply taking the hidden state corresponding
314
  # to the first token.
 
 
315
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
316
+ if adapter_mask is not None:
317
+ unique_tasks = torch.unique(adapter_mask).tolist()
318
+ pool_dtype = next(self.dense.parameters()).dtype
319
+ pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
320
+ dtype=pool_dtype).to(first_token_tensor.device)
321
+ for task_id in unique_tasks:
322
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
323
+ task_first_token_tensor = first_token_tensor[task_indices]
324
+ task_pooled_output = self.dense(task_first_token_tensor, task_type=task_id)
325
+ pooled_output[task_indices] = task_pooled_output
326
  else:
327
+ pooled_output = self.dense(first_token_tensor)
 
328
  pooled_output = self.activation(pooled_output)
329
  return pooled_output
330
 
 
437
  "gelu_fast",
438
  "gelu_pytorch_tanh",
439
  ]
 
440
  self.embeddings = XLMRobertaEmbeddings(
441
  config.hidden_size,
442
  config.vocab_size,
 
644
  layer output for these tokens.
645
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
646
  """
 
647
  adapter_mask = kwargs.pop('adapter_mask', None)
648
  if kwargs:
649
  for key, value in kwargs.items():
 
658
  )
659
 
660
  hidden_states = self.embeddings(
661
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, adapter_mask=adapter_mask
662
  )
663
  # TD [2022-12:18]: Don't need to force residual in fp32
664
  # BERT puts embedding LayerNorm before embedding dropout.
 
682
  subset_mask = None
683
 
684
  sequence_output = self.encoder(
685
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, adapter_mask=adapter_mask
686
  )
687
 
688
  if masked_tokens_mask is None:
689
  pooled_output = (
690
+ self.pooler(sequence_output, adapter_mask=adapter_mask) if self.pooler is not None else None
691
  )
692
  else:
693
  # TD [2022-03-01]: the indexing here is very tricky.
 
701
  pool_input = sequence_output[first_col_mask[subset_mask]]
702
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
703
  pooled_output = (
704
+ self.pooler(pool_input, pool=False, adapter_mask=adapter_mask) if self.pooler is not None else None
705
  )
706
 
707
  if not return_dict: