Skip to content

Commit

Permalink
[TEST] Fixing precision bug in fp8 test_dot (#4131)
Browse files Browse the repository at this point in the history
There is a precision bug in test_core.py::test_dot for chain-dot cases.
Inside the kernel, it does fp8xfp8->fp32 dot-product (`z=xy`) first.
Then, for chain-dot, it casts the output `z` back to fp8 and do the
fp8xfp8->fp32 dot-product again `z=zw`.
However, the reference numpy computation *does not* cast the
intermediate output `z_ref` to fp8.
Therefore, the second dot-product becomes fp32xfp8->fp32, whose result
is different from the kernel output.
In some fp8 setup (float8e4nv in our case), it sometimes causes a test
failure due to this precision issue.
I have fixed the reference computation to reduce the precision of the
intermediate output.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
  • Loading branch information
hwnam831 authored Jun 15, 2024
1 parent 83a9b34 commit 4f94c88
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3206,6 +3206,14 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
z_ref = num / denom
if epilogue == 'chain-dot':
if 'float8' in in_dtype:
# Reduce z_ref's precision to fp8 to match the kernel behavior
if in_dtype == 'float8e4nv':
z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fn)
elif in_dtype == 'float8e5':
z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2)
else:
assert "Unsupported float8 dtype"
z_ref = to_numpy(z_fp8.to(torch.float32))
w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype))
z_ref = np.matmul(z_ref, w)
# compare
Expand Down

0 comments on commit 4f94c88

Please sign in to comment.