Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +19 -17
modeling_gemmoe.py
CHANGED
@@ -616,11 +616,11 @@ class GemmoeSdpaAttention(GemmoeAttention):
|
|
616 |
causal_mask = attention_mask
|
617 |
if attention_mask is not None and cache_position is not None:
|
618 |
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
|
625 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
626 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
@@ -629,9 +629,6 @@ class GemmoeSdpaAttention(GemmoeAttention):
|
|
629 |
key_states = key_states.contiguous()
|
630 |
value_states = value_states.contiguous()
|
631 |
|
632 |
-
# Cast causal_mask to the same dtype as query_states
|
633 |
-
if causal_mask is not None:
|
634 |
-
causal_mask = causal_mask.to(dtype=query_states.dtype)
|
635 |
|
636 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
637 |
query_states,
|
@@ -1215,11 +1212,15 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1215 |
)
|
1216 |
|
1217 |
hidden_states = outputs[0]
|
1218 |
-
|
1219 |
-
# Ensure hidden_states and lm_head have compatible dtypes
|
1220 |
-
hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
|
1221 |
-
|
1222 |
logits = self.lm_head(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1223 |
|
1224 |
loss = None
|
1225 |
if labels is not None:
|
@@ -1298,8 +1299,8 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1298 |
past_length = 0
|
1299 |
else:
|
1300 |
past_length = cache_position[-1] + 1
|
1301 |
-
input_ids = input_ids[:,
|
1302 |
-
position_ids = position_ids[:,
|
1303 |
|
1304 |
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
|
1305 |
|
@@ -1328,7 +1329,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1328 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1329 |
)
|
1330 |
return reordered_past
|
1331 |
-
|
1332 |
@add_start_docstrings(
|
1333 |
"""
|
1334 |
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|
@@ -1418,8 +1418,10 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
|
|
1418 |
sequence_lengths = -1
|
1419 |
else:
|
1420 |
if input_ids is not None:
|
1421 |
-
|
1422 |
-
sequence_lengths =
|
|
|
|
|
1423 |
else:
|
1424 |
sequence_lengths = -1
|
1425 |
|
|
|
616 |
causal_mask = attention_mask
|
617 |
if attention_mask is not None and cache_position is not None:
|
618 |
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
619 |
+
|
620 |
+
# Cast query, key, and value states to the same dtype (bf16)
|
621 |
+
query_states = query_states.to(dtype=torch.bfloat16)
|
622 |
+
key_states = key_states.to(dtype=torch.bfloat16)
|
623 |
+
value_states = value_states.to(dtype=torch.bfloat16)
|
624 |
|
625 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
626 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
|
|
629 |
key_states = key_states.contiguous()
|
630 |
value_states = value_states.contiguous()
|
631 |
|
|
|
|
|
|
|
632 |
|
633 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
634 |
query_states,
|
|
|
1212 |
)
|
1213 |
|
1214 |
hidden_states = outputs[0]
|
|
|
|
|
|
|
|
|
1215 |
logits = self.lm_head(hidden_states)
|
1216 |
+
logits = logits.float()
|
1217 |
+
|
1218 |
+
# Handle unused parameters
|
1219 |
+
if self.training:
|
1220 |
+
for expert in self.model.layers[-1].block_sparse_moe.experts:
|
1221 |
+
for param in expert.parameters():
|
1222 |
+
if param.requires_grad and param.grad is None:
|
1223 |
+
param.grad = torch.zeros_like(param)
|
1224 |
|
1225 |
loss = None
|
1226 |
if labels is not None:
|
|
|
1299 |
past_length = 0
|
1300 |
else:
|
1301 |
past_length = cache_position[-1] + 1
|
1302 |
+
input_ids = input_ids[:, past_length:]
|
1303 |
+
position_ids = position_ids[:, past_length:]
|
1304 |
|
1305 |
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
|
1306 |
|
|
|
1329 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1330 |
)
|
1331 |
return reordered_past
|
|
|
1332 |
@add_start_docstrings(
|
1333 |
"""
|
1334 |
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|
|
|
1418 |
sequence_lengths = -1
|
1419 |
else:
|
1420 |
if input_ids is not None:
|
1421 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1422 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1423 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1424 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1425 |
else:
|
1426 |
sequence_lengths = -1
|
1427 |
|