FAT5 (Flash Attention T5) ⚡
Introduction
FAT5 (for Flash Attention T5) is an implementation of T5 in PyTorch with an UL2 objective optimized for GPGPU for both training and inference. It uses an experimental feature for using Flash Attention (v2) with relative position encoding biases that allow to train or finetune the model on longer sequence lengths than the original T5. It also has support for other positional embeddings such as RoPE, ALiBi or FIRE.
This methodology enabled us to efficiently pretrain as a proof of concept a T5 with 147M parameters in French in a reasonable time (1,461H to see 419B tokens) and with limited resources (1 single A100; i.e. a computational budget of around €2,200) which you'll find the weights in this repo.
To achieve this, we designed CUDA/Triton kernels to make Flash Attention compatible with T5, and to provide linear inference, thus extending the context size that can be taken into account by the model.
Other optimizations have also been implemented, as detailed in a subsequent blog post.
Motivation
While a lot of effort has been focused on optimizing decoder-only models, in many practical applications older architectures remains useful. We focus on T5, an encoder-decoder architecture exhibiting very decent performances for instruction tuning or even sometimes outperforming much larger models when finetuned. Moreover it’s a natural architecture while considering distillation of much larger models.
A critical limitation of this architecture is the length of the sequence that these models can deal with due to the quadratic size in memory. While this quadratic term cannot be removed without considering other form of attention (like for LongT5), it can still be alleviated to accomodate longer sequence lengths.
Another limitation is the pre-training time, since techniques such as Flash Attention are not available for this architecture.
Our work
We used the nanoT5 implementation as the base for our work.
We worked on optimizing the core component of the model, which is the attention part. We used the Flash Attention (v2) that optimize both the memory usage and the efficient use of Tensor Cores.
We support different implementation of attention biases:
- Full attention biases with Flash Attention 2 using this PR
- T5-like relative position encoding biases with Flash Attention 2 using this PR
- Full attention biases with a triton implementation of Flash Attention 2
Other parts of the architecture where optimized using ad-hoc Triton kernels for the cross-entropy (and z-loss) and layernorm.
For pretext tasks during pre-training, we use the UL2 mixture of denoisers with the following 7 tasks:
denoiser_list=[
{"mu": 3.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[R]"},
{"mu": 8.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[R]"},
{"mu": 4.0, "r": 0.0, "max_spans": 1, "prefix": "[S]"},
{"mu": 3.0, "r": 0.5, "max_spans": max_token_length, "prefix": "[X]"},
{"mu": 8.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[X]"},
{"mu": 64.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[X]"},
{"mu": 64.0, "r": 0.5, "max_spans": max_token_length, "prefix": "[X]"}]
denoiser_proportions=[0.165, 0.165, 0.34, 0.0825, 0.0825, 0.0825, 0.0825]
where mu
: the span size, r
: the % of masking in the span and prefix
: the type of the pretext task (the meaning of the letters [R]
, [S]
and [X]
is described here).
As there was no implementation available in PyTorch, we added one and adapted a dynamic batching mechanism to reduce padding in the model.
Benchmarks
TFLOPS
The number of TFLOPS (trillions of floating-point calculations a processor can perform in one second) is probably the most eloquent measure of the impact of the optimizations carried out.
We therefore compare four approaches:
• the SPDA (Scaled Dot Product Attention) implementation with full bias,
• the same implementation but in Triton,
• the Flash Attention RPE implementation (our kernel),
• the Flash Attention implementation, i.e. without bias. We've included it here for reference, as it's unusable in practice for a T5.
For the forward pass, we have:
For the forward pass, we can see that the Triton approach achieves 1.34 times more FLOPS than SPDA, and that the Flash Attention RPE approach achieves 1.99 times more FLOPS than SPDA. We can also see that our bf16 implementation is equivalent to fp16 (doing even better at size 512).
For the backward pass, we have:
For the backward pass, the Triton implementation is less efficient than SPDA, with 0.71 times the FLOPS of SPDA. The Flash Attention RPE implementation is more or less equivalent to SPDA (1.018 times more FLOPS). We can also observe that Triton in head_dim 64 is more efficient than Triton in head_dim 128.
Torch vs Triton
We mentioned above that we had optimized parts of the architecture using ad hoc Triton kernels, namely the cross-entropy and RMSNorm layer. The following benchmarks should illustrate why. For cross-entropy, we obtain a forward pass 7 to 11.4 times faster, a backward pass 3.26 to 3.75 times faster and a memory reduced by a factor of 4:
For the RMSNorm layer, we obtain a 3 to 5 times faster forward pass, a 2.33 to 4.33 times faster reverse pass and a memory reduced by a factor of 3.2:
Note that all benchmark graphs can be generated automatically using the following code.
Applications
To French
We've pretrained a small (147M parameters) FAT5-UL2 in French. This is the model you'll find in this Hugging Face repo.
The dataset we used is a mixture of CulturaX, Wikipedia, justice_fr and The Stack.
Our tokenizer of size 32,768 (8**5) is trained on CulturaX and The Stack.
Our model is pre-trained on a sequence of 1,024 tokens on a single A100 for 1,461H (= 419.4B tokens seen) at an estimated cost of €2,200.
Carbon emissions for the pretraining of this model were estimated using the Machine Learning Impact calculator presented in Lacoste et al. (2019). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact:
• Hardware Type: A100 PCIe 40/80GB
• Hours used: 1,461h
• Cloud Provider: Private Infrastructure
• Carbon Efficiency (kg/kWh): 0.03696 kg (estimated from electricitymaps (average carbon intensity in France average between October 18, 2024 and December 19, 2024)
• Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid): 13.5 kg eq. CO2.
To other language
Our contribution focuses on French, with the pre-training and finetuning of models for comparison against French benchmarks. For other languages, we can't afford to do the same kind of work.
Nevertheless, to ensure that it can be used in other languages, we have developed a code for adapting already pre-trained (m)T5/FLAN-T5 weights to our method. In this way, we hope users of a specific language will be able to efficiently continue pre-training one of these models to adapt it to more recent data, for example.
Note, however, that this adaptation is limited, since the additional pre-training will have to be carried out within the precision of the original model. For example, if the model's weights are in FP32 (which is the case with the FLAN-T5), training will not be as fast as with the FAT5, which is in BF16.
For English speakers, we have already adapted the weights of the various versions of the FLANT-T5 to our method. All weights can be found in this Hugging Face collection.
To use one of the models, simply do the command:
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("CATIE-AQ/FAT5-small-flan-en", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
Pretraining
If you want to pre-train your own model (to be specialized in a specific domain for example, and thus benefit from a custom tokenizer), we included a tutorial to pretrain a small model on minipile to show how it should be done.
You can find the documentation of the model configuration file here.
Note that we tested and trained the model of the tutorial on A100. It may or may not work with other GPUs.
Roadmap
We invite you to consult the “Next stage” section of the blog post.
Citation
@misc {FAT5,
title = { FAT5: Flash Attention T5 },
author = { Boris ALBAR and Loïck BOURDOIS },
organization = { Centre Aquitain des Technologies de l'Information et Electroniques },
year = 2025,
url = { https://huggingface.co/spaces/CATIE-AQ/FAT5-report },
doi = { 10.57967/hf/4160 },
publisher = { Hugging Face }
}
License
Ackowledgment
We use the following repos and thanks the authors for this:
- nanoT5 for the simple implementation and the optimizer.
- Flash attention for the groundbreaking algorithm for computing attention.
- Hugging Face for their excellent library.
- FlagAttention for the implementation of FA2 in Triton.
- Unsloth for the simple Triton kernels of the cross-entropy and layernorm that we adapted to our usage.
- TurboT5 for the improvement of the February 2024 version of our work.
This work was support by the Vaniila platform.
- Downloads last month
- 3