Google TPUs documentation

First TPU Training on Google Cloud

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

First TPU Training on Google Cloud

This tutorial walks you through setting up and running model training on TPU using the optimum-tpu package.

Prerequisites

Before starting, ensure you have a running TPU instance (see TPU Setup Guide)

Environment Setup

First, create and activate a virtual environment:

python -m venv .venv
source .venv/bin/activate

Install the required packages:

# Install optimum-tpu with PyTorch/XLA support
pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html

# Install additional training dependencies
pip install transformers datasets accelerate trl peft evaluate

Understanding FSDP for TPU Training

To speed up your training on TPU, you can rely on Optimum TPU’s integration with FSDP (Fully Sharded Data Parallel). When training large models, FSDP automatically shards (splits) your model across all available TPU workers, providing several key benefits:

  1. Memory efficiency: Each TPU worker only stores a portion of the model parameters, reducing per-device memory requirements
  2. Automatic scaling: FSDP handles the complexity of distributing the model and aggregating gradients
  3. Performance optimization: Optimum TPU’s implementation is specifically tuned for TPU hardware

This sharding happens automatically when you use the fsdp_v2.get_fsdp_training_args(model) configuration in your training setup, making it easy to train larger models that wouldn’t fit on a single TPU device.

How to Setup FSDP

The key modification to enable FSDP is just these few lines:

+from optimum.tpu import fsdp_v2
+fsdp_v2.use_fsdp_v2()
+fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)

Then include these arguments in your trainer configuration:

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        ...
+       dataloader_drop_last=True,  # Required for FSDPv2
+       **fsdp_training_args,
    ),
    ...
)

Complete example

Here’s a full working example that demonstrates TPU training with FSDP:

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
from optimum.tpu import fsdp_v2

# Enable FSDPv2 for TPU
fsdp_v2.use_fsdp_v2()

# Load model and dataset
model_id = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
dataset = load_dataset("tatsu-lab/alpaca", split="train[:1000]")

# Get FSDP training arguments
fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)

# Create trainer with minimal configuration
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        output_dir="./output",
        dataloader_drop_last=True,  # Required for FSDPv2
        **fsdp_training_args,
    ),
    peft_config=LoraConfig(
        r=8,
        target_modules=["k_proj", "v_proj"],
        task_type="CAUSAL_LM",
    ),
)

# Start training
trainer.train()

Save this code as train.py and run it:

python train.py

You should now see the loss decrease during training. When the training is done, you will have a fine-tuned model. Congrats - you’ve just trained your first model on TPUs! 🙌

Next Steps

Continue your TPU training journey by exploring: