winglian commited on
Commit
cc3cebf
·
unverified ·
1 Parent(s): 5894f0e

Pydantic 2.x cfg (#1239)

Browse files

* WIP conversion to use pydantic for config validation

* wip, more fields, add capabilities

* wip

* update pydantic validation to match existing tests

* tweak requirements

* setup deprecated paams pydantic model

* more validations

* wrap up rest of the validations

* flesh out the rest of the options from the readme into pydantic

* fix model validators as class methods

remember to return in validator
missing return
add missing relora attributes
fix test for DictDefault change
fix sys template for mistral from fastchat change in PR 2872
fix test for batch size warning

* more missing attributes for cfg

* updates from PR feedback

* fix validation for datasets and pretrain datasets

* fix test for lora check

.mypy.ini CHANGED
@@ -1,5 +1,5 @@
1
  [mypy]
2
-
3
  exclude = venv
4
 
5
  [mypy-alpaca_lora_4bit.*]
 
1
  [mypy]
2
+ plugins = pydantic.mypy
3
  exclude = venv
4
 
5
  [mypy-alpaca_lora_4bit.*]
.pre-commit-config.yaml CHANGED
@@ -31,6 +31,7 @@ repos:
31
  additional_dependencies:
32
  [
33
  'types-PyYAML',
 
34
  ]
35
  - repo: https://github.com/PyCQA/bandit
36
  rev: 1.7.5
 
31
  additional_dependencies:
32
  [
33
  'types-PyYAML',
34
+ 'pydantic>=2.5.3',
35
  ]
36
  - repo: https://github.com/PyCQA/bandit
37
  rev: 1.7.5
README.md CHANGED
@@ -543,7 +543,7 @@ is_mistral_derived_model:
543
  is_qwen_derived_model:
544
 
545
  # optional overrides to the base model configuration
546
- model_config:
547
  # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
548
  rope_scaling:
549
  type: # linear | dynamic
@@ -560,8 +560,6 @@ bnb_config_kwargs:
560
 
561
  # Whether you are training a 4-bit GPTQ quantized model
562
  gptq: true
563
- gptq_groupsize: 128 # group size
564
- gptq_model_v1: false # v1 or v2
565
 
566
  # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
567
  load_in_8bit: true
@@ -819,10 +817,6 @@ cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosin
819
  # For one_cycle optim
820
  lr_div_factor: # Learning rate div factor
821
 
822
- # For log_sweep optim
823
- log_sweep_min_lr:
824
- log_sweep_max_lr:
825
-
826
  # Specify optimizer
827
  # Valid values are driven by the Transformers OptimizerNames class, see:
828
  # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
 
543
  is_qwen_derived_model:
544
 
545
  # optional overrides to the base model configuration
546
+ model_config_overrides:
547
  # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
548
  rope_scaling:
549
  type: # linear | dynamic
 
560
 
561
  # Whether you are training a 4-bit GPTQ quantized model
562
  gptq: true
 
 
563
 
564
  # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
565
  load_in_8bit: true
 
817
  # For one_cycle optim
818
  lr_div_factor: # Learning rate div factor
819
 
 
 
 
 
820
  # Specify optimizer
821
  # Valid values are driven by the Transformers OptimizerNames class, see:
822
  # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
requirements.txt CHANGED
@@ -6,6 +6,7 @@ tokenizers==0.15.0
6
  bitsandbytes>=0.41.1
7
  accelerate==0.26.1
8
  deepspeed>=0.13.1
 
9
  addict
10
  fire
11
  PyYAML>=6.0
@@ -27,7 +28,7 @@ scipy
27
  scikit-learn==1.2.2
28
  pynvml
29
  art
30
- fschat==0.2.34
31
  gradio==3.50.2
32
  tensorboard
33
 
 
6
  bitsandbytes>=0.41.1
7
  accelerate==0.26.1
8
  deepspeed>=0.13.1
9
+ pydantic>=2.5.3
10
  addict
11
  fire
12
  PyYAML>=6.0
 
28
  scikit-learn==1.2.2
29
  pynvml
30
  art
31
+ fschat==0.2.36
32
  gradio==3.50.2
33
  tensorboard
34
 
src/axolotl/cli/__init__.py CHANGED
@@ -24,11 +24,13 @@ from art import text2art
24
  from huggingface_hub import HfApi
25
  from huggingface_hub.utils import LocalTokenNotFoundError
26
  from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
 
27
 
28
  from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
29
  from axolotl.logging_config import configure_logging
30
  from axolotl.train import TrainDatasetMeta
31
  from axolotl.utils.config import (
 
32
  normalize_cfg_datasets,
33
  normalize_config,
34
  validate_config,
@@ -328,7 +330,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
328
  # load the config from the yaml file
329
  with open(config, encoding="utf-8") as file:
330
  cfg: DictDefault = DictDefault(yaml.safe_load(file))
331
- cfg.axolotl_config_path = config
332
  # if there are any options passed in the cli, if it is something that seems valid from the yaml,
333
  # then overwrite the value
334
  cfg_keys = cfg.keys()
@@ -341,7 +342,21 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
341
  else:
342
  cfg[k] = kwargs[k]
343
 
344
- validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  prepare_optim_env(cfg)
347
 
 
24
  from huggingface_hub import HfApi
25
  from huggingface_hub.utils import LocalTokenNotFoundError
26
  from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
27
+ from transformers.utils import is_torch_bf16_gpu_available
28
 
29
  from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
30
  from axolotl.logging_config import configure_logging
31
  from axolotl.train import TrainDatasetMeta
32
  from axolotl.utils.config import (
33
+ GPUCapabilities,
34
  normalize_cfg_datasets,
35
  normalize_config,
36
  validate_config,
 
330
  # load the config from the yaml file
331
  with open(config, encoding="utf-8") as file:
332
  cfg: DictDefault = DictDefault(yaml.safe_load(file))
 
333
  # if there are any options passed in the cli, if it is something that seems valid from the yaml,
334
  # then overwrite the value
335
  cfg_keys = cfg.keys()
 
342
  else:
343
  cfg[k] = kwargs[k]
344
 
345
+ cfg.axolotl_config_path = config
346
+
347
+ try:
348
+ device_props = torch.cuda.get_device_properties("cuda")
349
+ gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
350
+ except: # pylint: disable=bare-except # noqa: E722
351
+ gpu_version = None
352
+
353
+ capabilities = GPUCapabilities(
354
+ bf16=is_torch_bf16_gpu_available(),
355
+ n_gpu=os.environ.get("WORLD_SIZE", 1),
356
+ compute_capability=gpu_version,
357
+ )
358
+
359
+ cfg = validate_config(cfg, capabilities=capabilities)
360
 
361
  prepare_optim_env(cfg)
362
 
src/axolotl/utils/{config.py → config/__init__.py} RENAMED
@@ -3,11 +3,17 @@ import json
3
  import logging
4
  import os
5
  from pathlib import Path
 
6
 
7
  import torch
8
  from transformers.utils import is_torch_bf16_gpu_available
9
 
10
  from axolotl.utils.bench import log_gpu_memory_usage
 
 
 
 
 
11
  from axolotl.utils.dict import DictDefault
12
  from axolotl.utils.models import load_model_config
13
 
@@ -191,7 +197,15 @@ def normalize_cfg_datasets(cfg):
191
  cfg.datasets[idx].conversation = "chatml"
192
 
193
 
194
- def validate_config(cfg):
 
 
 
 
 
 
 
 
195
  """
196
  This is a "pre-validation" step that handles the yaml configuration before we have any
197
  information about the model architecture
@@ -480,9 +494,6 @@ def validate_config(cfg):
480
  if cfg.rope_scaling:
481
  LOG.warning("`rope_scaling` should now be be a key under `model_config`")
482
 
483
- if cfg.warmup_steps and cfg.warmup_ratio:
484
- raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
485
-
486
  if cfg.wandb_run_id and not cfg.wandb_name:
487
  cfg.wandb_name = cfg.wandb_run_id
488
 
 
3
  import logging
4
  import os
5
  from pathlib import Path
6
+ from typing import Optional
7
 
8
  import torch
9
  from transformers.utils import is_torch_bf16_gpu_available
10
 
11
  from axolotl.utils.bench import log_gpu_memory_usage
12
+ from axolotl.utils.config.models.input.v0_4_1 import (
13
+ AxolotlConfigWCapabilities,
14
+ AxolotlInputConfig,
15
+ )
16
+ from axolotl.utils.config.models.internals import GPUCapabilities
17
  from axolotl.utils.dict import DictDefault
18
  from axolotl.utils.models import load_model_config
19
 
 
197
  cfg.datasets[idx].conversation = "chatml"
198
 
199
 
200
+ def validate_config(cfg: DictDefault, capabilities: Optional[GPUCapabilities] = None):
201
+ if capabilities:
202
+ return DictDefault(
203
+ dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
204
+ )
205
+ return DictDefault(dict(AxolotlInputConfig(**cfg.to_dict())))
206
+
207
+
208
+ def legacy_validate_config(cfg):
209
  """
210
  This is a "pre-validation" step that handles the yaml configuration before we have any
211
  information about the model architecture
 
494
  if cfg.rope_scaling:
495
  LOG.warning("`rope_scaling` should now be be a key under `model_config`")
496
 
 
 
 
497
  if cfg.wandb_run_id and not cfg.wandb_name:
498
  cfg.wandb_name = cfg.wandb_run_id
499
 
src/axolotl/utils/config/models/__init__.py ADDED
File without changes
src/axolotl/utils/config/models/input/__init__.py ADDED
File without changes
src/axolotl/utils/config/models/input/next/__init__.py ADDED
File without changes
src/axolotl/utils/config/models/input/v0_4_1/__init__.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for pydantic models for configuration
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ from enum import Enum
8
+ from typing import Any, Dict, List, Literal, Optional, Union
9
+
10
+ from pydantic import BaseModel, Field, conlist, field_validator, model_validator
11
+ from transformers import SchedulerType
12
+ from transformers.training_args import OptimizerNames
13
+
14
+ from axolotl.utils.config.models.internals import GPUCapabilities
15
+
16
+ LOG = logging.getLogger("axolotl.utils.config.models.input")
17
+
18
+
19
+ class DeprecatedParameters(BaseModel):
20
+ """configurations that are deprecated"""
21
+
22
+ max_packed_sequence_len: Optional[int] = None
23
+ rope_scaling: Optional[Any] = None
24
+ noisy_embedding_alpha: Optional[float] = None
25
+
26
+ @field_validator("max_packed_sequence_len")
27
+ @classmethod
28
+ def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
29
+ if max_packed_sequence_len:
30
+ raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
31
+ return max_packed_sequence_len
32
+
33
+ @field_validator("rope_scaling")
34
+ @classmethod
35
+ def validate_rope_scaling(cls, rope_scaling):
36
+ if rope_scaling:
37
+ raise DeprecationWarning(
38
+ "`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
39
+ )
40
+ return rope_scaling
41
+
42
+ @field_validator("noisy_embedding_alpha")
43
+ @classmethod
44
+ def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
45
+ if noisy_embedding_alpha:
46
+ LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
47
+ return noisy_embedding_alpha
48
+
49
+
50
+ class PretrainingDataset(BaseModel):
51
+ """pretraining dataset configuration subset"""
52
+
53
+ path: Optional[str] = None
54
+
55
+
56
+ class UserDefinedPrompterType(BaseModel):
57
+ """structure for user defined prompt types"""
58
+
59
+ system_prompt: Optional[str] = None
60
+ system_format: Optional[str] = None
61
+ field_system: Optional[str] = None
62
+ field_instruction: Optional[str] = None
63
+ field_input: Optional[str] = None
64
+ field_output: Optional[str] = None
65
+
66
+ format: Optional[str] = None
67
+ no_input_format: Optional[str] = None
68
+ field: Optional[str] = None
69
+
70
+
71
+ class SFTDataset(BaseModel):
72
+ """SFT configuration subset"""
73
+
74
+ path: Optional[str] = None
75
+ split: Optional[str] = None
76
+ type: Optional[Union[str, UserDefinedPrompterType]] = None
77
+ shards: Optional[int] = None
78
+ conversation: Optional[str] = None
79
+ data_files: Optional[List[str]] = None
80
+ name: Optional[str] = None
81
+ ds_type: Optional[str] = None
82
+ train_on_split: Optional[str] = None
83
+
84
+ field_human: Optional[str] = None
85
+ field_model: Optional[str] = None
86
+
87
+
88
+ class DPODataset(BaseModel):
89
+ """DPO configuration subset"""
90
+
91
+ path: Optional[str] = None
92
+ split: Optional[str] = None
93
+ type: Optional[str] = None
94
+ data_files: Optional[List[str]] = None
95
+
96
+
97
+ class RLType(str, Enum):
98
+ """RL trainer type configuration subset"""
99
+
100
+ dpo = "dpo" # pylint: disable=invalid-name
101
+ ipo = "ipo" # pylint: disable=invalid-name
102
+ kto_pair = "kto_pair" # pylint: disable=invalid-name
103
+
104
+
105
+ class ChatTemplate(str, Enum):
106
+ """Chat templates configuration subset"""
107
+
108
+ chatml = "chatml" # pylint: disable=invalid-name
109
+ inst = "inst" # pylint: disable=invalid-name
110
+
111
+
112
+ class LoftQConfig(BaseModel):
113
+ """LoftQ configuration subset"""
114
+
115
+ loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"})
116
+ # loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
117
+
118
+
119
+ class PeftConfig(BaseModel):
120
+ """peftq configuration subset"""
121
+
122
+ loftq_config: Optional[LoftQConfig] = None
123
+
124
+
125
+ class AutoType(str, Enum):
126
+ """auto type string configuration subset - used for bf16"""
127
+
128
+ AUTO = "auto"
129
+
130
+
131
+ class SpecialTokensConfig(BaseModel):
132
+ """Special tokens configuration subset"""
133
+
134
+ bos_token: Optional[str] = None
135
+ eos_token: Optional[str] = None
136
+ pad_token: Optional[str] = None
137
+ unk_token: Optional[str] = None
138
+ additional_special_tokens: Optional[List[str]] = None
139
+
140
+
141
+ class LoraConfig(BaseModel):
142
+ """Peft / LoRA configuration subset"""
143
+
144
+ load_in_8bit: Optional[bool] = Field(default=False)
145
+ load_in_4bit: Optional[bool] = Field(default=False)
146
+
147
+ adapter: Optional[str] = None
148
+ lora_model_dir: Optional[str] = None
149
+ lora_rank: Optional[int] = None
150
+ lora_alpha: Optional[int] = None
151
+ lora_fan_in_fan_out: Optional[bool] = None
152
+ lora_target_modules: Optional[List[str]] = None
153
+ lora_target_linear: Optional[bool] = None
154
+ lora_modules_to_save: Optional[List[str]] = None
155
+ lora_dropout: Optional[float] = None
156
+ peft_layers_to_transform: Optional[List[int]] = None
157
+ peft: Optional[PeftConfig] = None
158
+
159
+ lora_on_cpu: Optional[bool] = None
160
+ gptq: Optional[bool] = None
161
+ bnb_config_kwargs: Optional[Dict[str, Any]] = None
162
+
163
+ merge_lora: Optional[bool] = None
164
+
165
+ @model_validator(mode="before")
166
+ @classmethod
167
+ def validate_adapter(cls, data):
168
+ if not data.get("adapter") and (
169
+ data.get("load_in_8bit") or data.get("load_in_4bit")
170
+ ):
171
+ raise ValueError(
172
+ "load_in_8bit and load_in_4bit are not supported without setting an adapter."
173
+ "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
174
+ )
175
+ return data
176
+
177
+ @model_validator(mode="after")
178
+ def validate_qlora(self):
179
+ if self.adapter == "qlora":
180
+ if self.merge_lora:
181
+ # can't merge qlora if loaded in 8bit or 4bit
182
+ if self.load_in_8bit:
183
+ raise ValueError("Can't merge qlora if loaded in 8bit")
184
+
185
+ if self.gptq:
186
+ raise ValueError("Can't merge qlora if gptq")
187
+
188
+ if self.load_in_4bit:
189
+ raise ValueError("Can't merge qlora if loaded in 4bit")
190
+
191
+ else:
192
+ if self.load_in_8bit:
193
+ raise ValueError("Can't load qlora in 8bit")
194
+
195
+ if self.gptq:
196
+ raise ValueError("Can't load qlora if gptq")
197
+
198
+ if not self.load_in_4bit:
199
+ raise ValueError("Require cfg.load_in_4bit to be True for qlora")
200
+ return self
201
+
202
+
203
+ class ReLoRAConfig(BaseModel):
204
+ """ReLoRA configuration subset"""
205
+
206
+ relora_steps: Optional[int] = None
207
+ relora_warmup_steps: Optional[int] = None
208
+ relora_anneal_steps: Optional[int] = None
209
+ relora_prune_ratio: Optional[float] = None
210
+ relora_cpu_offload: Optional[bool] = None
211
+
212
+
213
+ class ModelInputConfig(BaseModel):
214
+ """model to train on configuration subset"""
215
+
216
+ base_model: str
217
+ base_model_config: Optional[str] = None
218
+ tokenizer_config: Optional[str] = None
219
+ tokenizer_use_fast: Optional[bool] = None
220
+ tokenizer_legacy: Optional[bool] = None
221
+ tokenizer_type: Optional[str] = Field(
222
+ default=None, metadata={"help": "transformers tokenizer class"}
223
+ )
224
+ model_type: Optional[str] = Field(default=None)
225
+ model_revision: Optional[str] = None
226
+ trust_remote_code: Optional[bool] = None
227
+
228
+ model_config_overrides: Optional[Dict[str, Any]] = None
229
+
230
+ @field_validator("trust_remote_code")
231
+ @classmethod
232
+ def hint_trust_remote_code(cls, trust_remote_code):
233
+ if trust_remote_code:
234
+ LOG.warning(
235
+ "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
236
+ )
237
+ return trust_remote_code
238
+
239
+
240
+ class HyperparametersConfig(BaseModel):
241
+ """training hyperparams configuration subset"""
242
+
243
+ gradient_accumulation_steps: Optional[int] = Field(default=1)
244
+ micro_batch_size: Optional[int] = Field(
245
+ default=1,
246
+ metadata={"help": "per gpu micro batch size for training"},
247
+ )
248
+ batch_size: Optional[int] = Field(
249
+ default=None,
250
+ metadata={
251
+ "help": "Total batch size, we do not recommended setting this manually"
252
+ },
253
+ )
254
+ eval_batch_size: Optional[int] = Field(
255
+ default=None,
256
+ metadata={
257
+ "help": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
258
+ },
259
+ )
260
+
261
+ train_on_inputs: Optional[bool] = None
262
+ group_by_length: Optional[bool] = None
263
+
264
+ learning_rate: Union[str, float]
265
+ weight_decay: Optional[float] = None
266
+ optimizer: Optional[OptimizerNames] = None
267
+ torchdistx_path: Optional[str] = None
268
+ lr_scheduler: Optional[SchedulerType] = None
269
+ lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
270
+ lr_quadratic_warmup: Optional[bool] = None
271
+ cosine_min_lr_ratio: Optional[float] = None
272
+ cosine_constant_lr_ratio: Optional[float] = None
273
+ lr_div_factor: Optional[float] = None
274
+
275
+ adam_epsilon: Optional[float] = None
276
+ adam_beta1: Optional[float] = None
277
+ adam_beta2: Optional[float] = None
278
+ max_grad_norm: Optional[float] = None
279
+ num_epochs: int = Field(default=1)
280
+
281
+ @field_validator("batch_size")
282
+ @classmethod
283
+ def hint_batch_size_set(cls, batch_size):
284
+ if batch_size:
285
+ LOG.warning(
286
+ "%s\n%s",
287
+ "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
288
+ "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
289
+ )
290
+ return batch_size
291
+
292
+
293
+ class ModelOutputConfig(BaseModel):
294
+ """model save configuration subset"""
295
+
296
+ output_dir: str = Field(default="./model-out")
297
+ hub_model_id: Optional[str] = None
298
+ hub_strategy: Optional[str] = None
299
+ save_safetensors: Optional[bool] = None
300
+
301
+
302
+ class MLFlowConfig(BaseModel):
303
+ """mlflow configuration subset"""
304
+
305
+ use_mlflow: Optional[str] = None
306
+ mlflow_tracking_uri: Optional[str] = None
307
+ mlflow_experiment_name: Optional[str] = None
308
+
309
+
310
+ class WandbConfig(BaseModel):
311
+ """wandb configuration subset"""
312
+
313
+ use_wandb: Optional[bool] = None
314
+ wandb_name: Optional[str] = None
315
+ wandb_run_id: Optional[str] = None
316
+ wandb_mode: Optional[str] = None
317
+ wandb_project: Optional[str] = None
318
+ wandb_entity: Optional[str] = None
319
+ wandb_watch: Optional[str] = None
320
+ wandb_log_model: Optional[str] = None
321
+
322
+ @model_validator(mode="before")
323
+ @classmethod
324
+ def check_wandb_run(cls, data):
325
+ if data.get("wandb_run_id") and not data.get("wandb_name"):
326
+ data["wandb_name"] = data.get("wandb_run_id")
327
+
328
+ LOG.warning(
329
+ "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
330
+ )
331
+
332
+ return data
333
+
334
+
335
+ # pylint: disable=too-many-public-methods,too-many-ancestors
336
+ class AxolotlInputConfig(
337
+ ModelInputConfig,
338
+ LoraConfig,
339
+ ReLoRAConfig,
340
+ HyperparametersConfig,
341
+ WandbConfig,
342
+ MLFlowConfig,
343
+ DeprecatedParameters,
344
+ BaseModel,
345
+ ):
346
+ """wrapper of all config options"""
347
+
348
+ strict: Optional[bool] = Field(default=False)
349
+ resume_from_checkpoint: Optional[str] = None
350
+ auto_resume_from_checkpoints: Optional[bool] = None
351
+ resize_token_embeddings_to_32x: Optional[bool] = None
352
+
353
+ rl: Optional[RLType] = None
354
+
355
+ datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
356
+ dataset_prepared_path: Optional[str] = None
357
+ dataset_shard_num: Optional[int] = None
358
+ dataset_shard_idx: Optional[int] = None
359
+
360
+ pretraining_dataset: Optional[ # type: ignore
361
+ conlist(Union[SFTDataset, PretrainingDataset], min_length=1)
362
+ ] = Field(
363
+ default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
364
+ )
365
+ dataset_processes: Optional[int] = Field(default=os.cpu_count())
366
+ dataset_keep_in_memory: Optional[bool] = None
367
+ dataloader_pin_memory: Optional[bool] = None
368
+ dataloader_num_workers: Optional[int] = None
369
+ dataloader_prefetch_factor: Optional[int] = None
370
+ dataloader_drop_last: Optional[bool] = None
371
+
372
+ push_dataset_to_hub: Optional[str] = None
373
+ hf_use_auth_token: Optional[bool] = None
374
+
375
+ device: Optional[Any] = None
376
+ device_map: Optional[Any] = None
377
+ world_size: Optional[int] = None
378
+ local_rank: Optional[int] = None
379
+ ddp: Optional[bool] = None
380
+
381
+ seed: Optional[int] = None
382
+ ddp_timeout: Optional[int] = None
383
+ ddp_bucket_cap_mb: Optional[int] = None
384
+ ddp_broadcast_buffers: Optional[bool] = None
385
+ ddp_find_unused_parameters: Optional[bool] = None
386
+
387
+ eval_table_size: Optional[int] = None
388
+ eval_max_new_tokens: Optional[int] = None
389
+ do_causal_lm_eval: Optional[bool] = None
390
+ eval_causal_lm_metrics: Optional[List[str]] = None
391
+ do_bench_eval: Optional[bool] = None
392
+ bench_dataset: Optional[str] = None
393
+ metric_for_best_model: Optional[str] = None
394
+ greater_is_better: Optional[bool] = None
395
+
396
+ loss_watchdog_threshold: Optional[float] = None
397
+ loss_watchdog_patience: Optional[int] = None
398
+
399
+ bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO
400
+ fp16: Optional[bool] = None
401
+ bfloat16: Optional[bool] = None # for non-AMP cases
402
+ float16: Optional[bool] = None # for non-AMP cases
403
+ tf32: Optional[bool] = None
404
+ float32: Optional[bool] = None
405
+
406
+ # torch_dtype: Optional[torch.dtype]
407
+
408
+ gradient_checkpointing: Optional[bool] = Field(default=False)
409
+ gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
410
+
411
+ unfrozen_parameters: Optional[List[str]] = None
412
+
413
+ sequence_len: int = Field(default=1024)
414
+ sample_packing: Optional[bool] = None
415
+ eval_sample_packing: Optional[bool] = None
416
+ pad_to_sequence_len: Optional[bool] = None
417
+
418
+ xformers_attention: Optional[bool] = None
419
+ sdp_attention: Optional[bool] = None
420
+ s2_attention: Optional[bool] = None
421
+ flash_attention: Optional[bool] = None
422
+ flash_attn_cross_entropy: Optional[bool] = None
423
+ flash_attn_rms_norm: Optional[bool] = None
424
+ flash_attn_fuse_qkv: Optional[bool] = None
425
+ flash_attn_fuse_mlp: Optional[bool] = None
426
+ flash_optimum: Optional[bool] = None
427
+
428
+ deepspeed: Optional[Union[str, Dict[str, Any]]] = None
429
+ fsdp: Optional[List[str]] = None
430
+ fsdp_config: Optional[Dict[str, Any]] = None
431
+
432
+ val_set_size: Optional[float] = Field(default=0.0)
433
+
434
+ special_tokens: Optional[SpecialTokensConfig] = None
435
+ tokens: Optional[List[str]] = None
436
+
437
+ torch_compile: Optional[bool] = None
438
+ torch_compile_backend: Optional[str] = None
439
+
440
+ max_steps: Optional[int] = None
441
+ warmup_steps: Optional[int] = None
442
+ warmup_ratio: Optional[float] = None
443
+ eval_steps: Optional[int] = None
444
+ evaluation_strategy: Optional[str] = None
445
+ save_steps: Optional[int] = None
446
+ saves_per_epoch: Optional[int] = None
447
+ save_strategy: Optional[str] = None
448
+ save_total_limit: Optional[int] = None
449
+ logging_steps: Optional[int] = None
450
+ early_stopping_patience: Optional[int] = None
451
+
452
+ neftune_noise_alpha: Optional[float] = None
453
+
454
+ max_memory: Optional[Union[int, str]] = None
455
+ gpu_memory_limit: Optional[Union[int, str]] = None
456
+
457
+ chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None
458
+ default_system_message: Optional[str] = None
459
+
460
+ # INTERNALS - document for now, generally not set externally
461
+ is_preprocess: Optional[bool] = None
462
+
463
+ total_num_tokens: Optional[int] = None
464
+ total_supervised_tokens: Optional[int] = None
465
+ sample_packing_eff_est: Optional[float] = None
466
+ axolotl_config_path: Optional[str] = None
467
+
468
+ is_falcon_derived_model: Optional[bool] = Field(default=False)
469
+ is_llama_derived_model: Optional[bool] = Field(default=False)
470
+ is_mistral_derived_model: Optional[bool] = Field(default=False)
471
+ is_qwen_derived_model: Optional[bool] = Field(default=False)
472
+
473
+ @field_validator("datasets", mode="before")
474
+ @classmethod
475
+ def fix_sharegpt_datasets(cls, datasets):
476
+ for idx, ds_cfg in enumerate(datasets):
477
+ if not ds_cfg["type"]:
478
+ continue
479
+ if ds_cfg["type"] == "sharegpt:chat":
480
+ LOG.warning(
481
+ PendingDeprecationWarning(
482
+ "`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
483
+ )
484
+ )
485
+ datasets[idx]["type"] = "sharegpt"
486
+ if "sharegpt_simple" in ds_cfg["type"]:
487
+ LOG.warning(
488
+ PendingDeprecationWarning(
489
+ "`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
490
+ )
491
+ )
492
+ datasets[idx]["type"] = datasets[idx]["type"].replace(
493
+ "sharegpt_simple", "sharegpt"
494
+ )
495
+ return datasets
496
+
497
+ @model_validator(mode="before")
498
+ @classmethod
499
+ def check_batch_size_fields(cls, data):
500
+ fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size")
501
+ non_empty_count = sum(1 for field in fields if data.get(field))
502
+
503
+ if non_empty_count < 2:
504
+ raise ValueError(f"At least two of {', '.join(fields)} must be set")
505
+ return data
506
+
507
+ @model_validator(mode="before")
508
+ @classmethod
509
+ def check_pretraining_w_max_steps(cls, data):
510
+ if data.get("pretraining_dataset") and not data.get("max_steps"):
511
+ raise ValueError(
512
+ "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
513
+ )
514
+ return data
515
+
516
+ @model_validator(mode="before")
517
+ @classmethod
518
+ def check_pretraining_w_group_by_length(cls, data):
519
+ if data.get("pretraining_dataset") and data.get("group_by_length"):
520
+ LOG.warning(
521
+ "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
522
+ )
523
+ return data
524
+
525
+ @model_validator(mode="before")
526
+ @classmethod
527
+ def check_gptq_w_revision(cls, data):
528
+ if data.get("gptq") and data.get("model_revision"):
529
+ raise ValueError(
530
+ "model_revision is not supported for GPTQ models. "
531
+ + "Please download the model from HuggingFace Hub manually for correct branch, "
532
+ + "point to its path, and remove model_revision from the config."
533
+ )
534
+ return data
535
+
536
+ @model_validator(mode="before")
537
+ @classmethod
538
+ def check_sample_packing_w_xformers(cls, data):
539
+ if data.get("sample_packing") and data.get("xformers_attention"):
540
+ raise ValueError(
541
+ "sample_packing not compatible with xformers_attention. Use flash_attention"
542
+ )
543
+
544
+ return data
545
+
546
+ @model_validator(mode="before")
547
+ @classmethod
548
+ def check_sample_packing_w_rl(cls, data):
549
+ if data.get("sample_packing") and data.get("rl"):
550
+ raise ValueError("`sample_packing: true` does not work with RLHF training")
551
+ return data
552
+
553
+ @model_validator(mode="before")
554
+ @classmethod
555
+ def hint_sample_packing_padding(cls, data):
556
+ if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
557
+ LOG.warning(
558
+ "`pad_to_sequence_len: true` is recommended when using sample_packing"
559
+ )
560
+ return data
561
+
562
+ @model_validator(mode="before")
563
+ @classmethod
564
+ def check_gas_bsz(cls, data):
565
+ if data.get("gradient_accumulation_steps") and data.get("batch_size"):
566
+ raise ValueError(
567
+ "please set only one of gradient_accumulation_steps or batch_size"
568
+ )
569
+ return data
570
+
571
+ @model_validator(mode="before")
572
+ @classmethod
573
+ def hint_eval_train_mbsz(cls, data):
574
+ if (
575
+ data.get("eval_batch_size")
576
+ and data.get("micro_batch_size")
577
+ and data.get("eval_batch_size") != data.get("micro_batch_size")
578
+ ):
579
+ LOG.warning(
580
+ "eval_batch_size != micro_batch_size. This can lead to VRAM instability."
581
+ )
582
+ return data
583
+
584
+ @model_validator(mode="before")
585
+ @classmethod
586
+ def check_push_ds_auth(cls, data):
587
+ if (
588
+ data.get("push_dataset_to_hub")
589
+ and data.get("hf_use_auth_token") is not True
590
+ ):
591
+ raise ValueError(
592
+ "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
593
+ )
594
+ return data
595
+
596
+ @model_validator(mode="after")
597
+ def check_falcon_fsdp(self):
598
+ if (self.base_model and "falcon" in self.base_model.lower()) and self.fsdp:
599
+ raise ValueError("FSDP is not supported for falcon models")
600
+ return self
601
+
602
+ @model_validator(mode="after")
603
+ def check_mpt_checkpointing(self):
604
+ if (
605
+ self.base_model and "mpt" in self.base_model.lower()
606
+ ) and self.gradient_checkpointing:
607
+ raise ValueError("gradient_checkpointing is not supported for MPT models")
608
+ return self
609
+
610
+ @model_validator(mode="after")
611
+ def check_better_transformers(self):
612
+ if self.flash_optimum is True:
613
+ if self.adapter:
614
+ LOG.warning(
615
+ "BetterTransformers probably doesn't work with PEFT adapters"
616
+ )
617
+ if self.fp16 or self.bf16:
618
+ raise ValueError("AMP is not supported with BetterTransformer")
619
+ if self.float16 is not True and self.bfloat16 is not True:
620
+ LOG.warning(
621
+ "You should probably set bfloat16 or float16 to true to "
622
+ "load the model in float16 for BetterTransformers"
623
+ )
624
+ return self
625
+
626
+ @model_validator(mode="after")
627
+ def check_adamw_optimizer_params(self):
628
+ if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
629
+ not self.optimizer or "adamw" not in self.optimizer.value
630
+ ):
631
+ LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
632
+ return self
633
+
634
+ @model_validator(mode="before")
635
+ @classmethod
636
+ def check_saves(cls, data):
637
+ if (
638
+ data.get("save_strategy")
639
+ and data.get("save_steps")
640
+ and data.get("save_strategy") != "steps"
641
+ ):
642
+ raise ValueError(
643
+ "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
644
+ )
645
+ if data.get("saves_per_epoch") and data.get("save_steps"):
646
+ raise ValueError(
647
+ "save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
648
+ )
649
+ return data
650
+
651
+ @model_validator(mode="before")
652
+ @classmethod
653
+ def check_push_save(cls, data):
654
+ if data.get("hub_model_id") and not (
655
+ data.get("save_steps") or data.get("saves_per_epoch")
656
+ ):
657
+ LOG.warning(
658
+ "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
659
+ )
660
+ return data
661
+
662
+ @model_validator(mode="before")
663
+ @classmethod
664
+ def check_evals(cls, data):
665
+ if (
666
+ data.get("evaluation_strategy")
667
+ and data.get("eval_steps")
668
+ and data.get("evaluation_strategy") != "steps"
669
+ ):
670
+ raise ValueError(
671
+ "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
672
+ )
673
+
674
+ if (
675
+ data.get("val_set_size") == 0
676
+ and (data.get("eval_steps") or data.get("evaluation_strategy"))
677
+ and not data.get("test_datasets")
678
+ ):
679
+ raise ValueError(
680
+ "eval_steps and evaluation_strategy are not supported with val_set_size == 0"
681
+ )
682
+ if data.get("evals_per_epoch") and data.get("eval_steps"):
683
+ raise ValueError(
684
+ "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
685
+ )
686
+ if (
687
+ data.get("evals_per_epoch")
688
+ and data.get("evaluation_strategy")
689
+ and data.get("evaluation_strategy") != "steps"
690
+ ):
691
+ raise ValueError(
692
+ "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
693
+ )
694
+
695
+ return data
696
+
697
+ @model_validator(mode="before")
698
+ @classmethod
699
+ def check_eval_packing(cls, data):
700
+ if (
701
+ data.get("sample_packing")
702
+ and data.get("eval_table_size")
703
+ and data.get("eval_sample_packing") is not False
704
+ ):
705
+ raise ValueError(
706
+ "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
707
+ )
708
+ return data
709
+
710
+ @model_validator(mode="before")
711
+ @classmethod
712
+ def check_warmup(cls, data):
713
+ if data.get("warmup_steps") and data.get("warmup_ratio"):
714
+ raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
715
+ return data
716
+
717
+ @model_validator(mode="before")
718
+ @classmethod
719
+ def check_neftune(cls, data):
720
+ if data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"):
721
+ data["neftune_noise_alpha"] = data["noisy_embedding_alpha"]
722
+ del data["noisy_embedding_alpha"]
723
+ elif data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"):
724
+ raise ValueError(
725
+ "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
726
+ )
727
+ return data
728
+
729
+ @field_validator("neftune_noise_alpha")
730
+ @classmethod
731
+ def validate_neftune_noise_alpha(cls, neftune_noise_alpha):
732
+ if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0:
733
+ raise ValueError("neftune_noise_alpha must be > 0.0")
734
+ return neftune_noise_alpha
735
+
736
+ @model_validator(mode="before")
737
+ @classmethod
738
+ def check_frozen(cls, data):
739
+ if (
740
+ data.get("adapter")
741
+ and data.get("peft_layers_to_transform")
742
+ and data.get("unfrozen_parameters")
743
+ ):
744
+ raise ValueError(
745
+ "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
746
+ )
747
+
748
+ return data
749
+
750
+ @model_validator(mode="after")
751
+ def check_fft_possible_bad_config(self):
752
+ if (
753
+ # pylint: disable=too-many-boolean-expressions
754
+ not (self.bf16 or self.bfloat16)
755
+ and (self.fp16 or self.float16)
756
+ and not self.adapter
757
+ and not self.flash_attention
758
+ and self.sample_packing
759
+ ):
760
+ LOG.warning(
761
+ "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
762
+ )
763
+ # ValueError: Attempting to unscale FP16 gradients.
764
+ # OR
765
+ # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
766
+ return self
767
+
768
+ @model_validator(mode="after")
769
+ def check_fused_lora(self):
770
+ if self.adapter in ["lora", "qlora"] and (
771
+ self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
772
+ ):
773
+ raise ValueError("Fused modules are not supported with LoRA/QLoRA")
774
+ return self
775
+
776
+ @model_validator(mode="after")
777
+ def hint_lora_8bit(self):
778
+ loftq = (
779
+ self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
780
+ )
781
+ if not self.load_in_8bit and self.adapter == "lora" and not loftq:
782
+ LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
783
+ return self
784
+
785
+ @model_validator(mode="after")
786
+ def check_early_stopping(self):
787
+ if self.early_stopping_patience:
788
+ if not self.save_steps or self.eval_steps:
789
+ raise ValueError(
790
+ "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
791
+ )
792
+ if self.save_steps % self.eval_steps != 0:
793
+ raise ValueError(
794
+ "`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
795
+ )
796
+ return self
797
+
798
+ @model_validator(mode="after")
799
+ def check_relora(self):
800
+ if self.relora_steps:
801
+ if self.adapter not in ("lora", "qlora"):
802
+ raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
803
+
804
+ if self.fsdp:
805
+ raise ValueError("fsdp not supported with ReLoRA")
806
+
807
+ if self.deepspeed:
808
+ raise ValueError("deepspeed not supported with ReLoRA")
809
+
810
+ if self.lr_scheduler == "one_cycle":
811
+ raise ValueError(
812
+ "ReLoRA is not compatible with the one_cycle scheduler"
813
+ )
814
+
815
+ if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
816
+ raise ValueError("Fused modules are not supported with ReLoRA")
817
+ return self
818
+
819
+ @model_validator(mode="before")
820
+ @classmethod
821
+ def check_mem_mismatch(cls, data):
822
+ if (
823
+ data.get("max_memory") is not None
824
+ and data.get("gpu_memory_limit") is not None
825
+ ):
826
+ raise ValueError(
827
+ "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
828
+ )
829
+ return data
830
+
831
+ @model_validator(mode="before")
832
+ @classmethod
833
+ def check_use_reentrant_mismatch(cls, data):
834
+ if (
835
+ data.get("unfrozen_parameters")
836
+ and data.get("gradient_checkpointing_kwargs")
837
+ and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
838
+ is True
839
+ ):
840
+ # https://github.com/huggingface/transformers/issues/21381
841
+ raise ValueError(
842
+ "`use_reentrant` must be false when used with partially frozen model."
843
+ )
844
+ return data
845
+
846
+ @model_validator(mode="before")
847
+ @classmethod
848
+ def check_val_w_test_datasets(cls, data):
849
+ if data.get("test_datasets") and data.get("val_set_size"):
850
+ raise ValueError(
851
+ "non-zero val_set_size should not be used with test_datasets configuration"
852
+ )
853
+ return data
854
+
855
+ @model_validator(mode="before")
856
+ @classmethod
857
+ def check_fsdp_w_8bit_optimizer(cls, data):
858
+ if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
859
+ raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
860
+ return data
861
+
862
+ @model_validator(mode="before")
863
+ @classmethod
864
+ def check_causal_lm_evals(cls, data):
865
+ if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"):
866
+ raise ValueError(
867
+ "do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
868
+ )
869
+
870
+ if data.get("eval_causal_lm_metrics"):
871
+ supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
872
+ if not isinstance(data.get("eval_causal_lm_metrics"), list):
873
+ raise ValueError("eval_causal_lm_metrics must be a list")
874
+ # only ["sacrebleu", "comet", "ter", "chrf"] supported
875
+ if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics):
876
+ raise ValueError(
877
+ f"eval_causal_lm_metrics must be one of {supported_metrics}"
878
+ )
879
+ return data
880
+
881
+ @model_validator(mode="before")
882
+ @classmethod
883
+ def check_dataset_or_pretraining_dataset(cls, data):
884
+ if data.get("datasets") is None and data.get("pretraining_dataset") is None:
885
+ raise ValueError("either datasets or pretraining_dataset is required")
886
+ return data
887
+
888
+
889
+ class AxolotlConfigWCapabilities(AxolotlInputConfig):
890
+ """wrapper to valdiate gpu capabilities with the configured options"""
891
+
892
+ capabilities: GPUCapabilities
893
+
894
+ @model_validator(mode="after")
895
+ def check_bf16(self):
896
+ if self.capabilities.bf16:
897
+ if not self.bf16 and not self.bfloat16:
898
+ LOG.info(
899
+ "bf16 support detected, but not enabled for this configuration."
900
+ )
901
+ else:
902
+ if (
903
+ not self.merge_lora
904
+ and not self.is_preprocess
905
+ and (self.bf16 is True or self.bfloat16 is True)
906
+ ):
907
+ raise ValueError(
908
+ "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
909
+ )
910
+ return self
911
+
912
+ @model_validator(mode="before")
913
+ @classmethod
914
+ def check_sample_packing_w_sdpa_bf16(cls, data):
915
+ is_sm_90: bool = (
916
+ data["capabilities"]
917
+ and data["capabilities"].get("compute_capability") == "sm_90"
918
+ )
919
+ if (
920
+ data.get("sample_packing")
921
+ and data.get("sdp_attention")
922
+ and (data.get("bfloat16") or data.get("bf16"))
923
+ and not is_sm_90
924
+ ):
925
+ # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
926
+ LOG.warning(
927
+ "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
928
+ "This may work on H100s."
929
+ )
930
+
931
+ return data
src/axolotl/utils/config/models/internals/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """module for gpu capabilities"""
2
+ from typing import Optional
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class GPUCapabilities(BaseModel):
8
+ """model to manage the gpu capabilities statically"""
9
+
10
+ bf16: bool = Field(default=False)
11
+ fp8: bool = Field(default=False)
12
+ n_gpu: int = Field(default=1)
13
+ n_node: int = Field(default=1)
14
+ compute_capability: Optional[str] = Field(default=None)
src/axolotl/utils/dict.py CHANGED
@@ -12,4 +12,4 @@ class DictDefault(Dict):
12
  return None
13
 
14
  def __or__(self, other):
15
- return DictDefault(super().__or__(other))
 
12
  return None
13
 
14
  def __or__(self, other):
15
+ return DictDefault(super().__ror__(other))
src/axolotl/utils/models.py CHANGED
@@ -104,8 +104,8 @@ def load_model_config(cfg):
104
  )
105
  raise err
106
 
107
- if cfg.model_config:
108
- for key, val in cfg.model_config.items():
109
  setattr(model_config, key, val)
110
 
111
  check_model_config(cfg, model_config)
 
104
  )
105
  raise err
106
 
107
+ if cfg.model_config_overrides:
108
+ for key, val in cfg.model_config_overrides.items():
109
  setattr(model_config, key, val)
110
 
111
  check_model_config(cfg, model_config)
tests/test_dict.py CHANGED
@@ -39,7 +39,9 @@ class DictDefaultTest(unittest.TestCase):
39
  ), "DictDefault should support in operator for existing keys in list"
40
 
41
  def test_dict_or_operator(self):
42
- cfg = DictDefault(
 
 
43
  {
44
  "key_a": {"key_b": "value_a"},
45
  "key_c": "value_c",
@@ -48,10 +50,6 @@ class DictDefaultTest(unittest.TestCase):
48
  }
49
  )
50
 
51
- cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
52
- {"key_a": {"key_b": "value_b"}, "key_f": "value_g"}
53
- )
54
-
55
  assert (
56
  cfg.key_a.key_b == "value_b"
57
  ), "DictDefault should support OR operator for existing nested keys"
 
39
  ), "DictDefault should support in operator for existing keys in list"
