replace -1e4 masks
Browse files- README.md +1 -0
- modeling_lsg_distilbert.py +11 -7
README.md
CHANGED
@@ -45,6 +45,7 @@ You can change various parameters like :
|
|
45 |
* local block size (block_size=128)
|
46 |
* sparse block size (sparse_block_size=128)
|
47 |
* sparsity factor (sparsity_factor=2)
|
|
|
48 |
* see config.json file
|
49 |
|
50 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
|
|
45 |
* local block size (block_size=128)
|
46 |
* sparse block size (sparse_block_size=128)
|
47 |
* sparsity factor (sparsity_factor=2)
|
48 |
+
* mask_first_token (mask first token since it is redundant with the first global token)
|
49 |
* see config.json file
|
50 |
|
51 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
modeling_lsg_distilbert.py
CHANGED
@@ -227,7 +227,11 @@ class CausalAttentionProduct(nn.Module):
|
|
227 |
|
228 |
# Add causal mask
|
229 |
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
|
230 |
-
causal_mask = torch.tril(
|
|
|
|
|
|
|
|
|
231 |
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
232 |
|
233 |
del attention_mask
|
@@ -345,7 +349,7 @@ class LSGAttentionProduct(nn.Module):
|
|
345 |
|
346 |
# Pad before block reshaping
|
347 |
if is_attn_mask:
|
348 |
-
pad_value =
|
349 |
hidden_states = hidden_states.transpose(-1, -2)
|
350 |
else:
|
351 |
pad_value = 0
|
@@ -378,7 +382,7 @@ class LSGAttentionProduct(nn.Module):
|
|
378 |
|
379 |
# Pad before block reshaping
|
380 |
if is_attn_mask:
|
381 |
-
pad_value =
|
382 |
hidden_states = hidden_states.transpose(-1, -2)
|
383 |
else:
|
384 |
pad_value = 0
|
@@ -511,7 +515,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
511 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
512 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
513 |
|
514 |
-
mask =
|
515 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
516 |
|
517 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
@@ -576,7 +580,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
576 |
keys /= mask + 1e-8
|
577 |
values /= mask + 1e-8
|
578 |
|
579 |
-
mask =
|
580 |
|
581 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
582 |
|
@@ -871,7 +875,7 @@ class LSGTransformerBlock(nn.Module):
|
|
871 |
# Self-Attention
|
872 |
sa_output = self.attention(
|
873 |
hidden_states=x,
|
874 |
-
attention_mask
|
875 |
head_mask=head_mask,
|
876 |
output_attentions=output_attentions,
|
877 |
)
|
@@ -948,7 +952,7 @@ class LSGDistilBertModel(LSGDistilBertPreTrainedModel, DistilBertModel):
|
|
948 |
n, t = inputs_.size()[:2]
|
949 |
|
950 |
if attention_mask is None:
|
951 |
-
attention_mask = torch.ones(n, t, device=inputs_.device)
|
952 |
if self.mask_first_token:
|
953 |
attention_mask[:,0] = 0
|
954 |
|
|
|
227 |
|
228 |
# Add causal mask
|
229 |
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
|
230 |
+
causal_mask = torch.tril(
|
231 |
+
torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
|
232 |
+
diagonal=-1
|
233 |
+
)
|
234 |
+
causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
|
235 |
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
236 |
|
237 |
del attention_mask
|
|
|
349 |
|
350 |
# Pad before block reshaping
|
351 |
if is_attn_mask:
|
352 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
353 |
hidden_states = hidden_states.transpose(-1, -2)
|
354 |
else:
|
355 |
pad_value = 0
|
|
|
382 |
|
383 |
# Pad before block reshaping
|
384 |
if is_attn_mask:
|
385 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
386 |
hidden_states = hidden_states.transpose(-1, -2)
|
387 |
else:
|
388 |
pad_value = 0
|
|
|
515 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
516 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
517 |
|
518 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
519 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
520 |
|
521 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
580 |
keys /= mask + 1e-8
|
581 |
values /= mask + 1e-8
|
582 |
|
583 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
584 |
|
585 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
586 |
|
|
|
875 |
# Self-Attention
|
876 |
sa_output = self.attention(
|
877 |
hidden_states=x,
|
878 |
+
attention_mask=torch.finfo(x.dtype).min*(1 - attn_mask).unsqueeze(1).unsqueeze(1),
|
879 |
head_mask=head_mask,
|
880 |
output_attentions=output_attentions,
|
881 |
)
|
|
|
952 |
n, t = inputs_.size()[:2]
|
953 |
|
954 |
if attention_mask is None:
|
955 |
+
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
956 |
if self.mask_first_token:
|
957 |
attention_mask[:,0] = 0
|
958 |
|