forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DTensor] Supported 2D
clip_grad_norm_
(pytorch#121945)
This PR adds support for 2D `clip_grad_norm_` (`foreach=True`). - This PR changes `OpSchema.args_spec` to use pytree if the runtime schema info specifies it. - This PR includes a unit test for 2D FSDP2 + SP with `clip_grad_norm_` enabled, which serves as a complete numerics test for 2D. Note: With this PR patched, 2-way SP + 4-way FSDP matches 8-way FSDP numerics on Llama-7B (doubling local batch size for the 2-way SP run). Pull Request resolved: pytorch#121945 Approved by: https://github.com/wanchaol ghstack dependencies: pytorch#121747, pytorch#121869
- Loading branch information
1 parent
2c33e3a
commit f4dd2fd
Showing
3 changed files
with
103 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters