Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor GPU math functions #3657

Merged
merged 17 commits into from
Dec 4, 2024
Merged

Refactor GPU math functions #3657

merged 17 commits into from
Dec 4, 2024

Conversation

pfultz2
Copy link
Collaborator

@pfultz2 pfultz2 commented Nov 25, 2024

This creates a generic math::wrap function to handle overloads of math function that will:

  • Use double for integer parameters
  • Fallback to float when there isn't an overload(ie for low-precision floats), so we dont need to define a bunch of macros to handle these cases

There is aMIGRAPHX_DEVICE_MATH_WRAP macro to setup the overloads. This will pass the overloads for each of the math functions. There is not a simple way to take a function pointer to the math functions since they are overloaded for host and device, so we need to "lift" them into a function object(see BOOST_HOF_LIFT). We have the MIGRAPHX_LIFT macro which can do that but it loses the type information(since all the parameters are templated) so the compiler can no longer figure out which overload is the best match(that is a float is passed with overloads for float and double, when its templated the compiler cant see that the float overload is an exact match because they all look like exact matches).

So instead we define the MIGRAPHX_DEVICE_MATH_LIFT to lift the function, which takes a type as the parameter to the macro and it will define the function with that type. Since we need to pass the type, we cant just pass the functions directly(such as MIGRAPHX_DEVICE_MATH_WRAP(cos, ::cos, ::cosf, ::hcos)). Therefore we pass in parenthesis preceding the function name: MIGRAPHX_DEVICE_MATH_WRAP(cos, (double)::cos, (float)::cosf, (half)::hcos). The macro will unwrap the type in parenthesis and pass it to the MIGRAPHX_DEVICE_MATH_LIFT.

I also updated the abs function to use a generic version for unsupported types(similar to min/max), since its probably faster than converting to double/float and back.

The tests have also been updated to check that there is also a float overload since we were missing several float overloads. There might be a small perf boost from this since we will no longer be using double precision for these functions.

@pfultz2 pfultz2 requested a review from causten as a code owner November 25, 2024 23:19
Copy link

codecov bot commented Nov 26, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.20%. Comparing base (6b886e3) to head (b05b10f).
Report is 1 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #3657   +/-   ##
========================================
  Coverage    92.20%   92.20%           
========================================
  Files          513      513           
  Lines        21658    21658           
========================================
  Hits         19970    19970           
  Misses        1688     1688           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@TedThemistokleous TedThemistokleous added the Cleanup Cleans up code from stale bits/warnings/previous changes for a previous feature PR label Nov 27, 2024
@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
841b21
Rate old
241e24
Diff Compare
torchvision-resnet50 64 3,256.73 3,253.39 0.10%
torchvision-resnet50_fp16 64 6,996.50 6,971.83 0.35%
torchvision-densenet121 32 2,435.09 2,432.57 0.10%
torchvision-densenet121_fp16 32 4,088.35 4,102.36 -0.34%
torchvision-inceptionv3 32 1,630.51 1,627.63 0.18%
torchvision-inceptionv3_fp16 32 2,744.73 2,741.25 0.13%
cadene-inceptionv4 16 764.47 763.96 0.07%
cadene-resnext64x4 16 810.62 809.72 0.11%
slim-mobilenet 64 7,469.79 7,381.84 1.19%
slim-nasnetalarge 64 208.41 208.40 0.00%
slim-resnet50v2 64 3,440.45 3,439.64 0.02%
bert-mrpc-onnx 8 1,145.22 1,145.80 -0.05%
bert-mrpc-tf 1 462.26 492.93 -6.22% 🔴
pytorch-examples-wlang-gru 1 420.33 426.82 -1.52%
pytorch-examples-wlang-lstm 1 385.67 392.81 -1.82%
torchvision-resnet50_1 1 790.29 777.08 1.70%
cadene-dpn92_1 1 403.33 429.78 -6.15% 🔴
cadene-resnext101_1 1 382.32 383.15 -0.22%
onnx-taau-downsample 1 345.51 346.23 -0.21%
dlrm-criteoterabyte 1 33.33 33.33 -0.01%
dlrm-criteoterabyte_fp16 1 52.74 52.75 -0.02%
agentmodel 1 8,409.86 8,083.75 4.03% 🔆
unet_fp16 2 58.73 58.78 -0.08%
resnet50v1_fp16 1 1,003.05 943.59 6.30% 🔆
resnet50v1_int8 1 1,008.59 1,004.54 0.40%
bert_base_cased_fp16 64 1,168.52 1,169.61 -0.09%
bert_large_uncased_fp16 32 363.02 363.60 -0.16%
bert_large_fp16 1 200.08 200.33 -0.12%
distilgpt2_fp16 16 2,200.06 2,196.55 0.16%
yolov5s 1 537.45 526.38 2.10%
tinyllama 1 43.39 43.64 -0.58%
vicuna-fastchat 1 174.88 173.50 0.79%
whisper-tiny-encoder 1 418.10 417.83 0.07%
whisper-tiny-decoder 1 435.65 427.29 1.96%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

appreciate your comments you added to especially with half2.

@causten causten merged commit 871fd56 into develop Dec 4, 2024
32 of 36 checks passed
@causten causten deleted the gpu-math-wrap branch December 4, 2024 16:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Cleanup Cleans up code from stale bits/warnings/previous changes for a previous feature PR Perf Improve
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants