Move flash_attn assert from __init__ into calling func

#32
Files changed (1) hide show
  1. modeling_phi3_small.py +2 -1
modeling_phi3_small.py CHANGED
@@ -215,7 +215,6 @@ class Phi3SmallSelfAttention(nn.Module):
215
  f"Layer {layer_idx + 1} is using dense attention since it is divisible by "
216
  f"{self.config.dense_attention_every_n_layers}"
217
  )
218
- assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
219
  else:
220
  # BlockSparse related Parameters
221
  self.blocksparse_params = BlockSparseParams.from_config(config)
@@ -419,6 +418,8 @@ class Phi3SmallSelfAttention(nn.Module):
419
  avoid doing that.
420
 
421
  """
 
 
422
  attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
423
  # Get into the correct shape for the Flash Attention API
424
  # shape: (bs, seq_len, nqp, hn)
 
215
  f"Layer {layer_idx + 1} is using dense attention since it is divisible by "
216
  f"{self.config.dense_attention_every_n_layers}"
217
  )
 
218
  else:
219
  # BlockSparse related Parameters
220
  self.blocksparse_params = BlockSparseParams.from_config(config)
 
418
  avoid doing that.
419
 
420
  """
421
+ assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
422
+
423
  attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
424
  # Get into the correct shape for the Flash Attention API
425
  # shape: (bs, seq_len, nqp, hn)