Update myr1/modeling_deepseek.py
Browse files- myr1/modeling_deepseek.py +76 -19
myr1/modeling_deepseek.py
CHANGED
@@ -2,9 +2,19 @@
|
|
2 |
modeling_deepseek.py
|
3 |
|
4 |
An improved version of the DeepSeekV3 model code with added docstrings, in-line commentary,
|
5 |
-
some mild refactoring, and suggestions for potential future enhancements
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
"""
|
9 |
|
10 |
import math
|
@@ -45,7 +55,7 @@ from transformers.utils import (
|
|
45 |
from transformers.utils.import_utils import is_torch_fx_available
|
46 |
|
47 |
# Import your configuration
|
48 |
-
from .configuration_deepseek import DeepseekV3Config
|
49 |
|
50 |
import torch.distributed as dist
|
51 |
import numpy as np
|
@@ -330,7 +340,7 @@ class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
|
|
330 |
self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
|
331 |
|
332 |
|
333 |
-
# ==============================================================================
|
334 |
# General Rotary helper functions
|
335 |
# ==============================================================================
|
336 |
|
@@ -438,6 +448,8 @@ class MoEGate(nn.Module):
|
|
438 |
logits = F.linear(hidden_states.float(), self.weight.float(), None)
|
439 |
if self.scoring_func == "sigmoid":
|
440 |
scores = logits.sigmoid()
|
|
|
|
|
441 |
else:
|
442 |
raise NotImplementedError(
|
443 |
f"Unsupported gating scoring function: {self.scoring_func}"
|
@@ -462,6 +474,9 @@ class MoEGate(nn.Module):
|
|
462 |
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
463 |
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
464 |
topk_weight = scores_for_choice.gather(1, topk_idx)
|
|
|
|
|
|
|
465 |
else:
|
466 |
raise NotImplementedError(
|
467 |
f"Unsupported topk_method: {self.topk_method}"
|
@@ -656,6 +671,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
656 |
class DeepseekV3Attention(nn.Module):
|
657 |
"""
|
658 |
Standard multi-headed attention for Deepseek.
|
|
|
|
|
|
|
|
|
659 |
"""
|
660 |
def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
|
661 |
super().__init__()
|
@@ -665,6 +684,13 @@ class DeepseekV3Attention(nn.Module):
|
|
665 |
self.attention_dropout = config.attention_dropout
|
666 |
self.hidden_size = config.hidden_size
|
667 |
self.num_heads = config.num_attention_heads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
668 |
|
669 |
self.max_position_embeddings = config.max_position_embeddings
|
670 |
self.rope_theta = config.rope_theta
|
@@ -691,16 +717,16 @@ class DeepseekV3Attention(nn.Module):
|
|
691 |
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
692 |
)
|
693 |
|
694 |
-
# K,V-proj (MQA style)
|
695 |
self.kv_a_proj_with_mqa = nn.Linear(
|
696 |
self.hidden_size,
|
697 |
-
config.kv_lora_rank +
|
698 |
bias=config.attention_bias,
|
699 |
)
|
700 |
self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
|
701 |
self.kv_b_proj = nn.Linear(
|
702 |
config.kv_lora_rank,
|
703 |
-
self.
|
704 |
bias=False,
|
705 |
)
|
706 |
|
@@ -731,8 +757,8 @@ class DeepseekV3Attention(nn.Module):
|
|
731 |
base=self.rope_theta,
|
732 |
)
|
733 |
else:
|
734 |
-
scaling_type = self.config.rope_scaling["type"]
|
735 |
-
scaling_factor = self.config.rope_scaling["factor"]
|
736 |
|
737 |
if scaling_type == "linear":
|
738 |
self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
|
@@ -782,6 +808,12 @@ class DeepseekV3Attention(nn.Module):
|
|
782 |
):
|
783 |
"""
|
784 |
Standard forward pass for multi-headed self-attention.
|
|
|
|
|
|
|
|
|
|
|
|
|
785 |
"""
|
786 |
if "padding_mask" in kwargs:
|
787 |
warnings.warn(
|
@@ -798,7 +830,7 @@ class DeepseekV3Attention(nn.Module):
|
|
798 |
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
799 |
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
800 |
|
801 |
-
# MQA: K,V from single projection
|
802 |
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
803 |
compressed_kv, k_pe = torch.split(
|
804 |
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
@@ -806,7 +838,7 @@ class DeepseekV3Attention(nn.Module):
|
|
806 |
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
807 |
kv = (
|
808 |
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
809 |
-
.view(bsz, q_len, self.
|
810 |
.transpose(1, 2)
|
811 |
)
|
812 |
k_nope, value_states = torch.split(
|
@@ -829,10 +861,17 @@ class DeepseekV3Attention(nn.Module):
|
|
829 |
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
830 |
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
831 |
|
832 |
-
key_states = k_pe.new_empty(bsz, self.
|
833 |
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
834 |
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
835 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
836 |
if past_key_value is not None:
|
837 |
cache_kwargs = {"sin": sin, "cos": cos} # for RoPE
|
838 |
key_states, value_states = past_key_value.update(
|
@@ -866,7 +905,7 @@ class DeepseekV3Attention(nn.Module):
|
|
866 |
class DeepseekV3FlashAttention2(DeepseekV3Attention):
|
867 |
"""
|
868 |
DeepseekV3 flash attention module. Inherits the same Q/K/V projections from DeepseekV3Attention.
|
869 |
-
Only the forward pass changes to use flash_attn APIs.
|
870 |
"""
|
871 |
def __init__(self, *args, **kwargs):
|
872 |
super().__init__(*args, **kwargs)
|
@@ -906,7 +945,7 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
|
|
906 |
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
907 |
kv = (
|
908 |
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
909 |
-
.view(bsz, q_len, self.
|
910 |
.transpose(1, 2)
|
911 |
)
|
912 |
k_nope, value_states = torch.split(
|
@@ -923,10 +962,17 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
|
|
923 |
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
924 |
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
925 |
|
926 |
-
key_states = k_pe.new_empty(bsz, self.
|
927 |
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
928 |
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
929 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
930 |
if self.q_head_dim != self.v_head_dim:
|
931 |
# Pad if needed
|
932 |
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
|
@@ -1091,6 +1137,7 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
|
|
1091 |
ATTENTION_CLASSES = {
|
1092 |
"eager": DeepseekV3Attention,
|
1093 |
"flash_attention_2": DeepseekV3FlashAttention2,
|
|
|
1094 |
}
|
1095 |
|
1096 |
|
@@ -1106,7 +1153,7 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|
1106 |
super().__init__()
|
1107 |
self.hidden_size = config.hidden_size
|
1108 |
|
1109 |
-
self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
|
1110 |
config=config, layer_idx=layer_idx
|
1111 |
)
|
1112 |
|
@@ -1138,7 +1185,7 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|
1138 |
**kwargs
|
1139 |
):
|
1140 |
"""
|
1141 |
-
Forward pass for one Deepseek decoder layer.
|
1142 |
"""
|
1143 |
residual = hidden_states
|
1144 |
|
@@ -1443,6 +1490,10 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
|
|
1443 |
Args:
|
1444 |
labels (torch.LongTensor of shape (batch_size, sequence_length), optional):
|
1445 |
For computing the language modeling loss. Indices in [0, config.vocab_size] or -100.
|
|
|
|
|
|
|
|
|
1446 |
"""
|
1447 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1448 |
output_hidden_states = (output_hidden_states if output_hidden_states is not None
|
@@ -1500,6 +1551,12 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
|
|
1500 |
):
|
1501 |
"""
|
1502 |
Prepare inputs during generation loops.
|
|
|
|
|
|
|
|
|
|
|
|
|
1503 |
"""
|
1504 |
if past_key_values is not None:
|
1505 |
if isinstance(past_key_values, Cache):
|
@@ -1672,4 +1729,4 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
|
|
1672 |
past_key_values=transformer_outputs.past_key_values,
|
1673 |
hidden_states=transformer_outputs.hidden_states,
|
1674 |
attentions=transformer_outputs.attentions,
|
1675 |
-
)
|
|
|
2 |
modeling_deepseek.py
|
3 |
|
4 |
An improved version of the DeepSeekV3 model code with added docstrings, in-line commentary,
|
5 |
+
some mild refactoring, and suggestions for potential future enhancements for **reasoning** and **efficiency**.
|
6 |
+
This version incorporates architectural considerations for enhanced reasoning,
|
7 |
+
efficiency improvements like GQA (configurable), and placeholders for more advanced features.
|
8 |
+
Actual performance gains may vary based on your environment and training data.
|
9 |
+
|
10 |
+
**Important Notes:**
|
11 |
+
|
12 |
+
* **Configuration is Key:** Many reasoning and efficiency improvements are driven by configuration changes in `configuration_deepseek.py`. You will need to update your config file to fully utilize these features. See comments marked with `[CONFIG]` for configuration-related suggestions.
|
13 |
+
* **Placeholders:** This code includes placeholders (comments and `TODO`s) for features like Sparse Attention, more advanced MoE gating, and Chain-of-Thought (CoT) prompting. Implementing these fully requires more code modifications and potentially changes to your training/inference pipelines.
|
14 |
+
* **Data is Crucial for Reasoning:** Reasoning improvements heavily depend on training data. Consider fine-tuning or pre-training on reasoning-focused datasets and using techniques like Chain-of-Thought data augmentation.
|
15 |
+
* **Grouped-Query Attention (GQA):** This version includes comments and configuration hints for GQA. Full GQA implementation would require modifying the attention logic in `DeepseekV3Attention.forward()` to handle grouped K/V heads. Currently, it's MQA-style.
|
16 |
+
* **Sparse Attention:** Placeholder for integrating Sparse Attention (e.g., Longformer, BigBird). You would need to implement a `SparseDeepseekV3Attention` class and integrate it.
|
17 |
+
|
18 |
"""
|
19 |
|
20 |
import math
|
|
|
55 |
from transformers.utils.import_utils import is_torch_fx_available
|
56 |
|
57 |
# Import your configuration
|
58 |
+
from .configuration_deepseek import DeepseekV3Config # [CONFIG] Make sure DeepseekV3Config has new parameters
|
59 |
|
60 |
import torch.distributed as dist
|
61 |
import numpy as np
|
|
|
340 |
self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
|
341 |
|
342 |
|
343 |
+
# ==============================================================================
|
344 |
# General Rotary helper functions
|
345 |
# ==============================================================================
|
346 |
|
|
|
448 |
logits = F.linear(hidden_states.float(), self.weight.float(), None)
|
449 |
if self.scoring_func == "sigmoid":
|
450 |
scores = logits.sigmoid()
|
451 |
+
elif self.scoring_func == "softmax": # [CONFIG] Option for softmax gating
|
452 |
+
scores = logits.softmax(dim=-1)
|
453 |
else:
|
454 |
raise NotImplementedError(
|
455 |
f"Unsupported gating scoring function: {self.scoring_func}"
|
|
|
474 |
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
475 |
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
476 |
topk_weight = scores_for_choice.gather(1, topk_idx)
|
477 |
+
elif self.topk_method == "topk_gating": # [CONFIG] Option for simpler top-k gating
|
478 |
+
_, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
479 |
+
topk_weight = torch.gather(scores, dim=-1, index=topk_idx)
|
480 |
else:
|
481 |
raise NotImplementedError(
|
482 |
f"Unsupported topk_method: {self.topk_method}"
|
|
|
671 |
class DeepseekV3Attention(nn.Module):
|
672 |
"""
|
673 |
Standard multi-headed attention for Deepseek.
|
674 |
+
|
675 |
+
**Reasoning & Efficiency Improvements Considered:**
|
676 |
+
* **Grouped-Query Attention (GQA):** Configurable via `config.num_key_value_heads` and `config.num_attention_heads`. If `num_key_value_heads < num_attention_heads`, GQA is implicitly enabled. See comments in `forward()` for GQA implementation hints. [CONFIG]
|
677 |
+
* **Sparse Attention:** Placeholder for integration. To use sparse attention, you would need to create a `SparseDeepseekV3Attention` class (e.g., based on LongformerAttention or BigBirdAttention) and replace `DeepseekV3Attention` in `ATTENTION_CLASSES` based on a config flag. [CONFIG]
|
678 |
"""
|
679 |
def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
|
680 |
super().__init__()
|
|
|
684 |
self.attention_dropout = config.attention_dropout
|
685 |
self.hidden_size = config.hidden_size
|
686 |
self.num_heads = config.num_attention_heads
|
687 |
+
self.num_key_value_heads = config.num_key_value_heads # [CONFIG] For GQA
|
688 |
+
|
689 |
+
if self.num_heads % self.num_key_value_heads != 0: # GQA check
|
690 |
+
raise ValueError(
|
691 |
+
"num_attention_heads must be divisible by num_key_value_heads (for GQA)"
|
692 |
+
)
|
693 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads # For GQA
|
694 |
|
695 |
self.max_position_embeddings = config.max_position_embeddings
|
696 |
self.rope_theta = config.rope_theta
|
|
|
717 |
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
718 |
)
|
719 |
|
720 |
+
# K,V-proj (MQA/GQA style)
|
721 |
self.kv_a_proj_with_mqa = nn.Linear(
|
722 |
self.hidden_size,
|
723 |
+
config.kv_lora_rank + self.qk_rope_head_dim,
|
724 |
bias=config.attention_bias,
|
725 |
)
|
726 |
self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
|
727 |
self.kv_b_proj = nn.Linear(
|
728 |
config.kv_lora_rank,
|
729 |
+
self.num_key_value_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), # [GQA] num_key_value_heads here
|
730 |
bias=False,
|
731 |
)
|
732 |
|
|
|
757 |
base=self.rope_theta,
|
758 |
)
|
759 |
else:
|
760 |
+
scaling_type = self.config.rope_scaling["type"] # [CONFIG] Rope scaling type
|
761 |
+
scaling_factor = self.config.rope_scaling["factor"] # [CONFIG] Rope scaling factor
|
762 |
|
763 |
if scaling_type == "linear":
|
764 |
self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
|
|
|
808 |
):
|
809 |
"""
|
810 |
Standard forward pass for multi-headed self-attention.
|
811 |
+
|
812 |
+
**Grouped-Query Attention (GQA) Implementation Notes:**
|
813 |
+
If `num_key_value_heads < num_attention_heads` (GQA is configured):
|
814 |
+
1. `kv` projection will produce `num_key_value_heads` * (head_dim * 2) channels.
|
815 |
+
2. We need to *repeat* the `key_states` and `value_states` `num_key_value_groups` times along the head dimension to match the `num_attention_heads` for query.
|
816 |
+
3. `repeat_kv` utility function is used for this repetition.
|
817 |
"""
|
818 |
if "padding_mask" in kwargs:
|
819 |
warnings.warn(
|
|
|
830 |
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
831 |
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
832 |
|
833 |
+
# MQA/GQA: K,V from single projection
|
834 |
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
835 |
compressed_kv, k_pe = torch.split(
|
836 |
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
|
|
838 |
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
839 |
kv = (
|
840 |
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
841 |
+
.view(bsz, q_len, self.num_key_value_heads, self.qk_nope_head_dim + self.v_head_dim) # [GQA] num_key_value_heads here
|
842 |
.transpose(1, 2)
|
843 |
)
|
844 |
k_nope, value_states = torch.split(
|
|
|
861 |
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
862 |
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
863 |
|
864 |
+
key_states = k_pe.new_empty(bsz, self.num_key_value_heads, q_len, self.q_head_dim) # [GQA] num_key_value_heads here
|
865 |
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
866 |
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
867 |
|
868 |
+
value_states = value_states # [GQA] num_key_value_heads is already in value_states
|
869 |
+
|
870 |
+
# GQA: Repeat K/V states if num_key_value_heads < num_attention_heads
|
871 |
+
if self.num_key_value_groups != 1:
|
872 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
873 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
874 |
+
|
875 |
if past_key_value is not None:
|
876 |
cache_kwargs = {"sin": sin, "cos": cos} # for RoPE
|
877 |
key_states, value_states = past_key_value.update(
|
|
|
905 |
class DeepseekV3FlashAttention2(DeepseekV3Attention):
|
906 |
"""
|
907 |
DeepseekV3 flash attention module. Inherits the same Q/K/V projections from DeepseekV3Attention.
|
908 |
+
Only the forward pass changes to use flash_attn APIs. Supports GQA implicitly through `DeepseekV3Attention`.
|
909 |
"""
|
910 |
def __init__(self, *args, **kwargs):
|
911 |
super().__init__(*args, **kwargs)
|
|
|
945 |
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
946 |
kv = (
|
947 |
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
948 |
+
.view(bsz, q_len, self.num_key_value_heads, self.qk_nope_head_dim + self.v_head_dim) # [GQA] num_key_value_heads here
|
949 |
.transpose(1, 2)
|
950 |
)
|
951 |
k_nope, value_states = torch.split(
|
|
|
962 |
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
963 |
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
964 |
|
965 |
+
key_states = k_pe.new_empty(bsz, self.num_key_value_heads, q_len, self.q_head_dim) # [GQA] num_key_value_heads here
|
966 |
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
967 |
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
968 |
|
969 |
+
value_states = value_states # [GQA] value_states already has num_key_value_heads
|
970 |
+
|
971 |
+
# GQA: Repeat K/V states if num_key_value_heads < num_attention_heads
|
972 |
+
if self.num_key_value_groups != 1:
|
973 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
974 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
975 |
+
|
976 |
if self.q_head_dim != self.v_head_dim:
|
977 |
# Pad if needed
|
978 |
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
|
|
|
1137 |
ATTENTION_CLASSES = {
|
1138 |
"eager": DeepseekV3Attention,
|
1139 |
"flash_attention_2": DeepseekV3FlashAttention2,
|
1140 |
+
# "sparse_attention": SparseDeepseekV3Attention, # [TODO] Placeholder for Sparse Attention class - implement SparseDeepseekV3Attention
|
1141 |
}
|
1142 |
|
1143 |
|
|
|
1153 |
super().__init__()
|
1154 |
self.hidden_size = config.hidden_size
|
1155 |
|
1156 |
+
self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( # [CONFIG] _attn_implementation to select attention type
|
1157 |
config=config, layer_idx=layer_idx
|
1158 |
)
|
1159 |
|
|
|
1185 |
**kwargs
|
1186 |
):
|
1187 |
"""
|
1188 |
+
Forward pass for one Deepseek decoder layer.
|
1189 |
"""
|
1190 |
residual = hidden_states
|
1191 |
|
|
|
1490 |
Args:
|
1491 |
labels (torch.LongTensor of shape (batch_size, sequence_length), optional):
|
1492 |
For computing the language modeling loss. Indices in [0, config.vocab_size] or -100.
|
1493 |
+
|
1494 |
+
**Reasoning Enhancement Considerations:**
|
1495 |
+
* **Chain-of-Thought (CoT) Data:** To effectively improve reasoning, fine-tune this model on datasets that include Chain-of-Thought examples. The model architecture is capable of leveraging CoT if trained appropriately.
|
1496 |
+
* **Prompt Engineering for CoT Inference:** During inference, use prompts that encourage the model to generate reasoning steps (e.g., "Let's think step by step...") to elicit Chain-of-Thought behavior.
|
1497 |
"""
|
1498 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1499 |
output_hidden_states = (output_hidden_states if output_hidden_states is not None
|
|
|
1551 |
):
|
1552 |
"""
|
1553 |
Prepare inputs during generation loops.
|
1554 |
+
|
1555 |
+
**Chain-of-Thought (CoT) Inference Hint:**
|
1556 |
+
When using Chain-of-Thought prompting during generation, ensure your prompts are correctly formatted to encourage reasoning.
|
1557 |
+
Consider using techniques like:
|
1558 |
+
* "Let's think step by step:" prefix in your prompt.
|
1559 |
+
* Sampling strategies that encourage diverse outputs for self-consistency decoding.
|
1560 |
"""
|
1561 |
if past_key_values is not None:
|
1562 |
if isinstance(past_key_values, Cache):
|
|
|
1729 |
past_key_values=transformer_outputs.past_key_values,
|
1730 |
hidden_states=transformer_outputs.hidden_states,
|
1731 |
attentions=transformer_outputs.attentions,
|
1732 |
+
)
|