Accelerate documentation

Fully Sharded Data Parallel utilities

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v1.3.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Fully Sharded Data Parallel utilities

enable_fsdp_ram_efficient_loading

accelerate.utils.enable_fsdp_ram_efficient_loading

< >

( )

Enables RAM efficient loading of Hugging Face models for FSDP in the environment.

disable_fsdp_ram_efficient_loading

accelerate.utils.disable_fsdp_ram_efficient_loading

< >

( )

Disables RAM efficient loading of Hugging Face models for FSDP in the environment.

merge_fsdp_weights

accelerate.utils.merge_fsdp_weights

< >

( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )

Parameters

  • checkpoint_dir (str) — The directory containing the FSDP checkpoints (can be either the model or optimizer).
  • output_path (str) — The path to save the merged checkpoint.
  • safe_serialization (bool, optional, defaults to True) — Whether to save the merged weights with safetensors (recommended).
  • remove_checkpoint_dir (bool, optional, defaults to False) — Whether to remove the checkpoint directory after merging.

Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if SHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors if safe_serialization else pytorch_model.bin.

Note: this is a CPU-bound process.

FullyShardedDataParallelPlugin

class accelerate.FullyShardedDataParallelPlugin

< >

( sharding_strategy: typing.Union[str, ForwardRef('torch.distributed.fsdp.ShardingStrategy')] = None backward_prefetch: typing.Union[str, ForwardRef('torch.distributed.fsdp.BackwardPrefetch')] = None mixed_precision_policy: typing.Union[dict, ForwardRef('torch.distributed.fsdp.MixedPrecision'), NoneType] = None auto_wrap_policy: typing.Union[typing.Callable, typing.Literal['transformer_based_wrap', 'size_based_wrap', 'no_wrap'], NoneType] = None cpu_offload: typing.Union[bool, ForwardRef('torch.distributed.fsdp.CPUOffload')] = None ignored_modules: typing.Optional[typing.Iterable[torch.nn.modules.module.Module]] = None state_dict_type: typing.Union[str, ForwardRef('torch.distributed.fsdp.StateDictType')] = None state_dict_config: typing.Union[ForwardRef('torch.distributed.fsdp.FullStateDictConfig'), ForwardRef('torch.distributed.fsdp.ShardedStateDictConfig'), NoneType] = None optim_state_dict_config: typing.Union[ForwardRef('torch.distributed.fsdp.FullOptimStateDictConfig'), ForwardRef('torch.distributed.fsdp.ShardedOptimStateDictConfig'), NoneType] = None limit_all_gathers: bool = True use_orig_params: bool = None param_init_fn: typing.Optional[typing.Callable[[torch.nn.modules.module.Module], NoneType]] = None sync_module_states: bool = None forward_prefetch: bool = None activation_checkpointing: bool = None cpu_ram_efficient_loading: bool = None transformer_cls_names_to_wrap: typing.Optional[typing.List[str]] = None min_num_params: typing.Optional[int] = None )

Parameters

  • sharding_strategy (Union[str, torch.distributed.fsdp.ShardingStrategy], defaults to 'FULL_SHARD') — Sharding strategy to use. Should be either a str or an instance of torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy.
  • backward_prefetch (Union[str, torch.distributed.fsdp.BackwardPrefetch], defaults to 'NO_PREFETCH') — Backward prefetch strategy to use. Should be either a str or an instance of torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch.
  • mixed_precision_policy (Optional[Union[dict, torch.distributed.fsdp.MixedPrecision]], defaults to None) — A config to enable mixed precision training with FullyShardedDataParallel. If passing in a dict, it should have the following keys: param_dtype, reduce_dtype, and buffer_dtype.
  • auto_wrap_policy (Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to NO_WRAP) -- A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of transformer_based_wrap, size_based_wrap, or no_wrap. See torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like.
  • cpu_offload (Union[bool, torch.distributed.fsdp.CPUOffload], defaults to False) — Whether to offload parameters to CPU. Should be either a bool or an instance of torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload.
  • ignored_modules (Optional[Iterable[torch.nn.Module]], defaults to None) — A list of modules to ignore when wrapping with FSDP.
  • state_dict_type (Union[str, torch.distributed.fsdp.StateDictType], defaults to 'FULL_STATE_DICT') — State dict type to use. If a string, it must be one of full_state_dict, local_state_dict, or sharded_state_dict.
  • state_dict_config (Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig], defaults to None) — State dict config to use. Is determined based on the state_dict_type if not passed in.
  • optim_state_dict_config (Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig], defaults to None) — Optim state dict config to use. Is determined based on the state_dict_type if not passed in.
  • limit_all_gathers (bool, defaults to True) — Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries.
  • use_orig_params (bool, defaults to False) — Whether to use the original parameters for the optimizer.
  • param_init_fn (Optional[Callable[[torch.nn.Module], None], defaults to None) — A Callable[torch.nn.Module] -> None that specifies how modules that are currently on the meta device should be initialized onto an actual device. Only applicable when sync_module_states is True. By default is a lambda which calls to_empty on the module.
  • sync_module_states (bool, defaults to False) — Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. Defaults to False unless cpu_ram_efficient_loading is True, then will be forcibly enabled.
  • forward_prefetch (bool, defaults to False) — Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. only use with Static graphs.
  • activation_checkpointing (bool, defaults to False) — A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time for reduced memory usage.
  • cpu_ram_efficient_loading (bool, defaults to None) — If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for Transformers. When using this, sync_module_states needs to be True.
  • transformer_cls_names_to_wrap (Optional[List[str]], defaults to None) — A list of transformer layer class names to wrap. Only applicable when auto_wrap_policy is transformer_based_wrap.
  • min_num_params (Optional[int], defaults to None) — The minimum number of parameters a module must have to be wrapped. Only applicable when auto_wrap_policy is size_based_wrap.

This plugin is used to enable fully sharded data parallelism.

set_auto_wrap_policy

< >

( model )

Given model, creates an auto_wrap_policy baesd on the passed in policy and if we can use the transformer_cls_to_wrap

set_mixed_precision

< >

( mixed_precision buffer_autocast = False override = False )

Sets the mixed precision policy for FSDP

set_state_dict_type

< >

( state_dict_type = None )

Set the state dict config based on the StateDictType.

< > Update on GitHub