Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give example on how to handle gradient accumulation with cross-entropy #3193

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
234 changes: 232 additions & 2 deletions docs/source/usage_guides/gradient_accumulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ set_seed(0)
x = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.])
y = torch.tensor([2., 4., 6., 8., 10., 12., 14., 16.])
gradient_accumulation_steps = 4
batch_size = len(x) // gradient_accumulation_steps
per_device_batch_size = len(x) // gradient_accumulation_steps

# define dataset and dataloader
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=batch_size)
dataloader = DataLoader(dataset, batch_size=per_device_batch_size)

# define model, optimizer and loss function
class SimpleLinearModel(torch.nn.Module):
Expand Down Expand Up @@ -238,3 +238,233 @@ initial model weight is 0.00000
w/ accumulation, the final model weight is 2.04000
w/o accumulation, the final model weight is 2.04000
```

## Gradient accumulation on training samples of variable size

As was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accumulation), which points out a common error that occurs when performing gradient accumulation on training samples of variable size:

> [...] for gradient accumulation across token-level tasks like causal LM training, the correct loss should be computed by the **total loss across all batches in a gradient accumulation step** divided by the **total number of all non padding tokens in those batches**. This is not the same as the average of the per-batch loss values.

In other words, some adjustements must be made on losses that operate on a token-level basis.

### Skeleton code

```python
from accelerate import Accelerator
import math
import contextlib

gradient_accumulation_steps = 2
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
model, optimizer, training_dataloader, scheduler
)

training_iterator = iter(training_dataloader)
num_samples_in_epoch = len(training_dataloader)
remainder = num_samples_in_epoch % gradient_accumulation_steps
remainder = remainder if remainder != 0 else gradient_accumulation_steps
total_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)


total_batched_samples = 0
for update_step in range(total_updates):
# In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss
# we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples
batch_samples = []
num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
for _ in range(num_batches_in_step):
batch_samples += [next(training_iterator)]

# get local num items in batch
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
# to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.
num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()

for i, batch in enumerate(batch_samples):
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating
# cf: https://muellerzr.github.io/blog/gradient_accumulation.html
if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
ctx = model.no_sync
else:
ctx = contextlib.nullcontext

total_batched_samples += 1

with ctx():
inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets) # the loss function shoud sum over samples rather than averaging

# We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
# Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps
loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch

accelerator.backward(loss)

# Sync gradients and perform optimization steps once every gradient_accumulation_steps
optimizer.step()
scheduler.step()
optimizer.zero_grad()
```

### Self-contained causal LM example

```py
import torch
import copy
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.logging import get_logger
from torch.utils.data import Dataset, DataLoader
import math
import contexlib

# seed
set_seed(0)
logger = get_logger(__name__)

class MyDataset(Dataset):
def __init__(self, num_samples):
super().__init__()
self.len = num_samples

def __getitem__(self, index):
input_ids = torch.arange(1, index+2, dtype=torch.float32)
labels = torch.remainder(input_ids, 2)
return {"input_ids": input_ids, "labels": labels}

def __len__(self):
return self.len

def collate_fn(features):
input_ids = torch.nn.utils.rnn.pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=-100)
labels = torch.nn.utils.rnn.pad_sequence([f["labels"] for f in features], batch_first=True, padding_value=-100)
return {"input_ids": input_ids[..., None], "labels": labels[..., None]}

# define toy inputs and labels
gradient_accumulation_steps = 2
per_device_batch_size = 4

# define accelerator
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)

# define dataset and dataloader
# for this toy example, we'll compute gradient descent over one single global batch
dataset = MyDataset(per_device_batch_size*gradient_accumulation_steps*accelerator.num_processes)
dataloader = DataLoader(dataset, batch_size=per_device_batch_size, collate_fn=collate_fn)

# define model, model_optimizer and loss function
model = torch.nn.Linear(1, 2, bias=False)
model_clone = copy.deepcopy(model)
criterion = torch.nn.CrossEntropyLoss(reduction="sum") # must sum over samples rather than averaging
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.08)


logger.warning(f"initial model weight is {model.weight.detach().cpu().squeeze()}")
logger.warning(f"initial model clone weight is {model_clone.weight.detach().cpu().squeeze()}")