40
 
41
  def test_dict_or_operator(self):
42
+ cfg = DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
43
+
44
+ cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
45
  {
46
  "key_a": {"key_b": "value_a"},
47
  "key_c": "value_c",
 
50
  }
51
  )
52
 
 
 
 
 
53
  assert (
54
  cfg.key_a.key_b == "value_b"
55
  ), "DictDefault should support OR operator for existing nested keys"
tests/test_prompt_tokenizers.py CHANGED
@@ -204,13 +204,13 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
204
  # fmt: off
205
  # System message, multi-turn conversations
206
  mt_ids = tokenize(test_data['multi_turn_sys'])
207
- assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
208
- assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
209
 
210
  # System message, single-turn conversations
211
  st_ids = tokenize(test_data['single_turn_sys'])
212
- assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
213
- assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
214
 
215
  # No system message, single-turn
216
  ns_ids = tokenize(test_data['single_turn_no_sys'])
 
204
  # fmt: off
205
  # System message, multi-turn conversations
206
  mt_ids = tokenize(test_data['multi_turn_sys'])
207
+ assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
208
+ assert mt_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
209
 
210
  # System message, single-turn conversations
211
  st_ids = tokenize(test_data['single_turn_sys'])
212
+ assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
213
+ assert st_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
214
 
215
  # No system message, single-turn
216
  ns_ids = tokenize(test_data['single_turn_no_sys'])
tests/test_validation.py CHANGED
@@ -1,20 +1,39 @@
 
1
  """Module for testing the validation module"""
2
 
3
  import logging
4
  import os
5
- import unittest
6
  from typing import Optional
7
 
8
  import pytest
9
- from transformers.utils import is_torch_bf16_gpu_available
10
 
11
  from axolotl.utils.config import validate_config
 
12
  from axolotl.utils.dict import DictDefault
13
  from axolotl.utils.models import check_model_config
14
  from axolotl.utils.wandb_ import setup_wandb_env_vars
15
 
16
 
17
- class BaseValidation(unittest.TestCase):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  """
19
  Base validation module to setup the log capture
20
  """
@@ -27,199 +46,354 @@ class BaseValidation(unittest.TestCase):
27
 
28
 
29
  # pylint: disable=too-many-public-methods
30
- class ValidationTest(BaseValidation):
31
  """
32
  Test the validation module
33
  """
34
 
35
- def test_batch_size_unused_warning(self):
36
  cfg = DictDefault(
37
  {
38
- "batch_size": 32,
 
 
 
 
39
  }
40
  )
41
 
42
- with self._caplog.at_level(logging.WARNING):
 
 
 
43
  validate_config(cfg)
44
- assert "batch_size is not recommended" in self._caplog.records[0].message
45
 
46
- def test_qlora(self):
47
- base_cfg = DictDefault(
48
- {
49
- "adapter": "qlora",
50
- }
51
- )
52
-
53
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
54
  {
55
- "load_in_8bit": True,
 
 
 
56
  }
57
  )
58
 
59
- with pytest.raises(ValueError, match=r".*8bit.*"):
 
 
60
  validate_config(cfg)
61
 
62
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
 
63
  {
64
- "gptq": True,
 
 
 
 
 
65
  }
66
  )
67
 
68
- with pytest.raises(ValueError, match=r".*gptq.*"):
 
 
 
69
  validate_config(cfg)
70
 
71
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
 
72
  {
73
- "load_in_4bit": False,
 
 
 
 
 
 
 
 
 
 
74
  }
75
  )
76
 
77
- with pytest.raises(ValueError, match=r".*4bit.*"):
78
- validate_config(cfg)
79
 
80
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
 
81
  {
82
- "load_in_4bit": True,
 
 
 
 
 
 
 
 
 
83
  }
84
  )
85
 
86
  validate_config(cfg)
87
 
88
- def test_qlora_merge(self):
89
- base_cfg = DictDefault(
90
  {
91
- "adapter": "qlora",
92
- "merge_lora": True,
 
 
 
 
 
 
 
 
93
  }
94
  )
95
 
96
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
 
 
 
 
 
97
  {
98
- "load_in_8bit": True,
 
 
 
 
 
 
 
 
99
  }
100
  )
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  with pytest.raises(ValueError, match=r".*8bit.*"):
103
  validate_config(cfg)
104
 
105
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
106
- {
107
- "gptq": True,
108
- }
 
 
 
109
  )
110
 
111
  with pytest.raises(ValueError, match=r".*gptq.*"):
112
  validate_config(cfg)
113
 
114
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
115
- {
116
- "load_in_4bit": True,
117
- }
 
 
 
118
  )
119
 
120
  with pytest.raises(ValueError, match=r".*4bit.*"):
121
  validate_config(cfg)
122
 
123
- def test_hf_use_auth_token(self):
124
- cfg = DictDefault(
125
- {
126
- "push_dataset_to_hub": "namespace/repo",
127
- }
 
 
128
  )
129
 
130
- with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  validate_config(cfg)
132
 
133
- cfg = DictDefault(
134
- {
135
- "push_dataset_to_hub": "namespace/repo",
136
- "hf_use_auth_token": True,
137
- }
 
 
138
  )
139
- validate_config(cfg)
140
 
141
- def test_gradient_accumulations_or_batch_size(self):
142
- cfg = DictDefault(
143
- {
144
- "gradient_accumulation_steps": 1,
145
- "batch_size": 1,
146
- }
 
 
 
 
147
  )
148
 
149
- with pytest.raises(
150
- ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
151
- ):
152
  validate_config(cfg)
153
 
154
- cfg = DictDefault(
155
- {
156
- "batch_size": 1,
157
- }
 
 
 
 
158
  )
