diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 334c9615f68..eda6ea626e4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -145,6 +145,7 @@ register_migraphx_ops( as_shape atanh atan + bit_cast bitwise_and broadcast broadcast_for_dot diff --git a/src/include/migraphx/op/bit_cast.hpp b/src/include/migraphx/op/bit_cast.hpp new file mode 100644 index 00000000000..eb233ad8b36 --- /dev/null +++ b/src/include/migraphx/op/bit_cast.hpp @@ -0,0 +1,104 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_BIT_CAST_HPP +#define MIGRAPHX_GUARD_OPERATORS_BIT_CAST_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +/** + * Obtain a value of type `target_type` by reinterpreting + * the object represnetaion of the input. Originally used + * for casting from fp8e4m3fn to fp8e4m3fnuz. + */ +struct bit_cast : unary +{ + shape::type_t target_type; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.target_type, "target_type")); + } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this, true}.has(1); + auto input = inputs.at(0); + std::size_t target_type_size; + shape::visit(target_type, [&](auto as) { target_type_size = as.size(); }); + if(input.type_size() != target_type_size) + { + MIGRAPHX_THROW("BIT_CAST: target_type has different type_size from input's"); + } + if(input.dynamic()) + { + return {target_type, input.dyn_dims()}; + } + else + { + return {target_type, input.lens(), input.strides()}; + } + } + + std::string point_op() const + { + return "${function:bit_cast}<" + shape::cpp_type(target_type) + ">(${0})"; + } + + argument compute(const dyn_output& dyn_out, std::vector args) const + { + argument result{dyn_out.computed_shape}; + result.visit([&](auto output) { + using otype = typename decltype(output)::value_type; + args[0].visit([&](auto input) { + using itype = typename decltype(input)::value_type; + if constexpr(sizeof(otype) == sizeof(itype)) + { + par_transform(input.begin(), input.end(), output.begin(), [&](auto x) { + return migraphx::bit_cast(x); + }); + } + else + { + // not possible to hit this unless somehow the types change after compute_shape + // is called + MIGRAPHX_THROW("BIT_CAST: type size mismatch"); + } + }); + }); + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp index c98395bbe10..e559658a004 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp @@ -23,15 +23,20 @@ #define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #include +#include namespace migraphx { + template {} and is_trivially_copyable{})> -inline constexpr To bit_cast(From fr) noexcept +inline constexpr auto bit_cast(From fr) noexcept { - static_assert(sizeof(To) == sizeof(From)); - return __builtin_bit_cast(To, fr); + return vec_transform(fr)([](auto x) -> To { + static_assert(sizeof(To) == sizeof(decltype(x))); + return __builtin_bit_cast(To, x); + }); } + } // namespace migraphx #endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 8d08455d814..66d54c8a460 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -201,6 +201,21 @@ TEST_CASE(binary_dyn_static_error) throws_shape(migraphx::make_op("add"), a_shape, b_shape); } +TEST_CASE(bit_cast_typesize_mismatch) +{ + migraphx::shape a_shape{migraphx::shape::int8_type, {1, 4, 4}}; + throws_shape(migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::int32_type}}), + a_shape); +} + +TEST_CASE(bit_cast_dyn) +{ + migraphx::shape a_shape{migraphx::shape::int8_type, {{1, 1}, {4, 8}, {4, 8}}}; + expect_shape(migraphx::shape{migraphx::shape::uint8_type, {{1, 1}, {4, 8}, {4, 8}}}, + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::uint8_type}}), + a_shape); +} + TEST_CASE(bitwise_and_not_integral_error) { migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}}; diff --git a/test/ref/bit_cast.cpp b/test/ref/bit_cast.cpp new file mode 100644 index 00000000000..4f9438ef4fd --- /dev/null +++ b/test/ref/bit_cast.cpp @@ -0,0 +1,75 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +#include + +TEST_CASE(bit_cast_fp8) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::fp8e4m3fn_type, {2, 2}}; + std::vector data; + data.push_back(fp8e4m3fn{26.0f}); + data.push_back(fp8e4m3fn{3.0f}); + data.push_back(fp8e4m3fn{96.0f}); + data.push_back(fp8e4m3fn{-1.25f}); + auto lit = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), lit); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold; + gold.push_back(fp8e4m3fnuz{13.0f}); + gold.push_back(fp8e4m3fnuz{1.5f}); + gold.push_back(fp8e4m3fnuz{48.0f}); + gold.push_back(fp8e4m3fnuz{-0.625f}); + EXPECT(results_vector == gold); +} + +TEST_CASE(bit_cast_uint8) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int8_type, {2, 2}}; + std::vector data = {23, -3, 0, -1}; + auto lit = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::uint8_type}}), lit); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {23, 253, 0, 255}; + EXPECT(results_vector == gold); +} diff --git a/test/verify/main.cpp b/test/verify/main.cpp index 5daa8a858d6..876db639644 100644 --- a/test/verify/main.cpp +++ b/test/verify/main.cpp @@ -129,14 +129,18 @@ int main(int argc, const char* argv[]) "test_block_reduce_small<67, migraphx::shape::int8_type>", "test_block_reduce_small<128, migraphx::shape::int8_type>", "test_block_reduce_small<129, migraphx::shape::int8_type>", + // disabled because CPU does eliminate_data_type to float for everything "test_bitwise_and", "test_bitwise_and", - "test_unpack_int4", "test_unpack_int4", "test_unpack_int4", - "test_unpack_int4"}); + "test_unpack_int4", + "test_bit_cast", + "test_bit_cast", + "test_bit_cast", + "test_bit_cast"}); rv.disable_test_for("gpu", { // These passes on MI300 but fails on others, same issue as CPU. diff --git a/test/verify/test_bit_cast.cpp b/test/verify/test_bit_cast.cpp new file mode 100644 index 00000000000..24f9a7fc745 --- /dev/null +++ b/test/verify/test_bit_cast.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +template +struct test_bit_cast : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{From, {8}}; + auto pa = mm->add_parameter("a", s); + auto pb = mm->add_parameter("b", s); + auto ia = mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pa); + auto ib = mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pb); + auto ret = mm->add_instruction(migraphx::make_op("add"), ia, ib); + mm->add_return({ret}); + return p; + }; +}; + +template struct test_bit_cast; +template struct test_bit_cast; +template struct test_bit_cast; +template struct test_bit_cast;