Crystalcareai commited on
Commit
a6948ec
·
verified ·
1 Parent(s): 2fbb6c8

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +4 -6
modeling_gemmoe.py CHANGED
@@ -616,11 +616,10 @@ class GemmoeSdpaAttention(GemmoeAttention):
616
  causal_mask = attention_mask
617
  if attention_mask is not None and cache_position is not None:
618
  causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
619
-
620
- # Cast query, key, and value states to the same dtype (bf16)
621
- query_states = query_states.to(dtype=torch.bfloat16)
622
- key_states = key_states.to(dtype=torch.bfloat16)
623
- value_states = value_states.to(dtype=torch.bfloat16)
624
 
625
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
626
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -629,7 +628,6 @@ class GemmoeSdpaAttention(GemmoeAttention):
629
  key_states = key_states.contiguous()
630
  value_states = value_states.contiguous()
631
 
632
-
633
  attn_output = torch.nn.functional.scaled_dot_product_attention(
634
  query_states,
635
  key_states,
 
616
  causal_mask = attention_mask
617
  if attention_mask is not None and cache_position is not None:
618
  causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
619
+
620
+ # Convert causal_mask to the same dtype as query_states
621
+ if causal_mask is not None:
622
+ causal_mask = causal_mask.to(dtype=query_states.dtype)
 
623
 
624
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
625
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
 
628
  key_states = key_states.contiguous()
629
  value_states = value_states.contiguous()
630
 
 
631
  attn_output = torch.nn.functional.scaled_dot_product_attention(
632
  query_states,
633
  key_states,