Ayushk44 commited on
Commit
213e2b5
·
verified ·
1 Parent(s): 95b72bd

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. 17_99p_300t/step10000/README.md +1 -0
  2. 17_99p_300t/step10000/generation_config.json +6 -0
  3. 17_99p_300t/step10000/tokenizer.json +0 -0
  4. 17_99p_300t/step10000/tokenizer_config.json +214 -0
  5. 17_99p_300t/step100000/README.md +1 -0
  6. 17_99p_300t/step100000/config.json +30 -0
  7. 17_99p_300t/step100000/configuration_spectra2.py +174 -0
  8. 17_99p_300t/step100000/modeling_spectra2.py +849 -0
  9. 17_99p_300t/step100000/special_tokens_map.json +23 -0
  10. 17_99p_300t/step100000/tokenizer.json +0 -0
  11. 17_99p_300t/step110000/README.md +1 -0
  12. 17_99p_300t/step110000/config.json +30 -0
  13. 17_99p_300t/step110000/configuration_spectra2.py +174 -0
  14. 17_99p_300t/step110000/generation_config.json +6 -0
  15. 17_99p_300t/step110000/modeling_spectra2.py +849 -0
  16. 17_99p_300t/step110000/special_tokens_map.json +23 -0
  17. 17_99p_300t/step110000/tokenizer.json +0 -0
  18. 17_99p_300t/step110000/tokenizer_config.json +214 -0
  19. 17_99p_300t/step140000/modeling_spectra2.py +849 -0
  20. 17_99p_300t/step140000/special_tokens_map.json +23 -0
  21. 17_99p_300t/step140000/tokenizer.json +0 -0
  22. 17_99p_300t/step50000/tokenizer.json +0 -0
  23. 17_99p_300t/step60000/config.json +30 -0
  24. 17_99p_300t/step60000/generation_config.json +6 -0
  25. 17_99p_300t/step60000/special_tokens_map.json +23 -0
  26. 17_99p_300t/step60000/tokenizer.json +0 -0
  27. 17_99p_300t/step60000/tokenizer_config.json +214 -0
  28. 17_99p_300t/step70000/README.md +1 -0
  29. 17_99p_300t/step70000/config.json +30 -0
  30. 17_99p_300t/step70000/configuration_spectra2.py +174 -0
  31. 17_99p_300t/step70000/generation_config.json +6 -0
  32. 17_99p_300t/step70000/modeling_spectra2.py +849 -0
  33. 17_99p_300t/step70000/special_tokens_map.json +23 -0
  34. 17_99p_300t/step70000/tokenizer.json +0 -0
  35. 17_99p_300t/step70000/tokenizer_config.json +214 -0
  36. 17_99p_300t/step80000/config.json +30 -0
  37. 17_99p_300t/step80000/configuration_spectra2.py +174 -0
  38. 17_99p_300t/step80000/tokenizer.json +0 -0
  39. 17_99p_300t/step80000/tokenizer_config.json +214 -0
  40. 17_99p_300t/step90000/README.md +1 -0
  41. 17_99p_300t/step90000/configuration_spectra2.py +174 -0
  42. 17_99p_300t/step90000/generation_config.json +6 -0
  43. 17_99p_300t/step90000/modeling_spectra2.py +849 -0
  44. 17_99p_300t/step90000/special_tokens_map.json +23 -0
  45. 17_99p_300t/step90000/tokenizer.json +0 -0
  46. 17_99p_300t/step90000/tokenizer_config.json +214 -0
  47. data-indices/rank0.tsv.gz +3 -0
  48. data-indices/rank1.tsv.gz +3 -0
  49. data-indices/rank10.tsv.gz +3 -0
  50. data-indices/rank100.tsv.gz +3 -0
