softmax_in_fp32
Browse files- configuration_qwen.py +4 -2
- modeling_qwen.py +5 -1
configuration_qwen.py
CHANGED
@@ -37,6 +37,7 @@ class QWenConfig(PretrainedConfig):
|
|
37 |
tie_word_embeddings=False,
|
38 |
use_cache_quantization=False,
|
39 |
use_cache_kernel=False,
|
|
|
40 |
**kwargs,
|
41 |
):
|
42 |
self.vocab_size = vocab_size
|
@@ -61,8 +62,9 @@ class QWenConfig(PretrainedConfig):
|
|
61 |
self.use_logn_attn = use_logn_attn
|
62 |
self.use_flash_attn = use_flash_attn
|
63 |
self.no_bias = no_bias
|
64 |
-
self.use_cache_quantization=use_cache_quantization
|
65 |
-
self.use_cache_kernel=use_cache_kernel
|
|
|
66 |
super().__init__(
|
67 |
tie_word_embeddings=tie_word_embeddings,
|
68 |
**kwargs
|
|
|
37 |
tie_word_embeddings=False,
|
38 |
use_cache_quantization=False,
|
39 |
use_cache_kernel=False,
|
40 |
+
softmax_in_fp32=False,
|
41 |
**kwargs,
|
42 |
):
|
43 |
self.vocab_size = vocab_size
|
|
|
62 |
self.use_logn_attn = use_logn_attn
|
63 |
self.use_flash_attn = use_flash_attn
|
64 |
self.no_bias = no_bias
|
65 |
+
self.use_cache_quantization = use_cache_quantization
|
66 |
+
self.use_cache_kernel = use_cache_kernel
|
67 |
+
self.softmax_in_fp32 = softmax_in_fp32
|
68 |
super().__init__(
|
69 |
tie_word_embeddings=tie_word_embeddings,
|
70 |
**kwargs
|
modeling_qwen.py
CHANGED
@@ -280,6 +280,7 @@ class QWenAttention(nn.Module):
|
|
280 |
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
|
281 |
|
282 |
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
|
|
283 |
self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
|
284 |
self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
|
285 |
cache_dtype = torch.float
|
@@ -346,7 +347,10 @@ class QWenAttention(nn.Module):
|
|
346 |
if attention_mask is not None:
|
347 |
attn_weights = attn_weights + attention_mask
|
348 |
|
349 |
-
|
|
|
|
|
|
|
350 |
|
351 |
attn_weights = attn_weights.type(query.dtype)
|
352 |
attn_weights = self.attn_dropout(attn_weights)
|
|
|
280 |
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
|
281 |
|
282 |
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
283 |
+
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
|
284 |
self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
|
285 |
self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
|
286 |
cache_dtype = torch.float
|
|
|
347 |
if attention_mask is not None:
|
348 |
attn_weights = attn_weights + attention_mask
|
349 |
|
350 |
+
if self.softmax_in_fp32:
|
351 |
+
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
|
352 |
+
else:
|
353 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
354 |
|
355 |
attn_weights = attn_weights.type(query.dtype)
|
356 |
attn_weights = self.attn_dropout(attn_weights)
|