159
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  validate_config(cfg)
161
 
 
162
  cfg = DictDefault(
163
  {
 
 
 
 
 
 
 
 
164
  "gradient_accumulation_steps": 1,
 
165
  }
166
  )
167
 
168
- validate_config(cfg)
 
 
 
169
 
170
- def test_falcon_fsdp(self):
171
  regex_exp = r".*FSDP is not supported for falcon models.*"
172
 
173
  # Check for lower-case
174
- cfg = DictDefault(
175
- {
176
- "base_model": "tiiuae/falcon-7b",
177
- "fsdp": ["full_shard", "auto_wrap"],
178
- }
 
 
 
179
  )
180
 
181
  with pytest.raises(ValueError, match=regex_exp):
182
  validate_config(cfg)
183
 
184
  # Check for upper-case
185
- cfg = DictDefault(
186
- {
187
- "base_model": "Falcon-7b",
188
- "fsdp": ["full_shard", "auto_wrap"],
189
- }
 
 
 
190
  )
191
 
192
  with pytest.raises(ValueError, match=regex_exp):
193
  validate_config(cfg)
194
 
195
- cfg = DictDefault(
196
- {
197
- "base_model": "tiiuae/falcon-7b",
198
- }
 
 
 
199
  )
200
 
201
  validate_config(cfg)
202
 
203
- def test_mpt_gradient_checkpointing(self):
204
  regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
205
 
206
  # Check for lower-case
207
- cfg = DictDefault(
208
- {
209
- "base_model": "mosaicml/mpt-7b",
210
- "gradient_checkpointing": True,
211
- }
 
 
 
212
  )
213
 
214
  with pytest.raises(ValueError, match=regex_exp):
215
  validate_config(cfg)
216
 
217
- def test_flash_optimum(self):
218
- cfg = DictDefault(
219
- {
220
- "flash_optimum": True,
221
- "adapter": "lora",
222
- }
 
 
 
 
223
  )
224
 
225
  with self._caplog.at_level(logging.WARNING):
@@ -230,10 +404,14 @@ class ValidationTest(BaseValidation):
230
  for record in self._caplog.records
231
  )
232
 
233
- cfg = DictDefault(
234
- {
235
- "flash_optimum": True,
236
- }
 
 
 
 
237
  )
238
 
239
  with self._caplog.at_level(logging.WARNING):
@@ -243,34 +421,43 @@ class ValidationTest(BaseValidation):
243
  for record in self._caplog.records
244
  )
245
 
246
- cfg = DictDefault(
247
- {
248
- "flash_optimum": True,
249
- "fp16": True,
250
- }
 
 
 
251
  )
252
  regex_exp = r".*AMP is not supported.*"
253
 
254
  with pytest.raises(ValueError, match=regex_exp):
255
  validate_config(cfg)
256
 
257
- cfg = DictDefault(
258
- {
259
- "flash_optimum": True,
260
- "bf16": True,
261
- }
 
 
 
262
  )
263
  regex_exp = r".*AMP is not supported.*"
264
 
265
  with pytest.raises(ValueError, match=regex_exp):
266
  validate_config(cfg)
267
 
268
- def test_adamw_hyperparams(self):
269
- cfg = DictDefault(
270
- {
271
- "optimizer": None,
272
- "adam_epsilon": 0.0001,
273
- }
 
 
 
274
  )
275
 
276
  with self._caplog.at_level(logging.WARNING):
@@ -281,11 +468,14 @@ class ValidationTest(BaseValidation):
281
  for record in self._caplog.records
282
  )
283
 
284
- cfg = DictDefault(
285
- {
286
- "optimizer": "adafactor",
287
- "adam_beta1": 0.0001,
288
- }
 
 
 
289
  )
290
 
291
  with self._caplog.at_level(logging.WARNING):
@@ -296,30 +486,39 @@ class ValidationTest(BaseValidation):
296
  for record in self._caplog.records
297
  )
298
 
299
- cfg = DictDefault(
300
- {
301
- "optimizer": "adamw_bnb_8bit",
302
- "adam_beta1": 0.9,
303
- "adam_beta2": 0.99,
304
- "adam_epsilon": 0.0001,
305
- }
 
 
 
306
  )
307
 
308
  validate_config(cfg)
309
 
310
- cfg = DictDefault(
311
- {
312
- "optimizer": "adafactor",
313
- }
 
 
 
314
  )
315
 
316
  validate_config(cfg)
317
 
318
- def test_deprecated_packing(self):
319
- cfg = DictDefault(
320
- {
321
- "max_packed_sequence_len": 1024,
322
- }
 
 
 
323
  )
324
  with pytest.raises(
325
  DeprecationWarning,
@@ -327,12 +526,15 @@ class ValidationTest(BaseValidation):
327
  ):
328
  validate_config(cfg)
329
 
330
- def test_packing(self):
331
- cfg = DictDefault(
332
- {
333
- "sample_packing": True,
334
- "pad_to_sequence_len": None,
335
- }
 
 
 
336
  )
337
  with self._caplog.at_level(logging.WARNING):
338
  validate_config(cfg)
@@ -342,62 +544,79 @@ class ValidationTest(BaseValidation):
342
  for record in self._caplog.records
343
  )
344
 
345
- @pytest.mark.skipif(
346
- is_torch_bf16_gpu_available(),
347
- reason="test should only run on gpus w/o bf16 support",
348
- )
349
- def test_merge_lora_no_bf16_fail(self):
350
  """
351
  This is assumed to be run on a CPU machine, so bf16 is not supported.
352
  """
353
 
354
- cfg = DictDefault(
355
- {
356
- "bf16": True,
357
- }
 
 
 
 
358
  )
359
 
360
  with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
361
- validate_config(cfg)
362
-
363
- cfg = DictDefault(
364
- {
365
- "bf16": True,
366
- "merge_lora": True,
367
- }
 
 
 
 
368
  )
369
 
370
  validate_config(cfg)
371
 
372
- def test_sharegpt_deprecation(self):
373
- cfg = DictDefault(
374
- {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
 
 
 
375
  )
376
  with self._caplog.at_level(logging.WARNING):
377
- validate_config(cfg)
378
  assert any(
379
  "`type: sharegpt:chat` will soon be deprecated." in record.message
380
  for record in self._caplog.records
381
  )
382
- assert cfg.datasets[0].type == "sharegpt"
383
-
384
- cfg = DictDefault(
385
- {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
 
 
 
 
 
 
 
386
  )
387
  with self._caplog.at_level(logging.WARNING):
388
- validate_config(cfg)
389
  assert any(
390
  "`type: sharegpt_simple` will soon be deprecated." in record.message
391
  for record in self._caplog.records
392
  )
393
- assert cfg.datasets[0].type == "sharegpt:load_role"
394
-
395
- def test_no_conflict_save_strategy(self):
396
- cfg = DictDefault(
397
- {
398
- "save_strategy": "epoch",
399
- "save_steps": 10,
400
- }
 
 
 
401
  )
402
 
403
  with pytest.raises(
@@ -405,11 +624,14 @@ class ValidationTest(BaseValidation):
405
  ):
406
  validate_config(cfg)
407
 
408
- cfg = DictDefault(
409
- {
410
- "save_strategy": "no",
411
- "save_steps": 10,
412
- }
 
 
 
413
  )
414
 
415
  with pytest.raises(
@@ -417,45 +639,60 @@ class ValidationTest(BaseValidation):
417
  ):
418
  validate_config(cfg)
419
 
420
- cfg = DictDefault(
421
- {
422
- "save_strategy": "steps",
423
- }
 
 
 
424
  )
425
 
426
  validate_config(cfg)
427
 
428
- cfg = DictDefault(
429
- {
430
- "save_strategy": "steps",
431
- "save_steps": 10,
432
- }
 
 
 
433
  )
434
 
435
  validate_config(cfg)
436
 
437
- cfg = DictDefault(
438
- {
439
- "save_steps": 10,
440
- }
 
 
 
441
  )
442
 
443
  validate_config(cfg)
444
 
445
- cfg = DictDefault(
446
- {
447
- "save_strategy": "no",
448
- }
 
 
 
449
  )
450
 
451
  validate_config(cfg)
452
 
453
- def test_no_conflict_eval_strategy(self):
454
- cfg = DictDefault(
455
- {
456
- "evaluation_strategy": "epoch",
457
- "eval_steps": 10,
458
- }
 
 
 
459
  )
460
 
461
  with pytest.raises(
@@ -463,11 +700,14 @@ class ValidationTest(BaseValidation):
463
  ):
464
  validate_config(cfg)
465
 
466
- cfg = DictDefault(
467
- {
468
- "evaluation_strategy": "no",
469
- "eval_steps": 10,
470
- }
 
 
 
471
  )
472
 
473
  with pytest.raises(
@@ -475,44 +715,59 @@ class ValidationTest(BaseValidation):
475
  ):
476
  validate_config(cfg)
477
 
478
- cfg = DictDefault(
479
- {
480
- "evaluation_strategy": "steps",
481
- }
 
 
 
482
  )
483
 
484
  validate_config(cfg)
485
 
486
- cfg = DictDefault(
487
- {
488
- "evaluation_strategy": "steps",
489
- "eval_steps": 10,
490
- }
 
 
 
491
  )
492
 
493
  validate_config(cfg)
494
 
495
- cfg = DictDefault(
496
- {
497
- "eval_steps": 10,
498
- }
 
 
 
499
  )
500
 
501
  validate_config(cfg)
502
 
503
- cfg = DictDefault(
504
- {
505
- "evaluation_strategy": "no",
506
- }
 
 
 
507
  )
508
 
509
  validate_config(cfg)
510
 
511
- cfg = DictDefault(
512
- {
513
- "evaluation_strategy": "epoch",
514
- "val_set_size": 0,
515
- }
 
 
 
516
  )
517
 
518
  with pytest.raises(
@@ -521,11 +776,14 @@ class ValidationTest(BaseValidation):
521
  ):
522
  validate_config(cfg)
523
 
524
- cfg = DictDefault(
525
- {
526
- "eval_steps": 10,
527
- "val_set_size": 0,
528
- }
 
 
 
529
  )
530
 
531
  with pytest.raises(
@@ -534,38 +792,50 @@ class ValidationTest(BaseValidation):
534
  ):
535
  validate_config(cfg)
536
 
537
- cfg = DictDefault(
538
- {
539
- "val_set_size": 0,
540
- }
 
 
 
541
  )
542
 
543
  validate_config(cfg)
544
 
545
- cfg = DictDefault(
546
- {
547
- "eval_steps": 10,
548
- "val_set_size": 0.01,
549
- }
 
 
 
550
  )
551
 
552
  validate_config(cfg)
553
 
554
- cfg = DictDefault(
555
- {
556
- "evaluation_strategy": "epoch",
557
- "val_set_size": 0.01,
558
- }
 
 
 
559
  )
560
 
561
  validate_config(cfg)
562
 
563
- def test_eval_table_size_conflict_eval_packing(self):
564
- cfg = DictDefault(
565
- {
566
- "sample_packing": True,
567
- "eval_table_size": 100,
568
- }
 
 
 
569
  )
570
 
571
  with pytest.raises(
@@ -573,39 +843,51 @@ class ValidationTest(BaseValidation):
573
  ):
574
  validate_config(cfg)
575
 
576
- cfg = DictDefault(
577
- {
578
- "sample_packing": True,
579
- "eval_sample_packing": False,
580
- }
 
 
 
581
  )
582
 
583
  validate_config(cfg)
584
 
585
- cfg = DictDefault(
586
- {
587
- "sample_packing": False,
588
- "eval_table_size": 100,
589
- }
 
 
 
590
  )
591
 
592
  validate_config(cfg)
593
 
594
- cfg = DictDefault(
595
- {
596
- "sample_packing": True,
597
- "eval_table_size": 100,
598
- "eval_sample_packing": False,
599
- }
 
 
 
600
  )
601
 
602
  validate_config(cfg)
603
 
604
- def test_load_in_x_bit_without_adapter(self):
605
- cfg = DictDefault(
606
- {
607
- "load_in_4bit": True,
608
- }
 
 
 
609
  )
610
 
611
  with pytest.raises(
@@ -614,10 +896,13 @@ class ValidationTest(BaseValidation):
614
  ):
615
  validate_config(cfg)
616
 
617
- cfg = DictDefault(
618
- {
619
- "load_in_8bit": True,
620
- }
 
 
 
621
  )
622
 
623
  with pytest.raises(
@@ -626,30 +911,39 @@ class ValidationTest(BaseValidation):
626
  ):
627
  validate_config(cfg)
628
 
629
- cfg = DictDefault(
630
- {
631
- "load_in_4bit": True,
632
- "adapter": "qlora",
633
- }
 
 
 
634
  )
635
 
636
  validate_config(cfg)
637
 
638
- cfg = DictDefault(
639
- {
640
- "load_in_8bit": True,
641
- "adapter": "lora",
642
- }
 
 
 
643
  )
644
 
645
  validate_config(cfg)
646
 
647
- def test_warmup_step_no_conflict(self):
648
- cfg = DictDefault(
649
- {
650
- "warmup_steps": 10,
651
- "warmup_ratio": 0.1,
652
- }
 
 
 
653
  )
654
 
655
  with pytest.raises(
@@ -658,29 +952,40 @@ class ValidationTest(BaseValidation):
658
  ):
659
  validate_config(cfg)
660
 
661
- cfg = DictDefault(
662
- {
663
- "warmup_steps": 10,
664
- }
 
 
 
665
  )
666
 
667
  validate_config(cfg)
668
 
669
- cfg = DictDefault(
670
- {
671
- "warmup_ratio": 0.1,
672
- }
 
 
 
673
  )
674
 
675
  validate_config(cfg)
676
 
677
- def test_unfrozen_parameters_w_peft_layers_to_transform(self):
678
- cfg = DictDefault(
679
- {
680
- "adapter": "lora",
681
- "unfrozen_parameters": ["model.layers.2[0-9]+.block_sparse_moe.gate.*"],
682
- "peft_layers_to_transform": [0, 1],
683
- }
 
 
 
 
 
684
  )
685
 
686
  with pytest.raises(
@@ -689,8 +994,8 @@ class ValidationTest(BaseValidation):
689
  ):
690
  validate_config(cfg)
691
 
692
- def test_hub_model_id_save_value_warns(self):
693
- cfg = DictDefault({"hub_model_id": "test"})
694
 
695
  with self._caplog.at_level(logging.WARNING):
696
  validate_config(cfg)
@@ -698,22 +1003,25 @@ class ValidationTest(BaseValidation):
698
  "set without any models being saved" in self._caplog.records[0].message
699
  )
700
 
701
- def test_hub_model_id_save_value(self):
702
- cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
703
 
704
  with self._caplog.at_level(logging.WARNING):
705
  validate_config(cfg)
706
  assert len(self._caplog.records) == 0
707
 
708
 
709
- class ValidationCheckModelConfig(BaseValidation):
710
  """
711
  Test the validation for the config when the model config is available
712
  """
713
 
714
- def test_llama_add_tokens_adapter(self):
715
- cfg = DictDefault(
716
- {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
 
 
 
717
  )
718
  model_config = DictDefault({"model_type": "llama"})
719
 
@@ -723,13 +1031,16 @@ class ValidationCheckModelConfig(BaseValidation):
723
  ):
724
  check_model_config(cfg, model_config)
725
 
726
- cfg = DictDefault(
727
- {
728
- "adapter": "qlora",
729
- "load_in_4bit": True,
730
- "tokens": ["<|imstart|>"],
731
- "lora_modules_to_save": ["embed_tokens"],
732
- }
 
 
 
733
  )
734
 
735
  with pytest.raises(
@@ -738,20 +1049,26 @@ class ValidationCheckModelConfig(BaseValidation):
738
  ):
739
  check_model_config(cfg, model_config)
740
 
741
- cfg = DictDefault(
742
- {
743
- "adapter": "qlora",
744
- "load_in_4bit": True,
745
- "tokens": ["<|imstart|>"],
746
- "lora_modules_to_save": ["embed_tokens", "lm_head"],
747
- }
 
 
 
748
  )
749
 
750
  check_model_config(cfg, model_config)
751
 
752
- def test_phi_add_tokens_adapter(self):
753
- cfg = DictDefault(
754
- {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
 
 
 
755
  )
756
  model_config = DictDefault({"model_type": "phi"})
757
 
@@ -761,13 +1078,16 @@ class ValidationCheckModelConfig(BaseValidation):
761
  ):
762
  check_model_config(cfg, model_config)
763
 
764
- cfg = DictDefault(
765
- {
766
- "adapter": "qlora",
767
- "load_in_4bit": True,
768
- "tokens": ["<|imstart|>"],
769
- "lora_modules_to_save": ["embd.wte", "lm_head.linear"],
770
- }
 
 
 
771
  )
772
 
773
  with pytest.raises(
@@ -776,66 +1096,78 @@ class ValidationCheckModelConfig(BaseValidation):
776
  ):
777
  check_model_config(cfg, model_config)
778
 
779
- cfg = DictDefault(
780
- {
781
- "adapter": "qlora",
782
- "load_in_4bit": True,
783
- "tokens": ["<|imstart|>"],
784
- "lora_modules_to_save": ["embed_tokens", "lm_head"],
785
- }
 
 
 
786
  )
787
 
788
  check_model_config(cfg, model_config)
789
 
790
 
791
- class ValidationWandbTest(BaseValidation):
792
  """
793
  Validation test for wandb
794
  """
795
 
796
- def test_wandb_set_run_id_to_name(self):
797
- cfg = DictDefault(
798
- {
799
- "wandb_run_id": "foo",
800
- }
 
 
 
801
  )
802
 
803
  with self._caplog.at_level(logging.WARNING):
804
- validate_config(cfg)
805
  assert any(
806
  "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
807
  in record.message
808
  for record in self._caplog.records
809
  )
810
 
811
- assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
812
 
813
- cfg = DictDefault(
814
- {
815
- "wandb_name": "foo",
816
- }
 
 
 
817
  )
818
 
819
- validate_config(cfg)
820
 
821
- assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
822
 
823
- def test_wandb_sets_env(self):
824
- cfg = DictDefault(
825
- {
826
- "wandb_project": "foo",
827
- "wandb_name": "bar",
828
- "wandb_run_id": "bat",
829
- "wandb_entity": "baz",
830
- "wandb_mode": "online",
831
- "wandb_watch": "false",
832
- "wandb_log_model": "checkpoint",
833
- }
 
 
 
834
  )
835
 
836
- validate_config(cfg)
837
 
838
- setup_wandb_env_vars(cfg)
839
 
840
  assert os.environ.get("WANDB_PROJECT", "") == "foo"
841
  assert os.environ.get("WANDB_NAME", "") == "bar"
@@ -855,24 +1187,27 @@ class ValidationWandbTest(BaseValidation):
855
  os.environ.pop("WANDB_LOG_MODEL", None)
856
  os.environ.pop("WANDB_DISABLED", None)
857
 
858
- def test_wandb_set_disabled(self):
859
- cfg = DictDefault({})
860
 
861
- validate_config(cfg)
862
 
863
- setup_wandb_env_vars(cfg)
864
 
865
  assert os.environ.get("WANDB_DISABLED", "") == "true"
866
 
867
- cfg = DictDefault(
868
- {
869
- "wandb_project": "foo",
870
- }
 
 
 
871
  )
872
 
873
- validate_config(cfg)
874
 
875
- setup_wandb_env_vars(cfg)
876
 
877
  assert os.environ.get("WANDB_DISABLED", "") != "true"
878
 
 
1
+ # pylint: disable=too-many-lines
2
  """Module for testing the validation module"""
3
 
4
  import logging
5
  import os
 
6
  from typing import Optional
7
 
8
  import pytest
9
+ from pydantic import ValidationError
10
 
11
  from axolotl.utils.config import validate_config
12
+ from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
13
  from axolotl.utils.dict import DictDefault
14
  from axolotl.utils.models import check_model_config
15
  from axolotl.utils.wandb_ import setup_wandb_env_vars
16
 
17
 
18
+ @pytest.fixture(name="minimal_cfg")
19
+ def fixture_cfg():
20
+ return DictDefault(
21
+ {
22
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
23
+ "learning_rate": 0.000001,
24
+ "datasets": [
25
+ {
26
+ "path": "mhenrichsen/alpaca_2k_test",
27
+ "type": "alpaca",
28
+ }
29
+ ],
30
+ "micro_batch_size": 1,
31
+ "gradient_accumulation_steps": 1,
32
+ }
33
+ )
34
+
35
+
36
+ class BaseValidation:
37
  """
38
  Base validation module to setup the log capture
39
  """
 
46
 
47
 
48
  # pylint: disable=too-many-public-methods
49
+ class TestValidation(BaseValidation):
50
  """
51
  Test the validation module
52
  """
53
 
54
+ def test_datasets_min_length(self):
55
  cfg = DictDefault(
56
  {
57
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
58
+ "learning_rate": 0.000001,
59
+ "datasets": [],
60
+ "micro_batch_size": 1,
61
+ "gradient_accumulation_steps": 1,
62
  }
63
  )
64
 
65
+ with pytest.raises(
66
+ ValidationError,
67
+ match=r".*List should have at least 1 item after validation*",
68
+ ):
69
  validate_config(cfg)
 
70
 
71
+ def test_datasets_min_length_empty(self):
72
+ cfg = DictDefault(
 
 
 
 
 
 
73
  {
74
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
75
+ "learning_rate": 0.000001,
76
+ "micro_batch_size": 1,
77
+ "gradient_accumulation_steps": 1,
78
  }
79
  )
80
 
81
+ with pytest.raises(
82
+ ValueError, match=r".*either datasets or pretraining_dataset is required*"
83
+ ):
84
  validate_config(cfg)
85
 
86
+ def test_pretrain_dataset_min_length(self):
87
+ cfg = DictDefault(
88
  {
89
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
90
+ "learning_rate": 0.000001,
91
+ "pretraining_dataset": [],
92
+ "micro_batch_size": 1,
93
+ "gradient_accumulation_steps": 1,
94
+ "max_steps": 100,
95
  }
96
  )
97
 
98
+ with pytest.raises(
99
+ ValidationError,
100
+ match=r".*List should have at least 1 item after validation*",
101
+ ):
102
  validate_config(cfg)
103
 
104
+ def test_valid_pretrain_dataset(self):
105
+ cfg = DictDefault(
106
  {
107
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
108
+ "learning_rate": 0.000001,
109
+ "pretraining_dataset": [
110
+ {
111
+ "path": "mhenrichsen/alpaca_2k_test",
112
+ "type": "alpaca",
113
+ }
114
+ ],
115
+ "micro_batch_size": 1,
116
+ "gradient_accumulation_steps": 1,
117
+ "max_steps": 100,
118
  }
119
  )
120
 
121
+ validate_config(cfg)
 
122
 
123
+ def test_valid_sft_dataset(self):
124
+ cfg = DictDefault(
125
  {
126
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
127
+ "learning_rate": 0.000001,
128
+ "datasets": [
129
+ {
130
+ "path": "mhenrichsen/alpaca_2k_test",
131
+ "type": "alpaca",
132
+ }
133
+ ],
134
+ "micro_batch_size": 1,
135
+ "gradient_accumulation_steps": 1,
136
  }
137
  )
138
 
139
  validate_config(cfg)
140
 
141
+ def test_batch_size_unused_warning(self):
142
+ cfg = DictDefault(
143
  {
144
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
145
+ "learning_rate": 0.000001,
146
+ "datasets": [
147
+ {
148
+ "path": "mhenrichsen/alpaca_2k_test",
149
+ "type": "alpaca",
150
+ }
151
+ ],
152
+ "micro_batch_size": 4,
153
+ "batch_size": 32,
154
  }
155
  )
156
 
157
+ with self._caplog.at_level(logging.WARNING):
158
+ validate_config(cfg)
159
+ assert "batch_size is not recommended" in self._caplog.records[0].message
160
+
161
+ def test_batch_size_more_params(self):
162
+ cfg = DictDefault(
163
  {
164
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
165
+ "learning_rate": 0.000001,
166
+ "datasets": [
167
+ {
168
+ "path": "mhenrichsen/alpaca_2k_test",
169
+ "type": "alpaca",
170
+ }
171
+ ],
172
+ "batch_size": 32,
173
  }
174
  )
175
 
176
+ with pytest.raises(ValueError, match=r".*At least two of*"):
177
+ validate_config(cfg)
178
+
179
+ def test_qlora(self, minimal_cfg):
180
+ base_cfg = (
181
+ DictDefault(
182
+ {
183
+ "adapter": "qlora",
184
+ }
185
+ )
186
+ | minimal_cfg
187
+ )
188
+
189
+ cfg = (
190
+ DictDefault( # pylint: disable=unsupported-binary-operation
191
+ {
192
+ "load_in_8bit": True,
193
+ }
194
+ )
195
+ | base_cfg
196
+ )
197
+
198
  with pytest.raises(ValueError, match=r".*8bit.*"):
199
  validate_config(cfg)
200
 
201
+ cfg = (
202
+ DictDefault( # pylint: disable=unsupported-binary-operation
203
+ {
204
+ "gptq": True,
205
+ }
206
+ )
207
+ | base_cfg
208
  )
209
 
210
  with pytest.raises(ValueError, match=r".*gptq.*"):
211
  validate_config(cfg)
212
 
213
+ cfg = (
214
+ DictDefault( # pylint: disable=unsupported-binary-operation
215
+ {
216
+ "load_in_4bit": False,
217
+ }
218
+ )
219
+ | base_cfg
220
  )
221
 
222
  with pytest.raises(ValueError, match=r".*4bit.*"):
223
  validate_config(cfg)
224
 
225
+ cfg = (
226
+ DictDefault( # pylint: disable=unsupported-binary-operation
227
+ {
228
+ "load_in_4bit": True,
229
+ }
230
+ )
231
+ | base_cfg
232
  )
233
 
234
+ validate_config(cfg)
235
+
236
+ def test_qlora_merge(self, minimal_cfg):
237
+ base_cfg = (
238
+ DictDefault(
239
+ {
240
+ "adapter": "qlora",
241
+ "merge_lora": True,
242
+ }
243
+ )
244
+ | minimal_cfg
245
+ )
246
+
247
+ cfg = (
248
+ DictDefault( # pylint: disable=unsupported-binary-operation
249
+ {
250
+ "load_in_8bit": True,
251
+ }
252
+ )
253
+ | base_cfg
254
+ )
255
+
256
+ with pytest.raises(ValueError, match=r".*8bit.*"):
257
  validate_config(cfg)
258
 
259
+ cfg = (
260
+ DictDefault( # pylint: disable=unsupported-binary-operation
261
+ {
262
+ "gptq": True,
263
+ }
264
+ )
265
+ | base_cfg
266
  )
 
267
 
268
+ with pytest.raises(ValueError, match=r".*gptq.*"):
269
+ validate_config(cfg)
270
+
271
+ cfg = (
272
+ DictDefault( # pylint: disable=unsupported-binary-operation
273
+ {
274
+ "load_in_4bit": True,
275
+ }
276
+ )
277
+ | base_cfg
278
  )
279
 
280
+ with pytest.raises(ValueError, match=r".*4bit.*"):
 
 
281
  validate_config(cfg)
282
 
283
+ def test_hf_use_auth_token(self, minimal_cfg):
284
+ cfg = (
285
+ DictDefault(
286
+ {
287
+ "push_dataset_to_hub": "namespace/repo",
288
+ }
289
+ )
290
+ | minimal_cfg
291
  )
292
 
293
+ with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"):
294
+ validate_config(cfg)
295
+
296
+ cfg = (
297
+ DictDefault(
298
+ {
299
+ "push_dataset_to_hub": "namespace/repo",
300
+ "hf_use_auth_token": True,
301
+ }
302
+ )
303
+ | minimal_cfg
304
+ )
305
  validate_config(cfg)
306
 
307
+ def test_gradient_accumulations_or_batch_size(self):
308
  cfg = DictDefault(
309
  {
310
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
311
+ "learning_rate": 0.000001,
312
+ "datasets": [
313
+ {
314
+ "path": "mhenrichsen/alpaca_2k_test",
315
+ "type": "alpaca",
316
+ }
317
+ ],
318
  "gradient_accumulation_steps": 1,
319
+ "batch_size": 1,
320
  }
321
  )
322
 
323
+ with pytest.raises(
324
+ ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
325
+ ):
326
+ validate_config(cfg)
327
 
328
+ def test_falcon_fsdp(self, minimal_cfg):
329
  regex_exp = r".*FSDP is not supported for falcon models.*"
330
 
331
  # Check for lower-case
332
+ cfg = (
333
+ DictDefault(
334
+ {
335
+ "base_model": "tiiuae/falcon-7b",
336
+ "fsdp": ["full_shard", "auto_wrap"],
337
+ }
338
+ )
339
+ | minimal_cfg
340
  )
341
 
342
  with pytest.raises(ValueError, match=regex_exp):
343
  validate_config(cfg)
344
 
345
  # Check for upper-case
346
+ cfg = (
347
+ DictDefault(
348
+ {
349
+ "base_model": "Falcon-7b",
350
+ "fsdp": ["full_shard", "auto_wrap"],
351
+ }
352
+ )
353
+ | minimal_cfg
354
  )
355
 
356
  with pytest.raises(ValueError, match=regex_exp):
357
  validate_config(cfg)
358
 
359
+ cfg = (
360
+ DictDefault(
361
+ {
362
+ "base_model": "tiiuae/falcon-7b",
363
+ }
364
+ )
365
+ | minimal_cfg
366
  )
367
 
368
  validate_config(cfg)
369
 
370
+ def test_mpt_gradient_checkpointing(self, minimal_cfg):
371
  regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
372
 
373
  # Check for lower-case
374
+ cfg = (
375
+ DictDefault(
376
+ {
377
+ "base_model": "mosaicml/mpt-7b",
378
+ "gradient_checkpointing": True,
379
+ }
380
+ )
381
+ | minimal_cfg
382
  )
383
 
384
  with pytest.raises(ValueError, match=regex_exp):
385
  validate_config(cfg)
386
 
387
+ def test_flash_optimum(self, minimal_cfg):
388
+ cfg = (
389
+ DictDefault(
390
+ {
391
+ "flash_optimum": True,
392
+ "adapter": "lora",
393
+ "bf16": False,
394
+ }
395
+ )
396
+ | minimal_cfg
397
  )
398
 
399
  with self._caplog.at_level(logging.WARNING):
 
404
  for record in self._caplog.records
405
  )
