Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bit_cast operator #3655

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
Open

bit_cast operator #3655

wants to merge 10 commits into from

Conversation

CharlieL7
Copy link
Collaborator

  • Converts between types while keeping the data the same. Basically an extension of reinterpret_cast without undefined behavior
  • Needed for conversion pass between fp8e4m3fn to fp8e4m3fnuz.

@CharlieL7 CharlieL7 self-assigned this Nov 22, 2024
@CharlieL7 CharlieL7 marked this pull request as ready for review November 25, 2024 20:40
typename From,
MIGRAPHX_REQUIRES(not is_any_vec<To>()),
MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr auto bit_cast(From fr) noexcept
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need for this overload, since the vec_transform works with non vectors as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't work properly with the usage of bit_cast in float8_impl.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be vectorizing fp8 types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, vec_transform already does the same check for is_any_vec and it just calls the function directly, so I dont see how that would cause an error with fp8 types.

MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should return auto. To is not valid return with vector types. If you do bit_cast<int8_t>(vec<uint8_t, 2>{}) it should return vec<int8_t, 2>{} not int8_t. The vec_transform functor already takes care of figuring out the return type for you, so you can just return auto instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue. When compiling float8 for GPU this version is used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue? The return type is wrong for this. If you bit_cast a vector type you should get a vector type, not the scalar type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue comes from the usage here:

f_inf = migraphx::bit_cast<float>(if_inf);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That takes a scalar input not a vector input.

Copy link

codecov bot commented Nov 25, 2024

Codecov Report

Attention: Patch coverage is 85.00000% with 3 lines in your changes missing coverage. Please review.

Project coverage is 92.17%. Comparing base (cd37826) to head (7b40796).

Files with missing lines Patch % Lines
src/include/migraphx/op/bit_cast.hpp 85.00% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3655      +/-   ##
===========================================
- Coverage    92.18%   92.17%   -0.01%     
===========================================
  Files          513      514       +1     
  Lines        21596    21616      +20     
===========================================
+ Hits         19908    19925      +17     
- Misses        1688     1691       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
7b4079
Rate old
bd6006
Diff Compare
torchvision-resnet50 64 3,255.49 3,255.03 0.01%
torchvision-resnet50_fp16 64 6,982.91 6,995.38 -0.18%
torchvision-densenet121 32 2,433.33 2,434.81 -0.06%
torchvision-densenet121_fp16 32 4,041.82 4,064.36 -0.55%
torchvision-inceptionv3 32 1,627.50 1,628.42 -0.06%
torchvision-inceptionv3_fp16 32 2,743.18 2,747.11 -0.14%
cadene-inceptionv4 16 764.41 765.07 -0.09%
cadene-resnext64x4 16 809.18 809.98 -0.10%
slim-mobilenet 64 7,470.54 7,462.14 0.11%
slim-nasnetalarge 64 208.46 208.42 0.02%
slim-resnet50v2 64 3,439.79 3,441.19 -0.04%
bert-mrpc-onnx 8 1,147.16 1,143.39 0.33%
bert-mrpc-tf 1 460.98 468.19 -1.54%
pytorch-examples-wlang-gru 1 430.83 414.07 4.05% 🔆
pytorch-examples-wlang-lstm 1 385.83 389.56 -0.96%
torchvision-resnet50_1 1 779.17 774.11 0.65%
cadene-dpn92_1 1 401.94 415.75 -3.32% 🔴
cadene-resnext101_1 1 382.87 383.64 -0.20%
onnx-taau-downsample 1 345.90 345.07 0.24%
dlrm-criteoterabyte 1 33.32 33.35 -0.09%
dlrm-criteoterabyte_fp16 1 52.70 52.74 -0.06%
agentmodel 1 8,135.97 9,248.55 -12.03% 🔴
unet_fp16 2 58.79 58.82 -0.06%
resnet50v1_fp16 1 956.46 945.33 1.18%
resnet50v1_int8 1 1,031.88 1,023.31 0.84%
bert_base_cased_fp16 64 1,170.22 1,170.96 -0.06%
bert_large_uncased_fp16 32 363.07 363.29 -0.06%
bert_large_fp16 1 198.54 198.28 0.13%
distilgpt2_fp16 16 2,201.80 2,196.36 0.25%
yolov5s 1 529.86 528.08 0.34%
tinyllama 1 43.65 43.42 0.53%
vicuna-fastchat 1 174.01 174.92 -0.52%
whisper-tiny-encoder 1 418.07 417.89 0.04%
whisper-tiny-decoder 1 436.02 427.86 1.91%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@pfultz2
Copy link
Collaborator

pfultz2 commented Nov 26, 2024

What is the exact error are you getting when doing only one overload for bit_cast like this?

template <typename To,
          typename From,
          MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr auto bit_cast(From fr) noexcept
{
    return vec_transform(fr)([](auto x) -> To {
        static_assert(sizeof(To) == sizeof(decltype(x)));
        return __builtin_bit_cast(To, x);
    });
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants