Fully Sharded Data Parallel (FSDP) v2
Overview
When fine-tuning Large Language Models (LLMs) on TPUs, model sharding across devices becomes essential for memory efficiency and improved training performance. The optimum.tpu.fsdp_v2
module provides utilities for implementing Fully Sharded Data Parallel training using SPMD (Single Program Multiple Data) specifically optimized for TPU devices.
FSDP_v2 Features
- Model weight sharding across TPU devices
- Gradient checkpointing support
- Automatic configuration for common model architectures
- Integration with PyTorch/XLA’s SPMD implementation
Basic Usage
Here’s how to enable and configure FSDP_v2 for your training:
from optimum.tpu import fsdp_v2
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Enable FSDP_v2
fsdp_v2.use_fsdp_v2()
# Load model and tokenizer
model_id = "meta-llama/Llama-2-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16
)
# Get FSDP training configuration
fsdp_args = fsdp_v2.get_fsdp_training_args(model)
Configuration Options
The get_fsdp_training_args()
function returns a dictionary with a model-specific configuration such as:
{
'fsdp': 'full_shard',
'fsdp_config': {
'transformer_layer_cls_to_wrap': ['LlamaDecoderLayer'], # Model-specific
'xla': True,
'xla_fsdp_v2': True,
'xla_fsdp_grad_ckpt': True
}
}
Key Parameters
transformer_layer_cls_to_wrap
: Specifies which model layers to wrap with FSDPxla
: Enables XLA optimizationxla_fsdp_v2
: Activates FSDP_v2 implementationxla_fsdp_grad_ckpt
: Enables gradient checkpointing for memory efficiency
Advanced Usage
Custom Layer Wrapping
You can customize which layers get wrapped with FSDP:
custom_fsdp_args = fsdp_v2.get_fsdp_training_args(
model,
layer_cls_to_wrap=['CustomTransformerLayer']
)
Integration with Transformers Trainer
FSDP_v2 configuration can be directly used with the Transformers Trainer:
from transformers import Trainer, TrainingArguments
# Or for instruction fine-tuning:
# from trl import SFTTrainer
trainer = Trainer( # or SFTTrainer
model=model,
args=TrainingArguments(**fsdp_args), # Unpack FSDP configuration
train_dataset=dataset,
...
)
Next steps
- You can look our example notebooks for best practice on training with optimum-tpu
- For more details on PyTorch/XLA’s FSDP implementation, refer to the official documentation.