Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +3 -3
modeling_gemmoe.py
CHANGED
@@ -702,10 +702,10 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
702 |
expert_indices = (selected_experts[token_indices] == expert_idx).nonzero(as_tuple=True)[1]
|
703 |
current_hidden_states *= top_routing_weights[token_indices, expert_indices, None]
|
704 |
|
705 |
-
|
|
|
706 |
|
707 |
-
|
708 |
-
return final_hidden_states, router_logits
|
709 |
|
710 |
|
711 |
class GemmoeDecoderLayer(nn.Module):
|
|
|
702 |
expert_indices = (selected_experts[token_indices] == expert_idx).nonzero(as_tuple=True)[1]
|
703 |
current_hidden_states *= top_routing_weights[token_indices, expert_indices, None]
|
704 |
|
705 |
+
# Cast current_hidden_states to the same data type as final_hidden_states
|
706 |
+
current_hidden_states = current_hidden_states.to(final_hidden_states.dtype)
|
707 |
|
708 |
+
final_hidden_states.index_add_(0, token_indices, current_hidden_states)
|
|
|
709 |
|
710 |
|
711 |
class GemmoeDecoderLayer(nn.Module):
|