406
 
407
+ cfg = (
408
+ DictDefault(
409
+ {
410
+ "flash_optimum": True,
411
+ "bf16": False,
412
+ }
413
+ )
414
+ | minimal_cfg
415
  )
416
 
417
  with self._caplog.at_level(logging.WARNING):
 
421
  for record in self._caplog.records
422
  )
423
 
424
+ cfg = (
425
+ DictDefault(
426
+ {
427
+ "flash_optimum": True,
428
+ "fp16": True,
429
+ }
430
+ )
431
+ | minimal_cfg
432
  )
433
  regex_exp = r".*AMP is not supported.*"
434
 
435
  with pytest.raises(ValueError, match=regex_exp):
436
  validate_config(cfg)
437
 
438
+ cfg = (
439
+ DictDefault(
440
+ {
441
+ "flash_optimum": True,
442
+ "bf16": True,
443
+ }
444
+ )
445
+ | minimal_cfg
446
  )
447
  regex_exp = r".*AMP is not supported.*"
448
 
449
  with pytest.raises(ValueError, match=regex_exp):
450
  validate_config(cfg)
451
 
452
+ def test_adamw_hyperparams(self, minimal_cfg):
453
+ cfg = (
454
+ DictDefault(
455
+ {
456
+ "optimizer": None,
457
+ "adam_epsilon": 0.0001,
458
+ }
459
+ )
460
+ | minimal_cfg
461
  )
462
 
463
  with self._caplog.at_level(logging.WARNING):
 
468
  for record in self._caplog.records
469
  )
470
 
471
+ cfg = (
472
+ DictDefault(
473
+ {
474
+ "optimizer": "adafactor",
475
+ "adam_beta1": 0.0001,
476
+ }
477
+ )
478
+ | minimal_cfg
479
  )
480
 
481
  with self._caplog.at_level(logging.WARNING):
 
486
  for record in self._caplog.records
487
  )
488
 
489
+ cfg = (
490
+ DictDefault(
491
+ {
492
+ "optimizer": "adamw_bnb_8bit",
493
+ "adam_beta1": 0.9,
494
+ "adam_beta2": 0.99,
495
+ "adam_epsilon": 0.0001,
496
+ }
497
+ )
498
+ | minimal_cfg
499
  )
500
 
501
  validate_config(cfg)
502
 
503
+ cfg = (
504
+ DictDefault(
505
+ {
506
+ "optimizer": "adafactor",
507
+ }
508
+ )
509
+ | minimal_cfg
510
  )
511
 
512
  validate_config(cfg)
513
 
514
+ def test_deprecated_packing(self, minimal_cfg):
515
+ cfg = (
516
+ DictDefault(
517
+ {
518
+ "max_packed_sequence_len": 1024,
519
+ }
520
+ )
521
+ | minimal_cfg
522
  )
523
  with pytest.raises(
524
  DeprecationWarning,
 
526
  ):
527
  validate_config(cfg)
528
 
529
+ def test_packing(self, minimal_cfg):
530
+ cfg = (
531
+ DictDefault(
532
+ {
533
+ "sample_packing": True,
534
+ "pad_to_sequence_len": None,
535
+ }
536
+ )
537
+ | minimal_cfg
538
  )
539
  with self._caplog.at_level(logging.WARNING):
540
  validate_config(cfg)
 
544
  for record in self._caplog.records
545
  )
546
 
547
+ def test_merge_lora_no_bf16_fail(self, minimal_cfg):
 
 
 
 
548
  """
549
  This is assumed to be run on a CPU machine, so bf16 is not supported.
550
  """
551
 
552
+ cfg = (
553
+ DictDefault(
554
+ {
555
+ "bf16": True,
556
+ "capabilities": {"bf16": False},
557
+ }
558
+ )
559
+ | minimal_cfg
560
  )
561
 
562
  with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
563
+ AxolotlConfigWCapabilities(**cfg.to_dict())
564
+
565
+ cfg = (
566
+ DictDefault(
567
+ {
568
+ "bf16": True,
569
+ "merge_lora": True,
570
+ "capabilities": {"bf16": False},
571
+ }
572
+ )
573
+ | minimal_cfg
574
  )
575
 
576
  validate_config(cfg)
577
 
