Online Batch Size Adaptation in Hugging Face Trainer

Community Article Published February 24, 2025

πŸŽ‰ New Feature

This repository provides a minimal extension of Hugging Face's Trainer to support dynamically changing batch sizes at each training step.

https://github.com/ifiaposto/hf_trainer_with_online_batch_size

🎯 Motivation

Training with batch sizes that can adapt on-the-fly has been shown to be beneficial in multiple ways:

  1. Improving training efficiency.: Empirical studies suggest that gradually increasing the batch sizeβ€”combined with learning rate decayβ€”can achieve the convergence benefits of smaller batch sizes while leveraging the performance advantages of large, multi-GPU batch sizes [1]. This training scheme has been revisited in cutting-edge large language models like DeepSeek-V2 [2].
  2. Supporting advanced learning algorithms for adaptively mixing multiple training data sources. In such cases, the number of examples drawn from each data stream may be unknown a priori and should be dynamically balanced based on training metrics to optimize their impact on the loss function.
    • In multi-task learning, the data sources correspond to different tasks [3].
    • In incremental learning, one must balance knowledge retention with the ability to generalize to new data by effectively mixing novel and past training examples [4,5].

πŸ› οΈ Installation

The repo is built on Python 3.12.9. You can use the following command to install the requirements:

pip install -r requirements.txt

Below is the exact cuda environment:

nvidia-cublas-cu12       12.4.5.8
nvidia-cuda-cupti-cu12   12.4.127
nvidia-cuda-nvrtc-cu12   12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12        9.1.0.70
nvidia-cufft-cu12        11.2.1.3
nvidia-curand-cu12       10.3.5.147
nvidia-cusolver-cu12     11.6.1.9
nvidia-cusparse-cu12     12.3.1.170
nvidia-cusparselt-cu12   0.6.2
nvidia-nccl-cu12         2.21.5
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.4.127

πŸš€ Quickstart

Using the trainer

You need only to define a callable implementing the batch size scheduler and pass it as an argument to the trainer:

#   < ------- Write your own batch size scheduler ------- >
def custom_batch_size_scheduler(step: int,
                                batch_size: int,
                                interval=5,
                                increment=1):
    """
        step: current optimization step to be provided by the trainer.
        batch_size: current optimization step  to be provided by the trainer.
    """

    if step % interval == 0 and step > 0:
        return batch_size + increment
    return batch_size


# Extended trainer using the adaptive batch sampler.
trainer = AdaptiveBatchSizeTrainer(
    batch_size_scheduler=partial(custom_batch_size_scheduler,
                                 interval=5,
                                 increment=1),
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collate_fn,
)

AdaptiveBatchSizeTrainer inherits from the Trainer and TrainingArguments. You can find the full list of training arguments for running your script here.

The repo currently controls the optimization length exclusively via the argument num_train_epochs. This corresponds to the number of times the training dataset will be parsed by the model. Note that the number of optimization steps might not be known a priori if your batch size scheduler works online. To avoid overriding Trainer's _inner_training_loop and maintain consistency, max_steps should be set to -1.

Understanding the logs

AdaptiveBatchSizeTrainer extends and revises Trainer's training logs to account for the varying length of each epoch:

  • train_coverage: It refers to the percent of train examples seen so far in the current epoch. Note that this might not increase linearly depending on the batch size scheduler you use.
  • epoch: It is now updated not given the optimization steps performed so far but based on the percent of train examples seen in total up to the current optimization step. This is equivalent to summing train_coverage across epochs.

🧐 Demo

