Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +18 -48
modeling_gemmoe.py
CHANGED
@@ -221,33 +221,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
|
|
221 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
222 |
return q_embed, k_embed
|
223 |
|
224 |
-
class GemmoeMLP(nn.Module):
|
225 |
-
def __init__(self, config):
|
226 |
-
super().__init__()
|
227 |
-
self.config = config
|
228 |
-
self.hidden_size = config.hidden_size
|
229 |
-
self.intermediate_size = config.intermediate_size
|
230 |
-
|
231 |
-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
232 |
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
233 |
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
234 |
-
|
235 |
-
self.act_fn = approx_gelu
|
236 |
-
|
237 |
-
def forward(self, x):
|
238 |
-
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
239 |
-
|
240 |
-
|
241 |
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
|
252 |
class GemmoeAttention(nn.Module):
|
253 |
"""
|
@@ -569,17 +552,7 @@ class GemmoeSdpaAttention(GemmoeAttention):
|
|
569 |
GemmoeAttention as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
570 |
SDPA API.
|
571 |
"""
|
572 |
-
|
573 |
-
"""
|
574 |
-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
575 |
-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
576 |
-
"""
|
577 |
-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
578 |
-
if n_rep == 1:
|
579 |
-
return hidden_states
|
580 |
-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
581 |
-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
582 |
-
|
583 |
def forward(
|
584 |
self,
|
585 |
hidden_states: torch.Tensor,
|
@@ -670,10 +643,12 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
|
|
670 |
super().__init__()
|
671 |
self.ffn_dim = config.intermediate_size
|
672 |
self.hidden_dim = config.hidden_size
|
|
|
673 |
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
674 |
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
675 |
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
676 |
-
|
|
|
677 |
|
678 |
def forward(self, hidden_states):
|
679 |
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
@@ -734,20 +709,14 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
734 |
|
735 |
|
736 |
class GemmoeDecoderLayer(nn.Module):
|
737 |
-
"""
|
738 |
-
Decoder layer for the Gemmoe model.
|
739 |
-
|
740 |
-
Args:
|
741 |
-
config (GemmoeConfig): The configuration object for the Gemmoe model.
|
742 |
-
layer_idx (int): The index of the layer.
|
743 |
-
"""
|
744 |
def __init__(self, config: GemmoeConfig, layer_idx: int):
|
745 |
super().__init__()
|
746 |
self.hidden_size = config.hidden_size
|
747 |
-
|
748 |
-
self.
|
|
|
749 |
self.block_sparse_moe = GemmoeSparseMoeBlock(config)
|
750 |
-
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
751 |
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
752 |
|
753 |
def forward(
|
@@ -901,6 +870,7 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
901 |
self.layers = nn.ModuleList(
|
902 |
[GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
903 |
)
|
|
|
904 |
self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
905 |
|
906 |
self.gradient_checkpointing = False
|
|
|
221 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
222 |
return q_embed, k_embed
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
225 |
+
"""
|
226 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
227 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
228 |
+
"""
|
229 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
230 |
+
if n_rep == 1:
|
231 |
+
return hidden_states
|
232 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
233 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
234 |
|
235 |
class GemmoeAttention(nn.Module):
|
236 |
"""
|
|
|
552 |
GemmoeAttention as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
553 |
SDPA API.
|
554 |
"""
|
555 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
def forward(
|
557 |
self,
|
558 |
hidden_states: torch.Tensor,
|
|
|
643 |
super().__init__()
|
644 |
self.ffn_dim = config.intermediate_size
|
645 |
self.hidden_dim = config.hidden_size
|
646 |
+
|
647 |
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
648 |
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
649 |
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
650 |
+
|
651 |
+
self.act_fn = approx_gelu
|
652 |
|
653 |
def forward(self, hidden_states):
|
654 |
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
|
|
709 |
|
710 |
|
711 |
class GemmoeDecoderLayer(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
712 |
def __init__(self, config: GemmoeConfig, layer_idx: int):
|
713 |
super().__init__()
|
714 |
self.hidden_size = config.hidden_size
|
715 |
+
|
716 |
+
self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
717 |
+
|
718 |
self.block_sparse_moe = GemmoeSparseMoeBlock(config)
|
719 |
+
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
720 |
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
721 |
|
722 |
def forward(
|
|
|
870 |
self.layers = nn.ModuleList(
|
871 |
[GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
872 |
)
|
873 |
+
|
874 |
self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
875 |
|
876 |
self.gradient_checkpointing = False
|