Crystalcareai commited on
Commit
99b259e
·
verified ·
1 Parent(s): acc1e63

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +3 -0
modeling_gemmoe.py CHANGED
@@ -629,6 +629,9 @@ class GemmoeSdpaAttention(GemmoeAttention):
629
  key_states = key_states.contiguous()
630
  value_states = value_states.contiguous()
631
 
 
 
 
632
  attn_output = torch.nn.functional.scaled_dot_product_attention(
633
  query_states,
634
  key_states,
 
629
  key_states = key_states.contiguous()
630
  value_states = value_states.contiguous()
631
 
632
+
633
+ causal_mask = causal_mask.to(query_states.dtype)
634
+
635
  attn_output = torch.nn.functional.scaled_dot_product_attention(
636
  query_states,
637
  key_states,