Below are some sample run commands for our test code using different configurations of single/multi-gpu training and dataloader_drop_last (which ignores the last batch when there aren't sufficient training examples left). We inspect the output for the last one.

single-gpu, ignore the last batch.
python -m torch.distributed.run --nproc-per-node=1 --master_port=29619 -m test_replay_dataloader \
    --output_dir ./results \
    --logging_dir ./logs \
    --logging_steps 1 \
    --save_strategy no \
    --num_train_epochs 10 \
    --per_device_train_batch_size 1 \
    --max_steps -1 \
    --dataloader_drop_last
multi-gpu, ignore the last batch.
python -m torch.distributed.run --nproc-per-node=2 --master_port=29619 -m test_replay_dataloader \
    --output_dir ./results \
    --logging_dir ./logs \
    --logging_steps 1 \
    --save_strategy no \
    --num_train_epochs 10 \
    --per_device_train_batch_size 1 \
    --max_steps -1 \
    --dataloader_drop_last
single-gpu, use the last batch.
python -m torch.distributed.run --nproc-per-node=1 --master_port=29619 -m test_replay_dataloader \
    --output_dir ./results \
    --logging_dir ./logs \
    --logging_steps 1 \
    --save_strategy no \
    --num_train_epochs 10 \
    --per_device_train_batch_size 1 \
    --max_steps -1 \
    --dataloader_drop_last False
    
multi-gpu, use the last batch.
python -m torch.distributed.run --nproc-per-node=2 --master_port=29619 -m test_replay_dataloader \
    --output_dir ./results \
    --logging_dir ./logs \
    --logging_steps 1 \
    --save_strategy no \
    --num_train_epochs 10 \
    --per_device_train_batch_size 1 \
    --max_steps -1 \
    --dataloader_drop_last False
    
  • we notice every 5 steps (epoch 1.4, 3.6 etc) the local batch size increases by one and according to the scheduler.

  • since dataloader_drop_last= False when train_coverage=1.0 the true batch size might be different than the scheduler's batch size.

    output for multi-gpu, with the last batch included.
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 1,  indices tensor([1], device='cuda:1')                                                                                  
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 1,  indices tensor([4], device='cuda:0')                                                                                  
    [rank0]:[W224 16:51:33.619610762 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in $
    n extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. 
    Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())                                                     
    [rank1]:[W224 16:51:33.629522783 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in $
    n extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. 
    Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())                                                     
    {'loss': 0.5852, 'grad_norm': 11.941976547241211, 'learning_rate': 4.9500000000000004e-05, 'train coverage': 0.2, 'epoch': 0.2}                                                                             
      1%|β–ˆβ–‹                                                                                                                                                                     | 1/100 [00:00<00:52,  1.90it/s$
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 1,  indices tensor([7], device='cuda:0')                                                                                  
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 1,  indices tensor([5], device='cuda:1')                                                                                  
    {'loss': 0.7744, 'grad_norm': 7.989934921264648, 'learning_rate': 4.9e-05, 'train coverage': 0.4, 'epoch': 0.4}                                                                                             
      2%|β–ˆβ–ˆβ–ˆβ–Ž                                                                                                                                                                   | 2/100 [00:00<00:28,  3.38it/s$
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 1,  indices tensor([3], device='cuda:0')                                                                                  
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 1,  indices tensor([9], device='cuda:1')                                                                                  
    {'loss': 0.6889, 'grad_norm': 3.215988874435425, 'learning_rate': 4.85e-05, 'train coverage': 0.6, 'epoch': 0.6}                                                                                            
      3%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ                                                                                                                                                                  | 3/100 [00:00<00:23,  4.18it/s$
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 1,  indices tensor([0], device='cuda:0')                                                                                  
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 1,  indices tensor([8], device='cuda:1')                                                                                  
    {'loss': 0.7219, 'grad_norm': 5.315413475036621, 'learning_rate': 4.8e-05, 'train coverage': 0.8, 'epoch': 0.8}                                                                                             
      4%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹                                                                                                                                                                | 4/100 [00:01<00:20,  4.73it/s$
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 1,  indices tensor([2], device='cuda:1')                                                                                  
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 1,  indices tensor([6], device='cuda:0')                                                                                  
    {'loss': 0.7349, 'grad_norm': 9.186591148376465, 'learning_rate': 4.75e-05, 'train coverage': 1.0, 'epoch': 1.0}                                                                                            
      5%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž                                                                                                                                                              | 5/100 [00:01<00:17,  5.46it/s$
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 2,  indices tensor([6, 2], device='cuda:1')                                                                               
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 2,  indices tensor([5, 1], device='cuda:0')                                                                               
    {'loss': 0.6667, 'grad_norm': 7.5692901611328125, 'learning_rate': 4.7e-05, 'train coverage': 0.4, 'epoch': 1.4}                                                                                            
      6%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ                                                                                                                                                             | 6/100 [00:01<00:14,  6.41it/s$
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 2,  indices tensor([8, 3], device='cuda:1')                                                                               
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 2,  indices tensor([0, 9], device='cuda:0')                                                                               
    {'loss': 0.6908, 'grad_norm': 1.293290615081787, 'learning_rate': 4.6500000000000005e-05, 'train coverage': 0.8, 'epoch': 1.8}                                                                              
      7%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹                                                                                                                                                           | 7/100 [00:01<00:14,  6.41it/s$
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 1,  indices tensor([7], device='cuda:0')
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 1,  indices tensor([4], device='cuda:1')
    {'loss': 0.6941, 'grad_norm': 0.7848922610282898, 'learning_rate': 4.600000000000001e-05, 'train coverage': 1.0, 'epoch': 2.0}
      8%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž                                                                                                                                                         | 8/100 [00:01<00:11,  8.29it/s$
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 2,  indices tensor([8, 1], device='cuda:0')
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 2,  indices tensor([7, 5], device='cuda:1')
    {'loss': 0.6856, 'grad_norm': 9.592278480529785, 'learning_rate': 4.55e-05, 'train coverage': 0.4, 'epoch': 2.4}
      9%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ                                                                                                                                                        | 9/100 [00:01<00:11,  7.86it/s$
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 2,  indices tensor([6, 0], device='cuda:0')
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 2,  indices tensor([9, 4], device='cuda:1')
    {'loss': 0.6908, 'grad_norm': 0.9208756685256958, 'learning_rate': 4.5e-05, 'train coverage': 0.8, 'epoch': 2.8}
     10%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ                                                                                                                                                     | 10/100 [00:01<00:11,  7.94it/s$
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 1,  indices tensor([2], device='cuda:0')
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 1,  indices tensor([3], device='cuda:1')
    {'loss': 0.7796, 'grad_norm': 20.710803985595703, 'learning_rate': 4.4500000000000004e-05, 'train coverage': 1.0, 'epoch': 3.0}
     11%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž                                                                                                                                                   | 11/100 [00:01<00:10,  8.41it/s]
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 3,  indices tensor([6, 3, 8], device='cuda:0')
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 3,  indices tensor([0, 7, 5], device='cuda:1')
    {'loss': 0.5507, 'grad_norm': 9.137839317321777, 'learning_rate': 4.4000000000000006e-05, 'train coverage': 0.6, 'epoch': 3.6}
     12%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰                                                                                                                                                  | 12/100 [00:01<00:10,  8.41it/s]
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 2,  indices tensor([9, 4], device='cuda:1')
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 2,  indices tensor([1, 2], device='cuda:0')
    {'loss': 0.8429, 'grad_norm': 9.582053184509277, 'learning_rate': 4.35e-05, 'train coverage': 1.0, 'epoch': 4.0}
     13%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ                                                                                                                                                | 13/100 [00:01<00:09,  9.06it/s]
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 0) local batch size 3,  indices tensor([0, 9, 7], device='cuda:0')
    compute loss Current GPU device: Tesla V100-SXM2-32GB (Device 1) local batch size 3,  indices tensor([4, 6, 3], device='cuda:1')
    {'loss': 0.7207, 'grad_norm': 5.453138828277588, 'learning_rate': 4.3e-05, 'train coverage': 0.6, 'epoch': 4.6}
    

πŸ€“ Solution Outline

Challenges

πŸ’‘ I preferred a minimally intrusive extension of hf Trainer. Therefore, I rely on callbacks for resetting the batch size and determining training termination rather than overriding existing methods to support the functionality.

πŸ’‘ The extension supports per-step and not per-epoch batch size adaptation for better flexibility. Therefore, the length of the DataLoader is not a priori known at the beginning of the epoch.

πŸ’‘ Due to the varying length of DataLoader, the logs and callbacks should be carefully handled and/or redefined.

πŸ’‘ Batch prefetching should be currently disabled to make sure the batch size is synchronously updated.

Main Components

πŸ“Œ AdaptiveBatchSampler inherits from PyTorch's BatchSampler. It overrides its iterator and adds a method for updating the batch size.

πŸ“Œ AdaptiveBatchSizeDataLoader inherits from PyTorch's DataLoader. It wraps AdaptiveBatchSampler and adds a set_epoch method for restarting the seed of the sampler at the beginning of each epoch.

πŸ“Œ AdaptiveBatchSizeTrainer inherits from Hugging Face's Trainer. It:

  • uses AdaptiveBatchSizeDataLoader and handles distributed data-parallel training.
  • calls callbacks for the batch size scheduler.
  • handles the revised/ extended logs.

Future Optimizations

βœ–οΈ Support transformers version >= 4.45.2.

βœ–οΈ Make it compatible with accelerate or support preemptive prefetching. Currently accelerate by default prefetches one batch ahead.

πŸ“š References

[1] Devarakonda, A., Naumov, M. and Garland, M., 2017. Adabatch: Adaptive batch sizes for training deep neural networks. arXiv preprint arXiv:1712.02029.

[2] Liu, A., Feng, B., Wang, B., Wang, B., Liu, B., Zhao, C., Dengr, C., Ruan, C., Dai, D., Guo, D. and Yang, D., 2024. Deepseek-v2: A strong, economical, and efficient mixture-of-experts language model. arXiv preprint arXiv:2405.04434.

[3] Li, Z., Deng, Y., Zhong, P., Razaviyayn, M. and Mirrokni, V., 2025. PiKE: Adaptive Data Mixing for Multi-Task Learning Under Low Gradient Conflicts. arXiv preprint arXiv:2502.06244.

[4] Chaudhry, A., Rohrbach, M., Elhoseiny, M., Ajanthan, T., Dokania, P.K., Torr, P.H. and Ranzato, M.A., 2019. On tiny episodic memories in continual learning. arXiv preprint arXiv:1902.10486.

[5] Wu, T., Luo, L., Li, Y.F., Pan, S., Vu, T.T. and Haffari, G., 2024. Continual learning for large language models: A survey. arXiv preprint arXiv:2402.01364.

πŸ’ Buy Me Flowers (or Cite)

If you use this repo in your research, please cite using the following BibTeX entry:

@software{adaptive_batch_size_hf_trainer,
title={Online Batch Size Adaptation in Hugging Face Trainer},
author={Apostolopoulou, Ifigeneia},
howpublished = {\url{https://github.com/ifiaposto/hf_trainer_with_online_batch_size}},
year={2025},
}

Community

Sign up or log in to comment