ccdv commited on
Commit
4688f82
·
1 Parent(s): ce82caa

fix for 4.23

Browse files
Files changed (1) hide show
  1. modeling_lsg_mbart.py +20 -94
modeling_lsg_mbart.py CHANGED
@@ -15,7 +15,7 @@ AUTO_MAP = {
15
 
16
  class LSGMBartConfig(MBartConfig):
17
  """
18
- This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate
19
  documentation alongside usage examples.
20
  """
21
 
@@ -57,7 +57,8 @@ class LSGMBartConfig(MBartConfig):
57
 
58
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
59
  logger.warning(
60
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
 
61
  self.sparsity_type = None
62
 
63
  if self.sparsity_type in ["stride", "block_stride"]:
@@ -73,7 +74,7 @@ class LSGMBartConfig(MBartConfig):
73
  self.num_global_tokens = 1
74
  elif self.num_global_tokens > 512:
75
  logger.warning(
76
- "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
77
  )
78
  self.num_global_tokens = 512
79
 
@@ -81,6 +82,16 @@ class LSGMBartConfig(MBartConfig):
81
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
82
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
83
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  class BaseSelfAttention(nn.Module):
86
 
@@ -557,8 +568,6 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
557
  attention_mask=attention_mask
558
  )
559
 
560
- if head_mask is not None:
561
- context_layer = context_layer * head_mask[:, :, :1, :1]
562
  return self.reshape_output(context_layer)
563
 
564
  # Split input into global tokens and other tokens
@@ -606,8 +615,6 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
606
 
607
  # Merge global and local-sparse tokens
608
  context_layer = torch.cat([bos, context_layer], dim=-2)
609
- if head_mask is not None:
610
- context_layer = context_layer * head_mask[:, :, :1, :1]
611
  context_layer = self.reshape_output(context_layer)
612
 
613
  return context_layer
@@ -631,27 +638,6 @@ class LSGMBartEncoderLayer(MBartEncoderLayer):
631
  )
632
 
633
 
634
- class LSGMBartDecoderLayer(MBartDecoderLayer):
635
-
636
- def __init__(self, config):
637
-
638
- super().__init__(config)
639
-
640
-
641
- class LSGMBartClassificationHead(MBartClassificationHead):
642
- """Head for sentence-level classification tasks."""
643
-
644
- def __init__(
645
- self,
646
- input_dim,
647
- inner_dim,
648
- num_classes,
649
- pooler_dropout,
650
- ):
651
-
652
- super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
653
-
654
-
655
  class LSGMBartPretrainedModel(MBartPreTrainedModel):
656
 
657
  config_class = LSGMBartConfig
@@ -659,7 +645,7 @@ class LSGMBartPretrainedModel(MBartPreTrainedModel):
659
  supports_gradient_checkpointing = True
660
 
661
  def _set_gradient_checkpointing(self, module, value=False):
662
- if isinstance(module, (MBartDecoder, MBartEncoder, LSGMBartDecoder, LSGMBartEncoder)):
663
  module.gradient_checkpointing = value
664
 
665
 
@@ -674,7 +660,7 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
674
 
675
  def __init__(self, config, embed_tokens=None):
676
 
677
- super().__init__(config)
678
  self.dropout = config.dropout
679
  self.layerdrop = config.encoder_layerdrop
680
 
@@ -811,7 +797,7 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
811
  if inputs_embeds is None:
812
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
813
 
814
- embed_pos = self.embed_positions(input_shape)
815
  hidden_states = inputs_embeds + embed_pos
816
 
817
  # Add global tokens
@@ -884,44 +870,6 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
884
  )
885
 
886
 
887
- class LSGMBartDecoder(LSGMBartPretrainedModel, MBartDecoder):
888
- """
889
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
890
- Args:
891
- config: BartConfig
892
- embed_tokens (nn.Embedding): output embedding
893
- """
894
-
895
- def __init__(self, config, embed_tokens=None):
896
-
897
- LSGMBartPretrainedModel.__init__(self, config)
898
-
899
- self.dropout = config.dropout
900
- self.layerdrop = config.decoder_layerdrop
901
- self.padding_idx = config.pad_token_id
902
- self.max_target_positions = config.max_position_embeddings
903
- self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
904
- self.adaptive = config.adaptive
905
-
906
- if embed_tokens is not None:
907
- self.embed_tokens = embed_tokens
908
- else:
909
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
910
-
911
- self.embed_positions = MBartLearnedPositionalEmbedding(
912
- config.max_position_embeddings,
913
- config.d_model,
914
- )
915
- self.layers = nn.ModuleList([LSGMBartDecoderLayer(config) for _ in range(config.decoder_layers)])
916
- self.layernorm_embedding = nn.LayerNorm(config.d_model)
917
- self.layer_norm = nn.LayerNorm(config.d_model)
918
-
919
- self.gradient_checkpointing = False
920
-
921
- # Initialize weights and apply final processing
922
- self.post_init()
923
-
924
-
925
  class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
926
 
927
  def __init__(self, config):
@@ -935,7 +883,7 @@ class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
935
  self.num_global_tokens = config.num_global_tokens
936
 
937
  self.encoder = LSGMBartEncoder(config, self.shared)
938
- self.decoder = LSGMBartDecoder(config, self.shared)
939
 
940
  # Initialize weights and apply final processing
941
  self.post_init()
@@ -1051,7 +999,7 @@ class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, MBartForSequenc
1051
 
1052
  LSGMBartPretrainedModel.__init__(self, config, **kwargs)
1053
  self.model = LSGMBartModel(config)
1054
- self.classification_head = LSGMBartClassificationHead(
1055
  config.d_model,
1056
  config.d_model,
1057
  config.num_labels,
@@ -1075,35 +1023,13 @@ class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, MBartForQuestionAnsw
1075
 
1076
  self.model._init_weights(self.qa_outputs)
1077
 
1078
-
1079
- class LSGMBartDecoderWrapper(LSGMBartPretrainedModel):
1080
- """
1081
- This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1082
- used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
1083
- """
1084
-
1085
- def __init__(self, config):
1086
- super().__init__(config)
1087
- self.decoder = LSGMBartDecoder(config)
1088
-
1089
- def forward(self, *args, **kwargs):
1090
- return self.decoder(*args, **kwargs)
1091
-
1092
 
1093
  class LSGMBartForCausalLM(LSGMBartPretrainedModel, MBartForCausalLM):
1094
 
1095
  def __init__(self, config):
1096
 
1097
- config = copy.deepcopy(config)
1098
- config.is_decoder = True
1099
- config.is_encoder_decoder = False
1100
  LSGMBartPretrainedModel.__init__(self, config)
1101
- self.model = LSGMBartDecoderWrapper(config)
1102
-
1103
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1104
-
1105
- # Initialize weights and apply final processing
1106
- self.post_init()
1107
 
1108
 
1109
  def str_to_class(classname):
 
15
 
16
  class LSGMBartConfig(MBartConfig):
17
  """
18
+ This class overrides :class:`~transformers.MBartConfig`. Please check the superclass for the appropriate
19
  documentation alongside usage examples.
20
  """
21
 
 
57
 
58
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
59
  logger.warning(
60
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
61
+ setting sparsity_type=None, computation will skip sparse attention")
62
  self.sparsity_type = None
63
 
64
  if self.sparsity_type in ["stride", "block_stride"]:
 
74
  self.num_global_tokens = 1
75
  elif self.num_global_tokens > 512:
76
  logger.warning(
77
+ "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
78
  )
79
  self.num_global_tokens = 512
80
 
 
82
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
83
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
84
 
85
+ if self.mask_first_token and not pool_with_global:
86
+ logger.warning(
87
+ "[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
88
+ self.pool_with_global = True
89
+
90
+ if hasattr(self, "position_embedding_type"):
91
+ if self.position_embedding_type != "absolute":
92
+ logger.warning(
93
+ "[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
94
+
95
 
96
  class BaseSelfAttention(nn.Module):
97
 
 
568
  attention_mask=attention_mask
569
  )
570
 
 
 
571
  return self.reshape_output(context_layer)
572
 
573
  # Split input into global tokens and other tokens
 
615
 
616
  # Merge global and local-sparse tokens
617
  context_layer = torch.cat([bos, context_layer], dim=-2)
 
 
618
  context_layer = self.reshape_output(context_layer)
619
 
620
  return context_layer
 
638
  )
639
 
640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
  class LSGMBartPretrainedModel(MBartPreTrainedModel):
642
 
643
  config_class = LSGMBartConfig
 
645
  supports_gradient_checkpointing = True
646
 
647
  def _set_gradient_checkpointing(self, module, value=False):
648
+ if isinstance(module, (MBartDecoder, MBartEncoder, LSGMBartEncoder)):
649
  module.gradient_checkpointing = value
650
 
651
 
 
660
 
661
  def __init__(self, config, embed_tokens=None):
662
 
663
+ LSGMBartPretrainedModel.__init__(self, config)
664
  self.dropout = config.dropout
665
  self.layerdrop = config.encoder_layerdrop
666
 
 
797
  if inputs_embeds is None:
798
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
799
 
800
+ embed_pos = self.embed_positions(inputs_embeds)
801
  hidden_states = inputs_embeds + embed_pos
802
 
803
  # Add global tokens
 
870
  )
871
 
872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
  class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
874
 
875
  def __init__(self, config):
 
883
  self.num_global_tokens = config.num_global_tokens
884
 
885
  self.encoder = LSGMBartEncoder(config, self.shared)
886
+ self.decoder = MBartDecoder(config, self.shared)
887
 
888
  # Initialize weights and apply final processing
889
  self.post_init()
 
999
 
1000
  LSGMBartPretrainedModel.__init__(self, config, **kwargs)
1001
  self.model = LSGMBartModel(config)
1002
+ self.classification_head = MBartClassificationHead(
1003
  config.d_model,
1004
  config.d_model,
1005
  config.num_labels,
 
1023
 
1024
  self.model._init_weights(self.qa_outputs)
1025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1026
 
1027
  class LSGMBartForCausalLM(LSGMBartPretrainedModel, MBartForCausalLM):
1028
 
1029
  def __init__(self, config):
1030
 
 
 
 
1031
  LSGMBartPretrainedModel.__init__(self, config)
1032
+ MBartForCausalLM.__init__(self, config)
 
 
 
 
 
1033
 
1034
 
1035
  def str_to_class(classname):