fix for 4.23
Browse files- modeling_lsg_mbart.py +20 -94
modeling_lsg_mbart.py
CHANGED
@@ -15,7 +15,7 @@ AUTO_MAP = {
|
|
15 |
|
16 |
class LSGMBartConfig(MBartConfig):
|
17 |
"""
|
18 |
-
This class overrides :class:`~transformers.
|
19 |
documentation alongside usage examples.
|
20 |
"""
|
21 |
|
@@ -57,7 +57,8 @@ class LSGMBartConfig(MBartConfig):
|
|
57 |
|
58 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
59 |
logger.warning(
|
60 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'],
|
|
|
61 |
self.sparsity_type = None
|
62 |
|
63 |
if self.sparsity_type in ["stride", "block_stride"]:
|
@@ -73,7 +74,7 @@ class LSGMBartConfig(MBartConfig):
|
|
73 |
self.num_global_tokens = 1
|
74 |
elif self.num_global_tokens > 512:
|
75 |
logger.warning(
|
76 |
-
"[WARNING CONFIG]: num_global_tokens > 512 is not
|
77 |
)
|
78 |
self.num_global_tokens = 512
|
79 |
|
@@ -81,6 +82,16 @@ class LSGMBartConfig(MBartConfig):
|
|
81 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
82 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
class BaseSelfAttention(nn.Module):
|
86 |
|
@@ -557,8 +568,6 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
|
|
557 |
attention_mask=attention_mask
|
558 |
)
|
559 |
|
560 |
-
if head_mask is not None:
|
561 |
-
context_layer = context_layer * head_mask[:, :, :1, :1]
|
562 |
return self.reshape_output(context_layer)
|
563 |
|
564 |
# Split input into global tokens and other tokens
|
@@ -606,8 +615,6 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
|
|
606 |
|
607 |
# Merge global and local-sparse tokens
|
608 |
context_layer = torch.cat([bos, context_layer], dim=-2)
|
609 |
-
if head_mask is not None:
|
610 |
-
context_layer = context_layer * head_mask[:, :, :1, :1]
|
611 |
context_layer = self.reshape_output(context_layer)
|
612 |
|
613 |
return context_layer
|
@@ -631,27 +638,6 @@ class LSGMBartEncoderLayer(MBartEncoderLayer):
|
|
631 |
)
|
632 |
|
633 |
|
634 |
-
class LSGMBartDecoderLayer(MBartDecoderLayer):
|
635 |
-
|
636 |
-
def __init__(self, config):
|
637 |
-
|
638 |
-
super().__init__(config)
|
639 |
-
|
640 |
-
|
641 |
-
class LSGMBartClassificationHead(MBartClassificationHead):
|
642 |
-
"""Head for sentence-level classification tasks."""
|
643 |
-
|
644 |
-
def __init__(
|
645 |
-
self,
|
646 |
-
input_dim,
|
647 |
-
inner_dim,
|
648 |
-
num_classes,
|
649 |
-
pooler_dropout,
|
650 |
-
):
|
651 |
-
|
652 |
-
super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
|
653 |
-
|
654 |
-
|
655 |
class LSGMBartPretrainedModel(MBartPreTrainedModel):
|
656 |
|
657 |
config_class = LSGMBartConfig
|
@@ -659,7 +645,7 @@ class LSGMBartPretrainedModel(MBartPreTrainedModel):
|
|
659 |
supports_gradient_checkpointing = True
|
660 |
|
661 |
def _set_gradient_checkpointing(self, module, value=False):
|
662 |
-
if isinstance(module, (MBartDecoder, MBartEncoder,
|
663 |
module.gradient_checkpointing = value
|
664 |
|
665 |
|
@@ -674,7 +660,7 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
|
|
674 |
|
675 |
def __init__(self, config, embed_tokens=None):
|
676 |
|
677 |
-
|
678 |
self.dropout = config.dropout
|
679 |
self.layerdrop = config.encoder_layerdrop
|
680 |
|
@@ -811,7 +797,7 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
|
|
811 |
if inputs_embeds is None:
|
812 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
813 |
|
814 |
-
embed_pos = self.embed_positions(
|
815 |
hidden_states = inputs_embeds + embed_pos
|
816 |
|
817 |
# Add global tokens
|
@@ -884,44 +870,6 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, MBartEncoder):
|
|
884 |
)
|
885 |
|
886 |
|
887 |
-
class LSGMBartDecoder(LSGMBartPretrainedModel, MBartDecoder):
|
888 |
-
"""
|
889 |
-
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
|
890 |
-
Args:
|
891 |
-
config: BartConfig
|
892 |
-
embed_tokens (nn.Embedding): output embedding
|
893 |
-
"""
|
894 |
-
|
895 |
-
def __init__(self, config, embed_tokens=None):
|
896 |
-
|
897 |
-
LSGMBartPretrainedModel.__init__(self, config)
|
898 |
-
|
899 |
-
self.dropout = config.dropout
|
900 |
-
self.layerdrop = config.decoder_layerdrop
|
901 |
-
self.padding_idx = config.pad_token_id
|
902 |
-
self.max_target_positions = config.max_position_embeddings
|
903 |
-
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
904 |
-
self.adaptive = config.adaptive
|
905 |
-
|
906 |
-
if embed_tokens is not None:
|
907 |
-
self.embed_tokens = embed_tokens
|
908 |
-
else:
|
909 |
-
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
910 |
-
|
911 |
-
self.embed_positions = MBartLearnedPositionalEmbedding(
|
912 |
-
config.max_position_embeddings,
|
913 |
-
config.d_model,
|
914 |
-
)
|
915 |
-
self.layers = nn.ModuleList([LSGMBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
916 |
-
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
917 |
-
self.layer_norm = nn.LayerNorm(config.d_model)
|
918 |
-
|
919 |
-
self.gradient_checkpointing = False
|
920 |
-
|
921 |
-
# Initialize weights and apply final processing
|
922 |
-
self.post_init()
|
923 |
-
|
924 |
-
|
925 |
class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
|
926 |
|
927 |
def __init__(self, config):
|
@@ -935,7 +883,7 @@ class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
|
|
935 |
self.num_global_tokens = config.num_global_tokens
|
936 |
|
937 |
self.encoder = LSGMBartEncoder(config, self.shared)
|
938 |
-
self.decoder =
|
939 |
|
940 |
# Initialize weights and apply final processing
|
941 |
self.post_init()
|
@@ -1051,7 +999,7 @@ class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, MBartForSequenc
|
|
1051 |
|
1052 |
LSGMBartPretrainedModel.__init__(self, config, **kwargs)
|
1053 |
self.model = LSGMBartModel(config)
|
1054 |
-
self.classification_head =
|
1055 |
config.d_model,
|
1056 |
config.d_model,
|
1057 |
config.num_labels,
|
@@ -1075,35 +1023,13 @@ class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, MBartForQuestionAnsw
|
|
1075 |
|
1076 |
self.model._init_weights(self.qa_outputs)
|
1077 |
|
1078 |
-
|
1079 |
-
class LSGMBartDecoderWrapper(LSGMBartPretrainedModel):
|
1080 |
-
"""
|
1081 |
-
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
1082 |
-
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
1083 |
-
"""
|
1084 |
-
|
1085 |
-
def __init__(self, config):
|
1086 |
-
super().__init__(config)
|
1087 |
-
self.decoder = LSGMBartDecoder(config)
|
1088 |
-
|
1089 |
-
def forward(self, *args, **kwargs):
|
1090 |
-
return self.decoder(*args, **kwargs)
|
1091 |
-
|
1092 |
|
1093 |
class LSGMBartForCausalLM(LSGMBartPretrainedModel, MBartForCausalLM):
|
1094 |
|
1095 |
def __init__(self, config):
|
1096 |
|
1097 |
-
config = copy.deepcopy(config)
|
1098 |
-
config.is_decoder = True
|
1099 |
-
config.is_encoder_decoder = False
|
1100 |
LSGMBartPretrainedModel.__init__(self, config)
|
1101 |
-
self
|
1102 |
-
|
1103 |
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1104 |
-
|
1105 |
-
# Initialize weights and apply final processing
|
1106 |
-
self.post_init()
|
1107 |
|
1108 |
|
1109 |
def str_to_class(classname):
|
|
|
15 |
|
16 |
class LSGMBartConfig(MBartConfig):
|
17 |
"""
|
18 |
+
This class overrides :class:`~transformers.MBartConfig`. Please check the superclass for the appropriate
|
19 |
documentation alongside usage examples.
|
20 |
"""
|
21 |
|
|
|
57 |
|
58 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
59 |
logger.warning(
|
60 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
|
61 |
+
setting sparsity_type=None, computation will skip sparse attention")
|
62 |
self.sparsity_type = None
|
63 |
|
64 |
if self.sparsity_type in ["stride", "block_stride"]:
|
|
|
74 |
self.num_global_tokens = 1
|
75 |
elif self.num_global_tokens > 512:
|
76 |
logger.warning(
|
77 |
+
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
78 |
)
|
79 |
self.num_global_tokens = 512
|
80 |
|
|
|
82 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
83 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
84 |
|
85 |
+
if self.mask_first_token and not pool_with_global:
|
86 |
+
logger.warning(
|
87 |
+
"[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
|
88 |
+
self.pool_with_global = True
|
89 |
+
|
90 |
+
if hasattr(self, "position_embedding_type"):
|
91 |
+
if self.position_embedding_type != "absolute":
|
92 |
+
logger.warning(
|
93 |
+
"[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
|
94 |
+
|
95 |
|
96 |
class BaseSelfAttention(nn.Module):
|
97 |
|
|
|
568 |
attention_mask=attention_mask
|
569 |
)
|
570 |
|
|
|
|
|
571 |
return self.reshape_output(context_layer)
|
572 |
|
573 |
# Split input into global tokens and other tokens
|
|
|
615 |
|
616 |
# Merge global and local-sparse tokens
|
617 |
context_layer = torch.cat([bos, context_layer], dim=-2)
|
|
|
|
|
618 |
context_layer = self.reshape_output(context_layer)
|
619 |
|
620 |
return context_layer
|
|
|
638 |
)
|
639 |
|
640 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
class LSGMBartPretrainedModel(MBartPreTrainedModel):
|
642 |
|
643 |
config_class = LSGMBartConfig
|
|
|
645 |
supports_gradient_checkpointing = True
|
646 |
|
647 |
def _set_gradient_checkpointing(self, module, value=False):
|
648 |
+
if isinstance(module, (MBartDecoder, MBartEncoder, LSGMBartEncoder)):
|
649 |
module.gradient_checkpointing = value
|
650 |
|
651 |
|
|
|
660 |
|
661 |
def __init__(self, config, embed_tokens=None):
|
662 |
|
663 |
+
LSGMBartPretrainedModel.__init__(self, config)
|
664 |
self.dropout = config.dropout
|
665 |
self.layerdrop = config.encoder_layerdrop
|
666 |
|
|
|
797 |
if inputs_embeds is None:
|
798 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
799 |
|
800 |
+
embed_pos = self.embed_positions(inputs_embeds)
|
801 |
hidden_states = inputs_embeds + embed_pos
|
802 |
|
803 |
# Add global tokens
|
|
|
870 |
)
|
871 |
|
872 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
873 |
class LSGMBartModel(LSGMBartPretrainedModel, MBartModel):
|
874 |
|
875 |
def __init__(self, config):
|
|
|
883 |
self.num_global_tokens = config.num_global_tokens
|
884 |
|
885 |
self.encoder = LSGMBartEncoder(config, self.shared)
|
886 |
+
self.decoder = MBartDecoder(config, self.shared)
|
887 |
|
888 |
# Initialize weights and apply final processing
|
889 |
self.post_init()
|
|
|
999 |
|
1000 |
LSGMBartPretrainedModel.__init__(self, config, **kwargs)
|
1001 |
self.model = LSGMBartModel(config)
|
1002 |
+
self.classification_head = MBartClassificationHead(
|
1003 |
config.d_model,
|
1004 |
config.d_model,
|
1005 |
config.num_labels,
|
|
|
1023 |
|
1024 |
self.model._init_weights(self.qa_outputs)
|
1025 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1026 |
|
1027 |
class LSGMBartForCausalLM(LSGMBartPretrainedModel, MBartForCausalLM):
|
1028 |
|
1029 |
def __init__(self, config):
|
1030 |
|
|
|
|
|
|
|
1031 |
LSGMBartPretrainedModel.__init__(self, config)
|
1032 |
+
MBartForCausalLM.__init__(self, config)
|
|
|
|
|
|
|
|
|
|
|
1033 |
|
1034 |
|
1035 |
def str_to_class(classname):
|