guymorganb commited on
Commit
9e948ba
·
1 Parent(s): 143f3be

messing around with differing settings.

Browse files
Files changed (2) hide show
  1. config.json +7 -6
  2. modeling_lsg_bert.py +17 -15
config.json CHANGED
@@ -16,8 +16,14 @@
16
  "AutoModelForSequenceClassification": "modeling_lsg_bert.LSGBertForSequenceClassification",
17
  "AutoModelForTokenClassification": "modeling_lsg_bert.LSGBertForTokenClassification"
18
  },
19
- "base_model_prefix": "lsg",
20
  "block_size": 128,
 
 
 
 
 
 
 
21
  "classifier_dropout": null,
22
  "hidden_act": "gelu",
23
  "hidden_dropout_prob": 0.1,
@@ -30,14 +36,9 @@
30
  "max_position_embeddings": 4096,
31
  "model_type": "bert",
32
  "num_attention_heads": 16,
33
- "num_global_tokens": 1,
34
  "num_hidden_layers": 24,
35
  "pad_token_id": 0,
36
- "pool_with_global": true,
37
  "position_embedding_type": "absolute",
38
- "sparse_block_size": 128,
39
- "sparsity_factor": 2,
40
- "sparsity_type": "norm",
41
  "torch_dtype": "float32",
42
  "transformers_version": "4.30.2",
43
  "type_vocab_size": 2,
 
16
  "AutoModelForSequenceClassification": "modeling_lsg_bert.LSGBertForSequenceClassification",
17
  "AutoModelForTokenClassification": "modeling_lsg_bert.LSGBertForTokenClassification"
18
  },
 
19
  "block_size": 128,
20
+ "sparse_block_size": 128,
21
+ "sparsity_factor": 2,
22
+ "base_model_prefix": "lsg",
23
+ "sparsity_type": "norm",
24
+ "is_decoder": false,
25
+ "pool_with_global": true,
26
+ "num_global_tokens": 1,
27
  "classifier_dropout": null,
28
  "hidden_act": "gelu",
29
  "hidden_dropout_prob": 0.1,
 
36
  "max_position_embeddings": 4096,
37
  "model_type": "bert",
38
  "num_attention_heads": 16,
 
39
  "num_hidden_layers": 24,
40
  "pad_token_id": 0,
 
41
  "position_embedding_type": "absolute",
 
 
 
42
  "torch_dtype": "float32",
43
  "transformers_version": "4.30.2",
44
  "type_vocab_size": 2,
modeling_lsg_bert.py CHANGED
@@ -59,16 +59,17 @@ class LSGBertConfig(BertConfig):
59
 
60
  def __init__(
61
  self,
62
- adaptive=False,
 
63
  base_model_prefix="lsg",
64
- block_size=0,
65
  lsh_num_pre_rounds=1,
66
- sparse_block_size=0,
67
  mask_first_token=False,
68
- num_global_tokens=0,
69
- pool_with_global=False,
70
- sparsity_factor=1,
71
- sparsity_type="non3",
 
72
  **kwargs
73
  ):
74
  """Constructs LSGBertConfig."""
@@ -85,6 +86,7 @@ class LSGBertConfig(BertConfig):
85
  self.sparse_block_size = sparse_block_size
86
  self.sparsity_factor = sparsity_factor
87
  self.sparsity_type = sparsity_type
 
88
 
89
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
90
  logger.warning(
@@ -98,20 +100,20 @@ class LSGBertConfig(BertConfig):
98
  "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
99
  )
100
 
101
- # if self.num_global_tokens < 1:
102
- # logger.warning(
103
- # "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
104
- # )
105
- # self.num_global_tokens = 1
106
  elif self.num_global_tokens > 512:
107
  logger.warning(
108
  "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
109
  )
110
  self.num_global_tokens = 512
111
 
112
- # if self.sparsity_factor > 0:
113
- # assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
114
- # assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
115
 
116
  if self.mask_first_token and not pool_with_global:
117
  logger.warning(
 
59
 
60
  def __init__(
61
  self,
62
+ adaptive=True,
63
+ is_decoder = False,
64
  base_model_prefix="lsg",
65
+ block_size=128,
66
  lsh_num_pre_rounds=1,
 
67
  mask_first_token=False,
68
+ num_global_tokens=1,
69
+ pool_with_global=True,
70
+ sparse_block_size=128,
71
+ sparsity_factor=2,
72
+ sparsity_type="norm",
73
  **kwargs
74
  ):
75
  """Constructs LSGBertConfig."""
 
86
  self.sparse_block_size = sparse_block_size
87
  self.sparsity_factor = sparsity_factor
88
  self.sparsity_type = sparsity_type
89
+ self.is_decoder = is_decoder
90
 
91
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
92
  logger.warning(
 
100
  "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
101
  )
102
 
103
+ if self.num_global_tokens < 1:
104
+ logger.warning(
105
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
106
+ )
107
+ self.num_global_tokens = 1
108
  elif self.num_global_tokens > 512:
109
  logger.warning(
110
  "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
111
  )
112
  self.num_global_tokens = 512
113
 
114
+ if self.sparsity_factor > 0:
115
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
116
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
117
 
118
  if self.mask_first_token and not pool_with_global:
119
  logger.warning(