lvwerra HF staff commited on
Commit
8e943ac
·
1 Parent(s): 236446b

kernels + conclusion

Browse files
dist/assets/svg/figure-01.svg ADDED
dist/assets/svg/test-svg.html ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>Interactive SVG Hover Effect</title>
6
+ <style>
7
+ body {
8
+ font-family: Arial, sans-serif;
9
+ margin: 20px;
10
+ background: #f8f9fa;
11
+ }
12
+ .svg-container {
13
+ border: 1px solid #ccc;
14
+ padding: 10px;
15
+ border-radius: 8px;
16
+ background: #fff;
17
+ }
18
+ .info {
19
+ margin-top: 15px;
20
+ font-size: 16px;
21
+ color: #555;
22
+ }
23
+ </style>
24
+ </head>
25
+ <body>
26
+ <div class="svg-container" id="svg-container-01">
27
+ <!-- The enhanced SVG will be injected here -->
28
+ </div>
29
+ <div class="info" id="info">Hover over the network elements to see their details</div>
30
+
31
+ <script>
32
+ // Function to enhance the SVG content by adding styles and data attributes
33
+ function enhanceSVGContent(originalContent) {
34
+ const parser = new DOMParser();
35
+ const doc = parser.parseFromString(originalContent, 'image/svg+xml');
36
+
37
+ // Create a style element with hover effects and insert it as the first child of the SVG
38
+ const styleElement = doc.createElementNS('http://www.w3.org/2000/svg', 'style');
39
+ styleElement.textContent = `
40
+ path[data-element-type="layer"] {
41
+ transition: all 0.3s;
42
+ cursor: pointer;
43
+ }
44
+ path[data-element-type="layer"]:hover {
45
+ fill: #b197fc !important;
46
+ transform: translate(0, -2px);
47
+ }
48
+
49
+ path[data-element-type="gradient"] {
50
+ transition: all 0.3s;
51
+ cursor: pointer;
52
+ }
53
+ path[data-element-type="gradient"]:hover {
54
+ fill: #f06595 !important;
55
+ transform: translate(0, -2px);
56
+ }
57
+
58
+ path[data-element-type="forward"] {
59
+ transition: all 0.3s;
60
+ cursor: pointer;
61
+ }
62
+ path[data-element-type="forward"]:hover {
63
+ stroke: #0c8599 !important;
64
+ stroke-width: 4 !important;
65
+ }
66
+
67
+ path[data-element-type="backward"] {
68
+ transition: all 0.3s;
69
+ cursor: pointer;
70
+ }
71
+ path[data-element-type="backward"]:hover {
72
+ stroke: #e8590c !important;
73
+ stroke-width: 4 !important;
74
+ }
75
+
76
+ path[data-element-type="optimization"] {
77
+ transition: all 0.3s;
78
+ cursor: pointer;
79
+ }
80
+ path[data-element-type="optimization"]:hover {
81
+ stroke: #087f5b !important;
82
+ stroke-width: 4 !important;
83
+ }
84
+ `;
85
+ doc.documentElement.insertBefore(styleElement, doc.documentElement.firstChild);
86
+
87
+ // Process neural network layers (purple nodes)
88
+ doc.querySelectorAll('path[fill="#d0bfff"]').forEach((node, index) => {
89
+ node.setAttribute('data-element-id', `layer-${index}`);
90
+ node.setAttribute('data-element-type', 'layer');
91
+ });
92
+
93
+ // Process gradient nodes (pink nodes)
94
+ doc.querySelectorAll('path[fill="#f783ac"]').forEach((node, index) => {
95
+ node.setAttribute('data-element-id', `gradient-${index}`);
96
+ node.setAttribute('data-element-type', 'gradient');
97
+ });
98
+
99
+ // Process arrows by matching stroke colors
100
+ const arrowTypes = {
101
+ '#15aabf': 'forward',
102
+ '#fd7e14': 'backward',
103
+ '#099268': 'optimization'
104
+ };
105
+
106
+ Object.entries(arrowTypes).forEach(([color, type]) => {
107
+ doc.querySelectorAll(`path[stroke="${color}"]`).forEach((arrow, index) => {
108
+ arrow.setAttribute('data-element-id', `${type}-${index}`);
109
+ arrow.setAttribute('data-element-type', type);
110
+ });
111
+ });
112
+
113
+ // Make the SVG responsive
114
+ doc.documentElement.setAttribute('width', '100%');
115
+ doc.documentElement.setAttribute('height', 'auto');
116
+ doc.documentElement.setAttribute('preserveAspectRatio', 'xMidYMid meet');
117
+
118
+ return new XMLSerializer().serializeToString(doc);
119
+ }
120
+
121
+ // Function to load an SVG file via fetch
122
+ async function loadSVG(url, containerId) {
123
+ try {
124
+ const response = await fetch(url);
125
+ if (!response.ok) {
126
+ throw new Error(`HTTP error! Status: ${response.status}`);
127
+ }
128
+ const svgText = await response.text();
129
+ const enhancedSVG = enhanceSVGContent(svgText);
130
+ document.getElementById(containerId).innerHTML = enhancedSVG;
131
+ } catch (error) {
132
+ console.error('Error loading SVG:', error);
133
+ document.getElementById(containerId).innerHTML = '<p>Error loading SVG.</p>';
134
+ }
135
+ }
136
+
137
+ // Load the SVG file (adjust the path if needed)
138
+ loadSVG('figure-01.svg', 'svg-container-01');
139
+
140
+ // Set up event listeners to display a description of the hovered element
141
+ const svgContainer = document.getElementById('svg-container-01');
142
+ svgContainer.addEventListener('mouseover', function(event) {
143
+ const target = event.target;
144
+ if (target.tagName.toLowerCase() === 'path' && target.hasAttribute('data-element-id')) {
145
+ const elementId = target.getAttribute('data-element-id');
146
+ const elementType = target.getAttribute('data-element-type');
147
+ const descriptions = {
148
+ layer: 'Neural Network Layer',
149
+ gradient: 'Gradient Update Layer',
150
+ forward: 'Forward Pass Connection',
151
+ backward: 'Backward Pass Connection',
152
+ optimization: 'Optimization Step'
153
+ };
154
+ const description = descriptions[elementType] || elementType;
155
+ document.getElementById('info').textContent = `Hovering over: ${description} (${elementId})`;
156
+ }
157
+ });
158
+
159
+ svgContainer.addEventListener('mouseout', function() {
160
+ document.getElementById('info').textContent = 'Hover over the network elements to see their details';
161
+ });
162
+ </script>
163
+ </body>
164
+ </html>
dist/bibliography.bib CHANGED
@@ -466,4 +466,48 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
466
  archivePrefix={arXiv},
467
  primaryClass={cs.CL},
468
  url={https://arxiv.org/abs/2006.16668},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  }
 
466
  archivePrefix={arXiv},
467
  primaryClass={cs.CL},
468
  url={https://arxiv.org/abs/2006.16668},
469
+ }
470
+ @misc{dao2022flashattention,
471
+ title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
472
+ author={Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher Ré},
473
+ year={2022},
474
+ eprint={2205.14135},
475
+ archivePrefix={arXiv},
476
+ primaryClass={cs.LG},
477
+ url={https://arxiv.org/abs/2205.14135},
478
+ }
479
+ @misc{micikevicius2018mixedprecisiontraining,
480
+ title={Mixed Precision Training},
481
+ author={Paulius Micikevicius and Sharan Narang and Jonah Alben and Gregory Diamos and Erich Elsen and David Garcia and Boris Ginsburg and Michael Houston and Oleksii Kuchaiev and Ganesh Venkatesh and Hao Wu},
482
+ year={2018},
483
+ eprint={1710.03740},
484
+ archivePrefix={arXiv},
485
+ primaryClass={cs.AI},
486
+ url={https://arxiv.org/abs/1710.03740},
487
+ }
488
+ @software{torchao,
489
+ title = {torchao: PyTorch native quantization and sparsity for training and inference},
490
+ author = {torchao maintainers and contributors},
491
+ url = {https://github.com/pytorch/torchao},
492
+ license = {BSD-3-Clause},
493
+ month = oct,
494
+ year = {2024}
495
+ }
496
+ @misc{peng2023fp8lmtrainingfp8large,
497
+ title={FP8-LM: Training FP8 Large Language Models},
498
+ author={Houwen Peng and Kan Wu and Yixuan Wei and Guoshuai Zhao and Yuxiang Yang and Ze Liu and Yifan Xiong and Ziyue Yang and Bolin Ni and Jingcheng Hu and Ruihang Li and Miaosen Zhang and Chen Li and Jia Ning and Ruizhe Wang and Zheng Zhang and Shuguang Liu and Joe Chau and Han Hu and Peng Cheng},
499
+ year={2023},
500
+ eprint={2310.18313},
501
+ archivePrefix={arXiv},
502
+ primaryClass={cs.LG},
503
+ url={https://arxiv.org/abs/2310.18313},
504
+ }
505
+ @misc{wortsman2023smallscaleproxieslargescaletransformer,
506
+ title={Small-scale proxies for large-scale Transformer training instabilities},
507
+ author={Mitchell Wortsman and Peter J. Liu and Lechao Xiao and Katie Everett and Alex Alemi and Ben Adlam and John D. Co-Reyes and Izzeddin Gur and Abhishek Kumar and Roman Novak and Jeffrey Pennington and Jascha Sohl-dickstein and Kelvin Xu and Jaehoon Lee and Justin Gilmer and Simon Kornblith},
508
+ year={2023},
509
+ eprint={2309.14322},
510
+ archivePrefix={arXiv},
511
+ primaryClass={cs.LG},
512
+ url={https://arxiv.org/abs/2309.14322},
513
  }
dist/distill.bundle.js.map CHANGED
The diff for this file is too large to render. See raw diff
 
dist/index.html CHANGED
@@ -51,115 +51,6 @@
51
  <d-article>
52
  <d-contents>
53
  </d-contents>
54
-
55
-
56
- <script>
57
- // Function to enhance the SVG content by adding styles and data attributes
58
- function enhanceSVGContent(originalContent) {
59
- const parser = new DOMParser();
60
- const doc = parser.parseFromString(originalContent, 'image/svg+xml');
61
-
62
- // Create a style element with hover effects and insert it as the first child of the SVG
63
- const styleElement = doc.createElementNS('http://www.w3.org/2000/svg', 'style');
64
- styleElement.textContent = `
65
- path[data-element-type="layer"] {
66
- transition: all 0.3s;
67
- cursor: pointer;
68
- }
69
- path[data-element-type="layer"]:hover {
70
- fill: #b197fc !important;
71
- transform: translate(0, -2px);
72
- }
73
-
74
- path[data-element-type="gradient"] {
75
- transition: all 0.3s;
76
- cursor: pointer;
77
- }
78
- path[data-element-type="gradient"]:hover {
79
- fill: #f06595 !important;
80
- transform: translate(0, -2px);
81
- }
82
-
83
- path[data-element-type="forward"] {
84
- transition: all 0.3s;
85
- cursor: pointer;
86
- }
87
- path[data-element-type="forward"]:hover {
88
- stroke: #0c8599 !important;
89
- stroke-width: 4 !important;
90
- }
91
-
92
- path[data-element-type="backward"] {
93
- transition: all 0.3s;
94
- cursor: pointer;
95
- }
96
- path[data-element-type="backward"]:hover {
97
- stroke: #e8590c !important;
98
- stroke-width: 4 !important;
99
- }
100
-
101
- path[data-element-type="optimization"] {
102
- transition: all 0.3s;
103
- cursor: pointer;
104
- }
105
- path[data-element-type="optimization"]:hover {
106
- stroke: #087f5b !important;
107
- stroke-width: 4 !important;
108
- }
109
- `;
110
- doc.documentElement.insertBefore(styleElement, doc.documentElement.firstChild);
111
-
112
- // Process neural network layers (purple nodes)
113
- doc.querySelectorAll('path[fill="#d0bfff"]').forEach((node, index) => {
114
- node.setAttribute('data-element-id', `layer-${index}`);
115
- node.setAttribute('data-element-type', 'layer');
116
- });
117
-
118
- // Process gradient nodes (pink nodes)
119
- doc.querySelectorAll('path[fill="#f783ac"]').forEach((node, index) => {
120
- node.setAttribute('data-element-id', `gradient-${index}`);
121
- node.setAttribute('data-element-type', 'gradient');
122
- });
123
-
124
- // Process arrows by matching stroke colors
125
- const arrowTypes = {
126
- '#15aabf': 'forward',
127
- '#fd7e14': 'backward',
128
- '#099268': 'optimization'
129
- };
130
-
131
- Object.entries(arrowTypes).forEach(([color, type]) => {
132
- doc.querySelectorAll(`path[stroke="${color}"]`).forEach((arrow, index) => {
133
- arrow.setAttribute('data-element-id', `${type}-${index}`);
134
- arrow.setAttribute('data-element-type', type);
135
- });
136
- });
137
-
138
- // Make the SVG responsive
139
- doc.documentElement.setAttribute('width', '100%');
140
- doc.documentElement.setAttribute('height', 'auto');
141
- doc.documentElement.setAttribute('preserveAspectRatio', 'xMidYMid meet');
142
-
143
- return new XMLSerializer().serializeToString(doc);
144
- }
145
-
146
- // Function to load an SVG file via fetch
147
- async function loadSVG(url) {
148
- try {
149
- const response = await fetch(url);
150
- if (!response.ok) {
151
- throw new Error(`HTTP error! Status: ${response.status}`);
152
- }
153
- const svgText = await response.text();
154
- const enhancedSVG = enhanceSVGContent(svgText);
155
- document.getElementById('svg-container').innerHTML = enhancedSVG;
156
- } catch (error) {
157
- console.error('Error loading SVG:', error);
158
- document.getElementById('svg-container').innerHTML = '<p>Error loading SVG.</p>';
159
- }
160
- }
161
- </script>
162
-
163
 
164
  <p>Fueled by the scaling laws<d-cite bibtex-key="kaplan2020scalinglaws"></d-cite><d-cite bibtex-key="hoffmann2022chinchilla"></d-cite>, the trend of training ever larger language models on vaster amounts of data has been driving progress in AI for the past couple years. Initially, the development of the largest models happened exclusively behind closed doors of a handful of research labs but recently opened up more with the release of models such as Llama 3.1 405B<d-cite bibtex-key="grattafiori2024llama3herdmodels"></d-cite> and DeepSeek R1<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. While these models have <a href="https://huggingface.co/meta-llama">openly shared</a> <a href="https://huggingface.co/deepseek-ai">weights</a> and their training recipes are described in <a href="https://ai.meta.com/research/publications/the-llama-3-herd-of-models/">technical</a> <a href="https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf">reports</a>, the challenging engineering to involved to train at the necessary infrastructure scale is still hidden between the lines of a handful of papers and complex training frameworks. This ~~long blog post~~ open-source book is here to open this black box!</p>
165
 
@@ -332,37 +223,7 @@
332
  </ol>
333
 
334
  <p>It looks generally like this: </p>
335
- <div class="svg-container" id="svg-container">
336
- </div>
337
- <div class="info" id="info">Hover over the network elements to see their details</div>
338
- <script>
339
- // Load the SVG file (adjust the path if needed)
340
- loadSVG('../assets/svg/figure-01.svg');
341
-
342
- // Set up event listeners to display a description of the hovered element
343
- const svgContainer = document.getElementById('svg-container');
344
- svgContainer.addEventListener('mouseover', function(event) {
345
- const target = event.target;
346
- if (target.tagName.toLowerCase() === 'path' && target.hasAttribute('data-element-id')) {
347
- const elementId = target.getAttribute('data-element-id');
348
- const elementType = target.getAttribute('data-element-type');
349
- const descriptions = {
350
- layer: 'Neural Network Layer',
351
- gradient: 'Gradient Update Layer',
352
- forward: 'Forward Pass Connection',
353
- backward: 'Backward Pass Connection',
354
- optimization: 'Optimization Step'
355
- };
356
- const description = descriptions[elementType] || elementType;
357
- document.getElementById('info').textContent = `Hovering over: ${description} (${elementId})`;
358
- }
359
- });
360
-
361
- svgContainer.addEventListener('mouseout', function() {
362
- document.getElementById('info').textContent = 'Hover over the network elements to see their details';
363
- });
364
-
365
- </script>
366
 
367
  <aside>As we’ll see later, these steps may be repeated or intertwined but for now we’ll start simple.</aside>
368
 
@@ -540,7 +401,7 @@
540
 
541
  <p>Is there a way to tame this “activation explosion”? Good question, reader!</p>
542
 
543
- <p>It’s time to explain our first technique – called <strong><em>activation recomputation</em><em>–</em> </strong>**which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.</p>
544
 
545
  <h3>Activation recomputation</h3>
546
 
@@ -704,7 +565,7 @@
704
 
705
  <p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p>
706
 
707
- <aside>Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by <em>ring latency</em> (time required for a signal to propagate once around the ring) **which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
708
  </aside>
709
 
710
  <p>While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:</p>
@@ -841,9 +702,9 @@
841
 
842
 
843
 
844
- <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same ***reduce-scatter*** as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
845
 
846
- <p>Thankfully, although we added many more communication operations, **prefetching** helps us overlap them efficiently by all-gathering weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, by all-gathering weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
847
 
848
  <p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in earlier chapters.</p>
849
 
@@ -1619,46 +1480,611 @@
1619
 
1620
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
1621
 
1622
- <p></p>
 
1623
 
1624
- <p></p>
1625
 
1626
- <p></p>
 
1627
 
1628
- <p></p>
1629
 
1630
- <p></p>
1631
 
 
 
 
 
1632
  <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
 
 
 
 
 
 
 
 
 
 
1633
 
 
1634
 
1635
  <h3>How to improve performance with Kernels ?</h3>
1636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1637
  <h4>Memory Coalescing</h4>
1638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1639
  <h4>Tiling</h4>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1640
 
 
 
 
 
1641
  <h4>Thread Coarsening</h4>
1642
 
 
 
 
 
 
 
 
 
 
 
 
 
1643
  <h4>Minimizing Control Divergence</h4>
1644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1645
  <h3>Flash Attention 1-3</h3>
1646
 
1647
- <h3>Fused Kernels</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1648
 
1649
  <h3>Mixed Precision Training</h3>
1650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1651
  <h4>FP16 and BF16 training</h4>
 
 
 
 
 
 
 
 
 
 
 
 
1652
 
1653
  <h4>FP8 pretraining</h4>
1654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1655
  <h2>Conclusion</h2>
1656
 
 
 
 
 
 
 
 
 
 
 
 
1657
  <h3>What you learned</h3>
1658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1659
  <h3>What we learned</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1661
  <h3>What’s next?</h3>
 
 
 
 
 
 
 
 
 
 
1662
 
1663
  <h2>References</h2>
1664
 
@@ -1712,8 +2138,7 @@
1712
  }</pre>
1713
  </d-appendix>
1714
 
1715
-
1716
- <script>
1717
  const article = document.querySelector('d-article');
1718
  const toc = document.querySelector('d-contents');
1719
  if (toc) {
 
51
  <d-article>
52
  <d-contents>
53
  </d-contents>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  <p>Fueled by the scaling laws<d-cite bibtex-key="kaplan2020scalinglaws"></d-cite><d-cite bibtex-key="hoffmann2022chinchilla"></d-cite>, the trend of training ever larger language models on vaster amounts of data has been driving progress in AI for the past couple years. Initially, the development of the largest models happened exclusively behind closed doors of a handful of research labs but recently opened up more with the release of models such as Llama 3.1 405B<d-cite bibtex-key="grattafiori2024llama3herdmodels"></d-cite> and DeepSeek R1<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. While these models have <a href="https://huggingface.co/meta-llama">openly shared</a> <a href="https://huggingface.co/deepseek-ai">weights</a> and their training recipes are described in <a href="https://ai.meta.com/research/publications/the-llama-3-herd-of-models/">technical</a> <a href="https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf">reports</a>, the challenging engineering to involved to train at the necessary infrastructure scale is still hidden between the lines of a handful of papers and complex training frameworks. This ~~long blog post~~ open-source book is here to open this black box!</p>
56
 
 
223
  </ol>
224
 
225
  <p>It looks generally like this: </p>
226
+ <p><img alt="image.png" src="assets/images/placeholder.png" /></p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  <aside>As we’ll see later, these steps may be repeated or intertwined but for now we’ll start simple.</aside>
229
 
 
401
 
402
  <p>Is there a way to tame this “activation explosion”? Good question, reader!</p>
403
 
404
+ <p>It’s time to explain our first technique – called <strong><em>activation recomputation</em><em>–</em> </strong>which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.</p>
405
 
406
  <h3>Activation recomputation</h3>
407
 
 
565
 
566
  <p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p>
567
 
568
+ <aside>Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by <em>ring latency</em> (time required for a signal to propagate once around the ring) which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
569
  </aside>
570
 
571
  <p>While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:</p>
 
702
 
703
 
704
 
705
+ <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
706
 
707
+ <p>Thankfully, although we added many more communication operations, <strong>prefetching</strong> helps us overlap them efficiently by all-gathering weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, by all-gathering weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
708
 
709
  <p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in earlier chapters.</p>
710
 
 
1480
 
1481
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
1482
 
1483
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1484
+ <p>TODO: Original figure from https://blog.codingconfessions.com/p/gpu-computing.</p>
1485
 
1486
+ <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
1487
 
1488
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1489
+ <p>TODO: Original figure from https://www.youtube.com/watch?v=ZQKMZIP3Fzg</p>
1490
 
1491
+ <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
1492
 
1493
+ <p>A piece of code running on a core of the GPU is called a <strong>kernel</strong>. It can be written at a high-level in <strong>CUDA</strong> or <strong>Triton</strong> for instance, and is then compiled to Parallel Thread Execution, PTX, the low-level assembly used by NVIDIA GPUs.</p>
1494
 
1495
+ <p>To run the kernel, you will also need a specific code part, called <strong>host code</strong>, which is executed on the <strong>CPU/host</strong> and will take care of preparing data allocations and loading data and code.</p>
1496
+
1497
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1498
+ <p>Figure 5: Host code for a CUDA kernel for adding two vectors from https://blog.codingconfessions.com/p/gpu-computing</p>
1499
  <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1500
+ <p>Figure 6: Device code containing the definition of the vector addition kernel from https://blog.codingconfessions.com/p/gpu-computing</p>
1501
+
1502
+ <p>Kernels are generally scheduled as follow:</p>
1503
+
1504
+ <ul>
1505
+ <li>threads are grouped in <strong>warps</strong> of sizes of 32. All the threads in a warp are synchronized to execute instructions simultaneously but on different parts of the data.</li>
1506
+ <li><strong>warps</strong> are grouped in larger <strong>blocks</strong> of more flexible size (e.g. size 256), each block still being assigned to a single SM. An SM may run several blocks in parallel, however, depending on the resources, not all the blocks may get assigned for execution immediately, some can be waitlisted waiting for resources.</li>
1507
+ </ul>
1508
+
1509
+ <p>The main thing to remember from these details is that there are various sizing and allocation constraints (size of the various memories, number of concurrent block and threads in the wraps) which need to be taken into account to use the GPU architecture in the most efficient way.</p>
1510
 
1511
+ <p>Most of the time you don’t need to go down to this level of precision and you can luckily reuse the kernels and code prepared by other members of the community. But in any case we want to give you a primer on how to get started with kernels! </p>
1512
 
1513
  <h3>How to improve performance with Kernels ?</h3>
1514
 
1515
+
1516
+ <p>If you’re looking to add a new operation that lacks an optimized kernel or to speed up an existing PyTorch function, writing kernels from scratch might seem like the most direct route. However, creating high-performance CUDA kernels from scratch requires extensive experience and a steep learning curve. Generally a better way to get started is to leverage <code>torch.compile</code>, which dynamically optimizes PyTorch code by capturing your operations and generating lower-level, high-performance kernels in triton.</p>
1517
+
1518
+ <p>Let’s suppose you want to write a kernel for an activation function called Exponential Linear Unit:</p>
1519
+
1520
+ <d-math block>
1521
+ \text{ELU}(x) = \begin{cases}
1522
+ e^x - 1 & \text{if } x < 0 \\
1523
+ x & \text{if } x \geq 0
1524
+ \end{cases}
1525
+ </d-math>
1526
+ <p>TODO: something off with spacing but seems the rendering engine</p>
1527
+
1528
+ <p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p>
1529
+
1530
+ <d-code block language="python">
1531
+ @torch.compile
1532
+ def elu(x, alpha=1.0):
1533
+ return torch.where(x < 0, alpha * (torch.exp(x) - 1), x)
1534
+ </d-code>
1535
+
1536
+ <p>The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns):</p>
1537
+
1538
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1539
+
1540
+
1541
+ <p>However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by @torch.compile . To do so, you simply need to set the environment variable <code>TORCH_LOGS</code> to <code>"output_code"</code>:</p>
1542
+
1543
+ <d-code block language="bash">
1544
+ export TORCH_LOGS="output_code"
1545
+ </d-code>
1546
+
1547
+ <p>Once you run the Python script with the <code>@torch.compile</code> decorator, it will generate and output the corresponding Triton kernel, which, in this case, is:</p>
1548
+
1549
+ <d-code block language="python">
1550
+ @triton.jit
1551
+ def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
1552
+ xnumel = 100000000
1553
+ xoffset = tl.program_id(0) * XBLOCK
1554
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
1555
+ xmask = xindex < xnumel
1556
+ x0 = xindex
1557
+ tmp0 = tl.load(in_ptr0 + (x0), xmask)
1558
+ tmp1 = 0.0
1559
+ tmp2 = tmp0 < tmp1
1560
+ tmp3 = tl_math.exp(tmp0)
1561
+ tmp4 = 1.0
1562
+ tmp5 = tmp3 - tmp4
1563
+ tmp6 = tl.where(tmp2, tmp5, tmp0)
1564
+ tl.store(out_ptr0 + (x0), tmp6, xmask)
1565
+ </d-code>
1566
+
1567
+ <p>To enhance readability, we can modify the variable names, add comments, and make slight adjustments, as demonstrated below:</p>
1568
+
1569
+ <d-code block language="python">
1570
+ @triton.jit
1571
+ def elu_kernel(input_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
1572
+ # Calculate the starting index for this block
1573
+ block_start = tl.program_id(0) * BLOCK_SIZE
1574
+ # Create an array of indices for this block
1575
+ block_indices = block_start + tl.arange(0, BLOCK_SIZE)[:]
1576
+ # Create a mask to ensure only valid indices are processed
1577
+ valid_mask = block_indices < num_elements
1578
+ # Load input values from the input pointer based on valid indices
1579
+ input_values = tl.load(input_ptr + block_indices, valid_mask)
1580
+ # Define the ELU parameters
1581
+ zero_value = 0.0 # Threshold for ELU activation
1582
+ negative_mask = input_values < zero_value
1583
+ exp_values = tl.math.exp(input_values)
1584
+ # Define the ELU output shift
1585
+ one_value = 1.0
1586
+ shifted_exp_values = exp_values - one_value
1587
+
1588
+ output_values = tl.where(negative_mask, shifted_exp_values, input_values)
1589
+
1590
+ # Store the computed output values back to the output pointer
1591
+ tl.store(output_ptr + block_indices, output_values, valid_mask)
1592
+ </d-code>
1593
+
1594
+ <p>Here, <code>tl.program_id(0)</code> provides a unique block ID, that we use to determine which section of data that block will process. Using this block ID, <code>block_start</code> calculates the starting index for each block’s section, while <code>block_indices</code> specifies the range of indices within that section. A <code>valid_mask</code> ensures that only indices within <code>num_elements</code> are processed, safely loading the data with <code>tl.load</code>. The ELU function is then applied, modifying values based on whether they're negative, and results are written back to memory with <code>tl.store</code>.</p>
1595
+
1596
+ <p>When we benchmark the generated kernel using <code>triton.testing.Benchmark</code> we have the following performance:</p>
1597
+
1598
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1599
+
1600
+ <p>This standalone kernel demonstrates superior performance with smaller sizes compared to <code>@torch.compile</code> but this is likely here just an artifact from the compilation time of <code>torch.compile</code>. In any case, instead of starting from scratch, we can focus on optimizing this generated kernel, saving us time in the process. </p>
1601
+
1602
+ <p>However, in Triton, sometimes, we cannot fully achieve the peak performance of the device due to limitations in handling shared memory and scheduling within streaming multiprocessors (SMs). Our access is restricted to blocks, allowing us only to manage the scheduling of blocks across SMs. To gain even more control, we will need to implement kernels in CUDA, where we have access to all the underlying components.</p>
1603
+
1604
+ <p>In CUDA, there are various techniques that can be employed to make kernels more efficient; we will present just a few. These include optimizing memory access patterns to reduce latency, using shared memory to store frequently accessed data, and managing thread workloads to minimize idle times. In summary, the tools for writing code to execute instructions on the GPU are:</p>
1605
+
1606
+ <ul>
1607
+ <li>Pytorch: easy but slow</li>
1608
+ <li>torch.compile: easy, fast, but not flexible</li>
1609
+ <li>triton: harder, faster, and more flexible</li>
1610
+ <li>CUDA: hardest, fastest, and flexiblest (if you get it right)</li>
1611
+
1612
+ </ul>
1613
+
1614
+ <p>Let’s talk about one of the most frequent technique we can use: optimizing memory access. The global memory in GPUs (the largest memory in our above graph) has a long latency and low bandwidth in comparison to the cache which often creates a major bottleneck for most applications. Efficiently accessing data from global memory can improve a lot the performance.</p>
1615
+
1616
  <h4>Memory Coalescing</h4>
1617
 
1618
+ <p>To effectively utilize the bandwidth of global memory, it is essential to understand its architecture. In CUDA devices, global memory is implemented using DRAM.</p>
1619
+
1620
+ <p>Memory coalescing takes advantage of how DRAM delivers data in bursts, or ranges of consecutive memory locations, whenever a memory address is accessed. Each time a DRAM location is accessed, a sequence of consecutive locations, including the requested one, is read in parallel by multiple sensors in the DRAM chip. Once read, this data can then be quickly transferred to the processor as a burst. In CUDA, coalescing uses this burst behavior to maximize memory access efficiency by ensuring that threads in a warp—32 threads that execute the same instruction in lockstep (SIMD)—access consecutive memory locations. For instance, if thread 0 accesses location M, thread 1 accesses M + 1, thread 2 accesses M + 2, and so forth, the GPU hardware coalesces or combines these requests into one large, efficient access request for the DRAM burst, rather than handling each access individually. </p>
1621
+
1622
+ <p>Let’s take the example of matrix multiplication. A simple, straightforward implementation would have each thread compute a single element of the output matrix, like this:</p>
1623
+
1624
+ <d-code block language="clike">
1625
+ __global__ void matmul_naive(int M, int N, int K, const float *A, const float *B, float *C) {
1626
+ const uint x = blockIdx.x * blockDim.x + threadIdx.x;
1627
+ const uint y = blockIdx.y * blockDim.y + threadIdx.y;
1628
+
1629
+ if (x < M && y < N) {
1630
+ float tmp = 0.0;
1631
+ for (int i = 0; i < K; ++i) {
1632
+ tmp += A[x * K + i] * B[i * N + y];
1633
+ }
1634
+ C[x * N + y] = tmp;
1635
+ }
1636
+ }
1637
+ </d-code>
1638
+
1639
+ <p>Here’s an excellent visualization of the kernel from this <a href="https://siboehm.com/articles/22/CUDA-MMM">fantastic blogpost</a>: </p>
1640
+
1641
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1642
+
1643
+ <p>However, when profiling this kernel with a tool like <code>ncu</code>, we can see issues, including low memory throughput and uncoalesced memory accesses.</p>
1644
+
1645
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1646
+
1647
+
1648
+ <p>The reason for this is that in this kernel, two threads in the same block with Thread IDs <code>(0, 0)</code> and <code>(1, 0)</code> (which will end up in the same warp) will both load from the same column of matrix <code>B</code> but different rows of matrix <code>A</code>. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with <code>i = 0</code>, thread <code>(0, 0)</code> will load <d-math>A_{0,0}</d-math>, and thread <code>(1, 0)</code> will load <d-math>A_{1,0}</d-math>. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.</p>
1649
+
1650
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1651
+
1652
+
1653
+ <p>To improve our kernel we can change the way the coordinates x and y are calculated like the following : </p>
1654
+
1655
+ <d-code block language="clike">
1656
+ const int x = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
1657
+ const int y = blockIdx.y * BLOCKSIZE + (threadIdx.x % BLOCKSIZE);
1658
+
1659
+ if (x < M && y < N) {
1660
+ float tmp = 0.0;
1661
+ for (int i = 0; i < K; ++i) {
1662
+ tmp += A[x * K + i] * B[i * N + y];
1663
+ }
1664
+ C[x * N + y] = tmp;
1665
+ }
1666
+ </d-code>
1667
+
1668
+ <p>Instead of using a 2D block, we switch to a 1D block and redefine how we determine the values of <code>x</code> and <code>y</code>. In this new method, threads within the same warp (which have close <code>threadIdx.x</code> values) will share the same <code>x</code> value but have different <code>y</code> values. This means that they will load the same row of matrix <code>A</code> but different columns of matrix <code>B</code>. As a result, memory accesses can be coalesced for a row-major matrix.</p>
1669
+
1670
+ <p>When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and <strong>the GPU's memory throughput has increased by approximately 10 times</strong>.</p>
1671
+
1672
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1673
+
1674
+
1675
+ <p>We also notice that the execution time of the kernel <strong>decreases by 10x</strong> !</p>
1676
+ <p>Let’s cover another technique you will often see mentioned in the litterature: tiling.</p>
1677
+
1678
+
1679
  <h4>Tiling</h4>
1680
+
1681
+
1682
+ <p>Tiling is a technique that leverages <em>shared memory</em> to optimize memory access patterns. As we mentioned above, the shared memory is a small, fast memory accessible by all threads within a block. It allows data to be reused by multiple threads, reducing the need to repeatedly load data from slower global memory.</p>
1683
+
1684
+ <p>In matrix multiplication for example, each thread in a block may need elements from two matrices, say A and B. If each thread independently loads the row and column it needs from global memory, we end up with many redundant loads, as multiple threads in a block will access overlapping data. Instead, we can use tiling to load a block (or tile) of A and B into shared memory just once, allowing all threads in that block to reuse the same shared data.</p>
1685
+
1686
+ <p>In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size <code>BLOCK_SIZE_M</code> by <code>BLOCK_SIZE_K</code>) and a tile of matrix B (of size <code>BLOCK_SIZE_K</code> by <code>BLOCK_SIZE_N</code>). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.</p>
1687
+
1688
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1689
+ <p>From https://cnugteren.github.io/tutorial/pages/page4.html</p>
1690
+
1691
+ <p>The important parts to understand the implementation are below (for simplicity we consider a square shaped tile) : </p>
1692
+
1693
+ <d-code block language="clike">
1694
+ // Set pointers to the starting elements
1695
+ A += blockRow * TILE_SIZE * K; // Start at row = blockRow, column = 0
1696
+ B += blockCol * TILE_SIZE; // Start at row = 0, column = blockCol
1697
+ C += blockRow * TILE_SIZE * N + blockCol * TILE_SIZE; // Start at row = blockRow, column = blockCol
1698
+ float sum = 0.0;
1699
+ // The outer loop moves through tiles of A (across columns) and B (down rows)
1700
+ for (int tileIdx = 0; tileIdx < K; tileIdx += TILE_SIZE) {
1701
+ sharedA[localRow * TILE_SIZE + localCol] = A[localRow * K + localCol];
1702
+ sharedB[localRow * TILE_SIZE + localCol] = B[localRow * N + localCol];
1703
+
1704
+ // Ensure all threads in the block have completed data loading
1705
+ __syncthreads();
1706
+
1707
+ // Shift pointers to the next tile
1708
+ A += TILE_SIZE;
1709
+ B += TILE_SIZE * N;
1710
+
1711
+ // Compute the partial dot product for this tile
1712
+ for (int i = 0; i < TILE_SIZE; ++i) {
1713
+ sum += sharedA[localRow * TILE_SIZE + i] * sharedB[i * TILE_SIZE + localCol];
1714
+ }
1715
+ // Synchronize again to prevent any thread from loading new data
1716
+ // into shared memory before others have completed their calculations
1717
+ __syncthreads();
1718
+ }
1719
+ C[localRow * N + localCol] = sum;
1720
+ </d-code>
1721
+
1722
+ <p>Each thread begins by loading one element from both <strong>Matrix A</strong> and <strong>Matrix B</strong> into shared memory. In this scenario, achieving coalesced memory access is straightforward, by assigning <code>threadIdx.x</code> as the <strong>local column index (localCol)</strong>, threads within the same warp will access adjacent elements of both matrices. After each thread in the block completes loading its elements into shared memory (ensured by calling <code>__syncthreads()</code>), they proceed to compute the dot product of the two tiles. Once the threads have iterated through all the tiles—horizontally for <strong>Matrix A</strong> and vertically for <strong>Matrix B</strong>—the resulting sum is stored in the corresponding location of <strong>Matrix C</strong>.</p>
1723
 
1724
+ <p>When benchmarking this kernel using ncu, we noticed that the memory throughput increased to 410 Gb / s, and the kernel execution time decreased by ~43% achieving a ~6.6 TFLOPs performance</p>
1725
+
1726
+
1727
+
1728
  <h4>Thread Coarsening</h4>
1729
 
1730
+
1731
+ <p>The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:</p>
1732
+
1733
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1734
+
1735
+
1736
+ <p>The meaning of the states can be found in the <a href="https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference">Profiling Guide</a>, specifically in the <strong>Warp Stall Reasons</strong> section. There we can read that:</p>
1737
+
1738
+ <p><em><code>smsp__pcsamp_warps_issue_stalled_mio_throttle</code>: Warp was stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure.</em></p>
1739
+
1740
+ <p>So it seems warps are stalling waiting for shared memory accesses to return ! To resolve this issue we can apply the <strong>Thread Coarsening</strong> technique by merging several threads into a single coarsened thread, we can significantly reduce shared memory accesses because each coarsened thread can handle multiple output elements which would increase the arithmetic intensity of the kernel.</p>
1741
+
1742
  <h4>Minimizing Control Divergence</h4>
1743
 
1744
+ <p>A Streaming Multiprocessor (SM) is built to execute all threads in a warp using the Single Instruction, Multiple Data (SIMD) model. This means that at any given moment, one instruction is fetched and executed simultaneously for all threads within the warp. When a warp is executed, the threads within it operate on different segments of the data but follow the same instruction, hence the name Single Instruction, Multiple Data. The primary advantage of SIMD is its efficiency; the control hardware responsible for instruction fetching and dispatching is shared among multiple execution units. This design minimizes the hardware overhead associated with control functions, allowing a greater portion of the hardware to focus on improving arithmetic throughput.</p>
1745
+
1746
+ <p>Control divergence occurs when threads within the same warp take different execution paths. For instance, if a conditional statement (like an <code>if</code> statement) leads to some threads executing one block of code while others execute a different block, the warp must serialize these executions, resulting in idle threads waiting for others to complete. To minimize control divergence, we need to design kernels to ensure that threads within the same warp follow the same execution path. This can be achieved by restructuring code to reduce branching, using data structures that ensure all threads follow similar execution paths, or employing techniques such as predication.</p>
1747
+
1748
+ <p>We have covered some of the main considerations when writing custom kernels and improving the performance and memory footprint of GPU operations. But there’s one more important concept before moving to a real example which is “fusing kernels”.</p>
1749
+
1750
+ <h3>Fused Kernels</h3>
1751
+
1752
+ <p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
1753
+
1754
+ <p>Non-blocking can be useful for overlapping communication and computation as we saw at several part along this blog post but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands. This is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
1755
+
1756
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1757
+ <p>A sequence of kernels requiring back and forth between global memory and compute units</p>
1758
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1759
+ <p>Instead of sending our triangle back to global memory just to read it back again, we instead just do all of our operations in one go.</p>
1760
+
1761
+ <p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
1762
+
1763
+
1764
+ <p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values local until the succession of computation has been performed.</p>
1765
+
1766
+ <p>What are many places in a Transformer model were this can be advantageous, for instance when. a succession of point-wise operations is performed, e.g. in the computation involved in the Layer norms.</p>
1767
+
1768
+ <p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention</em></strong></p>
1769
+
1770
  <h3>Flash Attention 1-3</h3>
1771
 
1772
+ <p>Flash attention is a technique pioneered by <a href="https://tridao.me">Tri Dao</a> that optimizes the attention computations by writing custom CUDA kernels to make it much faster *and* more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid using too much the slowest global memory of the GPU (confusingly called the High Bandwidth Memory, HBM 🫠)</p>
1773
+
1774
+ <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
1775
+
1776
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1777
+
1778
+ <p>Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!</p>
1779
+
1780
+ <p>The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of <d-math>O</d-math> directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.</p>
1781
+
1782
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1783
+ <p>From the FLASH-ATTENTION paper<d-cite bibtex-key="dao2022flashattention"></d-cite></p>
1784
+
1785
+ <p>The idea of flash attention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers:</p>
1786
+ <ul>
1787
+ <li>By avoiding to materialize the S matrix we <strong>reduce the memory burden of attention</strong></li>
1788
+ <li>We also remove a large part of the <strong>naive impact of the S^2 cost of attention</strong></li>
1789
+ </ul>
1790
+
1791
+ <p>As a result as well, all variants of linear attention and sub-quadratic approaches to approximate attention –developed shortly after the invention of the transformers architecture– have been mostly put aside in favor of this exact and fast flash attention implementation and mechanism.</p>
1792
+
1793
+ <p>Following Flash-attention 1, two successive improved versions have been released by the same lab: Flash-attention 2 and 3. In comparison to Flash-attention 1, the improvements in Flash-attention 2 and 3 are less about the general attention mechanism than about tailoring its low level implementation more specifically to the GPU by (1) reducing the number of non-matmul operations as much as possible (2) partitioning carefully the workload among wraps and thread blocks (for Flash Attention 2) and carefully optimizing for FP8 and Tensor Core support on the latest Hopper (H100) architecture for Flash Attention 3.</p>
1794
+
1795
+ <aside>Flash attention puts some restrictions on which attention patterns can be sped up. Check out <a href="https://pytorch.org/blog/flexattention/">FlexAttention</a> which is a fast <em>and</em> flexible variant.</aside>
1796
+
1797
+ <p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p>
1798
+
1799
+ <p>The techniques described so far in this section require specific modeling code changes and writing custom kernels for certain operations in order to speed up training. In this section we take a look at a range of methods that are agnostic to the modeling code and can be used for any model!</p>
1800
 
1801
  <h3>Mixed Precision Training</h3>
1802
 
1803
+ <p>Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p>
1804
+
1805
+ <ul>
1806
+ <li>Sign: the first bit determines if the number is positive or negative</li>
1807
+ <li>Mantissa: determines the significant figures of a number</li>
1808
+ <li>Exponent: controls the magnitude of the number</li>
1809
+ </ul>
1810
+
1811
+ <p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. <d-math>- 5.734 \times 10^{7}</d-math>, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p>
1812
+
1813
+ <p></p>
1814
+
1815
+ <table>
1816
+ <thead>
1817
+ <tr>
1818
+ <th><strong>Format</strong></th>
1819
+ <th><strong>Total bits</strong></th>
1820
+ <th><strong>Sign</strong></th>
1821
+ <th><strong>Mantissa</strong></th>
1822
+ <th><strong>Exponent</strong></th>
1823
+ </tr>
1824
+ </thead>
1825
+ <tbody>
1826
+ <tr>
1827
+ <td>float32</td>
1828
+ <td>32</td>
1829
+ <td>1</td>
1830
+ <td>23</td>
1831
+ <td>8</td>
1832
+ </tr>
1833
+ <tr>
1834
+ <td>float16</td>
1835
+ <td>16</td>
1836
+ <td>1</td>
1837
+ <td>10</td>
1838
+ <td>5</td>
1839
+ </tr>
1840
+ <tr>
1841
+ <td>bfloat16</td>
1842
+ <td>16</td>
1843
+ <td>1</td>
1844
+ <td>7</td>
1845
+ <td>8</td>
1846
+ </tr>
1847
+ <tr>
1848
+ <td>float8 (e4m3)</td>
1849
+ <td>8</td>
1850
+ <td>1</td>
1851
+ <td>3</td>
1852
+ <td>4</td>
1853
+ </tr>
1854
+ <tr>
1855
+ <td>float8 (e5m2)</td>
1856
+ <td>8</td>
1857
+ <td>1</td>
1858
+ <td>2</td>
1859
+ <td>5</td>
1860
+ </tr>
1861
+ </tbody>
1862
+ </table>
1863
+
1864
+ <aside>Note: You might be wondering where the “b” in bfloat16 comes from. The format was developed at Google Brain and thus the “b” stands for “brain”. </aside>
1865
+
1866
+ <p>Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:</p>
1867
+
1868
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1869
+
1870
+
1871
+ <p>We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.</p>
1872
+
1873
+ <p>How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:</p>
1874
+
1875
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1876
+
1877
+ <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
1878
+
1879
+ <p>A common metric to measure a formats resolution is epsilon: the first representable number after 1.00. We can see that for the float32 format $10^{-4}$ is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it is <d-math>\tilde 10^{-3}</d-math> and for bfloat 10x higher still.</p>
1880
+
1881
+ <p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision.</p>
1882
+
1883
+ <p>This is why lower precision training is usually called <strong><em>mixed precision</em></strong> training. </p>
1884
+
1885
+ <p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p>
1886
+
1887
+
1888
+
1889
  <h4>FP16 and BF16 training</h4>
1890
+
1891
+ <p>Naively switching all the tensors and operations to float16 unfortunately doesn’t work and the result is usually diverging losses. However, the original mixed precision training paper<d-cite bitex-key="micikevicius2018mixedprecisiontraining"></d-cite> came up with three tricks to match float32 trainings:</p>
1892
+
1893
+ <ol>
1894
+ <li><strong>FP32 copy of weights</strong>: There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li>
1895
+ <li><strong>Loss scaling</strong>: We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li>
1896
+ <li><strong>Accumulation</strong>: Finally, when performing arithmetic operations in float16 such as in dot products, we can also face under or overflows. Does targeting certain types of arithmetic operations to accumulate the intermediate results in float32 during the operation and then casting the accumulated result back to fp16. For the same reason gradients are also accumulated in float32.</li>
1897
+ </ol>
1898
+
1899
+ <p>With these techniques, you get consistently stable training while benefitting from higher throughput due to the faster, lower precision operations. Naturally, as the curious reader you are and by now slightly addicted to maximizing the throughput, you ask the question: can we go further and faster? </p>
1900
+
1901
+ <p>Maybe!</p>
1902
 
1903
  <h4>FP8 pretraining</h4>
1904
 
1905
+ <p>Even if we perfectly overlap communication with computation, we always eventually run into the low level theoretical FLOPS limit of the hardware itself, i.e. the efficiency of each individual operation on our hardware. This is where numerical precision becomes crucial. For instance, on NVIDIA's H100 GPU, FP8 matrix multiplications (GEMM operations) achieve twice the theoretical FLOPS of bfloat16, making lower-precision training an attractive path for further optimization.</p>
1906
+
1907
+ <p>Recent research - including FP8-LM<d-cite bibtex-key="peng2023fp8lmtrainingfp8large"></d-cite>, torchao<d-cite bibtex-key="torchao"></d-cite>, and DeepSeek-V3<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite> - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.</p>
1908
+
1909
+ <p>We know that instability increases as learning rates rise for a fixed model size<d-cite bibtex-key="wortsman2023smallscaleproxieslargescaletransformer"></d-cite>, making FP8 pretraining particularly tricky.</p>
1910
+
1911
+ <p>The first, successful, very large scale training with FP8 mixed precision was publicly reported on DeepSeek-V3. The authors carefully analyzed each operation of the forward pass (Fprop) as well as the activation (Dgrad) and weight (Wgrad) backward pass. Similar to BF16 mixed precision training, some aggregation and master weights are kept in higher precision while the operations themselves are performed in FP8. </p>
1912
+
1913
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1914
+
1915
+ <p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of values by computing the absolute maximum. DeepSeek-V3 also introduces a quantization scheme, where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less susceptible to outliers. There is a number of additional tricks they deploy to also reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. </p>
1916
+
1917
+ <p>Here’s a summary of a few known approaches to FP8 training:</p>
1918
+
1919
+ <table>
1920
+ <thead>
1921
+ <tr>
1922
+ <th></th>
1923
+ <th>GEMM's precision</th>
1924
+ <th>Master model weights</th>
1925
+ <th>Accumulated gradients</th>
1926
+ <th>Model weights</th>
1927
+ <th>Gradients</th>
1928
+ <th>Optimizer States</th>
1929
+ <th>Total Memory</th>
1930
+ </tr>
1931
+ </thead>
1932
+ <tbody>
1933
+ <tr>
1934
+ <td>bfloat16 with fp32 mixed precision baseline</td>
1935
+ <td>bf16</td>
1936
+ <td>fp32</td>
1937
+ <td>fp32</td>
1938
+ <td>bf16</td>
1939
+ <td>bf16</td>
1940
+ <td>fp32 + fp32</td>
1941
+ <td>4 + 4 + 2 + 2 + 4 + 4 = 20 bytes</td>
1942
+ </tr>
1943
+ <tr>
1944
+ <td>Above without FP32 grad accumulation</td>
1945
+ <td>bf16</td>
1946
+ <td>fp32</td>
1947
+ <td></td>
1948
+ <td>bf16</td>
1949
+ <td>bf16</td>
1950
+ <td>fp32 + fp32</td>
1951
+ <td>4 + 2 + 2 + 4 + 4 = 16 bytes</td>
1952
+ </tr>
1953
+ <tr>
1954
+ <td>Transformer Engine</td>
1955
+ <td>fp8</td>
1956
+ <td></td>
1957
+ <td></td>
1958
+ <td>fp32</td>
1959
+ <td>fp32</td>
1960
+ <td>fp32 + fp32</td>
1961
+ <td>4 + 4 + 4 + 4 = 16 bytes (20% reduction)</td>
1962
+ </tr>
1963
+ <tr>
1964
+ <td>FP8-LM's O3 level</td>
1965
+ <td>fp8</td>
1966
+ <td>fp16</td>
1967
+ <td>fp16</td>
1968
+ <td>fp8</td>
1969
+ <td>fp8</td>
1970
+ <td>fp8 + fp16</td>
1971
+ <td>2 + 2 + 1 + 1 + 1 + 2 = 9 bytes (55%)</td>
1972
+ </tr>
1973
+ <tr>
1974
+ <td>DeepSeek-V3</td>
1975
+ <td>fp8</td>
1976
+ <td>fp32</td>
1977
+ <td>fp32</td>
1978
+ <td>fp8</td>
1979
+ <td>bf16</td>
1980
+ <td>bf16 + bf16</td>
1981
+ <td>4+4+1+2+2+2 = 15 (25%)</td>
1982
+ </tr>
1983
+ <tr>
1984
+ <td>nanotron's FP8</td>
1985
+ <td>fp8</td>
1986
+ <td>bf16</td>
1987
+ <td>fp32</td>
1988
+ <td>fp8</td>
1989
+ <td>fp8</td>
1990
+ <td>fp8 + fp8</td>
1991
+ <td>2 + 4 + 1 + 1 + 1 + 1 = 10 bytes (50%)</td>
1992
+ </tr>
1993
+ </tbody>
1994
+ </table>
1995
+
1996
+ <p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow public implementations of this, please head to the nanotron’s implementation in [TODO: link to appendix]. </p>
1997
+
1998
+ <p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
1999
+
2000
+ <p>We now arrived at the end of the distributed training journey. Let’s take a step back and conclude.</p>
2001
+
2002
  <h2>Conclusion</h2>
2003
 
2004
+
2005
+ <p>Congratulations! You've completed quite a journey - from understanding how to train a simple model on a single GPU, all the way to mastering the complex techniques used to efficiently train massive language models like Llama-405B and DeepSeek-V3. By now, you should feel confident interpreting advanced parallelism diagrams like the one below, which would have seemed daunting when you first started.</p>
2006
+
2007
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2008
+
2009
+ <p>In distributed training, many concepts sound easy enough when you first hear them, like “Pipeline parallelism just distributes layers on different GPUs”, but we also worked through all the challenging details when implementing those methods. </p>
2010
+
2011
+ <p>However, not only did you learn something in the process, but we also want to share some insights we gained along the way, as well as give you ideas on what to work on next if you want to gain more experience in distributed training.</p>
2012
+
2013
+ <p>Let’s start with a brief recap of all the things we covered in these past hours and days!</p>
2014
+
2015
  <h3>What you learned</h3>
2016
 
2017
+ <p>Working through this whole blog post you mastered a ranged of concepts:</p>
2018
+
2019
+ <ul>
2020
+ <li>Basic principle of model training</li>
2021
+ <li>Collective communication primitives </li>
2022
+ <li>Memory anatomy of a LLM</li>
2023
+ <li>Distributed training with DP and ZeRO </li>
2024
+ <li>Model parallelism with TP, SP, CP and PP</li>
2025
+ <li>Fast kernels and mixed precision training</li>
2026
+ <li>Overlapping communication and computation</li>
2027
+ <li>Profiling distributed training</li>
2028
+ </ul>
2029
+
2030
+ <p>Furthermore, you saw code implementations of most methods and how to benchmark a distributed training. But it hasn’t been only a learning experience for you, also we learned a thing or two!</p>
2031
+
2032
  <h3>What we learned</h3>
2033
+
2034
+ <p>Running benchmarks on a cluster turned out to be much more challenging than we initially expected! What seemed like straightforward tasks often became complex debugging sessions:
2035
+ </p>
2036
+
2037
+ <ul>
2038
+ <li>PyTorch processes would sometimes fail to clean up properly</li>
2039
+ <li>Slurm job manager would forcefully terminate jobs, leading to node failures </li>
2040
+ <li>Simple benchmarks that should take minutes would stretch into hours</li>
2041
+ <li>We had to spend significant time:</li>
2042
+ <ul>
2043
+ <li>Minimizing cluster restart times and optimize idle time</li>
2044
+ <li>Analyzing detailed NCCL debug logs</li>
2045
+ <li>Understand memory usage patterns and CUDA memory allocator behaviors</li>
2046
+ <li>Improving pipeline parallelism performance on multi-node</li>
2047
+ </ul>
2048
+ </ul>
2049
 
2050
+ <p>These challenges deserve their own story, but they taught us valuable lessons about the complexities of distributed training infrastructure. What looks simple in theory often requires careful attention to many moving parts in practice.</p>
2051
+
2052
+ <p>Let's analyze the results of our benchmarks and understand how different configurations affect each other. All benchmarks were run with a sequence length of 4096 and a global batch size of 1M tokens. We'll look at two key visualizations that help illustrate our findings.
2053
+ </p>
2054
+
2055
+ <p>First, let's examine this heatmap visualization:</p>
2056
+
2057
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2058
+ <p>Heatmap visualization showing the optimal training configurations across different model sizes and compute node counts. For each combination, the configuration details include Data Parallelism (DP), Tensor Parallelism (TP), Pipeline Parallelism (PP), Gradient Accumulation Steps (GAS), Micro Batch Size (MBS), and ZeRO optimization stage. The color intensity indicates the Model FLOPs Utilization (MFU), with brighter colors representing higher efficiency.</p>
2059
+
2060
+ <p>To complement this, let's look at the relationships between different parameters:</p>
2061
+
2062
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2063
+ <p>Parallel coordinates plot showing the relationship between different model parallelism configurations (Data Parallel degree, Tensor Parallel degree, Pipeline Parallel degree), training hyperparameters (gradient accumulation steps, micro batch size), ZeRO stage and the resulting Model FLOPs Utilization (MFU). Each line represents a different training configuration, with colors indicating the MFU value - warmer colors show higher efficiency.</p>
2064
+
2065
+ <p>From these visualizations, we can draw several important insights:
2066
+ </p>
2067
+
2068
+ <ol>
2069
+ <li>As we increase the number of nodes (higher parallelism), we observe a decrease in efficiency. This effect is particularly pronounced for smaller models, which have a lower compute-to-model-size ratio. While we might typically compensate for small model size by increasing the batch size, we're constrained by our global batch size limit of 1M.
2070
+ </li>
2071
+ <li>Larger models present a different challenge. As model size increases, memory requirements grow substantially. This creates two scenarios with fewer nodes: either the model doesn't fit at all, or it barely fits but runs inefficiently due to operating near the GPU memory limits.</li>
2072
+ <li>Our benchmarks demonstrate how performance heavily depends on implementation quality. When we first implemented both parallelism strategies, Tensor Parallelism (TP) outperformed Pipeline Parallelism (PP). After optimizing our PP code, it became the faster option. Now that we're improving the communication overlap in our TP implementation, we expect it to regain the performance lead.</li>
2073
+ </ol>
2074
+
2075
+ <p>These findings highlight the challenges of reproducing theoretical results in practice, especially given the limited availability of production training code. Through open-source projects like picotron and nanotron, we hope to make these distributed training techniques more accessible and foster collaboration on simpler, more efficient codebases that help researchers and practitioners make the most of their hardware resources.</p>
2076
+
2077
  <h3>What’s next?</h3>
2078
+
2079
+ <p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p>
2080
+
2081
+ <ul>
2082
+ <li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in [TODO References]</li>
2083
+ <li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li>
2084
+ <li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li>
2085
+ </ul>
2086
+
2087
+ <p>We hope this blog helps you get started in distributed training or helps you to better understand methods that you may already be applying by using some distributed training frameworks.</p>
2088
 
2089
  <h2>References</h2>
2090
 
 
2138
  }</pre>
2139
  </d-appendix>
2140
 
2141
+ <script>
 
2142
  const article = document.querySelector('d-article');
2143
  const toc = document.querySelector('d-contents');
2144
  if (toc) {
src/bibliography.bib CHANGED
@@ -466,4 +466,48 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
466
  archivePrefix={arXiv},
467
  primaryClass={cs.CL},
468
  url={https://arxiv.org/abs/2006.16668},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  }
 
466
  archivePrefix={arXiv},
467
  primaryClass={cs.CL},
468
  url={https://arxiv.org/abs/2006.16668},
469
+ }
470
+ @misc{dao2022flashattention,
471
+ title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
472
+ author={Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher Ré},
473
+ year={2022},
474
+ eprint={2205.14135},
475
+ archivePrefix={arXiv},
476
+ primaryClass={cs.LG},
477
+ url={https://arxiv.org/abs/2205.14135},
478
+ }
479
+ @misc{micikevicius2018mixedprecisiontraining,
480
+ title={Mixed Precision Training},
481
+ author={Paulius Micikevicius and Sharan Narang and Jonah Alben and Gregory Diamos and Erich Elsen and David Garcia and Boris Ginsburg and Michael Houston and Oleksii Kuchaiev and Ganesh Venkatesh and Hao Wu},
482
+ year={2018},
483
+ eprint={1710.03740},
484
+ archivePrefix={arXiv},
485
+ primaryClass={cs.AI},
486
+ url={https://arxiv.org/abs/1710.03740},
487
+ }
488
+ @software{torchao,
489
+ title = {torchao: PyTorch native quantization and sparsity for training and inference},
490
+ author = {torchao maintainers and contributors},
491
+ url = {https://github.com/pytorch/torchao},
492
+ license = {BSD-3-Clause},
493
+ month = oct,
494
+ year = {2024}
495
+ }
496
+ @misc{peng2023fp8lmtrainingfp8large,
497
+ title={FP8-LM: Training FP8 Large Language Models},
498
+ author={Houwen Peng and Kan Wu and Yixuan Wei and Guoshuai Zhao and Yuxiang Yang and Ze Liu and Yifan Xiong and Ziyue Yang and Bolin Ni and Jingcheng Hu and Ruihang Li and Miaosen Zhang and Chen Li and Jia Ning and Ruizhe Wang and Zheng Zhang and Shuguang Liu and Joe Chau and Han Hu and Peng Cheng},
499
+ year={2023},
500
+ eprint={2310.18313},
501
+ archivePrefix={arXiv},
502
+ primaryClass={cs.LG},
503
+ url={https://arxiv.org/abs/2310.18313},
504
+ }
505
+ @misc{wortsman2023smallscaleproxieslargescaletransformer,
506
+ title={Small-scale proxies for large-scale Transformer training instabilities},
507
+ author={Mitchell Wortsman and Peter J. Liu and Lechao Xiao and Katie Everett and Alex Alemi and Ben Adlam and John D. Co-Reyes and Izzeddin Gur and Abhishek Kumar and Roman Novak and Jeffrey Pennington and Jascha Sohl-dickstein and Kelvin Xu and Jaehoon Lee and Justin Gilmer and Simon Kornblith},
508
+ year={2023},
509
+ eprint={2309.14322},
510
+ archivePrefix={arXiv},
511
+ primaryClass={cs.LG},
512
+ url={https://arxiv.org/abs/2309.14322},
513
  }
src/index.html CHANGED
@@ -401,7 +401,7 @@
401
 
402
  <p>Is there a way to tame this “activation explosion”? Good question, reader!</p>
403
 
404
- <p>It’s time to explain our first technique – called <strong><em>activation recomputation</em><em>–</em> </strong>**which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.</p>
405
 
406
  <h3>Activation recomputation</h3>
407
 
@@ -565,7 +565,7 @@
565
 
566
  <p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p>
567
 
568
- <aside>Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by <em>ring latency</em> (time required for a signal to propagate once around the ring) **which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
569
  </aside>
570
 
571
  <p>While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:</p>
@@ -702,9 +702,9 @@
702
 
703
 
704
 
705
- <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same ***reduce-scatter*** as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
706
 
707
- <p>Thankfully, although we added many more communication operations, **prefetching** helps us overlap them efficiently by all-gathering weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, by all-gathering weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
708
 
709
  <p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in earlier chapters.</p>
710
 
@@ -1480,46 +1480,611 @@
1480
 
1481
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
1482
 
1483
- <p></p>
 
1484
 
1485
- <p></p>
1486
 
1487
- <p></p>
 
1488
 
1489
- <p></p>
1490
 
1491
- <p></p>
 
 
1492
 
1493
  <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
 
 
 
 
 
 
 
 
 
 
 
 
1494
 
 
1495
 
1496
  <h3>How to improve performance with Kernels ?</h3>
1497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1498
  <h4>Memory Coalescing</h4>
1499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1500
  <h4>Tiling</h4>
 
 
 
 
 
1501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1502
  <h4>Thread Coarsening</h4>
1503
 
 
 
 
 
 
 
 
 
 
 
 
 
1504
  <h4>Minimizing Control Divergence</h4>
1505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1506
  <h3>Flash Attention 1-3</h3>
1507
 
1508
- <h3>Fused Kernels</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1509
 
1510
  <h3>Mixed Precision Training</h3>
1511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1512
  <h4>FP16 and BF16 training</h4>
 
 
 
 
 
 
 
 
 
 
 
 
1513
 
1514
  <h4>FP8 pretraining</h4>
1515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1516
  <h2>Conclusion</h2>
1517
 
 
 
 
 
 
 
 
 
 
 
 
1518
  <h3>What you learned</h3>
1519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1520
  <h3>What we learned</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1522
  <h3>What’s next?</h3>
 
 
 
 
 
 
 
 
 
 
1523
 
1524
  <h2>References</h2>
1525
 
 
401
 
402
  <p>Is there a way to tame this “activation explosion”? Good question, reader!</p>
403
 
404
+ <p>It’s time to explain our first technique – called <strong><em>activation recomputation</em><em>–</em> </strong>which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.</p>
405
 
406
  <h3>Activation recomputation</h3>
407
 
 
565
 
566
  <p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p>
567
 
568
+ <aside>Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by <em>ring latency</em> (time required for a signal to propagate once around the ring) which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
569
  </aside>
570
 
571
  <p>While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:</p>
 
702
 
703
 
704
 
705
+ <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
706
 
707
+ <p>Thankfully, although we added many more communication operations, <strong>prefetching</strong> helps us overlap them efficiently by all-gathering weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, by all-gathering weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
708
 
709
  <p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in earlier chapters.</p>
710
 
 
1480
 
1481
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
1482
 
1483
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1484
+ <p>TODO: Original figure from https://blog.codingconfessions.com/p/gpu-computing.</p>
1485
 
1486
+ <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
1487
 
1488
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1489
+ <p>TODO: Original figure from https://www.youtube.com/watch?v=ZQKMZIP3Fzg</p>
1490
 
1491
+ <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
1492
 
1493
+ <p>A piece of code running on a core of the GPU is called a <strong>kernel</strong>. It can be written at a high-level in <strong>CUDA</strong> or <strong>Triton</strong> for instance, and is then compiled to Parallel Thread Execution, PTX, the low-level assembly used by NVIDIA GPUs.</p>
1494
+
1495
+ <p>To run the kernel, you will also need a specific code part, called <strong>host code</strong>, which is executed on the <strong>CPU/host</strong> and will take care of preparing data allocations and loading data and code.</p>
1496
 
1497
  <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1498
+ <p>Figure 5: Host code for a CUDA kernel for adding two vectors from https://blog.codingconfessions.com/p/gpu-computing</p>
1499
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1500
+ <p>Figure 6: Device code containing the definition of the vector addition kernel from https://blog.codingconfessions.com/p/gpu-computing</p>
1501
+
1502
+ <p>Kernels are generally scheduled as follow:</p>
1503
+
1504
+ <ul>
1505
+ <li>threads are grouped in <strong>warps</strong> of sizes of 32. All the threads in a warp are synchronized to execute instructions simultaneously but on different parts of the data.</li>
1506
+ <li><strong>warps</strong> are grouped in larger <strong>blocks</strong> of more flexible size (e.g. size 256), each block still being assigned to a single SM. An SM may run several blocks in parallel, however, depending on the resources, not all the blocks may get assigned for execution immediately, some can be waitlisted waiting for resources.</li>
1507
+ </ul>
1508
+
1509
+ <p>The main thing to remember from these details is that there are various sizing and allocation constraints (size of the various memories, number of concurrent block and threads in the wraps) which need to be taken into account to use the GPU architecture in the most efficient way.</p>
1510
 
1511
+ <p>Most of the time you don’t need to go down to this level of precision and you can luckily reuse the kernels and code prepared by other members of the community. But in any case we want to give you a primer on how to get started with kernels! </p>
1512
 
1513
  <h3>How to improve performance with Kernels ?</h3>
1514
 
1515
+
1516
+ <p>If you’re looking to add a new operation that lacks an optimized kernel or to speed up an existing PyTorch function, writing kernels from scratch might seem like the most direct route. However, creating high-performance CUDA kernels from scratch requires extensive experience and a steep learning curve. Generally a better way to get started is to leverage <code>torch.compile</code>, which dynamically optimizes PyTorch code by capturing your operations and generating lower-level, high-performance kernels in triton.</p>
1517
+
1518
+ <p>Let’s suppose you want to write a kernel for an activation function called Exponential Linear Unit:</p>
1519
+
1520
+ <d-math block>
1521
+ \text{ELU}(x) = \begin{cases}
1522
+ e^x - 1 & \text{if } x < 0 \\
1523
+ x & \text{if } x \geq 0
1524
+ \end{cases}
1525
+ </d-math>
1526
+ <p>TODO: something off with spacing but seems the rendering engine</p>
1527
+
1528
+ <p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p>
1529
+
1530
+ <d-code block language="python">
1531
+ @torch.compile
1532
+ def elu(x, alpha=1.0):
1533
+ return torch.where(x < 0, alpha * (torch.exp(x) - 1), x)
1534
+ </d-code>
1535
+
1536
+ <p>The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns):</p>
1537
+
1538
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1539
+
1540
+
1541
+ <p>However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by @torch.compile . To do so, you simply need to set the environment variable <code>TORCH_LOGS</code> to <code>"output_code"</code>:</p>
1542
+
1543
+ <d-code block language="bash">
1544
+ export TORCH_LOGS="output_code"
1545
+ </d-code>
1546
+
1547
+ <p>Once you run the Python script with the <code>@torch.compile</code> decorator, it will generate and output the corresponding Triton kernel, which, in this case, is:</p>
1548
+
1549
+ <d-code block language="python">
1550
+ @triton.jit
1551
+ def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
1552
+ xnumel = 100000000
1553
+ xoffset = tl.program_id(0) * XBLOCK
1554
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
1555
+ xmask = xindex < xnumel
1556
+ x0 = xindex
1557
+ tmp0 = tl.load(in_ptr0 + (x0), xmask)
1558
+ tmp1 = 0.0
1559
+ tmp2 = tmp0 < tmp1
1560
+ tmp3 = tl_math.exp(tmp0)
1561
+ tmp4 = 1.0
1562
+ tmp5 = tmp3 - tmp4
1563
+ tmp6 = tl.where(tmp2, tmp5, tmp0)
1564
+ tl.store(out_ptr0 + (x0), tmp6, xmask)
1565
+ </d-code>
1566
+
1567
+ <p>To enhance readability, we can modify the variable names, add comments, and make slight adjustments, as demonstrated below:</p>
1568
+
1569
+ <d-code block language="python">
1570
+ @triton.jit
1571
+ def elu_kernel(input_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
1572
+ # Calculate the starting index for this block
1573
+ block_start = tl.program_id(0) * BLOCK_SIZE
1574
+ # Create an array of indices for this block
1575
+ block_indices = block_start + tl.arange(0, BLOCK_SIZE)[:]
1576
+ # Create a mask to ensure only valid indices are processed
1577
+ valid_mask = block_indices < num_elements
1578
+ # Load input values from the input pointer based on valid indices
1579
+ input_values = tl.load(input_ptr + block_indices, valid_mask)
1580
+ # Define the ELU parameters
1581
+ zero_value = 0.0 # Threshold for ELU activation
1582
+ negative_mask = input_values < zero_value
1583
+ exp_values = tl.math.exp(input_values)
1584
+ # Define the ELU output shift
1585
+ one_value = 1.0
1586
+ shifted_exp_values = exp_values - one_value
1587
+
1588
+ output_values = tl.where(negative_mask, shifted_exp_values, input_values)
1589
+
1590
+ # Store the computed output values back to the output pointer
1591
+ tl.store(output_ptr + block_indices, output_values, valid_mask)
1592
+ </d-code>
1593
+
1594
+ <p>Here, <code>tl.program_id(0)</code> provides a unique block ID, that we use to determine which section of data that block will process. Using this block ID, <code>block_start</code> calculates the starting index for each block’s section, while <code>block_indices</code> specifies the range of indices within that section. A <code>valid_mask</code> ensures that only indices within <code>num_elements</code> are processed, safely loading the data with <code>tl.load</code>. The ELU function is then applied, modifying values based on whether they're negative, and results are written back to memory with <code>tl.store</code>.</p>
1595
+
1596
+ <p>When we benchmark the generated kernel using <code>triton.testing.Benchmark</code> we have the following performance:</p>
1597
+
1598
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1599
+
1600
+ <p>This standalone kernel demonstrates superior performance with smaller sizes compared to <code>@torch.compile</code> but this is likely here just an artifact from the compilation time of <code>torch.compile</code>. In any case, instead of starting from scratch, we can focus on optimizing this generated kernel, saving us time in the process. </p>
1601
+
1602
+ <p>However, in Triton, sometimes, we cannot fully achieve the peak performance of the device due to limitations in handling shared memory and scheduling within streaming multiprocessors (SMs). Our access is restricted to blocks, allowing us only to manage the scheduling of blocks across SMs. To gain even more control, we will need to implement kernels in CUDA, where we have access to all the underlying components.</p>
1603
+
1604
+ <p>In CUDA, there are various techniques that can be employed to make kernels more efficient; we will present just a few. These include optimizing memory access patterns to reduce latency, using shared memory to store frequently accessed data, and managing thread workloads to minimize idle times. In summary, the tools for writing code to execute instructions on the GPU are:</p>
1605
+
1606
+ <ul>
1607
+ <li>Pytorch: easy but slow</li>
1608
+ <li>torch.compile: easy, fast, but not flexible</li>
1609
+ <li>triton: harder, faster, and more flexible</li>
1610
+ <li>CUDA: hardest, fastest, and flexiblest (if you get it right)</li>
1611
+
1612
+ </ul>
1613
+
1614
+ <p>Let’s talk about one of the most frequent technique we can use: optimizing memory access. The global memory in GPUs (the largest memory in our above graph) has a long latency and low bandwidth in comparison to the cache which often creates a major bottleneck for most applications. Efficiently accessing data from global memory can improve a lot the performance.</p>
1615
+
1616
  <h4>Memory Coalescing</h4>
1617
 
1618
+ <p>To effectively utilize the bandwidth of global memory, it is essential to understand its architecture. In CUDA devices, global memory is implemented using DRAM.</p>
1619
+
1620
+ <p>Memory coalescing takes advantage of how DRAM delivers data in bursts, or ranges of consecutive memory locations, whenever a memory address is accessed. Each time a DRAM location is accessed, a sequence of consecutive locations, including the requested one, is read in parallel by multiple sensors in the DRAM chip. Once read, this data can then be quickly transferred to the processor as a burst. In CUDA, coalescing uses this burst behavior to maximize memory access efficiency by ensuring that threads in a warp—32 threads that execute the same instruction in lockstep (SIMD)—access consecutive memory locations. For instance, if thread 0 accesses location M, thread 1 accesses M + 1, thread 2 accesses M + 2, and so forth, the GPU hardware coalesces or combines these requests into one large, efficient access request for the DRAM burst, rather than handling each access individually. </p>
1621
+
1622
+ <p>Let’s take the example of matrix multiplication. A simple, straightforward implementation would have each thread compute a single element of the output matrix, like this:</p>
1623
+
1624
+ <d-code block language="clike">
1625
+ __global__ void matmul_naive(int M, int N, int K, const float *A, const float *B, float *C) {
1626
+ const uint x = blockIdx.x * blockDim.x + threadIdx.x;
1627
+ const uint y = blockIdx.y * blockDim.y + threadIdx.y;
1628
+
1629
+ if (x < M && y < N) {
1630
+ float tmp = 0.0;
1631
+ for (int i = 0; i < K; ++i) {
1632
+ tmp += A[x * K + i] * B[i * N + y];
1633
+ }
1634
+ C[x * N + y] = tmp;
1635
+ }
1636
+ }
1637
+ </d-code>
1638
+
1639
+ <p>Here’s an excellent visualization of the kernel from this <a href="https://siboehm.com/articles/22/CUDA-MMM">fantastic blogpost</a>: </p>
1640
+
1641
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1642
+
1643
+ <p>However, when profiling this kernel with a tool like <code>ncu</code>, we can see issues, including low memory throughput and uncoalesced memory accesses.</p>
1644
+
1645
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1646
+
1647
+
1648
+ <p>The reason for this is that in this kernel, two threads in the same block with Thread IDs <code>(0, 0)</code> and <code>(1, 0)</code> (which will end up in the same warp) will both load from the same column of matrix <code>B</code> but different rows of matrix <code>A</code>. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with <code>i = 0</code>, thread <code>(0, 0)</code> will load <d-math>A_{0,0}</d-math>, and thread <code>(1, 0)</code> will load <d-math>A_{1,0}</d-math>. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.</p>
1649
+
1650
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1651
+
1652
+
1653
+ <p>To improve our kernel we can change the way the coordinates x and y are calculated like the following : </p>
1654
+
1655
+ <d-code block language="clike">
1656
+ const int x = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
1657
+ const int y = blockIdx.y * BLOCKSIZE + (threadIdx.x % BLOCKSIZE);
1658
+
1659
+ if (x < M && y < N) {
1660
+ float tmp = 0.0;
1661
+ for (int i = 0; i < K; ++i) {
1662
+ tmp += A[x * K + i] * B[i * N + y];
1663
+ }
1664
+ C[x * N + y] = tmp;
1665
+ }
1666
+ </d-code>
1667
+
1668
+ <p>Instead of using a 2D block, we switch to a 1D block and redefine how we determine the values of <code>x</code> and <code>y</code>. In this new method, threads within the same warp (which have close <code>threadIdx.x</code> values) will share the same <code>x</code> value but have different <code>y</code> values. This means that they will load the same row of matrix <code>A</code> but different columns of matrix <code>B</code>. As a result, memory accesses can be coalesced for a row-major matrix.</p>
1669
+
1670
+ <p>When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and <strong>the GPU's memory throughput has increased by approximately 10 times</strong>.</p>
1671
+
1672
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1673
+
1674
+
1675
+ <p>We also notice that the execution time of the kernel <strong>decreases by 10x</strong> !</p>
1676
+ <p>Let’s cover another technique you will often see mentioned in the litterature: tiling.</p>
1677
+
1678
+
1679
  <h4>Tiling</h4>
1680
+
1681
+
1682
+ <p>Tiling is a technique that leverages <em>shared memory</em> to optimize memory access patterns. As we mentioned above, the shared memory is a small, fast memory accessible by all threads within a block. It allows data to be reused by multiple threads, reducing the need to repeatedly load data from slower global memory.</p>
1683
+
1684
+ <p>In matrix multiplication for example, each thread in a block may need elements from two matrices, say A and B. If each thread independently loads the row and column it needs from global memory, we end up with many redundant loads, as multiple threads in a block will access overlapping data. Instead, we can use tiling to load a block (or tile) of A and B into shared memory just once, allowing all threads in that block to reuse the same shared data.</p>
1685
 
1686
+ <p>In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size <code>BLOCK_SIZE_M</code> by <code>BLOCK_SIZE_K</code>) and a tile of matrix B (of size <code>BLOCK_SIZE_K</code> by <code>BLOCK_SIZE_N</code>). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.</p>
1687
+
1688
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1689
+ <p>From https://cnugteren.github.io/tutorial/pages/page4.html</p>
1690
+
1691
+ <p>The important parts to understand the implementation are below (for simplicity we consider a square shaped tile) : </p>
1692
+
1693
+ <d-code block language="clike">
1694
+ // Set pointers to the starting elements
1695
+ A += blockRow * TILE_SIZE * K; // Start at row = blockRow, column = 0
1696
+ B += blockCol * TILE_SIZE; // Start at row = 0, column = blockCol
1697
+ C += blockRow * TILE_SIZE * N + blockCol * TILE_SIZE; // Start at row = blockRow, column = blockCol
1698
+ float sum = 0.0;
1699
+ // The outer loop moves through tiles of A (across columns) and B (down rows)
1700
+ for (int tileIdx = 0; tileIdx < K; tileIdx += TILE_SIZE) {
1701
+ sharedA[localRow * TILE_SIZE + localCol] = A[localRow * K + localCol];
1702
+ sharedB[localRow * TILE_SIZE + localCol] = B[localRow * N + localCol];
1703
+
1704
+ // Ensure all threads in the block have completed data loading
1705
+ __syncthreads();
1706
+
1707
+ // Shift pointers to the next tile
1708
+ A += TILE_SIZE;
1709
+ B += TILE_SIZE * N;
1710
+
1711
+ // Compute the partial dot product for this tile
1712
+ for (int i = 0; i < TILE_SIZE; ++i) {
1713
+ sum += sharedA[localRow * TILE_SIZE + i] * sharedB[i * TILE_SIZE + localCol];
1714
+ }
1715
+ // Synchronize again to prevent any thread from loading new data
1716
+ // into shared memory before others have completed their calculations
1717
+ __syncthreads();
1718
+ }
1719
+ C[localRow * N + localCol] = sum;
1720
+ </d-code>
1721
+
1722
+ <p>Each thread begins by loading one element from both <strong>Matrix A</strong> and <strong>Matrix B</strong> into shared memory. In this scenario, achieving coalesced memory access is straightforward, by assigning <code>threadIdx.x</code> as the <strong>local column index (localCol)</strong>, threads within the same warp will access adjacent elements of both matrices. After each thread in the block completes loading its elements into shared memory (ensured by calling <code>__syncthreads()</code>), they proceed to compute the dot product of the two tiles. Once the threads have iterated through all the tiles—horizontally for <strong>Matrix A</strong> and vertically for <strong>Matrix B</strong>—the resulting sum is stored in the corresponding location of <strong>Matrix C</strong>.</p>
1723
+
1724
+ <p>When benchmarking this kernel using ncu, we noticed that the memory throughput increased to 410 Gb / s, and the kernel execution time decreased by ~43% achieving a ~6.6 TFLOPs performance</p>
1725
+
1726
+
1727
+
1728
  <h4>Thread Coarsening</h4>
1729
 
1730
+
1731
+ <p>The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:</p>
1732
+
1733
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1734
+
1735
+
1736
+ <p>The meaning of the states can be found in the <a href="https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference">Profiling Guide</a>, specifically in the <strong>Warp Stall Reasons</strong> section. There we can read that:</p>
1737
+
1738
+ <p><em><code>smsp__pcsamp_warps_issue_stalled_mio_throttle</code>: Warp was stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure.</em></p>
1739
+
1740
+ <p>So it seems warps are stalling waiting for shared memory accesses to return ! To resolve this issue we can apply the <strong>Thread Coarsening</strong> technique by merging several threads into a single coarsened thread, we can significantly reduce shared memory accesses because each coarsened thread can handle multiple output elements which would increase the arithmetic intensity of the kernel.</p>
1741
+
1742
  <h4>Minimizing Control Divergence</h4>
1743
 
1744
+ <p>A Streaming Multiprocessor (SM) is built to execute all threads in a warp using the Single Instruction, Multiple Data (SIMD) model. This means that at any given moment, one instruction is fetched and executed simultaneously for all threads within the warp. When a warp is executed, the threads within it operate on different segments of the data but follow the same instruction, hence the name Single Instruction, Multiple Data. The primary advantage of SIMD is its efficiency; the control hardware responsible for instruction fetching and dispatching is shared among multiple execution units. This design minimizes the hardware overhead associated with control functions, allowing a greater portion of the hardware to focus on improving arithmetic throughput.</p>
1745
+
1746
+ <p>Control divergence occurs when threads within the same warp take different execution paths. For instance, if a conditional statement (like an <code>if</code> statement) leads to some threads executing one block of code while others execute a different block, the warp must serialize these executions, resulting in idle threads waiting for others to complete. To minimize control divergence, we need to design kernels to ensure that threads within the same warp follow the same execution path. This can be achieved by restructuring code to reduce branching, using data structures that ensure all threads follow similar execution paths, or employing techniques such as predication.</p>
1747
+
1748
+ <p>We have covered some of the main considerations when writing custom kernels and improving the performance and memory footprint of GPU operations. But there’s one more important concept before moving to a real example which is “fusing kernels”.</p>
1749
+
1750
+ <h3>Fused Kernels</h3>
1751
+
1752
+ <p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
1753
+
1754
+ <p>Non-blocking can be useful for overlapping communication and computation as we saw at several part along this blog post but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands. This is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
1755
+
1756
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1757
+ <p>A sequence of kernels requiring back and forth between global memory and compute units</p>
1758
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1759
+ <p>Instead of sending our triangle back to global memory just to read it back again, we instead just do all of our operations in one go.</p>
1760
+
1761
+ <p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
1762
+
1763
+
1764
+ <p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values local until the succession of computation has been performed.</p>
1765
+
1766
+ <p>What are many places in a Transformer model were this can be advantageous, for instance when. a succession of point-wise operations is performed, e.g. in the computation involved in the Layer norms.</p>
1767
+
1768
+ <p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention</em></strong></p>
1769
+
1770
  <h3>Flash Attention 1-3</h3>
1771
 
1772
+ <p>Flash attention is a technique pioneered by <a href="https://tridao.me">Tri Dao</a> that optimizes the attention computations by writing custom CUDA kernels to make it much faster *and* more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid using too much the slowest global memory of the GPU (confusingly called the High Bandwidth Memory, HBM 🫠)</p>
1773
+
1774
+ <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
1775
+
1776
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1777
+
1778
+ <p>Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!</p>
1779
+
1780
+ <p>The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of <d-math>O</d-math> directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.</p>
1781
+
1782
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1783
+ <p>From the FLASH-ATTENTION paper<d-cite bibtex-key="dao2022flashattention"></d-cite></p>
1784
+
1785
+ <p>The idea of flash attention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers:</p>
1786
+ <ul>
1787
+ <li>By avoiding to materialize the S matrix we <strong>reduce the memory burden of attention</strong></li>
1788
+ <li>We also remove a large part of the <strong>naive impact of the S^2 cost of attention</strong></li>
1789
+ </ul>
1790
+
1791
+ <p>As a result as well, all variants of linear attention and sub-quadratic approaches to approximate attention –developed shortly after the invention of the transformers architecture– have been mostly put aside in favor of this exact and fast flash attention implementation and mechanism.</p>
1792
+
1793
+ <p>Following Flash-attention 1, two successive improved versions have been released by the same lab: Flash-attention 2 and 3. In comparison to Flash-attention 1, the improvements in Flash-attention 2 and 3 are less about the general attention mechanism than about tailoring its low level implementation more specifically to the GPU by (1) reducing the number of non-matmul operations as much as possible (2) partitioning carefully the workload among wraps and thread blocks (for Flash Attention 2) and carefully optimizing for FP8 and Tensor Core support on the latest Hopper (H100) architecture for Flash Attention 3.</p>
1794
+
1795
+ <aside>Flash attention puts some restrictions on which attention patterns can be sped up. Check out <a href="https://pytorch.org/blog/flexattention/">FlexAttention</a> which is a fast <em>and</em> flexible variant.</aside>
1796
+
1797
+ <p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p>
1798
+
1799
+ <p>The techniques described so far in this section require specific modeling code changes and writing custom kernels for certain operations in order to speed up training. In this section we take a look at a range of methods that are agnostic to the modeling code and can be used for any model!</p>
1800
 
1801
  <h3>Mixed Precision Training</h3>
1802
 
1803
+ <p>Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p>
1804
+
1805
+ <ul>
1806
+ <li>Sign: the first bit determines if the number is positive or negative</li>
1807
+ <li>Mantissa: determines the significant figures of a number</li>
1808
+ <li>Exponent: controls the magnitude of the number</li>
1809
+ </ul>
1810
+
1811
+ <p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. <d-math>- 5.734 \times 10^{7}</d-math>, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p>
1812
+
1813
+ <p></p>
1814
+
1815
+ <table>
1816
+ <thead>
1817
+ <tr>
1818
+ <th><strong>Format</strong></th>
1819
+ <th><strong>Total bits</strong></th>
1820
+ <th><strong>Sign</strong></th>
1821
+ <th><strong>Mantissa</strong></th>
1822
+ <th><strong>Exponent</strong></th>
1823
+ </tr>
1824
+ </thead>
1825
+ <tbody>
1826
+ <tr>
1827
+ <td>float32</td>
1828
+ <td>32</td>
1829
+ <td>1</td>
1830
+ <td>23</td>
1831
+ <td>8</td>
1832
+ </tr>
1833
+ <tr>
1834
+ <td>float16</td>
1835
+ <td>16</td>
1836
+ <td>1</td>
1837
+ <td>10</td>
1838
+ <td>5</td>
1839
+ </tr>
1840
+ <tr>
1841
+ <td>bfloat16</td>
1842
+ <td>16</td>
1843
+ <td>1</td>
1844
+ <td>7</td>
1845
+ <td>8</td>
1846
+ </tr>
1847
+ <tr>
1848
+ <td>float8 (e4m3)</td>
1849
+ <td>8</td>
1850
+ <td>1</td>
1851
+ <td>3</td>
1852
+ <td>4</td>
1853
+ </tr>
1854
+ <tr>
1855
+ <td>float8 (e5m2)</td>
1856
+ <td>8</td>
1857
+ <td>1</td>
1858
+ <td>2</td>
1859
+ <td>5</td>
1860
+ </tr>
1861
+ </tbody>
1862
+ </table>
1863
+
1864
+ <aside>Note: You might be wondering where the “b” in bfloat16 comes from. The format was developed at Google Brain and thus the “b” stands for “brain”. </aside>
1865
+
1866
+ <p>Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:</p>
1867
+
1868
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1869
+
1870
+
1871
+ <p>We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.</p>
1872
+
1873
+ <p>How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:</p>
1874
+
1875
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1876
+
1877
+ <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
1878
+
1879
+ <p>A common metric to measure a formats resolution is epsilon: the first representable number after 1.00. We can see that for the float32 format $10^{-4}$ is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it is <d-math>\tilde 10^{-3}</d-math> and for bfloat 10x higher still.</p>
1880
+
1881
+ <p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision.</p>
1882
+
1883
+ <p>This is why lower precision training is usually called <strong><em>mixed precision</em></strong> training. </p>
1884
+
1885
+ <p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p>
1886
+
1887
+
1888
+
1889
  <h4>FP16 and BF16 training</h4>
1890
+
1891
+ <p>Naively switching all the tensors and operations to float16 unfortunately doesn’t work and the result is usually diverging losses. However, the original mixed precision training paper<d-cite bitex-key="micikevicius2018mixedprecisiontraining"></d-cite> came up with three tricks to match float32 trainings:</p>
1892
+
1893
+ <ol>
1894
+ <li><strong>FP32 copy of weights</strong>: There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li>
1895
+ <li><strong>Loss scaling</strong>: We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li>
1896
+ <li><strong>Accumulation</strong>: Finally, when performing arithmetic operations in float16 such as in dot products, we can also face under or overflows. Does targeting certain types of arithmetic operations to accumulate the intermediate results in float32 during the operation and then casting the accumulated result back to fp16. For the same reason gradients are also accumulated in float32.</li>
1897
+ </ol>
1898
+
1899
+ <p>With these techniques, you get consistently stable training while benefitting from higher throughput due to the faster, lower precision operations. Naturally, as the curious reader you are and by now slightly addicted to maximizing the throughput, you ask the question: can we go further and faster? </p>
1900
+
1901
+ <p>Maybe!</p>
1902
 
1903
  <h4>FP8 pretraining</h4>
1904
 
1905
+ <p>Even if we perfectly overlap communication with computation, we always eventually run into the low level theoretical FLOPS limit of the hardware itself, i.e. the efficiency of each individual operation on our hardware. This is where numerical precision becomes crucial. For instance, on NVIDIA's H100 GPU, FP8 matrix multiplications (GEMM operations) achieve twice the theoretical FLOPS of bfloat16, making lower-precision training an attractive path for further optimization.</p>
1906
+
1907
+ <p>Recent research - including FP8-LM<d-cite bibtex-key="peng2023fp8lmtrainingfp8large"></d-cite>, torchao<d-cite bibtex-key="torchao"></d-cite>, and DeepSeek-V3<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite> - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.</p>
1908
+
1909
+ <p>We know that instability increases as learning rates rise for a fixed model size<d-cite bibtex-key="wortsman2023smallscaleproxieslargescaletransformer"></d-cite>, making FP8 pretraining particularly tricky.</p>
1910
+
1911
+ <p>The first, successful, very large scale training with FP8 mixed precision was publicly reported on DeepSeek-V3. The authors carefully analyzed each operation of the forward pass (Fprop) as well as the activation (Dgrad) and weight (Wgrad) backward pass. Similar to BF16 mixed precision training, some aggregation and master weights are kept in higher precision while the operations themselves are performed in FP8. </p>
1912
+
1913
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1914
+
1915
+ <p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of values by computing the absolute maximum. DeepSeek-V3 also introduces a quantization scheme, where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less susceptible to outliers. There is a number of additional tricks they deploy to also reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. </p>
1916
+
1917
+ <p>Here’s a summary of a few known approaches to FP8 training:</p>
1918
+
1919
+ <table>
1920
+ <thead>
1921
+ <tr>
1922
+ <th></th>
1923
+ <th>GEMM's precision</th>
1924
+ <th>Master model weights</th>
1925
+ <th>Accumulated gradients</th>
1926
+ <th>Model weights</th>
1927
+ <th>Gradients</th>
1928
+ <th>Optimizer States</th>
1929
+ <th>Total Memory</th>
1930
+ </tr>
1931
+ </thead>
1932
+ <tbody>
1933
+ <tr>
1934
+ <td>bfloat16 with fp32 mixed precision baseline</td>
1935
+ <td>bf16</td>
1936
+ <td>fp32</td>
1937
+ <td>fp32</td>
1938
+ <td>bf16</td>
1939
+ <td>bf16</td>
1940
+ <td>fp32 + fp32</td>
1941
+ <td>4 + 4 + 2 + 2 + 4 + 4 = 20 bytes</td>
1942
+ </tr>
1943
+ <tr>
1944
+ <td>Above without FP32 grad accumulation</td>
1945
+ <td>bf16</td>
1946
+ <td>fp32</td>
1947
+ <td></td>
1948
+ <td>bf16</td>
1949
+ <td>bf16</td>
1950
+ <td>fp32 + fp32</td>
1951
+ <td>4 + 2 + 2 + 4 + 4 = 16 bytes</td>
1952
+ </tr>
1953
+ <tr>
1954
+ <td>Transformer Engine</td>
1955
+ <td>fp8</td>
1956
+ <td></td>
1957
+ <td></td>
1958
+ <td>fp32</td>
1959
+ <td>fp32</td>
1960
+ <td>fp32 + fp32</td>
1961
+ <td>4 + 4 + 4 + 4 = 16 bytes (20% reduction)</td>
1962
+ </tr>
1963
+ <tr>
1964
+ <td>FP8-LM's O3 level</td>
1965
+ <td>fp8</td>
1966
+ <td>fp16</td>
1967
+ <td>fp16</td>
1968
+ <td>fp8</td>
1969
+ <td>fp8</td>
1970
+ <td>fp8 + fp16</td>
1971
+ <td>2 + 2 + 1 + 1 + 1 + 2 = 9 bytes (55%)</td>
1972
+ </tr>
1973
+ <tr>
1974
+ <td>DeepSeek-V3</td>
1975
+ <td>fp8</td>
1976
+ <td>fp32</td>
1977
+ <td>fp32</td>
1978
+ <td>fp8</td>
1979
+ <td>bf16</td>
1980
+ <td>bf16 + bf16</td>
1981
+ <td>4+4+1+2+2+2 = 15 (25%)</td>
1982
+ </tr>
1983
+ <tr>
1984
+ <td>nanotron's FP8</td>
1985
+ <td>fp8</td>
1986
+ <td>bf16</td>
1987
+ <td>fp32</td>
1988
+ <td>fp8</td>
1989
+ <td>fp8</td>
1990
+ <td>fp8 + fp8</td>
1991
+ <td>2 + 4 + 1 + 1 + 1 + 1 = 10 bytes (50%)</td>
1992
+ </tr>
1993
+ </tbody>
1994
+ </table>
1995
+
1996
+ <p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow public implementations of this, please head to the nanotron’s implementation in [TODO: link to appendix]. </p>
1997
+
1998
+ <p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
1999
+
2000
+ <p>We now arrived at the end of the distributed training journey. Let’s take a step back and conclude.</p>
2001
+
2002
  <h2>Conclusion</h2>
2003
 
2004
+
2005
+ <p>Congratulations! You've completed quite a journey - from understanding how to train a simple model on a single GPU, all the way to mastering the complex techniques used to efficiently train massive language models like Llama-405B and DeepSeek-V3. By now, you should feel confident interpreting advanced parallelism diagrams like the one below, which would have seemed daunting when you first started.</p>
2006
+
2007
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2008
+
2009
+ <p>In distributed training, many concepts sound easy enough when you first hear them, like “Pipeline parallelism just distributes layers on different GPUs”, but we also worked through all the challenging details when implementing those methods. </p>
2010
+
2011
+ <p>However, not only did you learn something in the process, but we also want to share some insights we gained along the way, as well as give you ideas on what to work on next if you want to gain more experience in distributed training.</p>
2012
+
2013
+ <p>Let’s start with a brief recap of all the things we covered in these past hours and days!</p>
2014
+
2015
  <h3>What you learned</h3>
2016
 
2017
+ <p>Working through this whole blog post you mastered a ranged of concepts:</p>
2018
+
2019
+ <ul>
2020
+ <li>Basic principle of model training</li>
2021
+ <li>Collective communication primitives </li>
2022
+ <li>Memory anatomy of a LLM</li>
2023
+ <li>Distributed training with DP and ZeRO </li>
2024
+ <li>Model parallelism with TP, SP, CP and PP</li>
2025
+ <li>Fast kernels and mixed precision training</li>
2026
+ <li>Overlapping communication and computation</li>
2027
+ <li>Profiling distributed training</li>
2028
+ </ul>
2029
+
2030
+ <p>Furthermore, you saw code implementations of most methods and how to benchmark a distributed training. But it hasn’t been only a learning experience for you, also we learned a thing or two!</p>
2031
+
2032
  <h3>What we learned</h3>
2033
+
2034
+ <p>Running benchmarks on a cluster turned out to be much more challenging than we initially expected! What seemed like straightforward tasks often became complex debugging sessions:
2035
+ </p>
2036
+
2037
+ <ul>
2038
+ <li>PyTorch processes would sometimes fail to clean up properly</li>
2039
+ <li>Slurm job manager would forcefully terminate jobs, leading to node failures </li>
2040
+ <li>Simple benchmarks that should take minutes would stretch into hours</li>
2041
+ <li>We had to spend significant time:</li>
2042
+ <ul>
2043
+ <li>Minimizing cluster restart times and optimize idle time</li>
2044
+ <li>Analyzing detailed NCCL debug logs</li>
2045
+ <li>Understand memory usage patterns and CUDA memory allocator behaviors</li>
2046
+ <li>Improving pipeline parallelism performance on multi-node</li>
2047
+ </ul>
2048
+ </ul>
2049
+
2050
+ <p>These challenges deserve their own story, but they taught us valuable lessons about the complexities of distributed training infrastructure. What looks simple in theory often requires careful attention to many moving parts in practice.</p>
2051
+
2052
+ <p>Let's analyze the results of our benchmarks and understand how different configurations affect each other. All benchmarks were run with a sequence length of 4096 and a global batch size of 1M tokens. We'll look at two key visualizations that help illustrate our findings.
2053
+ </p>
2054
+
2055
+ <p>First, let's examine this heatmap visualization:</p>
2056
+
2057
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2058
+ <p>Heatmap visualization showing the optimal training configurations across different model sizes and compute node counts. For each combination, the configuration details include Data Parallelism (DP), Tensor Parallelism (TP), Pipeline Parallelism (PP), Gradient Accumulation Steps (GAS), Micro Batch Size (MBS), and ZeRO optimization stage. The color intensity indicates the Model FLOPs Utilization (MFU), with brighter colors representing higher efficiency.</p>
2059
+
2060
+ <p>To complement this, let's look at the relationships between different parameters:</p>
2061
 
2062
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2063
+ <p>Parallel coordinates plot showing the relationship between different model parallelism configurations (Data Parallel degree, Tensor Parallel degree, Pipeline Parallel degree), training hyperparameters (gradient accumulation steps, micro batch size), ZeRO stage and the resulting Model FLOPs Utilization (MFU). Each line represents a different training configuration, with colors indicating the MFU value - warmer colors show higher efficiency.</p>
2064
+
2065
+ <p>From these visualizations, we can draw several important insights:
2066
+ </p>
2067
+
2068
+ <ol>
2069
+ <li>As we increase the number of nodes (higher parallelism), we observe a decrease in efficiency. This effect is particularly pronounced for smaller models, which have a lower compute-to-model-size ratio. While we might typically compensate for small model size by increasing the batch size, we're constrained by our global batch size limit of 1M.
2070
+ </li>
2071
+ <li>Larger models present a different challenge. As model size increases, memory requirements grow substantially. This creates two scenarios with fewer nodes: either the model doesn't fit at all, or it barely fits but runs inefficiently due to operating near the GPU memory limits.</li>
2072
+ <li>Our benchmarks demonstrate how performance heavily depends on implementation quality. When we first implemented both parallelism strategies, Tensor Parallelism (TP) outperformed Pipeline Parallelism (PP). After optimizing our PP code, it became the faster option. Now that we're improving the communication overlap in our TP implementation, we expect it to regain the performance lead.</li>
2073
+ </ol>
2074
+
2075
+ <p>These findings highlight the challenges of reproducing theoretical results in practice, especially given the limited availability of production training code. Through open-source projects like picotron and nanotron, we hope to make these distributed training techniques more accessible and foster collaboration on simpler, more efficient codebases that help researchers and practitioners make the most of their hardware resources.</p>
2076
+
2077
  <h3>What’s next?</h3>
2078
+
2079
+ <p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p>
2080
+
2081
+ <ul>
2082
+ <li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in [TODO References]</li>
2083
+ <li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li>
2084
+ <li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li>
2085
+ </ul>
2086
+
2087
+ <p>We hope this blog helps you get started in distributed training or helps you to better understand methods that you may already be applying by using some distributed training frameworks.</p>
2088
 
2089
  <h2>References</h2>
2090