# prepare artifacts - accelerator handles device placement and dataloader splitting
model, model_optimizer = accelerator.prepare(model, model_optimizer)
dataloader = accelerator.prepare_data_loader(dataloader, device_placement=True)
training_iterator = iter(dataloader)

num_samples_in_epoch = len(dataloader)
remainder = num_samples_in_epoch % gradient_accumulation_steps
remainder = remainder if remainder != 0 else gradient_accumulation_steps
total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)

total_batched_samples = 0
for update_step in range(total_gradient_updates):
# In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss
# we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples
batch_samples = []
num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder
for _ in range(num_batches_in_step):
batch_samples += [next(training_iterator)]
Comment on lines +372 to +384
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works when we know the size of the dataloader. Can we think of a solution that doesn't require this information ? I think we can just iter on the dataloader until we have gradient_accumulation_steps to create the batch_sample. If we can't iter anymore, then we stop also. I think that code will be easier to understand.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agreed :) (What we do in the Trainer)


# get local num items in batch
local_num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
logger.warning(f"Step {update_step} - Device {accelerator.process_index} - num items in the local batch {local_num_items_in_batch}", main_process_only=False)

# to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.
num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item()
logger.warning(f"Total num items {num_items_in_batch}")

for i, batch in enumerate(batch_samples):
inputs, labels = batch["input_ids"], batch["labels"]
total_batched_samples += 1
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating
# cf: https://muellerzr.github.io/blog/gradient_accumulation.html
if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
ctx = model.no_sync
else:
ctx = contextlib.nullcontext
with ctx():

outputs = model(inputs)
loss = criterion(outputs.view(-1, 2), labels.view(-1).to(torch.int64))

# We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
# Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps
loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
accelerator.backward(loss)
model_optimizer.step()
model_optimizer.zero_grad()


logger.warning(f"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}", main_process_only=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than logger.warning, we can do print() here or change the default logging level :) (Just logging.warning rather than logging.info weirds me out)


# We know do the same operation but on a single device and without gradient accumulation

if accelerator.is_main_process:
# prepare one single entire batch
dataloader = DataLoader(dataset, batch_size=len(dataset), collate_fn=collate_fn)
full_batch_without_accum = next(iter(dataloader))
total_inputs, total_labels = full_batch_without_accum["input_ids"], full_batch_without_accum["labels"]
model_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.08)

# train the cloned model
loss = torch.nn.CrossEntropyLoss(reduction="mean")(model_clone(total_inputs).view(-1, 2), total_labels.view(-1).to(torch.int64))
model_clone_optimizer.zero_grad()
loss.backward()
model_clone_optimizer.step()

# We should have the same final weights.
logger.warning(f"w/o accumulation, the final model weight is {model_clone.weight.detach().cpu().squeeze()}")

```

Results on a single device - gradient accumulation steps set to 1 and batch_size set to 8:
```
initial model weight is tensor([-0.0075, 0.5364])
initial model clone weight is tensor([-0.0075, 0.5364])
Step 0 - Device 0 - num items in the local batch 36
Total num items 36
Device 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337])
w/o accumulation, the final model weight is tensor([0.0953, 0.4337])
```

Results on a two devices set-up - gradient accumulation steps set to 2 and batch_size set to 4.
```
initial model weight is tensor([-0.0075, 0.5364])
initial model clone weight is tensor([-0.0075, 0.5364])
Step 0 - Device 0 - num items in the local batch 52
Step 0 - Device 1 - num items in the local batch 84
Total num items 136
Device 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])
Device 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])
w/o accumulation, the final model weight is tensor([0.2117, 0.3172])
```

### To go further:

Please find a complete example script on a real world training run in the examples folder at the path [`accelerate/examples/by_feature/gradient_accumulation_for_autoregressive_models.py`](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/gradient_accumulation_for_autoregressive_models.py).

Running it on several training configurations with constant global batch size equal to 32 gives the following graph:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/hf-audio/gradient_accumulation_example/resolve/main/training_losses.png">
</div>

Note that the training losses are exactly the same up to training step 20. The small deviation after this training step occurs at the very end of the first epoch, because, by [default](https://huggingface.co/docs/accelerate/en/package_reference/torch_wrappers#accelerate.data_loader.prepare_data_loader.even_batches), the dataloader duplicates the samples at the beginning of the dataset when the total batch size doesn't exactly divide the dataset.
Loading
Loading