Crystalcareai commited on
Commit
1b5a82b
·
verified ·
1 Parent(s): a6948ec

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +8 -3
modeling_gemmoe.py CHANGED
@@ -617,9 +617,10 @@ class GemmoeSdpaAttention(GemmoeAttention):
617
  if attention_mask is not None and cache_position is not None:
618
  causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
619
 
620
- # Convert causal_mask to the same dtype as query_states
621
- if causal_mask is not None:
622
- causal_mask = causal_mask.to(dtype=query_states.dtype)
 
623
 
624
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
625
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -628,6 +629,10 @@ class GemmoeSdpaAttention(GemmoeAttention):
628
  key_states = key_states.contiguous()
629
  value_states = value_states.contiguous()
630
 
 
 
 
 
631
  attn_output = torch.nn.functional.scaled_dot_product_attention(
632
  query_states,
633
  key_states,
 
617
  if attention_mask is not None and cache_position is not None:
618
  causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
619
 
620
+ # Ensure query, key, and value states have the same dtype
621
+ common_dtype = query_states.dtype
622
+ key_states = key_states.to(dtype=common_dtype)
623
+ value_states = value_states.to(dtype=common_dtype)
624
 
625
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
626
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
 
629
  key_states = key_states.contiguous()
630
  value_states = value_states.contiguous()
631
 
632
+ # Cast causal_mask to the same dtype as query_states
633
+ if causal_mask is not None:
634
+ causal_mask = causal_mask.to(dtype=query_states.dtype)
635
+
636
  attn_output = torch.nn.functional.scaled_dot_product_attention(
637
  query_states,
638
  key_states,