Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give example on how to handle gradient accumulation with cross-entropy #3193

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
216 changes: 214 additions & 2 deletions docs/source/usage_guides/gradient_accumulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ set_seed(0)
x = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.])
y = torch.tensor([2., 4., 6., 8., 10., 12., 14., 16.])
gradient_accumulation_steps = 4
batch_size = len(x) // gradient_accumulation_steps
per_device_batch_size = len(x) // gradient_accumulation_steps

# define dataset and dataloader
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=batch_size)
dataloader = DataLoader(dataset, batch_size=per_device_batch_size)

# define model, optimizer and loss function
class SimpleLinearModel(torch.nn.Module):
Expand Down Expand Up @@ -238,3 +238,215 @@ initial model weight is 0.00000
w/ accumulation, the final model weight is 2.04000
w/o accumulation, the final model weight is 2.04000
```

## Gradient accumulation on training samples of variable size

As was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accumulation), which points out a common error that occurs when perfoming gradient accumulation on training samples of variable size:
ylacombe marked this conversation as resolved.
Show resolved Hide resolved

> [...] for gradient accumulation across token-level tasks like causal LM training, the correct loss should be computed by the **total loss across all batches in a gradient accumulation step** divided by the **total number of all non padding tokens in those batches**. This is not the same as the average of the per-batch loss values.

In other words, some adjustements must be made on losses that operate on a token-level basis.

### Skeleton code

```python
from accelerate import Accelerator
import math
gradient_accumulation_steps = 2
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
model, optimizer, training_dataloader, scheduler
)

training_iterator = iter(training_dataloader)
num_samples_in_epoch = len(training_dataloader)
remainder = num_samples_in_epoch % gradient_accumulation_steps
remainder = remainder if remainder != 0 else gradient_accumulation_steps
total_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)


total_batched_samples = 0
for update_step in range(total_updates):
# In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss
# we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples
batch_samples = []
num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
for _ in range(num_batches_in_step):
batch_samples += [next(training_iterator)]

# get local num items in batch
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
# to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.
num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()

for batch in batch_samples:
total_batched_samples += 1

with accelerator.accumulate(model):
# Since we performed prefetching, we need to manually set sync_gradients
if total_batched_samples % gradient_accumulation_steps != 0:
accelerator.gradient_state._set_sync_gradients(False)
else:
accelerator.gradient_state._set_sync_gradients(True)

inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets) # the loss function shoud sum over samples rather than averaging

# We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
# Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps
loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch

accelerator.backward(loss)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
```

### Self-contained causal LM example

```py
import torch
import copy
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.logging import get_logger
from torch.utils.data import Dataset, DataLoader
import math

# seed
set_seed(0)
logger = get_logger(__name__)

class MyDataset(Dataset):
def __init__(self, num_samples):
super().__init__()
self.len = num_samples

def __getitem__(self, index):
input_ids = torch.arange(1, index+2, dtype=torch.float32)
labels = torch.remainder(input_ids, 2)
return {"input_ids": input_ids, "labels": labels}

def __len__(self):
return self.len

def collate_fn(features):
input_ids = torch.nn.utils.rnn.pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=-100)
labels = torch.nn.utils.rnn.pad_sequence([f["labels"] for f in features], batch_first=True, padding_value=-100)
return {"input_ids": input_ids[..., None], "labels": labels[..., None]}

# define toy inputs and labels
gradient_accumulation_steps = 2
per_device_batch_size = 4

# define accelerator
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)

# define dataset and dataloader
# for this toy example, we'll compute gradient descent over one single global batch
dataset = MyDataset(per_device_batch_size*gradient_accumulation_steps*accelerator.num_processes)
dataloader = DataLoader(dataset, batch_size=per_device_batch_size, collate_fn=collate_fn)

# define model, model_optimizer and loss function
model = torch.nn.Linear(1, 2, bias=False)
model_clone = copy.deepcopy(model)
criterion = torch.nn.CrossEntropyLoss(reduction="sum") # must sum over samples rather than averaging
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.08)


logger.warning(f"initial model weight is {model.weight.detach().cpu().squeeze()}")
logger.warning(f"initial model clone weight is {model_clone.weight.detach().cpu().squeeze()}")

# prepare artifacts - accelerator handles device placement and dataloader splitting
model, model_optimizer = accelerator.prepare(model, model_optimizer)
dataloader = accelerator.prepare_data_loader(dataloader, device_placement=True)
training_iterator = iter(dataloader)

num_samples_in_epoch = len(dataloader)
remainder = num_samples_in_epoch % gradient_accumulation_steps
remainder = remainder if remainder != 0 else gradient_accumulation_steps
total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)

total_batched_samples = 0
for update_step in range(total_gradient_updates):
# In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss
# we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples
batch_samples = []
num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder
for _ in range(num_batches_in_step):
batch_samples += [next(training_iterator)]
Comment on lines +372 to +384
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works when we know the size of the dataloader. Can we think of a solution that doesn't require this information ? I think we can just iter on the dataloader until we have gradient_accumulation_steps to create the batch_sample. If we can't iter anymore, then we stop also. I think that code will be easier to understand.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agreed :) (What we do in the Trainer)


# get local num items in batch
local_num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
logger.warning(f"Step {update_step} - Device {accelerator.process_index} - num items in the local batch {local_num_items_in_batch}", main_process_only=False)

# to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.
num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item()
logger.warning(f"Total num items {num_items_in_batch}")

for batch in batch_samples:
inputs, labels = batch["input_ids"], batch["labels"]
total_batched_samples += 1

with accelerator.accumulate(model):
# Since we performed prefetching, we need to manually set sync_gradients
if total_batched_samples % gradient_accumulation_steps != 0:
accelerator.gradient_state._set_sync_gradients(False)
else:
accelerator.gradient_state._set_sync_gradients(True)

SunMarc marked this conversation as resolved.
Show resolved Hide resolved
outputs = model(inputs)
loss = criterion(outputs.view(-1, 2), labels.view(-1).to(torch.int64))

# We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
# Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps
loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
accelerator.backward(loss)
model_optimizer.step()
model_optimizer.zero_grad()


logger.warning(f"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}", main_process_only=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than logger.warning, we can do print() here or change the default logging level :) (Just logging.warning rather than logging.info weirds me out)


# We know do the same operation but on a single device and without gradient accumulation

if accelerator.is_main_process:
# prepare one single entire batch
dataloader = DataLoader(dataset, batch_size=len(dataset), collate_fn=collate_fn)
full_batch_without_accum = next(iter(dataloader))
total_inputs, total_labels = full_batch_without_accum["input_ids"], full_batch_without_accum["labels"]
model_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.08)

# train the cloned model
loss = torch.nn.CrossEntropyLoss(reduction="mean")(model_clone(total_inputs).view(-1, 2), total_labels.view(-1).to(torch.int64))
model_clone_optimizer.zero_grad()
loss.backward()
model_clone_optimizer.step()

# We should have the same final weights.
logger.warning(f"w/o accumulation, the final model weight is {model_clone.weight.detach().cpu().squeeze()}")

```

Results on a single device:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can precise the exact setup ? I think that we are doing the following ?

  • dp=1 grad_acc= 2 batch_size = 4 vs dp=1 grad_acc= 1 batch_size = 8 ?
    If we are only doing one update, then we won't be able to get a graph. Maybe we do this on a larger dataset where batch_size != len(data_loader) and add the graphs.

```
initial model weight is tensor([-0.0075, 0.5364])
initial model clone weight is tensor([-0.0075, 0.5364])
Step 0 - Device 0 - num items in the local batch 36
Total num items 36
Device 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337])
w/o accumulation, the final model weight is tensor([0.0953, 0.4337])
```

Results on a two devices set-up:
```
Copy link
Member

@SunMarc SunMarc Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a two devices set-up, the modification you did to take into account the dp won't be reflected here as we are only changing grad acc and batch_size. So the loss will be the same nevertheless. However, it's nice to see that the total_num_items really changed:

  • dp=2 grad_acc= 2 batch_size = 4 vs dp=2 grad_acc=1 batch_size=8

Maybe we should probably do a separate section/experiment to show the following will have the same loss graph

  • dp=2 batch_size =2 is the same as dp=1 batch_size=4. See this experiment for clarification

initial model weight is tensor([-0.0075, 0.5364])
initial model clone weight is tensor([-0.0075, 0.5364])
Step 0 - Device 0 - num items in the local batch 52
Step 0 - Device 1 - num items in the local batch 84
Total num items 136
Device 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])
Device 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])
w/o accumulation, the final model weight is tensor([0.2117, 0.3172])
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly if we can let's even toss up some wandb graphs 🔥

Copy link
Author

@ylacombe ylacombe Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, it'd be great, but here we do only one single global batch size, I don't think it's worth adding a graph. Maybe should I modify the current code snippet to do this with multiple global steps ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or add some wandb graphs from the upcoming modif of examples/by_feature/gradient_accumulation ?

Loading