Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using torch.bfloat16 to prevent overflow instead of default fp16 in AMP #345

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rajeevgl01
Copy link

Using torch.bfloat16 to prevent overflow. Float16 has three less integer bits compared to bfloat16 which causes NaN loss and NaN grad norms during AMP training. This seems to be a common issue while training the Swin Transformer.

BFloat16 has same integer bits compared to FP32 but less precision bits. If we want higher precision but also want to save GPU memory, then TensorFloat32 or tfloat32 can be used instead.

TF32 has less precision bits when compared to FP32, but 3 more integer bits compared to FP16. But TF32 can only be used on latest NVIDIA ampere gpus or newer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant