PyTorch, Gradient Accumulation, and the dreaded drop in speed
Introduction
Recently I was helping someone at work debug some distributed code as they were looking to find ways to speed it up. Immediately I noticed something odd, gradient accumulation.
That in of itself is not odd. But when it comes to distributed compute with Pytorch, if you are not careful you can see immense slowdowns in your code.
What follows below is an exploratory analysis I performed using Hugging Face Accelerate, PyTorch Distributed, and three machines to test what and by how much is the optimal and correct setup for gradient accumulation on multiple GPUs.
Setup
First let’s discuss setup.
For these experiments, I used the following:
- Python: 3.9.13
- PyTorch: v1.12.1+cu113
- Accelerate: v0.16.0
- Transformers: v4.26.1
- Compute:
- Two single-GPU T4 nodes from GCP that can communicate to each other
- One node with two T4 GPUs
- Script-specific parameters:
- Batch size per GPU: 16
- Gradient accumulation steps: 4
- Total observed batch size (1624): 128
- Mixed precision: fp16
- Scripts: available here
Gradient Accumulation is special?
Let’s talk about why gradient accumulation is different on multiple GPUs. On a single GPU, everything happens on that device, you can accumulate, compute, and update the gradients all exceedingly quickly. However when multiple GPUs get involved (both on a single network and on a single machine), each time the backward pass is performed all GPUs communicate with each other. The gradients are updated based on the average between each model on each GPU, and all the weights are synchronized to be this new result based on the average.
As you can imagine, for every instance you need to have all your GPUs communicate there will be a time loss. Even if they are in the same machine!
This time loss can be deadly to your programs as you run them because it can lead to even a 2x slowdown!
So, what’s the cure?
In PyTorch distributed training, the model is wrapped in a DistributedDataParallel
class. This module is what stores the model and understands how to update and process these weight changes, and communicate between all the GPUs you are utilizing during training to do so. This update, as mentioned earlier, happens on backward(), but begins on the forward pass.
As a result, the DistributedDataParallel
class has a function called no_sync
. Essentially this tells PyTorch while this block of code is running, do not synchronize with the other GPUs.
To make this work, this wrapper needs to be around both the forward and backward pass, such that:
= MyModel()
net = DistributedDataParallel(net,...)
net with net.no_sync():
= net(input)
pred = loss_func(pred)
loss pred.backward()
To synchronize again, remove the no_sync
wrapper for a batch and processes will synchronize again.
Translated, this is what gradient accumulation looks like properly in native PyTorch:
for step, (x,y) in enumerate(dataloader):
if step % gradient_accumulation_steps != 0:
with model.no_sync():
= model(x)
outputs = loss_func(outputs, y)
loss = loss / gradient_accumulation_steps
loss
accelerator.backward(loss)else:
= model(x)
outputs = loss_func(outputs, y)
loss = loss / gradient_accumulation_steps
loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step() optimizer.zero_grad()
But just how important is this?
Can I just wrap around .backward()
with the no_sync
?
I ran a few experiments to figure exactly that out.
The Experiments
Each experiment ran through 29 total batches, using bert-base-cased
as the model and the mrpc
dataset. Each attempt was then ran 5 times and the average was taken.
I’ll highlight each individual result below, as well as their code changes.
The Baseline
The baseline consists of nothing special. It calls .backward
at every step, and if we are finished accumulating then the optimizer and scheduler are zero’d and stepped.
for step, (x,y) in enumerate(train_dataloader):
= model(x)
outputs = loss_func(outputs, y)
loss = loss / gradient_accumulation_steps
loss
accelerator.backward(loss)if step % gradient_accumulation_steps == 0:
optimizer.step()
lr_scheduler.step() optimizer.zero_grad()
The Accelerator
here is simply used to handle the standard DDP processes, and nothing more.
This baseline finished at:
Note: Times are in Seconds per Batch
Multi Node | Single Node | |
---|---|---|
Run 1 | 1.95 | 0.52 |
Run 2 | 2.11 | 0.5 |
Run 3 | 1.94 | 0.5 |
Average | 2±0.01s | 0.50±0.01s |
Overall 2 seconds per batch on multi-node, and 0.5 seconds per batch on a single node. That’s a 4x slowdown when comparing single to multi-node. That is not efficient at all!
So, let’s try using this fancy no_sync
thing
Using no_sync
, improperly
For no_sync
to work correctly, it needs to be wrapped around both the backward pass and forward pass. Otherwise, processes will still be synchronized during .backward()
.
Here is the bad example of what not to do, and its results:
for step, batch in enumerate(train_dataloader):
batch.to(accelerator.device)= model(**batch)
outputs = outputs.loss
loss = loss / gradient_accumulation_steps
loss if step % gradient_accumulation_steps != 0:
with model.no_sync():
accelerator.backward(loss)else:
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step() optimizer.zero_grad()
Note: Times are in Seconds per Batch
Multi Node | Single Node | |
---|---|---|
Run 1 | 2.08 | 0.52 |
Run 2 | 2.09 | 0.5 |
Run 3 | 2.23 | 0.5 |
Average | 2.13±0.08s | 0.50±0.01s |
As you can see, negligible different because it’s not actually doing any non-synchronization! Everything is still being synced at the same time, and there’s potential some amount of extra communication is being added on top of this considering on average it was .13s slower.
What is the right way then?
The correct way to use no_sync
, as mentioned earlier, is to wrap around both the forward and backward pass. This ensures that only when we break out of the no_sync
will the gradients fully be synchronized properly.
The snippet and results are below:
for step, (x,y) in enumerate(train_dataloader):
if step % gradient_accumulation_steps != 0:
with model.no_sync():
= model(x)
outputs = loss_function(outputs, y)
loss = loss / gradient_accumulation_steps
loss
accelerator.backward(loss)else:
= model(**batch)
outputs = loss_function(outputs, y)
loss = loss / gradient_accumulation_steps
loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step() optimizer.zero_grad()
Note: Times are in Seconds per Batch
Multi Node | Single Node | |
---|---|---|
Run 1 | 0.84 | 0.4 |
Run 2 | 1.04 | 0.43 |
Run 3 | 0.86 | 0.41 |
Average | 0.91±0.11s | 0.41±0.015s |
You can see that not only did we get a 2x speedup on the multi-node setup, but there was also a 25% speedup on the single node!
Reducing the amount of communication between all of your GPUs when training in a distributed process is paramount to training fast and efficiently.
The last script I will show is how Hugging Face Accelerate can do this automatically for you, using the accumulate
wrapper:
Using Accelerate!
Snippet:
for step, (x,y) in enumerate(train_dataloader):
with accelerator.accumulate(model):
= model(x)
outputs = loss_function(outputs, y)
loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step() optimizer.zero_grad()
Timings:
Note: Times are in Seconds per Batch
Multi Node | Single Node | |
---|---|---|
Run 1 | 0.84 | 0.4 |
Run 2 | 1.04 | 0.43 |
Run 3 | 0.86 | 0.41 |
Average | 0.91±0.11s | 0.41±0.015s |
You can see that we get roughly the same times as the no_sync
example showed earlier, however Accelerate let’s us remove all of the if/else logic that was required entirely!
This helpful piece of magic not only lets you reduce lines of code, but it also ensures that you can never see the slowdowns presented here.
Article Takeaways
What I would like for you to take away from this brief discussion is:
- First, you should be very careful when writing distributed code, and try to minimize the number of times all your processes need to synchronize. This is one of the largest places a slowdown can occur, and it’s not even limited by network!
- Understand that even if something works the same on a single GPU, there may be behavioral changes and tweaks to have the same code working efficiently on other distributed systems. Accelerate helps with this by ensuring that the same code can be used across any distributed platform with minimal overhead on the user, however in general it is also a good idea to be familiar with just what needs to be changed and how
If you liked this article, please be sure to check out my Twitter and if you are interested be sure to check out Accelerate, a library I work on: Accelerate.