Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +8 -3
modeling_gemmoe.py
CHANGED
@@ -617,9 +617,10 @@ class GemmoeSdpaAttention(GemmoeAttention):
|
|
617 |
if attention_mask is not None and cache_position is not None:
|
618 |
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
619 |
|
620 |
-
#
|
621 |
-
|
622 |
-
|
|
|
623 |
|
624 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
625 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
@@ -628,6 +629,10 @@ class GemmoeSdpaAttention(GemmoeAttention):
|
|
628 |
key_states = key_states.contiguous()
|
629 |
value_states = value_states.contiguous()
|
630 |
|
|
|
|
|
|
|
|
|
631 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
632 |
query_states,
|
633 |
key_states,
|
|
|
617 |
if attention_mask is not None and cache_position is not None:
|
618 |
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
619 |
|
620 |
+
# Ensure query, key, and value states have the same dtype
|
621 |
+
common_dtype = query_states.dtype
|
622 |
+
key_states = key_states.to(dtype=common_dtype)
|
623 |
+
value_states = value_states.to(dtype=common_dtype)
|
624 |
|
625 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
626 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
|
|
629 |
key_states = key_states.contiguous()
|
630 |
value_states = value_states.contiguous()
|
631 |
|
632 |
+
# Cast causal_mask to the same dtype as query_states
|
633 |
+
if causal_mask is not None:
|
634 |
+
causal_mask = causal_mask.to(dtype=query_states.dtype)
|
635 |
+
|
636 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
637 |
query_states,
|
638 |
key_states,
|