ccdv commited on
Commit
24f6c2a
·
1 Parent(s): a5a907c

update for transformers >= 4.29.1

Browse files
Files changed (1) hide show
  1. modeling_lsg_bart.py +23 -3
modeling_lsg_bart.py CHANGED
@@ -643,6 +643,11 @@ class LSGBartEncoderLayer(BartEncoderLayer):
643
  class LSGBartPretrainedModel(BartPretrainedModel):
644
 
645
  config_class = LSGBartConfig
 
 
 
 
 
646
 
647
  def _set_gradient_checkpointing(self, module, value=False):
648
 
@@ -836,8 +841,13 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
836
  if output_hidden_states:
837
  encoder_states = encoder_states + (hidden_states,)
838
  # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
839
- dropout_probability = random.uniform(0, 1)
840
- if self.training and (dropout_probability < self.layerdrop): # skip the layer
 
 
 
 
 
841
  layer_outputs = (None, None)
842
  else:
843
  if self.gradient_checkpointing and self.training:
@@ -879,6 +889,8 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
879
 
880
  class LSGBartModel(LSGBartPretrainedModel, BartModel):
881
 
 
 
882
  def __init__(self, config):
883
 
884
  LSGBartPretrainedModel.__init__(self, config)
@@ -984,7 +996,8 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
984
  class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditionalGeneration):
985
 
986
  base_model_prefix = "model"
987
- _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
 
988
 
989
  def __init__(self, config):
990
 
@@ -999,6 +1012,8 @@ class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditional
999
 
1000
  class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceClassification):
1001
 
 
 
1002
  def __init__(self, config: LSGBartConfig, **kwargs):
1003
 
1004
  LSGBartPretrainedModel.__init__(self, config, **kwargs)
@@ -1015,6 +1030,8 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceCl
1015
 
1016
  class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnswering):
1017
 
 
 
1018
  def __init__(self, config: LSGBartConfig):
1019
 
1020
  LSGBartPretrainedModel.__init__(self, config)
@@ -1030,6 +1047,9 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnsweri
1030
 
1031
  class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
1032
 
 
 
 
1033
  def __init__(self, config: LSGBartConfig):
1034
 
1035
  LSGBartPretrainedModel.__init__(self, config)
 
643
  class LSGBartPretrainedModel(BartPretrainedModel):
644
 
645
  config_class = LSGBartConfig
646
+ base_model_prefix = "model"
647
+ supports_gradient_checkpointing = True
648
+ _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
649
+ _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
650
+ _skip_keys_device_placement = "past_key_values"
651
 
652
  def _set_gradient_checkpointing(self, module, value=False):
653
 
 
841
  if output_hidden_states:
842
  encoder_states = encoder_states + (hidden_states,)
843
  # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
844
+ to_drop = False
845
+ if self.training:
846
+ dropout_probability = torch.rand([])
847
+ if dropout_probability < self.layerdrop: # skip the layer
848
+ to_drop = True
849
+
850
+ if to_drop:
851
  layer_outputs = (None, None)
852
  else:
853
  if self.gradient_checkpointing and self.training:
 
889
 
890
  class LSGBartModel(LSGBartPretrainedModel, BartModel):
891
 
892
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
893
+
894
  def __init__(self, config):
895
 
896
  LSGBartPretrainedModel.__init__(self, config)
 
996
  class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditionalGeneration):
997
 
998
  base_model_prefix = "model"
999
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
1000
+ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
1001
 
1002
  def __init__(self, config):
1003
 
 
1012
 
1013
  class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceClassification):
1014
 
1015
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1016
+
1017
  def __init__(self, config: LSGBartConfig, **kwargs):
1018
 
1019
  LSGBartPretrainedModel.__init__(self, config, **kwargs)
 
1030
 
1031
  class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnswering):
1032
 
1033
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1034
+
1035
  def __init__(self, config: LSGBartConfig):
1036
 
1037
  LSGBartPretrainedModel.__init__(self, config)
 
1047
 
1048
  class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
1049
 
1050
+ _keys_to_ignore_on_load_missing = ["lm_head.weight"]
1051
+ _tied_weights_keys = ["lm_head.weight"]
1052
+
1053
  def __init__(self, config: LSGBartConfig):
1054
 
1055
  LSGBartPretrainedModel.__init__(self, config)