Crystalcareai commited on
Commit
1318de9
·
verified ·
1 Parent(s): cfc4ccd

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +3 -3
modeling_gemmoe.py CHANGED
@@ -702,10 +702,10 @@ class GemmoeSparseMoeBlock(nn.Module):
702
  expert_indices = (selected_experts[token_indices] == expert_idx).nonzero(as_tuple=True)[1]
703
  current_hidden_states *= top_routing_weights[token_indices, expert_indices, None]
704
 
705
- final_hidden_states.index_add_(0, token_indices, current_hidden_states)
 
706
 
707
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
708
- return final_hidden_states, router_logits
709
 
710
 
711
  class GemmoeDecoderLayer(nn.Module):
 
702
  expert_indices = (selected_experts[token_indices] == expert_idx).nonzero(as_tuple=True)[1]
703
  current_hidden_states *= top_routing_weights[token_indices, expert_indices, None]
704
 
705
+ # Cast current_hidden_states to the same data type as final_hidden_states
706
+ current_hidden_states = current_hidden_states.to(final_hidden_states.dtype)
707
 
708
+ final_hidden_states.index_add_(0, token_indices, current_hidden_states)
 
709
 
710
 
711
  class GemmoeDecoderLayer(nn.Module):