thomwolf HF staff commited on
Commit
8e93807
·
verified ·
1 Parent(s): 66bdca2
Files changed (2) hide show
  1. dist/index.html +26 -24
  2. src/index.html +26 -24
dist/index.html CHANGED
@@ -770,7 +770,7 @@
770
 
771
  <p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p>
772
 
773
- <p>While Data Parallelism is very efficient in scaling training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p>
774
 
775
  <aside>We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">DeepSpeed docs</a>.</aside>
776
 
@@ -782,7 +782,7 @@
782
  <li>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</li>
783
  </ul>
784
 
785
- <aside>When we say partitioning, it means alongside the DP axis, as ZeRO is part of Data Parallelism. We’ll see later that we can partition alongside other axes.</aside>
786
 
787
  <p>You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!</p>
788
 
@@ -790,13 +790,13 @@
790
 
791
  <h4>Memory usage revisited</h4>
792
 
793
- <p>Let’s first recap the memory usage of optimizer states, gradients, and parameters during a standard training. Let’s define the number of our model's parameters as <d-math>\Psi</d-math> (previously N but here we use the original ZeRO notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p>
794
 
795
  <ul>
796
  <li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
797
  <li>Model’s gradients (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
798
  <li>Model’s parameters in fp32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
799
- <li>- Model’s gradients in fp32: <d-math>4\Psi</d-math> (optional, only accounted if we want to accumulate grads in fp32)</li>
800
  </ul>
801
 
802
  <p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we accumulate it would be <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.</p>
@@ -804,7 +804,7 @@
804
  <p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
805
 
806
  <p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
807
- <p>Memory consumption of DP and three stages of Zero-DP. <d-math>\Psi</d-math> denotes number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam), and <d-math>N_d</d-math> denotes DP degree.</p>
808
 
809
 
810
  <p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
@@ -813,25 +813,26 @@
813
 
814
  <p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p>
815
 
816
- <p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts where <d-math>N_d</d-math> is the DP degree. This means that each model replica that’s distributed on each DP rank only keeps track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states. During the optimization step only <d-math>\frac{1}{N_d}</d-math> of the float32 weights are updated, which we cast to get the corresponding <d-math>\frac{1}{N_d}</d-math> portion of the bfloat16 parameters.</p>
817
 
818
- <p>However for the forward pass, we need all our bfloat16 parameters, we thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights.</p>
819
 
820
  <p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p>
821
 
822
  <ul>
823
- <li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
824
- <li>Backward pass with all gradients, but different microbatches across DP ranks</li>
825
- <li>Perform an reduce-scatter on the gradients (reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em>)</li>
826
- <li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
827
- <li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
828
  </ul>
 
829
 
830
- <p>See the figure below for all the necessary steps in one forward/backward pass cycle:</p>
831
 
832
  <p><img alt="dp_zero1.gif" src="/assets/images/dp_zero1.gif" /></p>
833
 
834
- <p>So in practice, compared to vanilla DP, Zero-1 adds an all-gather over all parameters after the optimizer step as we can see below:</p>
835
 
836
  <p><img alt="dp_zero1_overlap.svg" src="/assets/images/dp_zero1_overlap.svg" /></p>
837
 
@@ -845,15 +846,15 @@
845
  <div class="note-box">
846
  <p class="note-box-title">📝 Note</p>
847
  <div class="note-box-content">
848
- <p>Unfortunately these techniques are not straightforward to implement and require sophisticated use of hooks/bucketing. In practice we can just use ZeRO-3/FSDP implementation where the FSDPUnit is the entire model, more details about this later.</p>
849
  </div>
850
  </div>
851
 
852
- <p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates <d-math>\frac{1}{N_d}</d-math> of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place since only a subset is needed for the optimization step. Meet ZeRO-2!</p>
853
 
854
  <h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
855
 
