-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor GPU math functions #3657
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #3657 +/- ##
========================================
Coverage 92.20% 92.20%
========================================
Files 513 513
Lines 21658 21658
========================================
Hits 19970 19970
Misses 1688 1688 ☔ View full report in Codecov by Sentry. |
This build is not recommended to merge 🔴 |
🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
appreciate your comments you added to especially with half2.
This creates a generic
math::wrap
function to handle overloads of math function that will:double
for integer parametersfloat
when there isn't an overload(ie for low-precision floats), so we dont need to define a bunch of macros to handle these casesThere is a
MIGRAPHX_DEVICE_MATH_WRAP
macro to setup the overloads. This will pass the overloads for each of the math functions. There is not a simple way to take a function pointer to the math functions since they are overloaded for host and device, so we need to "lift" them into a function object(seeBOOST_HOF_LIFT
). We have theMIGRAPHX_LIFT
macro which can do that but it loses the type information(since all the parameters are templated) so the compiler can no longer figure out which overload is the best match(that is afloat
is passed with overloads forfloat
anddouble
, when its templated the compiler cant see that thefloat
overload is an exact match because they all look like exact matches).So instead we define the
MIGRAPHX_DEVICE_MATH_LIFT
to lift the function, which takes atype
as the parameter to the macro and it will define the function with that type. Since we need to pass the type, we cant just pass the functions directly(such asMIGRAPHX_DEVICE_MATH_WRAP(cos, ::cos, ::cosf, ::hcos)
). Therefore we pass in parenthesis preceding the function name:MIGRAPHX_DEVICE_MATH_WRAP(cos, (double)::cos, (float)::cosf, (half)::hcos)
. The macro will unwrap the type in parenthesis and pass it to theMIGRAPHX_DEVICE_MATH_LIFT
.I also updated the
abs
function to use a generic version for unsupported types(similar to min/max), since its probably faster than converting to double/float and back.The tests have also been updated to check that there is also a
float
overload since we were missing severalfloat
overloads. There might be a small perf boost from this since we will no longer be using double precision for these functions.