-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stage3: Use new torch grad accumulation hooks API #6773
Open
deepcharm
wants to merge
3
commits into
microsoft:master
Choose a base branch
from
deepcharm:stage3-use-new-grad-acc-api
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
* This commit addresses an issue reported in: microsoft#6718 * The existing code has been using the grad_acc node hook to reduce params grads. The constructs such as param.data = replicated_tensor.data used in allgather_params(..) are compiled into param.set() causing the hook assigned to the grad_acc node not being called. * This is a known torch issue pytorch/pytorch#139742. * The above caused accuracy issues and could be temporarily solved by simply disabling the torch compile when activation checkpointing is used. * This commit provides a clean solution by replacing the hook on a grad_acc node to a hook using a new and robust hook API on a param itself: param.register_post_accumulate_grad_hook(..)
tjruwase
reviewed
Nov 21, 2024
self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) | ||
self.grad_accs.append(grad_acc) | ||
self._grad_acc_hooks.append( | ||
param.register_post_accumulate_grad_hook(reduce_partition_and_remove_grads)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which pytorch version introduced this API? How should we handle older versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @tjruwase
- This API was introduced starting pytorch v2.1.
- We can add check what version and use the older API when needed.
- However, compilation should be disabled for the older versions to avoid the accuracy issues (if activation checkpointing is enabled).
Please advice what should be the approach. Thanks.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The constructs such as
param.data = replicated_tensor.data
used inallgather_params(..)
are compiled into
param.set()
causing the hook assigned to the grad_acc node not being called.param.register_post_accumulate_grad_hook(..)