guymorganb
commited on
Commit
·
9e948ba
1
Parent(s):
143f3be
messing around with differing settings.
Browse files- config.json +7 -6
- modeling_lsg_bert.py +17 -15
config.json
CHANGED
@@ -16,8 +16,14 @@
|
|
16 |
"AutoModelForSequenceClassification": "modeling_lsg_bert.LSGBertForSequenceClassification",
|
17 |
"AutoModelForTokenClassification": "modeling_lsg_bert.LSGBertForTokenClassification"
|
18 |
},
|
19 |
-
"base_model_prefix": "lsg",
|
20 |
"block_size": 128,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
"classifier_dropout": null,
|
22 |
"hidden_act": "gelu",
|
23 |
"hidden_dropout_prob": 0.1,
|
@@ -30,14 +36,9 @@
|
|
30 |
"max_position_embeddings": 4096,
|
31 |
"model_type": "bert",
|
32 |
"num_attention_heads": 16,
|
33 |
-
"num_global_tokens": 1,
|
34 |
"num_hidden_layers": 24,
|
35 |
"pad_token_id": 0,
|
36 |
-
"pool_with_global": true,
|
37 |
"position_embedding_type": "absolute",
|
38 |
-
"sparse_block_size": 128,
|
39 |
-
"sparsity_factor": 2,
|
40 |
-
"sparsity_type": "norm",
|
41 |
"torch_dtype": "float32",
|
42 |
"transformers_version": "4.30.2",
|
43 |
"type_vocab_size": 2,
|
|
|
16 |
"AutoModelForSequenceClassification": "modeling_lsg_bert.LSGBertForSequenceClassification",
|
17 |
"AutoModelForTokenClassification": "modeling_lsg_bert.LSGBertForTokenClassification"
|
18 |
},
|
|
|
19 |
"block_size": 128,
|
20 |
+
"sparse_block_size": 128,
|
21 |
+
"sparsity_factor": 2,
|
22 |
+
"base_model_prefix": "lsg",
|
23 |
+
"sparsity_type": "norm",
|
24 |
+
"is_decoder": false,
|
25 |
+
"pool_with_global": true,
|
26 |
+
"num_global_tokens": 1,
|
27 |
"classifier_dropout": null,
|
28 |
"hidden_act": "gelu",
|
29 |
"hidden_dropout_prob": 0.1,
|
|
|
36 |
"max_position_embeddings": 4096,
|
37 |
"model_type": "bert",
|
38 |
"num_attention_heads": 16,
|
|
|
39 |
"num_hidden_layers": 24,
|
40 |
"pad_token_id": 0,
|
|
|
41 |
"position_embedding_type": "absolute",
|
|
|
|
|
|
|
42 |
"torch_dtype": "float32",
|
43 |
"transformers_version": "4.30.2",
|
44 |
"type_vocab_size": 2,
|
modeling_lsg_bert.py
CHANGED
@@ -59,16 +59,17 @@ class LSGBertConfig(BertConfig):
|
|
59 |
|
60 |
def __init__(
|
61 |
self,
|
62 |
-
adaptive=
|
|
|
63 |
base_model_prefix="lsg",
|
64 |
-
block_size=
|
65 |
lsh_num_pre_rounds=1,
|
66 |
-
sparse_block_size=0,
|
67 |
mask_first_token=False,
|
68 |
-
num_global_tokens=
|
69 |
-
pool_with_global=
|
70 |
-
|
71 |
-
|
|
|
72 |
**kwargs
|
73 |
):
|
74 |
"""Constructs LSGBertConfig."""
|
@@ -85,6 +86,7 @@ class LSGBertConfig(BertConfig):
|
|
85 |
self.sparse_block_size = sparse_block_size
|
86 |
self.sparsity_factor = sparsity_factor
|
87 |
self.sparsity_type = sparsity_type
|
|
|
88 |
|
89 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
90 |
logger.warning(
|
@@ -98,20 +100,20 @@ class LSGBertConfig(BertConfig):
|
|
98 |
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
|
99 |
)
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
elif self.num_global_tokens > 512:
|
107 |
logger.warning(
|
108 |
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
109 |
)
|
110 |
self.num_global_tokens = 512
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
|
116 |
if self.mask_first_token and not pool_with_global:
|
117 |
logger.warning(
|
|
|
59 |
|
60 |
def __init__(
|
61 |
self,
|
62 |
+
adaptive=True,
|
63 |
+
is_decoder = False,
|
64 |
base_model_prefix="lsg",
|
65 |
+
block_size=128,
|
66 |
lsh_num_pre_rounds=1,
|
|
|
67 |
mask_first_token=False,
|
68 |
+
num_global_tokens=1,
|
69 |
+
pool_with_global=True,
|
70 |
+
sparse_block_size=128,
|
71 |
+
sparsity_factor=2,
|
72 |
+
sparsity_type="norm",
|
73 |
**kwargs
|
74 |
):
|
75 |
"""Constructs LSGBertConfig."""
|
|
|
86 |
self.sparse_block_size = sparse_block_size
|
87 |
self.sparsity_factor = sparsity_factor
|
88 |
self.sparsity_type = sparsity_type
|
89 |
+
self.is_decoder = is_decoder
|
90 |
|
91 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
92 |
logger.warning(
|
|
|
100 |
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
|
101 |
)
|
102 |
|
103 |
+
if self.num_global_tokens < 1:
|
104 |
+
logger.warning(
|
105 |
+
"[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
|
106 |
+
)
|
107 |
+
self.num_global_tokens = 1
|
108 |
elif self.num_global_tokens > 512:
|
109 |
logger.warning(
|
110 |
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
111 |
)
|
112 |
self.num_global_tokens = 512
|
113 |
|
114 |
+
if self.sparsity_factor > 0:
|
115 |
+
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
116 |
+
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
117 |
|
118 |
if self.mask_first_token and not pool_with_global:
|
119 |
logger.warning(
|