FancyZhao commited on
Commit
7326a58
·
1 Parent(s): 60dcc50

feat: check flash-attn version if installed (#15)

Browse files

- feat: check flash-attn version if installed (86ebcde8a3ee75d5a0c20c28408e1734740f3070)
- add version info (49d999cf6a7241dc9a34523c583650751235b92d)
- use packaging.version (0e24941c9f142fddd95079be8dcb6347674c023b)
- update (824bd56461efd039413c4c701e46285947808b49)

Files changed (1) hide show
  1. modeling_yi.py +10 -3
modeling_yi.py CHANGED
@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
4
 
5
  import torch.utils.checkpoint
6
  from einops import repeat
 
7
  from torch import nn
8
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
  from transformers.activations import ACT2FN
@@ -25,8 +26,12 @@ from .configuration_yi import YiConfig
25
 
26
  is_flash_attn_available = True
27
  try:
28
- from flash_attn import flash_attn_func
29
- except Exception:
 
 
 
 
30
  is_flash_attn_available = False
31
 
32
  logger = logging.get_logger(__name__)
@@ -539,7 +544,9 @@ class YiModel(YiPreTrainedModel):
539
  def _prepare_decoder_attention_mask(
540
  self, attention_mask, input_ids, inputs_embeds, past_key_values_length
541
  ):
542
- input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
 
 
543
  # create causal mask
544
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
545
  combined_attention_mask = None
 
4
 
5
  import torch.utils.checkpoint
6
  from einops import repeat
7
+ from packaging import version
8
  from torch import nn
9
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
  from transformers.activations import ACT2FN
 
26
 
27
  is_flash_attn_available = True
28
  try:
29
+ from flash_attn import flash_attn_func, __version__
30
+
31
+ assert version.parse(__version__) >= version.parse(
32
+ "2.3.0"
33
+ ), "please update your flash_attn version (>= 2.3.0)"
34
+ except ModuleNotFoundError:
35
  is_flash_attn_available = False
36
 
37
  logger = logging.get_logger(__name__)
 
544
  def _prepare_decoder_attention_mask(
545
  self, attention_mask, input_ids, inputs_embeds, past_key_values_length
546
  ):
547
+ input_shape = (
548
+ input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
549
+ )
550
  # create causal mask
551
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
552
  combined_attention_mask = None