fix version 4.23
Browse files- modeling_lsg_bart.py +18 -91
modeling_lsg_bart.py
CHANGED
@@ -57,7 +57,8 @@ class LSGBartConfig(BartConfig):
|
|
57 |
|
58 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
59 |
logger.warning(
|
60 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'],
|
|
|
61 |
self.sparsity_type = None
|
62 |
|
63 |
if self.sparsity_type in ["stride", "block_stride"]:
|
@@ -73,7 +74,7 @@ class LSGBartConfig(BartConfig):
|
|
73 |
self.num_global_tokens = 1
|
74 |
elif self.num_global_tokens > 512:
|
75 |
logger.warning(
|
76 |
-
"[WARNING CONFIG]: num_global_tokens > 512 is not
|
77 |
)
|
78 |
self.num_global_tokens = 512
|
79 |
|
@@ -81,6 +82,16 @@ class LSGBartConfig(BartConfig):
|
|
81 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
82 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
class BaseSelfAttention(nn.Module):
|
86 |
|
@@ -557,8 +568,6 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
557 |
attention_mask=attention_mask
|
558 |
)
|
559 |
|
560 |
-
if head_mask is not None:
|
561 |
-
context_layer = context_layer * head_mask[:, :, :1, :1]
|
562 |
return self.reshape_output(context_layer)
|
563 |
|
564 |
# Split input into global tokens and other tokens
|
@@ -606,8 +615,6 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
606 |
|
607 |
# Merge global and local-sparse tokens
|
608 |
context_layer = torch.cat([bos, context_layer], dim=-2)
|
609 |
-
if head_mask is not None:
|
610 |
-
context_layer = context_layer * head_mask[:, :, :1, :1]
|
611 |
context_layer = self.reshape_output(context_layer)
|
612 |
|
613 |
return context_layer
|
@@ -630,35 +637,14 @@ class LSGBartEncoderLayer(BartEncoderLayer):
|
|
630 |
dropout=config.attention_dropout,
|
631 |
)
|
632 |
|
633 |
-
|
634 |
-
class LSGBartDecoderLayer(BartDecoderLayer):
|
635 |
-
|
636 |
-
def __init__(self, config):
|
637 |
-
|
638 |
-
super().__init__(config)
|
639 |
|
640 |
-
|
641 |
-
class LSGBartClassificationHead(BartClassificationHead):
|
642 |
-
"""Head for sentence-level classification tasks."""
|
643 |
-
|
644 |
-
def __init__(
|
645 |
-
self,
|
646 |
-
input_dim,
|
647 |
-
inner_dim,
|
648 |
-
num_classes,
|
649 |
-
pooler_dropout,
|
650 |
-
):
|
651 |
-
|
652 |
-
super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
|
653 |
-
|
654 |
-
|
655 |
class LSGBartPretrainedModel(BartPretrainedModel):
|
656 |
|
657 |
config_class = LSGBartConfig
|
658 |
|
659 |
def _set_gradient_checkpointing(self, module, value=False):
|
660 |
|
661 |
-
if isinstance(module, (BartDecoder, BartEncoder,
|
662 |
module.gradient_checkpointing = value
|
663 |
|
664 |
|
@@ -818,7 +804,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
818 |
if inputs_embeds is None:
|
819 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
820 |
|
821 |
-
embed_pos = self.embed_positions(
|
822 |
hidden_states = inputs_embeds + embed_pos
|
823 |
|
824 |
# Add global tokens
|
@@ -889,43 +875,6 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
889 |
)
|
890 |
|
891 |
|
892 |
-
class LSGBartDecoder(LSGBartPretrainedModel, BartDecoder):
|
893 |
-
"""
|
894 |
-
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
|
895 |
-
Args:
|
896 |
-
config: BartConfig
|
897 |
-
embed_tokens (nn.Embedding): output embedding
|
898 |
-
"""
|
899 |
-
|
900 |
-
def __init__(self, config, embed_tokens=None):
|
901 |
-
|
902 |
-
LSGBartPretrainedModel.__init__(self, config)
|
903 |
-
|
904 |
-
self.dropout = config.dropout
|
905 |
-
self.layerdrop = config.decoder_layerdrop
|
906 |
-
self.padding_idx = config.pad_token_id
|
907 |
-
self.max_target_positions = config.max_position_embeddings
|
908 |
-
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
909 |
-
self.adaptive = config.adaptive
|
910 |
-
|
911 |
-
if embed_tokens is not None:
|
912 |
-
self.embed_tokens = embed_tokens
|
913 |
-
else:
|
914 |
-
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
915 |
-
|
916 |
-
self.embed_positions = BartLearnedPositionalEmbedding(
|
917 |
-
config.max_position_embeddings,
|
918 |
-
config.d_model,
|
919 |
-
)
|
920 |
-
self.layers = nn.ModuleList([LSGBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
921 |
-
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
922 |
-
|
923 |
-
self.gradient_checkpointing = False
|
924 |
-
|
925 |
-
# Initialize weights and apply final processing
|
926 |
-
self.post_init()
|
927 |
-
|
928 |
-
|
929 |
class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
930 |
|
931 |
def __init__(self, config):
|
@@ -939,7 +888,7 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
|
939 |
self.num_global_tokens = config.num_global_tokens
|
940 |
|
941 |
self.encoder = LSGBartEncoder(config, self.shared)
|
942 |
-
self.decoder =
|
943 |
|
944 |
# Initialize weights and apply final processing
|
945 |
self.post_init()
|
@@ -1052,7 +1001,7 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceCl
|
|
1052 |
|
1053 |
LSGBartPretrainedModel.__init__(self, config, **kwargs)
|
1054 |
self.model = LSGBartModel(config)
|
1055 |
-
self.classification_head =
|
1056 |
config.d_model,
|
1057 |
config.d_model,
|
1058 |
config.num_labels,
|
@@ -1077,34 +1026,12 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnsweri
|
|
1077 |
self.model._init_weights(self.qa_outputs)
|
1078 |
|
1079 |
|
1080 |
-
class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
1081 |
-
"""
|
1082 |
-
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
1083 |
-
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
1084 |
-
"""
|
1085 |
-
|
1086 |
-
def __init__(self, config: LSGBartConfig):
|
1087 |
-
super().__init__(config)
|
1088 |
-
self.decoder = LSGBartDecoder(config)
|
1089 |
-
|
1090 |
-
def forward(self, *args, **kwargs):
|
1091 |
-
return self.decoder(*args, **kwargs)
|
1092 |
-
|
1093 |
-
|
1094 |
class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
|
1095 |
|
1096 |
def __init__(self, config: LSGBartConfig):
|
1097 |
|
1098 |
-
config = copy.deepcopy(config)
|
1099 |
-
config.is_decoder = True
|
1100 |
-
config.is_encoder_decoder = False
|
1101 |
LSGBartPretrainedModel.__init__(self, config)
|
1102 |
-
self
|
1103 |
-
|
1104 |
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1105 |
-
|
1106 |
-
# Initialize weights and apply final processing
|
1107 |
-
self.post_init()
|
1108 |
|
1109 |
|
1110 |
def str_to_class(classname):
|
|
|
57 |
|
58 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
59 |
logger.warning(
|
60 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
|
61 |
+
setting sparsity_type=None, computation will skip sparse attention")
|
62 |
self.sparsity_type = None
|
63 |
|
64 |
if self.sparsity_type in ["stride", "block_stride"]:
|
|
|
74 |
self.num_global_tokens = 1
|
75 |
elif self.num_global_tokens > 512:
|
76 |
logger.warning(
|
77 |
+
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
78 |
)
|
79 |
self.num_global_tokens = 512
|
80 |
|
|
|
82 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
83 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
84 |
|
85 |
+
if self.mask_first_token and not pool_with_global:
|
86 |
+
logger.warning(
|
87 |
+
"[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
|
88 |
+
self.pool_with_global = True
|
89 |
+
|
90 |
+
if hasattr(self, "position_embedding_type"):
|
91 |
+
if self.position_embedding_type != "absolute":
|
92 |
+
logger.warning(
|
93 |
+
"[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
|
94 |
+
|
95 |
|
96 |
class BaseSelfAttention(nn.Module):
|
97 |
|
|
|
568 |
attention_mask=attention_mask
|
569 |
)
|
570 |
|
|
|
|
|
571 |
return self.reshape_output(context_layer)
|
572 |
|
573 |
# Split input into global tokens and other tokens
|
|
|
615 |
|
616 |
# Merge global and local-sparse tokens
|
617 |
context_layer = torch.cat([bos, context_layer], dim=-2)
|
|
|
|
|
618 |
context_layer = self.reshape_output(context_layer)
|
619 |
|
620 |
return context_layer
|
|
|
637 |
dropout=config.attention_dropout,
|
638 |
)
|
639 |
|
|
|
|
|
|
|
|
|
|
|
|
|
640 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
class LSGBartPretrainedModel(BartPretrainedModel):
|
642 |
|
643 |
config_class = LSGBartConfig
|
644 |
|
645 |
def _set_gradient_checkpointing(self, module, value=False):
|
646 |
|
647 |
+
if isinstance(module, (BartDecoder, BartEncoder, LSGBartEncoder)):
|
648 |
module.gradient_checkpointing = value
|
649 |
|
650 |
|
|
|
804 |
if inputs_embeds is None:
|
805 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
806 |
|
807 |
+
embed_pos = self.embed_positions(inputs_embeds)
|
808 |
hidden_states = inputs_embeds + embed_pos
|
809 |
|
810 |
# Add global tokens
|
|
|
875 |
)
|
876 |
|
877 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
878 |
class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
879 |
|
880 |
def __init__(self, config):
|
|
|
888 |
self.num_global_tokens = config.num_global_tokens
|
889 |
|
890 |
self.encoder = LSGBartEncoder(config, self.shared)
|
891 |
+
self.decoder = BartDecoder(config, self.shared)
|
892 |
|
893 |
# Initialize weights and apply final processing
|
894 |
self.post_init()
|
|
|
1001 |
|
1002 |
LSGBartPretrainedModel.__init__(self, config, **kwargs)
|
1003 |
self.model = LSGBartModel(config)
|
1004 |
+
self.classification_head = BartClassificationHead(
|
1005 |
config.d_model,
|
1006 |
config.d_model,
|
1007 |
config.num_labels,
|
|
|
1026 |
self.model._init_weights(self.qa_outputs)
|
1027 |
|
1028 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1029 |
class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
|
1030 |
|
1031 |
def __init__(self, config: LSGBartConfig):
|
1032 |
|
|
|
|
|
|
|
1033 |
LSGBartPretrainedModel.__init__(self, config)
|
1034 |
+
BartForCausalLM.__init__(self, config)
|
|
|
|
|
|
|
|
|
|
|
1035 |
|
1036 |
|
1037 |
def str_to_class(classname):
|