Guo commited on
Commit
fd09ce7
·
1 Parent(s): b5a6954

walk around for import check

Browse files
Files changed (1) hide show
  1. modeling_jetmoe.py +8 -4
modeling_jetmoe.py CHANGED
@@ -9,7 +9,6 @@ from torch import nn
9
  from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
  from torch.nn import functional as F
11
 
12
- #import megablocks
13
  from transformers.modeling_outputs import (
14
  BaseModelOutputWithPast,
15
  CausalLMOutputWithPast,
@@ -30,9 +29,14 @@ from transformers.cache_utils import Cache, DynamicCache
30
  from .configuration_jetmoe import JetMoEConfig
31
  from . import moe
32
 
33
- if is_flash_attn_2_available():
34
- from flash_attn import flash_attn_func, flash_attn_varlen_func
35
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
 
 
 
36
 
37
  logger = logging.get_logger(__name__)
38
 
 
9
  from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
  from torch.nn import functional as F
11
 
 
12
  from transformers.modeling_outputs import (
13
  BaseModelOutputWithPast,
14
  CausalLMOutputWithPast,
 
29
  from .configuration_jetmoe import JetMoEConfig
30
  from . import moe
31
 
32
+ try:
33
+ if is_flash_attn_2_available():
34
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
35
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
36
+ except ImportError:
37
+ # Workaround for https://github.com/huggingface/transformers/issues/28459,
38
+ # don't move to contextlib.suppress(ImportError)
39
+ pass
40
 
41
  logger = logging.get_logger(__name__)
42