ccdv commited on
Commit
f7c6158
·
1 Parent(s): 1483ba7

replace -1e4 masks

Browse files
Files changed (1) hide show
  1. modeling_lsg_bart.py +12 -14
modeling_lsg_bart.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  from transformers.models.bart.modeling_bart import *
4
  from transformers.models.bart.modeling_bart import _expand_mask
5
  import torch.nn as nn
6
- from torch.nn import BCEWithLogitsLoss
7
  import sys
8
 
9
  AUTO_MAP = {
@@ -16,7 +15,7 @@ AUTO_MAP = {
16
 
17
  class LSGBartConfig(BartConfig):
18
  """
19
- This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate
20
  documentation alongside usage examples.
21
  """
22
 
@@ -266,8 +265,8 @@ class LSGAttentionProduct(nn.Module):
266
  s = (size - step) // 2
267
 
268
  # Pad before block reshaping
269
- if is_attn_mask:
270
- pad_value = -10000
271
  hidden_states = hidden_states.transpose(-1, -2)
272
  else:
273
  pad_value = 0
@@ -296,7 +295,7 @@ class LSGAttentionProduct(nn.Module):
296
 
297
  # Pad before block reshaping
298
  if is_attn_mask:
299
- pad_value = -10000
300
  hidden_states = hidden_states.transpose(-1, -2)
301
  else:
302
  pad_value = 0
@@ -425,7 +424,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
425
  keys = keys.sum(dim=-2) / (mask + 1e-6)
426
  values = values.sum(dim=-2) / (mask + 1e-6)
427
 
428
- mask = - (1. - mask.clamp(0, 1)) * 1e4
429
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
430
 
431
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -490,8 +489,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
490
  keys /= mask + 1e-8
491
  values /= mask + 1e-8
492
 
493
- mask = -10000 * (1. - mask.clamp(0, 1))
494
-
495
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
496
 
497
  def lsh_round(self, keys, values, mask, output_size):
@@ -739,7 +737,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
739
  n, t = inputs_.size()[:2]
740
 
741
  if attention_mask is None:
742
- attention_mask = torch.ones(n, t, device=inputs_.device)
743
  if self.mask_first_token:
744
  attention_mask[:, 0] = 0
745
 
@@ -891,7 +889,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
891
  )
892
 
893
 
894
- class LSGBartDecoder(BartDecoder, LSGBartPretrainedModel):
895
  """
896
  Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
897
  Args:
@@ -1032,7 +1030,7 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
1032
  )
1033
 
1034
 
1035
- class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretrainedModel):
1036
 
1037
  base_model_prefix = "model"
1038
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
@@ -1048,7 +1046,7 @@ class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretr
1048
  self.post_init()
1049
 
1050
 
1051
- class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPretrainedModel):
1052
 
1053
  def __init__(self, config: LSGBartConfig, **kwargs):
1054
 
@@ -1064,7 +1062,7 @@ class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPre
1064
  self.model._init_weights(self.classification_head.out_proj)
1065
 
1066
 
1067
- class LSGBartForQuestionAnswering(BartForQuestionAnswering, LSGBartPretrainedModel):
1068
 
1069
  def __init__(self, config: LSGBartConfig):
1070
 
@@ -1093,7 +1091,7 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
1093
  return self.decoder(*args, **kwargs)
1094
 
1095
 
1096
- class LSGBartForCausalLM(BartForCausalLM, LSGBartPretrainedModel):
1097
 
1098
  def __init__(self, config: LSGBartConfig):
1099
 
 
3
  from transformers.models.bart.modeling_bart import *
4
  from transformers.models.bart.modeling_bart import _expand_mask
5
  import torch.nn as nn
 
6
  import sys
7
 
8
  AUTO_MAP = {
 
15
 
16
  class LSGBartConfig(BartConfig):
17
  """
18
+ This class overrides :class:`~transformers.BartConfig`. Please check the superclass for the appropriate
19
  documentation alongside usage examples.
20
  """
21
 
 
265
  s = (size - step) // 2
266
 
267
  # Pad before block reshaping
268
+ if is_attn_mask:
269
+ pad_value = torch.finfo(hidden_states.dtype).min
270
  hidden_states = hidden_states.transpose(-1, -2)
271
  else:
272
  pad_value = 0
 
295
 
296
  # Pad before block reshaping
297
  if is_attn_mask:
298
+ pad_value = torch.finfo(hidden_states.dtype).min
299
  hidden_states = hidden_states.transpose(-1, -2)
300
  else:
301
  pad_value = 0
 
424
  keys = keys.sum(dim=-2) / (mask + 1e-6)
425
  values = values.sum(dim=-2) / (mask + 1e-6)
426
 
427
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
428
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
429
 
430
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
489
  keys /= mask + 1e-8
490
  values /= mask + 1e-8
491
 
492
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
493
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
494
 
495
  def lsh_round(self, keys, values, mask, output_size):
 
737
  n, t = inputs_.size()[:2]
738
 
739
  if attention_mask is None:
740
+ attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
741
  if self.mask_first_token:
742
  attention_mask[:, 0] = 0
743
 
 
889
  )
890
 
891
 
892
+ class LSGBartDecoder(LSGBartPretrainedModel, BartDecoder):
893
  """
894
  Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
895
  Args:
 
1030
  )
1031
 
1032
 
1033
+ class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditionalGeneration):
1034
 
1035
  base_model_prefix = "model"
1036
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
 
1046
  self.post_init()
1047
 
1048
 
1049
+ class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceClassification):
1050
 
1051
  def __init__(self, config: LSGBartConfig, **kwargs):
1052
 
 
1062
  self.model._init_weights(self.classification_head.out_proj)
1063
 
1064
 
1065
+ class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnswering):
1066
 
1067
  def __init__(self, config: LSGBartConfig):
1068
 
 
1091
  return self.decoder(*args, **kwargs)
1092
 
1093
 
1094
+ class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
1095
 
1096
  def __init__(self, config: LSGBartConfig):
1097