Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Half precision fixes #606

Merged
merged 4 commits into from
Oct 24, 2023
Merged

Half precision fixes #606

merged 4 commits into from
Oct 24, 2023

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Oct 2, 2023

Fixes #602

@carmocca carmocca self-assigned this Oct 2, 2023
@@ -158,15 +157,15 @@ def forward(
h = self.attn(n_1, cos, sin, mask, input_pos)
if self.config.parallel_residual:
n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
x = x + h + self.mlp(n_2)
x = self.mlp(n_2) + h + x
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addition is not commutative in fp16, and this is how GptNeox implements the order

],
torch.device("cuda"), torch.float16, marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to find out why the final layernorm input is different. If you print x.sum() the value matches but some positions are not the same

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I change the dtype for cpu test from float32 to blofat16 some combinations will fail.
Screenshot 2023-10-24 at 7 24 22 PM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably for the same reason that the float16 tests are xfailed

pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
torch.device("cuda"), torch.float16, marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this upscaling would require giving up scaled_dot_product_attention

@carmocca carmocca force-pushed the carmocca/half-precision-fixes branch from 8afa6e7 to d35d5e0 Compare October 3, 2023 01:28
@carmocca carmocca marked this pull request as ready for review October 3, 2023 01:30
Comment on lines -333 to -336
# this is to mimic the behaviour of complex32, else we will get different results
if dtype in (torch.float16, torch.bfloat16, torch.int8):
return cos.half(), sin.half()
return cos, sin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the implications? These changes were necessary to have parity with the original llama rope cache implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not doing this matches the HF Rope implementations. HF keeps these in float32.

This piece of code comes from lit-llama, where it was ported from the original facebookresearch repo. Since this repo uses the HF implementations as references instead, I think it's fine to remove this.

The mistral original release still uses complex numbers: https://github.com/mistralai/mistral-src/blob/main/mistral/rope.py

I don't think there's a solution that is numerically precise and performant across implementations. We just need to choose which reference implementation we compare this with.

@carmocca carmocca mentioned this pull request Oct 19, 2023
@carmocca carmocca merged commit 6178c7c into main Oct 24, 2023
4 of 5 checks passed
@carmocca carmocca deleted the carmocca/half-precision-fixes branch October 24, 2023 15:25
@carmocca carmocca assigned carmocca and unassigned carmocca Nov 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add modelling tests with 16 bit precision
3 participants