Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +3 -0
modeling_gemmoe.py
CHANGED
@@ -629,6 +629,9 @@ class GemmoeSdpaAttention(GemmoeAttention):
|
|
629 |
key_states = key_states.contiguous()
|
630 |
value_states = value_states.contiguous()
|
631 |
|
|
|
|
|
|
|
632 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
633 |
query_states,
|
634 |
key_states,
|
|
|
629 |
key_states = key_states.contiguous()
|
630 |
value_states = value_states.contiguous()
|
631 |
|
632 |
+
|
633 |
+
causal_mask = causal_mask.to(query_states.dtype)
|
634 |
+
|
635 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
636 |
query_states,
|
637 |
key_states,
|