First TPU Training on Google Cloud
This tutorial walks you through setting up and running model training on TPU using the optimum-tpu
package.
Prerequisites
Before starting, ensure you have a running TPU instance (see TPU Setup Guide)
Environment Setup
First, create and activate a virtual environment:
python -m venv .venv
source .venv/bin/activate
Install the required packages:
# Install optimum-tpu with PyTorch/XLA support
pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html
# Install additional training dependencies
pip install transformers datasets accelerate trl peft evaluate
Understanding FSDP for TPU Training
To speed up your training on TPU, you can rely on Optimum TPU’s integration with FSDP (Fully Sharded Data Parallel). When training large models, FSDP automatically shards (splits) your model across all available TPU workers, providing several key benefits:
- Memory efficiency: Each TPU worker only stores a portion of the model parameters, reducing per-device memory requirements
- Automatic scaling: FSDP handles the complexity of distributing the model and aggregating gradients
- Performance optimization: Optimum TPU’s implementation is specifically tuned for TPU hardware
This sharding happens automatically when you use the fsdp_v2.get_fsdp_training_args(model)
configuration in your training setup, making it easy to train larger models that wouldn’t fit on a single TPU device.
How to Setup FSDP
The key modification to enable FSDP is just these few lines:
+from optimum.tpu import fsdp_v2
+fsdp_v2.use_fsdp_v2()
+fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)
Then include these arguments in your trainer configuration:
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=TrainingArguments(
...
+ dataloader_drop_last=True, # Required for FSDPv2
+ **fsdp_training_args,
),
...
)
Complete example
Here’s a full working example that demonstrates TPU training with FSDP:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
from optimum.tpu import fsdp_v2
# Enable FSDPv2 for TPU
fsdp_v2.use_fsdp_v2()
# Load model and dataset
model_id = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
dataset = load_dataset("tatsu-lab/alpaca", split="train[:1000]")
# Get FSDP training arguments
fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)
# Create trainer with minimal configuration
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=TrainingArguments(
output_dir="./output",
dataloader_drop_last=True, # Required for FSDPv2
**fsdp_training_args,
),
peft_config=LoraConfig(
r=8,
target_modules=["k_proj", "v_proj"],
task_type="CAUSAL_LM",
),
)
# Start training
trainer.train()
Save this code as train.py and run it:
python train.py
You should now see the loss decrease during training. When the training is done, you will have a fine-tuned model. Congrats - you’ve just trained your first model on TPUs! 🙌
Next Steps
Continue your TPU training journey by exploring:
- More complex training scenarios in our examples
- Different model architectures supported by Optimum TPU