guymorganb commited on
Commit
143f3be
·
1 Parent(s): f2c95c3

messing around with differing settings.

Browse files
Files changed (1) hide show
  1. modeling_lsg_bert.py +14 -14
modeling_lsg_bert.py CHANGED
@@ -59,16 +59,16 @@ class LSGBertConfig(BertConfig):
59
 
60
  def __init__(
61
  self,
 
62
  base_model_prefix="lsg",
63
  block_size=0,
64
- sparse_block_size=128,
65
- adaptive=True,
66
- sparsity_factor=2,
67
  lsh_num_pre_rounds=1,
 
68
  mask_first_token=False,
69
- num_global_tokens=1,
70
- pool_with_global=True,
71
- sparsity_type="none",
 
72
  **kwargs
73
  ):
74
  """Constructs LSGBertConfig."""
@@ -98,20 +98,20 @@ class LSGBertConfig(BertConfig):
98
  "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
99
  )
100
 
101
- if self.num_global_tokens < 1:
102
- logger.warning(
103
- "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
104
- )
105
- self.num_global_tokens = 1
106
  elif self.num_global_tokens > 512:
107
  logger.warning(
108
  "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
109
  )
110
  self.num_global_tokens = 512
111
 
112
- if self.sparsity_factor > 0:
113
- assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
114
- assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
115
 
116
  if self.mask_first_token and not pool_with_global:
117
  logger.warning(
 
59
 
60
  def __init__(
61
  self,
62
+ adaptive=False,
63
  base_model_prefix="lsg",
64
  block_size=0,
 
 
 
65
  lsh_num_pre_rounds=1,
66
+ sparse_block_size=0,
67
  mask_first_token=False,
68
+ num_global_tokens=0,
69
+ pool_with_global=False,
70
+ sparsity_factor=1,
71
+ sparsity_type="non3",
72
  **kwargs
73
  ):
74
  """Constructs LSGBertConfig."""
 
98
  "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
99
  )
100
 
101
+ # if self.num_global_tokens < 1:
102
+ # logger.warning(
103
+ # "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
104
+ # )
105
+ # self.num_global_tokens = 1
106
  elif self.num_global_tokens > 512:
107
  logger.warning(
108
  "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
109
  )
110
  self.num_global_tokens = 512
111
 
112
+ # if self.sparsity_factor > 0:
113
+ # assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
114
+ # assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
115
 
116
  if self.mask_first_token and not pool_with_global:
117
  logger.warning(