Crystalcareai commited on
Commit
382509c
·
verified ·
1 Parent(s): 877429a

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +17 -19
modeling_gemmoe.py CHANGED
@@ -616,11 +616,11 @@ class GemmoeSdpaAttention(GemmoeAttention):
616
  causal_mask = attention_mask
617
  if attention_mask is not None and cache_position is not None:
618
  causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
619
-
620
- # Cast query, key, and value states to the same dtype (bf16)
621
- query_states = query_states.to(dtype=torch.bfloat16)
622
- key_states = key_states.to(dtype=torch.bfloat16)
623
- value_states = value_states.to(dtype=torch.bfloat16)
624
 
625
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
626
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -629,6 +629,9 @@ class GemmoeSdpaAttention(GemmoeAttention):
629
  key_states = key_states.contiguous()
630
  value_states = value_states.contiguous()
631
 
 
 
 
632
 
633
  attn_output = torch.nn.functional.scaled_dot_product_attention(
634
  query_states,
@@ -1212,15 +1215,11 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1212
  )
1213
 
1214
  hidden_states = outputs[0]
1215
- logits = self.lm_head(hidden_states)
1216
- logits = logits.float()
1217
 
1218
- # Handle unused parameters
1219
- if self.training:
1220
- for expert in self.model.layers[-1].block_sparse_moe.experts:
1221
- for param in expert.parameters():
1222
- if param.requires_grad and param.grad is None:
1223
- param.grad = torch.zeros_like(param)
1224
 
1225
  loss = None
1226
  if labels is not None:
@@ -1299,8 +1298,8 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1299
  past_length = 0
1300
  else:
1301
  past_length = cache_position[-1] + 1
1302
- input_ids = input_ids[:, past_length:]
1303
- position_ids = position_ids[:, past_length:]
1304
 
1305
  cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1306
 
@@ -1329,6 +1328,7 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1329
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1330
  )
1331
  return reordered_past
 
1332
  @add_start_docstrings(
1333
  """
1334
  The Gemmoe Model transformer with a sequence classification head on top (linear layer).
@@ -1418,10 +1418,8 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1418
  sequence_lengths = -1
1419
  else:
1420
  if input_ids is not None:
1421
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1422
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1423
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1424
- sequence_lengths = sequence_lengths.to(logits.device)
1425
  else:
1426
  sequence_lengths = -1
1427
 
 
616
  causal_mask = attention_mask
617
  if attention_mask is not None and cache_position is not None:
618
  causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
619
+
620
+ # Ensure query, key, and value states have the same dtype
621
+ common_dtype = query_states.dtype
622
+ key_states = key_states.to(dtype=common_dtype)
623
+ value_states = value_states.to(dtype=common_dtype)
624
 
625
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
626
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
 
629
  key_states = key_states.contiguous()
630
  value_states = value_states.contiguous()
631
 
632
+ # Cast causal_mask to the same dtype as query_states
633
+ if causal_mask is not None:
634
+ causal_mask = causal_mask.to(dtype=query_states.dtype)
635
 
636
  attn_output = torch.nn.functional.scaled_dot_product_attention(
637
  query_states,
 
1215
  )
1216
 
1217
  hidden_states = outputs[0]
 
 
1218
 
1219
+ # Ensure hidden_states and lm_head have compatible dtypes
1220
+ hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
1221
+
1222
+ logits = self.lm_head(hidden_states)
 
 
1223
 
1224
  loss = None
1225
  if labels is not None:
 
1298
  past_length = 0
1299
  else:
1300
  past_length = cache_position[-1] + 1
1301
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1302
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1303
 
1304
  cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1305
 
 
1328
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1329
  )
1330
  return reordered_past
1331
+
1332
  @add_start_docstrings(
1333
  """
1334
  The Gemmoe Model transformer with a sequence classification head on top (linear layer).
 
1418
  sequence_lengths = -1
1419
  else:
1420
  if input_ids is not None:
1421
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1422
+ sequence_lengths = sequence_lengths.clamp(min=0).to(logits.device)
 
 
1423
  else:
1424
  sequence_lengths = -1
1425