update for transformers >= 4.29.1
Browse files- modeling_lsg_mbart.py +22 -2
modeling_lsg_mbart.py
CHANGED
@@ -829,8 +829,13 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
|
|
829 |
if output_hidden_states:
|
830 |
encoder_states = encoder_states + (hidden_states,)
|
831 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
832 |
-
|
833 |
-
if self.training
|
|
|
|
|
|
|
|
|
|
|
834 |
layer_outputs = (None, None)
|
835 |
else:
|
836 |
if self.gradient_checkpointing and self.training:
|
@@ -874,6 +879,9 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
|
|
874 |
|
875 |
class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
|
876 |
|
|
|
|
|
|
|
877 |
def __init__(self, config):
|
878 |
|
879 |
LSGMBartPretrainedModel.__init__(self, config)
|
@@ -982,7 +990,10 @@ class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, MBartForConditio
|
|
982 |
r"encoder.version",
|
983 |
r"decoder.version",
|
984 |
r"lm_head.weight",
|
|
|
|
|
985 |
]
|
|
|
986 |
|
987 |
def __init__(self, config):
|
988 |
|
@@ -997,6 +1008,9 @@ class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, MBartForConditio
|
|
997 |
|
998 |
class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, MBartForSequenceClassification):
|
999 |
|
|
|
|
|
|
|
1000 |
def __init__(self, config, **kwargs):
|
1001 |
|
1002 |
LSGMBartPretrainedModel.__init__(self, config, **kwargs)
|
@@ -1013,6 +1027,9 @@ class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, MBartForSequenc
|
|
1013 |
|
1014 |
class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, MBartForQuestionAnswering):
|
1015 |
|
|
|
|
|
|
|
1016 |
def __init__(self, config):
|
1017 |
|
1018 |
LSGMBartPretrainedModel.__init__(self, config)
|
@@ -1028,6 +1045,9 @@ class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, MBartForQuestionAnsw
|
|
1028 |
|
1029 |
class LSGMBartForCausalLM(LSGMBartPretrainedModel, MBartForCausalLM):
|
1030 |
|
|
|
|
|
|
|
1031 |
def __init__(self, config):
|
1032 |
|
1033 |
LSGMBartPretrainedModel.__init__(self, config)
|
|
|
829 |
if output_hidden_states:
|
830 |
encoder_states = encoder_states + (hidden_states,)
|
831 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
832 |
+
to_drop = False
|
833 |
+
if self.training:
|
834 |
+
dropout_probability = torch.rand([])
|
835 |
+
if dropout_probability < self.layerdrop: # skip the layer
|
836 |
+
to_drop = True
|
837 |
+
|
838 |
+
if to_drop:
|
839 |
layer_outputs = (None, None)
|
840 |
else:
|
841 |
if self.gradient_checkpointing and self.training:
|
|
|
879 |
|
880 |
class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
|
881 |
|
882 |
+
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
883 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
884 |
+
|
885 |
def __init__(self, config):
|
886 |
|
887 |
LSGMBartPretrainedModel.__init__(self, config)
|
|
|
990 |
r"encoder.version",
|
991 |
r"decoder.version",
|
992 |
r"lm_head.weight",
|
993 |
+
"encoder.embed_tokens.weight",
|
994 |
+
"decoder.embed_tokens.weight",
|
995 |
]
|
996 |
+
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
|
997 |
|
998 |
def __init__(self, config):
|
999 |
|
|
|
1008 |
|
1009 |
class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, MBartForSequenceClassification):
|
1010 |
|
1011 |
+
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
1012 |
+
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
|
1013 |
+
|
1014 |
def __init__(self, config, **kwargs):
|
1015 |
|
1016 |
LSGMBartPretrainedModel.__init__(self, config, **kwargs)
|
|
|
1027 |
|
1028 |
class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, MBartForQuestionAnswering):
|
1029 |
|
1030 |
+
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
1031 |
+
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
|
1032 |
+
|
1033 |
def __init__(self, config):
|
1034 |
|
1035 |
LSGMBartPretrainedModel.__init__(self, config)
|
|
|
1045 |
|
1046 |
class LSGMBartForCausalLM(LSGMBartPretrainedModel, MBartForCausalLM):
|
1047 |
|
1048 |
+
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
1049 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1050 |
+
|
1051 |
def __init__(self, config):
|
1052 |
|
1053 |
LSGMBartPretrainedModel.__init__(self, config)
|