-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Half precision fixes #606
Half precision fixes #606
Conversation
@@ -158,15 +157,15 @@ def forward( | |||
h = self.attn(n_1, cos, sin, mask, input_pos) | |||
if self.config.parallel_residual: | |||
n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) | |||
x = x + h + self.mlp(n_2) | |||
x = self.mlp(n_2) + h + x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addition is not commutative in fp16, and this is how GptNeox implements the order
], | ||
torch.device("cuda"), torch.float16, marks=[ | ||
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input | ||
# is slightly different |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't able to find out why the final layernorm input is different. If you print x.sum()
the value matches but some positions are not the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably for the same reason that the float16 tests are xfailed
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"), | ||
], | ||
torch.device("cuda"), torch.float16, marks=[ | ||
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing this upscaling would require giving up scaled_dot_product_attention
8afa6e7
to
d35d5e0
Compare
# this is to mimic the behaviour of complex32, else we will get different results | ||
if dtype in (torch.float16, torch.bfloat16, torch.int8): | ||
return cos.half(), sin.half() | ||
return cos, sin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the implications? These changes were necessary to have parity with the original llama rope cache implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not doing this matches the HF Rope implementations. HF keeps these in float32.
This piece of code comes from lit-llama, where it was ported from the original facebookresearch repo. Since this repo uses the HF implementations as references instead, I think it's fine to remove this.
The mistral original release still uses complex numbers: https://github.com/mistralai/mistral-src/blob/main/mistral/rope.py
I don't think there's a solution that is numerically precise and performant across implementations. We just need to choose which reference implementation we compare this with.
Fixes #602