Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +12 -26
modeling_gemmoe.py
CHANGED
@@ -670,44 +670,30 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
670 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
671 |
|
672 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
673 |
-
hidden_states = hidden_states.to(self.gate.weight.device)
|
674 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
675 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
676 |
|
677 |
# router_logits: (batch * sequence_length, n_experts)
|
678 |
router_logits = self.gate(hidden_states)
|
679 |
-
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.
|
680 |
-
|
681 |
-
|
682 |
|
683 |
# we cast back to the input dtype
|
684 |
-
|
685 |
|
686 |
-
|
687 |
-
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
688 |
-
)
|
689 |
-
|
690 |
-
# Loop over all available experts in the model and perform the computation on each expert
|
691 |
-
for expert_idx in range(self.num_experts):
|
692 |
-
expert_layer = self.experts[expert_idx]
|
693 |
-
token_indices = (selected_experts == expert_idx).any(dim=-1).nonzero(as_tuple=True)[0]
|
694 |
-
|
695 |
-
if token_indices.numel() == 0:
|
696 |
-
continue
|
697 |
-
|
698 |
-
current_state = hidden_states[token_indices]
|
699 |
-
current_hidden_states = expert_layer(current_state)
|
700 |
|
701 |
-
|
702 |
-
expert_indices = (selected_experts[token_indices] == expert_idx).nonzero(as_tuple=True)[1]
|
703 |
-
current_hidden_states *= top_routing_weights[token_indices, expert_indices, None]
|
704 |
|
705 |
-
|
706 |
-
|
|
|
|
|
707 |
|
708 |
-
|
709 |
|
710 |
-
final_hidden_states =
|
711 |
return final_hidden_states, router_logits
|
712 |
|
713 |
|
|
|
670 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
671 |
|
672 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
673 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
674 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
675 |
|
676 |
# router_logits: (batch * sequence_length, n_experts)
|
677 |
router_logits = self.gate(hidden_states)
|
678 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
679 |
+
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
680 |
+
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
681 |
|
682 |
# we cast back to the input dtype
|
683 |
+
topk_weight = topk_weight.to(hidden_states.dtype)
|
684 |
|
685 |
+
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
|
687 |
+
y = torch.empty_like(hidden_states)
|
|
|
|
|
688 |
|
689 |
+
flat_topk_idx = topk_idx.view(-1)
|
690 |
+
for i in range(self.num_experts):
|
691 |
+
expert = self.experts[i]
|
692 |
+
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
693 |
|
694 |
+
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
695 |
|
696 |
+
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
697 |
return final_hidden_states, router_logits
|
698 |
|
699 |
|