PyTorch, Gradient Accumulation, and the dreaded lack of reproducability

pytorch
Published

October 29, 2024

Introduction

A few weeks ago the Unsloth team put out a pretty damning report showing that most training frameworks have critical issues when it comes to applying gradient accumulation and training language models (specifically in the use-case of generation).

When performing gradient accumulation, the underlying assumption is that training with a batch size of 8 and 4 gradient accumulation steps should be exactly equivalent to training with a batch size of 32 and no accumulation. However, what has been discovered is that when training language models for generation the resulting outputs are not all uniform (the same size), which makes a drastic difference in calculating the loss.

In this blog, I’ll be walking you through what myself and the rest of the transformers team (Marc Sun, Yoach Lacombe, myself, and many others) worked through to investigate this issue and break it down to its core parts in a reproducible case. I’ll also discuss how this fix is also needed for distributed training , something the Unsloth team didn’t talk about in their report.

Required Reading

Before reading this article, I recommend reading my prior article on gradient accumulation relative to multi-GPU training, it will come into play later.

Setup

First let’s discuss setup.

For these experiments, I used the following:

  • Python: 3.10.13
  • PyTorch: v2.5.0
  • Accelerate: v1.0.1
  • Transformers: v4.46.0
  • Compute:
    • Single RTX 4090
    • 8x H100’s for the DDP tests

Creating a baseline

Like all good experiments, we need a baseline. A benchmark.

For this experiment, we’ll use the following setup:

  • Dataset: A small chunk of the Salesforce repo of wikitext-2-v1 hosted on the Hub.
  • Model: SmolLM-135M
  • Optimizer: AdamW
  • Scheduler: Constant LR
  • Actual batch size: 8 (it’s what could fit in 24gb of vRAM)

Core Code

Below is the basic code for setting up:

  • Reproducability
  • The dataset
  • The model
  • The torch DataLoaders
  • Base training

Let’s start with some code that sets up our dataset/dataloaders, model, optimizer, and scheduler:

Show the code
import random
import torch
import numpy as np

from datasets import load_dataset
from torch.nn.functional import cross_entropy
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModelForCausalLM, get_constant_schedule

def set_seed():
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

set_seed()

model_name = "HuggingFaceTB/SmolLM-135M"
dataset_name = "Salesforce/wikitext"
batch_size = 8

datasets = load_dataset(
    dataset_name, "wikitext-2-v1", split={"train":"train[:800]", "validation":"validation[:80]"}
)
datasets = datasets.filter(lambda x: len(x)>0, input_columns="text")
assert len(datasets["train"]) >= 500 and len(datasets["train"]) < 600
assert len(datasets["validation"]) >= 50 and len(datasets["validation"]) < 60

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

def get_items(model_name):
    model = AutoModelForCausalLM.from_pretrained(model_name)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    scheduler = get_constant_schedule(optimizer=optimizer)
    return model, optimizer, scheduler

model, optimizer, scheduler = get_items(model_name)

def tokenize_func(data):
    return tokenizer(data["text"], max_length=None, return_attention_mask=False)
tokenized_datasets = datasets.map(tokenize_func, batched=True, remove_columns=["text"])

def collate_fn(examples):
    max_length = max([len(example["input_ids"]) for example in examples])
    batch = tokenizer.pad(
        examples, 
        padding="max_length", 
        max_length=max_length+1, 
        pad_to_multiple_of = None,
        return_tensors="pt",
    )
    batch["labels"] = batch["input_ids"][:, 1:]
    batch["input_ids"] = batch["input_ids"][:, :-1]

    batch["labels"] = torch.where(batch["labels"] == tokenizer.pad_token_id, -100, batch["labels"])
    return batch

def get_dataloaders(train_batch_size:int=8):
    train_dl = DataLoader(
        tokenized_datasets["train"], shuffle=False, collate_fn=collate_fn, batch_size=train_batch_size,
    )
    eval_dl = DataLoader(
        tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=4
    )
    return train_dl, eval_dl

train_dl, eval_dl = get_dataloaders(train_batch_size=batch_size)

And finally write a training loop:

model.to("cuda")
losses_baseline = []
total_batched_samples = 0
for epoch in range(3):
    model.train()
    for batch in train_dl:
        batch = {k:v.to("cuda") for k,v in batch.items()}
        out = model(**batch)
        loss = loss_fn(
            out["logits"], batch["labels"], model.vocab_size
        )
        loss.backward()
        losses_baseline.append(loss.cpu().detach().item())
            
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

