Spaces:
Running
Running
<html> | |
<head> | |
<script src="distill.bundle.js" type="module" fetchpriority="high" blocking></script> | |
<script src="main.bundle.js" type="module" fetchpriority="low" defer></script> | |
<meta name="viewport" content="width=device-width, initial-scale=1"> | |
<meta charset="utf8"> | |
<base target="_blank"> | |
<title>FAT5: Flash Attention T5</title> | |
<link rel="stylesheet" href="style.css"> | |
</head> | |
<body> | |
<d-front-matter> | |
<script id='distill-front-matter' type="text/json">{ | |
"title": "FAT5 : Flash Attention T5", | |
"description": "", | |
"published": "May 28, 2024", | |
"authors": [ | |
{ | |
"author":"Boris ALBAR", | |
"authorURL":"https://github.com/b-albar", | |
"affiliation": [{"name": "CATIE", "url": "https://catie.fr"}] | |
}, | |
{ | |
"author":"Loïck BOURDOIS", | |
"authorURL":"https://github.com/lbourdois", | |
"affiliation": [{"name": "CATIE", "url": "https://catie.fr"}] | |
} | |
], | |
"color": "#9CA3AF", | |
"katex": { | |
"delimiters": [ | |
{"left": "$$", "right": "$$", "display": false} | |
] | |
} | |
} | |
</script> | |
</d-front-matter> | |
<d-title> | |
<h1 class="l-page" style="text-align: center;">FAT5: Flash Attention T5</h1> | |
<p><img src="./assets/FAT5_dark.gif" alt="FAT5" width="100%"></p> | |
</d-title> | |
<d-article> | |
<d-contents> | |
</d-contents> | |
<div class="note"> For a better experience, we do not recommend reading on a cell phone. </div> | |
<h2 id="motivation">Motivation</h2> | |
<p class="width_125"> | |
While much effort has been devoted to optimising decoder transformers, thus abandoning the encoder, we believe it is essential to maintain an encoder-decoder architecture.<br> | |
Indeed, this architecture, which offers interesting performance for instruction tuning <d-cite bibtex-key="chia2023instructeval"></d-cite>, is suitable for distillation <d-cite bibtex-key="hsieh2023distilling"></d-cite> and seems superior to decoder models when finetuned <d-cite bibtex-key="fu2024tiny"></d-cite>. | |
It has also been shown that encoder-decoder models trained with masked language modelling achieve better zero-shot performance after multitasking finetuning compared with a decoder model <d-cite bibtex-key="wang2022languagemodelarchitecturepretraining"></d-cite>.<br> | |
Beyond NLP, which is the focus of this blog post, encoder-decoder architecture is widely used in other fields such as audio or time series, for example. Note also that the encoder of such architecture is also used in some diffusion models.<br> | |
That's why we've decided to focus on the T5 <d-cite bibtex-key="JMLR:v21:20-074"></d-cite>.<br><br> | |
This article presents the optimisations we have implemented to efficiently pre-train a T5 in French with 147M parameters in a reasonable time (1,461 H for 419B tokens) and with limited resources (1 single A100; i.e. a computing budget of around 2,200 euros). | |
To achieve this, we designed CUDA/Triton kernels to make Flash Attention compatible with T5 and provide linear inference, thus extending the context size that can be taken into account by the model.<br><br> | |
<strong>The pre-training code is available in our <a class="link" href="https://github.com/catie-aq/flashT5">GitHub repository</a> under Apache-2.0 license and weights on our <a class="link" href="https://hf.co/CATIE-AQ">Hugging Face</a> account.</strong> | |
<p class="width_125"><br><br><br></p> | |
<h2 id="vue-d-ensemble-de-notre-travail">Overview of our work</h2> | |
<p class="width_125">We therefore chose to work with a T5 and in practice with the nanoT5 <d-cite bibtex-key="nawrot2023nanot5"></d-cite>.<br> | |
For pretext tasks during pre-training, we followed the UL2 ones <d-cite bibtex-key='tay2023ul2'></d-cite> with the following 7 tasks:</p> | |
<pre><code class="lang-py"> | |
denoiser_list=[ | |
{<span class="hljs-string">"mu"</span>: <span class="hljs-number">3.0</span>, <span class="hljs-string">"r"</span>: <span class="hljs-number">0</span>.<span class="hljs-number">15</span>, <span class="hljs-string">"max_spans"</span>: max_token_length, <span class="hljs-string">"prefix"</span>: <span class="hljs-string">"[R]"</span>}, | |
{<span class="hljs-string">"mu"</span>: <span class="hljs-number">8.0</span>, <span class="hljs-string">"r"</span>: <span class="hljs-number">0</span>.<span class="hljs-number">15</span>, <span class="hljs-string">"max_spans"</span>: max_token_length, <span class="hljs-string">"prefix"</span>: <span class="hljs-string">"[R]"</span>}, | |
{<span class="hljs-string">"mu"</span>: <span class="hljs-number">4.0</span>, <span class="hljs-string">"r"</span>: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>, <span class="hljs-string">"max_spans"</span>: <span class="hljs-number">1</span>, <span class="hljs-string">"prefix"</span>: <span class="hljs-string">"[S]"</span>}, | |
{<span class="hljs-string">"mu"</span>: <span class="hljs-number">3.0</span>, <span class="hljs-string">"r"</span>: <span class="hljs-number">0</span>.<span class="hljs-number">5</span>, <span class="hljs-string">"max_spans"</span>: max_token_length, <span class="hljs-string">"prefix"</span>: <span class="hljs-string">"[X]"</span>}, | |
{<span class="hljs-string">"mu"</span>: <span class="hljs-number">8.0</span>, <span class="hljs-string">"r"</span>: <span class="hljs-number">0</span>.<span class="hljs-number">15</span>, <span class="hljs-string">"max_spans"</span>: max_token_length, <span class="hljs-string">"prefix"</span>: <span class="hljs-string">"[X]"</span>}, | |
{<span class="hljs-string">"mu"</span>: <span class="hljs-number">64.0</span>, <span class="hljs-string">"r"</span>: <span class="hljs-number">0</span>.<span class="hljs-number">15</span>, <span class="hljs-string">"max_spans"</span>: max_token_length, <span class="hljs-string">"prefix"</span>: <span class="hljs-string">"[X]"</span>}, | |
{<span class="hljs-string">"mu"</span>: <span class="hljs-number">64.0</span>, <span class="hljs-string">"r"</span>: <span class="hljs-number">0</span>.<span class="hljs-number">5</span>, <span class="hljs-string">"max_spans"</span>: max_token_length, <span class="hljs-string">"prefix"</span>: <span class="hljs-string">"[X]"</span>}] | |
denoiser_proportions=[<span class="hljs-number">0</span>.<span class="hljs-number">165</span>, <span class="hljs-number">0</span>.<span class="hljs-number">165</span>, <span class="hljs-number">0</span>.<span class="hljs-number">34</span>, <span class="hljs-number">0</span>.0825, <span class="hljs-number">0</span>.0825, <span class="hljs-number">0</span>.0825, <span class="hljs-number">0</span>.0825] | |
</code></pre> | |
<p class="width_125">with <code>mu</code> the n-gram size, <code>r</code> the masking percentage in the n-gram and <code>prefix</code> the type of pretext task. | |
The meaning of letters <code>[R]</code>, <code>[S]</code> and <code>[X]</code> is described <a class="link" href="https://huggingface.co/google/ul2#mixture-of-denoisers">here</a> | |
and we invite you to take a look at the following <a class="link" href="https://raw.githubusercontent.com/google-research/google-research/master/ul2/figs/mod.png">image</a> in particular.</p> | |
<p class="width_125">For a quick training, we decided to focus on the Flash Attention <d-cite bibtex-key="dao2022flashattention"></d-cite>. | |
However, as it does not manage the attentional (additive) biases of the T5, we had to extend it by developing a custom kernel. | |
More specifically, we successively developed two versions of this kernel. | |
In the first version, at the start of our work, we transmitted the bias matrix to the kernel. | |
In the current version, inspired by TurboT5 <d-cite bibtex-key='turbot5'></d-cite>, we only communicate a tensor containing the merged biases in order to materialise the bias matrix on the fly. | |
This makes it possible to switch from a T5 with quadratic memory to a T5 with linear memory, and consequently greatly increases the size of context that the model can support.</p> | |
<p class="width_125">Our work resulted in the pre-training of a T5 in French with 147M parameters: the FAT5 <i>small</i>.<br> | |
The dataset we used is made up of the French part of the CulturaX corpus <d-cite bibtex-key='nguyen2023culturax'></d-cite> (the main source with over 1258 GB of text), | |
the French part of Wikipedia <d-cite bibtex-key="wikidump"></d-cite> (dump 20231101), | |
justice_fr (French legal texts) <d-cite bibtex-key="justice_fr"></d-cite> | |
and 25,000,000 lines from TheStack <d-cite bibtex-key="Kocetkov2022TheStack"></d-cite> | |
(the idea here is to show our model a bit of code, although this is not our main objective).<br> | |
This model was evaluated on five tasks: text summarization, binary classification, question answering, named entity recognition and sentence similarity.</p> | |
<p class="width_125"><br><br><br></p> | |
<h2 id="les-d-tails-de-la-recette">Recipe details</h2> | |
<p class="width_125">With only two A100 (one 80GB and one 40GB), we had to spend some time implementing optimisations to get the best out of our hardware. | |
Indeed, before even training a model, or even modifying its architecture, we need to ensure that we are optimising the use of our GPUs' computing capacity. | |
There are several factors that can explain sub-optimal training of a deep learning model:<br> | |
• Disk-bounded<br> | |
• Memory-bounded<br> | |
• Compute-bounded</p> | |
<p class="width_125">Ideally, we would like the model to be limited by the speed of calculation, i.e. the GPU to be used at full capacity. | |
With this in mind, we worked on three main points: <br> | |
• GPU disk optimisation <br> | |
• GPU memory bandwidth optimisation <br> | |
• Optimisation of the use of Tensor Cores<br> | |
</p> | |
<p class="width_125">So it's a combination of hardware and software issues.</p> | |
<p></p> | |
<p class="width_125">In the rest of this section, everything we have done/implemented to address the limitations encountered is available in a green box. Notes/comments can be found in a blue box. | |
<br><br></p> | |
<h3 id="optimisation-du-disque-du-gpu">GPU disk optimisation</h3> | |
<p class="width_125">Disk limitation occurs either during data loading or during pre-processing operations. | |
In both cases, the problem manifests itself as slowness. | |
<br></p> | |
<h4 id="acc-s-disques">Disk access</h4> | |
<p class="width_125">If the limitation comes from disk access, there are several possible solutions:</p> | |
<ul> | |
<li><p class="width_125"><u>Put data in RAM</u><br> | |
This solves the problem radically, but assumes that the database fits into RAM, which is far from obvious given its small size.</p> | |
<div class="tip"><p>So this is not the solution we have chosen.</p></div> | |
</li> | |
<li><p class="width_125"><u>Put data on a faster and/or less-used disk</u><br> | |
If you have physical access to your GPU server, it is very useful to integrate <a class="link" href="https://fr.wikipedia.org/wiki/NVM_Express">NVMe</a> in its configuration.</p> | |
<p class="width_125">You also need to be careful not to have too many processes from different training pulling on the same disc. | |
It is therefore preferable to have several small discs rather than one large one.</p> | |
<div class="note"><p>A beneficial indirect effect is that such a configuration costs less 😉</p></div> | |
</li> | |
</ul> | |
<ul> | |
<li class="width_125"><u>Use more efficient file formats, particularly in terms of random accesses</u><br> | |
For example <code>.parquet</code> files are more efficient than <code>.csv</code>. | |
We can also use formats specifically developed for this purpose, such as the <code>.beton</code> from ffcv <d-cite bibtex-key="leclerc2023ffcv"></d-cite>.</li> | |
<div class="tip"><p>We use the Datasets library <d-cite bibtex-key="lhoest2021datasets"></d-cite> to load and process the data we use. | |
With this library, the data is decompressed locally in the <code>Arrow</code> format. | |
Moreover, if the data loaded from the Hugging Face Hub has been added using the <code>push_to_hub()</code> function, | |
then the dataset is converted by default in <code>parquet</code>.</p></div> | |
</ul> | |
<ul> | |
<li class="width_125"><u>Pre-tokenise data</u><br> | |
The most effective option is probably to pre-tokenise the data in order to optimise access. | |
In other words, tokenisation takes place in a preliminary stage and not on the fly. | |
</li> | |
<div class="tip"><p>Readers are invited to consult the following | |
<a class="link" href="https://github.com/catie-aq/flashT5/blob/main/examples/minipile/pretokenize_minipile.py">code</a> which | |
illustrates how we proceed in our FAT5 tutorial applied to the Minipile dataset <d-cite bibtex-key="kaddour2023minipilechallengedataefficientlanguage"></d-cite>.</p></div> | |
</ul> | |
<p><br></p> | |
<h4 id="traitement-des-donn-es">Data processing</h4> | |
<p class="width_125">If the limitation comes from the processing of data after they have been uploaded:</p> | |
<ul> | |
<li><p class="width_125"><u>Several processes can be used to process data in parallel</u><br> | |
For example, the parameter <code>num_workers</code> of the <code>Dataloader</code> of PyTorch <d-cite bibtex-key="paszke2019pytorch"></d-cite>.</p></li> | |
<div class="tip"><p>You can find in our code the values we use for this parameter for our FAT5 small <a class="link" href="https://github.com/catie-aq/flashT5/blob/dfe10d498ae0b39082182f807acb509e91992360/configs/fr/fat5-fr-small.yaml#L42">small</a>.</div> | |
</ul> | |
<ul> | |
<li><p class="width_125"><u>The bottleneck can also come from the <code>DataCollator</code></u><br> | |
This is especially the case when there are complex tasks to perform (image masking or multiple denoisers on NLP tasks).<br> | |
We can then build a custom <code>DataCollator</code> for the task. | |
On appliquera les méthodes traditionnelles pour optimiser la vitesse de celui-ci. | |
Similarly, using Numpy's vectorisation will allow lists to be processed more quickly than with <code>for</code> loops. | |
Generally speaking, Numpy is faster than PyTorch for this type of task. | |
You can also use compilation methods such as numba <d-cite bibtex-key="10.1145/2833157.2833162"></d-cite> for Python, for example.</p></li> | |
<div class="tip"><p>We followed this principle and developed a custom <code>DataCollator</code> for our FAT5. | |
You can find the code <a class="link" href="https://github.com/catie-aq/flashT5/blob/main/src/data/data_collator_ul2.py">here</a>. | |
It manages UL2 pretext tasks and has a dynamic batch mechanism to reduce padding (more information in the next section)</p></div> | |
<div class="note"><p>As there was no implementation of UL2's <code>DataCollator</code> available in PyTorch until now, | |
we hope this may be useful for other work.</p></div> | |
</ul> | |
<ul> | |
<li><p class="width_125"><u>Effective padding</u><br> | |
<p class="width_125">When working with sequences, there is a natural tendency to pad a set of sequences in order to build batches. | |
The padding tokens then generate unnecessary calculations.<br> | |
The first thing to do is to limit padding to the maximum size sequence and not to a maximum value. | |
This is the <a class="link" href="https://huggingface.co/learn/nlp-course/chapter3/2?fw=pt#dynamic-padding">dynamic padding</a> technique.<br> | |
With this approach, padding tokens may still remain. There are two ways of managing them:<br> | |
• use a method for grouping data of similar sizes | |
(for example, <a class="link" href="https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.group_by_length">this parameter</a> | |
in the Transformers library <d-cite bibtex-key="wolf2020huggingfaces"></d-cite> or | |
<a class="link" href="https://discuss.huggingface.co/t/how-to-implement-trainers-group-by-length-in-pytorch/9232">by retrieving this sampler</a> for PyTorch)<br> | |
• concatenate different examples in a custom DataCollator.</p> | |
<div class="tip"><p>We have opted for the second option and refer the reader back to the | |
<a class="link" href="https://github.com/catie-aq/flashT5/blob/main/src/data/data_collator_ul2.py">code</a> our DataCollator.</p></div> | |
<div class="note"><p>More optimised heuristics probably need to be put in place. | |
We carried out a test by proposing a | |
<a class="link" href="https://github.com/catie-aq/flashT5/blob/dfe10d498ae0b39082182f807acb509e91992360/src/data/data_collator_ul2.py#L45">function</a> | |
in the <code>DataCollator</code> to sort <code>input_ids</code> and <code>labels</code> by descending length. | |
However, this is rather time-consuming for a minimal packaging gain. | |
More work needs to be done on this point. | |
</p></div> | |
</ul> | |
<p class="width_125"><br><br></p> | |
<h3 id="optimisation-de-la-bande-passante-de-la-m-moire-du-gpu">GPU memory bandwidth optimisation</h3> | |
<p class="width_125">Memory bandwidth limitation is more difficult to deal with. | |
A memory-limited operation is one whose overall execution time is restricted by memory accesses. | |
This is particularly the case for LLMs, especially at the inference level. | |
The diagnosis can be made from the <a class="link" href="https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html">PyTorch profiler</a>:</p> | |
<figure class="width_125"> | |
<img src="https://pytorch.org/tutorials/_static/img/profiler_overview1.png" alt="profiler_overview1.png" width="100%"> | |
<figcaption><center><i>Source: <a class="link" href="https://pytorch.org/tutorials/_static/img/profiler_overview1.png">https://pytorch.org/tutorials/_static/img/profiler_overview1.png</a></i></center></figcaption> | |
</figure> | |
<br><br><br> | |
<p class="width_125">Another way of establishing a diagnosis is to use a simple <code>nvidia-smi</code>:</p> | |
<figure class="width_125"> | |
<img src="./assets/nvidiasmi.png" alt="nvidiasmi.png" width="100%"> | |
</figure> | |
<br> | |
<p class="width_125">Useful for finding out if a problem is present, but gives limited information about the nature of the problem. | |
That's why we prefer the profiler.</p> | |
<p><br></p> | |
<h4 id="noyau-cuda">CUDA kernel</h4> | |
<p class="width_125">The main technique for optimising GPU memory bandwidth is to develop a CUDA kernel that merges several limiting operations into SRAM. | |
This can limit the copying of large matrices into the HBM and then immediately reloading them into SRAM. | |
This is now a common feature of decoder transformers thanks to the <a class="link" href="https://github.com/Dao-AILab/flash-attention">Flash Attention</a>.</p> | |
<div class="tip"><p> | |
As Flash Attention does not manage the (additive) attentional biases of the T5, we extended it by developing a custom CUDA kernel. | |
As mentioned in the introduction, we actually implemented two successive versions of this kernel. | |
Without going into the details of the 650 lines of code in the implementation of the first version (which can be consulted | |
<a class="link" href="https://github.com/Dao-AILab/flash-attention/pull/617">here</a>), | |
the general and simplified idea (for a forward pass) is as follows:</p> | |
<ul> | |
<li>The expected output O, initialised with 0's, is loaded from the HBM to the SRAM, as well as the query Q, the key K, the value V and the biases B.</li> | |
<li>Our CUDA kernel calculates the following steps:<br> | |
• Compute the matrix S using the matrix product of Q and the transpose of K<br> | |
• Compute S', which is the sum of the S matrix and the bias matrix B<br> | |
• Compute P, which is the softmax (cumulative under the hood) of S’<br> | |
• Compute the output O, which is the matrix product of P and V<br> | |
</li> | |
<li>Output O is loaded on the HBM and the SRAM is cleared. | |
<picture> | |
<source media="(prefers-color-scheme: dark)" srcset="./assets/FAT5_dark.gif"> | |
<img alt="FAT5 animation" src="./assets/FAT5.gif" width="100%"> | |
</picture> | |
</ul> | |
<br> | |
<p>While the first version of the kernel is generic, the second (available <a class="link" href="https://github.com/Dao-AILab/flash-attention/pull/956">here</a>) | |
is specific to the working of models with relative positional encoding (which is the case of the T5). | |
The general and simplified idea (for a forward pass) is as follows:</p> | |
<ul> | |
<li>In the HBM, we have the expected output O (output) initialised with 0s, as well as the query Q (query), the key K (key) and the value V (value). | |
However, we don't have the bias matrix B as before, but the bucket of tensors T.</li> | |
<li>O is loaded from the HBM to the SRAM, along with the query Q, the key K, the value V and the tensor bucket T.</li> | |
<li>Our CUDA kernel calculates the following steps:<br> | |
• Compute the matrix S using the matrix product of Q and the transpose of K<br> | |
• Compute S', which is the sum of the matrix S and a matrix filled with the elements of T<br> | |
• Compute P, which is the softmax (cumulative under the hood) of S’<br> | |
• Compute the output O, which is the matrix product of P and V<br> | |
</li> | |
<li>Output O is loaded on the HBM and the SRAM is cleared. | |
</ul> | |
<p> | |
In this way, whereas the first version of the B bias matrix required a quadratic memory, | |
here we are back to a linear memory enabling inferences to be performed on tens of thousands of tokens.<br> | |
To design this second version, we were inspired by the TurboT5's Triton kernel, which we ported to CUDA and extended to full BF16. | |
</p> | |
</div> | |
<br> | |
<div class="tip"><p>Note that the two versions developed can be used with several positional encodings.<br> | |
We invite the reader to consult this <a class="link" href="https://github.com/catie-aq/flashT5/blob/main/src/utils/positional_encoding.py">file</a> | |
containing classes compatible with Flash Attention for the | |
<a class="link" href="https://github.com/catie-aq/flashT5/blob/dfe10d498ae0b39082182f807acb509e91992360/src/utils/positional_encoding.py#L10">RelativePositionalEncoding</a> | |
<d-cite bibtex-key="shaw2018selfattention"></d-cite>, | |
the <a class="link" href="https://github.com/catie-aq/flashT5/blob/dfe10d498ae0b39082182f807acb509e91992360/src/utils/positional_encoding.py#L113">ALiBiPositionalEncoding</a> | |
<d-cite bibtex-key="press2022train"></d-cite>, | |
the <a class="link" href="https://github.com/catie-aq/flashT5/blob/dfe10d498ae0b39082182f807acb509e91992360/src/utils/positional_encoding.py#L205">RotaryPositionalEncoding</a> | |
<d-cite bibtex-key="su2023roformer"></d-cite> and | |
<a class="link" href="https://github.com/catie-aq/flashT5/blob/dfe10d498ae0b39082182f807acb509e91992360/src/utils/positional_encoding.py#L341">FIRE</a> <d-cite bibtex-key="li2024functional"></d-cite>.</p></div> | |
<div class="note"><p>At the time of writing, the two pull requests (one for each kernel version, | |
available <a class="link" href="https://github.com/Dao-AILab/flash-attention/pull/617">here</a> | |
and <a class="link" href="https://github.com/Dao-AILab/flash-attention/pull/956">here</a>) | |
opened on the official Flash Attention repository have not been merged. | |
Readers will therefore have to temporarily recompile our custom Flash Attention patches to be able to use our models.<br> | |
Readers are invited to consult the Benchmark section further below to see the improvements brought by these two kernels.</p></div> | |
<br> | |
<div class="note"><p>Although we didn't use them, it should be noted that some libraries contain merged implementations of common operators, for example Apex <d-cite bibtex-key="nvidiapex"></d-cite>.</p></div> | |
<p><br></p> | |
<h4 id="noyau-triton">Triton kernel</h4> | |
<p class="width_125">Triton <d-cite bibtex-key="10.1145/3315508.3329973"></d-cite> is a maintained programming language that allows Python code to be compiled efficiently, like CUDA, but with the advantage of being (from our point of view) easier to learn. Unlike CUDA, which requires an in-depth understanding of GPU hardware architecture, Triton ignores many low-level details such as memory coalescing, shared memory management and scheduling within CUDA thread blocks.</p> | |
<div class="tip"><p>A Triton implementation of the | |
<a class="link" href="https://github.com/catie-aq/flashT5/blob/main/src/model/ops/flash_attention_v2_bias.py">Flash Attention 2 managing attention bias</a> | |
is provided for those who do not wish to recompile a custom patch for Flash Attention. | |
To do this, we based ourselves on the FlagAttention repository <d-cite bibtex-key="flagattention"></d-cite>. | |
<br> | |
<br> | |
In addition to this implementation (whose use is optional), other parts of the architecture have been optimised using ad hoc Triton kernels, namely: | |
<br> | |
• the <a class="link" href="https://github.com/catie-aq/flashT5/blob/main/src/model/ops/cross_entropy_loss.py">cross entropy loss</a> (and the loss z <d-cite bibtex-key="debrébisson2016zloss"></d-cite>) <br> | |
• the <a class="link" href="https://github.com/catie-aq/flashT5/blob/main/src/model/ops/rms_norm.py">RMSNorm layer</a> <d-cite bibtex-key="zhang2019root"></d-cite> <br> | |
<br> | |
We drew inspiration from <a class="link" href="https://github.com/unslothai/unsloth">Unsloth</a> <d-cite bibtex-key="unsloth"></d-cite>.<br> | |
<br> | |
Readers are invited to refer to the Benchmark section below to see the impact of this optimisation.</div> | |
<p><br></p> | |
<h4 id="utiliser-torch-compile-">Use <code>torch.compile</code></h4> | |
<p class="width_125">A simpler approach is to compile the models with <code>torch.compile</code>. | |
PyTorch then takes care of making the possible merges, possibly by reordering operations. | |
This involves hunting down breaks in the compilation graph, which are returns to an eager execution mode that have a negative impact on the performance of the operation.</p> | |
<div class="note"><p>See the <a class="link" href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html">official documentation</a> for more details.</p></div> | |
<p class="width_125">Another possibility is to use both a custom kernel and <code>torch.compile</code>. | |
The implementation of this option has been greatly simplified since the | |
<a class="link" href="https://github.com/pytorch/pytorch/releases/tag/v2.4.0">version 2.4 of PyTorch</a>.</p> | |
<p class="width_125">Readers are invited to refer to the benchmark section at the end of the article to measure the memory performance | |
performance of the various techniques described.</p> | |
<p class="width_125"><br><br></p> | |
<h3 id="optimisation-de-l-utilisation-des-tensor-cores">Optimisation of the use of Tensor Cores</h3> | |
<p class="width_125">Recent GPUs have units dedicated to tensorial operations: the TensorCore. Using them correctly is essential.</p> | |
<p class="width_125">Once again, to establish a diagnosis, it is advisable to refer to the PyTorch profiler, which indicates the proportion of TensorCore used for each CUDA kernel:</p> | |
<p><figure class="width_125"> | |
<img src="https://pytorch.org/tutorials/_static/img/profiler_kernel_view.png" alt="profiler_kernel_view.png" width="100%"> | |
<figcaption><center><i>Source: <a class="link" href="https://pytorch.org/tutorials/_static/img/profiler_kernel_view.png">https://pytorch.org/tutorials/_static/img/profiler_kernel_view.png</a></i></center></figcaption> | |
</figure> | |
<br><br> | |
<p class="width_125">The optimisations that can be made are:<br></p> | |
<h4 id="puissances-de-2">Use multiples of 8 or 64</h4> | |
<p class="width_125">The first is to use tensor sizes that are multiples of 8 or 64. | |
Please refer to the Nvidia documentation, | |
in particular this <a class="link" href="https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc">article</a> | |
and this <a class="link" href="https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/">article</a> | |
to determine the multiple to select according to the desired precision.</p> | |
<div class="tip"><p>With this in mind, we trained a tokenizer of size 32 768 (8**5), | |
following <a class="link" href="https://twitter.com/karpathy/status/1621578354024677377">this observation by KARPATHY</a>. | |
This is a BPE tokenizer <d-cite bibtex-key="sennrich2016neuralmachinetranslationrare"></d-cite> trained on CulturaX and The Stack, using 256 extra_tokens and the numbers are separated.<br> | |
Readers can find the code used <a class="link" href=https://github.com/catie-aq/flashT5/blob/main/examples/fat5-fr/train_tokenizer.py">here</a>. | |
</p></div> | |
<p><br></p> | |
<h4 id="utiliser-le-bon-optimiseur">Use the right optimiser</h4> | |
<p class="width_125">Changing optimisers from the initial implementation of the model can be a good way of speeding up convergence of the model (although it may prevent the results of the original paper from being reproduced).<br> | |
Optimisers speed up convergence by allowing large batch sizes, as in the case of LAMB <d-cite bibtex-key="you2020large"></d-cite> | |
or the use of higher learning rates such as Sophia <d-cite bibtex-key="liu2024sophia"></d-cite>.<br> | |
More efficient versions of the optimisers can also be used, such as the <code>fused</code> option | |
in the <a class="link" href="https://pytorch.org/docs/stable/generated/torch.optim.Adam.html">Adam optimiser</a> available in PyTorch | |
or the optimisers available in <a class="link" href="https://github.com/NVIDIA/apex">Apex</a>.</p> | |
<div class="tip"><p> | |
We used the original T5 optimiser, <a class="link" href="https://github.com/catie-aq/flashT5/blob/main/src/utils/adamw_scaled.py">AdamWScale</a>. | |
For hyperparameter values, we use <code>lr = 5e-3</code>, <code>betas = (0.9, 0.999)</code>, <code>eps = 1e-6</code> et <code>weight_decay = 0.0</code> | |
based on the observations of <a class="link" href="https://github.com/PiotrNawrot/nanoT5/issues/25#issuecomment-1922731400">Wilson Wongso</a>. | |
Indeed, it turns out that not all the alternative optimisers tested converged.</p></div> | |
<div class="note"><p>We have added the parameter <code>foreach</code> in our version of AdamWScale.</p> | |
</div> | |
<p><br></p> | |
<h4 id="entra-ner-ses-mod-les-en-bf16-ou-fp16-">Training models in <code>bf16</code></h4> | |
<p class="width_125">Recent GPUs make it possible to full exploit the use of reduced precision | |
(enabling a gain of a factor of 2 in throughput compared to the <code>fp32</code> precision). | |
The <code>bf16</code> is only available on Ampere or more recent architectures, but allows you to avoid loss scaling | |
<d-cite bibtex-key="micikevicius2018mixed"></d-cite> method which is generally necessary in <code>fp16</code> | |
thanks to a wider dynamic range (the exponent is coded on 8 bits like the <code>fp32</code>).</p> | |
<div class="tip"><p>With this in mind, we train our models in <code>bf16</code>. | |
More specifically, while at the beginning of our experiments we used <code>bf16-mixed</code>, we have used the | |
<a class="link" href="https://en.wikipedia.org/wiki/Kahan_summation_algorithm">Kahan summation algorithm</a> | |
so that we can use <code>full bf16</code> in our optimizer.<br> | |
Once again, the code for our optimizer is accessible <a class="link" href="https://github.com/catie-aq/flashT5/blob/main/src/utils/adamw_scaled.py">here</a>. | |
</p></div> | |
<p><br></p> | |
<h4 id="utiliser-moins-de-m-moire-du-gpu">Use less GPU memory</h4> | |
<p class="width_125">Certain techniques exist to limit the use of GPU memory by the model, such as the | |
<a class="link" href="https://pytorch.org/docs/stable/checkpoint.html">gradient checkpointing</a> | |
or ZeRO-type methods <d-cite bibtex-key="rajbhandari2020zero"></d-cite> implemented in | |
<a class="link" href="https://github.com/microsoft/DeepSpeed">DeepSpeed</a>. | |
By limiting the amount of memory used, larger batch sizes can be used to speed up model training.</p> | |
<p class="width_125"><br><br></p> | |
<h3 id="autres">Other</h3> | |
<h4 id="le-parall-lisme">Parallelism</h4> | |
<p class="width_125">Using several GPUs is tricky. | |
Done naively, it can result in lower performance than implementation on a single GPU, wasting computing resources. | |
This is particularly the case when bottlenecks occur in communications between GPUs. | |
The aim is to ensure that the model is not limited by the bandwidth between the cards, or to ensure that the cards are connected with sufficient | |
bandwidths via techniques such as <a class="link" href="https://en.wikipedia.org/wiki/NVLink">NVLink</a> for example. </p> | |
<p class="width_125">It should also be noted that optimisation techniques generally require all the GPUs to be synchronised at the end of a batch. | |
As a result, if one GPU is slower than the others (or is being used by another process), the model is limited to the speed of the slowest GPU in the group. </p> | |
<div class="note"><p> | |
Having pre-trained our model on a single 80GB A100, we were unable to experiment with parallelism.</p> | |
</div> | |
<p><br></p> | |
<h4 id="les-t-tes-pour-le-finetuning">Finetuning heads</h4> | |
<p class="width_125">We looked at the elements listed above with a view to optimising the pre-training of our model. | |
In practice, we then need to fine-tune it to specialise on the final tasks that interest us. | |
To do this, we use heads. For the <a class="link" href="https://huggingface.co/docs/transformers/model_doc/t5">vanilla T5</a>, | |
five are available in Transformers to perform all feasible tasks: | |
<a class="link" href="https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5ForConditionalGeneration"><code>T5ForConditionalGeneration</code></a>, | |
<a class="link" href="https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5ForSequenceClassification"><code>T5ForSequenceClassification</code></a>, | |
<a class="link" href="https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5ForTokenClassification"><code>T5ForTokenClassification</code></a>, | |
<a class="link" href="https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5ForQuestionAnswering"><code>T5ForQuestionAnswering</code></a> | |
et <a class="link" href="https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel"><code>T5EncoderModel</code></a>.<br><br> | |
Here again, optimisation work can be carried out.<br> | |
For conditional generation, the main point is to ensure that the generation process is efficient.<br> | |
For heads involved in classification tasks (sequence, NER and QA), it is necessary to ensure that the encoder part | |
of the T5 is used, since the decoder is not essential for these tasks, as shown in EncT5 <d-cite bibtex-key="liu2022enct5"></d-cite>. | |
The decoder weights take up unnecessary memory space, and the execution time of the finetuning code is doubled unnecessarily.<br> | |
The last head is simply used to retain only the encoder part of an encoder-decoder model. It therefore does not need to be optimised.</p> | |
<div class="tip"><p> | |
About the head<code>ForConditionalGeneration</code>, our | |
<a class="link" href="https://github.com/catie-aq/flashT5/blob/684d02640464ea8bd2339689ce37da2d4e3b5f0b/src/model/modeling_flash_t5.py#L593">implementation</a> | |
is based on the generation process available in the | |
<a class="link" href="https://github.com/PiotrNawrot/nanoT5/blob/1c82d67bf8dea635be68a3b2a68a43b68b665193/nanoT5/utils/t5_model.py#L407">nanoT5</a> | |
because it is 14% faster than the Hugging Face implementation.<br> | |
For classification heads, the implementation is available in this | |
<a class="link" href="https://github.com/catie-aq/flashT5/blob/main/src/model/custom_heads_flash_t5.py">file</a>. | |
This file is separate from the modelling file because our implementations differ from those available in Transformers. | |
Indeed, heads <code>T5ForSequenceClassification</code> and <code>T5ForQuestionAnswering</code> available in Transformers are based | |
on the T5 encoder and decoder, which is inefficient. | |
We therefore recoded these two heads to use only the encoder. | |
We then followed the same structure as the <code>T5ForTokenClassification</code> head available in Transformers, | |
which also only uses the encoder, and so have used as is.</p> | |
</div> | |
<p class="width_125"><br><br><br></p> | |
<h2 id="benchmark">Benchmark</h2> | |
<h3 id="TFLOPS">TFLOPS</h3> | |
<p class="width_125"> | |
The number of TFLOPS (trillions of floating-point calculations a processor can perform in one second) is probably the most telling metric to demonstrate the impact of the optimizations carried out.<br> | |
We compare four approaches:<br> | |
• SPDA (Scaled Dot Product Attention) implementation with full bias,<br> | |
• the same implementation but in Triton,<br> | |
• the Flash implementation Attention RPE, i.e. the second kernel we've developed (can be seen as turboT5 but in C++/Cuda),<br> | |
• the Flash implementation Attention i.e. without bias. We've included it for reference but it's unusable in practice for a T5.<br> | |
<br> | |
For the forward pass, we have: | |
</p> | |
<p class="width_125"> | |
<picture> | |
<source media="(prefers-color-scheme: dark)" srcset="./assets/FWD-causal-True_dark.png" width="100%"> | |
<img alt="Benchmark memory backward pass" src="./assets/FWD-causal-True.png" width="100%"> | |
</picture> | |
<div class="width_125"><p>For the forward pass, the Triton approach achieves 1.34 times more FLOPS than the SPDA approach, while the Flash Attention RPE approach achieves 1.99 times more FLOPS than the SPDA approach.<br> | |
We can also see that our bf16 implementation is equivalent to fp16 (doing even better at size 512).<br> | |
Following this benchmark, we decided to train our French model in bf16, head_dim = 128 and with a sequence of 1024.</p></div> | |
<br> | |
<p class="width_125"> | |
For the backward pass, we have: | |
</p> | |
<p class="width_125"> | |
<picture> | |
<source media="(prefers-color-scheme: dark)" srcset="./assets/BWD-causal-True_dark.png" width="100%"> | |
<img alt="Benchmark memory backward pass" src="./assets/BWD-causal-True.png" width="100%"> | |
</picture> | |
<div class="width_125"><p>For the backward pass, the Triton implementation performed worse than SPDA, with 0.71 times the FLOPS of SPDA. The Flash Attention RPE implementation is more or less equivalent to SPDA (1.018 times more FLOPS).<br> | |
We can also observe that Triton in head_dim 64 is more efficient than Triton in head_dim 128.</p></div> | |
<p><br></p> | |
<h4 id="torchvstriton">Torch vs Triton</h4> | |
<p class="width_125"> | |
We mentioned previously that we had optimised parts of the architecture using ad hoc Triton kernels, namely the cross-entropy and RMSNorm layer. | |
The following benchmarks should illustrate why.<br> | |
For cross-entropy, we get a forward pass 7 to 11.4 times faster, a backward pass 3.26 to 3.75 times faster as well as a memory reduced by a factor of 4:</p>. | |
<p class="width_125"> | |
<picture> | |
<source media="(prefers-color-scheme: dark)" srcset="./assets/CE_dark.png" width="100%"> | |
<img alt="Benchmark memory backward pass" src="./assets/CE.png" width="100%"> | |
</picture> | |
<p class="width_125"> | |
For the RMSNorm layer, we get a forward pass 3 to 5 times faster, a backward pass 2.33 to 4.33 times faster as well as a memory reduced by a factor of 3.2:</p>. | |
<p class="width_125"> | |
<picture> | |
<source media="(prefers-color-scheme: dark)" srcset="./assets/LN_dark.png" width="100%"> | |
<img alt="Benchmark memory backward pass" src="./assets/BLN.png" width="100%"> | |
</picture> | |
<p class="note"> | |
Note that all the benchmark graphs can be generated automatically using the following <a href="https://github.com/catie-aq/flashT5/tree/main/benchmarks">code</a>. | |
</p> | |
<p><br><br></p> | |
<h3 id="mod-le-en-fran-ais">Model in French</h3> | |
<p class="width_125">We applied our work to French by pre-training a 147M parameter model. <br> | |
The dataset we used is a mixture of CulturaX, Wikipedia, justice_fr and The Stack. <br> | |
Our tokenizer of size 32,768 (8**5) is trained on CulturaX and The Stack.<br> | |
Our model is pre-trained on a sequence of 1,024 tokens.</p> | |
<p class="width_125"> | |
We wanted to compare the performance of our model with other previously published French-language models, such as CamemBERT <d-cite bibtex-key="Martin_2020"></d-cite> for classification tasks and BARThez <d-cite bibtex-key="eddine2021barthez"></d-cite> for generation tasks.<br> | |
For this reason, we thought it important to make comparisons with an equivalent number of tokens seen. | |
We therefore tried to estimate the number of tokens seen by these two models using the formula number of steps × sequence size × batch size. We couldn't find the information in the BARThez publication to do this. For CamemBERT, we estimate a maximum of 419.4B tokens. This figure could actually be lower, as we don't know the number of padding tokens seen by this model (where in our case, we don't use any). So we have pre-trained our model on the maximum number of tokens seen by the CamemBERT.<br></p> | |
<p><br></p> | |
<p class="width_125"> | |
<picture> | |
<source media="(prefers-color-scheme: dark)" srcset="./assets/loss_train.png" width="49%"> | |
<img alt="Convergence masked accuracy FAT5" src="./assets/loss_train.png" width="49%"> | |
</picture> | |
<picture> | |
<source media="(prefers-color-scheme: dark)" srcset="./assets/loss_eval.png" width="49%"> | |
<img alt="Convergence masked accuracy FAT5" src="./assets/loss_eval.png" width="49%"> | |
</picture> | |
</p> | |
<p><br></p> | |
<p class="width_125"> | |
We were also interested in comparing our model against itself, i.e. we evaluated its performance on downstream tasks every 100,000 steps (~26 billion tokens) during pre-training.<br> | |
In the table below, we have listed the number of tokens equivalent to each interval of 100,000 steps.<br> | |
</p> | |
<table class="width_125"> | |
<thead> | |
<tr> | |
<th>Model</th> | |
<th>Number of tokens ✝</th> | |
</tr> | |
</thead> | |
<tbody> | |
<tr> | |
<td>FAT5-small-100K</td> | |
<td>26,214,400,000 (100,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-200K</td> | |
<td>52,428,800,000 (200,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-300K</td> | |
<td>78,643,200,000 (300,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-400K</td> | |
<td>104,857,600,000 (400,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-500K</td> | |
<td>131,072,000,000 (500,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-600K</td> | |
<td>157,286,400,000 (600,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-700K</td> | |
<td>183,500,800,000 (700,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-800K</td> | |
<td>209,715,200,000 (800,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-900K</td> | |
<td>235,929,600,000 (900,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1000K</td> | |
<td>262,144,000,000 (1,000,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1100K</td> | |
<td>288,358,400,000 (1,100,000× 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1200K</td> | |
<td>314,572,800,000 (1,200,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1300K</td> | |
<td>340,787,200,000 (1,300,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1400K</td> | |
<td>367,001,600,000 (1,400,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1500K</td> | |
<td>393,216,000,000 (1,500,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1600K</td> | |
<td>419,430,400,000 (1,600,000 × 1,024 × 256)</td> | |
</tr> | |
<tr> | |
<td><a class="link" href="https://hf.co/almanach/camembert-base">camembert (base ou large)</a></td> | |
<td>419,430,400,000 (100,000 × 512 × 8,192)</td> | |
</tr> | |
</tbody> | |
</table> | |
<p class="width_125">✝ equivalent to number of steps × sequence size × batch size</p> | |
<p><br></p> | |
<h4 id="finetuning">Finetuning</h4> | |
<p class="width_125">We focused on five tasks:<br> | |
• Summarising texts to illustrate the use of the head <code>T5ForConditionalGeneration</code>,<br> | |
• Binary classification to illustrate the use of the head <code>T5ForSequenceClassification</code>,<br> | |
• Named entity recognition to illustrate the use of the head <code>T5ForTokenClassification</code>,<br> | |
• Question answering to illustrate the use of the head <code>T5ForQuestionAnswering</code>.<br> | |
• Sentence similarity to illustrate the use of the head <code>T5EncoderModel</code>.</p> | |
<p class="width_125"> Classification tasks seem to us important to evaluate, as they are generally ignored by benchmarks of generative language models, even though they are often used in practice by companies (document retrieval, classification for customer reviews, data anonymization, etc.). | |
The fact that 6 and a half years after its release, BERT <d-cite bibtex-key="devlin2019bert"></d-cite> alone is downloaded more times per month than the <a class="link" href="https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads">30 text generation models</a> most downloaded on Hugging Face at the time of writing: 38.5M versus 31.3M.</p> | |
<p class="width_125">In the following tables, we underline for FAT5 the line with the best result for each task. We interpret the results of the generation part after the text summarization table. The classification results are interpreted after the binary classification, QA, NER and sentence-similarity tables.</p> | |
<p><br></p> | |
<h5>Summarization</h5> | |
<p class="width_125">For this task, we used the dataset <a class="link" href="https://huggingface.co/datasets/orange_sum">orange_sum</a><d-cite bibtex-key="eddine2021barthez"></d-cite>.</p> | |
<table class="width_125"> | |
<thead> | |
<tr> | |
<th>Model</th> | |
<th>ROUGE-1</th> | |
<th>ROUGE-2</th> | |
<th>ROUGE-L</th> | |
</tr> | |
</thead> | |
<tbody> | |
<tr> | |
<td>FAT5-small-100K (147M)</td> | |
<td>28.17</td> | |
<td>10.60</td> | |
<td>20.62</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-200K (147M)</td> | |
<td>28.72</td> | |
<td>10.86</td> | |
<td>20.68</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-300K (147M)</td> | |
<td>28.76</td> | |
<td>10.85</td> | |
<td>20.63</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-400K (147M)</td> | |
<td>28.59</td> | |
<td>10.76</td> | |
<td>20.60</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-500K (147M)</td> | |
<td>28.98</td> | |
<td>10.97</td> | |
<td>20.72</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-600K (147M)</td> | |
<td>29.04</td> | |
<td>11.20</td> | |
<td>20.89</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-700K (147M)</td> | |
<td>28.72</td> | |
<td>10.87</td> | |
<td>20.77</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-800K (147M)</td> | |
<td>29.00</td> | |
<td>10.91</td> | |
<td>20.78</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-900K (147M)</td> | |
<td>29.30</td> | |
<td>11.34</td> | |
<td>21.22</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1000K (147M)</td> | |
<td>29.10</td> | |
<td>11.21</td> | |
<td>21.08</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1100K (147M)</td> | |
<td>29.43</td> | |
<td>11.40</td> | |
<td>21.15</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1200K (147M)</td> | |
<td>29.30</td> | |
<td>11.38</td> | |
<td>21.18</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1300K (147M)</td> | |
<td>29.38</td> | |
<td>11.38</td> | |
<td>21.18</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1400K (147M)</td> | |
<td>29.29</td> | |
<td>11.18</td> | |
<td>21.14</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1500K (147M)</td> | |
<td><u>29.48</u></td> | |
<td><u>11.48</u></td> | |
<td><u>21.22</u></td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1600K (147M)</td> | |
<td>29.30</td> | |
<td>11.27</td> | |
<td>21.10</td> | |
</tr> | |
<tr> | |
<td><a class="link" href="https://huggingface.co/moussaKam/barthez">Barthez<d-cite bibtex-key="eddine2021barthez"></d-cite></a> (165M)</td> | |
<td>31.44</td> | |
<td>12.77</td> | |
<td>22.23</td> | |
</tr> | |
<tr> | |
<td><a class="link" href="https://huggingface.co/moussaKam/mbarthez">mBarthez</a> (458M)</td> | |
<td>32.67</td> | |
<td>13.73</td> | |
<td>23.18</td> | |
</tr> | |
</tbody> | |
</table> | |
<p><br></p> | |
<p class="width_125">We can see that our model performs worse than the Barthez. We can put forward a few hypotheses on this subject. <br> | |
Firstly, it's likely that our text generation process is not optimal. Not knowing the one used by the Barthez, we simply used the default parameters of the <a class="link" href="https://github.com/huggingface/transformers/blob/241c04d36867259cdf11dbb4e9d9a60f9cb65ebc/src/transformers/generation/utils.py#L1905">generate</a> function in transformers to avoid giving our model an advantage with a more sophisticated generation process.<br> | |
Secondly, we didn't use a prompt to condition the generation, which could have benefited our model since the T5 is the model that introduced this system.<br> | |
Thirdly, the Barthez surely saw more tokens than our model. Although we can't determine this number from the authors' publication, it is indicated that this is a BART model <d-cite bibtex-key="lewis2019bartdenoisingsequencetosequencepretraining"></d-cite> which received additional pre-training on French. However, BART's paper states that the model was trained on 500,000 steps × a sequence of 1,024 tokens × a batch of size 8000, i.e. 4,096,000,000,000 tokens, which is 9.76 times more than our model. | |
</p> | |
<p><br></p> | |
<h5 id="classification">Classification</h5> | |
<p class="width_125">We use a cleaned version of the allocine dataset <d-cite bibtex-key="allocine"></d-cite> : <a class="link" href="https://huggingface.co/datasets/CATIE-AQ/allocine_clean">allocine_clean</a>. Specifically, 0.6% of the test sample was unreliable because it contained leaks or duplicate data. It is likely that the resulting dataset is still imperfect, with annotation problems requiring further proofreading/correction. | |
</p> | |
<table class="width_125"> | |
<thead> | |
<tr> | |
<th>Model</th> | |
<th>Accuracy</th> | |
</tr> | |
</thead> | |
<tbody> | |
<tr> | |
<td>FAT5-small-100K (67.4M)</td> | |
<td>96.05</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-200K (67.4M)</td> | |
<td>96.20</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-300K (67.4M)</td> | |
<td>96.48</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-400K (67.4M)</td> | |
<td>96.60</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-500K (67.4M)</td> | |
<td>96.60</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-600K (67.4M)</td> | |
<td>96.60</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-700K (67.4M)</td> | |
<td>96.68</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-800K (67.4M)</td> | |
<td>96.59</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-900K (67.4M)</td> | |
<td><u>96.75</u></td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1000K (67.4M)</td> | |
<td>96.62</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1100K (67.4M)</td> | |
<td>96.69</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1200K (67.4M)</td> | |
<td>96.71</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1300K (67.4M)</td> | |
<td>96.69</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1400K (67.4M)</td> | |
<td>96.65</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1500K (67.4M)</td> | |
<td>96.57</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1600K (67.4M)</td> | |
<td>96.69</td> | |
</tr> | |
<tr> | |
<td><a class="link" href="">distillcamembert</a> (68.1M)</td> | |
<td>96.74</td> | |
</tr> | |
<tr> | |
<td><a class="link" href="https://huggingface.co/bourdoiscatie/camembert_base_cls">camembert-base</a> (111M)</td> | |
<td>97.27</td> | |
</tr> | |
<tr> | |
<td><a class="link" href="https://huggingface.co/bourdoiscatie/camembert_large_cls">camembert-large</a> (337M)</td> | |
<td>97.15</td> | |
</tr> | |
<tr> | |
</tbody> | |
</table> | |
<p class="width_125">Note: in this and the following tables, distillcamembert refers to a <a class="link" href="https://huggingface.co/cmarkea/distilcamembert-base">distilcamembert-base</a> <d-cite bibtex-key="delestre2022distilcamembert"></d-cite> that we have finetuned.</p> | |
<p><br></p> | |
<h5>Named entity recognition</h5> | |
<p class="width_125">For this task, we used frenchNER in its <a class="link" href="https://huggingface.co/datasets/CATIE-AQ/frenchNER_4entities">4 entities</a> (PER, LOC, ORG, MISC) <d-cite bibtex-key="frenchNER2024"></d-cite> configuration.</p> | |
<table class="width_125"> | |
<thead> | |
<tr> | |
<th>Model</th> | |
<th>F1 PER</th> | |
<th>F1 LOC</th> | |
<th>F1 ORG</th> | |
<th>F1 MISC</th> | |
</tr> | |
</thead> | |
<tbody> | |
<tr> | |
<td>FAT5-small-100K (67.1M)</td> | |
<td>96.51</td> | |
<td>94.48</td> | |
<td>87.24</td> | |
<td>75.81</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-200K (67.1M)</td> | |
<td>96.90</td> | |
<td>94.83</td> | |
<td>88.78</td> | |
<td>76.82</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-300K (67.1M)</td> | |
<td>97.25</td> | |
<td>95.11</td> | |
<td>88.86</td> | |
<td><u>77.48</u></td> | |
</tr> | |
<tr> | |
<td>FAT5-small-400K (67.1M)</td> | |
<td>97.18</td> | |
<td>95.08</td> | |
<td>89.11</td> | |
<td>77.42</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-500K (67.1M)</td> | |
<td>97.25</td> | |
<td>95.16</td> | |
<td>89.16</td> | |
<td>76.91</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-600K (67.1M)</td> | |
<td>97.19</td> | |
<td>95.19</td> | |
<td>88.85</td> | |
<td>76.88</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-700K (67.1M)</td> | |
<td>97.17</td> | |
<td>95.14</td> | |
<td>89.39</td> | |
<td>76.82</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-800K (67.1M)</td> | |
<td><u>97.34</u></td> | |
<td>95.20</td> | |
<td>89.18</td> | |
<td>77.27</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-900K (67.1M)</td> | |
<td>97.19</td> | |
<td>95.21</td> | |
<td>89.04</td> | |
<td>76.83</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1000K (67.1M)</td> | |
<td>97.31</td> | |
<td>95.26</td> | |
<td>89.24</td> | |
<td>76.84</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1100K (67.1M)</td> | |
<td>97.11</td> | |
<td>94.99</td> | |
<td>88.52</td> | |
<td>76.30</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1200K (67.1M)</td> | |
<td>97.19</td> | |
<td>95.11</td> | |
<td>88.79</td> | |
<td>76.86</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1300K (67.1M)</td> | |
<td>97.15</td> | |
<td>95.00</td> | |
<td>88.62</td> | |
<td>76.58</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1400K (67.1M)</td> | |
<td>97.22</td> | |
<td>95.09</td> | |
<td>89.01</td> | |
<td>77.00</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1500K (67.1M)</td> | |
<td>97.32</td> | |
<td><u>95.34</u></td> | |
<td><u>89.39</u></td> | |
<td>77.30</td> | |
</tr> | |
<tr> | |
<td>FAT5-small-1600K (67.1M)</td> | |
<td>97.14</td> | |
<td>95.22</td> | |
<td>89.24</td> | |
<td>76.88</td> | |
</tr> | |
<tr> | |
<td><a class="link" href="">distillcamembert</a> (67.5M)</td> | |
<td>97.26</td> | |
<td>95.24</td> | |
<td>89.10</td> | |
<td>79.88</td> | |
</tr> | |
<tr> | |
<td><a class="link" href="https://huggingface.co/CATIE-AQ/NERmembert-base-4entities">camembert-base</a> (110M)</td> | |
<td>97.80</td> | |
<td>95.78</td> | |
<td>90.27</td> | |
<td>81.38</td> | |
</tr> | |
<tr> | |
<td><a class="link" href="https://huggingface.co/CATIE-AQ/NERmembert-large-4entities">camembert-large</a> (336M)</td> | |
<td>98.17</td> | |
<td>96.37</td> | |
<td>91.87</td> | |
<td>83.35</td> | |
</tr> | |
</tbody> | |
</table> | |
<p><br></p> | |
<h5 id="question-answering">Question Answering</h5> | |
<p class="width_125"> | |
We wanted to finetune our model on this task but realized that our tokenizer has two problems.<br> | |
Firstly, we forgot to add the token at the beginning of the sentence. | |
Secondly, we decided to use a fast BPE tokenizer. We learned afterwards that the `add_special_tokens=True` argument doesn't work with this type of tokenizer. | |
Correcting these two points requires us to post-process the tokenizer's encodings before performing our finetuning task, which isn't elegant and requires time we don't have right now. | |
<p><br></p> | |
<h5><i>Sentence Similarity</i></h5> | |
<p class="width_125"> | |
We invite the reader to take the results of this section with a grain of salt.<br> | |
We performed a finetuning on this task in order to verify that the <code>T5EncoderModel</code> head was working, but we are not focusing on the results obtained because we are questioning the quality of the benchmark on which we are evaluating the models, namely MTEB FR <d-cite bibtex-key="ciancone2024mtebfrenchresourcesfrenchsentence"></d-cite>, a French version of MTEB.<br> | |
Indeed, Nils Reimers, creator of the MTEB, recently questioned in a <a class="link" href="https://x.com/Nils_Reimers/status/1870812625505849849">tweet</a> the relevance of this benchmark, declaring it "dead". | |
Earlier in the year, we observed data leaks and duplications in this benchmark | |
(see <a class="link" href="https://huggingface.co/datasets/lbourdois/MTEB_leaks_and_duplications">here</a> and | |
<a class="link" href="https://github.com/embeddings-benchmark/mteb/issues/1036">here</a>). | |
Alexey Vatolin then extended these observations to include empty lines (see <a class="link" href="https://github.com/embeddings-benchmark/mteb/issues/1049#issuecomment-2463095122">here</a>). | |
<br> | |
In the table below, we finetuned on a cleaned version of the dataset <code>stsb_multi_mt</code> <d-cite bibtex-key="huggingface:dataset:stsb_multi_mt"></d-cite> (0.653% of the split test was unreliable because it contained leaks or duplicated data) before evaluating on MTEB FR. | |
<br> | |
</p> | |
<table class="width_125"> | |
<thead> | |
<tr> | |
<th>Model</th> | |
<th>Average</th> | |
<th>Classification</th> | |
<th>Clustering</th> | |
<th>PairClassification</th> | |
<th>Reranking</th> | |
<th>Retrieval</th> | |
<th>STS</th> | |
<th>Summary</th> | |
</tr> | |
</thead> | |
<tbody> | |
<tr> | |
<td>FAT5-small-400K (67.1M)</td> | |
<td>52.2</td> | |
<td>59.8</td> | |
<td>39.1</td> | |
<td>77.5</td> | |
<td>56.1</td> | |
<td>29.1</td> | |
<td>74</td> | |
<td>29.8</td> | |
</tr> | |
<tr> | |
<td>distillcamembert(68.1M)</td> | |
<td>51.3</td> | |
<td>60.7</td> | |
<td>37.4</td> | |
<td>77</td> | |
<td>51.1</td> | |
<td>25.2</td> | |
<td>76.4</td> | |
<td>31.3</td> | |
</tr> | |
</tbody> | |
</table> | |
<p><br><br><br></p> | |
<p class="width_125"> | |
We can see from the masked accuracy convergence graph that the performance of the encoder part of the model progresses initially before flattening out. | |
</p> | |
<p><br></p> | |
<p class="width_125"> | |
<picture> | |
<source media="(prefers-color-scheme: dark)" srcset="./assets/convergence_masked_accuracy_FAT5.png" width="100%"> | |
<img alt="Convergence masked accuracy FAT5" src="./assets/convergence_masked_accuracy_FAT5.png" width="100%"> | |
</picture> | |
</p> | |
<p><br></p> | |
<p class="width_125"> | |
This phenomenon can also be observed in the finetuning results: FAT5 matches the performance of distilcamembert at around 800 or 900K steps (except for the MISC entity for the NER task), but does no better beyond that. This is nevertheless encouraging in view of scaling up, since distilled models derived from larger models usually perform better than trained models from scratch.<br> | |
Note that this sort of plateau in performance needs to be confirmed by carrying out several executions with different configurations (notably at seed level), in order to propose results in the form of an interval instead of a single result (for each step evaluated, we use a seed of 42).<br> | |
It should also be mentioned that this capping for the encoder part has already been observed by other authors. One example is CamemBERT(a) 2.0 <d-cite bibtex-key="antoun2024camembert20smarterfrench"></d-cite> which has also been trained on the French-language part of CulturaX. CamemBERT 2.0 did not perform any better than CamemBERT 1.0, despite having seen more tokens. On the other hand, the authors obtained performance gains with CamemBERTa 2.0, suggesting that for encoders, the most important thing is to focus on the architecture (CamemBERTa 2.0 is a DeBERTaV3 <d-cite bibtex-key="he2023debertav3improvingdebertausing"></d-cite> while CamemBERT 2.0 is a RoBERTa <d-cite bibtex-key="liu2019robertarobustlyoptimizedbert"></d-cite>) rather than data. This result invites us to think about updating the T5 encoder architecture.<br> | |
A final observation that can be made is that if performance plateaus, it is possible to afford to stop pre-training earlier and thus reduce costs.<br> | |
In the table below, we list cost estimates (in euros) for the pre-training of our model according to various cloud providers. | |
For each of them, we base ourselves on the hourly price of an A 100 80GB offered on December 20, 2024.<br> | |
We show two cases: pre-training on 262 billion tokens (the threshold at which performance on classification tasks begins to plateau and marginal gains become low) on 419 billion tokens (the maximum number of tokens seen by CamemBERT). | |
<br> | |
</p> | |
<table class="width_125"> | |
<thead> | |
<tr> | |
<th>Cloud provider</th> | |
<th>Hourly rate for an A 100</th> | |
<th>Price for 262B tokens</th> | |
<th>Price for 419B tokens</th> | |
<th>Note</th> | |
</tr> | |
</thead> | |
<tbody> | |
<tr> | |
<td>AWS</td> | |
<td>1.77</td> | |
<td>1,616</td> | |
<td>2,586</td> | |
<td></td> | |
</tr> | |
<tr> | |
<td>OVH</td> | |
<td>2.75</td> | |
<td>2,475</td> | |
<td>3,960</td> | |
<td>By opting for monthly rather than hourly payment, the price in both cases is €2,200.</td> | |
</tr> | |
<tr> | |
<td>Azure</td> | |
<td>3.31</td> | |
<td>3,021</td> | |
<td>4,833</td> | |
<td>The hourly price was calculated from the monthly price of 8 A100.</td> | |
</tr> | |
<tr> | |
<td>Google Cloud</td> | |
<td>3.52</td> | |
<td>3,214</td> | |
<td>5,143</td> | |
<td></td> | |
</tr> | |
</tbody> | |
</table> | |
<p><br><br></p> | |
<h4>Pre-training times and emissions</h4> | |
<p class="width_125">Carbon emissions were estimated using the <a class="link" href="https://mlco2.github.io/impact#compute">Machine Learning Impact calculator</a> <d-cite bibtex-key="lacoste2019quantifying"></d-cite>.<br> | |
Our model was pre-trained on a single A100 PCIe 80GB, on a private infrastructure. | |
For carbon efficiency, we based ourselves on the daily numbers given by <a class="link" href="https://app.electricitymaps.com/zone/FR">electricitymaps</a> for France during our pre-training period. | |
The finetunings were carried out on a single A100 PCIe 40GB. | |
As execution time is generally counted in hours or even minutes, for carbon efficiency we refer to the electricitymaps numbers for the hour in question rather than the daily number.<br> | |
We estimate the emissions of our model at 14.084 kg eq. CO2, including 13.5 kg eq. CO2 for pre-training and 0.584 kg eq. CO2 for the 49 finetunings.<br> | |
To this, we must add additional emissions estimated at 6.24 kg eq. CO2. | |
They correspond to the finetuning of models to establish baselines against which to compare (0.475 kg eq. CO2), to our preliminary work in bfp16 mixed (4.735 kg eq. CO2 for the pre-training of three different models over 300K steps) and to tests in bfp16 full prior to the training of our final model (1.03 kg eq. in pre-training of a model half the size over 400K steps).<br> | |
In total, we estimate the carbon footprint of our work at 20.324 kg eq. CO2. </p> | |
<p class="width_125">For the pre-training phase (we don't have enough information to make estimates for the other phases), it is then possible to compare us with the other French pre-trained models listed above: </p> | |
<table class="width_125"> | |
<thead> | |
<tr> | |
<th>Model</th> | |
<th>Time (H)</th> | |
<th>Emissions (kg Co2 eq)</th> | |
<th>Note</th> | |
</tr> | |
</thead> | |
<tbody> | |
<tr> | |
<td>Camembert</td> | |
<td>6,144</td> | |
<td>106.91 ✝</td> | |
<td>24H × 256 Tesla V100-SXM2-32GB at 58g (average over 2019) <br>The authors do not specify the numbers for the large version.</td> | |
</tr> | |
<tr> | |
<td>Flaubert base <d-cite bibtex-key="le2020flaubert"></d-cite></td> | |
<td>13,120</td> | |
<td>190.24 to 228.29 ✝</td> | |
<td>410H × 32 V100 at 58g (average over 2019) <br>The V100 type is not specified<br>(V100-SXM2-32GB ? V100-SXM2-16GB ? V100-PCIE-16GB ?)</td> | |
</tr> | |
<tr> | |
<td>Flaubert large <d-cite bibtex-key="le2020flaubert"></d-cite></td> | |
<td>49,920</td> | |
<td>723.84 to 868.61 ✝</td> | |
<td>390H × 128 V100 at 58g (average over 2019) <br>The V100 type is not specified<br>(V100-SXM2-32GB ? V100-SXM2-16GB ? V100-PCIE-16GB ?)</td> | |
</tr> | |
<tr> | |
<td>Barthez</td> | |
<td>7,680 ★</td> | |
<td>107.52 to 129.02 ✝</td> | |
<td>60H × 128 V100 at 56g (average over 2020) <br>The V100 type is not specified<br>(V100-SXM2-32GB ? V100-SXM2-16GB ? V100-PCIE-16GB ?)</td> | |
</tr> | |
<tr> | |
<td>FAT5-small</td> | |
<td>1,461</td> | |
<td>13.5</td> | |
<td>1 461H × 1 A100 to 36.96 g (average between 2024-18-10 and 2024-19-12)</td> | |
</tr> | |
</tbody> | |
</table> | |
<p class="width_125">✝ the numbers given are estimates based on the information provided by the authors in their publication<br> | |
★ we indicate only the hours for the French pre-training applied on top of the initial English pre-training on which the model is based</p> | |
<p><br></p> | |
<h3 id="mod-les-en-anglais">Models in other languages</h3> | |
<p class="width_125"> | |
Our contribution focuses on French, with the introduction of a new model. For other languages, we can't afford to carry out work on the same magnitude.<br> | |
Nevertheless, we provide a <a class="link" href="https://github.com/catie-aq/flashT5/blob/main/convert_huggingface_t5.py">code</a> for adapting already pre-trained (m)T5/FLAN-T5 weights <d-cite bibtex-key="chung2022scaling"></d-cite> to our method. We hope that users will be able to continue pre-training one of these models to adapt it to more recent data, for example.<br> | |
Please note, however, that this adaptation is limited, since the additional pre-training will have to be carried out in the precision of the original model. For example, if the model's weights are in fp32 (which is the case with the FLAN-T5), training will not be as fast as with the FAT5, which is in bf16.<br><br> | |
For English speakers, we have already adapted the weights of the various FLAN-T5 versions to our method. All weights can be found in this <a class="link" href="https://huggingface.co/collections/CATIE-AQ/catie-english-fat5-flan-662b679a8e855c7c0137d69e">Hugging Face collection</a>.<br><br> | |
If you'd like to pre-train your own model (to specialize in a specific domain, for example, and thus benefit from a customized tokenizer), we refer you once again to the <a class="link" href="https://github.com/catie-aq/flashT5/tree/main/examples/minipile">tutorial</a> showing how to pre-train a model on minipile. Note that we have tested and trained the model in the tutorial on an A100, which may or may not work with other GPUs.</p> | |
<p class="width_125"><br><br><br></p> | |
<h2 id="la-suite">Next stage</h2> | |
<p class="width_125">Let's end this article by mentioning what we intend to do, or at least would like to do, as a follow-up to this work.<br></p> | |
<h3>Near future</h3> | |
<p class="width_125">These are things that should already have been in this article, but took more time than expected. | |
Typically, we've finished building datasets but haven't had time to do the finetunings.<br> | |
The aim is to complete these tasks in the near future, so that we can include the results in an update to this blog post. | |
</p> | |
<h4>Fix the tokenizer</h4> | |
<p class="width_125"> | |
The current FAT5 is usable. However, due to problems with the tokenizer resulting in inelegant post-processing for certain tasks, we're not excluding the possibility of re-training a model (on 1M steps only) with a new tokenizer allowing simpler use of the model. | |
<br><br></p> | |
<h4>Instruct model</h4> | |
<p class="width_125">We'd like to test FAT5's text generation capabilities in a more optimal way, in particular through the use of prompts, by developing an instruct model.<br> | |
For this, we have <a class="link" href="https://huggingface.co/datasets/CATIE-AQ/DFP">DFP</a> (<i>Dataset of French Prompts</i>) <d-cite bibtex-key="centre_aquitain_des_technologies_de_l'information_et_electroniques_2023"></d-cite>, a dataset of over 100M rows covering thirty NLP tasks. It follows the methodology of the <a class="link" href="https://huggingface.co/datasets/bigscience/xP3">xP3</a> dataset used for mT0 <d-cite bibtex-key="muennighoff2023crosslingualgeneralizationmultitaskfinetuning"></d-cite>. We could also take this opportunity to check out BigScience's "Finding 2" <d-cite bibtex-key="wang2022languagemodelarchitecturepretraining"></d-cite> (page 9 of the publication) indicating that encoder-decoder models would have better 0-shot capabilities than decoder models. <br> | |
Beyond NLP tasks, we also have over 2M open QA prompt rows, which should enable us to test FAT5 on more general tasks/knowledge.<br><br> | |
The development of this instruct model should also enable us to work on its alignment, in particular via a dataset of 12M rows to perform DPO in French.<br><br></p> | |
<h4>Long sequences</h4> | |
<p class="width_125"> | |
Pre-training is performed on sequences of 1,024 tokens. However, the CUDA kernel we've developed supports positional encodings that greatly extend the context size, as well a linear inference.<br> | |
With this in mind, we've created two datasets of long sequences in French (one of QA, one of text summaries) on which we'd like to finetune our model.<br><br><br></p> | |
<h3>Distant future</h3> | |
<p class="width_125">The items listed below are longer-term ideas. In other words, they will take time to implement and will be the subject of a new blog post if necessary.</p> | |
<h4 id="calcul-lin-aire">Memory reduction</h4> | |
<p class="width_125">Although we're already satisfied with the memory optimisations achieved via our CUDA kernel, we think we can take these results further using other techniques. For example, we can cite the CCE (Cut Cross-Entropy) method <d-cite bibtex-key="wijmans2024cut"></d-cite> with which we have already obtained interesting results on decoder models.<br> | |
In addition, while we have concentrated on pre-training, more work needs to be done on inference, which in practice consumes the most resources over time once the model is in production. We are thinking in particular of using the SageAttention2 <d-cite bibtex-key="zhang2024sageattention2efficientattentionthorough"></d-cite> released while our model was training. | |
<br><br></p> | |
<h4 id="calcul-lin-aire">Linear computation</h4> | |
<p class="width_125">In this work, we present a linear memory model. | |
A further improvement would be that, in addition to this memory, the model operates with linear computations.<br> | |
The idea is to replace traditional quadratic attention with another form of attention.<br> | |
We can think of some already applied to the T5, such as that of LongT5 <d-cite bibtex-key="guo2022longt5"></d-cite>. | |
It is also possible to test more recent forms such as Based <d-cite bibtex-key="arora2024simple"></d-cite>. | |
We are also interested in testing with Hedgehog <d-cite bibtex-key="zhang2024hedgehog"></d-cite>. | |
In fact, it is possible to combine them with the optimised kernels available in <a class="link" href="https://github.com/HazyResearch/ThunderKittens/tree/main/kernels">ThunderKittens</a> <d-cite bibtex-key="thunderkittens"></d-cite>. | |
The benefit is that it is then possible to keep the pre-trained model and, via additional finetuning, replace standard attention with softmax by linear attention with Hedgehog.<br> | |
LoLCATs <d-cite bibtex-key="zhang2024lolcatslowranklinearizinglarge"></d-cite> performs this finetuning via LoRA <d-cite bibtex-key="hu2021loralowrankadaptationlarge"></d-cite>. | |
<br><br></p> | |
<h4 id="passage-l-chelle">Model size</h4> | |
<p class="width_125"> T5/FLAN-T5 have been trained to 11 billion parameters, demonstrating that this architecture can scale.<br> | |
We would like to offer larger models with a FAT5-base and a FAT5-large with 305M and 973M parameters respectively, which we would then like to distil. The aim is to offer models that consume as little as possible in routine/inference.<br> | |
We also expect the distilled models to perform better than models of equivalent size trained from scratch.<br> | |
This should also allow us to propose models that will be used in practice. Indeed, in the current state for French, if the user is more motivated by performance than by the memory size of the model, he has more interest in using a CamemBERTa 2.0 for classification tasks. The present FAT5 should therefore be seen more as a proof of concept before being scaled up to make it competitive. | |
<br><br></p> | |
<h4 id="modeles-specialises">Training data</h4> | |
<p class="width_125"> | |
In this work, we used "generic" French data, mainly from CulturaX. During the training of our model, Hugging Face introduced the FineWeb2 dataset <d-cite bibtex-key="penedo2024fineweb-2"></d-cite> which includes French. We would like to pre-train a new model so that we can compare the impact of pre-training data on performance on downstream tasks.<br> | |
Beyond generic French, we particularly want to be able to apply our methodology to specific domains (medicine, regional variants of French, etc.).<br> | |
To do this, we would need to train a new dedicated tokenizer and perform a new pre-training for each of the chosen domains. | |
The advantage of the optimisations implemented and presented in this blog article is that they enable a significant reduction in the cost of pre-training.<br> | |
We would then like to conduct a comparison between these small specialised models vs. large generic models.<br><br></p> | |
<h4 id="modeles-specialises">Update of the T5 architecture</h4> | |
<p class="width_125">The final direction we would like to explore is an update of the T5 architecture. As encoder-decoders have been neglected, they have not benefited from the improvements that decoder models have received in recent months (more recent activation or normalisation layers, multi-token prediction <d-cite bibtex-key="gloeckle2024betterfasterlarge"></d-cite>, etc.).</p> | |
<p class="width_125"><br><br><br></p> | |
<h2 id="conclusion">Conclusion</h2> | |
<p class="width_125"> | |
We introduced the FAT5 (Flash Attention T5) model, detailing our approach to optimizing various elements of the pre-training and finetuning processes. | |
This is based on kernels that enable Flash Attention to be used with a T5 and give the model a linear memory. | |
In particular, we've applied our work to French as a proof of concept, and made sure that it can also be used in any other language. | |
We hope that our method, which enables a model with 147M parameters to be pre-trained from scratch for €1,600, will be useful for people with limited computational resources. | |
It also opens the way for a possible comeback of encoder-decoder models, rather than only decoder models.<br> | |
<p class="width_125"><br><br></p> | |
<style> | |
d-appendix .citation { | |
font-size: 11px; | |
line-height: 15px; | |
border-left: 1px solid rgba(0, 0, 0, 0.1); | |
padding-left: 10px; | |
border: 1px solid rgba(0,0,0,0.1); | |
background: #0D1117; | |
padding: 10px 10px; | |
border-radius: 3px; | |
color: rgba(150, 150, 150, 1); | |
overflow: hidden; | |
margin-top: -12px; | |
white-space: pre-wrap; | |
word-wrap: break-word; | |
} | |
</style> | |
<h3 id="citation">Citation</h3> | |
<pre class="citation long">@misc {FAT5, | |
title = { FAT5: Flash Attention T5 }, | |
author = { Boris ALBAR and Loïck BOURDOIS }, | |
organization = { Centre Aquitain des Technologies de l'Information et Electroniques }, | |
year = 2025, | |
url = { https://huggingface.co/spaces/CATIE-AQ/FAT5-report }, | |
doi = { 10.57967/hf/4160 }, | |
publisher = { Hugging Face } | |
}</pre> | |
<d-appendix style="color: #9CA3AF;" > | |
<d-bibliography src="bibliography.bib"></d-bibliography> | |
</d-appendix> | |
</d-article> | |
<script> | |
const article = document.querySelector('d-article'); | |
const toc = document.querySelector('d-contents'); | |
if (toc) { | |
const headings = article.querySelectorAll('h2, h3, h4'); | |
let ToC = `<nav role="navigation" class="l-text figcaption" style="color: #9CA3AF;"><h3>Table des matières</h3>`; | |
let prevLevel = 0; | |
for (const el of headings) { | |
// should element be included in TOC? | |
const isInTitle = el.parentElement.tagName == 'D-TITLE'; | |
const isException = el.getAttribute('no-toc'); | |
if (isInTitle || isException) continue; | |
el.setAttribute('id', el.textContent.toLowerCase().replaceAll(" ", "_")) | |
const link = '<a target="_self" href="' + '#' + el.getAttribute('id') + '">' + el.textContent + '</a>'; | |
const level = el.tagName === 'H2' ? 0 : (el.tagName === 'H3' ? 1 : 2); | |
while (prevLevel < level) { | |
ToC += '<ul>' | |
prevLevel++; | |
} | |
while (prevLevel > level) { | |
ToC += '</ul>' | |
prevLevel--; | |
} | |
if (level === 0) | |
ToC += '<div>' + link + '</div>'; | |
else | |
ToC += '<li>' + link + '</li>'; | |
} | |
while (prevLevel > 0) { | |
ToC += '</ul>' | |
prevLevel--; | |
} | |
ToC += '</nav>'; | |
toc.innerHTML = ToC; | |
toc.setAttribute('prerendered', 'true'); | |
const toc_links = document.querySelectorAll('d-contents > nav a'); | |
window.addEventListener('scroll', (_event) => { | |
if (typeof (headings) != 'undefined' && headings != null && typeof (toc_links) != 'undefined' && toc_links != null) { | |
// Then iterate forwards, on the first match highlight it and break | |
find_active: { | |
for (let i = headings.length - 1; i >= 0; i--) { | |
if (headings[i].getBoundingClientRect().top - 50 <= 0) { | |
if (!toc_links[i].classList.contains("active")) { | |
toc_links.forEach((link, _index) => { | |
link.classList.remove("active"); | |
}); | |
toc_links[i].classList.add('active'); | |
} | |
break find_active; | |
} | |
} | |
toc_links.forEach((link, _index) => { | |
link.classList.remove("active"); | |
}); | |
} | |
} | |
}); | |
} | |
</script> | |
</body> | |
</html> |