|
from __future__ import annotations |
|
|
|
from dataclasses import asdict, dataclass, field |
|
from glob import glob |
|
from pathlib import Path |
|
from typing import ( |
|
Any, |
|
Dict, |
|
Iterable, |
|
List, |
|
Optional, |
|
Tuple, |
|
Type, |
|
TypeVar, |
|
Union, |
|
cast, |
|
) |
|
|
|
import torch |
|
from omegaconf import DictConfig, ListConfig |
|
from omegaconf import OmegaConf as om |
|
from omegaconf.errors import OmegaConfBaseException |
|
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy |
|
|
|
from .aliases import PathOrStr |
|
from .beam_search import Sampler |
|
from .exceptions import OLMoConfigurationError |
|
from .util import StrEnum |
|
|
|
__all__ = [ |
|
"ActivationType", |
|
"ActivationCheckpointingStrategy", |
|
"BlockType", |
|
"LayerNormType", |
|
"InitFnType", |
|
"ModelConfig", |
|
"OptimizerType", |
|
"OptimizerConfig", |
|
"SchedulerType", |
|
"SchedulerConfig", |
|
"DataConfig", |
|
"EvaluatorConfig", |
|
"TokenizerConfig", |
|
"TrainConfig", |
|
"PaddingDirection", |
|
"TruncationDirection", |
|
"SpeedMonitorConfig", |
|
"WandbConfig", |
|
"CompilerConfig", |
|
"WandbConfig", |
|
"FSDPPrecision", |
|
"FSDPWrapStrategy", |
|
"FSDPConfig", |
|
"CheckpointType", |
|
] |
|
|
|
C = TypeVar("C", bound="BaseConfig") |
|
D = TypeVar("D", bound="DictConfig|ListConfig") |
|
|
|
|
|
class BaseConfig: |
|
@classmethod |
|
def _register_resolvers(cls, validate_paths: bool = True): |
|
|
|
def path_glob(*paths) -> List[str]: |
|
out = [] |
|
for path in paths: |
|
matches = sorted(glob(path)) |
|
if not matches and validate_paths: |
|
raise FileNotFoundError(f"{path} does not match any files or dirs") |
|
out.extend(matches) |
|
return out |
|
|
|
|
|
def path_choose(*paths) -> str: |
|
from .util import is_url |
|
|
|
for path in paths: |
|
if is_url(path) or Path(path).exists(): |
|
return path |
|
if validate_paths: |
|
raise FileNotFoundError(", ".join(paths)) |
|
else: |
|
return "" |
|
|
|
|
|
def path_last_checkpoint(path) -> str: |
|
from .util import find_latest_checkpoint |
|
|
|
latest_checkpoint = find_latest_checkpoint(path) |
|
if latest_checkpoint is None: |
|
if validate_paths: |
|
raise FileNotFoundError(f"Could not find a latest checkpoint at {path}") |
|
else: |
|
return "" |
|
else: |
|
return str(latest_checkpoint) |
|
|
|
om.register_new_resolver("path.glob", path_glob, replace=True) |
|
om.register_new_resolver("path.choose", path_choose, replace=True) |
|
om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True) |
|
|
|
@classmethod |
|
def update_legacy_settings(cls, config: D) -> D: |
|
""" |
|
Update the legacy config settings whose schemas have undergone backwards-incompatible changes. |
|
""" |
|
return config |
|
|
|
@classmethod |
|
def new(cls: Type[C], **kwargs) -> C: |
|
cls._register_resolvers() |
|
conf = om.structured(cls) |
|
try: |
|
if kwargs: |
|
conf = om.merge(conf, kwargs) |
|
return cast(C, om.to_object(conf)) |
|
except OmegaConfBaseException as e: |
|
raise OLMoConfigurationError(str(e)) |
|
|
|
@classmethod |
|
def load( |
|
cls: Type[C], |
|
path: PathOrStr, |
|
overrides: Optional[List[str]] = None, |
|
key: Optional[str] = None, |
|
validate_paths: bool = True, |
|
) -> C: |
|
"""Load from a YAML file.""" |
|
cls._register_resolvers(validate_paths=validate_paths) |
|
schema = om.structured(cls) |
|
try: |
|
raw = om.load(str(path)) |
|
if key is not None: |
|
raw = raw[key] |
|
raw = cls.update_legacy_settings(raw) |
|
conf = om.merge(schema, raw) |
|
if overrides: |
|
conf = om.merge(conf, om.from_dotlist(overrides)) |
|
return cast(C, om.to_object(conf)) |
|
except OmegaConfBaseException as e: |
|
raise OLMoConfigurationError(str(e)) |
|
|
|
def save(self, path: PathOrStr) -> None: |
|
"""Save to a YAML file.""" |
|
om.save(config=self, f=str(path)) |
|
|
|
def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]: |
|
out = asdict(self) |
|
if exclude is not None: |
|
for name in exclude: |
|
if name in out: |
|
del out[name] |
|
return out |
|
|
|
|
|
class LayerNormType(StrEnum): |
|
default = "default" |
|
""" |
|
The default LayerNorm implementation, equivalent to PyTorch's built-in version. |
|
""" |
|
|
|
low_precision = "low_precision" |
|
""" |
|
A low-precision version of the default LayerNorm. |
|
""" |
|
|
|
rms = "rms" |
|
""" |
|
An RMSNorm implementation. When using ``torch.compile`` this is |
|
probably the fastest implementation. |
|
""" |
|
|
|
|
|
class ActivationType(StrEnum): |
|
gelu = "gelu" |
|
relu = "relu" |
|
swiglu = "swiglu" |
|
|
|
|
|
class BlockType(StrEnum): |
|
sequential = "sequential" |
|
|
|
llama = "llama" |
|
""" |
|
A block similar to the sequential block with slightly different |
|
implementations of operations like attention to imitate the behavior of Llama. |
|
""" |
|
|
|
|
|
class InitFnType(StrEnum): |
|
mitchell = "mitchell" |
|
""" |
|
The strategy suggested to us by Mitchell Wortsman from UW. |
|
This uses a truncated normal distribution with an adaptive standard deviation that depends |
|
on the size of the weights as well as the depth of the layer. |
|
""" |
|
|
|
normal = "normal" |
|
""" |
|
All weights are initialized from the same normal distribution. |
|
""" |
|
|
|
kaiming_normal = "kaiming_normal" |
|
""" |
|
All weights are initialized with the Kaiming method from a normal distribution. |
|
Note this currently won't work with FSDP. |
|
""" |
|
|
|
fan_in = "fan_in" |
|
""" |
|
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in`` |
|
is the input dimensionality of the kernel. |
|
""" |
|
|
|
full_megatron = "full_megatron" |
|
""" |
|
This is what metaseq calls "full megatron init". It is the init used for Llama 2. |
|
""" |
|
|
|
|
|
@dataclass |
|
class ModelConfig(BaseConfig): |
|
""" |
|
OLMo (model) configuration. |
|
""" |
|
|
|
|
|
|
|
d_model: int = 768 |
|
""" |
|
The hidden size of the model. |
|
""" |
|
|
|
n_heads: int = 12 |
|
""" |
|
The number of self-attention heads. |
|
""" |
|
|
|
n_kv_heads: Optional[int] = None |
|
""" |
|
The number of heads to use for keys and values. Defaults to `n_heads`. |
|
Set this to ``None`` or ``n_heads`` for normal multi-head attention. |
|
Set this to 1 for multi-query attention. |
|
Set it to some in-between value for Llama2-style grouped query attention. |
|
""" |
|
|
|
clip_qkv: Optional[float] = None |
|
""" |
|
Clip QKV to this value when set. |
|
""" |
|
|
|
n_layers: int = 12 |
|
""" |
|
The number of layers/blocks. |
|
""" |
|
|
|
mlp_ratio: int = 4 |
|
""" |
|
The ratio of the inner MLP dimensionality to ``d_model``. |
|
This is only used when ``mlp_hidden_size`` is not set. |
|
""" |
|
|
|
mlp_hidden_size: Optional[int] = None |
|
""" |
|
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`. |
|
""" |
|
|
|
activation_type: ActivationType = ActivationType.swiglu |
|
""" |
|
The activation function to use within the MLP layers. |
|
""" |
|
|
|
block_type: BlockType = BlockType.sequential |
|
""" |
|
The transformer block implementation. |
|
""" |
|
|
|
block_group_size: int = 1 |
|
""" |
|
The number of blocks to group together into a single parent block. |
|
This has no affect on the number of parameters in the model and is only used to wrap groups |
|
of blocks together with a single FSDP wrapper during training. |
|
""" |
|
|
|
alibi: bool = False |
|
""" |
|
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``. |
|
""" |
|
|
|
alibi_bias_max: float = 8.0 |
|
""" |
|
Maximum absolute value of ALiBi bias. |
|
""" |
|
|
|
rope: bool = False |
|
""" |
|
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. |
|
""" |
|
|
|
rope_full_precision: bool = True |
|
""" |
|
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, |
|
apply RoPE at the precision of the input. |
|
""" |
|
|
|
flash_attention: bool = False |
|
""" |
|
If ``True``, use ``FlashAttention``. |
|
""" |
|
|
|
attention_dropout: float = 0.1 |
|
""" |
|
The dropout probability within the attention modules. |
|
""" |
|
|
|
multi_query_attention: Optional[bool] = None |
|
""" |
|
Deprecated. Use n_kv_heads instead. |
|
""" |
|
|
|
attention_layer_norm: bool = False |
|
""" |
|
Apply layer norm to the keys and queries within the attention mechanism. |
|
This can help stabilize training. |
|
""" |
|
|
|
residual_dropout: float = 0.1 |
|
""" |
|
The dropout probability for the MLP and attention output within each block. |
|
""" |
|
|
|
embedding_dropout: float = 0.1 |
|
""" |
|
The dropout probability for embeddings. |
|
""" |
|
|
|
layer_norm_type: LayerNormType = LayerNormType.default |
|
""" |
|
The layernorm implementation to use. |
|
""" |
|
|
|
layer_norm_with_affine: bool = True |
|
""" |
|
Whether to include bias and weight parameters for the layer norms. |
|
This only affects layer norms that are immediately followed by a linear layer in the forward pass, |
|
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine` |
|
to ``False``. |
|
""" |
|
|
|
attention_layer_norm_with_affine: bool = True |
|
""" |
|
Toggle affine transform for the QK norms. |
|
""" |
|
|
|
max_sequence_length: int = 1024 |
|
""" |
|
The maximum input sequence length supported by the model. |
|
""" |
|
|
|
include_bias: bool = True |
|
""" |
|
Whether or not to include bias parameters in linear layers. |
|
In PaLM, they got rid of all bias terms because they found that large |
|
models tend to have near 0 bias terms anyway. |
|
""" |
|
|
|
bias_for_layer_norm: Optional[bool] = None |
|
""" |
|
Whether or not to include bias parameters in layer norm. |
|
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in |
|
layer norm. |
|
When this is None (the default), it inherits the setting from include_bias. |
|
""" |
|
|
|
scale_logits: bool = False |
|
""" |
|
If ``True``, scale the output logits by ``1 / sqrt(d_model)``. |
|
""" |
|
|
|
vocab_size: int = 50257 |
|
""" |
|
Vocabulary size of the model. |
|
""" |
|
|
|
embedding_size: Optional[int] = 50304 |
|
""" |
|
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default |
|
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the |
|
next multiple of 128 that's greater than ``vocab_size`` can improve throughput |
|
substantially. |
|
""" |
|
|
|
weight_tying: bool = True |
|
""" |
|
Whether to tie output linear weights to the input embedding. |
|
""" |
|
|
|
eos_token_id: int = 50256 |
|
""" |
|
The ID of the end-of-sentence special token. |
|
""" |
|
|
|
pad_token_id: int = 50256 |
|
""" |
|
The ID of the token to use for padding. Defaults to the ID of the EOS token. |
|
""" |
|
|
|
init_device: Optional[str] = None |
|
""" |
|
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta". |
|
""" |
|
|
|
init_fn: InitFnType = InitFnType.normal |
|
""" |
|
The weight initialization strategy. |
|
""" |
|
|
|
init_std: float = 0.02 |
|
""" |
|
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such |
|
as "normal". |
|
""" |
|
|
|
init_cutoff_factor: Optional[float] = None |
|
""" |
|
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such |
|
as "normal". Setting this to None means values are not cutoff. |
|
""" |
|
|
|
precision: Optional[str] = None |
|
""" |
|
Precision used to train/evaluate with. You shouldn't set this directly. |
|
See :data:`TrainConfig.precision` instead. |
|
""" |
|
|
|
ternary: bool = False |
|
""" |
|
Use ternary BitLinear layer from "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" (https://arxiv.org/pdf/2402.17764.pdf) |
|
""" |
|
|
|
@property |
|
def effective_n_kv_heads(self) -> int: |
|
if self.n_kv_heads is None: |
|
if self.multi_query_attention is True: |
|
return 1 |
|
else: |
|
return self.n_heads |
|
else: |
|
if self.multi_query_attention is None: |
|
return self.n_kv_heads |
|
if self.multi_query_attention: |
|
n_kv_heads_should_be = 1 |
|
else: |
|
n_kv_heads_should_be = self.n_heads |
|
if self.n_kv_heads == n_kv_heads_should_be: |
|
return n_kv_heads_should_be |
|
else: |
|
raise OLMoConfigurationError( |
|
"You can't set `multi_query_attention` and `n_kv_heads` at the same time." |
|
) |
|
|
|
|
|
class OptimizerType(StrEnum): |
|
lionw = "lionw" |
|
adamw = "adamw" |
|
|
|
|
|
@dataclass |
|
class OptimizerConfig(BaseConfig): |
|
name: OptimizerType = OptimizerType.lionw |
|
learning_rate: float = 1.0e-4 |
|
weight_decay: float = 0.01 |
|
betas: Tuple[float, float] = (0.9, 0.95) |
|
|
|
no_decay_norm_and_bias: Optional[bool] = None |
|
""" |
|
Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead. |
|
""" |
|
|
|
decay_norm_and_bias: bool = False |
|
decay_embeddings: bool = False |
|
metrics_log_interval: Optional[int] = None |
|
""" |
|
The interval with which to collect and log detailed parameter-specific metrics. |
|
This only applies when logging to W&B, since these metrics won't be logged to the console. |
|
If not set, defaults to the wandb `log_interval`. |
|
""" |
|
|
|
def __post_init__(self): |
|
self.betas = tuple(self.betas) |
|
|
|
@classmethod |
|
def update_legacy_settings(cls, config: D) -> D: |
|
new_config = config.copy() |
|
if om.is_dict(new_config): |
|
assert isinstance(new_config, DictConfig) |
|
|
|
if hasattr(new_config, "name") and new_config.name == "decoupled_lionw": |
|
new_config.name = "lionw" |
|
if hasattr(new_config, "eps"): |
|
del new_config.eps |
|
|
|
return new_config |
|
|
|
|
|
class SchedulerType(StrEnum): |
|
cosine_with_warmup = "cosine_with_warmup" |
|
linear_with_warmup = "linear_with_warmup" |
|
inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup" |
|
max_scheduler = "max_scheduler" |
|
constant = "constant" |
|
|
|
|
|
class SchedulerUnits(StrEnum): |
|
steps = "steps" |
|
tokens = "tokens" |
|
|
|
|
|
@dataclass |
|
class SchedulerConfig(BaseConfig): |
|
name: SchedulerType = SchedulerType.cosine_with_warmup |
|
units: SchedulerUnits = SchedulerUnits.steps |
|
t_warmup: Union[int, float] = 100 |
|
t_max: Optional[Union[int, float]] = None |
|
alpha_f: float = 0.1 |
|
|
|
grad_clip_warmup_steps: Optional[Union[int, float]] = None |
|
""" |
|
The warmup period for which the max grad norm (or norm ratio) will be set to its |
|
warmup value of `max_grad_norm * grad_clip_warmup_factor`. |
|
""" |
|
|
|
grad_clip_warmup_factor: Optional[float] = None |
|
""" |
|
The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period |
|
vs after the warmup period. |
|
""" |
|
|
|
|
|
class PaddingDirection(StrEnum): |
|
right = "right" |
|
left = "left" |
|
|
|
|
|
@dataclass |
|
class DataConfig(BaseConfig): |
|
paths: Optional[List[str]] = None |
|
datasets: Optional[Dict[str, List[str]]] = None |
|
label_mask_paths: Optional[List[str]] = None |
|
pad_direction: PaddingDirection = PaddingDirection.right |
|
generate_attention_mask: bool = False |
|
num_workers: int = 0 |
|
drop_last: bool = False |
|
pin_memory: bool = False |
|
prefetch_factor: Optional[int] = None |
|
persistent_workers: bool = False |
|
timeout: int = 0 |
|
seed: Optional[int] = None |
|
|
|
|
|
class EvaluatorType(StrEnum): |
|
downstream = "downstream" |
|
lm = "lm" |
|
|
|
|
|
@dataclass |
|
class EvaluatorConfig(BaseConfig): |
|
label: str |
|
type: EvaluatorType = EvaluatorType.lm |
|
data: DataConfig = field(default_factory=DataConfig) |
|
device_eval_batch_size: Optional[int] = None |
|
subset_num_batches: Optional[int] = None |
|
|
|
|
|
class TruncationDirection(StrEnum): |
|
right = "right" |
|
left = "left" |
|
|
|
|
|
@dataclass |
|
class TokenizerConfig(BaseConfig): |
|
identifier: str = "gpt2" |
|
truncate_direction: TruncationDirection = TruncationDirection.right |
|
|
|
|
|
@dataclass |
|
class WandbConfig(BaseConfig): |
|
project: Optional[str] = None |
|
entity: Optional[str] = "ai2-llm" |
|
group: Optional[str] = None |
|
name: Optional[str] = None |
|
tags: Optional[List[str]] = field(default_factory=lambda: ["watching"]) |
|
log_artifacts: bool = False |
|
rank_zero_only: bool = True |
|
log_interval: int = 1 |
|
|
|
|
|
@dataclass |
|
class SpeedMonitorConfig(BaseConfig): |
|
window_size: int = 100 |
|
gpu_flops_available: Optional[Union[float, int]] = None |
|
|
|
|
|
@dataclass |
|
class CompilerConfig(BaseConfig): |
|
mode: Optional[str] = None |
|
""" |
|
The mode to compile the model in. At the moment this can be "default", |
|
"reduce-overhead" (useful for smaller models/batches), or "max-autotune" |
|
(the fastest for larger models, but takes a long time to compile). |
|
""" |
|
|
|
fullgraph: bool = False |
|
""" |
|
Whether it is OK to break model into several subgraphs when compiling. |
|
Note that this is not compatible with FSDP. |
|
""" |
|
|
|
backend: str = "inductor" |
|
""" |
|
The backend to use. |
|
""" |
|
|
|
|
|
class FSDPWrapStrategy(StrEnum): |
|
by_block = "by_block" |
|
""" |
|
Wrap each OLMo block with its own FSDP instance. |
|
""" |
|
|
|
by_block_and_size = "by_block_and_size" |
|
""" |
|
Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well. |
|
""" |
|
|
|
by_block_group = "by_block_group" |
|
""" |
|
Wrap each block group together into its own FSDP instance. |
|
This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1. |
|
""" |
|
|
|
by_block_group_and_size = "by_block_group_and_size" |
|
""" |
|
Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well. |
|
""" |
|
|
|
size_based = "size_based" |
|
""" |
|
Used PyTorch's default size-based auto wrap policy. |
|
""" |
|
|
|
one_in_two = "one_in_two" |
|
one_in_three = "one_in_three" |
|
one_in_four = "one_in_four" |
|
one_in_five = "one_in_five" |
|
|
|
|
|
class FSDPPrecision(StrEnum): |
|
pure = "pure" |
|
""" |
|
Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``, |
|
and ``buffer_dtype`` all set to the autocast precision data type. |
|
""" |
|
|
|
mixed = "mixed" |
|
""" |
|
Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype`` |
|
set to the autocast precision data type, while ``reduce_dtype`` is set to fp32. |
|
""" |
|
|
|
|
|
@dataclass |
|
class FSDPConfig(BaseConfig): |
|
use_orig_params: bool = True |
|
""" |
|
This must be ``True`` if using ``compile`` or you want to track the parameter norm during training. |
|
""" |
|
|
|
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD |
|
|
|
wrapping_strategy: Optional[FSDPWrapStrategy] = None |
|
""" |
|
The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level |
|
FSDP instance. |
|
""" |
|
|
|
precision: FSDPPrecision = FSDPPrecision.pure |
|
|
|
|
|
class CheckpointType(StrEnum): |
|
sharded = "sharded" |
|
unsharded = "unsharded" |
|
sharded_ephemeral = "sharded_ephemeral" |
|
|
|
|
|
class ShardedCheckpointerType(StrEnum): |
|
torch_new = "torch_new" |
|
torch_legacy = "torch_legacy" |
|
local = "local" |
|
|
|
|
|
class ActivationCheckpointingStrategy(StrEnum): |
|
whole_layer = "whole_layer" |
|
""" |
|
Checkpoint every transformer layer. |
|
""" |
|
|
|
one_in_two = "one_in_two" |
|
""" |
|
Checkpoint one in two transformer layers. |
|
""" |
|
|
|
one_in_three = "one_in_three" |
|
""" |
|
Checkpoint one in three transformer layers. |
|
""" |
|
|
|
one_in_four = "one_in_four" |
|
""" |
|
Checkpoint one in four transformer layers. |
|
""" |
|
|
|
two_in_three = "two_in_three" |
|
""" |
|
Checkpoint two out of every three transformer layers. |
|
""" |
|
|
|
three_in_four = "three_in_four" |
|
""" |
|
Checkpoint three out of four of every transformer layers. |
|
""" |
|
|
|
fine_grained = "fine_grained" |
|
""" |
|
Focus checkpointing on where it is cheap to recompute and saves most memory. |
|
""" |
|
|
|
|
|
@dataclass |
|
class TrainConfig(BaseConfig): |
|
""" |
|
OLMo training configuration. |
|
""" |
|
|
|
run_name: Optional[str] = None |
|
""" |
|
The name of the run. |
|
""" |
|
|
|
seed: int = 6198 |
|
""" |
|
Used to seed all initial RNG states. |
|
""" |
|
|
|
epoch: Optional[int] = None |
|
""" |
|
Increment this when starting a new epoch. |
|
""" |
|
|
|
dry_run: bool = False |
|
""" |
|
If ``True``, don't actually train. |
|
""" |
|
|
|
model: ModelConfig = field(default_factory=ModelConfig) |
|
""" |
|
OLMo Model configuration. |
|
""" |
|
|
|
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) |
|
""" |
|
Optimizer configuration. |
|
""" |
|
|
|
scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) |
|
""" |
|
Learning rate scheduler configuration. |
|
""" |
|
|
|
data: DataConfig = field(default_factory=DataConfig) |
|
""" |
|
Training data configuration. |
|
""" |
|
|
|
restore_dataloader: bool = True |
|
""" |
|
When restarting, restore the data loader to where it left off. |
|
If you restarting in order to train on a different dataset, set this to ``False``. |
|
""" |
|
|
|
fast_forward_batches: Optional[int] = None |
|
""" |
|
When restarting, use this to fast-forward the dataloader beyond the last checkpoint. |
|
This can be useful when restarting due to a loss spike in order to skip the data that |
|
corresponded to the spike. |
|
""" |
|
|
|
evaluators: List[EvaluatorConfig] = field(default_factory=list) |
|
""" |
|
Evaluation configurations. |
|
""" |
|
|
|
eval_interval: int = 1000 |
|
""" |
|
How often (in terms of batches) to run evaluations. |
|
""" |
|
|
|
tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig) |
|
""" |
|
Tokenizer configuration. |
|
""" |
|
|
|
save_folder: str = "./" |
|
""" |
|
The directory to save checkpoints to. |
|
""" |
|
|
|
remote_save_folder: Optional[str] = None |
|
""" |
|
A folder in a cloud bucket to upload saved checkpoints to. |
|
""" |
|
|
|
canceled_check_interval: int = 50 |
|
""" |
|
How often (in batches) to check if the run has been canceled or reached its time limit. |
|
""" |
|
|
|
save_interval: int = 1000 |
|
""" |
|
How often (in terms of steps) to save sharded training state checkpoints. |
|
""" |
|
|
|
save_interval_unsharded: Optional[int] = None |
|
""" |
|
How often (if at all) to save unsharded training state checkpoint. |
|
For large models it can be costly to save these, so it usually makes sense to save |
|
these less often than regular (sharded) training checkpoints. |
|
""" |
|
|
|
save_interval_ephemeral: Optional[int] = None |
|
""" |
|
How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same |
|
as those saved every `save_interval` except that at most only the most recent one of these is kept. |
|
This is useful when you want to checkpoint often for restarts in case of failures, but don't |
|
want to keep the majority of these checkpoints. |
|
|
|
For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save |
|
a temporary checkpoint every 100 steps in case your job fails. In that case you would |
|
set `save_interval=1000` and `save_interval_ephemeral=100`. |
|
""" |
|
|
|
save_num_checkpoints_to_keep: int = -1 |
|
""" |
|
How many sharded checkpoints to keep. |
|
""" |
|
|
|
save_num_unsharded_checkpoints_to_keep: int = -1 |
|
""" |
|
How many unsharded checkpoints to keep. |
|
""" |
|
|
|
save_overwrite: bool = False |
|
""" |
|
If ``True``, overwrite any conflicting checkpoint files. |
|
""" |
|
|
|
force_save_unsharded: bool = False |
|
""" |
|
Save an unsharded checkpoint before training (even during a dry run). |
|
Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded |
|
checkpoint into an unsharded checkpoint. |
|
""" |
|
|
|
no_pre_train_checkpoint: bool = False |
|
""" |
|
Skip saving pre-train checkpoint. |
|
""" |
|
|
|
load_path: Optional[str] = None |
|
""" |
|
The path to a training checkpoint to restore/resume from. |
|
|
|
Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes |
|
a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory. |
|
For example, |
|
|
|
```bash |
|
--load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}' |
|
``` |
|
""" |
|
|
|
load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None |
|
""" |
|
The sharded checkpointer type to use to load the initial checkpoint from ``load_path``. |
|
""" |
|
|
|
reset_optimizer_state: bool = False |
|
""" |
|
When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized. |
|
We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning |
|
curve (according to the current learning rate schedule settings), and continues from there. |
|
""" |
|
|
|
reset_trainer_state: bool = False |
|
""" |
|
When this is set we don't restore the trainer state from a checkpoint. |
|
""" |
|
|
|
sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy |
|
""" |
|
The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training. |
|
""" |
|
|
|
new_style_checkpoints: Optional[bool] = None |
|
""" |
|
Deprecated. Use ``sharded_checkpointer`` instead. |
|
""" |
|
|
|
max_duration: Union[int, str] = 10000 |
|
""" |
|
How long to train for. |
|
|
|
If specified without a unit (the default), the units are assumed to be steps. |
|
You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until |
|
2 trillion tokens. |
|
""" |
|
|
|
global_train_batch_size: int = 512 |
|
""" |
|
The effective global batch size. |
|
""" |
|
|
|
device_train_batch_size: Optional[int] = None |
|
""" |
|
Don't set this manually. This will be set to ``global_train_batch_size // world_size``. |
|
""" |
|
|
|
device_train_microbatch_size: int = 16 |
|
""" |
|
The number of instances passed to the model in a single forward-backward pass. You should set |
|
this as large as you can based on available GPU memory. |
|
""" |
|
|
|
device_eval_batch_size: int = 16 |
|
""" |
|
The number of evaluation instances passed to the model in a single forward pass on each device. |
|
""" |
|
|
|
eval_subset_num_batches: int = -1 |
|
""" |
|
The number of batches to use for downstream evaluation from each dataset. |
|
""" |
|
|
|
eval_on_load: bool = False |
|
""" |
|
When resuming from a checkpoint, run the evaluation loop right away. |
|
""" |
|
|
|
device_train_grad_accum: Optional[int] = None |
|
""" |
|
Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``. |
|
""" |
|
|
|
max_grad_norm: Optional[float] = None |
|
""" |
|
Clip gradient norms to this value if set. |
|
""" |
|
|
|
max_grad_norm_ratio: Optional[float] = None |
|
""" |
|
If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`. |
|
This takes priority over `max_grad_norm` when set. |
|
""" |
|
|
|
precision: Optional[str] = None |
|
""" |
|
Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32"). |
|
""" |
|
|
|
wandb: Optional[WandbConfig] = None |
|
""" |
|
Weights & Biases configuration. |
|
""" |
|
|
|
speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig) |
|
""" |
|
Speed monitor configuration. |
|
""" |
|
|
|
console_log_interval: int = 1 |
|
""" |
|
How often to log to the console. |
|
""" |
|
|
|
compile: Optional[CompilerConfig] = None |
|
""" |
|
Settings for compiling the model with ``torch.compile()``. |
|
""" |
|
|
|
fsdp: FSDPConfig = field(default_factory=FSDPConfig) |
|
""" |
|
Fully sharded data parallel settings. |
|
""" |
|
|
|
softmax_auxiliary_loss: bool = False |
|
""" |
|
If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax |
|
normalizing term to be close to 0. |
|
""" |
|
|
|
time_limit: Optional[float] = 60 * 60 * 47.5 |
|
""" |
|
The maximum amount of time to train for before saving a checkpoint and ending early. |
|
On LUMI we have 48 hours max per job, so we default to just under 48 hours to give us time |
|
to write out a final checkpoint. |
|
""" |
|
|
|
extra_steps_after_cancel: int = 10 |
|
""" |
|
Under certain conditions when a run is canceled we train for a few extra steps after saving |
|
the final checkpoint so that when the run is restarted from the latest checkpoint we have some |
|
overlap in metrics. |
|
""" |
|
|
|
early_stopping_factor: Optional[float] = None |
|
|
|
save_data_indices: bool = True |
|
""" |
|
Save training data indices from each batch for each worker. |
|
""" |
|
|
|
python_profiling: bool = False |
|
""" |
|
Whether to run the Python profiler on batches 6, 7, and 8. |
|
""" |
|
|
|
torch_profiling: bool = False |
|
""" |
|
Whether to run the PyTorch profiler on batches 6, 7, and 8. |
|
""" |
|
|
|
stop_at: Optional[int] = None |
|
""" |
|
Stop at a specific step. |
|
""" |
|
|
|
stop_after: Optional[int] = None |
|
""" |
|
Stop after a specific number of steps. |
|
""" |
|
|
|
activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None |
|
""" |
|
The activation checkpointing strategy to use. |
|
""" |
|
|
|
fused_loss: Optional[bool] = None |
|
""" |
|
Whether to use the fused CE loss function from `flash-attn`. |
|
""" |
|
|
|
@property |
|
def autocast_precision(self) -> torch.dtype: |
|
if self.precision == "amp_bf16": |
|
return torch.bfloat16 |
|
elif self.precision == "amp_fp16": |
|
return torch.float16 |
|
elif self.precision == "fp32": |
|
return torch.float32 |
|
else: |
|
raise ValueError(f"Unexpected precision type '{self.precision}'") |
|
|
|
@property |
|
def fsdp_precision(self) -> MixedPrecision: |
|
if self.fsdp.precision == FSDPPrecision.pure: |
|
return MixedPrecision( |
|
param_dtype=self.autocast_precision, |
|
reduce_dtype=self.autocast_precision, |
|
buffer_dtype=self.autocast_precision, |
|
) |
|
elif self.fsdp.precision == FSDPPrecision.mixed: |
|
return MixedPrecision( |
|
param_dtype=self.autocast_precision, |
|
reduce_dtype=torch.float32, |
|
buffer_dtype=self.autocast_precision, |
|
) |
|
else: |
|
raise NotImplementedError(f"{self.fsdp.precision}") |
|
|
|
@classmethod |
|
def update_legacy_settings(cls, config: D) -> D: |
|
new_config = config.copy() |
|
if om.is_dict(new_config): |
|
assert isinstance(new_config, DictConfig) |
|
|
|
if hasattr(new_config, "activation_checkpointing"): |
|
if new_config.activation_checkpointing is False: |
|
new_config.activation_checkpointing = None |
|
if new_config.activation_checkpointing is True: |
|
new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer |
|
|
|
if hasattr(new_config, "optimizer"): |
|
new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer) |
|
|
|
return new_config |
|
|