guymorganb
commited on
Commit
·
143f3be
1
Parent(s):
f2c95c3
messing around with differing settings.
Browse files- modeling_lsg_bert.py +14 -14
modeling_lsg_bert.py
CHANGED
@@ -59,16 +59,16 @@ class LSGBertConfig(BertConfig):
|
|
59 |
|
60 |
def __init__(
|
61 |
self,
|
|
|
62 |
base_model_prefix="lsg",
|
63 |
block_size=0,
|
64 |
-
sparse_block_size=128,
|
65 |
-
adaptive=True,
|
66 |
-
sparsity_factor=2,
|
67 |
lsh_num_pre_rounds=1,
|
|
|
68 |
mask_first_token=False,
|
69 |
-
num_global_tokens=
|
70 |
-
pool_with_global=
|
71 |
-
|
|
|
72 |
**kwargs
|
73 |
):
|
74 |
"""Constructs LSGBertConfig."""
|
@@ -98,20 +98,20 @@ class LSGBertConfig(BertConfig):
|
|
98 |
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
|
99 |
)
|
100 |
|
101 |
-
if self.num_global_tokens < 1:
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
elif self.num_global_tokens > 512:
|
107 |
logger.warning(
|
108 |
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
109 |
)
|
110 |
self.num_global_tokens = 512
|
111 |
|
112 |
-
if self.sparsity_factor > 0:
|
113 |
-
|
114 |
-
|
115 |
|
116 |
if self.mask_first_token and not pool_with_global:
|
117 |
logger.warning(
|
|
|
59 |
|
60 |
def __init__(
|
61 |
self,
|
62 |
+
adaptive=False,
|
63 |
base_model_prefix="lsg",
|
64 |
block_size=0,
|
|
|
|
|
|
|
65 |
lsh_num_pre_rounds=1,
|
66 |
+
sparse_block_size=0,
|
67 |
mask_first_token=False,
|
68 |
+
num_global_tokens=0,
|
69 |
+
pool_with_global=False,
|
70 |
+
sparsity_factor=1,
|
71 |
+
sparsity_type="non3",
|
72 |
**kwargs
|
73 |
):
|
74 |
"""Constructs LSGBertConfig."""
|
|
|
98 |
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
|
99 |
)
|
100 |
|
101 |
+
# if self.num_global_tokens < 1:
|
102 |
+
# logger.warning(
|
103 |
+
# "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
|
104 |
+
# )
|
105 |
+
# self.num_global_tokens = 1
|
106 |
elif self.num_global_tokens > 512:
|
107 |
logger.warning(
|
108 |
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
109 |
)
|
110 |
self.num_global_tokens = 512
|
111 |
|
112 |
+
# if self.sparsity_factor > 0:
|
113 |
+
# assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
114 |
+
# assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
115 |
|
116 |
if self.mask_first_token and not pool_with_global:
|
117 |
logger.warning(
|