Utility functions and classes
Below are a variety of utility functions that 🤗 Accelerate provides, broken down by use-case.
Constants
Constants used throughout 🤗 Accelerate for reference
The following are constants used when utilizing Accelerator.save_state()
utils.MODEL_NAME
: "pytorch_model"
utils.OPTIMIZER_NAME
: "optimizer"
utils.RNG_STATE_NAME
: "random_states"
utils.SCALER_NAME
: "scaler.pt
utils.SCHEDULER_NAME
: "scheduler
The following are constants used when utilizing Accelerator.save_model()
utils.WEIGHTS_NAME
: "pytorch_model.bin"
utils.SAFE_WEIGHTS_NAME
: "model.safetensors"
utils.WEIGHTS_INDEX_NAME
: "pytorch_model.bin.index.json"
utils.SAFE_WEIGHTS_INDEX_NAME
: "model.safetensors.index.json"
Data Classes
These are basic dataclasses used throughout 🤗 Accelerate and they can be passed in as parameters.
Standalone
These are standalone dataclasses used for checks, such as the type of distributed system being used
class accelerate.utils.ComputeEnvironment
< source >( value names = None module = None qualname = None type = None start = 1 )
Represents a type of the compute environment.
Values:
- LOCAL_MACHINE — private/custom cluster hardware.
- AMAZON_SAGEMAKER — Amazon SageMaker as compute environment.
class accelerate.DistributedType
< source >( value names = None module = None qualname = None type = None start = 1 )
Represents a type of distributed environment.
Values:
- NO — Not a distributed environment, just a single process.
- MULTI_CPU — Distributed on multiple CPU nodes.
- MULTI_GPU — Distributed on multiple GPUs.
- MULTI_MLU — Distributed on multiple MLUs.
- MULTI_MUSA — Distributed on multiple MUSAs.
- MULTI_NPU — Distributed on multiple NPUs.
- MULTI_XPU — Distributed on multiple XPUs.
- DEEPSPEED — Using DeepSpeed.
- XLA — Using TorchXLA.
class accelerate.utils.DynamoBackend
< source >( value names = None module = None qualname = None type = None start = 1 )
Represents a dynamo backend (see https://pytorch.org/docs/stable/torch.compiler.html).
Values:
- NO — Do not use torch dynamo.
- EAGER — Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo issues.
- AOT_EAGER — Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd’s extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups.
- INDUCTOR — Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton kernels. Read more
- AOT_TS_NVFUSER — nvFuser with AotAutograd/TorchScript. Read more
- NVPRIMS_NVFUSER — nvFuser with PrimTorch. Read more
- CUDAGRAPHS — cudagraphs with AotAutograd. Read more
- OFI — Uses Torchscript optimize_for_inference. Inference only. Read more
- FX2TRT — Uses Nvidia TensorRT for inference optimizations. Inference only. Read more
- ONNXRT — Uses ONNXRT for inference on CPU/GPU. Inference only. Read more
- TENSORRT — Uses ONNXRT to run TensorRT for inference optimizations. Read more
- AOT_TORCHXLA_TRACE_ONCE — Uses Pytorch/XLA with TorchDynamo optimization, for training. Read more
- TORCHXLA_TRACE_ONCE — Uses Pytorch/XLA with TorchDynamo optimization, for inference. Read more
- IPEX — Uses IPEX for inference on CPU. Inference only. Read more.
- TVM — Uses Apach TVM for inference optimizations. Read more
class accelerate.utils.LoggerType
< source >( value names = None module = None qualname = None type = None start = 1 )
Represents a type of supported experiment tracker
Values:
- ALL — all available trackers in the environment that are supported
- TENSORBOARD — TensorBoard as an experiment tracker
- WANDB — wandb as an experiment tracker
- COMETML — comet_ml as an experiment tracker
- DVCLIVE — dvclive as an experiment tracker
class accelerate.utils.PrecisionType
< source >( value names = None module = None qualname = None type = None start = 1 )
Represents a type of precision used on floating point values
Values:
- NO — using full precision (FP32)
- FP16 — using half precision
- BF16 — using brain floating point precision
class accelerate.utils.RNGType
< source >( value names = None module = None qualname = None type = None start = 1 )
An enumeration.
class accelerate.utils.SageMakerDistributedType
< source >( value names = None module = None qualname = None type = None start = 1 )
Represents a type of distributed environment.
Values:
- NO — Not a distributed environment, just a single process.
- DATA_PARALLEL — using sagemaker distributed data parallelism.
- MODEL_PARALLEL — using sagemaker distributed model parallelism.
Kwargs
These are configurable arguments for specific interactions throughout the PyTorch ecosystem that Accelerate handles under the hood.
Use this object in your Accelerator to customize how torch.autocast
behaves. Please refer to the
documentation of this context manager for more
information on each argument.
class accelerate.DistributedDataParallelKwargs
< source >( dim: int = 0 broadcast_buffers: bool = True bucket_cap_mb: int = 25 find_unused_parameters: bool = False check_reduction: bool = False gradient_as_bucket_view: bool = False static_graph: bool = False comm_hook: DDPCommunicationHookType = <DDPCommunicationHookType.NO: 'no'> comm_wrapper: Literal = <DDPCommunicationHookType.NO: 'no'> comm_state_option: dict = <factory> )
Use this object in your Accelerator to customize how your model is wrapped in a
torch.nn.parallel.DistributedDataParallel
. Please refer to the documentation of this
wrapper for more
information on each argument.
gradient_as_bucket_view
is only available in PyTorch 1.7.0 and later versions.
static_graph
is only available in PyTorch 1.11.0 and later versions.
class accelerate.utils.FP8RecipeKwargs
< source >( backend: Literal = None use_autocast_during_eval: bool = None opt_level: Literal = None margin: int = None interval: int = None fp8_format: Literal = None amax_history_len: int = None amax_compute_algo: Literal = None override_linear_precision: Tuple = None )
Parameters
- backend (
str
, optional) — Which FP8 engine to use. Must be one of"msamp"
(MS-AMP) or"te"
(TransformerEngine). If not passed, will use whichever is available in the environment, prioritizing MS-AMP. - use_autocast_during_eval (
bool
, optional, default toFalse
) — Whether to use FP8 autocast during eval mode. Generally better metrics are found when this isFalse
. - margin (
int
, optional, default to 0) — The margin to use for the gradient scaling. - interval (
int
, optional, default to 1) — The interval to use for how often the scaling factor is recomputed. - fp8_format (
str
, optional, default to “HYBRID”) — The format to use for the FP8 recipe. Must be one ofHYBRID
orE4M3
. (GenerallyHYBRID
for training,E4M3
for evaluation) - amax_history_len (
int
, optional, default to 1024) — The length of the history to use for the scaling factor computation - amax_compute_algo (
str
, optional, default to “most_recent”) — The algorithm to use for the scaling factor computation. Must be one ofmax
ormost_recent
. - override_linear_precision (
tuple
of threebool
, optional, default to(False, False, False)
) — Whether or not to executefprop
,dgrad
, andwgrad
GEMMS in higher precision. - optimization_level (
str
), one ofO1
,O2
. (default isO2
) — What level of 8-bit collective communication should be used with MS-AMP. In general:- O1: Weight gradients and
all_reduce
communications are done in fp8, reducing GPU memory usage and communication bandwidth - O2: First-order optimizer states are in 8-bit, and second order states are in FP16. Only available when using Adam or AdamW. This maintains accuracy and can potentially save the highest memory.
- 03: Specifically for DeepSpeed, implements capabilities so weights and master weights of models
are stored in FP8. If
fp8
is selected and deepspeed is enabled, will be used by default. (Not available currently).
- O1: Weight gradients and
Use this object in your Accelerator to customize the initialization of the recipe for FP8 mixed precision
training with transformer-engine
or ms-amp
.
For more information on transformer-engine
args, please refer to the API
documentation.
For more information on the ms-amp
args, please refer to the Optimization Level
documentation.
class accelerate.GradScalerKwargs
< source >( init_scale: float = 65536.0 growth_factor: float = 2.0 backoff_factor: float = 0.5 growth_interval: int = 2000 enabled: bool = True )
Use this object in your Accelerator to customize the behavior of mixed precision, specifically how the
torch.cuda.amp.GradScaler
used is created. Please refer to the documentation of this
scaler for more information on each argument.
GradScaler
is only available in PyTorch 1.5.0 and later versions.
class accelerate.InitProcessGroupKwargs
< source >( backend: Optional = 'nccl' init_method: Optional = None timeout: Optional = None )
Use this object in your Accelerator to customize the initialization of the distributed processes. Please refer to the documentation of this method for more information on each argument.
Note: If timeout
is set to None
, the default will be based upon how backend
is set.
Internal mixin that implements a to_kwargs()
method for a dataclass.
Returns a dictionary containing the attributes with values different from the default of this class.
Plugins
These are plugins that can be passed to the Accelerator object. While they are defined elsewhere in the documentation, for convenience all of them are available to see here:
class accelerate.DeepSpeedPlugin
< source >( hf_ds_config: Any = None gradient_accumulation_steps: int = None gradient_clipping: float = None zero_stage: int = None is_train_batch_min: bool = True offload_optimizer_device: str = None offload_param_device: str = None offload_optimizer_nvme_path: str = None offload_param_nvme_path: str = None zero3_init_flag: bool = None zero3_save_16bit_model: bool = None transformer_moe_cls_names: str = None enable_msamp: bool = None msamp_opt_level: Optional = None )
Parameters
- hf_ds_config (
Any
, defaults toNone
) — Path to DeepSpeed config file or dict or an object of classaccelerate.utils.deepspeed.HfDeepSpeedConfig
. - gradient_accumulation_steps (
int
, defaults toNone
) — Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value from theAccelerator
directly. - gradient_clipping (
float
, defaults toNone
) — Enable gradient clipping with value. - zero_stage (
int
, defaults toNone
) — Possible options are 0, 1, 2, 3. Default will be taken from environment variable. - is_train_batch_min (
bool
, defaults toTrue
) — If both train & eval dataloaders are specified, this will decide thetrain_batch_size
. - offload_optimizer_device (
str
, defaults toNone
) — Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3. - offload_param_device (
str
, defaults toNone
) — Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3. - offload_optimizer_nvme_path (
str
, defaults toNone
) — Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3. - offload_param_nvme_path (
str
, defaults toNone
) — Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3. - zero3_init_flag (
bool
, defaults toNone
) — Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3. - zero3_save_16bit_model (
bool
, defaults toNone
) — Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3. - transformer_moe_cls_names (
str
, defaults toNone
) — Comma-separated list of Transformers MoE layer class names (case-sensitive). For example,MixtralSparseMoeBlock
,Qwen2MoeSparseMoeBlock
,JetMoEAttention
,JetMoEBlock
, etc. - enable_msamp (
bool
, defaults toNone
) — Flag to indicate whether to enable MS-AMP backend for FP8 training. - msasmp_opt_level (
Optional[Literal["O1", "O2"]]
, defaults toNone
) — Optimization level for MS-AMP (defaults to ‘O1’). Only applicable ifenable_msamp
is True. Should be one of [‘O1’ or ‘O2’].
This plugin is used to integrate DeepSpeed.
deepspeed_config_process
< source >( prefix = '' mismatches = None config = None must_match = True **kwargs )
Process the DeepSpeed config with the values from the kwargs.
Sets the HfDeepSpeedWeakref to use the current deepspeed plugin configuration
class accelerate.FullyShardedDataParallelPlugin
< source >( sharding_strategy: Union = None backward_prefetch: Union = None mixed_precision_policy: Union = None auto_wrap_policy: Union = None cpu_offload: Union = None ignored_modules: Optional = None state_dict_type: Union = None state_dict_config: Union = None optim_state_dict_config: Union = None limit_all_gathers: bool = True use_orig_params: bool = None param_init_fn: Optional = None sync_module_states: bool = None forward_prefetch: bool = None activation_checkpointing: bool = None cpu_ram_efficient_loading: bool = None transformer_cls_names_to_wrap: Optional = None min_num_params: Optional = None )
Parameters
- sharding_strategy (
Union[str, torch.distributed.fsdp.ShardingStrategy]
, defaults to'FULL_SHARD'
) — Sharding strategy to use. Should be either astr
or an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy
. - backward_prefetch (
Union[str, torch.distributed.fsdp.BackwardPrefetch]
, defaults to'NO_PREFETCH'
) — Backward prefetch strategy to use. Should be either astr
or an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch
. - mixed_precision_policy (
Optional[Union[dict, torch.distributed.fsdp.MixedPrecision]]
, defaults toNone
) — A config to enable mixed precision training with FullyShardedDataParallel. If passing in adict
, it should have the following keys:param_dtype
,reduce_dtype
, andbuffer_dtype
. - auto_wrap_policy (
Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to
NO_WRAP) -- A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of
transformer_based_wrap,
size_based_wrap, or
no_wrap. See
torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like. - cpu_offload (
Union[bool, torch.distributed.fsdp.CPUOffload]
, defaults toFalse
) — Whether to offload parameters to CPU. Should be either abool
or an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload
. - ignored_modules (
Optional[Iterable[torch.nn.Module]]
, defaults toNone
) — A list of modules to ignore when wrapping with FSDP. - state_dict_type (
Union[str, torch.distributed.fsdp.StateDictType]
, defaults to'FULL_STATE_DICT'
) — State dict type to use. If a string, it must be one offull_state_dict
,local_state_dict
, orsharded_state_dict
. - state_dict_config (
Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig]
, defaults toNone
) — State dict config to use. Is determined based on thestate_dict_type
if not passed in. - optim_state_dict_config (
Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig]
, defaults toNone
) — Optim state dict config to use. Is determined based on thestate_dict_type
if not passed in. - limit_all_gathers (
bool
, defaults toTrue
) — Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries. - use_orig_params (
bool
, defaults toFalse
) — Whether to use the original parameters for the optimizer. - param_init_fn (
Optional[Callable[[torch.nn.Module], None]
, defaults toNone
) — ACallable[torch.nn.Module] -> None
that specifies how modules that are currently on the meta device should be initialized onto an actual device. Only applicable whensync_module_states
isTrue
. By default is alambda
which callsto_empty
on the module. - sync_module_states (
bool
, defaults toFalse
) — Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. Defaults toFalse
unlesscpu_ram_efficient_loading
isTrue
, then will be forcibly enabled. - forward_prefetch (
bool
, defaults toFalse
) — Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. only use with Static graphs. - activation_checkpointing (
bool
, defaults toFalse
) — A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time for reduced memory usage. - cpu_ram_efficient_loading (
bool
, defaults toNone
) — If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for Transformers. When using this,sync_module_states
needs to beTrue
. - transformer_cls_names_to_wrap (
Optional[List[str]]
, defaults toNone
) — A list of transformer layer class names to wrap. Only applicable whenauto_wrap_policy
istransformer_based_wrap
. - min_num_params (
Optional[int]
, defaults toNone
) — The minimum number of parameters a module must have to be wrapped. Only applicable whenauto_wrap_policy
issize_based_wrap
.
This plugin is used to enable fully sharded data parallelism.
Given model
, creates an auto_wrap_policy
baesd on the passed in policy and if we can use the
transformer_cls_to_wrap
Sets the mixed precision policy for FSDP
Set the state dict config based on the StateDictType
.
class accelerate.utils.GradientAccumulationPlugin
< source >( num_steps: int = None adjust_scheduler: bool = True sync_with_dataloader: bool = True sync_each_batch: bool = False )
Parameters
- num_steps (
int
) — The number of steps to accumulate gradients for. - adjust_scheduler (
bool
, optional, defaults toTrue
) — Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should beTrue
if the used scheduler was not adjusted for gradient accumulation. - sync_with_dataloader (
bool
, optional, defaults toTrue
) — Whether to synchronize setting the gradients when at the end of the dataloader. - sync_each_batch (
bool
, optional) — Whether to synchronize setting the gradients at each data batch. Seting toTrue
may reduce memory requirements when using gradient accumulation with distributed training, at expense of speed.
A plugin to configure gradient accumulation behavior. You can only pass one of gradient_accumulation_plugin
or
gradient_accumulation_steps
to Accelerator. Passing both raises an error.
class accelerate.utils.MegatronLMPlugin
< source >( tp_degree: int = None pp_degree: int = None num_micro_batches: int = None gradient_clipping: float = None sequence_parallelism: bool = None recompute_activations: bool = None use_distributed_optimizer: bool = None pipeline_model_parallel_split_rank: int = None num_layers_per_virtual_pipeline_stage: int = None is_train_batch_min: str = True train_iters: int = None train_samples: int = None weight_decay_incr_style: str = 'constant' start_weight_decay: float = None end_weight_decay: float = None lr_decay_style: str = 'linear' lr_decay_iters: int = None lr_decay_samples: int = None lr_warmup_iters: int = None lr_warmup_samples: int = None lr_warmup_fraction: float = None min_lr: float = 0 consumed_samples: List = None no_wd_decay_cond: Optional = None scale_lr_cond: Optional = None lr_mult: float = 1.0 megatron_dataset_flag: bool = False seq_length: int = None encoder_seq_length: int = None decoder_seq_length: int = None tensorboard_dir: str = None set_all_logging_options: bool = False eval_iters: int = 100 eval_interval: int = 1000 return_logits: bool = False custom_train_step_class: Optional = None custom_train_step_kwargs: Optional = None custom_model_provider_function: Optional = None custom_prepare_model_function: Optional = None custom_megatron_datasets_provider_function: Optional = None custom_get_batch_function: Optional = None custom_loss_function: Optional = None other_megatron_args: Optional = None )
Parameters
- tp_degree (
int
, defaults toNone
) — Tensor parallelism degree. - pp_degree (
int
, defaults toNone
) — Pipeline parallelism degree. - num_micro_batches (
int
, defaults toNone
) — Number of micro-batches. - gradient_clipping (
float
, defaults toNone
) — Gradient clipping value based on global L2 Norm (0 to disable). - sequence_parallelism (
bool
, defaults toNone
) — Enable sequence parallelism. - recompute_activations (
bool
, defaults toNone
) — Enable selective activation recomputation. - use_distributed_optimizr (
bool
, defaults toNone
) — Enable distributed optimizer. - pipeline_model_parallel_split_rank (
int
, defaults toNone
) — Rank where encoder and decoder should be split. - num_layers_per_virtual_pipeline_stage (
int
, defaults toNone
) — Number of layers per virtual pipeline stage. - is_train_batch_min (
str
, defaults toTrue
) — If both tran & eval dataloaders are specified, this will decide themicro_batch_size
. - train_iters (
int
, defaults toNone
) — Total number of samples to train over all training runs. Note that either train-iters or train-samples should be provided when usingMegatronLMDummyScheduler
. - train_samples (
int
, defaults toNone
) — Total number of samples to train over all training runs. Note that either train-iters or train-samples should be provided when usingMegatronLMDummyScheduler
. - weight_decay_incr_style (
str
, defaults to'constant'
) — Weight decay increment function. choices=[“constant”, “linear”, “cosine”]. - start_weight_decay (
float
, defaults toNone
) — Initial weight decay coefficient for L2 regularization. - end_weight_decay (
float
, defaults toNone
) — End of run weight decay coefficient for L2 regularization. - lr_decay_style (
str
, defaults to'linear'
) — Learning rate decay function. choices=[‘constant’, ‘linear’, ‘cosine’]. - lr_decay_iters (
int
, defaults toNone
) — Number of iterations for learning rate decay. If None defaults totrain_iters
. - lr_decay_samples (
int
, defaults toNone
) — Number of samples for learning rate decay. If None defaults totrain_samples
. - lr_warmup_iters (
int
, defaults toNone
) — Number of iterations to linearly warmup learning rate over. - lr_warmup_samples (
int
, defaults toNone
) — Number of samples to linearly warmup learning rate over. - lr_warmup_fraction (
float
, defaults toNone
) — Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over. - min_lr (
float
, defaults to0
) — Minumum value for learning rate. The scheduler clip values below this threshold. - consumed_samples (
List
, defaults toNone
) — Number of samples consumed in the same order as the dataloaders toaccelerator.prepare
call. - no_wd_decay_cond (
Optional
, defaults toNone
) — Condition to disable weight decay. - scale_lr_cond (
Optional
, defaults toNone
) — Condition to scale learning rate. - lr_mult (
float
, defaults to1.0
) — Learning rate multiplier. - megatron_dataset_flag (
bool
, defaults toFalse
) — Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format. - seq_length (
int
, defaults toNone
) — Maximum sequence length to process. - encoder_seq_length (
int
, defaults toNone
) — Maximum sequence length to process for the encoder. - decoder_seq_length (
int
, defaults toNone
) — Maximum sequence length to process for the decoder. - tensorboard_dir (
str
, defaults toNone
) — Path to save tensorboard logs. - set_all_logging_options (
bool
, defaults toFalse
) — Whether to set all logging options. - eval_iters (
int
, defaults to100
) — Number of iterations to run for evaluation validation/test for. - eval_interval (
int
, defaults to1000
) — Interval between running evaluation on validation set. - return_logits (
bool
, defaults toFalse
) — Whether to return logits from the model. - custom_train_step_class (
Optional
, defaults toNone
) — Custom train step class. - custom_train_step_kwargs (
Optional
, defaults toNone
) — Custom train step kwargs. - custom_model_provider_function (
Optional
, defaults toNone
) — Custom model provider function. - custom_prepare_model_function (
Optional
, defaults toNone
) — Custom prepare model function. - custom_megatron_datasets_provider_function (
Optional
, defaults toNone
) — Custom megatron train_valid_test datasets provider function. - custom_get_batch_function (
Optional
, defaults toNone
) — Custom get batch function. - custom_loss_function (
Optional
, defaults toNone
) — Custom loss function. - other_megatron_args (
Optional
, defaults toNone
) — Other Megatron-LM arguments. Please refer Megatron-LM.
Plugin for Megatron-LM to enable tensor, pipeline, sequence and data parallelism. Also to enable selective activation recomputation and optimized fused kernels.
class accelerate.utils.TorchDynamoPlugin
< source >( backend: DynamoBackend = None mode: str = None fullgraph: bool = None dynamic: bool = None options: Any = None disable: bool = False )
Parameters
- backend (
DynamoBackend
, defaults toNone
) — A valid Dynamo backend. See https://pytorch.org/docs/stable/torch.compiler.html for more details. - mode (
str
, defaults toNone
) — Possible options are ‘default’, ‘reduce-overhead’ or ‘max-autotune’. - fullgraph (
bool
, defaults toNone
) — Whether it is ok to break model into several subgraphs. - dynamic (
bool
, defaults toNone
) — Whether to use dynamic shape for tracing. - options (
Any
, defaults toNone
) — A dictionary of options to pass to the backend. - disable (
bool
, defaults toFalse
) — Turn torch.compile() into a no-op for testing
This plugin is used to compile a model with PyTorch 2.0
Configurations
These are classes which can be configured and passed through to the appropriate integration
class accelerate.utils.BnbQuantizationConfig
< source >( load_in_8bit: bool = False llm_int8_threshold: float = 6.0 load_in_4bit: bool = False bnb_4bit_quant_type: str = 'fp4' bnb_4bit_use_double_quant: bool = False bnb_4bit_compute_dtype: bool = 'fp16' torch_dtype: dtype = None skip_modules: List = None keep_in_fp32_modules: List = None )
Parameters
- load_in_8bit (
bool
, defaults toFalse
) — Enable 8bit quantization. - llm_int8_threshold (
float
, defaults to6.0
) — Value of the outliner threshold. Only relevant whenload_in_8bit=True
. - load_in_4_bit (
bool
, defaults toFalse
) — Enable 4bit quantization. - bnb_4bit_quant_type (
str
, defaults tofp4
) — Set the quantization data type in thebnb.nn.Linear4Bit
layers. Options are {‘fp4’,‘np4’}. - bnb_4bit_use_double_quant (
bool
, defaults toFalse
) — Enable nested quantization where the quantization constants from the first quantization are quantized again. - bnb_4bit_compute_dtype (
bool
, defaults tofp16
) — This sets the computational type which might be different than the input time. For example, inputs might be fp32, but computation can be set to bf16 for speedups. Options are {‘fp32’,‘fp16’,‘bf16’}. - torch_dtype (
torch.dtype
, defaults toNone
) — This sets the dtype of the remaining non quantized layers.bitsandbytes
library suggests to set the value totorch.float16
for 8 bit model and use the same dtype as the compute dtype for 4 bit model. - skip_modules (
List[str]
, defaults toNone
) — An explicit list of the modules that we don’t quantize. The dtype of these modules will betorch_dtype
. - keep_in_fp32_modules (
List
, defaults toNone
) — An explicit list of the modules that we don’t quantize. We keep them intorch.float32
.
A plugin to enable BitsAndBytes 4bit and 8bit quantization
class accelerate.DataLoaderConfiguration
< source >( split_batches: bool = False dispatch_batches: bool = None even_batches: bool = True use_seedable_sampler: bool = False non_blocking: bool = False use_stateful_dataloader: bool = False )
Parameters
- split_batches (
bool
, defaults toFalse
) — Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. IfTrue
, the actual batch size used will be the same on any kind of distributed processes, but it must be a round multiple ofnum_processes
you are using. IfFalse
, actual batch size used will be the one set in your script multiplied by the number of processes. - dispatch_batches (
bool
, defaults toNone
) — If set toTrue
, the dataloader prepared by the Accelerator is only iterated through on the main process and then the batches are split and broadcast to each process. Will default toTrue
forDataLoader
whose underlying dataset is anIterableDataset
,False
otherwise. - even_batches (
bool
, defaults toTrue
) — If set toTrue
, in cases where the total batch size across all processes does not exactly divide the dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among all workers. - use_seedable_sampler (
bool
, defaults toFalse
) — Whether or not use a fully seedable random sampler (data_loader.SeedableRandomSampler
). Ensures training results are fully reproducable using a different sampling technique. While seed-to-seed results may differ, on average the differences are neglible when using multiple different seeds to compare. Should also be ran with set_seed() for the best results. - non_blocking (
bool
, defaults toFalse
) — If set toTrue
, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device transfers, allowing for better overlap between dataloader communication and computation. Recommended that the prepared dataloader haspin_memory
set toTrue
to work properly. - use_stateful_dataloader (
bool
, defaults toFalse
) — If set toTrue
, the dataloader prepared by the Accelerator will be backed by torchdata.StatefulDataLoader. This requirestorchdata
version 0.8.0 or higher that supports StatefulDataLoader to be installed.
Configuration for dataloader-related items when calling accelerator.prepare
.
class accelerate.utils.ProjectConfiguration
< source >( project_dir: str = None logging_dir: str = None automatic_checkpoint_naming: bool = False total_limit: int = None iteration: int = 0 save_on_each_node: bool = False )
Parameters
- project_dir (
str
, defaults toNone
) — A path to a directory for storing data. - logging_dir (
str
, defaults toNone
) — A path to a directory for storing logs of locally-compatible loggers. If None, defaults toproject_dir
. - automatic_checkpoint_naming (
bool
, defaults toFalse
) — Whether saved states should be automatically iteratively named. - total_limit (
int
, defaults toNone
) — The maximum number of total saved states to keep. - iteration (
int
, defaults to0
) — The current save iteration. - save_on_each_node (
bool
, defaults toFalse
) — When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one.
Configuration for the Accelerator object based on inner-project needs.
Sets self.project_dir
and self.logging_dir
to the appropriate values.
Environmental Variables
These are environmental variables that can be enabled for different use cases
ACCELERATE_DEBUG_MODE
(str
): Whether to run accelerate in debug mode. More info available here.
Data Manipulation and Operations
These include data operations that mimic the same torch
ops but can be used on distributed processes.
accelerate.utils.broadcast
< source >( tensor from_process: int = 0 )
Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices.
accelerate.utils.broadcast_object_list
< source >( object_list from_process: int = 0 )
Broadcast a list of picklable objects form one process to the others.
accelerate.utils.concatenate
< source >( data dim = 0 )
Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.
accelerate.utils.convert_to_fp32
< source >( tensor )
Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32.
accelerate.utils.gather
< source >( tensor )
Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices.
accelerate.utils.gather_object
< source >( object: Any )
Recursively gather object in a nested list/tuple/dictionary of objects from all devices.
accelerate.utils.listify
< source >( data )
Recursively finds tensors in a nested list/tuple/dictionary and converts them to a list of numbers.
accelerate.utils.pad_across_processes
< source >( tensor dim = 0 pad_index = 0 pad_first = False )
Parameters
- tensor (nested list/tuple/dictionary of
torch.Tensor
) — The data to gather. - dim (
int
, optional, defaults to 0) — The dimension on which to pad. - pad_index (
int
, optional, defaults to 0) — The value with which to pad. - pad_first (
bool
, optional, defaults toFalse
) — Whether to pad at the beginning or the end.
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they can safely be gathered.
accelerate.utils.recursively_apply
< source >( func data *args test_type = <function is_torch_tensor at 0x7f2ddd14d480> error_on_other_type = False **kwargs )
Parameters
- func (
callable
) — The function to recursively apply. - data (nested list/tuple/dictionary of
main_type
) — The data on which to applyfunc
*args — Positional arguments that will be passed tofunc
when applied on the unpacked data. - main_type (
type
, optional, defaults totorch.Tensor
) — The base type of the objects to which applyfunc
. - error_on_other_type (
bool
, optional, defaults toFalse
) — Whether to return an error or not if after unpackingdata
, we get on an object that is not of typemain_type
. IfFalse
, the function will leave objects of types different thanmain_type
unchanged. - **kwargs (additional keyword arguments, optional) —
Keyword arguments that will be passed to
func
when applied on the unpacked data.
Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
accelerate.utils.reduce
< source >( tensor reduction = 'mean' scale = 1.0 )
Recursively reduce the tensors in a nested list/tuple/dictionary of lists of tensors across all processes by the mean of a given operation.
accelerate.utils.send_to_device
< source >( tensor device non_blocking = False skip_keys = None )
Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.
accelerate.utils.slice_tensors
< source >( data tensor_slice process_index = None num_processes = None )
Recursively takes a slice in a nested list/tuple/dictionary of tensors.
Environment Checks
These functionalities check the state of the current working environment including information about the operating system itself, what it can support, and if particular dependencies are installed.
Checks if bf16 is supported, optionally ignoring the TPU
Checks if ipex is installed.
Checks if MPS device is available. The minimum version required is 1.12.
Checks if torch_npu
is installed and potentially if a NPU is in the environment
accelerate.utils.is_torch_version
< source >( operation: str version: str )
Compares the current PyTorch version to a given reference with an operation.
Check if torch_xla
is available. To train a native pytorch job in an environment with torch xla installed, set
the USE_TORCH_XLA to false.
Checks if XPU acceleration is available either via intel_extension_for_pytorch
or via stock PyTorch (>=2.4) and
potentially if a XPU is in the environment
Environment Manipulation
A context manager that will add each keyword argument passed to os.environ
and remove them when exiting.
Will convert the values in kwargs
to strings and upper-case all the keys.
A context manager that will temporarily clear environment variables.
When this context exits, the previous environment variables will be back.
accelerate.commands.config.default.write_basic_config
< source >( mixed_precision = 'no' save_location: str = '/github/home/.cache/huggingface/accelerate/default_config.yaml' use_xpu: bool = False )
Parameters
- mixed_precision (
str
, optional, defaults to “no”) — Mixed Precision to use. Should be one of “no”, “fp16”, or “bf16” - save_location (
str
, optional, defaults todefault_json_config_file
) — Optional custom save location. Should be passed to--config_file
when usingaccelerate launch
. Default location is inside the huggingface cache folder (~/.cache/huggingface
) but can be overriden by setting theHF_HOME
environmental variable, followed byaccelerate/default_config.yaml
. - use_xpu (
bool
, optional, defaults toFalse
) — Whether to use XPU if available.
Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also set CPU if it is a CPU-only machine.
When setting up 🤗 Accelerate for the first time, rather than running accelerate config
[~utils.write_basic_config] can be used as an alternative for quick configuration.
Assigns the current process to a specific NUMA node. Ideally most efficient when having at least 2 cpus per node.
This result is cached between calls. If you want to override it, please use
accelerate.utils.environment.override_numa_afifnity
.
accelerate.utils.environment.override_numa_affinity
< source >( local_process_index: int verbose: Optional = None )
Overrides whatever NUMA affinity is set for the current process. This is very taxing and requires recalculating the
affinity to set, ideally you should use utils.environment.set_numa_affinity
instead.
Memory
accelerate.find_executable_batch_size
< source >( function: callable = None starting_batch_size: int = 128 )
A basic decorator that will try to execute function
. If it fails from exceptions related to out-of-memory or
CUDNN, the batch size is cut in half and passed to function
function
must take in a batch_size
parameter as its first argument.
Modeling
These utilities relate to interacting with PyTorch models
Computes the total size of the model and its largest layer
accelerate.utils.compute_module_sizes
< source >( model: Module dtype: Union = None special_dtypes: Optional = None buffers_only: bool = False )
Compute the size of each submodule of a given model.
accelerate.utils.extract_model_from_parallel
< source >( model keep_fp32_wrapper: bool = True recursive: bool = False ) → torch.nn.Module
Parameters
- model (
torch.nn.Module
) — The model to extract. - keep_fp32_wrapper (
bool
, optional) — Whether to remove mixed precision hooks from the model. - recursive (
bool
, optional, defaults toFalse
) — Whether to recursively extract all cases ofmodule.module
frommodel
as well as unwrap child sublayers recursively, not just the top-level distributed containers.
Returns
torch.nn.Module
The extracted model.
Extract a model from its distributed containers.
accelerate.utils.get_balanced_memory
< source >( model: Module max_memory: Optional = None no_split_module_classes: Optional = None dtype: Union = None special_dtypes: Optional = None low_zero: bool = False )
Parameters
- model (
torch.nn.Module
) — The model to analyze. - max_memory (
Dict
, optional) — A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. Example:max_memory={0: "1GB"}
. - no_split_module_classes (
List[str]
, optional) — A list of layer class names that should never be split across device (for instance any layer that has a residual connection). - dtype (
str
ortorch.dtype
, optional) — If provided, the weights will be converted to that type when loaded. - special_dtypes (
Dict[str, Union[str, torch.device]]
, optional) — If provided, special dtypes to consider for some specific weights (will override dtype used as default for all weights). - low_zero (
bool
, optional) — Minimizes the number of weights on GPU 0, which is convenient when it’s used for other operations (like the Transformers generate function).
Compute a max_memory
dictionary for infer_auto_device_map() that will balance the use of each available GPU.
All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
meta device (as it would if initialized within the init_empty_weights
context manager).
accelerate.utils.get_max_layer_size
< source >( modules: List module_sizes: Dict no_split_module_classes: List ) → Tuple[int, List[str]]
Parameters
- modules (
List[Tuple[str, torch.nn.Module]]
) — The list of named modules where we want to determine the maximum layer size. - module_sizes (
Dict[str, int]
) — A dictionary mapping each layer name to its size (as generated bycompute_module_sizes
). - no_split_module_classes (
List[str]
) — A list of class names for layers we don’t want to be split.
Returns
Tuple[int, List[str]]
The maximum size of a layer with the list of layer names realizing that maximum size.
Utility function that will scan a list of named modules and return the maximum size used by one full layer. The definition of a layer being:
- a module with no direct children (just parameters and buffers)
- a module whose class name is in the list
no_split_module_classes
accelerate.infer_auto_device_map
< source >( model: Module max_memory: Optional = None no_split_module_classes: Optional = None dtype: Union = None special_dtypes: Optional = None verbose: bool = False clean_result: bool = True offload_buffers: bool = False )
Parameters
- model (
torch.nn.Module
) — The model to analyze. - max_memory (
Dict
, optional) — A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. Example:max_memory={0: "1GB"}
. - no_split_module_classes (
List[str]
, optional) — A list of layer class names that should never be split across device (for instance any layer that has a residual connection). - dtype (
str
ortorch.dtype
, optional) — If provided, the weights will be converted to that type when loaded. - special_dtypes (
Dict[str, Union[str, torch.device]]
, optional) — If provided, special dtypes to consider for some specific weights (will override dtype used as default for all weights). - verbose (
bool
, optional, defaults toFalse
) — Whether or not to provide debugging statements as the function builds the device_map. - clean_result (
bool
, optional, defaults toTrue
) — Clean the resulting device_map by grouping all submodules that go on the same device together. - offload_buffers (
bool
, optional, defaults toFalse
) — In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as well as the parameters.
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, such that:
- we don’t exceed the memory available of any of the GPU.
- if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that has the largest size.
- if offload to the CPU is needed,we don’t exceed the RAM available on the CPU.
- if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk that has the largest size.
All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
meta device (as it would if initialized within the init_empty_weights
context manager).
accelerate.load_checkpoint_in_model
< source >( model: Module checkpoint: Union device_map: Optional = None offload_folder: Union = None dtype: Union = None offload_state_dict: bool = False offload_buffers: bool = False keep_in_fp32_modules: List = None offload_8bit_bnb: bool = False strict: bool = False )
Parameters
- model (
torch.nn.Module
) — The model in which we want to load a checkpoint. - checkpoint (
str
oros.PathLike
) — The folder checkpoint to load. It can be:- a path to a file containing a whole model state dict
- a path to a
.json
file containing the index to a sharded checkpoint - a path to a folder containing a unique
.index.json
file and the shards of a checkpoint. - a path to a folder containing a unique pytorch_model.bin or a model.safetensors file.
- device_map (
Dict[str, Union[int, str, torch.device]]
, optional) — A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. - offload_folder (
str
oros.PathLike
, optional) — If thedevice_map
contains any value"disk"
, the folder where we will offload weights. - dtype (
str
ortorch.dtype
, optional) — If provided, the weights will be converted to that type when loaded. - offload_state_dict (
bool
, optional, defaults toFalse
) — IfTrue
, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard does not fit. - offload_buffers (
bool
, optional, defaults toFalse
) — Whether or not to include the buffers in the weights offloaded to disk. - keep_in_fp32_modules(
List[str]
, optional) — A list of the modules that we keep intorch.float32
dtype. - offload_8bit_bnb (
bool
, optional) — Whether or not to enable offload of 8-bit modules on cpu/disk. - strict (
bool
, optional, defaults toFalse
) — Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model’s state_dict.
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are loaded.
Once loaded across devices, you still need to call dispatch_model() on your model to make it able to run. To group the checkpoint loading and dispatch in one single call, use load_checkpoint_and_dispatch().
accelerate.utils.load_offloaded_weights
< source >( model index offload_folder )
Loads the weights from the offload folder into the model.
accelerate.utils.load_state_dict
< source >( checkpoint_file device_map = None )
Parameters
- checkpoint_file (
str
) — The path to the checkpoint to load. - device_map (
Dict[str, Union[int, str, torch.device]]
, optional) — A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device.
Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the weights can be fast-loaded directly on the GPU.
accelerate.utils.offload_state_dict
< source >( save_dir: Union state_dict: Dict )
Offload a state dict in a given folder.
accelerate.utils.retie_parameters
< source >( model tied_params )
Reties tied parameters in a given model if the link was broken (for instance when adding hooks).
accelerate.utils.set_module_tensor_to_device
< source >( module: Module tensor_name: str device: Union value: Optional = None dtype: Union = None fp16_statistics: Optional = None tied_params_map: Optional = None )
Parameters
- module (
torch.nn.Module
) — The module in which the tensor we want to move lives. - tensor_name (
str
) — The full name of the parameter/buffer. - device (
int
,str
ortorch.device
) — The device on which to set the tensor. - value (
torch.Tensor
, optional) — The value of the tensor (useful when going from the meta device to any other device). - dtype (
torch.dtype
, optional) — If passed along the value of the parameter will be cast to thisdtype
. Otherwise,value
will be cast to the dtype of the existing parameter in the model. - fp16_statistics (
torch.HalfTensor
, optional) — The list of fp16 statistics to set on the module, used for 8 bit model serialization. - tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], optional, defaults to
None
) — A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given execution device, this parameter is useful to reuse the first available pointer of a shared weight on the device for all others, instead of duplicating memory.
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
param.to(device)
creates a new tensor not linked to the parameter, which is why we need this function).
Parallel
These include general utilities that should be used when working in parallel.
accelerate.utils.extract_model_from_parallel
< source >( model keep_fp32_wrapper: bool = True recursive: bool = False ) → torch.nn.Module
Parameters
- model (
torch.nn.Module
) — The model to extract. - keep_fp32_wrapper (
bool
, optional) — Whether to remove mixed precision hooks from the model. - recursive (
bool
, optional, defaults toFalse
) — Whether to recursively extract all cases ofmodule.module
frommodel
as well as unwrap child sublayers recursively, not just the top-level distributed containers.
Returns
torch.nn.Module
The extracted model.
Extract a model from its distributed containers.
accelerate.utils.save
< source >( obj f save_on_each_node: bool = False safe_serialization: bool = False )
Save the data to disk. Use in place of torch.save()
.
Introduces a blocking point in the script, making sure all processes have reached this point before continuing.
Make sure all processes will reach this instruction otherwise one of your processes will hang forever.
Random
These utilities relate to setting and synchronizing of all the random states.
accelerate.utils.set_seed
< source >( seed: int device_specific: bool = False deterministic: bool = False )
Parameters
Helper function for reproducible behavior to set the seed in random
, numpy
, torch
.
accelerate.utils.synchronize_rng_state
< source >( rng_type: Optional = None generator: Optional = None )
PyTorch XLA
These include utilities that are useful while using PyTorch with XLA.
accelerate.utils.install_xla
< source >( upgrade: bool = False )
Helper function to install appropriate xla wheels based on the torch
version in Google Colaboratory.
Loading model weights
These include utilities that are useful to load checkpoints.
accelerate.load_checkpoint_in_model
< source >( model: Module checkpoint: Union device_map: Optional = None offload_folder: Union = None dtype: Union = None offload_state_dict: bool = False offload_buffers: bool = False keep_in_fp32_modules: List = None offload_8bit_bnb: bool = False strict: bool = False )
Parameters
- model (
torch.nn.Module
) — The model in which we want to load a checkpoint. - checkpoint (
str
oros.PathLike
) — The folder checkpoint to load. It can be:- a path to a file containing a whole model state dict
- a path to a
.json
file containing the index to a sharded checkpoint - a path to a folder containing a unique
.index.json
file and the shards of a checkpoint. - a path to a folder containing a unique pytorch_model.bin or a model.safetensors file.
- device_map (
Dict[str, Union[int, str, torch.device]]
, optional) — A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. - offload_folder (
str
oros.PathLike
, optional) — If thedevice_map
contains any value"disk"
, the folder where we will offload weights. - dtype (
str
ortorch.dtype
, optional) — If provided, the weights will be converted to that type when loaded. - offload_state_dict (
bool
, optional, defaults toFalse
) — IfTrue
, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard does not fit. - offload_buffers (
bool
, optional, defaults toFalse
) — Whether or not to include the buffers in the weights offloaded to disk. - keep_in_fp32_modules(
List[str]
, optional) — A list of the modules that we keep intorch.float32
dtype. - offload_8bit_bnb (
bool
, optional) — Whether or not to enable offload of 8-bit modules on cpu/disk. - strict (
bool
, optional, defaults toFalse
) — Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model’s state_dict.
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are loaded.
Once loaded across devices, you still need to call dispatch_model() on your model to make it able to run. To group the checkpoint loading and dispatch in one single call, use load_checkpoint_and_dispatch().
Quantization
These include utilities that are useful to quantize model.
accelerate.utils.load_and_quantize_model
< source >( model: Module bnb_quantization_config: BnbQuantizationConfig weights_location: Union = None device_map: Optional = None no_split_module_classes: Optional = None max_memory: Optional = None offload_folder: Union = None offload_state_dict: bool = False ) → torch.nn.Module
Parameters
- model (
torch.nn.Module
) — Input model. The model can be already loaded or on the meta device - bnb_quantization_config (
BnbQuantizationConfig
) — The bitsandbytes quantization parameters - weights_location (
str
oros.PathLike
) — The folder weights_location to load. It can be:- a path to a file containing a whole model state dict
- a path to a
.json
file containing the index to a sharded checkpoint - a path to a folder containing a unique
.index.json
file and the shards of a checkpoint. - a path to a folder containing a unique pytorch_model.bin file.
- device_map (
Dict[str, Union[int, str, torch.device]]
, optional) — A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. - no_split_module_classes (
List[str]
, optional) — A list of layer class names that should never be split across device (for instance any layer that has a residual connection). - max_memory (
Dict
, optional) — A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. - offload_folder (
str
oros.PathLike
, optional) — If thedevice_map
contains any value"disk"
, the folder where we will offload weights. - offload_state_dict (
bool
, optional, defaults toFalse
) — IfTrue
, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard does not fit.
Returns
torch.nn.Module
The quantized model
This function will quantize the input model with the associated config passed in bnb_quantization_config
. If the
model is in the meta device, we will load and dispatch the weights according to the device_map
passed. If the
model is already loaded, we will quantize the model and put the model on the GPU,