disable `torch.jit` when using Ascend NPUs
#20
by
statelesshz
- opened
- modeling_chatglm.py +2 -2
modeling_chatglm.py
CHANGED
@@ -21,7 +21,7 @@ from transformers.modeling_outputs import (
|
|
21 |
SequenceClassifierOutputWithPast,
|
22 |
)
|
23 |
from transformers.modeling_utils import PreTrainedModel
|
24 |
-
from transformers.utils import logging
|
25 |
from transformers.generation.logits_process import LogitsProcessor
|
26 |
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
27 |
|
@@ -29,7 +29,7 @@ from .configuration_chatglm import ChatGLMConfig
|
|
29 |
|
30 |
# flags required to enable jit fusion kernels
|
31 |
|
32 |
-
if sys.platform != 'darwin':
|
33 |
torch._C._jit_set_profiling_mode(False)
|
34 |
torch._C._jit_set_profiling_executor(False)
|
35 |
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
|
21 |
SequenceClassifierOutputWithPast,
|
22 |
)
|
23 |
from transformers.modeling_utils import PreTrainedModel
|
24 |
+
from transformers.utils import logging, is_torch_npu_available
|
25 |
from transformers.generation.logits_process import LogitsProcessor
|
26 |
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
27 |
|
|
|
29 |
|
30 |
# flags required to enable jit fusion kernels
|
31 |
|
32 |
+
if sys.platform != 'darwin' and not is_torch_npu_available():
|
33 |
torch._C._jit_set_profiling_mode(False)
|
34 |
torch._C._jit_set_profiling_executor(False)
|
35 |
torch._C._jit_override_can_fuse_on_cpu(True)
|