update config about model precision, fix apply_rotary_pos_emb
Browse files- config.json +3 -2
- configuration_qwen.py +5 -1
- modeling_qwen.py +80 -35
config.json
CHANGED
@@ -10,12 +10,13 @@
|
|
10 |
},
|
11 |
"attn_pdrop": 0.0,
|
12 |
"bf16": false,
|
|
|
|
|
13 |
"bias_dropout_fusion": true,
|
14 |
"bos_token_id": 151643,
|
15 |
"embd_pdrop": 0.1,
|
16 |
"eos_token_id": 151643,
|
17 |
"ffn_hidden_size": 22016,
|
18 |
-
"fp16": false,
|
19 |
"initializer_range": 0.02,
|
20 |
"kv_channels": 128,
|
21 |
"layer_norm_epsilon": 1e-05,
|
@@ -38,7 +39,7 @@
|
|
38 |
"tokenizer_type": "QWenTokenizer",
|
39 |
"transformers_version": "4.31.0",
|
40 |
"use_cache": true,
|
41 |
-
"use_flash_attn":
|
42 |
"vocab_size": 151936,
|
43 |
"use_dynamic_ntk": true,
|
44 |
"use_logn_attn": true
|
|
|
10 |
},
|
11 |
"attn_pdrop": 0.0,
|
12 |
"bf16": false,
|
13 |
+
"fp16": false,
|
14 |
+
"fp32": false,
|
15 |
"bias_dropout_fusion": true,
|
16 |
"bos_token_id": 151643,
|
17 |
"embd_pdrop": 0.1,
|
18 |
"eos_token_id": 151643,
|
19 |
"ffn_hidden_size": 22016,
|
|
|
20 |
"initializer_range": 0.02,
|
21 |
"kv_channels": 128,
|
22 |
"layer_norm_epsilon": 1e-05,
|
|
|
39 |
"tokenizer_type": "QWenTokenizer",
|
40 |
"transformers_version": "4.31.0",
|
41 |
"use_cache": true,
|
42 |
+
"use_flash_attn": "auto",
|
43 |
"vocab_size": 151936,
|
44 |
"use_dynamic_ntk": true,
|
45 |
"use_logn_attn": true
|
configuration_qwen.py
CHANGED
@@ -31,7 +31,9 @@ class QWenConfig(PretrainedConfig):
|
|
31 |
use_cache=True,
|
32 |
eos_token_id=151643,
|
33 |
apply_residual_connection_post_layernorm=False,
|
34 |
-
bf16=
|
|
|
|
|
35 |
kv_channels=128,
|
36 |
rotary_pct=1.0,
|
37 |
rotary_emb_base=10000,
|
@@ -63,6 +65,8 @@ class QWenConfig(PretrainedConfig):
|
|
63 |
apply_residual_connection_post_layernorm
|
64 |
)
|
65 |
self.bf16 = bf16
|
|
|
|
|
66 |
self.kv_channels = kv_channels
|
67 |
self.rotary_pct = rotary_pct
|
68 |
self.rotary_emb_base = rotary_emb_base
|
|
|
31 |
use_cache=True,
|
32 |
eos_token_id=151643,
|
33 |
apply_residual_connection_post_layernorm=False,
|
34 |
+
bf16=False,
|
35 |
+
fp16=False,
|
36 |
+
fp32=False,
|
37 |
kv_channels=128,
|
38 |
rotary_pct=1.0,
|
39 |
rotary_emb_base=10000,
|
|
|
65 |
apply_residual_connection_post_layernorm
|
66 |
)
|
67 |
self.bf16 = bf16
|
68 |
+
self.fp16 = fp16
|
69 |
+
self.fp32 = fp32
|
70 |
self.kv_channels = kv_channels
|
71 |
self.rotary_pct = rotary_pct
|
72 |
self.rotary_emb_base = rotary_emb_base
|
modeling_qwen.py
CHANGED
@@ -32,26 +32,13 @@ except ImportError:
|
|
32 |
rearrange = None
|
33 |
from torch import nn
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
use_flash_rotary = True
|
40 |
-
except ImportError:
|
41 |
-
use_flash_rotary = False
|
42 |
-
print(
|
43 |
-
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance "
|
44 |
-
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
|
45 |
-
)
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
rms_norm = None
|
51 |
-
print(
|
52 |
-
"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance "
|
53 |
-
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
|
54 |
-
)
|
55 |
|
56 |
from .configuration_qwen import QWenConfig
|
57 |
from .qwen_generation_utils import (
|
@@ -70,16 +57,6 @@ _CONFIG_FOR_DOC = "QWenConfig"
|
|
70 |
|
71 |
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
72 |
|
73 |
-
try:
|
74 |
-
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
75 |
-
except ImportError:
|
76 |
-
flash_attn_unpadded_func = None
|
77 |
-
print(
|
78 |
-
"Warning: import flash_attn fail, please install FlashAttention "
|
79 |
-
"https://github.com/Dao-AILab/flash-attention"
|
80 |
-
)
|
81 |
-
|
82 |
-
|
83 |
class FlashSelfAttention(torch.nn.Module):
|
84 |
def __init__(
|
85 |
self,
|
@@ -388,7 +365,7 @@ class QWenAttention(nn.Module):
|
|
388 |
present = None
|
389 |
|
390 |
if self.use_logn_attn and not self.training:
|
391 |
-
if self.logn_tensor.device != query.device:
|
392 |
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
|
393 |
seq_start = key.size(1) - query.size(1)
|
394 |
seq_end = key.size(1)
|
@@ -775,11 +752,79 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
775 |
|
776 |
def __init__(self, config):
|
777 |
super().__init__(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
778 |
self.transformer = QWenModel(config)
|
779 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
780 |
-
|
781 |
-
config.bf16 and config.fp16
|
782 |
-
), "In config, bf16 and fp16 cannot both be true"
|
783 |
if config.bf16:
|
784 |
self.transformer.bfloat16()
|
785 |
self.lm_head.bfloat16()
|
@@ -1040,8 +1085,8 @@ def _rotate_half(x):
|
|
1040 |
return torch.cat((-x2, x1), dim=-1)
|
1041 |
|
1042 |
|
1043 |
-
def apply_rotary_pos_emb(t, freqs
|
1044 |
-
if
|
1045 |
t_ = t.float()
|
1046 |
freqs = freqs.squeeze(0).squeeze(1)
|
1047 |
cos = freqs[:, : freqs.shape[-1] // 2].cos()
|
|
|
32 |
rearrange = None
|
33 |
from torch import nn
|
34 |
|
35 |
+
SUPPORT_CUDA = torch.cuda.is_available()
|
36 |
+
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
37 |
+
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
apply_rotary_emb_func = None
|
40 |
+
rms_norm = None
|
41 |
+
flash_attn_unpadded_func = None
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
from .configuration_qwen import QWenConfig
|
44 |
from .qwen_generation_utils import (
|
|
|
57 |
|
58 |
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
class FlashSelfAttention(torch.nn.Module):
|
61 |
def __init__(
|
62 |
self,
|
|
|
365 |
present = None
|
366 |
|
367 |
if self.use_logn_attn and not self.training:
|
368 |
+
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
|
369 |
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
|
370 |
seq_start = key.size(1) - query.size(1)
|
371 |
seq_end = key.size(1)
|
|
|
752 |
|
753 |
def __init__(self, config):
|
754 |
super().__init__(config)
|
755 |
+
assert (
|
756 |
+
config.bf16 + config.fp16 + config.fp32 <= 1
|
757 |
+
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
|
758 |
+
|
759 |
+
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
|
760 |
+
|
761 |
+
if autoset_precision:
|
762 |
+
if SUPPORT_BF16:
|
763 |
+
logger.warn(
|
764 |
+
"The model is automatically converting to bf16 for faster inference. "
|
765 |
+
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
|
766 |
+
)
|
767 |
+
config.bf16 = True
|
768 |
+
elif SUPPORT_FP16:
|
769 |
+
logger.warn(
|
770 |
+
"The model is automatically converting to fp16 for faster inference. "
|
771 |
+
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
|
772 |
+
)
|
773 |
+
config.fp16 = True
|
774 |
+
else:
|
775 |
+
config.fp32 = True
|
776 |
+
|
777 |
+
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
|
778 |
+
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
|
779 |
+
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
|
780 |
+
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
|
781 |
+
if config.fp32:
|
782 |
+
if SUPPORT_BF16:
|
783 |
+
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
|
784 |
+
elif SUPPORT_FP16:
|
785 |
+
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
|
786 |
+
|
787 |
+
if config.use_flash_attn == "auto":
|
788 |
+
if config.bf16 or config.fp16:
|
789 |
+
logger.warn("Try importing flash-attention for faster inference...")
|
790 |
+
config.use_flash_attn = True
|
791 |
+
else:
|
792 |
+
config.use_flash_attn = False
|
793 |
+
if config.use_flash_attn and config.fp32:
|
794 |
+
logger.warn("Flash attention will be disabled because it does NOT support fp32.")
|
795 |
+
|
796 |
+
if config.use_flash_attn:
|
797 |
+
global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
|
798 |
+
try:
|
799 |
+
from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
|
800 |
+
apply_rotary_emb_func = __apply_rotary_emb_func
|
801 |
+
except ImportError:
|
802 |
+
logger.warn(
|
803 |
+
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
|
804 |
+
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
|
805 |
+
)
|
806 |
+
|
807 |
+
try:
|
808 |
+
from flash_attn.ops.rms_norm import rms_norm as __rms_norm
|
809 |
+
rms_norm = __rms_norm
|
810 |
+
except ImportError:
|
811 |
+
logger.warn(
|
812 |
+
"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
|
813 |
+
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
|
814 |
+
)
|
815 |
+
|
816 |
+
try:
|
817 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
|
818 |
+
flash_attn_unpadded_func = __flash_attn_unpadded_func
|
819 |
+
except ImportError:
|
820 |
+
logger.warn(
|
821 |
+
"Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
|
822 |
+
"https://github.com/Dao-AILab/flash-attention"
|
823 |
+
)
|
824 |
+
|
825 |
self.transformer = QWenModel(config)
|
826 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
827 |
+
|
|
|
|
|
828 |
if config.bf16:
|
829 |
self.transformer.bfloat16()
|
830 |
self.lm_head.bfloat16()
|
|
|
1085 |
return torch.cat((-x2, x1), dim=-1)
|
1086 |
|
1087 |
|
1088 |
+
def apply_rotary_pos_emb(t, freqs):
|
1089 |
+
if apply_rotary_emb_func is not None:
|
1090 |
t_ = t.float()
|
1091 |
freqs = freqs.squeeze(0).squeeze(1)
|
1092 |
cos = freqs[:, : freqs.shape[-1] // 2].cos()
|