578
+ def test_sharegpt_deprecation(self, minimal_cfg):
579
+ cfg = (
580
+ DictDefault(
581
+ {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
582
+ )
583
+ | minimal_cfg
584
  )
585
  with self._caplog.at_level(logging.WARNING):
586
+ new_cfg = validate_config(cfg)
587
  assert any(
588
  "`type: sharegpt:chat` will soon be deprecated." in record.message
589
  for record in self._caplog.records
590
  )
591
+ assert new_cfg.datasets[0].type == "sharegpt"
592
+
593
+ cfg = (
594
+ DictDefault(
595
+ {
596
+ "datasets": [
597
+ {"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}
598
+ ]
599
+ }
600
+ )
601
+ | minimal_cfg
602
  )
603
  with self._caplog.at_level(logging.WARNING):
604
+ new_cfg = validate_config(cfg)
605
  assert any(
606
  "`type: sharegpt_simple` will soon be deprecated." in record.message
607
  for record in self._caplog.records
608
  )
609
+ assert new_cfg.datasets[0].type == "sharegpt:load_role"
610
+
611
+ def test_no_conflict_save_strategy(self, minimal_cfg):
612
+ cfg = (
613
+ DictDefault(
614
+ {
615
+ "save_strategy": "epoch",
616
+ "save_steps": 10,
617
+ }
618
+ )
619
+ | minimal_cfg
620
  )
621
 
622
  with pytest.raises(
 
624
  ):
625
  validate_config(cfg)
626
 
627
+ cfg = (
628
+ DictDefault(
629
+ {
630
+ "save_strategy": "no",
631
+ "save_steps": 10,
632
+ }
633
+ )
634
+ | minimal_cfg
635
  )
636
 
637
  with pytest.raises(
 
639
  ):
640
  validate_config(cfg)
641
 
642
+ cfg = (
643
+ DictDefault(
644
+ {
645
+ "save_strategy": "steps",
646
+ }
647
+ )
648
+ | minimal_cfg
649
  )
650
 
651
  validate_config(cfg)
652
 
653
+ cfg = (
654
+ DictDefault(
655
+ {
656
+ "save_strategy": "steps",
657
+ "save_steps": 10,
658
+ }
659
+ )
660
+ | minimal_cfg
661
  )
662
 
663
  validate_config(cfg)
664
 
665
+ cfg = (
666
+ DictDefault(
667
+ {
668
+ "save_steps": 10,
669
+ }
670
+ )
671
+ | minimal_cfg
672
  )
673
 
674
  validate_config(cfg)
675
 
676
+ cfg = (
677
+ DictDefault(
678
+ {
679
+ "save_strategy": "no",
680
+ }
681
+ )
682
+ | minimal_cfg
683
  )
684
 
685
  validate_config(cfg)
686
 
687
+ def test_no_conflict_eval_strategy(self, minimal_cfg):
688
+ cfg = (
689
+ DictDefault(
690
+ {
691
+ "evaluation_strategy": "epoch",
692
+ "eval_steps": 10,
693
+ }
694
+ )
695
+ | minimal_cfg
696
  )
697
 
698
  with pytest.raises(
 
700
  ):
701
  validate_config(cfg)
702
 
703
+ cfg = (
704
+ DictDefault(
705
+ {
706
+ "evaluation_strategy": "no",
707
+ "eval_steps": 10,
708
+ }
709
+ )
710
+ | minimal_cfg
711
  )
712
 
713
  with pytest.raises(
 
715
  ):
716
  validate_config(cfg)
717
 
718
+ cfg = (
719
+ DictDefault(
720
+ {
721
+ "evaluation_strategy": "steps",
722
+ }
723
+ )
724
+ | minimal_cfg
725
  )
726
 
727
  validate_config(cfg)
728
 
729
+ cfg = (
730
+ DictDefault(
731
+ {
732
+ "evaluation_strategy": "steps",
733
+ "eval_steps": 10,
734
+ }
735
+ )
736
+ | minimal_cfg
737
  )
738
 
739
  validate_config(cfg)
740
 
741
+ cfg = (
742
+ DictDefault(
743
+ {
744
+ "eval_steps": 10,
745
+ }
746
+ )
747
+ | minimal_cfg
748
  )
749
 
750
  validate_config(cfg)
751
 
752
+ cfg = (
753
+ DictDefault(
754
+ {
755
+ "evaluation_strategy": "no",
756
+ }
757
+ )
758
+ | minimal_cfg
759
  )
760
 
761
  validate_config(cfg)
762
 
763
+ cfg = (
764
+ DictDefault(
765
+ {
766
+ "evaluation_strategy": "epoch",
767
+ "val_set_size": 0,
768
+ }
769
+ )
770
+ | minimal_cfg
771
  )
772
 
773
  with pytest.raises(
 
776
  ):
777
  validate_config(cfg)
778
 
779
+ cfg = (
780
+ DictDefault(
781
+ {
782
+ "eval_steps": 10,
783
+ "val_set_size": 0,
784
+ }
785
+ )
786
+ | minimal_cfg
787
  )
788
 
789
  with pytest.raises(
 
792
  ):
793
  validate_config(cfg)
794
 
795
+ cfg = (
796
+ DictDefault(
797
+ {
798
+ "val_set_size": 0,
799
+ }
800
+ )
801
+ | minimal_cfg
802
  )
803
 
804
  validate_config(cfg)
805
 
806
+ cfg = (
807
+ DictDefault(
808
+ {
809
+ "eval_steps": 10,
810
+ "val_set_size": 0.01,
811
+ }
812
+ )
813
+ | minimal_cfg
814
  )
815
 
816
  validate_config(cfg)
817
 
818
+ cfg = (
819
+ DictDefault(
820
+ {
821
+ "evaluation_strategy": "epoch",
822
+ "val_set_size": 0.01,
823
+ }
824
+ )
825
+ | minimal_cfg
826
  )
827
 
828
  validate_config(cfg)
829
 
830
+ def test_eval_table_size_conflict_eval_packing(self, minimal_cfg):
831
+ cfg = (
832
+ DictDefault(
833
+ {
834
+ "sample_packing": True,
835
+ "eval_table_size": 100,
836
+ }
837
+ )
838
+ | minimal_cfg
839
  )
840
 
841
  with pytest.raises(
 
843
  ):
844
  validate_config(cfg)
845
 
846
+ cfg = (
847
+ DictDefault(
848
+ {
849
+ "sample_packing": True,
850
+ "eval_sample_packing": False,
851
+ }
852
+ )
853
+ | minimal_cfg
854
  )
855
 
856
  validate_config(cfg)
857
 
858
+ cfg = (
859
+ DictDefault(
860
+ {
861
+ "sample_packing": False,
862
+ "eval_table_size": 100,
863
+ }
864
+ )
865
+ | minimal_cfg
866
  )
867
 
868
  validate_config(cfg)
869
 
870
+ cfg = (
871
+ DictDefault(
872
+ {
873
+ "sample_packing": True,
874
+ "eval_table_size": 100,
875
+ "eval_sample_packing": False,
876
+ }
877
+ )
878
+ | minimal_cfg
879
  )
880
 
881
  validate_config(cfg)
882
 
883
+ def test_load_in_x_bit_without_adapter(self, minimal_cfg):
884
+ cfg = (
885
+ DictDefault(
886
+ {
887
+ "load_in_4bit": True,
888
+ }
889
+ )
890
+ | minimal_cfg
891
  )
892
 
893
  with pytest.raises(
 
896
  ):
897
  validate_config(cfg)
898
 
899
+ cfg = (
900
+ DictDefault(
901
+ {
902
+ "load_in_8bit": True,
903
+ }
904
+ )
905
+ | minimal_cfg
906
  )
907
 
908
  with pytest.raises(
 
911
  ):
912
  validate_config(cfg)
913
 
914
+ cfg = (
915
+ DictDefault(
916
+ {
917
+ "load_in_4bit": True,
918
+ "adapter": "qlora",
919
+ }
920
+ )
921
+ | minimal_cfg
922
  )
923
 
924
  validate_config(cfg)
925
 
926
+ cfg = (
927
+ DictDefault(
928
+ {
929
+ "load_in_8bit": True,
930
+ "adapter": "lora",
931
+ }
932
+ )
933
+ | minimal_cfg
934
  )
935
 
936
  validate_config(cfg)
937
 
938
+ def test_warmup_step_no_conflict(self, minimal_cfg):
939
+ cfg = (
940
+ DictDefault(
941
+ {
942
+ "warmup_steps": 10,
943
+ "warmup_ratio": 0.1,
944
+ }
945
+ )
946
+ | minimal_cfg
947
  )
948
 
949
  with pytest.raises(
 
952
  ):
953
  validate_config(cfg)
954
 
955
+ cfg = (
956
+ DictDefault(
957
+ {
958
+ "warmup_steps": 10,
959
+ }
960
+ )
961
+ | minimal_cfg
962
  )
963
 
964
  validate_config(cfg)
965
 
966
+ cfg = (
967
+ DictDefault(
968
+ {
969
+ "warmup_ratio": 0.1,
970
+ }
971
+ )
972
+ | minimal_cfg
973
  )
974
 
975
  validate_config(cfg)
976
 
977
+ def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg):
978
+ cfg = (
979
+ DictDefault(
980
+ {
981
+ "adapter": "lora",
982
+ "unfrozen_parameters": [
983
+ "model.layers.2[0-9]+.block_sparse_moe.gate.*"
984
+ ],
985
+ "peft_layers_to_transform": [0, 1],
986
+ }
987
+ )
988
+ | minimal_cfg
989
  )
990
 
991
  with pytest.raises(
 
994
  ):
995
  validate_config(cfg)
996
 
997
+ def test_hub_model_id_save_value_warns(self, minimal_cfg):
998
+ cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
999
 
1000
  with self._caplog.at_level(logging.WARNING):
1001
  validate_config(cfg)
 
1003
  "set without any models being saved" in self._caplog.records[0].message
1004
  )
1005
 
1006
+ def test_hub_model_id_save_value(self, minimal_cfg):
1007
+ cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
1008
 
1009
  with self._caplog.at_level(logging.WARNING):
1010
  validate_config(cfg)
1011
  assert len(self._caplog.records) == 0
1012
 
1013
 
1014
+ class TestValidationCheckModelConfig(BaseValidation):
1015
  """
1016
  Test the validation for the config when the model config is available
1017
  """
1018
 
1019
+ def test_llama_add_tokens_adapter(self, minimal_cfg):
1020
+ cfg = (
1021
+ DictDefault(
1022
+ {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
1023
+ )
1024
+ | minimal_cfg
1025
  )
1026
  model_config = DictDefault({"model_type": "llama"})
1027
 
 
1031
  ):
1032
  check_model_config(cfg, model_config)
1033
 
1034
+ cfg = (
1035
+ DictDefault(
1036
+ {
1037
+ "adapter": "qlora",
1038
+ "load_in_4bit": True,
1039
+ "tokens": ["<|imstart|>"],
1040
+ "lora_modules_to_save": ["embed_tokens"],
1041
+ }
1042
+ )
1043
+ | minimal_cfg
1044
  )
1045
 
1046
  with pytest.raises(
 
1049
  ):
1050
  check_model_config(cfg, model_config)
1051
 
1052
+ cfg = (
1053
+ DictDefault(
1054
+ {
1055
+ "adapter": "qlora",
1056
+ "load_in_4bit": True,
1057
+ "tokens": ["<|imstart|>"],
1058
+ "lora_modules_to_save": ["embed_tokens", "lm_head"],
1059
+ }
1060
+ )
1061
+ | minimal_cfg
1062
  )
1063
 
1064
  check_model_config(cfg, model_config)
1065
 
1066
+ def test_phi_add_tokens_adapter(self, minimal_cfg):
1067
+ cfg = (
1068
+ DictDefault(
1069
+ {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
1070
+ )
1071
+ | minimal_cfg
1072
  )
1073
  model_config = DictDefault({"model_type": "phi"})
1074
 
 
1078
  ):
1079
  check_model_config(cfg, model_config)
1080
 
1081
+ cfg = (
1082
+ DictDefault(
1083
+ {
1084
+ "adapter": "qlora",
1085
+ "load_in_4bit": True,
1086
+ "tokens": ["<|imstart|>"],
1087
+ "lora_modules_to_save": ["embd.wte", "lm_head.linear"],
1088
+ }
1089
+ )
1090
+ | minimal_cfg
1091
  )
1092
 
1093
  with pytest.raises(
 
1096
  ):
1097
  check_model_config(cfg, model_config)
1098
 
1099
+ cfg = (
1100
+ DictDefault(
1101
+ {
1102
+ "adapter": "qlora",
1103
+ "load_in_4bit": True,
1104
+ "tokens": ["<|imstart|>"],
1105
+ "lora_modules_to_save": ["embed_tokens", "lm_head"],
1106
+ }
1107
+ )
1108
+ | minimal_cfg
1109
  )
1110
 
1111
  check_model_config(cfg, model_config)
1112
 
1113
 
1114
+ class TestValidationWandb(BaseValidation):
1115
  """
1116
  Validation test for wandb
1117
  """
1118
 
1119
+ def test_wandb_set_run_id_to_name(self, minimal_cfg):
1120
+ cfg = (
1121
+ DictDefault(
1122
+ {
1123
+ "wandb_run_id": "foo",
1124
+ }
1125
+ )
1126
+ | minimal_cfg
1127
  )
1128
 
1129
  with self._caplog.at_level(logging.WARNING):
1130
+ new_cfg = validate_config(cfg)
1131
  assert any(
1132
  "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
1133
  in record.message
1134
  for record in self._caplog.records
1135
  )
1136
 
1137
+ assert new_cfg.wandb_name == "foo" and new_cfg.wandb_run_id == "foo"
1138
 
1139
+ cfg = (
1140
+ DictDefault(
1141
+ {
1142
+ "wandb_name": "foo",
1143
+ }
1144
+ )
1145
+ | minimal_cfg
1146
  )
1147
 
1148
+ new_cfg = validate_config(cfg)
1149
 
1150
+ assert new_cfg.wandb_name == "foo" and new_cfg.wandb_run_id is None
1151
 
1152
+ def test_wandb_sets_env(self, minimal_cfg):
1153
+ cfg = (
1154
+ DictDefault(
1155
+ {
1156
+ "wandb_project": "foo",
1157
+ "wandb_name": "bar",
1158
+ "wandb_run_id": "bat",
1159
+ "wandb_entity": "baz",
1160
+ "wandb_mode": "online",
1161
+ "wandb_watch": "false",
1162
+ "wandb_log_model": "checkpoint",
1163
+ }
1164
+ )
1165
+ | minimal_cfg
1166
  )
1167
 
1168
+ new_cfg = validate_config(cfg)
1169
 
1170
+ setup_wandb_env_vars(new_cfg)
1171
 
1172
  assert os.environ.get("WANDB_PROJECT", "") == "foo"
1173
  assert os.environ.get("WANDB_NAME", "") == "bar"
 
1187
  os.environ.pop("WANDB_LOG_MODEL", None)
1188
  os.environ.pop("WANDB_DISABLED", None)
1189
 
1190
+ def test_wandb_set_disabled(self, minimal_cfg):
1191
+ cfg = DictDefault({}) | minimal_cfg
1192
 
1193
+ new_cfg = validate_config(cfg)
1194
 
1195
+ setup_wandb_env_vars(new_cfg)
1196
 
1197
  assert os.environ.get("WANDB_DISABLED", "") == "true"
1198
 
1199
+ cfg = (
1200
+ DictDefault(
1201
+ {
1202
+ "wandb_project": "foo",
1203
+ }
1204
+ )
1205
+ | minimal_cfg
1206
  )
1207
 
1208
+ new_cfg = validate_config(cfg)
1209
 
1210
+ setup_wandb_env_vars(new_cfg)
1211
 
1212
  assert os.environ.get("WANDB_DISABLED", "") != "true"
1213