856
- <p>The idea of ZeRO to is to not only shard the optimizer states but also the gradients. We actually only need the gradient shard corresponding to the optimizer state shard, so it makes sense to shard both the same way. [TODO: update] During the backward pass, instead of performing an all-reduce over the gradients, we only perform a <strong><em>reduce-scatter</em></strong> operation! Where we only spread the <d-math>\frac{1}{N_d}</d-math> gradients needed in memory, thus saving more memory compared to ZeRO-1.</p>
857
 
858
  <aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
859
 
@@ -867,7 +868,7 @@
867
 
868
  <aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
869
 
870
- <p>Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of. We would like to reduce the memory of the parameters as well, and we’ve seen that we don’t need to wait for the entire all-gather to start the forward, we can already start the forward once we get the first layer.. here comes ZeRO-3!</p>
871
 
872
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
873
 
@@ -894,16 +895,15 @@
894
 
895
  <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
896
 
897
- <p>Thankfully, although we added many more communication operations, <strong>prefetching</strong> helps us overlap them efficiently by all-gathering weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, by all-gathering weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
898
 
899
- <p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in earlier chapters.</p>
900
-
901
- <aside>If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
902
 
903
  <p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
904
  </strong></p>
 
905
 
906
- <p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length. </p>
907
 
908
  <iframe class="l-body-outset" id="plotFrame6" src="assets/data/benchmarks/zero3_memoryusage.html" width="90%" scrolling="no" frameborder="0"></iframe>
909
  <script>
@@ -915,7 +915,9 @@
915
  </script>
916
  <!-- <p><img alt="zero3_memoryusage.svg" src="/assets/images/zero3_memoryusage.svg" /></p> -->
917
 
918
- <p>Now that we've efficiently used the DP axis to reduce memory through efficient communication patterns, let's explore a new, orthogonal axis of parallelism - Tensor Parallelism. Unlike ZeRO3 that relies on heavy parameter communication, TP manages to shard parameters, gradients, optimizer states AND activations across devices without requiring any model parameter movement between GPUs. What! How is this even possible?! Let's explore this seemingly magical approach together! 🙂</p>
 
 
919
 
920
  <h2>Tensor Parallelism</h2>
921
 
 
770
 
771
  <p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p>
772
 
773
+ <p>While Data Parallelism is an efficient way to scale training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p>
774
 
775
  <aside>We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">DeepSpeed docs</a>.</aside>
776
 
 
782
  <li>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</li>
783
  </ul>
784
 
785
+ <aside>When we say partitioning, it means along the DP axis, as ZeRO is part of Data Parallelism. We’ll see later that we can partition along other axes.</aside>
786
 
787
  <p>You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!</p>
788
 
 
790
 
791
  <h4>Memory usage revisited</h4>
792
 
793
+ <p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p>
794
 
795
  <ul>
796
  <li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
797
  <li>Model’s gradients (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
798
  <li>Model’s parameters in fp32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
799
+ <li>Model’s gradients in fp32: <d-math>4\Psi</d-math> (optional, only accounted if we want to accumulate grads in fp32)</li>
800
  </ul>
801
 
802
  <p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we accumulate it would be <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.</p>
 
804
  <p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
805
 
806
  <p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
807
+ <p>Here <d-math>\Psi</d-math> denotes number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam as we've just seen), and <d-math>N_d</d-math> denotes DP degree.</p>
808
 
809
 
810
  <p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
 
813
 
814
  <p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p>
815
 
816
+ <p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts where <d-math>N_d</d-math> is the DP degree. This means that each model replica distributed on each DP rank only keeps track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states. During the optimization step only <d-math>\frac{1}{N_d}</d-math> of the float32 weights are updated.</p>
817
 
818
+ <p>However during the forward pass, each replica@ need all the parameters, we thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights.</p>
819
 
820
  <p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p>
821
 
822
  <ul>
823
+ <li>Forward pass with the same, full set of bf16 parameters on each replica, but different microbatches across replicas</li>
824
+ <li>Backward pass with the same, full set of gradients on each replica, but different microbatches across replicas</li>
825
+ <li>Perform an reduce-scatter on the gradients (we'll explain the reduce-scatter primitive in the graph below)</li>
826
+ <li>Each replica perform an optimizer step on its local optimizer steps (only <d-math>\frac{1}{N_d}</d-math> optimizer states) to get updated <d-math>\frac{1}{N_d}</d-math> fp32 parameters which can then be converted to <d-math>\frac{1}{N_d}</d-math> of the full set of bf16 parameters.</li>
827
+ <li>Perform an all-gather among the bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
828
  </ul>
829
+ <aside>Note: reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em></aside>
830
 
831
+ <p>You may be wondering what is this "reduce-scatter" operation and how this all look so lets try to make this more graphical with the figure below. We'll go over all the steps of a forward/backward pass cycle:</p>
832
 
833
  <p><img alt="dp_zero1.gif" src="/assets/images/dp_zero1.gif" /></p>
834
 
835
+ <p>In terms of practical communications, compared to vanilla DP, Zero-1 change our "all-reduce" gradient communication to a "reduce-scatter" operation and adds an all-gather operation over all parameters after the optimizer step. Here is how it looks:</p>
836
 
837
  <p><img alt="dp_zero1_overlap.svg" src="/assets/images/dp_zero1_overlap.svg" /></p>
838
 
 
846
  <div class="note-box">
847
  <p class="note-box-title">📝 Note</p>
848
  <div class="note-box-content">
849
+ <p>Unfortunately these techniques are not straightforward to implement and require sophisticated use of hooks/bucketing. In practice we can just use PyTorch native ZeRO-3/FSDP implementation and set the FSDPUnit to be the entire model, more details about this later.</p>
850
  </div>
851
  </div>
852
 
853
+ <p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates <d-math>\frac{1}{N_d}</d-math> of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place as only a subset is needed for the optimization step. Meet ZeRO-2!</p>
854
 
855
  <h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
856
 
857
+ <p>Since we only need, on each replica, to have the gradient shard corresponding to the optimizer state shard, it makes sense to shard gradient as well similarly to the optimizer states. During the backward pass, instead of performing an all-reduce over the gradients, we only perform a <strong><em>reduce-scatter</em></strong> operation! Where we only spread the <d-math>\frac{1}{N_d}</d-math> gradients needed in memory, thus saving more memory compared to ZeRO-1.</p>
858
 
859
  <aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
860
 
 
868
 
869
  <aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
870
 
871
+ <p>Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of. Here comes ZeRO-3!</p>
872
 
873
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
874
 
 
895
 
896
  <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
897
 
898
+ <p>This may sounds like a lot of communication overhead but it's actually pretty fine as we can overlap the communication of the parameters for the next layer with the forward pass of the current layer in what is called <strong>prefetching</strong>. With prefetching, we will "all-gather" weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, we will "all-gather" weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
899
 
900
+ <p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in the previous chapters.</p>
 
 
901
 
902
  <p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
903
  </strong></p>
904
+ <aside>If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
905
 
906
+ <p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! We recall from <a target="_self" href="#memory_usage_in_transformers">the activation memory discussion</a> that this part of the memory scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length. </p>
907
 
908
  <iframe class="l-body-outset" id="plotFrame6" src="assets/data/benchmarks/zero3_memoryusage.html" width="90%" scrolling="no" frameborder="0"></iframe>
909
  <script>
 
915
  </script>
916
  <!-- <p><img alt="zero3_memoryusage.svg" src="/assets/images/zero3_memoryusage.svg" /></p> -->
917
 
918
+ <p>To overcome this issues, it's time to explore a new, orthogonal axis of parallelism - Tensor Parallelism (TP). Unlike ZeRO3 which relies on heavy parameter communication, TP proposes to shard parameters, gradients, optimizer states AND activations across devices without requiring any communication of model parameters between GPUs.</p>
919
+
920
+ <p>What? How is this even possible?! Let's explore this seemingly magical approach together! 🙂</p>
921
 
922
  <h2>Tensor Parallelism</h2>
923
 
src/index.html CHANGED
@@ -770,7 +770,7 @@
770
 
771
  <p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p>
772
 
773
- <p>While Data Parallelism is very efficient in scaling training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p>
774
 
775
  <aside>We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">DeepSpeed docs</a>.</aside>
776
 
@@ -782,7 +782,7 @@
782
  <li>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</li>
783
  </ul>
784
 
785
- <aside>When we say partitioning, it means alongside the DP axis, as ZeRO is part of Data Parallelism. We’ll see later that we can partition alongside other axes.</aside>
786
 
787
  <p>You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!</p>
788
 
@@ -790,13 +790,13 @@
790
 
791
  <h4>Memory usage revisited</h4>
792
 
793
- <p>Let’s first recap the memory usage of optimizer states, gradients, and parameters during a standard training. Let’s define the number of our model's parameters as <d-math>\Psi</d-math> (previously N but here we use the original ZeRO notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p>
794
 
795
  <ul>
796
  <li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
797
  <li>Model’s gradients (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
798
  <li>Model’s parameters in fp32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
799
- <li>- Model’s gradients in fp32: <d-math>4\Psi</d-math> (optional, only accounted if we want to accumulate grads in fp32)</li>
800
  </ul>
801
 
802
  <p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we accumulate it would be <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.</p>
@@ -804,7 +804,7 @@
804
  <p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
805
 
806
  <p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
807
- <p>Memory consumption of DP and three stages of Zero-DP. <d-math>\Psi</d-math> denotes number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam), and <d-math>N_d</d-math> denotes DP degree.</p>
808
 
809
 
810
  <p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
@@ -813,25 +813,26 @@
813
 
814
  <p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p>
815
 
816
- <p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts where <d-math>N_d</d-math> is the DP degree. This means that each model replica that’s distributed on each DP rank only keeps track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states. During the optimization step only <d-math>\frac{1}{N_d}</d-math> of the float32 weights are updated, which we cast to get the corresponding <d-math>\frac{1}{N_d}</d-math> portion of the bfloat16 parameters.</p>
817
 
818
- <p>However for the forward pass, we need all our bfloat16 parameters, we thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights.</p>
819
 
820
  <p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p>
821
 
822
  <ul>
823
- <li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
824
- <li>Backward pass with all gradients, but different microbatches across DP ranks</li>
825
- <li>Perform an reduce-scatter on the gradients (reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em>)</li>
826
- <li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
827
- <li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
828
  </ul>
 
829
 
830
- <p>See the figure below for all the necessary steps in one forward/backward pass cycle:</p>
831
 
832
  <p><img alt="dp_zero1.gif" src="/assets/images/dp_zero1.gif" /></p>
833
 
834
- <p>So in practice, compared to vanilla DP, Zero-1 adds an all-gather over all parameters after the optimizer step as we can see below:</p>
835
 
836
  <p><img alt="dp_zero1_overlap.svg" src="/assets/images/dp_zero1_overlap.svg" /></p>
837
 
@@ -845,15 +846,15 @@
845
  <div class="note-box">
846
  <p class="note-box-title">📝 Note</p>
847
  <div class="note-box-content">
848
- <p>Unfortunately these techniques are not straightforward to implement and require sophisticated use of hooks/bucketing. In practice we can just use ZeRO-3/FSDP implementation where the FSDPUnit is the entire model, more details about this later.</p>
849
  </div>
850
  </div>
851
 
852
- <p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates <d-math>\frac{1}{N_d}</d-math> of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place since only a subset is needed for the optimization step. Meet ZeRO-2!</p>
853
 
854
  <h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
855
 
856
- <p>The idea of ZeRO to is to not only shard the optimizer states but also the gradients. We actually only need the gradient shard corresponding to the optimizer state shard, so it makes sense to shard both the same way. [TODO: update] During the backward pass, instead of performing an all-reduce over the gradients, we only perform a <strong><em>reduce-scatter</em></strong> operation! Where we only spread the <d-math>\frac{1}{N_d}</d-math> gradients needed in memory, thus saving more memory compared to ZeRO-1.</p>
857
 
858
  <aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
859
 
@@ -867,7 +868,7 @@
867
 
868
  <aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
869
 
870
- <p>Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of. We would like to reduce the memory of the parameters as well, and we’ve seen that we don’t need to wait for the entire all-gather to start the forward, we can already start the forward once we get the first layer.. here comes ZeRO-3!</p>
871
 
872
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
873
 
@@ -894,16 +895,15 @@
894
 
895
  <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
896
 
897
- <p>Thankfully, although we added many more communication operations, <strong>prefetching</strong> helps us overlap them efficiently by all-gathering weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, by all-gathering weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
898
 
899
- <p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in earlier chapters.</p>
900
-
901
- <aside>If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
902
 
903
  <p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
904
  </strong></p>
 
905
 
906
- <p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length. </p>
907
 
908
  <iframe class="l-body-outset" id="plotFrame6" src="assets/data/benchmarks/zero3_memoryusage.html" width="90%" scrolling="no" frameborder="0"></iframe>
909
  <script>
@@ -915,7 +915,9 @@
915
  </script>
916
  <!-- <p><img alt="zero3_memoryusage.svg" src="/assets/images/zero3_memoryusage.svg" /></p> -->
917
 
918
- <p>Now that we've efficiently used the DP axis to reduce memory through efficient communication patterns, let's explore a new, orthogonal axis of parallelism - Tensor Parallelism. Unlike ZeRO3 that relies on heavy parameter communication, TP manages to shard parameters, gradients, optimizer states AND activations across devices without requiring any model parameter movement between GPUs. What! How is this even possible?! Let's explore this seemingly magical approach together! 🙂</p>
 
 
919
 
920
  <h2>Tensor Parallelism</h2>
921
 
 
770
 
771
  <p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p>
772
 
773
+ <p>While Data Parallelism is an efficient way to scale training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p>
774
 
775
  <aside>We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">DeepSpeed docs</a>.</aside>
776
 
 
782
  <li>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</li>
783
  </ul>
784
 
785
+ <aside>When we say partitioning, it means along the DP axis, as ZeRO is part of Data Parallelism. We’ll see later that we can partition along other axes.</aside>
786
 
787
  <p>You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!</p>
788
 
 
790
 
791
  <h4>Memory usage revisited</h4>
792
 
793
+ <p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p>
794
 
795
  <ul>
796
  <li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
797
  <li>Model’s gradients (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
798
  <li>Model’s parameters in fp32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
799
+ <li>Model’s gradients in fp32: <d-math>4\Psi</d-math> (optional, only accounted if we want to accumulate grads in fp32)</li>
800
  </ul>
801
 
802
  <p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we accumulate it would be <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.</p>
 
804
  <p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
805
 
806
  <p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
807
+ <p>Here <d-math>\Psi</d-math> denotes number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam as we've just seen), and <d-math>N_d</d-math> denotes DP degree.</p>
808
 
809
 
810
  <p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
 
813
 
814
  <p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p>
815
 
816
+ <p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts where <d-math>N_d</d-math> is the DP degree. This means that each model replica distributed on each DP rank only keeps track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states. During the optimization step only <d-math>\frac{1}{N_d}</d-math> of the float32 weights are updated.</p>
817
 
818
+ <p>However during the forward pass, each replica@ need all the parameters, we thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights.</p>
819
 
820
  <p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p>
821
 
822
  <ul>
823
+ <li>Forward pass with the same, full set of bf16 parameters on each replica, but different microbatches across replicas</li>
824
+ <li>Backward pass with the same, full set of gradients on each replica, but different microbatches across replicas</li>
825
+ <li>Perform an reduce-scatter on the gradients (we'll explain the reduce-scatter primitive in the graph below)</li>
826
+ <li>Each replica perform an optimizer step on its local optimizer steps (only <d-math>\frac{1}{N_d}</d-math> optimizer states) to get updated <d-math>\frac{1}{N_d}</d-math> fp32 parameters which can then be converted to <d-math>\frac{1}{N_d}</d-math> of the full set of bf16 parameters.</li>
827
+ <li>Perform an all-gather among the bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
828
  </ul>
829
+ <aside>Note: reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em></aside>
830
 
831
+ <p>You may be wondering what is this "reduce-scatter" operation and how this all look so lets try to make this more graphical with the figure below. We'll go over all the steps of a forward/backward pass cycle:</p>
832
 
833
  <p><img alt="dp_zero1.gif" src="/assets/images/dp_zero1.gif" /></p>
834
 
835
+ <p>In terms of practical communications, compared to vanilla DP, Zero-1 change our "all-reduce" gradient communication to a "reduce-scatter" operation and adds an all-gather operation over all parameters after the optimizer step. Here is how it looks:</p>
836
 
837
  <p><img alt="dp_zero1_overlap.svg" src="/assets/images/dp_zero1_overlap.svg" /></p>
838
 
 
846
  <div class="note-box">
847
  <p class="note-box-title">📝 Note</p>
848
  <div class="note-box-content">
849
+ <p>Unfortunately these techniques are not straightforward to implement and require sophisticated use of hooks/bucketing. In practice we can just use PyTorch native ZeRO-3/FSDP implementation and set the FSDPUnit to be the entire model, more details about this later.</p>
850
  </div>
851
  </div>
852
 
853
+ <p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates <d-math>\frac{1}{N_d}</d-math> of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place as only a subset is needed for the optimization step. Meet ZeRO-2!</p>
854
 
855
  <h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
856
 
857
+ <p>Since we only need, on each replica, to have the gradient shard corresponding to the optimizer state shard, it makes sense to shard gradient as well similarly to the optimizer states. During the backward pass, instead of performing an all-reduce over the gradients, we only perform a <strong><em>reduce-scatter</em></strong> operation! Where we only spread the <d-math>\frac{1}{N_d}</d-math> gradients needed in memory, thus saving more memory compared to ZeRO-1.</p>
858
 
859
  <aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
860
 
 
868
 
869
  <aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
870
 
871
+ <p>Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of. Here comes ZeRO-3!</p>
872
 
873
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
874
 
 
895
 
896
  <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
897
 
898
+ <p>This may sounds like a lot of communication overhead but it's actually pretty fine as we can overlap the communication of the parameters for the next layer with the forward pass of the current layer in what is called <strong>prefetching</strong>. With prefetching, we will "all-gather" weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, we will "all-gather" weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
899
 
900
+ <p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in the previous chapters.</p>
 
 
901
 
902
  <p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
903
  </strong></p>
904
+ <aside>If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
905
 
906
+ <p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! We recall from <a target="_self" href="#memory_usage_in_transformers">the activation memory discussion</a> that this part of the memory scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length. </p>
907
 
908
  <iframe class="l-body-outset" id="plotFrame6" src="assets/data/benchmarks/zero3_memoryusage.html" width="90%" scrolling="no" frameborder="0"></iframe>
909
  <script>
 
915
  </script>
916
  <!-- <p><img alt="zero3_memoryusage.svg" src="/assets/images/zero3_memoryusage.svg" /></p> -->
917
 
918
+ <p>To overcome this issues, it's time to explore a new, orthogonal axis of parallelism - Tensor Parallelism (TP). Unlike ZeRO3 which relies on heavy parameter communication, TP proposes to shard parameters, gradients, optimizer states AND activations across devices without requiring any communication of model parameters between GPUs.</p>
919
+
920
+ <p>What? How is this even possible?! Let's explore this seemingly magical approach together! 🙂</p>
921
 
922
  <h2>Tensor Parallelism</h2>
923