We can then graph our curve and see it’s fairly smooth:

Gradient Accumulation, the naive way

Now let’s modify our training loop to perform basic gradient accumulation, and go again

(For this, the number of step is 2)

I’ve highlighted the core change to the code below:

gradient_accumulation_steps = 2
batch_size = 4
...
losses_grad_accum = []
total_batched_samples = 0
grad_accum_loss = 0
for epoch in range(3):
    ...
    for i,batch in enumerate(train_dl):
        ...
        loss = ...
        loss = loss / gradient_accumulation_steps
        loss.backward()
        grad_accum_loss += loss.cpu().detach().item()

        if (i % gradient_accumulation_steps != 0) or (i == len(train_dl)-1):
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            losses_grad_accum.append(grad_accum_loss)
            grad_accum_loss = 0
Show the full code
set_seed()
gradient_accumulation_steps = 2
batch_size = 4
model, optimizer, scheduler = get_items(model_name)
train_dl, eval_dl = get_dataloaders(train_batch_size=batch_size)

model.to("cuda")
losses_grad_accum = []
total_batched_samples = 0
grad_accum_loss = 0
for epoch in range(3):
    model.train()
    for i,batch in enumerate(train_dl):
        batch = {k:v.to("cuda") for k,v in batch.items()}
        out = model(**batch)
        loss = loss_fn(
            out["logits"], batch["labels"], model.vocab_size
        )
        loss = loss / gradient_accumulation_steps
        loss.backward()
        grad_accum_loss += loss.cpu().detach().item()

        if (i % gradient_accumulation_steps != 0) or (i == len(train_dl)-1):
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            losses_grad_accum.append(grad_accum_loss)
            grad_accum_loss = 0

And plot the results:

As you can see, they’re close but… not exact.

What’s going on?

The Problem: Loss

Let’s go back to how we defined our loss function:

def loss_fn(logits, labels, vocab_size):
    logits = logits.float()
    shift_logits = logits[..., :-1, :].contiguous().view(-1, vocab_size)
    shift_labels = labels[..., 1:].contiguous().view(-1)
    shift_labels = shift_labels.to(shift_logits.device)
    return cross_entropy(
        shift_logits, shift_labels, ignore_index=-100, reduction="mean"
    )

If you notice, we explicitly define the reduction as "mean" (the default).

What this means, is that we are assuming that across all steps of gradient accumulation, the number of labels seen total are the exact same. In a generation problem though this is not the case when we start messing with the batch sizes. For a quick dumb TL;DR:

Say the batch is:

[[0],[0,1],[0,1,2], [0,1,2,3]]

The average length of the first two items is .75, while the second is 3.5.

This tiny numerical difference means the world when it comes to calculating our loss here, as that "mean" isn’t taking into account the rest of the items our gradient accumulation step is seeing!

So what’s the fix?

The Fix: Loss

The first fix is to rewrite our loss function to take into account the total number of items seen across all gradient accumulation steps. The Unsloth crew go into more detail on why that matters, below I’ve defined a new loss function which reflects this:

def loss_fn(logits, labels, vocab_size, num_items_in_batch=None):
    logits = logits.float()
    shift_logits = logits[..., :-1, :].contiguous().view(-1, vocab_size)
    shift_labels = labels[..., 1:].contiguous().view(-1)
    shift_labels = shift_labels.to(shift_logits.device)
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = cross_entropy(
        shift_logits, shift_labels, ignore_index=-100, reduction=reduction
    )
    if reduction == "sum":
        return loss / num_items_in_batch
    return loss

Essentially if we pass in a num_items_in_batch, we use the "sum" of everything then divide by the total later, rather than letting PyTorch do it themselves.

But, that’s not the only fix we need to do. How do we get num_items_in_batch?

The Fix: Prefetching

The second fix is figuring out num_items_in_batch. We need to be careful about:

  1. Making sure we prefetch gradient_accumulation_steps batches of data at a time
  2. Calculating the total non pad tokens across all labels.

Let’s rewrite our training loop to do just that:

num_update_steps_per_epoch = math.ceil(len(train_dl) / gradient_accumulation_steps)
remainder = len(train_dl) % gradient_accumulation_steps
if remainder == 0:
    remainder = gradient_accumulation_steps

losses_fixed_ga = []
actual_loss = 0
for epoch in range(3):
    ...
    iterator = iter(train_dl)
    for update_step in range(num_update_steps_per_epoch):
        batch_samples = []
        num_batches = gradient_accumulation_steps if update_step != (num_update_steps_per_epoch - 1) else remainder
        # Prefetch and calculate the number of non-padded items seen across one gradient accumulation "step"
        for _ in range(num_batches):
            batch_samples += [next(iterator)]
        num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])

        for batch in batch_samples:
            ...
            loss = loss_fn(
                out["logits"], batch["labels"], 
                vocab_size=model.vocab_size, num_items_in_batch=num_items_in_batch
            )
            loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
Show the full code
import math

set_seed()
gradient_accumulation_steps = 2
batch_size = 4
model, optimizer, scheduler = get_items(model_name)
train_dl, eval_dl = get_dataloaders(train_batch_size=batch_size)

model.to("cuda")

num_update_steps_per_epoch = math.ceil(len(train_dl) / gradient_accumulation_steps)
remainder = len(train_dl) % gradient_accumulation_steps
if remainder == 0:
    remainder = gradient_accumulation_steps

losses_fixed_ga = []
actual_loss = 0
total_batched_samples = 0
for epoch in range(3):
    model.train()
    iterator = iter(train_dl)
    for update_step in range(num_update_steps_per_epoch):
        batch_samples = []
        num_batches = gradient_accumulation_steps if update_step != (num_update_steps_per_epoch - 1) else remainder
        # Prefetch and calculate the number of non-padded items seen across one gradient accumulation "step"
        for _ in range(num_batches):
            batch_samples += [next(iterator)]
        num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])

        for batch in batch_samples:
            total_batched_samples += 1
            batch = {k:v.to("cuda") for k,v in batch.items()}
            out = model(**batch)
            loss = loss_fn(
                out["logits"], batch["labels"], 
                vocab_size=model.vocab_size, num_items_in_batch=num_items_in_batch
            )
            loss.backward()
            actual_loss += loss.detach().cpu().item()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        losses_fixed_ga.append(actual_loss)
        actual_loss = 0

And also rerun our baseline:

Show the code
set_seed()
gradient_accumulation_steps = 1
batch_size = 8
model, optimizer, scheduler = get_items(model_name)
train_dl, eval_dl = get_dataloaders(train_batch_size=batch_size)

model.to("cuda")

num_update_steps_per_epoch = math.ceil(len(train_dl) / gradient_accumulation_steps)
remainder = len(train_dl) % gradient_accumulation_steps
if remainder == 0:
    remainder = gradient_accumulation_steps

losses_baseline = []
actual_loss = 0
total_batched_samples = 0
for epoch in range(3):
    model.train()
    iterator = iter(train_dl)
    for update_step in range(num_update_steps_per_epoch):
        batch_samples = []
        num_batches = gradient_accumulation_steps if update_step != (num_update_steps_per_epoch - 1) else remainder
        # Prefetch and calculate the number of non-padded items seen across one gradient accumulation "step"
        for _ in range(num_batches):
            batch_samples += [next(iterator)]
        num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])

        for batch in batch_samples:
            total_batched_samples += 1
            batch = {k:v.to("cuda") for k,v in batch.items()}
            out = model(**batch)
            loss = loss_fn(
                out["logits"], batch["labels"], 
                vocab_size=model.vocab_size, num_items_in_batch=num_items_in_batch
            )
            loss.backward()
            actual_loss += loss.detach().cpu().item()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        losses_baseline.append(actual_loss)
        actual_loss = 0

And now we find they are near exactly the same! (I found I could get within ~5 decimal places, not bad at all).

That’s it, we’re done right?

Wrong

Problem: Distributed Training

That’s great, but what about during distributed training?

Since the data is split across n GPUs, each other GPU has no idea how many total items are seen across a step, leading to the same issue.

The solution is to call a gather() across the inputs and use them to help calculate the loss. The problem here, is this involves a communication step between all of the GPUs, which can get costly if we’re doing so every gradient accumulation step (as rather than a single communication when we do backward(), we’re now doubling it to two).

Below is an experiment I ran across 8 GPUs (with a much larger batch size) showcasing how these results change based on if we do gather() or not.

The full solution is below, utilizing accelerate solely to handle DDP and splitting the data between each GPU, just make sure to run this via torchrun or accelerate launch.

If you want to be 100% exact, I recommend you do this. However, without it we’re extremely close (much closer than before), so it’s up to you, your compute budget, and if you find the extra .gather() adds too much time.

Show the code
import pandas as pd
import torch
import argparse
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, get_constant_schedule
import random
import numpy as np
from accelerate import Accelerator
from accelerate.utils import reduce
import math
import contextlib


