Crystalcareai commited on
Commit
cf62fb2
·
verified ·
1 Parent(s): da3db13

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +12 -1
modeling_gemmoe.py CHANGED
@@ -553,6 +553,17 @@ class GemmoeSdpaAttention(GemmoeAttention):
553
  SDPA API.
554
  """
555
 
 
 
 
 
 
 
 
 
 
 
 
556
  def forward(
557
  self,
558
  hidden_states: torch.Tensor,
@@ -578,7 +589,7 @@ class GemmoeSdpaAttention(GemmoeAttention):
578
  output_attentions=output_attentions,
579
  use_cache=use_cache,
580
  cache_position=cache_position,
581
- ),
582
 
583
  bsz, q_len, _ = hidden_states.size()
584
 
 
553
  SDPA API.
554
  """
555
 
556
+ def repeat_kv(self, x, n_rep):
557
+ """
558
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
559
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
560
+ """
561
+ batch, num_key_value_heads, slen, head_dim = x.shape
562
+ if n_rep == 1:
563
+ return x
564
+ x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
565
+ return x.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
566
+
567
  def forward(
568
  self,
569
  hidden_states: torch.Tensor,
 
589
  output_attentions=output_attentions,
590
  use_cache=use_cache,
591
  cache_position=cache_position,
592
+ )
593
 
594
  bsz, q_len, _ = hidden_states.size()
595