ccdv commited on
Commit
cf1b60c
·
1 Parent(s): adc1a36

replace -1e4 masks

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. modeling_lsg_distilbert.py +11 -7
README.md CHANGED
@@ -45,6 +45,7 @@ You can change various parameters like :
45
  * local block size (block_size=128)
46
  * sparse block size (sparse_block_size=128)
47
  * sparsity factor (sparsity_factor=2)
 
48
  * see config.json file
49
 
50
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
 
45
  * local block size (block_size=128)
46
  * sparse block size (sparse_block_size=128)
47
  * sparsity factor (sparsity_factor=2)
48
+ * mask_first_token (mask first token since it is redundant with the first global token)
49
  * see config.json file
50
 
51
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
modeling_lsg_distilbert.py CHANGED
@@ -227,7 +227,11 @@ class CausalAttentionProduct(nn.Module):
227
 
228
  # Add causal mask
229
  causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
230
- causal_mask = torch.tril(torch.ones(*causal_shape, device=attention_mask.device), diagonal=-1).T * (-10000)
 
 
 
 
231
  attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
232
 
233
  del attention_mask
@@ -345,7 +349,7 @@ class LSGAttentionProduct(nn.Module):
345
 
346
  # Pad before block reshaping
347
  if is_attn_mask:
348
- pad_value = -10000
349
  hidden_states = hidden_states.transpose(-1, -2)
350
  else:
351
  pad_value = 0
@@ -378,7 +382,7 @@ class LSGAttentionProduct(nn.Module):
378
 
379
  # Pad before block reshaping
380
  if is_attn_mask:
381
- pad_value = -10000
382
  hidden_states = hidden_states.transpose(-1, -2)
383
  else:
384
  pad_value = 0
@@ -511,7 +515,7 @@ class LSGSelfAttention(BaseSelfAttention):
511
  keys = keys.sum(dim=-2) / (mask + 1e-6)
512
  values = values.sum(dim=-2) / (mask + 1e-6)
513
 
514
- mask = - (1. - mask.clamp(0, 1)) * 1e4
515
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
516
 
517
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -576,7 +580,7 @@ class LSGSelfAttention(BaseSelfAttention):
576
  keys /= mask + 1e-8
577
  values /= mask + 1e-8
578
 
579
- mask = -10000 * (1. - mask.clamp(0, 1))
580
 
581
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
582
 
@@ -871,7 +875,7 @@ class LSGTransformerBlock(nn.Module):
871
  # Self-Attention
872
  sa_output = self.attention(
873
  hidden_states=x,
874
- attention_mask=-10000*(1 - attn_mask).unsqueeze(1).unsqueeze(1),
875
  head_mask=head_mask,
876
  output_attentions=output_attentions,
877
  )
@@ -948,7 +952,7 @@ class LSGDistilBertModel(LSGDistilBertPreTrainedModel, DistilBertModel):
948
  n, t = inputs_.size()[:2]
949
 
950
  if attention_mask is None:
951
- attention_mask = torch.ones(n, t, device=inputs_.device)
952
  if self.mask_first_token:
953
  attention_mask[:,0] = 0
954
 
 
227
 
228
  # Add causal mask
229
  causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
230
+ causal_mask = torch.tril(
231
+ torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
232
+ diagonal=-1
233
+ )
234
+ causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
235
  attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
236
 
237
  del attention_mask
 
349
 
350
  # Pad before block reshaping
351
  if is_attn_mask:
352
+ pad_value = torch.finfo(hidden_states.dtype).min
353
  hidden_states = hidden_states.transpose(-1, -2)
354
  else:
355
  pad_value = 0
 
382
 
383
  # Pad before block reshaping
384
  if is_attn_mask:
385
+ pad_value = torch.finfo(hidden_states.dtype).min
386
  hidden_states = hidden_states.transpose(-1, -2)
387
  else:
388
  pad_value = 0
 
515
  keys = keys.sum(dim=-2) / (mask + 1e-6)
516
  values = values.sum(dim=-2) / (mask + 1e-6)
517
 
518
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
519
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
520
 
521
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
580
  keys /= mask + 1e-8
581
  values /= mask + 1e-8
582
 
583
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
584
 
585
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
586
 
 
875
  # Self-Attention
876
  sa_output = self.attention(
877
  hidden_states=x,
878
+ attention_mask=torch.finfo(x.dtype).min*(1 - attn_mask).unsqueeze(1).unsqueeze(1),
879
  head_mask=head_mask,
880
  output_attentions=output_attentions,
881
  )
 
952
  n, t = inputs_.size()[:2]
953
 
954
  if attention_mask is None:
955
+ attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
956
  if self.mask_first_token:
957
  attention_mask[:,0] = 0
958