Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TEST] Fixing precision bug in fp8 test_dot (#4131)
There is a precision bug in test_core.py::test_dot for chain-dot cases. Inside the kernel, it does fp8xfp8->fp32 dot-product (`z=xy`) first. Then, for chain-dot, it casts the output `z` back to fp8 and do the fp8xfp8->fp32 dot-product again `z=zw`. However, the reference numpy computation *does not* cast the intermediate output `z_ref` to fp8. Therefore, the second dot-product becomes fp32xfp8->fp32, whose result is different from the kernel output. In some fp8 setup (float8e4nv in our case), it sometimes causes a test failure due to this precision issue. I have fixed the reference computation to reduce the precision of the intermediate output. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
- Loading branch information