random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

def main(args):
    accelerator = Accelerator()
    accelerator.print("Loading dataset")
    datasets = load_dataset("Salesforce/wikitext", "wikitext-2-v1")
    datasets = datasets.filter(lambda x: len(x)>0, input_columns="text")

    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
    tokenizer.pad_token = tokenizer.eos_token

    accelerator.print("Creating model")
    model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M")
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    scheduler = get_constant_schedule(optimizer=optimizer)

    def tokenize_func(data):
        return tokenizer(data["text"], max_length=None, return_attention_mask=False)
    tokenized_datasets = datasets.map(tokenize_func, batched=True, remove_columns=["text"])

    def collate_fn(examples):
        max_length = max([len(example["input_ids"]) for example in examples])
        batch = tokenizer.pad(
            examples, 
            padding="max_length", 
            max_length=max_length+1, 
            pad_to_multiple_of = None,
            return_tensors="pt",
        )
        batch["labels"] = batch["input_ids"][:, 1:]
        batch["input_ids"] = batch["input_ids"][:, :-1]

        batch["labels"] = torch.where(batch["labels"] == tokenizer.pad_token_id, -100, batch["labels"])
        return batch

    def get_dataloaders(train_batch_size:int=8):
        train_dl = DataLoader(
            tokenized_datasets["train"], shuffle=False, collate_fn=collate_fn, batch_size=train_batch_size,
        )
        eval_dl = DataLoader(
            tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=4
        )
        return train_dl, eval_dl

    accelerator.print("Making dataloaders")
    train_dl, eval_dl = get_dataloaders(train_batch_size=args.bs)

    def loss_fn(logits, labels, vocab_size, num_items_in_batch=None):
        logits = logits.float()
        shift_logits = logits[..., :-1, :].contiguous().view(-1, vocab_size)
        shift_labels = labels[..., 1:].contiguous().view(-1)
        shift_labels = shift_labels.to(shift_logits.device)
        reduction = "sum" if num_items_in_batch is not None else "mean"
        loss = cross_entropy(
            shift_logits, shift_labels, ignore_index=-100, reduction=reduction
        )
        if reduction == "sum":
            return loss / num_items_in_batch
        return loss

    accelerator.print("Calling prepare")
    model, train_dl = accelerator.prepare(model, train_dl)

    losses_baseline = []
    actual_loss = 0
    num_update_steps_per_epoch = math.ceil(len(train_dl) / args.ga)
    remainder = len(train_dl) % args.ga
    if remainder == 0:
        remainder = args.ga

    total_batched_samples = 0
    accelerator.print("Starting training")
    for epoch in range(3):
        model.train()
        iterator = iter(train_dl)
        for update_step in range(num_update_steps_per_epoch):
            batch_samples = []
            num_batches = args.ga if update_step != (num_update_steps_per_epoch - 1) else remainder
            for _ in range(num_batches):
                batch_samples += [next(iterator)]
            num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
            num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()

            for i,batch in enumerate(batch_samples):
                ctx = model.no_sync if i == len(batch_samples) - 1 else contextlib.nullcontext
                total_batched_samples += 1
                with ctx():
                    out = model(**batch)
                    loss = loss_fn(
                        out["logits"], batch["labels"], 
                        vocab_size=model.module.vocab_size, num_items_in_batch=num_items_in_batch
                    )
                    loss = loss * accelerator.num_processes
                    loss.backward()
                actual_loss += loss.detach()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            actual_loss = accelerator.gather(actual_loss)
            actual_loss = actual_loss.cpu().sum().item()
            losses_baseline.append(actual_loss)
            actual_loss = 0
    
    df = pd.DataFrame({"loss": losses_baseline})
    if args.ga == 1:
        name = "losses_baseline"
    else:
        name = f"losses_bs{args.bs}_ga{args.ga}_fixed"
    df.to_csv(f"{name}.csv", index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a language model with optional gradient accumulation")
    parser.add_argument("--bs", type=int, default=8, help="Training batch size")
    parser.add_argument("--ga", type=int, default=1, help="Gradient accumulation steps")

    args = parser.parse_args()
    main(args)

Conclusion

As we continue to see, gradient accumulation seems simple on the surface but hard to get right! Hopefully this article helps teach you how to stay reproducible as you scale training with gradient accumulation.

I’d like to thank the Unsloth team who helped us figure out how to change the code in the Trainer, and Yoach and Marc for getting down in the weeds with me as we worked towards coming up with minimal reproducible examples to help educate all of us.