Google TPUs documentation

Fully Sharded Data Parallel (FSDP) v2

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Fully Sharded Data Parallel (FSDP) v2

Overview

When fine-tuning Large Language Models (LLMs) on TPUs, model sharding across devices becomes essential for memory efficiency and improved training performance. The optimum.tpu.fsdp_v2 module provides utilities for implementing Fully Sharded Data Parallel training using SPMD (Single Program Multiple Data) specifically optimized for TPU devices.

FSDP_v2 Features

  • Model weight sharding across TPU devices
  • Gradient checkpointing support
  • Automatic configuration for common model architectures
  • Integration with PyTorch/XLA’s SPMD implementation

Basic Usage

Here’s how to enable and configure FSDP_v2 for your training:

from optimum.tpu import fsdp_v2
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Enable FSDP_v2
fsdp_v2.use_fsdp_v2()

# Load model and tokenizer
model_id = "meta-llama/Llama-2-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16
)

# Get FSDP training configuration
fsdp_args = fsdp_v2.get_fsdp_training_args(model)

Configuration Options

The get_fsdp_training_args() function returns a dictionary with a model-specific configuration such as:

{
    'fsdp': 'full_shard',
    'fsdp_config': {
        'transformer_layer_cls_to_wrap': ['LlamaDecoderLayer'],  # Model-specific
        'xla': True,
        'xla_fsdp_v2': True,
        'xla_fsdp_grad_ckpt': True
    }
}

Key Parameters

  • transformer_layer_cls_to_wrap: Specifies which model layers to wrap with FSDP
  • xla: Enables XLA optimization
  • xla_fsdp_v2: Activates FSDP_v2 implementation
  • xla_fsdp_grad_ckpt: Enables gradient checkpointing for memory efficiency

Advanced Usage

Custom Layer Wrapping

You can customize which layers get wrapped with FSDP:

custom_fsdp_args = fsdp_v2.get_fsdp_training_args(
    model,
    layer_cls_to_wrap=['CustomTransformerLayer']
)

Integration with Transformers Trainer

FSDP_v2 configuration can be directly used with the Transformers Trainer:

from transformers import Trainer, TrainingArguments
# Or for instruction fine-tuning:
# from trl import SFTTrainer

trainer = Trainer(  # or SFTTrainer
    model=model,
    args=TrainingArguments(**fsdp_args),  # Unpack FSDP configuration
    train_dataset=dataset,
    ...
)

Next steps