Crystalcareai commited on
Commit
8825292
·
verified ·
1 Parent(s): 1318de9

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +4 -1
modeling_gemmoe.py CHANGED
@@ -669,7 +669,7 @@ class GemmoeSparseMoeBlock(nn.Module):
669
 
670
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
671
 
672
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
673
  hidden_states = hidden_states.to(self.gate.weight.device)
674
  batch_size, sequence_length, hidden_dim = hidden_states.shape
675
  hidden_states = hidden_states.view(-1, hidden_dim)
@@ -707,6 +707,9 @@ class GemmoeSparseMoeBlock(nn.Module):
707
 
708
  final_hidden_states.index_add_(0, token_indices, current_hidden_states)
709
 
 
 
 
710
 
711
  class GemmoeDecoderLayer(nn.Module):
712
  def __init__(self, config: GemmoeConfig, layer_idx: int):
 
669
 
670
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
671
 
672
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
673
  hidden_states = hidden_states.to(self.gate.weight.device)
674
  batch_size, sequence_length, hidden_dim = hidden_states.shape
675
  hidden_states = hidden_states.view(-1, hidden_dim)
 
707
 
708
  final_hidden_states.index_add_(0, token_indices, current_hidden_states)
709
 
710
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
711
+ return final_hidden_states, router_logits
712
+
713
 
714
  class GemmoeDecoderLayer(nn.Module):
715
  def __init__(self, config: GemmoeConfig, layer_idx: int):