Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass tests/gpu/jit.cpp with BF16 #3639

Draft
wants to merge 77 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
c51c1ce
first pass at integrating generic float
richagadgil Oct 10, 2024
134b408
fix namespaces
richagadgil Oct 10, 2024
d4fa6eb
fix mantissa
richagadgil Oct 10, 2024
0b60841
refactor
richagadgil Oct 11, 2024
7a646f1
refactor
richagadgil Oct 11, 2024
ebe819b
add fp
richagadgil Oct 11, 2024
379a77a
fixed generic float class
richagadgil Oct 14, 2024
174384c
add fp32 test
richagadgil Oct 14, 2024
787b651
remove import
richagadgil Oct 14, 2024
1d1fa1c
update tests
richagadgil Oct 15, 2024
1791092
fp16 tests that work
richagadgil Oct 17, 2024
a2eb005
update tests
richagadgil Oct 18, 2024
ff8ffc7
updated fp16 and fp32 tests
richagadgil Oct 18, 2024
e36fd65
half tests
richagadgil Oct 22, 2024
9ac4e2a
underflow and overflow tests
richagadgil Oct 22, 2024
f05fd31
generate map
richagadgil Oct 22, 2024
cb4d92d
add more tests
richagadgil Oct 22, 2024
0cc1946
fix names
richagadgil Oct 22, 2024
85a761b
update tests
richagadgil Oct 23, 2024
65cf9ae
remove and
richagadgil Oct 24, 2024
fbabf54
disable warning
richagadgil Oct 24, 2024
549f5e6
fix tidy warning
richagadgil Oct 24, 2024
d302e5d
migraphx py fix
richagadgil Oct 25, 2024
8d475e3
add increments
richagadgil Oct 25, 2024
a0fd055
fix warnings
richagadgil Oct 25, 2024
41379fe
disable duplicate branch warning
richagadgil Oct 25, 2024
0c29c7b
add countzero_std
richagadgil Oct 28, 2024
4b012a8
ci error
richagadgil Oct 28, 2024
dbaa3a8
simplify countl
richagadgil Oct 28, 2024
b2bd2a0
fix ci
richagadgil Oct 28, 2024
6f328f0
src
richagadgil Oct 29, 2024
e6d9763
remove flag
richagadgil Oct 29, 2024
6538050
hide abi warning
richagadgil Oct 29, 2024
4e96d4d
revert changes
richagadgil Oct 29, 2024
ef11f1f
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
e4a25bd
change half in tests
richagadgil Oct 29, 2024
3354c6e
Update generic_float.hpp
richagadgil Oct 29, 2024
6de079b
format
richagadgil Oct 29, 2024
7750874
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
801f485
Merge branch 'develop' into generic_float
causten Oct 30, 2024
33e2c8d
fix bug
richagadgil Oct 30, 2024
9bb7198
Merge branch 'generic_float' of github.com:ROCm/AMDMIGraphX into gene…
richagadgil Oct 30, 2024
b3c345d
fix err
richagadgil Oct 30, 2024
03df6f9
edits
richagadgil Oct 31, 2024
ad817b2
tidy and format
richagadgil Oct 31, 2024
898417b
tidy etc
richagadgil Oct 31, 2024
aa5b9c9
gf
richagadgil Oct 31, 2024
6f72370
fix tidy errs
richagadgil Nov 1, 2024
0aab1a0
bf16 changes
richagadgil Nov 4, 2024
7b965c0
add flag to trace quantization passes (#3571)
shivadbhavsar Oct 30, 2024
5f5f13d
bf16
richagadgil Oct 30, 2024
d64b124
Update bf16.cpp
richagadgil Nov 1, 2024
a064eaa
Update bf16.hpp
richagadgil Nov 2, 2024
befbd9e
Update bf16.hpp
richagadgil Nov 2, 2024
08b9511
update files with working version
richagadgil Nov 4, 2024
b9d204e
Update bf16.cpp
richagadgil Nov 4, 2024
fb6df2d
Update generic_float.hpp
richagadgil Nov 4, 2024
bb78138
Merge branch 'develop' into bf16
richagadgil Nov 8, 2024
8e1f99e
add extra common type
richagadgil Nov 8, 2024
6192970
tidy
richagadgil Nov 8, 2024
c0d6bc4
Update bf16.hpp
richagadgil Nov 11, 2024
7bfc407
Update generic_float.hpp
richagadgil Nov 11, 2024
4cb96ad
Merge branch 'develop' into bf16
richagadgil Nov 11, 2024
ffd4ba2
remove imports
richagadgil Nov 12, 2024
8a10da3
Merge branch 'develop' into bf16
richagadgil Nov 12, 2024
1565a0e
ref tests
richagadgil Nov 13, 2024
e6d1155
migraphx_py fix
richagadgil Nov 13, 2024
867e960
fix test cae by index
richagadgil Nov 13, 2024
9852da5
add rocblas type
richagadgil Nov 13, 2024
bf50653
fix tgts err
richagadgil Nov 13, 2024
0ebd220
address changes
richagadgil Nov 18, 2024
043e322
Merge branch 'develop' into bf16
richagadgil Nov 18, 2024
21746a5
Merge branch 'develop' into bf16
causten Nov 19, 2024
511d37e
add new type
richagadgil Nov 19, 2024
6701db5
change math types
richagadgil Nov 20, 2024
3c39126
Merge branch 'develop' into gpu_jit_bf16
richagadgil Nov 21, 2024
7a9c3c4
fix assert test
richagadgil Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#ifndef MIGRAPHX_USE_HIPRTC
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/math_functions.h>
#endif

Expand Down
52 changes: 52 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
constexpr float as_float(migraphx::fp8::fp8e4m3fn x) { return x; }
constexpr float as_float(migraphx::fp8::fp8e5m2 x) { return x; }

constexpr float as_float(migraphx::bf16 x) { return x; }

template <class T>
constexpr T as_float(T x)
{
Expand Down Expand Up @@ -78,6 +80,12 @@
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))

// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BF16(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::bf16 x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))

// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
Expand Down Expand Up @@ -166,6 +174,20 @@
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)

// Builtin half functions
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, abs, ::__habs)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, ceil, ::hceil)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, cos, ::hcos)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, exp, ::hexp)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, floor, ::hfloor)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, isinf, ::__hisinf)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, isnan, ::__hisnan)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, log, ::hlog)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, log2, ::hlog2)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, rsqrt, ::hrsqrt)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, sin, ::hsin)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, sqrt, ::hsqrt)

// Use float to compute half overload
MIGRAPHX_DEVICE_MATH_HALF(acos, ::acos)
MIGRAPHX_DEVICE_MATH_HALF(acosh, ::acosh)
Expand All @@ -184,6 +206,34 @@
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)

// Use float to compute bf16 overload
MIGRAPHX_DEVICE_MATH_BF16(abs, ::abs)
MIGRAPHX_DEVICE_MATH_BF16(acos, ::acos)
MIGRAPHX_DEVICE_MATH_BF16(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_BF16(asin, ::asin)
MIGRAPHX_DEVICE_MATH_BF16(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_BF16(atan, ::atan)
MIGRAPHX_DEVICE_MATH_BF16(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_BF16(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_BF16(cos, ::cos)
MIGRAPHX_DEVICE_MATH_BF16(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_BF16(erf, ::erf)
MIGRAPHX_DEVICE_MATH_BF16(exp, ::exp)
MIGRAPHX_DEVICE_MATH_BF16(floor, ::floor)
MIGRAPHX_DEVICE_MATH_BF16(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_BF16(log, ::log)
MIGRAPHX_DEVICE_MATH_BF16(log2, ::log2)
MIGRAPHX_DEVICE_MATH_BF16(pow, ::pow)
MIGRAPHX_DEVICE_MATH_BF16(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_BF16(round, ::round)
MIGRAPHX_DEVICE_MATH_BF16(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH_BF16(sin, ::sin)
MIGRAPHX_DEVICE_MATH_BF16(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_BF16(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH_BF16(tan, ::tan)
MIGRAPHX_DEVICE_MATH_BF16(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_BF16(fmod, ::fmod)

// use float to compute fp8 overload
MIGRAPHX_DEVICE_MATH_FP8(abs, ::abs)
MIGRAPHX_DEVICE_MATH_FP8(acos, ::acos)
Expand Down Expand Up @@ -242,8 +292,10 @@
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax)

Check warning on line 295 in src/targets/gpu/kernels/include/migraphx/kernels/math.hpp

View workflow job for this annotation

GitHub Actions / tidy

conversion from 'migraphx::half' (aka '_Float16') to '__hip_bfloat16' is ambiguous [clang-diagnostic-error]

Check warning on line 295 in src/targets/gpu/kernels/include/migraphx/kernels/math.hpp

View workflow job for this annotation

GitHub Actions / tidy

conversion from 'migraphx::half' (aka '_Float16') to '__hip_bfloat16' is ambiguous [clang-diagnostic-error]

Check warning on line 295 in src/targets/gpu/kernels/include/migraphx/kernels/math.hpp

View workflow job for this annotation

GitHub Actions / tidy

conversion from 'migraphx::half' (aka '_Float16') to '__hip_bfloat16' is ambiguous [clang-diagnostic-error]

Check warning on line 295 in src/targets/gpu/kernels/include/migraphx/kernels/math.hpp

View workflow job for this annotation

GitHub Actions / tidy

conversion from 'migraphx::half' (aka '_Float16') to '__hip_bfloat16' is ambiguous [clang-diagnostic-error]

Check warning on line 295 in src/targets/gpu/kernels/include/migraphx/kernels/math.hpp

View workflow job for this annotation

GitHub Actions / tidy

conversion from 'migraphx::half' (aka '_Float16') to '__hip_bfloat16' is ambiguous [clang-diagnostic-error]
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin)

Check warning on line 296 in src/targets/gpu/kernels/include/migraphx/kernels/math.hpp

View workflow job for this annotation

GitHub Actions / tidy

conversion from 'migraphx::half' (aka '_Float16') to '__hip_bfloat16' is ambiguous [clang-diagnostic-error]

Check warning on line 296 in src/targets/gpu/kernels/include/migraphx/kernels/math.hpp

View workflow job for this annotation

GitHub Actions / tidy

conversion from 'migraphx::half' (aka '_Float16') to '__hip_bfloat16' is ambiguous [clang-diagnostic-error]

Check warning on line 296 in src/targets/gpu/kernels/include/migraphx/kernels/math.hpp

View workflow job for this annotation

GitHub Actions / tidy

conversion from 'migraphx::half' (aka '_Float16') to '__hip_bfloat16' is ambiguous [clang-diagnostic-error]

Check warning on line 296 in src/targets/gpu/kernels/include/migraphx/kernels/math.hpp

View workflow job for this annotation

GitHub Actions / tidy

conversion from 'migraphx::half' (aka '_Float16') to '__hip_bfloat16' is ambiguous [clang-diagnostic-error]

Check warning on line 296 in src/targets/gpu/kernels/include/migraphx/kernels/math.hpp

View workflow job for this annotation

GitHub Actions / tidy

conversion from 'migraphx::half' (aka '_Float16') to '__hip_bfloat16' is ambiguous [clang-diagnostic-error]
// MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::bf16, max, ::__hmax)
// MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::bf16, min, ::__hmin)

template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ constexpr T numeric_max()
return __FLT_MAX__;
else if constexpr(is_same<T, migraphx::half>{})
return __FLT16_MAX__;
else if constexpr(is_same<T, migraphx::bf16>{})
return 338953138925153547590470800371487866880.000000;
else
return 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ using vec = T __attribute__((ext_vector_type(N)));

using half = _Float16;
using half2 = migraphx::vec<half, 2>;

using bf16 = __bf16;
} // namespace migraphx

#endif
5 changes: 2 additions & 3 deletions test/gpu/jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,12 @@ TEST_CASE(assert_type_min_max)
if(contains({migraphx::shape::bool_type,
migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::tuple_type,
migraphx::shape::bf16_type},
},
t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type)
if((t == migraphx::shape::half_type) or (t == migraphx::shape::bf16_type))
name.insert(0, "migraphx::");

migraphx::shape::visit(t, [&](auto as) {
std::string min = "";
std::string max = "";
Expand Down
Loading