MiniMax-AI Rocketknight1 HF staff commited on
Commit
372fb1d
·
verified ·
1 Parent(s): 40a455a

Remove import that no longer exists in Transformers (#5)

Browse files

- Remove import that no longer exists in Transformers (72ba538b44b7a26e0167115ef469c41440150fd2)


Co-authored-by: Matthew Carrigan <[email protected]>

Files changed (1) hide show
  1. modeling_minimax_text_01.py +1 -4
modeling_minimax_text_01.py CHANGED
@@ -22,7 +22,6 @@ from transformers.modeling_outputs import (
22
  SequenceClassifierOutputWithPast,
23
  )
24
  from transformers.modeling_utils import PreTrainedModel
25
- from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
26
  from transformers.utils import (
27
  add_start_docstrings,
28
  add_start_docstrings_to_model_forward,
@@ -43,10 +42,8 @@ if is_flash_attn_2_available():
43
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
44
  # It means that the function will not be traced through and simply appear as a node in the graph.
45
  if is_torch_fx_available():
46
- if not is_torch_greater_or_equal_than_1_13:
47
- import torch.fx
48
-
49
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
 
50
  use_triton = eval(os.environ.get("use_triton", default="False"))
51
  debug = eval(os.environ.get("debug", default="False"))
52
  do_eval = eval(os.environ.get("do_eval", default="False"))
 
22
  SequenceClassifierOutputWithPast,
23
  )
24
  from transformers.modeling_utils import PreTrainedModel
 
25
  from transformers.utils import (
26
  add_start_docstrings,
27
  add_start_docstrings_to_model_forward,
 
42
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
43
  # It means that the function will not be traced through and simply appear as a node in the graph.
44
  if is_torch_fx_available():
 
 
 
45
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
46
+
47
  use_triton = eval(os.environ.get("use_triton", default="False"))
48
  debug = eval(os.environ.get("debug", default="False"))
49
  do_eval = eval(os.environ.get("do_eval", default="False"))