17_99p_300t/step10000/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ `pip install git+https://github.com/huggingface/transformers.git@05260a1`
17_99p_300t/step10000/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 50277,
4
+ "pad_token_id": 1,
5
+ "transformers_version": "4.45.2"
6
+ }
17_99p_300t/step10000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
17_99p_300t/step10000/tokenizer_config.json ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ }
206
+ },
207
+ "bos_token": "<|endoftext|>",
208
+ "clean_up_tokenization_spaces": true,
209
+ "eos_token": "<|endoftext|>",
210
+ "model_max_length": 1000000000000000019884624838656,
211
+ "pad_token": null,
212
+ "tokenizer_class": "GPTNeoXTokenizer",
213
+ "unk_token": "<|endoftext|>"
214
+ }
17_99p_300t/step100000/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ `pip install git+https://github.com/huggingface/transformers.git@05260a1`
17_99p_300t/step100000/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Spectra2ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "clip_qkv": null,
8
+ "eos_token_id": 50277,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 512,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1280,
13
+ "max_position_embeddings": 2048,
14
+ "model_type": "spectra2",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 16,
17
+ "num_key_value_heads": 8,
18
+ "pad_token_id": 1,
19
+ "rope_scaling": null,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "torch_dtype": "float32",
23
+ "use_cache": true,
24
+ "vocab_size": 50304,
25
+ "auto_map": {
26
+ "AutoConfig": "configuration_spectra2.Spectra2Config",
27
+ "AutoModel": "modeling_spectra2.Spectra2Model",
28
+ "AutoModelForCausalLM": "modeling_spectra2.Spectra2ForCausalLM"
29
+ }
30
+ }
17_99p_300t/step100000/configuration_spectra2.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Spectra2 model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class Spectra2Config(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`Spectra2Model`]. It is used to instantiate an Spectra2
32
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
+ defaults will yield a similar configuration to that of the [SpectraSuite/Spectra2-3B-base](https://huggingface.co/spectrasuite/Spectra2-3B-base).
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 50304):
38
+ Vocabulary size of the Spectra2 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`Spectra2Model`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 11008):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer decoder.
48
+ num_key_value_heads (`int`, *optional*):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
55
+ `num_attention_heads`.
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
57
+ The non-linear activation function (function or string) in the decoder.
58
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
59
+ The maximum sequence length that this model might ever be used with.
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ use_cache (`bool`, *optional*, defaults to `True`):
63
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
64
+ relevant if `config.is_decoder=True`.
65
+ pad_token_id (`int`, *optional*, defaults to 1):
66
+ Padding token id.
67
+ bos_token_id (`int`, *optional*):
68
+ Beginning of stream token id.
69
+ eos_token_id (`int`, *optional*, defaults to 50279):
70
+ End of stream token id.
71
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
72
+ Whether to tie weight embeddings
73
+ rope_theta (`float`, *optional*, defaults to 10000.0):
74
+ The base period of the RoPE embeddings.
75
+ rope_scaling (`Dict`, *optional*):
76
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
77
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
78
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
79
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
80
+ these scaling strategies behave:
81
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
82
+ experimental feature, subject to breaking API changes in future versions.
83
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
84
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
85
+ attention_dropout (`float`, *optional*, defaults to 0.0):
86
+ The dropout ratio for the attention probabilities.
87
+ clip_qkv (`float`, *optional*):
88
+ If not `None`, elements of query, key and value attention states are clipped so that their
89
+ absolute value does not exceed this value.
90
+ ```python
91
+ >>> from transformers import Spectra2Model, Spectra2Config
92
+ >>> # Initializing a Spectra2 3B style configuration
93
+ >>> configuration = Spectra2Config()
94
+ >>> # Initializing a model from the Spectra2 3B style configuration
95
+ >>> model = Spectra2Model(configuration)
96
+ >>> # Accessing the model configuration
97
+ >>> configuration = model.config
98
+ ```"""
99
+
100
+ model_type = "spectra2"
101
+ keys_to_ignore_at_inference = ["past_key_values"]
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=50304,
106
+ hidden_size=4096,
107
+ intermediate_size=11008,
108
+ num_hidden_layers=32,
109
+ num_attention_heads=32,
110
+ num_key_value_heads=None,
111
+ hidden_act="silu",
112
+ max_position_embeddings=2048,
113
+ initializer_range=0.02,
114
+ use_cache=True,
115
+ pad_token_id=1,
116
+ bos_token_id=None,
117
+ eos_token_id=50279,
118
+ tie_word_embeddings=False,
119
+ rope_theta=10000.0,
120
+ rope_scaling=None,
121
+ attention_bias=False,
122
+ attention_dropout=0.0,
123
+ clip_qkv=None,
124
+ **kwargs,
125
+ ):
126
+ self.vocab_size = vocab_size
127
+ self.max_position_embeddings = max_position_embeddings
128
+ self.hidden_size = hidden_size
129
+ self.intermediate_size = intermediate_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.num_attention_heads = num_attention_heads
132
+
133
+ # for backward compatibility
134
+ if num_key_value_heads is None:
135
+ num_key_value_heads = num_attention_heads
136
+
137
+ self.num_key_value_heads = num_key_value_heads
138
+ self.hidden_act = hidden_act
139
+ self.initializer_range = initializer_range
140
+ self.use_cache = use_cache
141
+ self.rope_theta = rope_theta
142
+ self.rope_scaling = rope_scaling
143
+ self._rope_scaling_validation()
144
+ self.attention_bias = attention_bias
145
+ self.attention_dropout = attention_dropout
146
+ self.clip_qkv = clip_qkv
147
+
148
+ super().__init__(
149
+ pad_token_id=pad_token_id,
150
+ bos_token_id=bos_token_id,
151
+ eos_token_id=eos_token_id,
152
+ tie_word_embeddings=tie_word_embeddings,
153
+ **kwargs,
154
+ )
155
+
156
+ def _rope_scaling_validation(self):
157
+ """
158
+ Validate the `rope_scaling` configuration.
159
+ """
160
+ if self.rope_scaling is None:
161
+ return
162
+
163
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
164
+ raise ValueError(
165
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
166
+ )
167
+ rope_scaling_type = self.rope_scaling.get("type", None)
168
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
169
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
170
+ raise ValueError(
171
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
172
+ )
173
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
174
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
17_99p_300t/step100000/modeling_spectra2.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/spectra2/modular_spectra2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_spectra2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
+ from transformers.generation import GenerationMixin
16
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
17
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import (
23
+ LossKwargs,
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ replace_return_docstrings,
28
+ )
29
+ from .configuration_spectra2 import Spectra2Config
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+ _CONFIG_FOR_DOC = "Spectra2Config"
34
+
35
+
36
+ class Spectra2RMSLayerNorm(nn.Module):
37
+ """LayerNorm but with no learnable weight or bias."""
38
+
39
+ def __init__(self, hidden_size: int) -> None:
40
+ super().__init__()
41
+ self.weight = nn.Parameter(torch.ones(hidden_size))
42
+ self.variance_epsilon = 1e-05 # Hardcoded
43
+ self.normalized_shape = (hidden_size,)
44
+
45
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
46
+ with torch.autocast(enabled=False, device_type=hidden_states.device.type):
47
+ og_dtype = hidden_states.dtype
48
+ hidden_states = hidden_states.to(torch.float32)
49
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
50
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
51
+ hidden_states = hidden_states.to(og_dtype)
52
+ return self.weight * hidden_states
53
+
54
+
55
+ class Spectra2MLP(nn.Module):
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.config = config
59
+ self.hidden_size = config.hidden_size
60
+ self.intermediate_size = config.intermediate_size
61
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
62
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
63
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
64
+ self.act_fn = ACT2FN[config.hidden_act]
65
+
66
+ def forward(self, x):
67
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
68
+ return down_proj
69
+
70
+
71
+ def rotate_half(x):
72
+ """Rotates half the hidden dims of the input."""
73
+ x1 = x[..., : x.shape[-1] // 2]
74
+ x2 = x[..., x.shape[-1] // 2 :]
75
+ return torch.cat((-x2, x1), dim=-1)
76
+
77
+
78
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
79
+ """Applies Rotary Position Embedding to the query and key tensors.
80
+
81
+ Args:
82
+ q (`torch.Tensor`): The query tensor.
83
+ k (`torch.Tensor`): The key tensor.
84
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
85
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
86
+ position_ids (`torch.Tensor`, *optional*):
87
+ Deprecated and unused.
88
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
89
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
90
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
91
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
92
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
93
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
94
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
95
+ Returns:
96
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
97
+ """
98
+ cos = cos.unsqueeze(unsqueeze_dim)
99
+ sin = sin.unsqueeze(unsqueeze_dim)
100
+ q_embed = (q * cos) + (rotate_half(q) * sin)
101
+ k_embed = (k * cos) + (rotate_half(k) * sin)
102
+ return q_embed, k_embed
103
+
104
+
105
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
106
+ """
107
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
108
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
109
+ """
110
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
111
+ if n_rep == 1:
112
+ return hidden_states
113
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
114
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
115
+
116
+
117
+ def eager_attention_forward(
118
+ module: nn.Module,
119
+ query: torch.Tensor,
120
+ key: torch.Tensor,
121
+ value: torch.Tensor,
122
+ attention_mask: Optional[torch.Tensor],
123
+ scaling: float,
124
+ dropout: float = 0.0,
125
+ **kwargs,
126
+ ):
127
+ key_states = repeat_kv(key, module.num_key_value_groups)
128
+ value_states = repeat_kv(value, module.num_key_value_groups)
129
+
130
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
131
+ if attention_mask is not None:
132
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
133
+ attn_weights = attn_weights + causal_mask
134
+
135
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
136
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
137
+ attn_output = torch.matmul(attn_weights, value_states)
138
+ attn_output = attn_output.transpose(1, 2).contiguous()
139
+
140
+ return attn_output, attn_weights
141
+
142
+
143
+ class Spectra2Attention(nn.Module):
144
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
145
+
146
+ def __init__(self, config: Spectra2Config, layer_idx: int):
147
+ super().__init__()
148
+ self.config = config
149
+ self.layer_idx = layer_idx
150
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
151
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
152
+ self.scaling = self.head_dim**-0.5
153
+ self.attention_dropout = config.attention_dropout
154
+ self.is_causal = True
155
+
156
+ self.q_proj = nn.Linear(
157
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
158
+ )
159
+ self.k_proj = nn.Linear(
160
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
161
+ )
162
+ self.v_proj = nn.Linear(
163
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
164
+ )
165
+ self.o_proj = nn.Linear(
166
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
167
+ )
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
173
+ attention_mask: Optional[torch.Tensor],
174
+ past_key_value: Optional[Cache] = None,
175
+ cache_position: Optional[torch.LongTensor] = None,
176
+ **kwargs,
177
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
178
+ input_shape = hidden_states.shape[:-1]
179
+ hidden_shape = (*input_shape, -1, self.head_dim)
180
+
181
+ query_states = self.q_proj(hidden_states)
182
+ key_states = self.k_proj(hidden_states)
183
+ value_states = self.v_proj(hidden_states)
184
+
185
+ if self.config.clip_qkv is not None:
186
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
187
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
188
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
189
+
190
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
191
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
192
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
193
+
194
+ cos, sin = position_embeddings
195
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
196
+
197
+ if past_key_value is not None:
198
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
199
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
200
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
201
+
202
+ attention_interface: Callable = eager_attention_forward
203
+ if self.config._attn_implementation != "eager":
204
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
205
+ logger.warning_once(
206
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
207
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
208
+ )
209
+ else:
210
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
211
+
212
+ attn_output, attn_weights = attention_interface(
213
+ self,
214
+ query_states,
215
+ key_states,
216
+ value_states,
217
+ attention_mask,
218
+ dropout=0.0 if not self.training else self.attention_dropout,
219
+ scaling=self.scaling,
220
+ **kwargs,
221
+ )
222
+
223
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
224
+ attn_output = self.o_proj(attn_output)
225
+ return attn_output, attn_weights
226
+
227
+
228
+ class Spectra2DecoderLayer(nn.Module):
229
+ def __init__(self, config: Spectra2Config, layer_idx: int):
230
+ super().__init__()
231
+ self.hidden_size = config.hidden_size
232
+ self.self_attn = Spectra2Attention(config=config, layer_idx=layer_idx)
233
+
234
+ self.mlp = Spectra2MLP(config)
235
+ self.input_rms_layernorm = Spectra2RMSLayerNorm(config.hidden_size)
236
+ self.post_attention_rms_layernorm = Spectra2RMSLayerNorm(config.hidden_size)
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ position_ids: Optional[torch.LongTensor] = None,
243
+ past_key_value: Optional[Cache] = None,
244
+ output_attentions: Optional[bool] = False,
245
+ use_cache: Optional[bool] = False,
246
+ cache_position: Optional[torch.LongTensor] = None,
247
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
248
+ **kwargs: Unpack[FlashAttentionKwargs],
249
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
250
+ residual = hidden_states
251
+
252
+ hidden_states = self.input_rms_layernorm(hidden_states)
253
+
254
+ # Self Attention
255
+ hidden_states, self_attn_weights = self.self_attn(
256
+ hidden_states=hidden_states,
257
+ attention_mask=attention_mask,
258
+ position_ids=position_ids,
259
+ past_key_value=past_key_value,
260
+ output_attentions=output_attentions,
261
+ use_cache=use_cache,
262
+ cache_position=cache_position,
263
+ position_embeddings=position_embeddings,
264
+ **kwargs,
265
+ )
266
+ hidden_states = residual + hidden_states
267
+
268
+ # Fully Connected
269
+ residual = hidden_states
270
+ hidden_states = self.post_attention_rms_layernorm(hidden_states)
271
+ hidden_states = self.mlp(hidden_states)
272
+ hidden_states = residual + hidden_states
273
+
274
+ outputs = (hidden_states,)
275
+ if output_attentions:
276
+ outputs += (self_attn_weights,)
277
+
278
+ return outputs
279
+
280
+
281
+ class Spectra2RotaryEmbedding(nn.Module):
282
+ def __init__(
283
+ self,
284
+ config: Spectra2Config,
285
+ device=None,
286
+ ):
287
+ super().__init__()
288
+ self.rope_kwargs = {}
289
+ # BC: "rope_type" was originally "type"
290
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
291
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
292
+ else:
293
+ self.rope_type = "default"
294
+ self.max_seq_len_cached = config.max_position_embeddings
295
+ self.original_max_seq_len = config.max_position_embeddings
296
+
297
+ self.config = config
298
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
299
+
300
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
301
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
302
+ self.original_inv_freq = self.inv_freq
303
+
304
+ def _dynamic_frequency_update(self, position_ids, device):
305
+ """
306
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
307
+ 1 - growing beyond the cached sequence length (allow scaling)
308
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
309
+ """
310
+ seq_len = torch.max(position_ids) + 1
311
+ if seq_len > self.max_seq_len_cached: # growth
312
+ inv_freq, self.attention_scaling = self.rope_init_fn(
313
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
314
+ )
315
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
316
+ self.max_seq_len_cached = seq_len
317
+
318
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
319
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
320
+ self.max_seq_len_cached = self.original_max_seq_len
321
+
322
+ @torch.no_grad()
323
+ def forward(self, x, position_ids):
324
+ if "dynamic" in self.rope_type:
325
+ self._dynamic_frequency_update(position_ids, device=x.device)
326
+
327
+ # Core RoPE block
328
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
329
+ position_ids_expanded = position_ids[:, None, :].float()
330
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
331
+ device_type = x.device.type
332
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
333
+ with torch.autocast(device_type=device_type, enabled=False):
334
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
335
+ emb = torch.cat((freqs, freqs), dim=-1)
336
+ cos = emb.cos()
337
+ sin = emb.sin()
338
+
339
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
340
+ cos = cos * self.attention_scaling
341
+ sin = sin * self.attention_scaling
342
+
343
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
344
+
345
+
346
+ SPECTRA2_START_DOCSTRING = r"""
347
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
348
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
349
+ etc.)
350
+
351
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
352
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
353
+ and behavior.
354
+
355
+ Parameters:
356
+ config ([`Spectra2Config`]):
357
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
358
+ load the weights associated with the model, only the configuration. Check out the
359
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
360
+ """
361
+
362
+
363
+ @add_start_docstrings(
364
+ "The bare Spectra2 Model outputting raw hidden-states without any specific head on top.",
365
+ SPECTRA2_START_DOCSTRING,
366
+ )
367
+ class Spectra2PreTrainedModel(PreTrainedModel):
368
+ config_class = Spectra2Config
369
+ base_model_prefix = "model"
370
+ supports_gradient_checkpointing = True
371
+ _no_split_modules = ["Spectra2DecoderLayer"]
372
+ _skip_keys_device_placement = ["past_key_values"]
373
+ _supports_flash_attn_2 = True
374
+ _supports_sdpa = True
375
+ _supports_cache_class = True
376
+ _supports_quantized_cache = True
377
+ _supports_static_cache = True
378
+
379
+ def _init_weights(self, module):
380
+ std = self.config.initializer_range
381
+ if isinstance(module, nn.Linear):
382
+ module.weight.data.normal_(mean=0.0, std=std)
383
+ if module.bias is not None:
384
+ module.bias.data.zero_()
385
+ elif isinstance(module, nn.Embedding):
386
+ module.weight.data.normal_(mean=0.0, std=std)
387
+ if module.padding_idx is not None:
388
+ module.weight.data[module.padding_idx].zero_()
389
+
390
+ SPECTRA2_INPUTS_DOCSTRING = r"""
391
+ Args:
392
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
393
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
394
+ it.
395
+
396
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
397
+ [`PreTrainedTokenizer.__call__`] for details.
398
+
399
+ [What are input IDs?](../glossary#input-ids)
400
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
401
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
402
+
403
+ - 1 for tokens that are **not masked**,
404
+ - 0 for tokens that are **masked**.
405
+
406
+ [What are attention masks?](../glossary#attention-mask)
407
+
408
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
409
+ [`PreTrainedTokenizer.__call__`] for details.
410
+
411
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
412
+ `past_key_values`).
413
+
414
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
415
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
416
+ information on the default strategy.
417
+
418
+ - 1 indicates the head is **not masked**,
419
+ - 0 indicates the head is **masked**.
420
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
421
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
422
+ config.n_positions - 1]`.
423
+
424
+ [What are position IDs?](../glossary#position-ids)
425
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
426
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
427
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
428
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
429
+
430
+ Two formats are allowed:
431
+ - a [`~cache_utils.Cache`] instance, see our
432
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
433
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
434
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
435
+ cache format.
436
+
437
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
438
+ legacy cache format will be returned.
439
+
440
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
441
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
442
+ of shape `(batch_size, sequence_length)`.
443
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
444
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
445
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
446
+ model's internal embedding lookup matrix.
447
+ use_cache (`bool`, *optional*):
448
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
449
+ `past_key_values`).
450
+ output_attentions (`bool`, *optional*):
451
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
452
+ tensors for more detail.
453
+ output_hidden_states (`bool`, *optional*):
454
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
455
+ more detail.
456
+ return_dict (`bool`, *optional*):
457
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
458
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
459
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
460
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
461
+ the complete sequence length.
462
+ """
463
+
464
+
465
+ @add_start_docstrings(
466
+ "The bare Spectra2 Model outputting raw hidden-states without any specific head on top.",
467
+ SPECTRA2_START_DOCSTRING,
468
+ )
469
+ class Spectra2Model(Spectra2PreTrainedModel):
470
+ """
471
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Spectra2DecoderLayer`]
472
+
473
+ Args:
474
+ config: Spectra2Config
475
+ """
476
+
477
+ def __init__(self, config: Spectra2Config):
478
+ super().__init__(config)
479
+ self.padding_idx = config.pad_token_id
480
+ self.vocab_size = config.vocab_size
481
+
482
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
483
+ self.layers = nn.ModuleList(
484
+ [Spectra2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
485
+ )
486
+ self.norm = Spectra2RMSLayerNorm(config.hidden_size)
487
+ self.rotary_emb = Spectra2RotaryEmbedding(config=config)
488
+ self.gradient_checkpointing = False
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ return self.embed_tokens
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.embed_tokens = value
498
+
499
+ @add_start_docstrings_to_model_forward(SPECTRA2_INPUTS_DOCSTRING)
500
+ def forward(
501
+ self,
502
+ input_ids: torch.LongTensor = None,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ position_ids: Optional[torch.LongTensor] = None,
505
+ past_key_values: Optional[Cache] = None,
506
+ inputs_embeds: Optional[torch.FloatTensor] = None,
507
+ use_cache: Optional[bool] = None,
508
+ output_attentions: Optional[bool] = None,
509
+ output_hidden_states: Optional[bool] = None,
510
+ return_dict: Optional[bool] = None,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
513
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
514
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
515
+ output_hidden_states = (
516
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
517
+ )
518
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
519
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
520
+
521
+ if (input_ids is None) ^ (inputs_embeds is not None):
522
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
523
+
524
+ if self.gradient_checkpointing and self.training and use_cache:
525
+ logger.warning_once(
526
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
527
+ )
528
+ use_cache = False
529
+
530
+ if inputs_embeds is None:
531
+ inputs_embeds = self.embed_tokens(input_ids)
532
+
533
+ if use_cache and past_key_values is None:
534
+ past_key_values = DynamicCache()
535
+
536
+ if cache_position is None:
537
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
538
+ cache_position = torch.arange(
539
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
540
+ )
541
+
542
+ if position_ids is None:
543
+ position_ids = cache_position.unsqueeze(0)
544
+
545
+ causal_mask = self._update_causal_mask(
546
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
547
+ )
548
+
549
+ hidden_states = inputs_embeds
550
+
551
+ # create position embeddings to be shared across the decoder layers
552
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
553
+
554
+ # decoder layers
555
+ all_hidden_states = () if output_hidden_states else None
556
+ all_self_attns = () if output_attentions else None
557
+
558
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
559
+ if output_hidden_states:
560
+ all_hidden_states += (hidden_states,)
561
+
562
+ if self.gradient_checkpointing and self.training:
563
+ layer_outputs = self._gradient_checkpointing_func(
564
+ decoder_layer.__call__,
565
+ hidden_states,
566
+ causal_mask,
567
+ position_ids,
568
+ past_key_values,
569
+ output_attentions,
570
+ use_cache,
571
+ cache_position,
572
+ position_embeddings,
573
+ )
574
+ else:
575
+ layer_outputs = decoder_layer(
576
+ hidden_states,
577
+ attention_mask=causal_mask,
578
+ position_ids=position_ids,
579
+ past_key_value=past_key_values,
580
+ output_attentions=output_attentions,
581
+ use_cache=use_cache,
582
+ cache_position=cache_position,
583
+ position_embeddings=position_embeddings,
584
+ **flash_attn_kwargs,
585
+ )
586
+
587
+ hidden_states = layer_outputs[0]
588
+
589
+ if output_attentions:
590
+ all_self_attns += (layer_outputs[1],)
591
+
592
+ hidden_states = self.norm(hidden_states)
593
+
594
+ # add hidden states from the last decoder layer
595
+ if output_hidden_states:
596
+ all_hidden_states += (hidden_states,)
597
+
598
+ output = BaseModelOutputWithPast(
599
+ last_hidden_state=hidden_states,
600
+ past_key_values=past_key_values if use_cache else None,
601
+ hidden_states=all_hidden_states,
602
+ attentions=all_self_attns,
603
+ )
604
+ return output if return_dict else output.to_tuple()
605
+
606
+ def _update_causal_mask(
607
+ self,
608
+ attention_mask: torch.Tensor,
609
+ input_tensor: torch.Tensor,
610
+ cache_position: torch.Tensor,
611
+ past_key_values: Cache,
612
+ output_attentions: bool,
613
+ ):
614
+ if self.config._attn_implementation == "flash_attention_2":
615
+ if attention_mask is not None and 0.0 in attention_mask:
616
+ return attention_mask
617
+ return None
618
+
619
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
620
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
621
+ # to infer the attention mask.
622
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
623
+ using_static_cache = isinstance(past_key_values, StaticCache)
624
+
625
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
626
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
627
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
628
+ attention_mask,
629
+ inputs_embeds=input_tensor,
630
+ past_key_values_length=past_seen_tokens,
631
+ is_training=self.training,
632
+ ):
633
+ return None
634
+
635
+ dtype, device = input_tensor.dtype, input_tensor.device
636
+ sequence_length = input_tensor.shape[1]
637
+ if using_static_cache:
638
+ target_length = past_key_values.get_max_cache_shape()
639
+ else:
640
+ target_length = (
641
+ attention_mask.shape[-1]
642
+ if isinstance(attention_mask, torch.Tensor)
643
+ else past_seen_tokens + sequence_length + 1
644
+ )
645
+
646
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
647
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
648
+ attention_mask,
649
+ sequence_length=sequence_length,
650
+ target_length=target_length,
651
+ dtype=dtype,
652
+ device=device,
653
+ cache_position=cache_position,
654
+ batch_size=input_tensor.shape[0],
655
+ )
656
+
657
+ if (
658
+ self.config._attn_implementation == "sdpa"
659
+ and attention_mask is not None
660
+ and attention_mask.device.type == "cuda"
661
+ and not output_attentions
662
+ ):
663
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
664
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
665
+ # Details: https://github.com/pytorch/pytorch/issues/110213
666
+ min_dtype = torch.finfo(dtype).min
667
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
668
+
669
+ return causal_mask
670
+
671
+ @staticmethod
672
+ def _prepare_4d_causal_attention_mask_with_cache_position(
673
+ attention_mask: torch.Tensor,
674
+ sequence_length: int,
675
+ target_length: int,
676
+ dtype: torch.dtype,
677
+ device: torch.device,
678
+ cache_position: torch.Tensor,
679
+ batch_size: int,
680
+ **kwargs,
681
+ ):
682
+ """
683
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
684
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
685
+
686
+ Args:
687
+ attention_mask (`torch.Tensor`):
688
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
689
+ `(batch_size, 1, query_length, key_value_length)`.
690
+ sequence_length (`int`):
691
+ The sequence length being processed.
692
+ target_length (`int`):
693
+ The target length: when generating with static cache, the mask should be as long as the static cache,
694
+ to account for the 0 padding, the part of the cache that is not filled yet.
695
+ dtype (`torch.dtype`):
696
+ The dtype to use for the 4D attention mask.
697
+ device (`torch.device`):
698
+ The device to plcae the 4D attention mask on.
699
+ cache_position (`torch.Tensor`):
700
+ Indices depicting the position of the input sequence tokens in the sequence.
701
+ batch_size (`torch.Tensor`):
702
+ Batch size.
703
+ """
704
+ if attention_mask is not None and attention_mask.dim() == 4:
705
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
706
+ causal_mask = attention_mask
707
+ else:
708
+ min_dtype = torch.finfo(dtype).min
709
+ causal_mask = torch.full(
710
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
711
+ )
712
+ if sequence_length != 1:
713
+ causal_mask = torch.triu(causal_mask, diagonal=1)
714
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
715
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
716
+ if attention_mask is not None:
717
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
718
+ mask_length = attention_mask.shape[-1]
719
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
720
+ padding_mask = padding_mask == 0
721
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
722
+ padding_mask, min_dtype
723
+ )
724
+
725
+ return causal_mask
726
+
727
+
728
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
729
+
730
+
731
+ class Spectra2ForCausalLM(Spectra2PreTrainedModel, GenerationMixin):
732
+ _tied_weights_keys = ["lm_head.weight"]
733
+ _tp_plan = {"lm_head": "colwise_rep"}
734
+
735
+ def __init__(self, config):
736
+ super().__init__(config)
737
+ self.model = Spectra2Model(config)
738
+ self.vocab_size = config.vocab_size
739
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self):
745
+ return self.model.embed_tokens
746
+
747
+ def set_input_embeddings(self, value):
748
+ self.model.embed_tokens = value
749
+
750
+ def get_output_embeddings(self):
751
+ return self.lm_head
752
+
753
+ def set_output_embeddings(self, new_embeddings):
754
+ self.lm_head = new_embeddings
755
+
756
+ def set_decoder(self, decoder):
757
+ self.model = decoder
758
+
759
+ def get_decoder(self):
760
+ return self.model
761
+
762
+ @add_start_docstrings_to_model_forward(SPECTRA2_INPUTS_DOCSTRING)
763
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
764
+ def forward(
765
+ self,
766
+ input_ids: torch.LongTensor = None,
767
+ attention_mask: Optional[torch.Tensor] = None,
768
+ position_ids: Optional[torch.LongTensor] = None,
769
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
770
+ inputs_embeds: Optional[torch.FloatTensor] = None,
771
+ labels: Optional[torch.LongTensor] = None,
772
+ use_cache: Optional[bool] = None,
773
+ output_attentions: Optional[bool] = None,
774
+ output_hidden_states: Optional[bool] = None,
775
+ return_dict: Optional[bool] = None,
776
+ cache_position: Optional[torch.LongTensor] = None,
777
+ num_logits_to_keep: int = 0,
778
+ **kwargs: Unpack[KwargsForCausalLM],
779
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
780
+ r"""
781
+ Args:
782
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
783
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
784
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
785
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
786
+
787
+ num_logits_to_keep (`int`, *optional*):
788
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
789
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
790
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
791
+
792
+ Returns:
793
+
794
+ Example:
795
+
796
+ ```python
797
+ >>> from transformers import AutoTokenizer, Spectra2ForCausalLM
798
+
799
+ >>> model = Spectra2ForCausalLM.from_pretrained("SpectraSuite/Spectra2-3B-base")
800
+ >>> tokenizer = AutoTokenizer.from_pretrained("SpectraSuite/Spectra2-3B-base")
801
+
802
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
803
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
804
+
805
+ >>> # Generate
806
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
807
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
808
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
809
+ ```"""
810
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
811
+ output_hidden_states = (
812
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
813
+ )
814
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
815
+
816
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
817
+ outputs = self.model(
818
+ input_ids=input_ids,
819
+ attention_mask=attention_mask,
820
+ position_ids=position_ids,
821
+ past_key_values=past_key_values,
822
+ inputs_embeds=inputs_embeds,
823
+ use_cache=use_cache,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ cache_position=cache_position,
828
+ **kwargs,
829
+ )
830
+
831
+ hidden_states = outputs[0]
832
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
833
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
834
+
835
+ loss = None
836
+ if labels is not None:
837
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
838
+
839
+ if not return_dict:
840
+ output = (logits,) + outputs[1:]
841
+ return (loss,) + output if loss is not None else output
842
+
843
+ return CausalLMOutputWithPast(
844
+ loss=loss,
845
+ logits=logits,
846
+ past_key_values=outputs.past_key_values,
847
+ hidden_states=outputs.hidden_states,
848
+ attentions=outputs.attentions,
849
+ )
17_99p_300t/step100000/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
17_99p_300t/step100000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
17_99p_300t/step110000/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ `pip install git+https://github.com/huggingface/transformers.git@05260a1`
17_99p_300t/step110000/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Spectra2ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "clip_qkv": null,
8
+ "eos_token_id": 50277,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 512,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1280,
13
+ "max_position_embeddings": 2048,
14
+ "model_type": "spectra2",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 16,
17
+ "num_key_value_heads": 8,
18
+ "pad_token_id": 1,
19
+ "rope_scaling": null,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "torch_dtype": "float32",
23
+ "use_cache": true,
24
+ "vocab_size": 50304,
25
+ "auto_map": {
26
+ "AutoConfig": "configuration_spectra2.Spectra2Config",
27
+ "AutoModel": "modeling_spectra2.Spectra2Model",
28
+ "AutoModelForCausalLM": "modeling_spectra2.Spectra2ForCausalLM"
29
+ }
30
+ }
17_99p_300t/step110000/configuration_spectra2.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Spectra2 model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class Spectra2Config(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`Spectra2Model`]. It is used to instantiate an Spectra2
32
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
+ defaults will yield a similar configuration to that of the [SpectraSuite/Spectra2-3B-base](https://huggingface.co/spectrasuite/Spectra2-3B-base).
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 50304):
38
+ Vocabulary size of the Spectra2 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`Spectra2Model`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 11008):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer decoder.
48
+ num_key_value_heads (`int`, *optional*):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
55
+ `num_attention_heads`.
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
57
+ The non-linear activation function (function or string) in the decoder.
58
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
59
+ The maximum sequence length that this model might ever be used with.
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ use_cache (`bool`, *optional*, defaults to `True`):
63
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
64
+ relevant if `config.is_decoder=True`.
65
+ pad_token_id (`int`, *optional*, defaults to 1):
66
+ Padding token id.
67
+ bos_token_id (`int`, *optional*):
68
+ Beginning of stream token id.
69
+ eos_token_id (`int`, *optional*, defaults to 50279):
70
+ End of stream token id.
71
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
72
+ Whether to tie weight embeddings
73
+ rope_theta (`float`, *optional*, defaults to 10000.0):
74
+ The base period of the RoPE embeddings.
75
+ rope_scaling (`Dict`, *optional*):
76
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
77
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
78
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
79
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
80
+ these scaling strategies behave:
81
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
82
+ experimental feature, subject to breaking API changes in future versions.
83
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
84
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
85
+ attention_dropout (`float`, *optional*, defaults to 0.0):
86
+ The dropout ratio for the attention probabilities.
87
+ clip_qkv (`float`, *optional*):
88
+ If not `None`, elements of query, key and value attention states are clipped so that their
89
+ absolute value does not exceed this value.
90
+ ```python
91
+ >>> from transformers import Spectra2Model, Spectra2Config
92
+ >>> # Initializing a Spectra2 3B style configuration
93
+ >>> configuration = Spectra2Config()
94
+ >>> # Initializing a model from the Spectra2 3B style configuration
95
+ >>> model = Spectra2Model(configuration)
96
+ >>> # Accessing the model configuration
97
+ >>> configuration = model.config
98
+ ```"""
99
+
100
+ model_type = "spectra2"
101
+ keys_to_ignore_at_inference = ["past_key_values"]
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=50304,
106
+ hidden_size=4096,
107
+ intermediate_size=11008,
108
+ num_hidden_layers=32,
109
+ num_attention_heads=32,
110
+ num_key_value_heads=None,
111
+ hidden_act="silu",
112
+ max_position_embeddings=2048,
113
+ initializer_range=0.02,
114
+ use_cache=True,
115
+ pad_token_id=1,
116
+ bos_token_id=None,
117
+ eos_token_id=50279,
118
+ tie_word_embeddings=False,
119
+ rope_theta=10000.0,
120
+ rope_scaling=None,
121
+ attention_bias=False,
122
+ attention_dropout=0.0,
123
+ clip_qkv=None,
124
+ **kwargs,
125
+ ):
126
+ self.vocab_size = vocab_size
127
+ self.max_position_embeddings = max_position_embeddings
128
+ self.hidden_size = hidden_size
129
+ self.intermediate_size = intermediate_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.num_attention_heads = num_attention_heads
132
+
133
+ # for backward compatibility
134
+ if num_key_value_heads is None:
135
+ num_key_value_heads = num_attention_heads
136
+
137
+ self.num_key_value_heads = num_key_value_heads
138
+ self.hidden_act = hidden_act
139
+ self.initializer_range = initializer_range
140
+ self.use_cache = use_cache
141
+ self.rope_theta = rope_theta
142
+ self.rope_scaling = rope_scaling
143
+ self._rope_scaling_validation()
144
+ self.attention_bias = attention_bias
145
+ self.attention_dropout = attention_dropout
146
+ self.clip_qkv = clip_qkv
147
+
148
+ super().__init__(
149
+ pad_token_id=pad_token_id,
150
+ bos_token_id=bos_token_id,
151
+ eos_token_id=eos_token_id,
152
+ tie_word_embeddings=tie_word_embeddings,
153
+ **kwargs,
154
+ )
155
+
156
+ def _rope_scaling_validation(self):
157
+ """
158
+ Validate the `rope_scaling` configuration.
159
+ """
160
+ if self.rope_scaling is None:
161
+ return
162
+
163
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
164
+ raise ValueError(
165
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
166
+ )
167
+ rope_scaling_type = self.rope_scaling.get("type", None)
168
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
169
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
170
+ raise ValueError(
171
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
172
+ )
173
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
174
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
17_99p_300t/step110000/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 50277,
4
+ "pad_token_id": 1,
5
+ "transformers_version": "4.45.2"
6
+ }
17_99p_300t/step110000/modeling_spectra2.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/spectra2/modular_spectra2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_spectra2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
+ from transformers.generation import GenerationMixin
16
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
17
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import (
23
+ LossKwargs,
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ replace_return_docstrings,
28
+ )
29
+ from .configuration_spectra2 import Spectra2Config
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+ _CONFIG_FOR_DOC = "Spectra2Config"
34
+
35
+
36
+ class Spectra2RMSLayerNorm(nn.Module):
37
+ """LayerNorm but with no learnable weight or bias."""
38
+
39
+ def __init__(self, hidden_size: int) -> None:
40
+ super().__init__()
41
+ self.weight = nn.Parameter(torch.ones(hidden_size))
42
+ self.variance_epsilon = 1e-05 # Hardcoded
43
+ self.normalized_shape = (hidden_size,)
44
+
45
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
46
+ with torch.autocast(enabled=False, device_type=hidden_states.device.type):
47
+ og_dtype = hidden_states.dtype
48
+ hidden_states = hidden_states.to(torch.float32)
49
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
50
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
51
+ hidden_states = hidden_states.to(og_dtype)
52
+ return self.weight * hidden_states
53
+
54
+
55
+ class Spectra2MLP(nn.Module):
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.config = config
59
+ self.hidden_size = config.hidden_size
60
+ self.intermediate_size = config.intermediate_size
61
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
62
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
63
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
64
+ self.act_fn = ACT2FN[config.hidden_act]
65
+
66
+ def forward(self, x):
67
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
68
+ return down_proj
69
+
70
+
71
+ def rotate_half(x):
72
+ """Rotates half the hidden dims of the input."""
73
+ x1 = x[..., : x.shape[-1] // 2]
74
+ x2 = x[..., x.shape[-1] // 2 :]
75
+ return torch.cat((-x2, x1), dim=-1)
76
+
77
+
78
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
79
+ """Applies Rotary Position Embedding to the query and key tensors.
80
+
81
+ Args:
82
+ q (`torch.Tensor`): The query tensor.
83
+ k (`torch.Tensor`): The key tensor.
84
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
85
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
86
+ position_ids (`torch.Tensor`, *optional*):
87
+ Deprecated and unused.
88
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
89
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
90
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
91
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
92
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
93
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
94
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
95
+ Returns:
96
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
97
+ """
98
+ cos = cos.unsqueeze(unsqueeze_dim)
99
+ sin = sin.unsqueeze(unsqueeze_dim)
100
+ q_embed = (q * cos) + (rotate_half(q) * sin)
101
+ k_embed = (k * cos) + (rotate_half(k) * sin)
102
+ return q_embed, k_embed
103
+
104
+
105
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
106
+ """
107
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
108
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
109
+ """
110
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
111
+ if n_rep == 1:
112
+ return hidden_states
113
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
114
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
115
+
116
+
117
+ def eager_attention_forward(
118
+ module: nn.Module,
119
+ query: torch.Tensor,
120
+ key: torch.Tensor,
121
+ value: torch.Tensor,
122
+ attention_mask: Optional[torch.Tensor],
123
+ scaling: float,
124
+ dropout: float = 0.0,
125
+ **kwargs,
126
+ ):
127
+ key_states = repeat_kv(key, module.num_key_value_groups)
128
+ value_states = repeat_kv(value, module.num_key_value_groups)
129
+
130
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
131
+ if attention_mask is not None:
132
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
133
+ attn_weights = attn_weights + causal_mask
134
+
135
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
136
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
137
+ attn_output = torch.matmul(attn_weights, value_states)
138
+ attn_output = attn_output.transpose(1, 2).contiguous()
139
+
140
+ return attn_output, attn_weights
141
+
142
+
143
+ class Spectra2Attention(nn.Module):
144
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
145
+
146
+ def __init__(self, config: Spectra2Config, layer_idx: int):
147
+ super().__init__()
148
+ self.config = config
149
+ self.layer_idx = layer_idx
150
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
151
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
152
+ self.scaling = self.head_dim**-0.5
153
+ self.attention_dropout = config.attention_dropout
154
+ self.is_causal = True
155
+
156
+ self.q_proj = nn.Linear(
157
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
158
+ )
159
+ self.k_proj = nn.Linear(
160
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
161
+ )
162
+ self.v_proj = nn.Linear(
163
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
164
+ )
165
+ self.o_proj = nn.Linear(
166
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
167
+ )
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
173
+ attention_mask: Optional[torch.Tensor],
174
+ past_key_value: Optional[Cache] = None,
175
+ cache_position: Optional[torch.LongTensor] = None,
176
+ **kwargs,
177
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
178
+ input_shape = hidden_states.shape[:-1]
179
+ hidden_shape = (*input_shape, -1, self.head_dim)
180
+
181
+ query_states = self.q_proj(hidden_states)
182
+ key_states = self.k_proj(hidden_states)
183
+ value_states = self.v_proj(hidden_states)
184
+
185
+ if self.config.clip_qkv is not None:
186
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
187
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
188
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
189
+
190
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
191
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
192
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
193
+
194
+ cos, sin = position_embeddings
195
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
196
+
197
+ if past_key_value is not None:
198
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
199
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
200
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
201
+
202
+ attention_interface: Callable = eager_attention_forward
203
+ if self.config._attn_implementation != "eager":
204
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
205
+ logger.warning_once(
206
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
207
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
208
+ )
209
+ else:
210
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
211
+
212
+ attn_output, attn_weights = attention_interface(
213
+ self,
214
+ query_states,
215
+ key_states,
216
+ value_states,
217
+ attention_mask,
218
+ dropout=0.0 if not self.training else self.attention_dropout,
219
+ scaling=self.scaling,
220
+ **kwargs,
221
+ )
222
+
223
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
224
+ attn_output = self.o_proj(attn_output)
225
+ return attn_output, attn_weights
226
+
227
+
228
+ class Spectra2DecoderLayer(nn.Module):
229
+ def __init__(self, config: Spectra2Config, layer_idx: int):
230
+ super().__init__()
231
+ self.hidden_size = config.hidden_size
232
+ self.self_attn = Spectra2Attention(config=config, layer_idx=layer_idx)
233
+
234
+ self.mlp = Spectra2MLP(config)
235
+ self.input_rms_layernorm = Spectra2RMSLayerNorm(config.hidden_size)
236
+ self.post_attention_rms_layernorm = Spectra2RMSLayerNorm(config.hidden_size)
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ position_ids: Optional[torch.LongTensor] = None,
243
+ past_key_value: Optional[Cache] = None,
244
+ output_attentions: Optional[bool] = False,
245
+ use_cache: Optional[bool] = False,
246
+ cache_position: Optional[torch.LongTensor] = None,
247
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
248
+ **kwargs: Unpack[FlashAttentionKwargs],
249
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
250
+ residual = hidden_states
251
+
252
+ hidden_states = self.input_rms_layernorm(hidden_states)
253
+
254
+ # Self Attention
255
+ hidden_states, self_attn_weights = self.self_attn(
256
+ hidden_states=hidden_states,
257
+ attention_mask=attention_mask,
258
+ position_ids=position_ids,
259
+ past_key_value=past_key_value,
260
+ output_attentions=output_attentions,
261
+ use_cache=use_cache,
262
+ cache_position=cache_position,
263
+ position_embeddings=position_embeddings,
264
+ **kwargs,
265
+ )
266
+ hidden_states = residual + hidden_states
267
+
268
+ # Fully Connected
269
+ residual = hidden_states
270
+ hidden_states = self.post_attention_rms_layernorm(hidden_states)
271
+ hidden_states = self.mlp(hidden_states)
272
+ hidden_states = residual + hidden_states
273
+
274
+ outputs = (hidden_states,)
275
+ if output_attentions:
276
+ outputs += (self_attn_weights,)
277
+
278
+ return outputs
279
+
280
+
281
+ class Spectra2RotaryEmbedding(nn.Module):
282
+ def __init__(
283
+ self,
284
+ config: Spectra2Config,
285
+ device=None,
286
+ ):
287
+ super().__init__()
288
+ self.rope_kwargs = {}
289
+ # BC: "rope_type" was originally "type"
290
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
291
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
292
+ else:
293
+ self.rope_type = "default"
294
+ self.max_seq_len_cached = config.max_position_embeddings
295
+ self.original_max_seq_len = config.max_position_embeddings
296
+
297
+ self.config = config
298
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
299
+
300
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
301
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
302
+ self.original_inv_freq = self.inv_freq
303
+
304
+ def _dynamic_frequency_update(self, position_ids, device):
305
+ """
306
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
307
+ 1 - growing beyond the cached sequence length (allow scaling)
308
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
309
+ """
310
+ seq_len = torch.max(position_ids) + 1
311
+ if seq_len > self.max_seq_len_cached: # growth
312
+ inv_freq, self.attention_scaling = self.rope_init_fn(
313
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
314
+ )
315
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
316
+ self.max_seq_len_cached = seq_len
317
+
318
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
319
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
320
+ self.max_seq_len_cached = self.original_max_seq_len
321
+
322
+ @torch.no_grad()
323
+ def forward(self, x, position_ids):
324
+ if "dynamic" in self.rope_type:
325
+ self._dynamic_frequency_update(position_ids, device=x.device)
326
+
327
+ # Core RoPE block
328
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
329
+ position_ids_expanded = position_ids[:, None, :].float()
330
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
331
+ device_type = x.device.type
332
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
333
+ with torch.autocast(device_type=device_type, enabled=False):
334
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
335
+ emb = torch.cat((freqs, freqs), dim=-1)
336
+ cos = emb.cos()
337
+ sin = emb.sin()
338
+
339
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
340
+ cos = cos * self.attention_scaling
341
+ sin = sin * self.attention_scaling
342
+
343
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
344
+
345
+
346
+ SPECTRA2_START_DOCSTRING = r"""
347
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
348
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
349
+ etc.)
350
+
351
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
352
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
353
+ and behavior.
354
+
355
+ Parameters:
356
+ config ([`Spectra2Config`]):
357
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
358
+ load the weights associated with the model, only the configuration. Check out the
359
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
360
+ """
361
+
362
+
363
+ @add_start_docstrings(
364
+ "The bare Spectra2 Model outputting raw hidden-states without any specific head on top.",
365
+ SPECTRA2_START_DOCSTRING,
366
+ )
367
+ class Spectra2PreTrainedModel(PreTrainedModel):
368
+ config_class = Spectra2Config
369
+ base_model_prefix = "model"
370
+ supports_gradient_checkpointing = True
371
+ _no_split_modules = ["Spectra2DecoderLayer"]
372
+ _skip_keys_device_placement = ["past_key_values"]
373
+ _supports_flash_attn_2 = True
374
+ _supports_sdpa = True
375
+ _supports_cache_class = True
376
+ _supports_quantized_cache = True
377
+ _supports_static_cache = True
378
+
379
+ def _init_weights(self, module):
380
+ std = self.config.initializer_range
381
+ if isinstance(module, nn.Linear):
382
+ module.weight.data.normal_(mean=0.0, std=std)
383
+ if module.bias is not None:
384
+ module.bias.data.zero_()
385
+ elif isinstance(module, nn.Embedding):
386
+ module.weight.data.normal_(mean=0.0, std=std)
387
+ if module.padding_idx is not None:
388
+ module.weight.data[module.padding_idx].zero_()
389
+
390
+ SPECTRA2_INPUTS_DOCSTRING = r"""
391
+ Args:
392
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
393
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
394
+ it.
395
+
396
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
397
+ [`PreTrainedTokenizer.__call__`] for details.
398
+
399
+ [What are input IDs?](../glossary#input-ids)
400
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
401
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
402
+
403
+ - 1 for tokens that are **not masked**,
404
+ - 0 for tokens that are **masked**.
405
+
406
+ [What are attention masks?](../glossary#attention-mask)
407
+
408
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
409
+ [`PreTrainedTokenizer.__call__`] for details.
410
+
411
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
412
+ `past_key_values`).
413
+
414
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
415
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
416
+ information on the default strategy.
417
+
418
+ - 1 indicates the head is **not masked**,
419
+ - 0 indicates the head is **masked**.
420
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
421
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
422
+ config.n_positions - 1]`.
423
+
424
+ [What are position IDs?](../glossary#position-ids)
425
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
426
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
427
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
428
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
429
+
430
+ Two formats are allowed:
431
+ - a [`~cache_utils.Cache`] instance, see our
432
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
433
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
434
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
435
+ cache format.
436
+
437
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
438
+ legacy cache format will be returned.
439
+
440
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
441
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
442
+ of shape `(batch_size, sequence_length)`.
443
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
444
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
445
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
446
+ model's internal embedding lookup matrix.
447
+ use_cache (`bool`, *optional*):
448
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
449
+ `past_key_values`).
450
+ output_attentions (`bool`, *optional*):
451
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
452
+ tensors for more detail.
453
+ output_hidden_states (`bool`, *optional*):
454
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
455
+ more detail.
456
+ return_dict (`bool`, *optional*):
457
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
458
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
459
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
460
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
461
+ the complete sequence length.
462
+ """
463
+
464
+
465
+ @add_start_docstrings(
466
+ "The bare Spectra2 Model outputting raw hidden-states without any specific head on top.",
467
+ SPECTRA2_START_DOCSTRING,
468
+ )
469
+ class Spectra2Model(Spectra2PreTrainedModel):
470
+ """
471
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Spectra2DecoderLayer`]
472
+
473
+ Args:
474
+ config: Spectra2Config
475
+ """
476
+
477
+ def __init__(self, config: Spectra2Config):
478
+ super().__init__(config)
479
+ self.padding_idx = config.pad_token_id
480
+ self.vocab_size = config.vocab_size
481
+
482
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
483
+ self.layers = nn.ModuleList(
484
+ [Spectra2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
485
+ )
486
+ self.norm = Spectra2RMSLayerNorm(config.hidden_size)
487
+ self.rotary_emb = Spectra2RotaryEmbedding(config=config)
488
+ self.gradient_checkpointing = False
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ return self.embed_tokens
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.embed_tokens = value
498
+
499
+ @add_start_docstrings_to_model_forward(SPECTRA2_INPUTS_DOCSTRING)
500
+ def forward(
501
+ self,
502
+ input_ids: torch.LongTensor = None,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ position_ids: Optional[torch.LongTensor] = None,
505
+ past_key_values: Optional[Cache] = None,
506
+ inputs_embeds: Optional[torch.FloatTensor] = None,
507
+ use_cache: Optional[bool] = None,
508
+ output_attentions: Optional[bool] = None,
509
+ output_hidden_states: Optional[bool] = None,
510
+ return_dict: Optional[bool] = None,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
513
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
514
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
515
+ output_hidden_states = (
516
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
517
+ )
518
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
519
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
520
+
521
+ if (input_ids is None) ^ (inputs_embeds is not None):
522
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
523
+
524
+ if self.gradient_checkpointing and self.training and use_cache:
525
+ logger.warning_once(
526
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
527
+ )
528
+ use_cache = False
529
+
530
+ if inputs_embeds is None:
531
+ inputs_embeds = self.embed_tokens(input_ids)
532
+
533
+ if use_cache and past_key_values is None:
534
+ past_key_values = DynamicCache()
535
+
536
+ if cache_position is None:
537
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
538
+ cache_position = torch.arange(
539
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
540
+ )
541
+
542
+ if position_ids is None:
543
+ position_ids = cache_position.unsqueeze(0)
544
+
545
+ causal_mask = self._update_causal_mask(
546
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
547
+ )
548
+
549
+ hidden_states = inputs_embeds
550
+
551
+ # create position embeddings to be shared across the decoder layers
552
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
553
+
554
+ # decoder layers
555
+ all_hidden_states = () if output_hidden_states else None
556
+ all_self_attns = () if output_attentions else None
557
+
558
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
559
+ if output_hidden_states:
560
+ all_hidden_states += (hidden_states,)
561
+
562
+ if self.gradient_checkpointing and self.training:
563
+ layer_outputs = self._gradient_checkpointing_func(
564
+ decoder_layer.__call__,
565
+ hidden_states,
566
+ causal_mask,
567
+ position_ids,
568
+ past_key_values,
569
+ output_attentions,
570
+ use_cache,
571
+ cache_position,
572
+ position_embeddings,
573
+ )
574
+ else:
575
+ layer_outputs = decoder_layer(
576
+ hidden_states,
577
+ attention_mask=causal_mask,
578
+ position_ids=position_ids,
579
+ past_key_value=past_key_values,
580
+ output_attentions=output_attentions,
581
+ use_cache=use_cache,
582
+ cache_position=cache_position,
583
+ position_embeddings=position_embeddings,
584
+ **flash_attn_kwargs,
585
+ )
586
+
587
+ hidden_states = layer_outputs[0]
588
+
589
+ if output_attentions:
590
+ all_self_attns += (layer_outputs[1],)
591
+
592
+ hidden_states = self.norm(hidden_states)
593
+
594
+ # add hidden states from the last decoder layer
595
+ if output_hidden_states:
596
+ all_hidden_states += (hidden_states,)
597
+
598
+ output = BaseModelOutputWithPast(
599
+ last_hidden_state=hidden_states,
600
+ past_key_values=past_key_values if use_cache else None,
601
+ hidden_states=all_hidden_states,
602
+ attentions=all_self_attns,
603
+ )
604
+ return output if return_dict else output.to_tuple()
605
+
606
+ def _update_causal_mask(
607
+ self,
608
+ attention_mask: torch.Tensor,
609
+ input_tensor: torch.Tensor,
610
+ cache_position: torch.Tensor,
611
+ past_key_values: Cache,
612
+ output_attentions: bool,
613
+ ):
614
+ if self.config._attn_implementation == "flash_attention_2":
615
+ if attention_mask is not None and 0.0 in attention_mask:
616
+ return attention_mask
617
+ return None
618
+
619
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
620
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
621
+ # to infer the attention mask.
622
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
623
+ using_static_cache = isinstance(past_key_values, StaticCache)
624
+
625
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
626
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
627
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
628
+ attention_mask,
629
+ inputs_embeds=input_tensor,
630
+ past_key_values_length=past_seen_tokens,
631
+ is_training=self.training,
632
+ ):
633
+ return None
634
+
635
+ dtype, device = input_tensor.dtype, input_tensor.device
636
+ sequence_length = input_tensor.shape[1]
637
+ if using_static_cache:
638
+ target_length = past_key_values.get_max_cache_shape()
639
+ else:
640
+ target_length = (
641
+ attention_mask.shape[-1]
642
+ if isinstance(attention_mask, torch.Tensor)
643
+ else past_seen_tokens + sequence_length + 1
644
+ )
645
+
646
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
647
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
648
+ attention_mask,
649
+ sequence_length=sequence_length,
650
+ target_length=target_length,
651
+ dtype=dtype,
652
+ device=device,
653
+ cache_position=cache_position,
654
+ batch_size=input_tensor.shape[0],
655
+ )
656
+
657
+ if (
658
+ self.config._attn_implementation == "sdpa"
659
+ and attention_mask is not None
660
+ and attention_mask.device.type == "cuda"
661
+ and not output_attentions
662
+ ):
663
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
664
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
665
+ # Details: https://github.com/pytorch/pytorch/issues/110213
666
+ min_dtype = torch.finfo(dtype).min
667
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
668
+
669
+ return causal_mask
670
+
671
+ @staticmethod
672
+ def _prepare_4d_causal_attention_mask_with_cache_position(
673
+ attention_mask: torch.Tensor,
674
+ sequence_length: int,
675
+ target_length: int,
676
+ dtype: torch.dtype,
677
+ device: torch.device,
678
+ cache_position: torch.Tensor,
679
+ batch_size: int,
680
+ **kwargs,
681
+ ):
682
+ """
683
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
684
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
685
+
686
+ Args:
687
+ attention_mask (`torch.Tensor`):
688
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
689
+ `(batch_size, 1, query_length, key_value_length)`.
690
+ sequence_length (`int`):
691
+ The sequence length being processed.
692
+ target_length (`int`):
693
+ The target length: when generating with static cache, the mask should be as long as the static cache,
694
+ to account for the 0 padding, the part of the cache that is not filled yet.
695
+ dtype (`torch.dtype`):
696
+ The dtype to use for the 4D attention mask.
697
+ device (`torch.device`):
698
+ The device to plcae the 4D attention mask on.
699
+ cache_position (`torch.Tensor`):
700
+ Indices depicting the position of the input sequence tokens in the sequence.
701
+ batch_size (`torch.Tensor`):
702
+ Batch size.
703
+ """
704
+ if attention_mask is not None and attention_mask.dim() == 4:
705
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
706
+ causal_mask = attention_mask
707
+ else:
708
+ min_dtype = torch.finfo(dtype).min
709
+ causal_mask = torch.full(
710
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
711
+ )
712
+ if sequence_length != 1:
713
+ causal_mask = torch.triu(causal_mask, diagonal=1)
714
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
715
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
716
+ if attention_mask is not None:
717
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
718
+ mask_length = attention_mask.shape[-1]
719
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
720
+ padding_mask = padding_mask == 0
721
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
722
+ padding_mask, min_dtype
723
+ )
724
+
725
+ return causal_mask
726
+
727
+
728
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
729
+
730
+
731
+ class Spectra2ForCausalLM(Spectra2PreTrainedModel, GenerationMixin):
732
+ _tied_weights_keys = ["lm_head.weight"]
733
+ _tp_plan = {"lm_head": "colwise_rep"}
734
+
735
+ def __init__(self, config):
736
+ super().__init__(config)
737
+ self.model = Spectra2Model(config)
738
+ self.vocab_size = config.vocab_size
739
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self):
745
+ return self.model.embed_tokens
746
+
747
+ def set_input_embeddings(self, value):
748
+ self.model.embed_tokens = value
749
+
750
+ def get_output_embeddings(self):
751
+ return self.lm_head
752
+
753
+ def set_output_embeddings(self, new_embeddings):
754
+ self.lm_head = new_embeddings
755
+
756
+ def set_decoder(self, decoder):
757
+ self.model = decoder
758
+
759
+ def get_decoder(self):
760
+ return self.model
761
+
762
+ @add_start_docstrings_to_model_forward(SPECTRA2_INPUTS_DOCSTRING)
763
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
764
+ def forward(
765
+ self,
766
+ input_ids: torch.LongTensor = None,
767
+ attention_mask: Optional[torch.Tensor] = None,
768
+ position_ids: Optional[torch.LongTensor] = None,
769
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
770
+ inputs_embeds: Optional[torch.FloatTensor] = None,
771
+ labels: Optional[torch.LongTensor] = None,
772
+ use_cache: Optional[bool] = None,
773
+ output_attentions: Optional[bool] = None,
774
+ output_hidden_states: Optional[bool] = None,
775
+ return_dict: Optional[bool] = None,
776
+ cache_position: Optional[torch.LongTensor] = None,
777
+ num_logits_to_keep: int = 0,
778
+ **kwargs: Unpack[KwargsForCausalLM],
779
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
780
+ r"""
781
+ Args:
782
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
783
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
784
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
785
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
786
+
787
+ num_logits_to_keep (`int`, *optional*):
788
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
789
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
790
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
791
+
792
+ Returns:
793
+
794
+ Example:
795
+
796
+ ```python
797
+ >>> from transformers import AutoTokenizer, Spectra2ForCausalLM
798
+
799
+ >>> model = Spectra2ForCausalLM.from_pretrained("SpectraSuite/Spectra2-3B-base")
800
+ >>> tokenizer = AutoTokenizer.from_pretrained("SpectraSuite/Spectra2-3B-base")
801
+
802
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
803
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
804
+
805
+ >>> # Generate
806
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
807
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
808
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
809
+ ```"""
810
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
811
+ output_hidden_states = (
812
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
813
+ )
814
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
815
+
816
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
817
+ outputs = self.model(
818
+ input_ids=input_ids,
819
+ attention_mask=attention_mask,
820
+ position_ids=position_ids,
821
+ past_key_values=past_key_values,
822
+ inputs_embeds=inputs_embeds,
823
+ use_cache=use_cache,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ cache_position=cache_position,
828
+ **kwargs,
829
+ )
830
+
831
+ hidden_states = outputs[0]
832
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
833
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
834
+
835
+ loss = None
836
+ if labels is not None:
837
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
838
+
839
+ if not return_dict:
840
+ output = (logits,) + outputs[1:]
841
+ return (loss,) + output if loss is not None else output
842
+
843
+ return CausalLMOutputWithPast(
844
+ loss=loss,
845
+ logits=logits,
846
+ past_key_values=outputs.past_key_values,
847
+ hidden_states=outputs.hidden_states,
848
+ attentions=outputs.attentions,
849
+ )
17_99p_300t/step110000/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
17_99p_300t/step110000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
17_99p_300t/step110000/tokenizer_config.json ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ }
206
+ },
207
+ "bos_token": "<|endoftext|>",
208
+ "clean_up_tokenization_spaces": true,
209
+ "eos_token": "<|endoftext|>",
210
+ "model_max_length": 1000000000000000019884624838656,
211
+ "pad_token": null,
212
+ "tokenizer_class": "GPTNeoXTokenizer",
213
+ "unk_token": "<|endoftext|>"
214
+ }
17_99p_300t/step140000/modeling_spectra2.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/spectra2/modular_spectra2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_spectra2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
+ from transformers.generation import GenerationMixin
16
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
17
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import (
23
+ LossKwargs,
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ replace_return_docstrings,
28
+ )
29
+ from .configuration_spectra2 import Spectra2Config
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+ _CONFIG_FOR_DOC = "Spectra2Config"
34
+
35
+
36
+ class Spectra2RMSLayerNorm(nn.Module):
37
+ """LayerNorm but with no learnable weight or bias."""
38
+
39
+ def __init__(self, hidden_size: int) -> None:
40
+ super().__init__()
41
+ self.weight = nn.Parameter(torch.ones(hidden_size))
42
+ self.variance_epsilon = 1e-05 # Hardcoded
43
+ self.normalized_shape = (hidden_size,)
44
+
45
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
46
+ with torch.autocast(enabled=False, device_type=hidden_states.device.type):
47
+ og_dtype = hidden_states.dtype
48
+ hidden_states = hidden_states.to(torch.float32)
49
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
50
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
51
+ hidden_states = hidden_states.to(og_dtype)
52
+ return self.weight * hidden_states
53
+
54
+
55
+ class Spectra2MLP(nn.Module):
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.config = config
59
+ self.hidden_size = config.hidden_size
60
+ self.intermediate_size = config.intermediate_size
61
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
62
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
63
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
64
+ self.act_fn = ACT2FN[config.hidden_act]
65
+
66
+ def forward(self, x):
67
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
68
+ return down_proj
69
+
70
+
71
+ def rotate_half(x):
72
+ """Rotates half the hidden dims of the input."""
73
+ x1 = x[..., : x.shape[-1] // 2]
74
+ x2 = x[..., x.shape[-1] // 2 :]
75
+ return torch.cat((-x2, x1), dim=-1)
76
+
77
+
78
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
79
+ """Applies Rotary Position Embedding to the query and key tensors.
80
+
81
+ Args:
82
+ q (`torch.Tensor`): The query tensor.
83
+ k (`torch.Tensor`): The key tensor.
84
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
85
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
86
+ position_ids (`torch.Tensor`, *optional*):
87
+ Deprecated and unused.
88
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
89
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
90
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
91
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
92
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
93
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
94
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
95
+ Returns:
96
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
97
+ """
98
+ cos = cos.unsqueeze(unsqueeze_dim)
99
+ sin = sin.unsqueeze(unsqueeze_dim)
100
+ q_embed = (q * cos) + (rotate_half(q) * sin)
101
+ k_embed = (k * cos) + (rotate_half(k) * sin)
102
+ return q_embed, k_embed
103
+
104
+
105
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
106
+ """
107
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
108
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
109
+ """
110
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
111
+ if n_rep == 1:
112
+ return hidden_states
113
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
114
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
115
+
116
+
117
+ def eager_attention_forward(
118
+ module: nn.Module,
119
+ query: torch.Tensor,
120
+ key: torch.Tensor,
121
+ value: torch.Tensor,
122
+ attention_mask: Optional[torch.Tensor],
123
+ scaling: float,
124
+ dropout: float = 0.0,
125
+ **kwargs,
126
+ ):
127
+ key_states = repeat_kv(key, module.num_key_value_groups)
128
+ value_states = repeat_kv(value, module.num_key_value_groups)
129
+
130
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
131
+ if attention_mask is not None:
132
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
133
+ attn_weights = attn_weights + causal_mask
134
+
135
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
136
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
137
+ attn_output = torch.matmul(attn_weights, value_states)
138
+ attn_output = attn_output.transpose(1, 2).contiguous()
139
+
140
+ return attn_output, attn_weights
141
+
142
+
143
+ class Spectra2Attention(nn.Module):
144
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
145
+
146
+ def __init__(self, config: Spectra2Config, layer_idx: int):
147
+ super().__init__()
148
+ self.config = config
149
+ self.layer_idx = layer_idx
150
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
151
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
152
+ self.scaling = self.head_dim**-0.5
153
+ self.attention_dropout = config.attention_dropout
154
+ self.is_causal = True
155
+
156
+ self.q_proj = nn.Linear(
157
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
158
+ )
159
+ self.k_proj = nn.Linear(
160
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
161
+ )
162
+ self.v_proj = nn.Linear(
163
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
164
+ )
165
+ self.o_proj = nn.Linear(
166
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
167
+ )
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
173
+ attention_mask: Optional[torch.Tensor],
174
+ past_key_value: Optional[Cache] = None,
175
+ cache_position: Optional[torch.LongTensor] = None,
176
+ **kwargs,
177
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
178
+ input_shape = hidden_states.shape[:-1]
179
+ hidden_shape = (*input_shape, -1, self.head_dim)
180
+
181
+ query_states = self.q_proj(hidden_states)
182
+ key_states = self.k_proj(hidden_states)
183
+ value_states = self.v_proj(hidden_states)
184
+
185
+ if self.config.clip_qkv is not None:
186
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
187
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
188
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
189
+
190
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
191
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
192
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
193
+
194
+ cos, sin = position_embeddings
195
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
196
+
197
+ if past_key_value is not None:
198
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
199
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
200
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
201
+
202
+ attention_interface: Callable = eager_attention_forward
203
+ if self.config._attn_implementation != "eager":
204
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
205
+ logger.warning_once(
206
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
207
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
208
+ )
209
+ else:
210
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
211
+
212
+ attn_output, attn_weights = attention_interface(
213
+ self,
214
+ query_states,
215
+ key_states,
216
+ value_states,
217
+ attention_mask,
218
+ dropout=0.0 if not self.training else self.attention_dropout,
219
+ scaling=self.scaling,
220
+ **kwargs,
221
+ )
222
+
223
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
224
+ attn_output = self.o_proj(attn_output)
225
+ return attn_output, attn_weights
226
+
227
+
228
+ class Spectra2DecoderLayer(nn.Module):
229
+ def __init__(self, config: Spectra2Config, layer_idx: int):
230
+ super().__init__()
231
+ self.hidden_size = config.hidden_size
232
+ self.self_attn = Spectra2Attention(config=config, layer_idx=layer_idx)
233
+
234
+ self.mlp = Spectra2MLP(config)
235
+ self.input_rms_layernorm = Spectra2RMSLayerNorm(config.hidden_size)
236
+ self.post_attention_rms_layernorm = Spectra2RMSLayerNorm(config.hidden_size)
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ position_ids: Optional[torch.LongTensor] = None,
243
+ past_key_value: Optional[Cache] = None,
244
+ output_attentions: Optional[bool] = False,
245
+ use_cache: Optional[bool] = False,
246
+ cache_position: Optional[torch.LongTensor] = None,
247
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
248
+ **kwargs: Unpack[FlashAttentionKwargs],
249
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
250
+ residual = hidden_states
251
+
252
+ hidden_states = self.input_rms_layernorm(hidden_states)
253
+
254
+ # Self Attention
255
+ hidden_states, self_attn_weights = self.self_attn(
256
+ hidden_states=hidden_states,
257
+ attention_mask=attention_mask,
258
+ position_ids=position_ids,
259
+ past_key_value=past_key_value,
260
+ output_attentions=output_attentions,
261
+ use_cache=use_cache,
262
+ cache_position=cache_position,
263
+ position_embeddings=position_embeddings,
264
+ **kwargs,
265
+ )
266
+ hidden_states = residual + hidden_states
267
+
268
+ # Fully Connected
269
+ residual = hidden_states
270
+ hidden_states = self.post_attention_rms_layernorm(hidden_states)
271
+ hidden_states = self.mlp(hidden_states)
272
+ hidden_states = residual + hidden_states
273
+
274
+ outputs = (hidden_states,)
275
+ if output_attentions:
276
+ outputs += (self_attn_weights,)
277
+
278
+ return outputs
279
+
280
+
281
+ class Spectra2RotaryEmbedding(nn.Module):
282
+ def __init__(
283
+ self,
284
+ config: Spectra2Config,
285
+ device=None,
286
+ ):
287
+ super().__init__()
288
+ self.rope_kwargs = {}
289
+ # BC: "rope_type" was originally "type"
290
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
291
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
292
+ else:
293
+ self.rope_type = "default"
294
+ self.max_seq_len_cached = config.max_position_embeddings
295
+ self.original_max_seq_len = config.max_position_embeddings
296
+
297
+ self.config = config
298
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
299
+
300
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
301
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
302
+ self.original_inv_freq = self.inv_freq
303
+
304
+ def _dynamic_frequency_update(self, position_ids, device):
305
+ """
306
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
307
+ 1 - growing beyond the cached sequence length (allow scaling)
308
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
309
+ """
310
+ seq_len = torch.max(position_ids) + 1
311
+ if seq_len > self.max_seq_len_cached: # growth
312
+ inv_freq, self.attention_scaling = self.rope_init_fn(
313
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
314
+ )
315
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
316
+ self.max_seq_len_cached = seq_len
317
+
318
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
319
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
320
+ self.max_seq_len_cached = self.original_max_seq_len
321
+
322
+ @torch.no_grad()
323
+ def forward(self, x, position_ids):
324
+ if "dynamic" in self.rope_type:
325
+ self._dynamic_frequency_update(position_ids, device=x.device)
326
+
327
+ # Core RoPE block
328
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
329
+ position_ids_expanded = position_ids[:, None, :].float()
330
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
331
+ device_type = x.device.type
332
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
333
+ with torch.autocast(device_type=device_type, enabled=False):
334
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
335
+ emb = torch.cat((freqs, freqs), dim=-1)
336
+ cos = emb.cos()
337
+ sin = emb.sin()
338
+
339
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
340
+ cos = cos * self.attention_scaling
341
+ sin = sin * self.attention_scaling
342
+
343
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
344
+
345
+
346
+ SPECTRA2_START_DOCSTRING = r"""
347
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
348
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
349
+ etc.)
350
+
351
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
352
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
353
+ and behavior.
354
+
355
+ Parameters:
356
+ config ([`Spectra2Config`]):
357
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
358
+ load the weights associated with the model, only the configuration. Check out the
359
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
360
+ """
361
+
362
+
363
+ @add_start_docstrings(
364
+ "The bare Spectra2 Model outputting raw hidden-states without any specific head on top.",
365
+ SPECTRA2_START_DOCSTRING,
366
+ )
367
+ class Spectra2PreTrainedModel(PreTrainedModel):
368
+ config_class = Spectra2Config
369
+ base_model_prefix = "model"
370
+ supports_gradient_checkpointing = True
371
+ _no_split_modules = ["Spectra2DecoderLayer"]
372
+ _skip_keys_device_placement = ["past_key_values"]
373
+ _supports_flash_attn_2 = True
374
+ _supports_sdpa = True
375
+ _supports_cache_class = True
376
+ _supports_quantized_cache = True
377
+ _supports_static_cache = True
378
+
379
+ def _init_weights(self, module):
380
+ std = self.config.initializer_range
381
+ if isinstance(module, nn.Linear):
382
+ module.weight.data.normal_(mean=0.0, std=std)
383
+ if module.bias is not None:
384
+ module.bias.data.zero_()
385
+ elif isinstance(module, nn.Embedding):
386
+ module.weight.data.normal_(mean=0.0, std=std)
387
+ if module.padding_idx is not None:
388
+ module.weight.data[module.padding_idx].zero_()
389
+
390
+ SPECTRA2_INPUTS_DOCSTRING = r"""
391
+ Args:
392
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
393
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
394
+ it.
395
+
396
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
397
+ [`PreTrainedTokenizer.__call__`] for details.
398
+
399
+ [What are input IDs?](../glossary#input-ids)
400
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
401
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
402
+
403
+ - 1 for tokens that are **not masked**,
404
+ - 0 for tokens that are **masked**.
405
+
406
+ [What are attention masks?](../glossary#attention-mask)
407
+
408
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
409
+ [`PreTrainedTokenizer.__call__`] for details.
410
+
411
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
412
+ `past_key_values`).
413
+
414
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
415
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
416
+ information on the default strategy.
417
+
418
+ - 1 indicates the head is **not masked**,
419
+ - 0 indicates the head is **masked**.
420
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
421
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
422
+ config.n_positions - 1]`.
423
+
424
+ [What are position IDs?](../glossary#position-ids)
425
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
426
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
427
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
428
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
429
+
430
+ Two formats are allowed:
431
+ - a [`~cache_utils.Cache`] instance, see our
432
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
433
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
434
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
435
+ cache format.
436
+
437
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
438
+ legacy cache format will be returned.
439
+
440
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
441
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
442
+ of shape `(batch_size, sequence_length)`.
443
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
444
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
445
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
446
+ model's internal embedding lookup matrix.
447
+ use_cache (`bool`, *optional*):
448
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
449
+ `past_key_values`).
450
+ output_attentions (`bool`, *optional*):
451
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
452
+ tensors for more detail.
453
+ output_hidden_states (`bool`, *optional*):
454
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
455
+ more detail.
456
+ return_dict (`bool`, *optional*):
457
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
458
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
459
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
460
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
461
+ the complete sequence length.
462
+ """
463
+
464
+
465
+ @add_start_docstrings(
466
+ "The bare Spectra2 Model outputting raw hidden-states without any specific head on top.",
467
+ SPECTRA2_START_DOCSTRING,
468
+ )
469
+ class Spectra2Model(Spectra2PreTrainedModel):
470
+ """
471
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Spectra2DecoderLayer`]
472
+
473
+ Args:
474
+ config: Spectra2Config
475
+ """
476
+
477
+ def __init__(self, config: Spectra2Config):
478
+ super().__init__(config)
479
+ self.padding_idx = config.pad_token_id
480
+ self.vocab_size = config.vocab_size
481
+
482
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
483
+ self.layers = nn.ModuleList(
484
+ [Spectra2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
485
+ )
486
+ self.norm = Spectra2RMSLayerNorm(config.hidden_size)
487
+ self.rotary_emb = Spectra2RotaryEmbedding(config=config)
488
+ self.gradient_checkpointing = False
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ return self.embed_tokens
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.embed_tokens = value
498
+
499
+ @add_start_docstrings_to_model_forward(SPECTRA2_INPUTS_DOCSTRING)
500
+ def forward(
501
+ self,
502
+ input_ids: torch.LongTensor = None,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ position_ids: Optional[torch.LongTensor] = None,
505
+ past_key_values: Optional[Cache] = None,
506
+ inputs_embeds: Optional[torch.FloatTensor] = None,
507
+ use_cache: Optional[bool] = None,
508
+ output_attentions: Optional[bool] = None,
509
+ output_hidden_states: Optional[bool] = None,
510
+ return_dict: Optional[bool] = None,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
513
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
514
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
515
+ output_hidden_states = (
516
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
517
+ )
518
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
519
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
520
+
521
+ if (input_ids is None) ^ (inputs_embeds is not None):
522
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
523
+
524
+ if self.gradient_checkpointing and self.training and use_cache:
525
+ logger.warning_once(
526
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
527
+ )
528
+ use_cache = False
529
+
530
+ if inputs_embeds is None:
531
+ inputs_embeds = self.embed_tokens(input_ids)
532
+
533
+ if use_cache and past_key_values is None:
534
+ past_key_values = DynamicCache()
535
+
536
+ if cache_position is None:
537
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
538
+ cache_position = torch.arange(
539
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
540
+ )
541
+
542
+ if position_ids is None:
543
+ position_ids = cache_position.unsqueeze(0)
544
+
545
+ causal_mask = self._update_causal_mask(
546
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
547
+ )
548
+
549
+ hidden_states = inputs_embeds
550
+
551
+ # create position embeddings to be shared across the decoder layers
552
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
553
+
554
+ # decoder layers
555
+ all_hidden_states = () if output_hidden_states else None
556
+ all_self_attns = () if output_attentions else None
557
+
558
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
559
+ if output_hidden_states:
560
+ all_hidden_states += (hidden_states,)
561
+
562
+ if self.gradient_checkpointing and self.training:
563
+ layer_outputs = self._gradient_checkpointing_func(
564
+ decoder_layer.__call__,
565
+ hidden_states,
566
+ causal_mask,
567
+ position_ids,
568
+ past_key_values,
569
+ output_attentions,
570
+ use_cache,
571
+ cache_position,
572
+ position_embeddings,
573
+ )
574
+ else:
575
+ layer_outputs = decoder_layer(
576
+ hidden_states,
577
+ attention_mask=causal_mask,
578
+ position_ids=position_ids,
579
+ past_key_value=past_key_values,
580
+ output_attentions=output_attentions,
581
+ use_cache=use_cache,
582
+ cache_position=cache_position,
583
+ position_embeddings=position_embeddings,
584
+ **flash_attn_kwargs,
585
+ )
586
+
587
+ hidden_states = layer_outputs[0]
588
+
589
+ if output_attentions:
590
+ all_self_attns += (layer_outputs[1],)
591
+
592
+ hidden_states = self.norm(hidden_states)
593
+
594
+ # add hidden states from the last decoder layer
595
+ if output_hidden_states:
596
+ all_hidden_states += (hidden_states,)
597
+
598
+ output = BaseModelOutputWithPast(
599
+ last_hidden_state=hidden_states,
600
+ past_key_values=past_key_values if use_cache else None,
601
+ hidden_states=all_hidden_states,
602
+ attentions=all_self_attns,
603
+ )
604
+ return output if return_dict else output.to_tuple()
605
+
606
+ def _update_causal_mask(
607
+ self,
608
+ attention_mask: torch.Tensor,
609
+ input_tensor: torch.Tensor,
610
+ cache_position: torch.Tensor,
611
+ past_key_values: Cache,
612
+ output_attentions: bool,
613
+ ):
614
+ if self.config._attn_implementation == "flash_attention_2":
615
+ if attention_mask is not None and 0.0 in attention_mask:
616
+ return attention_mask
617
+ return None
618
+
619
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
620
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
621
+ # to infer the attention mask.
622
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
623
+ using_static_cache = isinstance(past_key_values, StaticCache)
624
+
625
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
626
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
627
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
628
+ attention_mask,
629
+ inputs_embeds=input_tensor,
630
+ past_key_values_length=past_seen_tokens,
631
+ is_training=self.training,
632
+ ):
633
+ return None
634
+
635
+ dtype, device = input_tensor.dtype, input_tensor.device
636
+ sequence_length = input_tensor.shape[1]
637
+ if using_static_cache:
638
+ target_length = past_key_values.get_max_cache_shape()
639
+ else:
640
+ target_length = (
641
+ attention_mask.shape[-1]
642
+ if isinstance(attention_mask, torch.Tensor)
643
+ else past_seen_tokens + sequence_length + 1
644
+ )
645
+
646
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
647
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
648
+ attention_mask,
649
+ sequence_length=sequence_length,
650
+ target_length=target_length,
651
+ dtype=dtype,
652
+ device=device,
653
+ cache_position=cache_position,
654
+ batch_size=input_tensor.shape[0],
655
+ )
656
+
657
+ if (
658
+ self.config._attn_implementation == "sdpa"
659
+ and attention_mask is not None
660
+ and attention_mask.device.type == "cuda"
661
+ and not output_attentions
662
+ ):
663
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
664
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
665
+ # Details: https://github.com/pytorch/pytorch/issues/110213
666
+ min_dtype = torch.finfo(dtype).min
667
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
668
+
669
+ return causal_mask
670
+
671
+ @staticmethod
672
+ def _prepare_4d_causal_attention_mask_with_cache_position(
673
+ attention_mask: torch.Tensor,
674
+ sequence_length: int,
675
+ target_length: int,
676
+ dtype: torch.dtype,
677
+ device: torch.device,
678
+ cache_position: torch.Tensor,
679
+ batch_size: int,
680
+ **kwargs,
681
+ ):
682
+ """
683
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
684
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
685
+
686
+ Args:
687
+ attention_mask (`torch.Tensor`):
688
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
689
+ `(batch_size, 1, query_length, key_value_length)`.
690
+ sequence_length (`int`):
691
+ The sequence length being processed.
692
+ target_length (`int`):
693
+ The target length: when generating with static cache, the mask should be as long as the static cache,
694
+ to account for the 0 padding, the part of the cache that is not filled yet.
695
+ dtype (`torch.dtype`):
696
+ The dtype to use for the 4D attention mask.
697
+ device (`torch.device`):
698
+ The device to plcae the 4D attention mask on.
699
+ cache_position (`torch.Tensor`):
700
+ Indices depicting the position of the input sequence tokens in the sequence.
701
+ batch_size (`torch.Tensor`):
702
+ Batch size.
703
+ """
704
+ if attention_mask is not None and attention_mask.dim() == 4:
705
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
706
+ causal_mask = attention_mask
707
+ else:
708
+ min_dtype = torch.finfo(dtype).min
709
+ causal_mask = torch.full(
710
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
711
+ )
712
+ if sequence_length != 1:
713
+ causal_mask = torch.triu(causal_mask, diagonal=1)
714
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
715
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
716
+ if attention_mask is not None:
717
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
718
+ mask_length = attention_mask.shape[-1]
719
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
720
+ padding_mask = padding_mask == 0
721
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
722
+ padding_mask, min_dtype
723
+ )
724
+
725
+ return causal_mask
726
+
727
+
728
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
729
+
730
+
731
+ class Spectra2ForCausalLM(Spectra2PreTrainedModel, GenerationMixin):
732
+ _tied_weights_keys = ["lm_head.weight"]
733
+ _tp_plan = {"lm_head": "colwise_rep"}
734
+
735
+ def __init__(self, config):
736
+ super().__init__(config)
737
+ self.model = Spectra2Model(config)
738
+ self.vocab_size = config.vocab_size
739
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self):
745
+ return self.model.embed_tokens
746
+
747
+ def set_input_embeddings(self, value):
748
+ self.model.embed_tokens = value
749
+
750
+ def get_output_embeddings(self):
751
+ return self.lm_head
752
+
753
+ def set_output_embeddings(self, new_embeddings):
754
+ self.lm_head = new_embeddings
755
+
756
+ def set_decoder(self, decoder):
757
+ self.model = decoder
758
+
759
+ def get_decoder(self):
760
+ return self.model
761
+
762
+ @add_start_docstrings_to_model_forward(SPECTRA2_INPUTS_DOCSTRING)
763
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
764
+ def forward(
765
+ self,
766
+ input_ids: torch.LongTensor = None,
767
+ attention_mask: Optional[torch.Tensor] = None,
768
+ position_ids: Optional[torch.LongTensor] = None,
769
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
770
+ inputs_embeds: Optional[torch.FloatTensor] = None,
771
+ labels: Optional[torch.LongTensor] = None,
772
+ use_cache: Optional[bool] = None,
773
+ output_attentions: Optional[bool] = None,
774
+ output_hidden_states: Optional[bool] = None,
775
+ return_dict: Optional[bool] = None,
776
+ cache_position: Optional[torch.LongTensor] = None,
777
+ num_logits_to_keep: int = 0,
778
+ **kwargs: Unpack[KwargsForCausalLM],
779
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
780
+ r"""
781
+ Args:
782
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
783
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
784
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
785
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
786
+
787
+ num_logits_to_keep (`int`, *optional*):
788
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
789
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
790
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
791
+
792
+ Returns:
793
+
794
+ Example:
795
+
796
+ ```python
797
+ >>> from transformers import AutoTokenizer, Spectra2ForCausalLM
798
+
799
+ >>> model = Spectra2ForCausalLM.from_pretrained("SpectraSuite/Spectra2-3B-base")
800
+ >>> tokenizer = AutoTokenizer.from_pretrained("SpectraSuite/Spectra2-3B-base")
801
+
802
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
803
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
804
+
805
+ >>> # Generate
806
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
807
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
808
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
809
+ ```"""
810
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
811
+ output_hidden_states = (
812
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
813
+ )
814
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
815
+
816
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
817
+ outputs = self.model(
818
+ input_ids=input_ids,
819
+ attention_mask=attention_mask,
820
+ position_ids=position_ids,
821
+ past_key_values=past_key_values,
822
+ inputs_embeds=inputs_embeds,
823
+ use_cache=use_cache,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ cache_position=cache_position,
828
+ **kwargs,
829
+ )
830
+
831
+ hidden_states = outputs[0]
832
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
833
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
834
+
835
+ loss = None
836
+ if labels is not None:
837
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
838
+
839
+ if not return_dict:
840
+ output = (logits,) + outputs[1:]
841
+ return (loss,) + output if loss is not None else output
842
+
843
+ return CausalLMOutputWithPast(
844
+ loss=loss,
845
+ logits=logits,
846
+ past_key_values=outputs.past_key_values,
847
+ hidden_states=outputs.hidden_states,
848
+ attentions=outputs.attentions,
849
+ )
17_99p_300t/step140000/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
17_99p_300t/step140000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
17_99p_300t/step50000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
17_99p_300t/step60000/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Spectra2ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "clip_qkv": null,
8
+ "eos_token_id": 50277,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 512,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1280,
13
+ "max_position_embeddings": 2048,
14
+ "model_type": "spectra2",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 16,
17
+ "num_key_value_heads": 8,
18
+ "pad_token_id": 1,
19
+ "rope_scaling": null,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "torch_dtype": "float32",
23
+ "use_cache": true,
24
+ "vocab_size": 50304,
25
+ "auto_map": {
26
+ "AutoConfig": "configuration_spectra2.Spectra2Config",
27
+ "AutoModel": "modeling_spectra2.Spectra2Model",
28
+ "AutoModelForCausalLM": "modeling_spectra2.Spectra2ForCausalLM"
29
+ }
30
+ }
17_99p_300t/step60000/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 50277,
4
+ "pad_token_id": 1,
5
+ "transformers_version": "4.45.2"
6
+ }
17_99p_300t/step60000/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
17_99p_300t/step60000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
17_99p_300t/step60000/tokenizer_config.json ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ }
206
+ },
207
+ "bos_token": "<|endoftext|>",
208
+ "clean_up_tokenization_spaces": true,
209
+ "eos_token": "<|endoftext|>",
210
+ "model_max_length": 1000000000000000019884624838656,
211
+ "pad_token": null,
212
+ "tokenizer_class": "GPTNeoXTokenizer",
213
+ "unk_token": "<|endoftext|>"
214
+ }
17_99p_300t/step70000/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ `pip install git+https://github.com/huggingface/transformers.git@05260a1`
17_99p_300t/step70000/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Spectra2ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "clip_qkv": null,
8
+ "eos_token_id": 50277,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 512,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1280,
13
+ "max_position_embeddings": 2048,
14
+ "model_type": "spectra2",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 16,
17
+ "num_key_value_heads": 8,
18
+ "pad_token_id": 1,
19
+ "rope_scaling": null,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "torch_dtype": "float32",
23
+ "use_cache": true,
24
+ "vocab_size": 50304,
25
+ "auto_map": {
26
+ "AutoConfig": "configuration_spectra2.Spectra2Config",
27
+ "AutoModel": "modeling_spectra2.Spectra2Model",
28
+ "AutoModelForCausalLM": "modeling_spectra2.Spectra2ForCausalLM"
29
+ }
30
+ }
17_99p_300t/step70000/configuration_spectra2.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Spectra2 model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class Spectra2Config(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`Spectra2Model`]. It is used to instantiate an Spectra2
32
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
+ defaults will yield a similar configuration to that of the [SpectraSuite/Spectra2-3B-base](https://huggingface.co/spectrasuite/Spectra2-3B-base).
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 50304):
38
+ Vocabulary size of the Spectra2 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`Spectra2Model`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 11008):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer decoder.
48
+ num_key_value_heads (`int`, *optional*):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
55
+ `num_attention_heads`.
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
57
+ The non-linear activation function (function or string) in the decoder.
58
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
59
+ The maximum sequence length that this model might ever be used with.
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ use_cache (`bool`, *optional*, defaults to `True`):
63
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
64
+ relevant if `config.is_decoder=True`.
65
+ pad_token_id (`int`, *optional*, defaults to 1):
66
+ Padding token id.
67
+ bos_token_id (`int`, *optional*):
68
+ Beginning of stream token id.
69
+ eos_token_id (`int`, *optional*, defaults to 50279):
70
+ End of stream token id.
71
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
72
+ Whether to tie weight embeddings
73
+ rope_theta (`float`, *optional*, defaults to 10000.0):
74
+ The base period of the RoPE embeddings.
75
+ rope_scaling (`Dict`, *optional*):
76
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
77
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
78
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
79
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
80
+ these scaling strategies behave:
81
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
82
+ experimental feature, subject to breaking API changes in future versions.
83
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
84
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
85
+ attention_dropout (`float`, *optional*, defaults to 0.0):
86
+ The dropout ratio for the attention probabilities.
87
+ clip_qkv (`float`, *optional*):
88
+ If not `None`, elements of query, key and value attention states are clipped so that their
89
+ absolute value does not exceed this value.
90
+ ```python
91
+ >>> from transformers import Spectra2Model, Spectra2Config
92
+ >>> # Initializing a Spectra2 3B style configuration
93
+ >>> configuration = Spectra2Config()
94
+ >>> # Initializing a model from the Spectra2 3B style configuration
95
+ >>> model = Spectra2Model(configuration)
96
+ >>> # Accessing the model configuration
97
+ >>> configuration = model.config
98
+ ```"""
99
+
100
+ model_type = "spectra2"
101
+ keys_to_ignore_at_inference = ["past_key_values"]
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=50304,
106
+ hidden_size=4096,
107
+ intermediate_size=11008,
108
+ num_hidden_layers=32,
109
+ num_attention_heads=32,
110
+ num_key_value_heads=None,
111
+ hidden_act="silu",
112
+ max_position_embeddings=2048,
113
+ initializer_range=0.02,
114
+ use_cache=True,
115
+ pad_token_id=1,
116
+ bos_token_id=None,
117
+ eos_token_id=50279,
118
+ tie_word_embeddings=False,
119
+ rope_theta=10000.0,
120
+ rope_scaling=None,
121
+ attention_bias=False,
122
+ attention_dropout=0.0,
123
+ clip_qkv=None,
124
+ **kwargs,
125
+ ):
126
+ self.vocab_size = vocab_size
127
+ self.max_position_embeddings = max_position_embeddings
128
+ self.hidden_size = hidden_size
129
+ self.intermediate_size = intermediate_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.num_attention_heads = num_attention_heads
132
+
133
+ # for backward compatibility
134
+ if num_key_value_heads is None:
135
+ num_key_value_heads = num_attention_heads
136
+
137
+ self.num_key_value_heads = num_key_value_heads
138
+ self.hidden_act = hidden_act
139
+ self.initializer_range = initializer_range
140
+ self.use_cache = use_cache
141
+ self.rope_theta = rope_theta
142
+ self.rope_scaling = rope_scaling
143
+ self._rope_scaling_validation()
144
+ self.attention_bias = attention_bias
145
+ self.attention_dropout = attention_dropout
146
+ self.clip_qkv = clip_qkv
147
+
148
+ super().__init__(
149
+ pad_token_id=pad_token_id,
150
+ bos_token_id=bos_token_id,
151
+ eos_token_id=eos_token_id,
152
+ tie_word_embeddings=tie_word_embeddings,
153
+ **kwargs,
154
+ )
155
+
156
+ def _rope_scaling_validation(self):
157
+ """
158
+ Validate the `rope_scaling` configuration.
159
+ """
160
+ if self.rope_scaling is None:
161
+ return
162
+
163
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
164
+ raise ValueError(
165
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
166
+ )
167
+ rope_scaling_type = self.rope_scaling.get("type", None)
168
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
169
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
170
+ raise ValueError(
171
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
172
+ )
173
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
174
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
17_99p_300t/step70000/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 50277,
4
+ "pad_token_id": 1,
5
+ "transformers_version": "4.45.2"
6
+ }
17_99p_300t/step70000/modeling_spectra2.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/spectra2/modular_spectra2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_spectra2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
+ from transformers.generation import GenerationMixin
16
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
17
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import (
23
+ LossKwargs,
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ replace_return_docstrings,
28
+ )
29
+ from .configuration_spectra2 import Spectra2Config
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+ _CONFIG_FOR_DOC = "Spectra2Config"
34
+
35
+
36
+ class Spectra2RMSLayerNorm(nn.Module):
37
+ """LayerNorm but with no learnable weight or bias."""
38
+
39
+ def __init__(self, hidden_size: int) -> None:
40
+ super().__init__()
41
+ self.weight = nn.Parameter(torch.ones(hidden_size))
42
+ self.variance_epsilon = 1e-05 # Hardcoded
43
+ self.normalized_shape = (hidden_size,)
44
+
45
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
46
+ with torch.autocast(enabled=False, device_type=hidden_states.device.type):
47
+ og_dtype = hidden_states.dtype
48
+ hidden_states = hidden_states.to(torch.float32)
49
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
50
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
51
+ hidden_states = hidden_states.to(og_dtype)
52
+ return self.weight * hidden_states
53
+
54
+
55
+ class Spectra2MLP(nn.Module):
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.config = config
59
+ self.hidden_size = config.hidden_size
60
+ self.intermediate_size = config.intermediate_size
61
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
62
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
63
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
64
+ self.act_fn = ACT2FN[config.hidden_act]
65
+
66
+ def forward(self, x):
67
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
68
+ return down_proj
69
+
70
+
71
+ def rotate_half(x):
72
+ """Rotates half the hidden dims of the input."""
73
+ x1 = x[..., : x.shape[-1] // 2]
74
+ x2 = x[..., x.shape[-1] // 2 :]
75
+ return torch.cat((-x2, x1), dim=-1)
76
+
77
+
78
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
79
+ """Applies Rotary Position Embedding to the query and key tensors.
80
+
81
+ Args:
82
+ q (`torch.Tensor`): The query tensor.
83
+ k (`torch.Tensor`): The key tensor.
84
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
85
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
86
+ position_ids (`torch.Tensor`, *optional*):
87
+ Deprecated and unused.
88
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
89
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
90
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
91
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
92
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
93
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
94
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
95
+ Returns:
96
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
97
+ """
98
+ cos = cos.unsqueeze(unsqueeze_dim)
99
+ sin = sin.unsqueeze(unsqueeze_dim)
100
+ q_embed = (q * cos) + (rotate_half(q) * sin)
101
+ k_embed = (k * cos) + (rotate_half(k) * sin)
102
+ return q_embed, k_embed
103
+
104
+
105
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
106
+ """
107
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
108
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
109
+ """
110
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
111
+ if n_rep == 1:
112
+ return hidden_states
113
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
114
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
115
+
116
+
117
+ def eager_attention_forward(
118
+ module: nn.Module,
119
+ query: torch.Tensor,
120
+ key: torch.Tensor,
121
+ value: torch.Tensor,
122
+ attention_mask: Optional[torch.Tensor],
123
+ scaling: float,
124
+ dropout: float = 0.0,
125
+ **kwargs,
126
+ ):
127
+ key_states = repeat_kv(key, module.num_key_value_groups)
128
+ value_states = repeat_kv(value, module.num_key_value_groups)
129
+
130
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
131
+ if attention_mask is not None:
132
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
133
+ attn_weights = attn_weights + causal_mask
134
+
135
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
136
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
137
+ attn_output = torch.matmul(attn_weights, value_states)
138
+ attn_output = attn_output.transpose(1, 2).contiguous()
139
+
140
+ return attn_output, attn_weights
141
+
142
+
143
+ class Spectra2Attention(nn.Module):
144
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
145
+
146
+ def __init__(self, config: Spectra2Config, layer_idx: int):
147
+ super().__init__()
148
+ self.config = config
149
+ self.layer_idx = layer_idx
150
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
151
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
152
+ self.scaling = self.head_dim**-0.5
153
+ self.attention_dropout = config.attention_dropout
154
+ self.is_causal = True
155
+
156
+ self.q_proj = nn.Linear(
157
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
158
+ )
159
+ self.k_proj = nn.Linear(
160
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
161
+ )
162
+ self.v_proj = nn.Linear(
163
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
164
+ )
165
+ self.o_proj = nn.Linear(
166
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
167
+ )
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
173
+ attention_mask: Optional[torch.Tensor],
174
+ past_key_value: Optional[Cache] = None,
175
+ cache_position: Optional[torch.LongTensor] = None,
176
+ **kwargs,
177
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
178
+ input_shape = hidden_states.shape[:-1]
179
+ hidden_shape = (*input_shape, -1, self.head_dim)
180
+
181
+ query_states = self.q_proj(hidden_states)
182
+ key_states = self.k_proj(hidden_states)
183
+ value_states = self.v_proj(hidden_states)
184
+
185
+ if self.config.clip_qkv is not None:
186
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
187
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
188
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
189
+
190
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
191
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
192
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
193
+
194
+ cos, sin = position_embeddings
195
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
196
+
197
+ if past_key_value is not None:
198
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
199
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
200
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
201
+
202
+ attention_interface: Callable = eager_attention_forward
203
+ if self.config._attn_implementation != "eager":
204
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
205
+ logger.warning_once(
206
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
207
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
208
+ )
209
+ else:
210
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
211
+
212
+ attn_output, attn_weights = attention_interface(
213
+ self,
214
+ query_states,
215
+ key_states,
216
+ value_states,
217
+ attention_mask,
218
+ dropout=0.0 if not self.training else self.attention_dropout,
219
+ scaling=self.scaling,
220
+ **kwargs,
221
+ )
222
+
223
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
224
+ attn_output = self.o_proj(attn_output)
225
+ return attn_output, attn_weights
226
+
227
+
228
+ class Spectra2DecoderLayer(nn.Module):
229
+ def __init__(self, config: Spectra2Config, layer_idx: int):
230
+ super().__init__()
231
+ self.hidden_size = config.hidden_size
232
+ self.self_attn = Spectra2Attention(config=config, layer_idx=layer_idx)
233
+
234
+ self.mlp = Spectra2MLP(config)
235
+ self.input_rms_layernorm = Spectra2RMSLayerNorm(config.hidden_size)
236
+ self.post_attention_rms_layernorm = Spectra2RMSLayerNorm(config.hidden_size)
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ position_ids: Optional[torch.LongTensor] = None,
243
+ past_key_value: Optional[Cache] = None,
244
+ output_attentions: Optional[bool] = False,
245
+ use_cache: Optional[bool] = False,
246
+ cache_position: Optional[torch.LongTensor] = None,
247
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
248
+ **kwargs: Unpack[FlashAttentionKwargs],
249
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
250
+ residual = hidden_states
251
+
252
+ hidden_states = self.input_rms_layernorm(hidden_states)
253
+
254
+ # Self Attention
255
+ hidden_states, self_attn_weights = self.self_attn(
256
+ hidden_states=hidden_states,
257
+ attention_mask=attention_mask,
258
+ position_ids=position_ids,
259
+ past_key_value=past_key_value,
260
+ output_attentions=output_attentions,
261
+ use_cache=use_cache,
262
+ cache_position=cache_position,
263
+ position_embeddings=position_embeddings,
264
+ **kwargs,
265
+ )
266
+ hidden_states = residual + hidden_states
267
+
268
+ # Fully Connected
269
+ residual = hidden_states
270
+ hidden_states = self.post_attention_rms_layernorm(hidden_states)
271
+ hidden_states = self.mlp(hidden_states)
272
+ hidden_states = residual + hidden_states
273
+
274
+ outputs = (hidden_states,)
275
+ if output_attentions:
276
+ outputs += (self_attn_weights,)
277
+
278
+ return outputs
279
+
280
+
281
+ class Spectra2RotaryEmbedding(nn.Module):
282
+ def __init__(
283
+ self,
284
+ config: Spectra2Config,
285
+ device=None,
286
+ ):
287
+ super().__init__()
288
+ self.rope_kwargs = {}
289
+ # BC: "rope_type" was originally "type"
290
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
291
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
292
+ else:
293
+ self.rope_type = "default"
294
+ self.max_seq_len_cached = config.max_position_embeddings
295
+ self.original_max_seq_len = config.max_position_embeddings
296
+
297
+ self.config = config
298
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
299
+
300
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
301
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
302
+ self.original_inv_freq = self.inv_freq
303
+
304
+ def _dynamic_frequency_update(self, position_ids, device):
305
+ """
306
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
307
+ 1 - growing beyond the cached sequence length (allow scaling)
308
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
309
+ """
310
+ seq_len = torch.max(position_ids) + 1
311
+ if seq_len > self.max_seq_len_cached: # growth
312
+ inv_freq, self.attention_scaling = self.rope_init_fn(
313
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
314
+ )
315
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
316
+ self.max_seq_len_cached = seq_len
317
+
318
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
319
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
320
+ self.max_seq_len_cached = self.original_max_seq_len
321
+
322
+ @torch.no_grad()
323
+ def forward(self, x, position_ids):
324
+ if "dynamic" in self.rope_type:
325
+ self._dynamic_frequency_update(position_ids, device=x.device)
326
+
327
+ # Core RoPE block
328
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
329
+ position_ids_expanded = position_ids[:, None, :].float()
330
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
331
+ device_type = x.device.type
332
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
333
+ with torch.autocast(device_type=device_type, enabled=False):
334
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
335
+ emb = torch.cat((freqs, freqs), dim=-1)
336
+ cos = emb.cos()
337
+ sin = emb.sin()
338
+
339
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
340
+ cos = cos * self.attention_scaling
341
+ sin = sin * self.attention_scaling
342
+
343
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
344
+
345
+
346
+ SPECTRA2_START_DOCSTRING = r"""
347
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
348
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
349
+ etc.)
350
+
351
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
352
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
353
+ and behavior.
354
+
355
+ Parameters:
356
+ config ([`Spectra2Config`]):
357
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
358
+ load the weights associated with the model, only the configuration. Check out the
359
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
360
+ """
361
+
362
+
363
+ @add_start_docstrings(
364
+ "The bare Spectra2 Model outputting raw hidden-states without any specific head on top.",
365
+ SPECTRA2_START_DOCSTRING,
366
+ )
367
+ class Spectra2PreTrainedModel(PreTrainedModel):
368
+ config_class = Spectra2Config
369
+ base_model_prefix = "model"
370
+ supports_gradient_checkpointing = True
371
+ _no_split_modules = ["Spectra2DecoderLayer"]
372
+ _skip_keys_device_placement = ["past_key_values"]
373
+ _supports_flash_attn_2 = True
374
+ _supports_sdpa = True
375
+ _supports_cache_class = True
376
+ _supports_quantized_cache = True
377
+ _supports_static_cache = True
378
+
379
+ def _init_weights(self, module):
380
+ std = self.config.initializer_range
381
+ if isinstance(module, nn.Linear):
382
+ module.weight.data.normal_(mean=0.0, std=std)
383
+ if module.bias is not None:
384
+ module.bias.data.zero_()
385
+ elif isinstance(module, nn.Embedding):
386
+ module.weight.data.normal_(mean=0.0, std=std)
387
+ if module.padding_idx is not None:
388
+ module.weight.data[module.padding_idx].zero_()
389
+
390
+ SPECTRA2_INPUTS_DOCSTRING = r"""
391
+ Args:
392
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
393
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
394
+ it.
395
+
396
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
397
+ [`PreTrainedTokenizer.__call__`] for details.
398
+
399
+ [What are input IDs?](../glossary#input-ids)
400
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
401
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
402
+
403
+ - 1 for tokens that are **not masked**,
404
+ - 0 for tokens that are **masked**.
405
+
406
+ [What are attention masks?](../glossary#attention-mask)
407
+
408
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
409
+ [`PreTrainedTokenizer.__call__`] for details.
410
+
411
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
412
+ `past_key_values`).
413
+
414
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
415
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
416
+ information on the default strategy.
417
+
418
+ - 1 indicates the head is **not masked**,
419
+ - 0 indicates the head is **masked**.
420
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
421
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
422
+ config.n_positions - 1]`.
423
+
424
+ [What are position IDs?](../glossary#position-ids)
425
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
426
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
427
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
428
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
429
+
430
+ Two formats are allowed:
431
+ - a [`~cache_utils.Cache`] instance, see our
432
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
433
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
434
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
435
+ cache format.
436
+
437
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
438
+ legacy cache format will be returned.
439
+
440
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
441
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
442
+ of shape `(batch_size, sequence_length)`.
443
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
444
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
445
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
446
+ model's internal embedding lookup matrix.
447
+ use_cache (`bool`, *optional*):
448
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
449
+ `past_key_values`).
450
+ output_attentions (`bool`, *optional*):
451
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
452
+ tensors for more detail.
453
+ output_hidden_states (`bool`, *optional*):
454
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
455
+ more detail.
456
+ return_dict (`bool`, *optional*):
457
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
458
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
459
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
460
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
461
+ the complete sequence length.
462
+ """
463
+
464
+
465
+ @add_start_docstrings(
466
+ "The bare Spectra2 Model outputting raw hidden-states without any specific head on top.",
467
+ SPECTRA2_START_DOCSTRING,
468
+ )
469
+ class Spectra2Model(Spectra2PreTrainedModel):
470
+ """
471
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Spectra2DecoderLayer`]
472
+
473
+ Args:
474
+ config: Spectra2Config
475
+ """
476
+
477
+ def __init__(self, config: Spectra2Config):
478
+ super().__init__(config)
479
+ self.padding_idx = config.pad_token_id
480
+ self.vocab_size = config.vocab_size
481
+
482
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
483
+ self.layers = nn.ModuleList(
484
+ [Spectra2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
485
+ )
486
+ self.norm = Spectra2RMSLayerNorm(config.hidden_size)
487
+ self.rotary_emb = Spectra2RotaryEmbedding(config=config)
488
+ self.gradient_checkpointing = False
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ return self.embed_tokens
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.embed_tokens = value
498
+
499
+ @add_start_docstrings_to_model_forward(SPECTRA2_INPUTS_DOCSTRING)
500
+ def forward(
501
+ self,
502
+ input_ids: torch.LongTensor = None,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ position_ids: Optional[torch.LongTensor] = None,
505
+ past_key_values: Optional[Cache] = None,
506
+ inputs_embeds: Optional[torch.FloatTensor] = None,
507
+ use_cache: Optional[bool] = None,
508
+ output_attentions: Optional[bool] = None,
509
+ output_hidden_states: Optional[bool] = None,
510
+ return_dict: Optional[bool] = None,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
513
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
514
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
515
+ output_hidden_states = (
516
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
517
+ )
518
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
519
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
520
+
521
+ if (input_ids is None) ^ (inputs_embeds is not None):
522
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
523
+
524
+ if self.gradient_checkpointing and self.training and use_cache:
525
+ logger.warning_once(
526
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
527
+ )
528
+ use_cache = False
529
+
530
+ if inputs_embeds is None:
531
+ inputs_embeds = self.embed_tokens(input_ids)
532
+
533
+ if use_cache and past_key_values is None:
534
+ past_key_values = DynamicCache()
535
+
536
+ if cache_position is None:
537
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
538
+ cache_position = torch.arange(
539
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
540
+ )
541
+
542
+ if position_ids is None:
543
+ position_ids = cache_position.unsqueeze(0)
544
+
545
+ causal_mask = self._update_causal_mask(
546
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
547
+ )
548
+
549
+ hidden_states = inputs_embeds
550
+
551
+ # create position embeddings to be shared across the decoder layers
552
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
553
+
554
+ # decoder layers
555
+ all_hidden_states = () if output_hidden_states else None
556
+ all_self_attns = () if output_attentions else None
557
+
558
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
559
+ if output_hidden_states:
560
+ all_hidden_states += (hidden_states,)
561
+
562
+ if self.gradient_checkpointing and self.training:
563
+ layer_outputs = self._gradient_checkpointing_func(
564
+ decoder_layer.__call__,
565
+ hidden_states,
566
+ causal_mask,
567
+ position_ids,
568
+ past_key_values,
569
+ output_attentions,
570
+ use_cache,
571
+ cache_position,
572
+ position_embeddings,
573
+ )
574
+ else:
575
+ layer_outputs = decoder_layer(
576
+ hidden_states,
577
+ attention_mask=causal_mask,
578
+ position_ids=position_ids,
579
+ past_key_value=past_key_values,
580
+ output_attentions=output_attentions,
581
+ use_cache=use_cache,
582
+ cache_position=cache_position,
583
+ position_embeddings=position_embeddings,
584
+ **flash_attn_kwargs,
585
+ )
586
+
587
+ hidden_states = layer_outputs[0]
588
+
589
+ if output_attentions:
590
+ all_self_attns += (layer_outputs[1],)
591
+
592
+ hidden_states = self.norm(hidden_states)
593
+
594
+ # add hidden states from the last decoder layer
595
+ if output_hidden_states:
596
+ all_hidden_states += (hidden_states,)
597
+
598
+ output = BaseModelOutputWithPast(
599
+ last_hidden_state=hidden_states,
600
+ past_key_values=past_key_values if use_cache else None,
601
+ hidden_states=all_hidden_states,
602
+ attentions=all_self_attns,
603
+ )
604
+ return output if return_dict else output.to_tuple()
605
+
606
+ def _update_causal_mask(
607
+ self,
608
+ attention_mask: torch.Tensor,
609
+ input_tensor: torch.Tensor,
610
+ cache_position: torch.Tensor,
611
+ past_key_values: Cache,
612
+ output_attentions: bool,
613
+ ):
614
+ if self.config._attn_implementation == "flash_attention_2":
615
+ if attention_mask is not None and 0.0 in attention_mask:
616
+ return attention_mask
617
+ return None
618
+
619
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
620
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
621
+ # to infer the attention mask.
622
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
623
+ using_static_cache = isinstance(past_key_values, StaticCache)
624
+
625
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
626
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
627
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
628
+ attention_mask,
629
+ inputs_embeds=input_tensor,
630
+ past_key_values_length=past_seen_tokens,
631
+ is_training=self.training,
632
+ ):
633
+ return None
634
+
635
+ dtype, device = input_tensor.dtype, input_tensor.device
636
+ sequence_length = input_tensor.shape[1]
637
+ if using_static_cache:
638
+ target_length = past_key_values.get_max_cache_shape()
639
+ else:
640
+ target_length = (
641
+ attention_mask.shape[-1]
642
+ if isinstance(attention_mask, torch.Tensor)
643
+ else past_seen_tokens + sequence_length + 1
644
+ )
645
+
646
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
647
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
648
+ attention_mask,
649
+ sequence_length=sequence_length,
650
+ target_length=target_length,
651
+ dtype=dtype,
652
+ device=device,
653
+ cache_position=cache_position,
654
+ batch_size=input_tensor.shape[0],
655
+ )
656
+
657
+ if (
658
+ self.config._attn_implementation == "sdpa"
659
+ and attention_mask is not None
660
+ and attention_mask.device.type == "cuda"
661
+ and not output_attentions
662
+ ):
663
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
664
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
665
+ # Details: https://github.com/pytorch/pytorch/issues/110213
666
+ min_dtype = torch.finfo(dtype).min
667
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
668
+
669
+ return causal_mask
670
+
671
+ @staticmethod
672
+ def _prepare_4d_causal_attention_mask_with_cache_position(
673
+ attention_mask: torch.Tensor,
674
+ sequence_length: int,
675
+ target_length: int,
676
+ dtype: torch.dtype,
677
+ device: torch.device,
678
+ cache_position: torch.Tensor,
679
+ batch_size: int,
680
+ **kwargs,
681
+ ):
682
+ """
683
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
684
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
685
+
686
+ Args:
687
+ attention_mask (`torch.Tensor`):
688
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
689
+ `(batch_size, 1, query_length, key_value_length)`.
690
+ sequence_length (`int`):
691
+ The sequence length being processed.
692
+ target_length (`int`):
693
+ The target length: when generating with static cache, the mask should be as long as the static cache,
694
+ to account for the 0 padding, the part of the cache that is not filled yet.
695
+ dtype (`torch.dtype`):
696
+ The dtype to use for the 4D attention mask.
697
+ device (`torch.device`):
698
+ The device to plcae the 4D attention mask on.
699
+ cache_position (`torch.Tensor`):
700
+ Indices depicting the position of the input sequence tokens in the sequence.
701
+ batch_size (`torch.Tensor`):
702
+ Batch size.
703
+ """
704
+ if attention_mask is not None and attention_mask.dim() == 4:
705
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
706
+ causal_mask = attention_mask
707
+ else:
708
+ min_dtype = torch.finfo(dtype).min
709
+ causal_mask = torch.full(
710
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
711
+ )
712
+ if sequence_length != 1:
713
+ causal_mask = torch.triu(causal_mask, diagonal=1)
714
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
715
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
716
+ if attention_mask is not None:
717
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
718
+ mask_length = attention_mask.shape[-1]
719
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
720
+ padding_mask = padding_mask == 0
721
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
722
+ padding_mask, min_dtype
723
+ )
724
+
725
+ return causal_mask
726
+
727
+
728
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
729
+
730
+
731
+ class Spectra2ForCausalLM(Spectra2PreTrainedModel, GenerationMixin):
732
+ _tied_weights_keys = ["lm_head.weight"]
733
+ _tp_plan = {"lm_head": "colwise_rep"}
734
+
735
+ def __init__(self, config):
736
+ super().__init__(config)
737
+ self.model = Spectra2Model(config)
738
+ self.vocab_size = config.vocab_size
739
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self):
745
+ return self.model.embed_tokens
746
+
747
+ def set_input_embeddings(self, value):
748
+ self.model.embed_tokens = value
749
+
750
+ def get_output_embeddings(self):
751
+ return self.lm_head
752
+
753
+ def set_output_embeddings(self, new_embeddings):
754
+ self.lm_head = new_embeddings
755
+
756
+ def set_decoder(self, decoder):
757
+ self.model = decoder
758
+
759
+ def get_decoder(self):
760
+ return self.model
761
+
762
+ @add_start_docstrings_to_model_forward(SPECTRA2_INPUTS_DOCSTRING)
763
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
764
+ def forward(
765
+ self,
766
+ input_ids: torch.LongTensor = None,
767
+ attention_mask: Optional[torch.Tensor] = None,
768
+ position_ids: Optional[torch.LongTensor] = None,
769
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
770
+ inputs_embeds: Optional[torch.FloatTensor] = None,
771
+ labels: Optional[torch.LongTensor] = None,
772
+ use_cache: Optional[bool] = None,
773
+ output_attentions: Optional[bool] = None,
774
+ output_hidden_states: Optional[bool] = None,
775
+ return_dict: Optional[bool] = None,
776
+ cache_position: Optional[torch.LongTensor] = None,
777
+ num_logits_to_keep: int = 0,
778
+ **kwargs: Unpack[KwargsForCausalLM],
779
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
780
+ r"""
781
+ Args:
782
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
783
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
784
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
785
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
786
+
787
+ num_logits_to_keep (`int`, *optional*):
788
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
789
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
790
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
791
+
792
+ Returns:
793
+
794
+ Example:
795
+
796
+ ```python
797
+ >>> from transformers import AutoTokenizer, Spectra2ForCausalLM
798
+
799
+ >>> model = Spectra2ForCausalLM.from_pretrained("SpectraSuite/Spectra2-3B-base")
800
+ >>> tokenizer = AutoTokenizer.from_pretrained("SpectraSuite/Spectra2-3B-base")
801
+
802
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
803
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
804
+
805
+ >>> # Generate
806
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
807
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
808
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
809
+ ```"""
810
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
811
+ output_hidden_states = (
812
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
813
+ )
814
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
815
+
816
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
817
+ outputs = self.model(
818
+ input_ids=input_ids,
819
+ attention_mask=attention_mask,
820
+ position_ids=position_ids,
821
+ past_key_values=past_key_values,
822
+ inputs_embeds=inputs_embeds,
823
+ use_cache=use_cache,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ cache_position=cache_position,
828
+ **kwargs,
829
+ )
830
+
831
+ hidden_states = outputs[0]
832
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
833
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
834
+
835
+ loss = None
836
+ if labels is not None:
837
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
838
+
839
+ if not return_dict:
840
+ output = (logits,) + outputs[1:]
841
+ return (loss,) + output if loss is not None else output
842
+
843
+ return CausalLMOutputWithPast(
844
+ loss=loss,
845
+ logits=logits,
846
+ past_key_values=outputs.past_key_values,
847
+ hidden_states=outputs.hidden_states,
848
+ attentions=outputs.attentions,
849
+ )
17_99p_300t/step70000/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
17_99p_300t/step70000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
17_99p_300t/step70000/tokenizer_config.json ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ }
206
+ },
207
+ "bos_token": "<|endoftext|>",
208
+ "clean_up_tokenization_spaces": true,
209
+ "eos_token": "<|endoftext|>",
210
+ "model_max_length": 1000000000000000019884624838656,
211
+ "pad_token": null,
212
+ "tokenizer_class": "GPTNeoXTokenizer",
213
+ "unk_token": "<|endoftext|>"
214
+ }
17_99p_300t/step80000/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Spectra2ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "clip_qkv": null,
8
+ "eos_token_id": 50277,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 512,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1280,
13
+ "max_position_embeddings": 2048,
14
+ "model_type": "spectra2",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 16,
17
+ "num_key_value_heads": 8,
18
+ "pad_token_id": 1,
19
+ "rope_scaling": null,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "torch_dtype": "float32",
23
+ "use_cache": true,
24
+ "vocab_size": 50304,
25
+ "auto_map": {
26
+ "AutoConfig": "configuration_spectra2.Spectra2Config",
27
+ "AutoModel": "modeling_spectra2.Spectra2Model",
28
+ "AutoModelForCausalLM": "modeling_spectra2.Spectra2ForCausalLM"
29
+ }
30
+ }
17_99p_300t/step80000/configuration_spectra2.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Spectra2 model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class Spectra2Config(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`Spectra2Model`]. It is used to instantiate an Spectra2
32
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
+ defaults will yield a similar configuration to that of the [SpectraSuite/Spectra2-3B-base](https://huggingface.co/spectrasuite/Spectra2-3B-base).
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 50304):
38
+ Vocabulary size of the Spectra2 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`Spectra2Model`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 11008):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer decoder.
48
+ num_key_value_heads (`int`, *optional*):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
55
+ `num_attention_heads`.
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
57
+ The non-linear activation function (function or string) in the decoder.
58
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
59
+ The maximum sequence length that this model might ever be used with.
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ use_cache (`bool`, *optional*, defaults to `True`):
63
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
64
+ relevant if `config.is_decoder=True`.
65
+ pad_token_id (`int`, *optional*, defaults to 1):
66
+ Padding token id.
67
+ bos_token_id (`int`, *optional*):
68
+ Beginning of stream token id.
69
+ eos_token_id (`int`, *optional*, defaults to 50279):
70
+ End of stream token id.
71
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
72
+ Whether to tie weight embeddings
73
+ rope_theta (`float`, *optional*, defaults to 10000.0):
74
+ The base period of the RoPE embeddings.
75
+ rope_scaling (`Dict`, *optional*):
76
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
77
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
78
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
79
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
80
+ these scaling strategies behave:
81
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
82
+ experimental feature, subject to breaking API changes in future versions.
83
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
84
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
85
+ attention_dropout (`float`, *optional*, defaults to 0.0):
86
+ The dropout ratio for the attention probabilities.
87
+ clip_qkv (`float`, *optional*):
88
+ If not `None`, elements of query, key and value attention states are clipped so that their
89
+ absolute value does not exceed this value.
90
+ ```python
91
+ >>> from transformers import Spectra2Model, Spectra2Config
92
+ >>> # Initializing a Spectra2 3B style configuration
93
+ >>> configuration = Spectra2Config()
94
+ >>> # Initializing a model from the Spectra2 3B style configuration
95
+ >>> model = Spectra2Model(configuration)
96
+ >>> # Accessing the model configuration
97
+ >>> configuration = model.config
98
+ ```"""
99
+
100
+ model_type = "spectra2"
101
+ keys_to_ignore_at_inference = ["past_key_values"]
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=50304,
106
+ hidden_size=4096,
107
+ intermediate_size=11008,
108
+ num_hidden_layers=32,
109
+ num_attention_heads=32,
110
+ num_key_value_heads=None,
111
+ hidden_act="silu",
112
+ max_position_embeddings=2048,
113
+ initializer_range=0.02,
114
+ use_cache=True,
115
+ pad_token_id=1,
116
+ bos_token_id=None,
117
+ eos_token_id=50279,
118
+ tie_word_embeddings=False,
119
+ rope_theta=10000.0,
120
+ rope_scaling=None,
121
+ attention_bias=False,
122
+ attention_dropout=0.0,
123
+ clip_qkv=None,
124
+ **kwargs,
125
+ ):
126
+ self.vocab_size = vocab_size
127
+ self.max_position_embeddings = max_position_embeddings
128
+ self.hidden_size = hidden_size
129
+ self.intermediate_size = intermediate_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.num_attention_heads = num_attention_heads
132
+
133
+ # for backward compatibility
134
+ if num_key_value_heads is None:
135
+ num_key_value_heads = num_attention_heads
136
+
137
+ self.num_key_value_heads = num_key_value_heads
138
+ self.hidden_act = hidden_act
139
+ self.initializer_range = initializer_range
140
+ self.use_cache = use_cache
141
+ self.rope_theta = rope_theta
142
+ self.rope_scaling = rope_scaling
143
+ self._rope_scaling_validation()
144
+ self.attention_bias = attention_bias
145
+ self.attention_dropout = attention_dropout
146
+ self.clip_qkv = clip_qkv
147
+
148
+ super().__init__(
149
+ pad_token_id=pad_token_id,
150
+ bos_token_id=bos_token_id,
151
+ eos_token_id=eos_token_id,
152
+ tie_word_embeddings=tie_word_embeddings,
153
+ **kwargs,
154
+ )
155
+
156
+ def _rope_scaling_validation(self):
157
+ """
158
+ Validate the `rope_scaling` configuration.
159
+ """
160
+ if self.rope_scaling is None:
161
+ return
162
+
163
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
164
+ raise ValueError(
165
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
166
+ )
167
+ rope_scaling_type = self.rope_scaling.get("type", None)
168
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
169
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
170
+ raise ValueError(
171
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
172
+ )
173
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
174
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
17_99p_300t/step80000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
17_99p_300t/step80000/tokenizer_config.json ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ }
206
+ },
207
+ "bos_token": "<|endoftext|>",
208
+ "clean_up_tokenization_spaces": true,
209
+ "eos_token": "<|endoftext|>",
210
+ "model_max_length": 1000000000000000019884624838656,
211
+ "pad_token": null,
212
+ "tokenizer_class": "GPTNeoXTokenizer",
213
+ "unk_token": "<|endoftext|>"
214
+ }
17_99p_300t/step90000/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ `pip install git+https://github.com/huggingface/transformers.git@05260a1`
17_99p_300t/step90000/configuration_spectra2.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Spectra2 model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class Spectra2Config(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`Spectra2Model`]. It is used to instantiate an Spectra2
32
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
+ defaults will yield a similar configuration to that of the [SpectraSuite/Spectra2-3B-base](https://huggingface.co/spectrasuite/Spectra2-3B-base).
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 50304):
38
+ Vocabulary size of the Spectra2 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`Spectra2Model`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 11008):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer decoder.
48
+ num_key_value_heads (`int`, *optional*):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
55
+ `num_attention_heads`.
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
57
+ The non-linear activation function (function or string) in the decoder.
58
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
59
+ The maximum sequence length that this model might ever be used with.
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ use_cache (`bool`, *optional*, defaults to `True`):
63
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
64
+ relevant if `config.is_decoder=True`.
65
+ pad_token_id (`int`, *optional*, defaults to 1):
66
+ Padding token id.
67
+ bos_token_id (`int`, *optional*):
68
+ Beginning of stream token id.
69
+ eos_token_id (`int`, *optional*, defaults to 50279):
70
+ End of stream token id.
71
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
72
+ Whether to tie weight embeddings
73
+ rope_theta (`float`, *optional*, defaults to 10000.0):
74
+ The base period of the RoPE embeddings.
75
+ rope_scaling (`Dict`, *optional*):
76
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
77
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
78
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
79
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
80
+ these scaling strategies behave:
81
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
82
+ experimental feature, subject to breaking API changes in future versions.
83
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
84
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
85
+ attention_dropout (`float`, *optional*, defaults to 0.0):
86
+ The dropout ratio for the attention probabilities.
87
+ clip_qkv (`float`, *optional*):
88
+ If not `None`, elements of query, key and value attention states are clipped so that their
89
+ absolute value does not exceed this value.
90
+ ```python
91
+ >>> from transformers import Spectra2Model, Spectra2Config
92
+ >>> # Initializing a Spectra2 3B style configuration
93
+ >>> configuration = Spectra2Config()
94
+ >>> # Initializing a model from the Spectra2 3B style configuration
95
+ >>> model = Spectra2Model(configuration)
96
+ >>> # Accessing the model configuration
97
+ >>> configuration = model.config
98
+ ```"""
99
+
100
+ model_type = "spectra2"
101
+ keys_to_ignore_at_inference = ["past_key_values"]
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=50304,
106
+ hidden_size=4096,
107
+ intermediate_size=11008,
108
+ num_hidden_layers=32,
109
+ num_attention_heads=32,
110
+ num_key_value_heads=None,
111
+ hidden_act="silu",
112
+ max_position_embeddings=2048,
113
+ initializer_range=0.02,
114
+ use_cache=True,
115
+ pad_token_id=1,
116
+ bos_token_id=None,
117
+ eos_token_id=50279,
118
+ tie_word_embeddings=False,
119
+ rope_theta=10000.0,
120
+ rope_scaling=None,
121
+ attention_bias=False,
122
+ attention_dropout=0.0,
123
+ clip_qkv=None,
124
+ **kwargs,
125
+ ):
126
+ self.vocab_size = vocab_size
127
+ self.max_position_embeddings = max_position_embeddings
128
+ self.hidden_size = hidden_size
129
+ self.intermediate_size = intermediate_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.num_attention_heads = num_attention_heads
132
+
133
+ # for backward compatibility
134
+ if num_key_value_heads is None:
135
+ num_key_value_heads = num_attention_heads
136
+
137
+ self.num_key_value_heads = num_key_value_heads
138
+ self.hidden_act = hidden_act
139
+ self.initializer_range = initializer_range
140
+ self.use_cache = use_cache
141
+ self.rope_theta = rope_theta
142
+ self.rope_scaling = rope_scaling
143
+ self._rope_scaling_validation()
144
+ self.attention_bias = attention_bias
145
+ self.attention_dropout = attention_dropout
146
+ self.clip_qkv = clip_qkv
147
+
148
+ super().__init__(
149
+ pad_token_id=pad_token_id,
150
+ bos_token_id=bos_token_id,
151
+ eos_token_id=eos_token_id,
152
+ tie_word_embeddings=tie_word_embeddings,
153
+ **kwargs,
154
+ )
155
+
156
+ def _rope_scaling_validation(self):
157
+ """
158
+ Validate the `rope_scaling` configuration.
159
+ """
160
+ if self.rope_scaling is None:
161
+ return
162
+
163
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
164
+ raise ValueError(
165
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
166
+ )
167
+ rope_scaling_type = self.rope_scaling.get("type", None)
168
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
169
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
170
+ raise ValueError(
171
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
172
+ )
173
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
174
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
17_99p_300t/step90000/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 50277,
4
+ "pad_token_id": 1,
5
+ "transformers_version": "4.45.2"
6
+ }
17_99p_300t/step90000/modeling_spectra2.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/spectra2/modular_spectra2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_spectra2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
+ from transformers.generation import GenerationMixin
16
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
17
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import (
23
+ LossKwargs,
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ replace_return_docstrings,
28
+ )
29
+ from .configuration_spectra2 import Spectra2Config
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+ _CONFIG_FOR_DOC = "Spectra2Config"
34
+
35
+
36
+ class Spectra2RMSLayerNorm(nn.Module):
37
+ """LayerNorm but with no learnable weight or bias."""
38
+
39
+ def __init__(self, hidden_size: int) -> None:
40
+ super().__init__()
41
+ self.weight = nn.Parameter(torch.ones(hidden_size))
42
+ self.variance_epsilon = 1e-05 # Hardcoded
43
+ self.normalized_shape = (hidden_size,)
44
+
45
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
46
+ with torch.autocast(enabled=False, device_type=hidden_states.device.type):
47
+ og_dtype = hidden_states.dtype
48
+ hidden_states = hidden_states.to(torch.float32)
49
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
50
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
51
+ hidden_states = hidden_states.to(og_dtype)
52
+ return self.weight * hidden_states
53
+
54
+
55
+ class Spectra2MLP(nn.Module):
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.config = config
59
+ self.hidden_size = config.hidden_size
60
+ self.intermediate_size = config.intermediate_size
61
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
62
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
63
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
64
+ self.act_fn = ACT2FN[config.hidden_act]
65
+
66
+ def forward(self, x):
67
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
68
+ return down_proj
69
+
70
+
71
+ def rotate_half(x):
72
+ """Rotates half the hidden dims of the input."""
73
+ x1 = x[..., : x.shape[-1] // 2]
74
+ x2 = x[..., x.shape[-1] // 2 :]
75
+ return torch.cat((-x2, x1), dim=-1)
76
+
77
+
78
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
79
+ """Applies Rotary Position Embedding to the query and key tensors.
80
+
81
+ Args:
82
+ q (`torch.Tensor`): The query tensor.
83
+ k (`torch.Tensor`): The key tensor.
84
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
85
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
86
+ position_ids (`torch.Tensor`, *optional*):
87
+ Deprecated and unused.
88
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
89
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
90
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
91
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
92
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
93
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
94
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
95
+ Returns:
96
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
97
+ """
98
+ cos = cos.unsqueeze(unsqueeze_dim)
99
+ sin = sin.unsqueeze(unsqueeze_dim)
100
+ q_embed = (q * cos) + (rotate_half(q) * sin)
101
+ k_embed = (k * cos) + (rotate_half(k) * sin)
102
+ return q_embed, k_embed
103
+
104
+
105
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
106
+ """
107
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
108
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
109
+ """
110
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
111
+ if n_rep == 1:
112
+ return hidden_states
113
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
114
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
115
+
116
+
117
+ def eager_attention_forward(
118
+ module: nn.Module,
119
+ query: torch.Tensor,
120
+ key: torch.Tensor,
121
+ value: torch.Tensor,
122
+ attention_mask: Optional[torch.Tensor],
123
+ scaling: float,
124
+ dropout: float = 0.0,
125
+ **kwargs,
126
+ ):
127
+ key_states = repeat_kv(key, module.num_key_value_groups)
128
+ value_states = repeat_kv(value, module.num_key_value_groups)
129
+
130
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
131
+ if attention_mask is not None:
132
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
133
+ attn_weights = attn_weights + causal_mask
134
+
135
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
136
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
137
+ attn_output = torch.matmul(attn_weights, value_states)
138
+ attn_output = attn_output.transpose(1, 2).contiguous()
139
+
140
+ return attn_output, attn_weights
141
+
142
+
143
+ class Spectra2Attention(nn.Module):
144
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
145
+
146
+ def __init__(self, config: Spectra2Config, layer_idx: int):
147
+ super().__init__()
148
+ self.config = config
149
+ self.layer_idx = layer_idx
150
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
151
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
152
+ self.scaling = self.head_dim**-0.5
153
+ self.attention_dropout = config.attention_dropout
154
+ self.is_causal = True
155
+
156
+ self.q_proj = nn.Linear(
157
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
158
+ )
159
+ self.k_proj = nn.Linear(
160
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
161
+ )
162
+ self.v_proj = nn.Linear(
163
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
164
+ )
165
+ self.o_proj = nn.Linear(
166
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
167
+ )
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
173
+ attention_mask: Optional[torch.Tensor],
174
+ past_key_value: Optional[Cache] = None,
175
+ cache_position: Optional[torch.LongTensor] = None,
176
+ **kwargs,
177
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
178
+ input_shape = hidden_states.shape[:-1]
179
+ hidden_shape = (*input_shape, -1, self.head_dim)
180
+
181
+ query_states = self.q_proj(hidden_states)
182
+ key_states = self.k_proj(hidden_states)
183
+ value_states = self.v_proj(hidden_states)
184
+
185
+ if self.config.clip_qkv is not None:
186
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
187
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
188
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
189
+
190
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
191
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
192
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
193
+
194
+ cos, sin = position_embeddings
195
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
196
+
197
+ if past_key_value is not None:
198
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
199
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
200
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
201
+
202
+ attention_interface: Callable = eager_attention_forward
203
+ if self.config._attn_implementation != "eager":
204
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
205
+ logger.warning_once(
206
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
207
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
208
+ )
209
+ else:
210
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
211
+
212
+ attn_output, attn_weights = attention_interface(
213
+ self,
214
+ query_states,
215
+ key_states,
216
+ value_states,
217
+ attention_mask,
218
+ dropout=0.0 if not self.training else self.attention_dropout,
219
+ scaling=self.scaling,
220
+ **kwargs,
221
+ )
222
+
223
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
224
+ attn_output = self.o_proj(attn_output)
225
+ return attn_output, attn_weights
226
+
227
+
228
+ class Spectra2DecoderLayer(nn.Module):
229
+ def __init__(self, config: Spectra2Config, layer_idx: int):
230
+ super().__init__()
231
+ self.hidden_size = config.hidden_size
232
+ self.self_attn = Spectra2Attention(config=config, layer_idx=layer_idx)
233
+
234
+ self.mlp = Spectra2MLP(config)
235
+ self.input_rms_layernorm = Spectra2RMSLayerNorm(config.hidden_size)
236
+ self.post_attention_rms_layernorm = Spectra2RMSLayerNorm(config.hidden_size)
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ position_ids: Optional[torch.LongTensor] = None,
243
+ past_key_value: Optional[Cache] = None,
244
+ output_attentions: Optional[bool] = False,
245
+ use_cache: Optional[bool] = False,
246
+ cache_position: Optional[torch.LongTensor] = None,
247
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
248
+ **kwargs: Unpack[FlashAttentionKwargs],
249
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
250
+ residual = hidden_states
251
+
252
+ hidden_states = self.input_rms_layernorm(hidden_states)
253
+
254
+ # Self Attention
255
+ hidden_states, self_attn_weights = self.self_attn(
256
+ hidden_states=hidden_states,
257
+ attention_mask=attention_mask,
258
+ position_ids=position_ids,
259
+ past_key_value=past_key_value,
260
+ output_attentions=output_attentions,
261
+ use_cache=use_cache,
262
+ cache_position=cache_position,
263
+ position_embeddings=position_embeddings,
264
+ **kwargs,
265
+ )
266
+ hidden_states = residual + hidden_states
267
+
268
+ # Fully Connected
269
+ residual = hidden_states
270
+ hidden_states = self.post_attention_rms_layernorm(hidden_states)
271
+ hidden_states = self.mlp(hidden_states)
272
+ hidden_states = residual + hidden_states
273
+
274
+ outputs = (hidden_states,)
275
+ if output_attentions:
276
+ outputs += (self_attn_weights,)
277
+
278
+ return outputs
279
+
280
+
281
+ class Spectra2RotaryEmbedding(nn.Module):
282
+ def __init__(
283
+ self,
284
+ config: Spectra2Config,
285
+ device=None,
286
+ ):
287
+ super().__init__()
288
+ self.rope_kwargs = {}
289
+ # BC: "rope_type" was originally "type"
290
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
291
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
292
+ else:
293
+ self.rope_type = "default"
294
+ self.max_seq_len_cached = config.max_position_embeddings
295
+ self.original_max_seq_len = config.max_position_embeddings
296
+
297
+ self.config = config
298
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
299
+
300
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
301
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
302
+ self.original_inv_freq = self.inv_freq
303
+
304
+ def _dynamic_frequency_update(self, position_ids, device):
305
+ """
306
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
307
+ 1 - growing beyond the cached sequence length (allow scaling)
308
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
309
+ """
310
+ seq_len = torch.max(position_ids) + 1
311
+ if seq_len > self.max_seq_len_cached: # growth
312
+ inv_freq, self.attention_scaling = self.rope_init_fn(
313
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
314
+ )
315
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
316
+ self.max_seq_len_cached = seq_len
317
+
318
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
319
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
320
+ self.max_seq_len_cached = self.original_max_seq_len
321
+
322
+ @torch.no_grad()
323
+ def forward(self, x, position_ids):
324
+ if "dynamic" in self.rope_type:
325
+ self._dynamic_frequency_update(position_ids, device=x.device)
326
+
327
+ # Core RoPE block
328
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
329
+ position_ids_expanded = position_ids[:, None, :].float()
330
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
331
+ device_type = x.device.type
332
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
333
+ with torch.autocast(device_type=device_type, enabled=False):
334
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
335
+ emb = torch.cat((freqs, freqs), dim=-1)
336
+ cos = emb.cos()
337
+ sin = emb.sin()
338
+
339
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
340
+ cos = cos * self.attention_scaling
341
+ sin = sin * self.attention_scaling
342
+
343
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
344
+
345
+
346
+ SPECTRA2_START_DOCSTRING = r"""
347
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
348
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
349
+ etc.)
350
+
351
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
352
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
353
+ and behavior.
354
+
355
+ Parameters:
356
+ config ([`Spectra2Config`]):
357
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
358
+ load the weights associated with the model, only the configuration. Check out the
359
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
360
+ """
361
+
362
+
363
+ @add_start_docstrings(
364
+ "The bare Spectra2 Model outputting raw hidden-states without any specific head on top.",
365
+ SPECTRA2_START_DOCSTRING,
366
+ )
367
+ class Spectra2PreTrainedModel(PreTrainedModel):
368
+ config_class = Spectra2Config
369
+ base_model_prefix = "model"
370
+ supports_gradient_checkpointing = True
371
+ _no_split_modules = ["Spectra2DecoderLayer"]
372
+ _skip_keys_device_placement = ["past_key_values"]
373
+ _supports_flash_attn_2 = True
374
+ _supports_sdpa = True
375
+ _supports_cache_class = True
376
+ _supports_quantized_cache = True
377
+ _supports_static_cache = True
378
+
379
+ def _init_weights(self, module):
380
+ std = self.config.initializer_range
381
+ if isinstance(module, nn.Linear):
382
+ module.weight.data.normal_(mean=0.0, std=std)
383
+ if module.bias is not None:
384
+ module.bias.data.zero_()
385
+ elif isinstance(module, nn.Embedding):
386
+ module.weight.data.normal_(mean=0.0, std=std)
387
+ if module.padding_idx is not None:
388
+ module.weight.data[module.padding_idx].zero_()
389
+
390
+ SPECTRA2_INPUTS_DOCSTRING = r"""
391
+ Args:
392
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
393
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
394
+ it.
395
+
396
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
397
+ [`PreTrainedTokenizer.__call__`] for details.
398
+
399
+ [What are input IDs?](../glossary#input-ids)
400
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
401
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
402
+
403
+ - 1 for tokens that are **not masked**,
404
+ - 0 for tokens that are **masked**.
405
+
406
+ [What are attention masks?](../glossary#attention-mask)
407
+
408
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
409
+ [`PreTrainedTokenizer.__call__`] for details.
410
+
411
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
412
+ `past_key_values`).
413
+
414
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
415
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
416
+ information on the default strategy.
417
+
418
+ - 1 indicates the head is **not masked**,
419
+ - 0 indicates the head is **masked**.
420
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
421
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
422
+ config.n_positions - 1]`.
423
+
424
+ [What are position IDs?](../glossary#position-ids)
425
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
426
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
427
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
428
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
429
+
430
+ Two formats are allowed:
431
+ - a [`~cache_utils.Cache`] instance, see our
432
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
433
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
434
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
435
+ cache format.
436
+
437
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
438
+ legacy cache format will be returned.
439
+
440
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
441
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
442
+ of shape `(batch_size, sequence_length)`.
443
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
444
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
445
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
446
+ model's internal embedding lookup matrix.
447
+ use_cache (`bool`, *optional*):
448
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
449
+ `past_key_values`).
450
+ output_attentions (`bool`, *optional*):
451
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
452
+ tensors for more detail.
453
+ output_hidden_states (`bool`, *optional*):
454
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
455
+ more detail.
456
+ return_dict (`bool`, *optional*):
457
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
458
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
459
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
460
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
461
+ the complete sequence length.
462
+ """
463
+
464
+
465
+ @add_start_docstrings(
466
+ "The bare Spectra2 Model outputting raw hidden-states without any specific head on top.",
467
+ SPECTRA2_START_DOCSTRING,
468
+ )
469
+ class Spectra2Model(Spectra2PreTrainedModel):
470
+ """
471
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Spectra2DecoderLayer`]
472
+
473
+ Args:
474
+ config: Spectra2Config
475
+ """
476
+
477
+ def __init__(self, config: Spectra2Config):
478
+ super().__init__(config)
479
+ self.padding_idx = config.pad_token_id
480
+ self.vocab_size = config.vocab_size
481
+
482
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
483
+ self.layers = nn.ModuleList(
484
+ [Spectra2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
485
+ )
486
+ self.norm = Spectra2RMSLayerNorm(config.hidden_size)
487
+ self.rotary_emb = Spectra2RotaryEmbedding(config=config)
488
+ self.gradient_checkpointing = False
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ return self.embed_tokens
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.embed_tokens = value
498
+
499
+ @add_start_docstrings_to_model_forward(SPECTRA2_INPUTS_DOCSTRING)
500
+ def forward(
501
+ self,
502
+ input_ids: torch.LongTensor = None,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ position_ids: Optional[torch.LongTensor] = None,
505
+ past_key_values: Optional[Cache] = None,
506
+ inputs_embeds: Optional[torch.FloatTensor] = None,
507
+ use_cache: Optional[bool] = None,
508
+ output_attentions: Optional[bool] = None,
509
+ output_hidden_states: Optional[bool] = None,
510
+ return_dict: Optional[bool] = None,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
513
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
514
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
515
+ output_hidden_states = (
516
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
517
+ )
518
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
519
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
520
+
521
+ if (input_ids is None) ^ (inputs_embeds is not None):
522
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
523
+
524
+ if self.gradient_checkpointing and self.training and use_cache:
525
+ logger.warning_once(
526
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
527
+ )
528
+ use_cache = False
529
+
530
+ if inputs_embeds is None:
531
+ inputs_embeds = self.embed_tokens(input_ids)
532
+
533
+ if use_cache and past_key_values is None:
534
+ past_key_values = DynamicCache()
535
+
536
+ if cache_position is None:
537
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
538
+ cache_position = torch.arange(
539
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
540
+ )
541
+
542
+ if position_ids is None:
543
+ position_ids = cache_position.unsqueeze(0)
544
+
545
+ causal_mask = self._update_causal_mask(
546
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
547
+ )
548
+
549
+ hidden_states = inputs_embeds
550
+
551
+ # create position embeddings to be shared across the decoder layers
552
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
553
+
554
+ # decoder layers
555
+ all_hidden_states = () if output_hidden_states else None
556
+ all_self_attns = () if output_attentions else None
557
+
558
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
559
+ if output_hidden_states:
560
+ all_hidden_states += (hidden_states,)
561
+
562
+ if self.gradient_checkpointing and self.training:
563
+ layer_outputs = self._gradient_checkpointing_func(
564
+ decoder_layer.__call__,
565
+ hidden_states,
566
+ causal_mask,
567
+ position_ids,
568
+ past_key_values,
569
+ output_attentions,
570
+ use_cache,
571
+ cache_position,
572
+ position_embeddings,
573
+ )
574
+ else:
575
+ layer_outputs = decoder_layer(
576
+ hidden_states,
577
+ attention_mask=causal_mask,
578
+ position_ids=position_ids,
579
+ past_key_value=past_key_values,
580
+ output_attentions=output_attentions,
581
+ use_cache=use_cache,
582
+ cache_position=cache_position,
583
+ position_embeddings=position_embeddings,
584
+ **flash_attn_kwargs,
585
+ )
586
+
587
+ hidden_states = layer_outputs[0]
588
+
589
+ if output_attentions:
590
+ all_self_attns += (layer_outputs[1],)
591
+
592
+ hidden_states = self.norm(hidden_states)
593
+
594
+ # add hidden states from the last decoder layer
595
+ if output_hidden_states:
596
+ all_hidden_states += (hidden_states,)
597
+
598
+ output = BaseModelOutputWithPast(
599
+ last_hidden_state=hidden_states,
600
+ past_key_values=past_key_values if use_cache else None,
601
+ hidden_states=all_hidden_states,
602
+ attentions=all_self_attns,
603
+ )
604
+ return output if return_dict else output.to_tuple()
605
+
606
+ def _update_causal_mask(
607
+ self,
608
+ attention_mask: torch.Tensor,
609
+ input_tensor: torch.Tensor,
610
+ cache_position: torch.Tensor,
611
+ past_key_values: Cache,
612
+ output_attentions: bool,
613
+ ):
614
+ if self.config._attn_implementation == "flash_attention_2":
615
+ if attention_mask is not None and 0.0 in attention_mask:
616
+ return attention_mask
617
+ return None
618
+
619
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
620
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
621
+ # to infer the attention mask.
622
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
623
+ using_static_cache = isinstance(past_key_values, StaticCache)
624
+
625
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
626
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
627
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
628
+ attention_mask,
629
+ inputs_embeds=input_tensor,
630
+ past_key_values_length=past_seen_tokens,
631
+ is_training=self.training,
632
+ ):
633
+ return None
634
+
635
+ dtype, device = input_tensor.dtype, input_tensor.device
636
+ sequence_length = input_tensor.shape[1]
637
+ if using_static_cache:
638
+ target_length = past_key_values.get_max_cache_shape()
639
+ else:
640
+ target_length = (
641
+ attention_mask.shape[-1]
642
+ if isinstance(attention_mask, torch.Tensor)
643
+ else past_seen_tokens + sequence_length + 1
644
+ )
645
+
646
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
647
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
648
+ attention_mask,
649
+ sequence_length=sequence_length,
650
+ target_length=target_length,
651
+ dtype=dtype,
652
+ device=device,
653
+ cache_position=cache_position,
654
+ batch_size=input_tensor.shape[0],
655
+ )
656
+
657
+ if (
658
+ self.config._attn_implementation == "sdpa"
659
+ and attention_mask is not None
660
+ and attention_mask.device.type == "cuda"
661
+ and not output_attentions
662
+ ):
663
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
664
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
665
+ # Details: https://github.com/pytorch/pytorch/issues/110213
666
+ min_dtype = torch.finfo(dtype).min
667
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
668
+
669
+ return causal_mask
670
+
671
+ @staticmethod
672
+ def _prepare_4d_causal_attention_mask_with_cache_position(
673
+ attention_mask: torch.Tensor,
674
+ sequence_length: int,
675
+ target_length: int,
676
+ dtype: torch.dtype,
677
+ device: torch.device,
678
+ cache_position: torch.Tensor,
679
+ batch_size: int,
680
+ **kwargs,
681
+ ):
682
+ """
683
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
684
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
685
+
686
+ Args:
687
+ attention_mask (`torch.Tensor`):
688
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
689
+ `(batch_size, 1, query_length, key_value_length)`.
690
+ sequence_length (`int`):
691
+ The sequence length being processed.
692
+ target_length (`int`):
693
+ The target length: when generating with static cache, the mask should be as long as the static cache,
694
+ to account for the 0 padding, the part of the cache that is not filled yet.
695
+ dtype (`torch.dtype`):
696
+ The dtype to use for the 4D attention mask.
697
+ device (`torch.device`):
698
+ The device to plcae the 4D attention mask on.
699
+ cache_position (`torch.Tensor`):
700
+ Indices depicting the position of the input sequence tokens in the sequence.
701
+ batch_size (`torch.Tensor`):
702
+ Batch size.
703
+ """
704
+ if attention_mask is not None and attention_mask.dim() == 4:
705
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
706
+ causal_mask = attention_mask
707
+ else:
708
+ min_dtype = torch.finfo(dtype).min
709
+ causal_mask = torch.full(
710
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
711
+ )
712
+ if sequence_length != 1:
713
+ causal_mask = torch.triu(causal_mask, diagonal=1)
714
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
715
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
716
+ if attention_mask is not None:
717
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
718
+ mask_length = attention_mask.shape[-1]
719
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
720
+ padding_mask = padding_mask == 0
721
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
722
+ padding_mask, min_dtype
723
+ )
724
+
725
+ return causal_mask
726
+
727
+
728
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
729
+
730
+
731
+ class Spectra2ForCausalLM(Spectra2PreTrainedModel, GenerationMixin):
732
+ _tied_weights_keys = ["lm_head.weight"]
733
+ _tp_plan = {"lm_head": "colwise_rep"}
734
+
735
+ def __init__(self, config):
736
+ super().__init__(config)
737
+ self.model = Spectra2Model(config)
738
+ self.vocab_size = config.vocab_size
739
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self):
745
+ return self.model.embed_tokens
746
+
747
+ def set_input_embeddings(self, value):
748
+ self.model.embed_tokens = value
749
+
750
+ def get_output_embeddings(self):
751
+ return self.lm_head
752
+
753
+ def set_output_embeddings(self, new_embeddings):
754
+ self.lm_head = new_embeddings
755
+
756
+ def set_decoder(self, decoder):
757
+ self.model = decoder
758
+
759
+ def get_decoder(self):
760
+ return self.model
761
+
762
+ @add_start_docstrings_to_model_forward(SPECTRA2_INPUTS_DOCSTRING)
763
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
764
+ def forward(
765
+ self,
766
+ input_ids: torch.LongTensor = None,
767
+ attention_mask: Optional[torch.Tensor] = None,
768
+ position_ids: Optional[torch.LongTensor] = None,
769
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
770
+ inputs_embeds: Optional[torch.FloatTensor] = None,
771
+ labels: Optional[torch.LongTensor] = None,
772
+ use_cache: Optional[bool] = None,
773
+ output_attentions: Optional[bool] = None,
774
+ output_hidden_states: Optional[bool] = None,
775
+ return_dict: Optional[bool] = None,
776
+ cache_position: Optional[torch.LongTensor] = None,
777
+ num_logits_to_keep: int = 0,
778
+ **kwargs: Unpack[KwargsForCausalLM],
779
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
780
+ r"""
781
+ Args:
782
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
783
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
784
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
785
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
786
+
787
+ num_logits_to_keep (`int`, *optional*):
788
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
789
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
790
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
791
+
792
+ Returns:
793
+
794
+ Example:
795
+
796
+ ```python
797
+ >>> from transformers import AutoTokenizer, Spectra2ForCausalLM
798
+
799
+ >>> model = Spectra2ForCausalLM.from_pretrained("SpectraSuite/Spectra2-3B-base")
800
+ >>> tokenizer = AutoTokenizer.from_pretrained("SpectraSuite/Spectra2-3B-base")
801
+
802
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
803
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
804
+
805
+ >>> # Generate
806
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
807
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
808
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
809
+ ```"""
810
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
811
+ output_hidden_states = (
812
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
813
+ )
814
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
815
+
816
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
817
+ outputs = self.model(
818
+ input_ids=input_ids,
819
+ attention_mask=attention_mask,
820
+ position_ids=position_ids,
821
+ past_key_values=past_key_values,
822
+ inputs_embeds=inputs_embeds,
823
+ use_cache=use_cache,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ cache_position=cache_position,
828
+ **kwargs,
829
+ )
830
+
831
+ hidden_states = outputs[0]
832
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
833
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
834
+
835
+ loss = None
836
+ if labels is not None:
837
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
838
+
839
+ if not return_dict:
840
+ output = (logits,) + outputs[1:]
841
+ return (loss,) + output if loss is not None else output
842
+
843
+ return CausalLMOutputWithPast(
844
+ loss=loss,
845
+ logits=logits,
846
+ past_key_values=outputs.past_key_values,
847
+ hidden_states=outputs.hidden_states,
848
+ attentions=outputs.attentions,
849
+ )
17_99p_300t/step90000/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
17_99p_300t/step90000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
17_99p_300t/step90000/tokenizer_config.json ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ }
206
+ },
207
+ "bos_token": "<|endoftext|>",
208
+ "clean_up_tokenization_spaces": true,
209
+ "eos_token": "<|endoftext|>",
210
+ "model_max_length": 1000000000000000019884624838656,
211
+ "pad_token": null,
212
+ "tokenizer_class": "GPTNeoXTokenizer",
213
+ "unk_token": "<|endoftext|>"
214
+ }
data-indices/rank0.tsv.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cd74eac961030462ab2cab1b11d312e76a58c0901b2949ffdb4ba0eb6f5014b
3
+ size 453
data-indices/rank1.tsv.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba9480b440fcf9d31df4360bf05861a033a66a5c112beca7fdcde5c8cacac35c
3
+ size 451
data-indices/rank10.tsv.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5f4a0ad100dea970b8c7f451b0ddcf0a86a0f73b21d81f87db1c08a002c836a
3
+ size 452
data-indices/rank100.tsv.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:900a0fd7899a013c6cbb886f03d241ced0d7fc362c8439d27c7b930b7a564a4e
3
+ size 457