Update modeling_internlm.py
Browse files- modeling_internlm.py +4 -16
modeling_internlm.py
CHANGED
@@ -243,22 +243,10 @@ def rotate_half(x):
|
|
243 |
|
244 |
# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
|
245 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
position_ids = position_ids.flatten() + 1
|
252 |
-
max_length = max(position_ids)
|
253 |
-
position_ids = torch.stack([torch.cat([torch.ones(max_length - w, dtype=torch.long), torch.arange(w)]) for w in position_ids])
|
254 |
-
k_cos = cos[position_ids].unsqueeze(1).expand(k.shape)
|
255 |
-
k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
|
256 |
-
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
|
257 |
-
else:
|
258 |
-
cos = cos[position_ids].unsqueeze(1)
|
259 |
-
sin = sin[position_ids].unsqueeze(1)
|
260 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
261 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
262 |
return q_embed, k_embed
|
263 |
|
264 |
|
|
|
243 |
|
244 |
# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
|
245 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
246 |
+
cos = cos[position_ids].unsqueeze(1)
|
247 |
+
sin = sin[position_ids].unsqueeze(1)
|
248 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
249 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
return q_embed, k_embed
|
251 |
|
252 |
|