fix(modeling_phi): Fixes cached generation when above maximum context length.
Browse files- modeling_phi.py +11 -20
modeling_phi.py
CHANGED
@@ -261,32 +261,30 @@ class RotaryEmbedding(nn.Module):
|
|
261 |
seqlen_offset: int = 0,
|
262 |
**kwargs,
|
263 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
264 |
-
seq_start = seqlen_offset
|
265 |
-
seq_end = seq_start + qkv.shape[1]
|
266 |
-
|
267 |
if (
|
268 |
-
self.
|
|
|
269 |
or self._cos_cached.dtype != qkv.dtype
|
270 |
or (self.training and self._cos_cached.is_inference())
|
271 |
):
|
272 |
-
self._update_cos_sin_cache(
|
273 |
|
274 |
if kv is None:
|
275 |
return _apply_rotary_emb_qkv(
|
276 |
qkv,
|
277 |
-
self._cos_cached[
|
278 |
-
self._sin_cached[
|
279 |
)
|
280 |
else:
|
281 |
q = _apply_rotary_emb(
|
282 |
qkv,
|
283 |
-
self._cos_cached[
|
284 |
-
self._sin_cached[
|
285 |
)
|
286 |
kv = _apply_rotary_emb_kv(
|
287 |
kv,
|
288 |
-
self._cos_cached[
|
289 |
-
self._sin_cached[
|
290 |
)
|
291 |
|
292 |
return q, kv
|
@@ -498,9 +496,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|
498 |
sequence_end = sequence_start + kv.shape[1]
|
499 |
|
500 |
# When the current sequence length is equal to or larger than the maximum sequence length,
|
501 |
-
# we need to
|
502 |
if sequence_end >= inference_params.max_seqlen:
|
503 |
-
inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx]
|
504 |
|
505 |
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
506 |
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
|
@@ -864,13 +862,6 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
864 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
865 |
**kwargs,
|
866 |
) -> Dict[str, Any]:
|
867 |
-
# Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding
|
868 |
-
# the maximum sequence length
|
869 |
-
if input_ids.shape[1] > self.config.n_positions:
|
870 |
-
input_ids = input_ids[:, -self.config.n_positions :]
|
871 |
-
if attention_mask is not None:
|
872 |
-
attention_mask = attention_mask[:, -self.config.n_positions :]
|
873 |
-
|
874 |
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
875 |
past_key_values = InferenceParams(
|
876 |
max_seqlen=self.config.n_positions,
|
|
|
261 |
seqlen_offset: int = 0,
|
262 |
**kwargs,
|
263 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
264 |
if (
|
265 |
+
self._seq_len_cached < qkv.shape[1] + seqlen_offset
|
266 |
+
or self._cos_cached.device != qkv.device
|
267 |
or self._cos_cached.dtype != qkv.dtype
|
268 |
or (self.training and self._cos_cached.is_inference())
|
269 |
):
|
270 |
+
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
271 |
|
272 |
if kv is None:
|
273 |
return _apply_rotary_emb_qkv(
|
274 |
qkv,
|
275 |
+
self._cos_cached[seqlen_offset:],
|
276 |
+
self._sin_cached[seqlen_offset:],
|
277 |
)
|
278 |
else:
|
279 |
q = _apply_rotary_emb(
|
280 |
qkv,
|
281 |
+
self._cos_cached[seqlen_offset:],
|
282 |
+
self._sin_cached[seqlen_offset:],
|
283 |
)
|
284 |
kv = _apply_rotary_emb_kv(
|
285 |
kv,
|
286 |
+
self._cos_cached[seqlen_offset:],
|
287 |
+
self._sin_cached[seqlen_offset:],
|
288 |
)
|
289 |
|
290 |
return q, kv
|
|
|
496 |
sequence_end = sequence_start + kv.shape[1]
|
497 |
|
498 |
# When the current sequence length is equal to or larger than the maximum sequence length,
|
499 |
+
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
500 |
if sequence_end >= inference_params.max_seqlen:
|
501 |
+
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
502 |
|
503 |
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
504 |
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
|
|
|
862 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
863 |
**kwargs,
|
864 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
865 |
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
866 |
past_key_values = InferenceParams(
|
867 |
max_seqlen=self.config.n_positions,
|