Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give example on how to handle gradient accumulation with cross-entropy #3193

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

ylacombe
Copy link

What does this PR do?

Following the recent highlights on how gradient accumulation with the cross-entropy loss is usually off, it could be great to have it mentioned in the doc. I've thus added some code and explanation of it in the gradient accumulation page.

cc @SunMarc and @muellerzr, let me know what you think of it or if I can make this any clearer!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! (We need to do the gather() + div by num processes in the trainer still).

Left a few nits, I think it'd be really cool if we can show full training graphs. After doing stuff with FP8 just taking "the end result is the same" at face value I don't fully trust :)

Comment on lines 432 to 452
Results on a single device:
```
initial model weight is tensor([-0.0075, 0.5364])
initial model clone weight is tensor([-0.0075, 0.5364])
Step 0 - Device 0 - num items in the local batch 36
Total num items 36
Device 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337])
w/o accumulation, the final model weight is tensor([0.0953, 0.4337])
```

Results on a two devices set-up:
```
initial model weight is tensor([-0.0075, 0.5364])
initial model clone weight is tensor([-0.0075, 0.5364])
Step 0 - Device 0 - num items in the local batch 52
Step 0 - Device 1 - num items in the local batch 84
Total num items 136
Device 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])
Device 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])
w/o accumulation, the final model weight is tensor([0.2117, 0.3172])
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly if we can let's even toss up some wandb graphs 🔥

Copy link
Author

@ylacombe ylacombe Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, it'd be great, but here we do only one single global batch size, I don't think it's worth adding a graph. Maybe should I modify the current code snippet to do this with multiple global steps ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or add some wandb graphs from the upcoming modif of examples/by_feature/gradient_accumulation ?

model_optimizer.zero_grad()


logger.warning(f"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}", main_process_only=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than logger.warning, we can do print() here or change the default logging level :) (Just logging.warning rather than logging.info weirds me out)

docs/source/usage_guides/gradient_accumulation.md Outdated Show resolved Hide resolved
docs/source/usage_guides/gradient_accumulation.md Outdated Show resolved Hide resolved
docs/source/usage_guides/gradient_accumulation.md Outdated Show resolved Hide resolved
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job @ylacombe ! Left a few suggestions !

docs/source/usage_guides/gradient_accumulation.md Outdated Show resolved Hide resolved
Comment on lines +366 to +378
num_samples_in_epoch = len(dataloader)
remainder = num_samples_in_epoch % gradient_accumulation_steps
remainder = remainder if remainder != 0 else gradient_accumulation_steps
total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)

total_batched_samples = 0
for update_step in range(total_gradient_updates):
# In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss
# we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples
batch_samples = []
num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder
for _ in range(num_batches_in_step):
batch_samples += [next(training_iterator)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works when we know the size of the dataloader. Can we think of a solution that doesn't require this information ? I think we can just iter on the dataloader until we have gradient_accumulation_steps to create the batch_sample. If we can't iter anymore, then we stop also. I think that code will be easier to understand.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agreed :) (What we do in the Trainer)

docs/source/usage_guides/gradient_accumulation.md Outdated Show resolved Hide resolved
Comment on lines 431 to 432

Results on a single device:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can precise the exact setup ? I think that we are doing the following ?

  • dp=1 grad_acc= 2 batch_size = 4 vs dp=1 grad_acc= 1 batch_size = 8 ?
    If we are only doing one update, then we won't be able to get a graph. Maybe we do this on a larger dataset where batch_size != len(data_loader) and add the graphs.

Comment on lines 442 to 443
Results on a two devices set-up:
```
Copy link
Member

@SunMarc SunMarc Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a two devices set-up, the modification you did to take into account the dp won't be reflected here as we are only changing grad acc and batch_size. So the loss will be the same nevertheless. However, it's nice to see that the total_num_items really changed:

  • dp=2 grad_acc= 2 batch_size = 4 vs dp=2 grad_acc=1 batch_size=8

Maybe we should probably do a separate section/experiment to show the following will have the same loss graph

  • dp=2 batch_size =2 is the same as dp=1 batch_size=4. See this experiment for clarification

Comment on lines +249 to +251
def test_gradient_accumulation_for_autoregressive_models(self):
testargs = ["examples/by_feature/gradient_accumulation_for_autoregressive_models.py"]
run_command(self.launch_args + testargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nit: this doesn't use gradient accumulation here since it uses the default of 1

"--per_device_batch_size",
type=int,
default=2,
help="The number of minibatches to be ran before gradients are accumulated.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be "The size of each minibatch"?

Suggested change
help="The number of minibatches to be ran before gradients are accumulated.",
help="The size of each minibatch",

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants