Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +12 -1
modeling_gemmoe.py
CHANGED
@@ -553,6 +553,17 @@ class GemmoeSdpaAttention(GemmoeAttention):
|
|
553 |
SDPA API.
|
554 |
"""
|
555 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
def forward(
|
557 |
self,
|
558 |
hidden_states: torch.Tensor,
|
@@ -578,7 +589,7 @@ class GemmoeSdpaAttention(GemmoeAttention):
|
|
578 |
output_attentions=output_attentions,
|
579 |
use_cache=use_cache,
|
580 |
cache_position=cache_position,
|
581 |
-
|
582 |
|
583 |
bsz, q_len, _ = hidden_states.size()
|
584 |
|
|
|
553 |
SDPA API.
|
554 |
"""
|
555 |
|
556 |
+
def repeat_kv(self, x, n_rep):
|
557 |
+
"""
|
558 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
559 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
560 |
+
"""
|
561 |
+
batch, num_key_value_heads, slen, head_dim = x.shape
|
562 |
+
if n_rep == 1:
|
563 |
+
return x
|
564 |
+
x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
565 |
+
return x.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
566 |
+
|
567 |
def forward(
|
568 |
self,
|
569 |
hidden_states: torch.Tensor,
|
|
|
589 |
output_attentions=output_attentions,
|
590 |
use_cache=use_cache,
|
591 |
cache_position=cache_position,
|
592 |
+
)
|
593 |
|
594 |
bsz, q_len, _ = hidden_states.size()
|
595 |
|