-
Notifications
You must be signed in to change notification settings - Fork 971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Give example on how to handle gradient accumulation with cross-entropy #3193
Open
ylacombe
wants to merge
11
commits into
huggingface:main
Choose a base branch
from
ylacombe:add-cross-entropy-accumulation-example
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
2eb684e
Add cross-entropy example in the gradient accumulation docs
ylacombe 4d4ed80
add example of logs
ylacombe 3b8c887
correct skeleton code
ylacombe c01827c
replace gather_for_metrics with gather
ylacombe 22cbf9c
batch_size -> per_device_batch_size
ylacombe 395c572
remove main_process_only=True
ylacombe 2e80bf0
add autoregressive example in examples/
ylacombe 5e3e811
Update docs/source/usage_guides/gradient_accumulation.md
ylacombe c56c780
ruff format
ylacombe 80c720a
add grad accum test
ylacombe e5d2c50
update docs
ylacombe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,11 +187,11 @@ set_seed(0) | |
x = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.]) | ||
y = torch.tensor([2., 4., 6., 8., 10., 12., 14., 16.]) | ||
gradient_accumulation_steps = 4 | ||
batch_size = len(x) // gradient_accumulation_steps | ||
per_device_batch_size = len(x) // gradient_accumulation_steps | ||
|
||
# define dataset and dataloader | ||
dataset = TensorDataset(x, y) | ||
dataloader = DataLoader(dataset, batch_size=batch_size) | ||
dataloader = DataLoader(dataset, batch_size=per_device_batch_size) | ||
|
||
# define model, optimizer and loss function | ||
class SimpleLinearModel(torch.nn.Module): | ||
|
@@ -238,3 +238,233 @@ initial model weight is 0.00000 | |
w/ accumulation, the final model weight is 2.04000 | ||
w/o accumulation, the final model weight is 2.04000 | ||
``` | ||
|
||
## Gradient accumulation on training samples of variable size | ||
|
||
As was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accumulation), which points out a common error that occurs when performing gradient accumulation on training samples of variable size: | ||
|
||
> [...] for gradient accumulation across token-level tasks like causal LM training, the correct loss should be computed by the **total loss across all batches in a gradient accumulation step** divided by the **total number of all non padding tokens in those batches**. This is not the same as the average of the per-batch loss values. | ||
|
||
In other words, some adjustements must be made on losses that operate on a token-level basis. | ||
|
||
### Skeleton code | ||
|
||
```python | ||
from accelerate import Accelerator | ||
import math | ||
import contextlib | ||
|
||
gradient_accumulation_steps = 2 | ||
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) | ||
model, optimizer, training_dataloader, scheduler = accelerator.prepare( | ||
model, optimizer, training_dataloader, scheduler | ||
) | ||
|
||
training_iterator = iter(training_dataloader) | ||
num_samples_in_epoch = len(training_dataloader) | ||
remainder = num_samples_in_epoch % gradient_accumulation_steps | ||
remainder = remainder if remainder != 0 else gradient_accumulation_steps | ||
total_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps) | ||
|
||
|
||
total_batched_samples = 0 | ||
for update_step in range(total_updates): | ||
# In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss | ||
# we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples | ||
batch_samples = [] | ||
num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder | ||
for _ in range(num_batches_in_step): | ||
batch_samples += [next(training_iterator)] | ||
|
||
# get local num items in batch | ||
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) | ||
# to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch. | ||
num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item() | ||
|
||
for i, batch in enumerate(batch_samples): | ||
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating | ||
# cf: https://muellerzr.github.io/blog/gradient_accumulation.html | ||
if (i < len(batch_samples) - 1 and accelerator.num_processes > 1): | ||
ctx = model.no_sync | ||
else: | ||
ctx = contextlib.nullcontext | ||
|
||
total_batched_samples += 1 | ||
|
||
with ctx(): | ||
inputs, targets = batch | ||
outputs = model(inputs) | ||
loss = loss_function(outputs, targets) # the loss function shoud sum over samples rather than averaging | ||
|
||
# We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices | ||
# Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps | ||
loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch | ||
|
||
accelerator.backward(loss) | ||
|
||
# Sync gradients and perform optimization steps once every gradient_accumulation_steps | ||
optimizer.step() | ||
scheduler.step() | ||
optimizer.zero_grad() | ||
``` | ||
|
||
### Self-contained causal LM example | ||
|
||
```py | ||
import torch | ||
import copy | ||
from accelerate import Accelerator | ||
from accelerate.utils import set_seed | ||
from accelerate.logging import get_logger | ||
from torch.utils.data import Dataset, DataLoader | ||
import math | ||
import contexlib | ||
|
||
# seed | ||
set_seed(0) | ||
logger = get_logger(__name__) | ||
|
||
class MyDataset(Dataset): | ||
def __init__(self, num_samples): | ||
super().__init__() | ||
self.len = num_samples | ||
|
||
def __getitem__(self, index): | ||
input_ids = torch.arange(1, index+2, dtype=torch.float32) | ||
labels = torch.remainder(input_ids, 2) | ||
return {"input_ids": input_ids, "labels": labels} | ||
|
||
def __len__(self): | ||
return self.len | ||
|
||
def collate_fn(features): | ||
input_ids = torch.nn.utils.rnn.pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=-100) | ||
labels = torch.nn.utils.rnn.pad_sequence([f["labels"] for f in features], batch_first=True, padding_value=-100) | ||
return {"input_ids": input_ids[..., None], "labels": labels[..., None]} | ||
|
||
# define toy inputs and labels | ||
gradient_accumulation_steps = 2 | ||
per_device_batch_size = 4 | ||
|
||
# define accelerator | ||
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) | ||
|
||
# define dataset and dataloader | ||
# for this toy example, we'll compute gradient descent over one single global batch | ||
dataset = MyDataset(per_device_batch_size*gradient_accumulation_steps*accelerator.num_processes) | ||
dataloader = DataLoader(dataset, batch_size=per_device_batch_size, collate_fn=collate_fn) | ||
|
||
# define model, model_optimizer and loss function | ||
model = torch.nn.Linear(1, 2, bias=False) | ||
model_clone = copy.deepcopy(model) | ||
criterion = torch.nn.CrossEntropyLoss(reduction="sum") # must sum over samples rather than averaging | ||
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.08) | ||
|
||
|
||
logger.warning(f"initial model weight is {model.weight.detach().cpu().squeeze()}") | ||
logger.warning(f"initial model clone weight is {model_clone.weight.detach().cpu().squeeze()}") | ||
|
||
# prepare artifacts - accelerator handles device placement and dataloader splitting | ||
model, model_optimizer = accelerator.prepare(model, model_optimizer) | ||
dataloader = accelerator.prepare_data_loader(dataloader, device_placement=True) | ||
training_iterator = iter(dataloader) | ||
|
||
num_samples_in_epoch = len(dataloader) | ||
remainder = num_samples_in_epoch % gradient_accumulation_steps | ||
remainder = remainder if remainder != 0 else gradient_accumulation_steps | ||
total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps) | ||
|
||
total_batched_samples = 0 | ||
for update_step in range(total_gradient_updates): | ||
# In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss | ||
# we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples | ||
batch_samples = [] | ||
num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder | ||
for _ in range(num_batches_in_step): | ||
batch_samples += [next(training_iterator)] | ||
|
||
# get local num items in batch | ||
local_num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) | ||
logger.warning(f"Step {update_step} - Device {accelerator.process_index} - num items in the local batch {local_num_items_in_batch}", main_process_only=False) | ||
|
||
# to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch. | ||
num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item() | ||
logger.warning(f"Total num items {num_items_in_batch}") | ||
|
||
for i, batch in enumerate(batch_samples): | ||
inputs, labels = batch["input_ids"], batch["labels"] | ||
total_batched_samples += 1 | ||
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating | ||
# cf: https://muellerzr.github.io/blog/gradient_accumulation.html | ||
if (i < len(batch_samples) - 1 and accelerator.num_processes > 1): | ||
ctx = model.no_sync | ||
else: | ||
ctx = contextlib.nullcontext | ||
with ctx(): | ||
|
||
outputs = model(inputs) | ||
loss = criterion(outputs.view(-1, 2), labels.view(-1).to(torch.int64)) | ||
|
||
# We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices | ||
# Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps | ||
loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch | ||
accelerator.backward(loss) | ||
model_optimizer.step() | ||
model_optimizer.zero_grad() | ||
|
||
|
||
logger.warning(f"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}", main_process_only=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than |
||
|
||
# We know do the same operation but on a single device and without gradient accumulation | ||
|
||
if accelerator.is_main_process: | ||
# prepare one single entire batch | ||
dataloader = DataLoader(dataset, batch_size=len(dataset), collate_fn=collate_fn) | ||
full_batch_without_accum = next(iter(dataloader)) | ||
total_inputs, total_labels = full_batch_without_accum["input_ids"], full_batch_without_accum["labels"] | ||
model_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.08) | ||
|
||
# train the cloned model | ||
loss = torch.nn.CrossEntropyLoss(reduction="mean")(model_clone(total_inputs).view(-1, 2), total_labels.view(-1).to(torch.int64)) | ||
model_clone_optimizer.zero_grad() | ||
loss.backward() | ||
model_clone_optimizer.step() | ||
|
||
# We should have the same final weights. | ||
logger.warning(f"w/o accumulation, the final model weight is {model_clone.weight.detach().cpu().squeeze()}") | ||
|
||
``` | ||
|
||
Results on a single device - gradient accumulation steps set to 1 and batch_size set to 8: | ||
``` | ||
initial model weight is tensor([-0.0075, 0.5364]) | ||
initial model clone weight is tensor([-0.0075, 0.5364]) | ||
Step 0 - Device 0 - num items in the local batch 36 | ||
Total num items 36 | ||
Device 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337]) | ||
w/o accumulation, the final model weight is tensor([0.0953, 0.4337]) | ||
``` | ||
|
||
Results on a two devices set-up - gradient accumulation steps set to 2 and batch_size set to 4. | ||
``` | ||
initial model weight is tensor([-0.0075, 0.5364]) | ||
initial model clone weight is tensor([-0.0075, 0.5364]) | ||
Step 0 - Device 0 - num items in the local batch 52 | ||
Step 0 - Device 1 - num items in the local batch 84 | ||
Total num items 136 | ||
Device 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172]) | ||
Device 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172]) | ||
w/o accumulation, the final model weight is tensor([0.2117, 0.3172]) | ||
``` | ||
|
||
### To go further: | ||
|
||
Please find a complete example script on a real world training run in the examples folder at the path [`accelerate/examples/by_feature/gradient_accumulation_for_autoregressive_models.py`](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/gradient_accumulation_for_autoregressive_models.py). | ||
|
||
Running it on several training configurations with constant global batch size equal to 32 gives the following graph: | ||
|
||
<div style="text-align: center"> | ||
<img src="https://huggingface.co/datasets/hf-audio/gradient_accumulation_example/resolve/main/training_losses.png"> | ||
</div> | ||
|
||
Note that the training losses are exactly the same up to training step 20. The small deviation after this training step occurs at the very end of the first epoch, because, by [default](https://huggingface.co/docs/accelerate/en/package_reference/torch_wrappers#accelerate.data_loader.prepare_data_loader.even_batches), the dataloader duplicates the samples at the beginning of the dataset when the total batch size doesn't exactly divide the dataset. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only works when we know the size of the dataloader. Can we think of a solution that doesn't require this information ? I think we can just iter on the dataloader until we have
gradient_accumulation_steps
to create the batch_sample. If we can't iter anymore, then we stop also. I think that code will be easier to understand.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes agreed :) (What we do in the Trainer)