ccdv commited on
Commit
58e3ed1
·
1 Parent(s): ffda167

update for transformers >= 4.29.1

Browse files
Files changed (1) hide show
  1. modeling_lsg_mbart.py +22 -2
modeling_lsg_mbart.py CHANGED
@@ -829,8 +829,13 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
829
  if output_hidden_states:
830
  encoder_states = encoder_states + (hidden_states,)
831
  # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
832
- dropout_probability = random.uniform(0, 1)
833
- if self.training and (dropout_probability < self.layerdrop): # skip the layer
 
 
 
 
 
834
  layer_outputs = (None, None)
835
  else:
836
  if self.gradient_checkpointing and self.training:
@@ -874,6 +879,9 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
874
 
875
  class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
876
 
 
 
 
877
  def __init__(self, config):
878
 
879
  LSGMBartPretrainedModel.__init__(self, config)
@@ -982,7 +990,10 @@ class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, MBartForConditio
982
  r"encoder.version",
983
  r"decoder.version",
984
  r"lm_head.weight",
 
 
985
  ]
 
986
 
987
  def __init__(self, config):
988
 
@@ -997,6 +1008,9 @@ class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, MBartForConditio
997
 
998
  class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, MBartForSequenceClassification):
999
 
 
 
 
1000
  def __init__(self, config, **kwargs):
1001
 
1002
  LSGMBartPretrainedModel.__init__(self, config, **kwargs)
@@ -1013,6 +1027,9 @@ class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, MBartForSequenc
1013
 
1014
  class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, MBartForQuestionAnswering):
1015
 
 
 
 
1016
  def __init__(self, config):
1017
 
1018
  LSGMBartPretrainedModel.__init__(self, config)
@@ -1028,6 +1045,9 @@ class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, MBartForQuestionAnsw
1028
 
1029
  class LSGMBartForCausalLM(LSGMBartPretrainedModel, MBartForCausalLM):
1030
 
 
 
 
1031
  def __init__(self, config):
1032
 
1033
  LSGMBartPretrainedModel.__init__(self, config)
 
829
  if output_hidden_states:
830
  encoder_states = encoder_states + (hidden_states,)
831
  # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
832
+ to_drop = False
833
+ if self.training:
834
+ dropout_probability = torch.rand([])
835
+ if dropout_probability < self.layerdrop: # skip the layer
836
+ to_drop = True
837
+
838
+ if to_drop:
839
  layer_outputs = (None, None)
840
  else:
841
  if self.gradient_checkpointing and self.training:
 
879
 
880
  class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
881
 
882
+ _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
883
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
884
+
885
  def __init__(self, config):
886
 
887
  LSGMBartPretrainedModel.__init__(self, config)
 
990
  r"encoder.version",
991
  r"decoder.version",
992
  r"lm_head.weight",
993
+ "encoder.embed_tokens.weight",
994
+ "decoder.embed_tokens.weight",
995
  ]
996
+ _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
997
 
998
  def __init__(self, config):
999
 
 
1008
 
1009
  class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, MBartForSequenceClassification):
1010
 
1011
+ _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1012
+ _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
1013
+
1014
  def __init__(self, config, **kwargs):
1015
 
1016
  LSGMBartPretrainedModel.__init__(self, config, **kwargs)
 
1027
 
1028
  class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, MBartForQuestionAnswering):
1029
 
1030
+ _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1031
+ _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
1032
+
1033
  def __init__(self, config):
1034
 
1035
  LSGMBartPretrainedModel.__init__(self, config)
 
1045
 
1046
  class LSGMBartForCausalLM(LSGMBartPretrainedModel, MBartForCausalLM):
1047
 
1048
+ _keys_to_ignore_on_load_missing = ["lm_head.weight"]
1049
+ _tied_weights_keys = ["lm_head.weight"]
1050
+
1051
  def __init__(self, config):
1052
 
1053
  LSGMBartPretrainedModel.__init__(self, config)