katuni4ka commited on
Commit
c4d0882
1 Parent(s): 28a5a00

Update modeling_internlm.py

Browse files
Files changed (1) hide show
  1. modeling_internlm.py +4 -16
modeling_internlm.py CHANGED
@@ -243,22 +243,10 @@ def rotate_half(x):
243
 
244
  # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
245
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
246
- if position_ids.size(1) == 1:
247
- q_cos = cos[position_ids].unsqueeze(1).expand(q.shape)
248
- q_sin = sin[position_ids].unsqueeze(1).expand(q.shape)
249
- q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
250
-
251
- position_ids = position_ids.flatten() + 1
252
- max_length = max(position_ids)
253
- position_ids = torch.stack([torch.cat([torch.ones(max_length - w, dtype=torch.long), torch.arange(w)]) for w in position_ids])
254
- k_cos = cos[position_ids].unsqueeze(1).expand(k.shape)
255
- k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
256
- k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
257
- else:
258
- cos = cos[position_ids].unsqueeze(1)
259
- sin = sin[position_ids].unsqueeze(1)
260
- q_embed = (q * cos) + (rotate_half(q) * sin)
261
- k_embed = (k * cos) + (rotate_half(k) * sin)
262
  return q_embed, k_embed
263
 
264
 
 
243
 
244
  # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
245
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
246
+ cos = cos[position_ids].unsqueeze(1)
247
+ sin = sin[position_ids].unsqueeze(1)
248
+ q_embed = (q * cos) + (rotate_half(q) * sin)
249
+ k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
 
 
 
 
 
 
 
 
 
250
  return q_embed, k_embed
251
 
252