Spaces:
Running
Running
continuing all updates (#40)
Browse files- update (1954885fcad56bfafc60101bec62d11c99235994)
- dist/index.html +26 -24
- src/index.html +26 -24
dist/index.html
CHANGED
@@ -770,7 +770,7 @@
|
|
770 |
|
771 |
<p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p>
|
772 |
|
773 |
-
<p>While Data Parallelism is
|
774 |
|
775 |
<aside>We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">DeepSpeed docs</a>.</aside>
|
776 |
|
@@ -782,7 +782,7 @@
|
|
782 |
<li>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</li>
|
783 |
</ul>
|
784 |
|
785 |
-
<aside>When we say partitioning, it means
|
786 |
|
787 |
<p>You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!</p>
|
788 |
|
@@ -790,13 +790,13 @@
|
|
790 |
|
791 |
<h4>Memory usage revisited</h4>
|
792 |
|
793 |
-
<p>
|
794 |
|
795 |
<ul>
|
796 |
<li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
797 |
<li>Model’s gradients (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
798 |
<li>Model’s parameters in fp32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
|
799 |
-
<li
|
800 |
</ul>
|
801 |
|
802 |
<p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we accumulate it would be <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.</p>
|
@@ -804,7 +804,7 @@
|
|
804 |
<p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
|
805 |
|
806 |
<p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
|
807 |
-
<p>
|
808 |
|
809 |
|
810 |
<p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
|
@@ -813,25 +813,26 @@
|
|
813 |
|
814 |
<p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p>
|
815 |
|
816 |
-
<p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts where <d-math>N_d</d-math> is the DP degree. This means that each model replica
|
817 |
|
818 |
-
<p>However
|
819 |
|
820 |
<p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p>
|
821 |
|
822 |
<ul>
|
823 |
-
<li>Forward pass with
|
824 |
-
<li>Backward pass with
|
825 |
-
<li>Perform an reduce-scatter on the gradients (
|
826 |
-
<li
|
827 |
-
<li>Perform an all-gather
|
828 |
</ul>
|
|
|
829 |
|
830 |
-
<p>
|
831 |
|
832 |
<p><img alt="dp_zero1.gif" src="/assets/images/dp_zero1.gif" /></p>
|
833 |
|
834 |
-
<p>
|
835 |
|
836 |
<p><img alt="dp_zero1_overlap.svg" src="/assets/images/dp_zero1_overlap.svg" /></p>
|
837 |
|
@@ -845,15 +846,15 @@
|
|
845 |
<div class="note-box">
|
846 |
<p class="note-box-title">📝 Note</p>
|
847 |
<div class="note-box-content">
|
848 |
-
<p>Unfortunately these techniques are not straightforward to implement and require sophisticated use of hooks/bucketing. In practice we can just use ZeRO-3/FSDP implementation
|
849 |
</div>
|
850 |
</div>
|
851 |
|
852 |
-
<p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates <d-math>\frac{1}{N_d}</d-math> of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place
|
853 |
|
854 |
<h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
|
855 |
|
856 |
-
<p>
|
857 |
|
858 |
<aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
|
859 |
|
@@ -867,7 +868,7 @@
|
|
867 |
|
868 |
<aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
|
869 |
|
870 |
-
<p>Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of.
|
871 |
|
872 |
<h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
|
873 |
|
@@ -894,16 +895,15 @@
|
|
894 |
|
895 |
<p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
|
896 |
|
897 |
-
<p>
|
898 |
|
899 |
-
<p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in
|
900 |
-
|
901 |
-
<aside>If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
|
902 |
|
903 |
<p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
|
904 |
</strong></p>
|
|
|
905 |
|
906 |
-
<p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory!
|
907 |
|
908 |
<iframe class="l-body-outset" id="plotFrame6" src="assets/data/benchmarks/zero3_memoryusage.html" width="90%" scrolling="no" frameborder="0"></iframe>
|
909 |
<script>
|
@@ -915,7 +915,9 @@
|
|
915 |
</script>
|
916 |
<!-- <p><img alt="zero3_memoryusage.svg" src="/assets/images/zero3_memoryusage.svg" /></p> -->
|
917 |
|
918 |
-
<p>
|
|
|
|
|
919 |
|
920 |
<h2>Tensor Parallelism</h2>
|
921 |
|
|
|
770 |
|
771 |
<p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p>
|
772 |
|
773 |
+
<p>While Data Parallelism is an efficient way to scale training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p>
|
774 |
|
775 |
<aside>We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">DeepSpeed docs</a>.</aside>
|
776 |
|
|
|
782 |
<li>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</li>
|
783 |
</ul>
|
784 |
|
785 |
+
<aside>When we say partitioning, it means along the DP axis, as ZeRO is part of Data Parallelism. We’ll see later that we can partition along other axes.</aside>
|
786 |
|
787 |
<p>You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!</p>
|
788 |
|
|
|
790 |
|
791 |
<h4>Memory usage revisited</h4>
|
792 |
|
793 |
+
<p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p>
|
794 |
|
795 |
<ul>
|
796 |
<li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
797 |
<li>Model’s gradients (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
798 |
<li>Model’s parameters in fp32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
|
799 |
+
<li>Model’s gradients in fp32: <d-math>4\Psi</d-math> (optional, only accounted if we want to accumulate grads in fp32)</li>
|
800 |
</ul>
|
801 |
|
802 |
<p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we accumulate it would be <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.</p>
|
|
|
804 |
<p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
|
805 |
|
806 |
<p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
|
807 |
+
<p>Here <d-math>\Psi</d-math> denotes number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam as we've just seen), and <d-math>N_d</d-math> denotes DP degree.</p>
|
808 |
|
809 |
|
810 |
<p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
|
|
|
813 |
|
814 |
<p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p>
|
815 |
|
816 |
+
<p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts where <d-math>N_d</d-math> is the DP degree. This means that each model replica distributed on each DP rank only keeps track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states. During the optimization step only <d-math>\frac{1}{N_d}</d-math> of the float32 weights are updated.</p>
|
817 |
|
818 |
+
<p>However during the forward pass, each replica@ need all the parameters, we thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights.</p>
|
819 |
|
820 |
<p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p>
|
821 |
|
822 |
<ul>
|
823 |
+
<li>Forward pass with the same, full set of bf16 parameters on each replica, but different microbatches across replicas</li>
|
824 |
+
<li>Backward pass with the same, full set of gradients on each replica, but different microbatches across replicas</li>
|
825 |
+
<li>Perform an reduce-scatter on the gradients (we'll explain the reduce-scatter primitive in the graph below)</li>
|
826 |
+
<li>Each replica perform an optimizer step on its local optimizer steps (only <d-math>\frac{1}{N_d}</d-math> optimizer states) to get updated <d-math>\frac{1}{N_d}</d-math> fp32 parameters which can then be converted to <d-math>\frac{1}{N_d}</d-math> of the full set of bf16 parameters.</li>
|
827 |
+
<li>Perform an all-gather among the bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
|
828 |
</ul>
|
829 |
+
<aside>Note: reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em></aside>
|
830 |
|
831 |
+
<p>You may be wondering what is this "reduce-scatter" operation and how this all look so lets try to make this more graphical with the figure below. We'll go over all the steps of a forward/backward pass cycle:</p>
|
832 |
|
833 |
<p><img alt="dp_zero1.gif" src="/assets/images/dp_zero1.gif" /></p>
|
834 |
|
835 |
+
<p>In terms of practical communications, compared to vanilla DP, Zero-1 change our "all-reduce" gradient communication to a "reduce-scatter" operation and adds an all-gather operation over all parameters after the optimizer step. Here is how it looks:</p>
|
836 |
|
837 |
<p><img alt="dp_zero1_overlap.svg" src="/assets/images/dp_zero1_overlap.svg" /></p>
|
838 |
|
|
|
846 |
<div class="note-box">
|
847 |
<p class="note-box-title">📝 Note</p>
|
848 |
<div class="note-box-content">
|
849 |
+
<p>Unfortunately these techniques are not straightforward to implement and require sophisticated use of hooks/bucketing. In practice we can just use PyTorch native ZeRO-3/FSDP implementation and set the FSDPUnit to be the entire model, more details about this later.</p>
|
850 |
</div>
|
851 |
</div>
|
852 |
|
853 |
+
<p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates <d-math>\frac{1}{N_d}</d-math> of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place as only a subset is needed for the optimization step. Meet ZeRO-2!</p>
|
854 |
|
855 |
<h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
|
856 |
|
857 |
+
<p>Since we only need, on each replica, to have the gradient shard corresponding to the optimizer state shard, it makes sense to shard gradient as well similarly to the optimizer states. During the backward pass, instead of performing an all-reduce over the gradients, we only perform a <strong><em>reduce-scatter</em></strong> operation! Where we only spread the <d-math>\frac{1}{N_d}</d-math> gradients needed in memory, thus saving more memory compared to ZeRO-1.</p>
|
858 |
|
859 |
<aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
|
860 |
|
|
|
868 |
|
869 |
<aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
|
870 |
|
871 |
+
<p>Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of. Here comes ZeRO-3!</p>
|
872 |
|
873 |
<h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
|
874 |
|
|
|
895 |
|
896 |
<p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
|
897 |
|
898 |
+
<p>This may sounds like a lot of communication overhead but it's actually pretty fine as we can overlap the communication of the parameters for the next layer with the forward pass of the current layer in what is called <strong>prefetching</strong>. With prefetching, we will "all-gather" weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, we will "all-gather" weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
|
899 |
|
900 |
+
<p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in the previous chapters.</p>
|
|
|
|
|
901 |
|
902 |
<p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
|
903 |
</strong></p>
|
904 |
+
<aside>If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
|
905 |
|
906 |
+
<p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! We recall from <a target="_self" href="#memory_usage_in_transformers">the activation memory discussion</a> that this part of the memory scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length. </p>
|
907 |
|
908 |
<iframe class="l-body-outset" id="plotFrame6" src="assets/data/benchmarks/zero3_memoryusage.html" width="90%" scrolling="no" frameborder="0"></iframe>
|
909 |
<script>
|
|
|
915 |
</script>
|
916 |
<!-- <p><img alt="zero3_memoryusage.svg" src="/assets/images/zero3_memoryusage.svg" /></p> -->
|
917 |
|
918 |
+
<p>To overcome this issues, it's time to explore a new, orthogonal axis of parallelism - Tensor Parallelism (TP). Unlike ZeRO3 which relies on heavy parameter communication, TP proposes to shard parameters, gradients, optimizer states AND activations across devices without requiring any communication of model parameters between GPUs.</p>
|
919 |
+
|
920 |
+
<p>What? How is this even possible?! Let's explore this seemingly magical approach together! 🙂</p>
|
921 |
|
922 |
<h2>Tensor Parallelism</h2>
|
923 |
|
src/index.html
CHANGED
@@ -770,7 +770,7 @@
|
|
770 |
|
771 |
<p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p>
|
772 |
|
773 |
-
<p>While Data Parallelism is
|
774 |
|
775 |
<aside>We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">DeepSpeed docs</a>.</aside>
|
776 |
|
@@ -782,7 +782,7 @@
|
|
782 |
<li>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</li>
|
783 |
</ul>
|
784 |
|
785 |
-
<aside>When we say partitioning, it means
|
786 |
|
787 |
<p>You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!</p>
|
788 |
|
@@ -790,13 +790,13 @@
|
|
790 |
|
791 |
<h4>Memory usage revisited</h4>
|
792 |
|
793 |
-
<p>
|
794 |
|
795 |
<ul>
|
796 |
<li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
797 |
<li>Model’s gradients (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
798 |
<li>Model’s parameters in fp32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
|
799 |
-
<li
|
800 |
</ul>
|
801 |
|
802 |
<p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we accumulate it would be <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.</p>
|
@@ -804,7 +804,7 @@
|
|
804 |
<p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
|
805 |
|
806 |
<p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
|
807 |
-
<p>
|
808 |
|
809 |
|
810 |
<p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
|
@@ -813,25 +813,26 @@
|
|
813 |
|
814 |
<p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p>
|
815 |
|
816 |
-
<p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts where <d-math>N_d</d-math> is the DP degree. This means that each model replica
|
817 |
|
818 |
-
<p>However
|
819 |
|
820 |
<p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p>
|
821 |
|
822 |
<ul>
|
823 |
-
<li>Forward pass with
|
824 |
-
<li>Backward pass with
|
825 |
-
<li>Perform an reduce-scatter on the gradients (
|
826 |
-
<li
|
827 |
-
<li>Perform an all-gather
|
828 |
</ul>
|
|
|
829 |
|
830 |
-
<p>
|
831 |
|
832 |
<p><img alt="dp_zero1.gif" src="/assets/images/dp_zero1.gif" /></p>
|
833 |
|
834 |
-
<p>
|
835 |
|
836 |
<p><img alt="dp_zero1_overlap.svg" src="/assets/images/dp_zero1_overlap.svg" /></p>
|
837 |
|
@@ -845,15 +846,15 @@
|
|
845 |
<div class="note-box">
|
846 |
<p class="note-box-title">📝 Note</p>
|
847 |
<div class="note-box-content">
|
848 |
-
<p>Unfortunately these techniques are not straightforward to implement and require sophisticated use of hooks/bucketing. In practice we can just use ZeRO-3/FSDP implementation
|
849 |
</div>
|
850 |
</div>
|
851 |
|
852 |
-
<p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates <d-math>\frac{1}{N_d}</d-math> of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place
|
853 |
|
854 |
<h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
|
855 |
|
856 |
-
<p>
|
857 |
|
858 |
<aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
|
859 |
|
@@ -867,7 +868,7 @@
|
|
867 |
|
868 |
<aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
|
869 |
|
870 |
-
<p>Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of.
|
871 |
|
872 |
<h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
|
873 |
|
@@ -894,16 +895,15 @@
|
|
894 |
|
895 |
<p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
|
896 |
|
897 |
-
<p>
|
898 |
|
899 |
-
<p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in
|
900 |
-
|
901 |
-
<aside>If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
|
902 |
|
903 |
<p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
|
904 |
</strong></p>
|
|
|
905 |
|
906 |
-
<p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory!
|
907 |
|
908 |
<iframe class="l-body-outset" id="plotFrame6" src="assets/data/benchmarks/zero3_memoryusage.html" width="90%" scrolling="no" frameborder="0"></iframe>
|
909 |
<script>
|
@@ -915,7 +915,9 @@
|
|
915 |
</script>
|
916 |
<!-- <p><img alt="zero3_memoryusage.svg" src="/assets/images/zero3_memoryusage.svg" /></p> -->
|
917 |
|
918 |
-
<p>
|
|
|
|
|
919 |
|
920 |
<h2>Tensor Parallelism</h2>
|
921 |
|
|
|
770 |
|
771 |
<p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p>
|
772 |
|
773 |
+
<p>While Data Parallelism is an efficient way to scale training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p>
|
774 |
|
775 |
<aside>We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">DeepSpeed docs</a>.</aside>
|
776 |
|
|
|
782 |
<li>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</li>
|
783 |
</ul>
|
784 |
|
785 |
+
<aside>When we say partitioning, it means along the DP axis, as ZeRO is part of Data Parallelism. We’ll see later that we can partition along other axes.</aside>
|
786 |
|
787 |
<p>You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!</p>
|
788 |
|
|
|
790 |
|
791 |
<h4>Memory usage revisited</h4>
|
792 |
|
793 |
+
<p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p>
|
794 |
|
795 |
<ul>
|
796 |
<li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
797 |
<li>Model’s gradients (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
798 |
<li>Model’s parameters in fp32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
|
799 |
+
<li>Model’s gradients in fp32: <d-math>4\Psi</d-math> (optional, only accounted if we want to accumulate grads in fp32)</li>
|
800 |
</ul>
|
801 |
|
802 |
<p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we accumulate it would be <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.</p>
|
|
|
804 |
<p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
|
805 |
|
806 |
<p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
|
807 |
+
<p>Here <d-math>\Psi</d-math> denotes number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam as we've just seen), and <d-math>N_d</d-math> denotes DP degree.</p>
|
808 |
|
809 |
|
810 |
<p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
|
|
|
813 |
|
814 |
<p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p>
|
815 |
|
816 |
+
<p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts where <d-math>N_d</d-math> is the DP degree. This means that each model replica distributed on each DP rank only keeps track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states. During the optimization step only <d-math>\frac{1}{N_d}</d-math> of the float32 weights are updated.</p>
|
817 |
|
818 |
+
<p>However during the forward pass, each replica@ need all the parameters, we thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights.</p>
|
819 |
|
820 |
<p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p>
|
821 |
|
822 |
<ul>
|
823 |
+
<li>Forward pass with the same, full set of bf16 parameters on each replica, but different microbatches across replicas</li>
|
824 |
+
<li>Backward pass with the same, full set of gradients on each replica, but different microbatches across replicas</li>
|
825 |
+
<li>Perform an reduce-scatter on the gradients (we'll explain the reduce-scatter primitive in the graph below)</li>
|
826 |
+
<li>Each replica perform an optimizer step on its local optimizer steps (only <d-math>\frac{1}{N_d}</d-math> optimizer states) to get updated <d-math>\frac{1}{N_d}</d-math> fp32 parameters which can then be converted to <d-math>\frac{1}{N_d}</d-math> of the full set of bf16 parameters.</li>
|
827 |
+
<li>Perform an all-gather among the bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
|
828 |
</ul>
|
829 |
+
<aside>Note: reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em></aside>
|
830 |
|
831 |
+
<p>You may be wondering what is this "reduce-scatter" operation and how this all look so lets try to make this more graphical with the figure below. We'll go over all the steps of a forward/backward pass cycle:</p>
|
832 |
|
833 |
<p><img alt="dp_zero1.gif" src="/assets/images/dp_zero1.gif" /></p>
|
834 |
|
835 |
+
<p>In terms of practical communications, compared to vanilla DP, Zero-1 change our "all-reduce" gradient communication to a "reduce-scatter" operation and adds an all-gather operation over all parameters after the optimizer step. Here is how it looks:</p>
|
836 |
|
837 |
<p><img alt="dp_zero1_overlap.svg" src="/assets/images/dp_zero1_overlap.svg" /></p>
|
838 |
|
|
|
846 |
<div class="note-box">
|
847 |
<p class="note-box-title">📝 Note</p>
|
848 |
<div class="note-box-content">
|
849 |
+
<p>Unfortunately these techniques are not straightforward to implement and require sophisticated use of hooks/bucketing. In practice we can just use PyTorch native ZeRO-3/FSDP implementation and set the FSDPUnit to be the entire model, more details about this later.</p>
|
850 |
</div>
|
851 |
</div>
|
852 |
|
853 |
+
<p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates <d-math>\frac{1}{N_d}</d-math> of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place as only a subset is needed for the optimization step. Meet ZeRO-2!</p>
|
854 |
|
855 |
<h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
|
856 |
|
857 |
+
<p>Since we only need, on each replica, to have the gradient shard corresponding to the optimizer state shard, it makes sense to shard gradient as well similarly to the optimizer states. During the backward pass, instead of performing an all-reduce over the gradients, we only perform a <strong><em>reduce-scatter</em></strong> operation! Where we only spread the <d-math>\frac{1}{N_d}</d-math> gradients needed in memory, thus saving more memory compared to ZeRO-1.</p>
|
858 |
|
859 |
<aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
|
860 |
|
|
|
868 |
|
869 |
<aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
|
870 |
|
871 |
+
<p>Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of. Here comes ZeRO-3!</p>
|
872 |
|
873 |
<h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
|
874 |
|
|
|
895 |
|
896 |
<p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
|
897 |
|
898 |
+
<p>This may sounds like a lot of communication overhead but it's actually pretty fine as we can overlap the communication of the parameters for the next layer with the forward pass of the current layer in what is called <strong>prefetching</strong>. With prefetching, we will "all-gather" weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, we will "all-gather" weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
|
899 |
|
900 |
+
<p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in the previous chapters.</p>
|
|
|
|
|
901 |
|
902 |
<p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
|
903 |
</strong></p>
|
904 |
+
<aside>If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
|
905 |
|
906 |
+
<p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! We recall from <a target="_self" href="#memory_usage_in_transformers">the activation memory discussion</a> that this part of the memory scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length. </p>
|
907 |
|
908 |
<iframe class="l-body-outset" id="plotFrame6" src="assets/data/benchmarks/zero3_memoryusage.html" width="90%" scrolling="no" frameborder="0"></iframe>
|
909 |
<script>
|
|
|
915 |
</script>
|
916 |
<!-- <p><img alt="zero3_memoryusage.svg" src="/assets/images/zero3_memoryusage.svg" /></p> -->
|
917 |
|
918 |
+
<p>To overcome this issues, it's time to explore a new, orthogonal axis of parallelism - Tensor Parallelism (TP). Unlike ZeRO3 which relies on heavy parameter communication, TP proposes to shard parameters, gradients, optimizer states AND activations across devices without requiring any communication of model parameters between GPUs.</p>
|
919 |
+
|
920 |
+
<p>What? How is this even possible?! Let's explore this seemingly magical approach together! 🙂</p>
|
921 |
|
922 |
<h2>Tensor Parallelism</h2>
|
923 |
|