Skip to content

Commit

Permalink
bit_cast operator (#3655)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored Dec 4, 2024
1 parent 88327d7 commit dde7986
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ register_migraphx_ops(
as_shape
atanh
atan
bit_cast
bitwise_and
broadcast
broadcast_for_dot
Expand Down
104 changes: 104 additions & 0 deletions src/include/migraphx/op/bit_cast.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_BIT_CAST_HPP
#define MIGRAPHX_GUARD_OPERATORS_BIT_CAST_HPP

#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
#include <migraphx/bit_cast.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

/**
* Obtain a value of type `target_type` by reinterpreting
* the object represnetaion of the input. Originally used
* for casting from fp8e4m3fn to fp8e4m3fnuz.
*/
struct bit_cast : unary<bit_cast>
{
shape::type_t target_type;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.target_type, "target_type"));
}

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto input = inputs.at(0);
std::size_t target_type_size;
shape::visit(target_type, [&](auto as) { target_type_size = as.size(); });
if(input.type_size() != target_type_size)
{
MIGRAPHX_THROW("BIT_CAST: target_type has different type_size from input's");
}
if(input.dynamic())
{
return {target_type, input.dyn_dims()};
}
else
{
return {target_type, input.lens(), input.strides()};
}
}

std::string point_op() const
{
return "${function:bit_cast}<" + shape::cpp_type(target_type) + ">(${0})";
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{dyn_out.computed_shape};
result.visit([&](auto output) {
using otype = typename decltype(output)::value_type;
args[0].visit([&](auto input) {
using itype = typename decltype(input)::value_type;
if constexpr(sizeof(otype) == sizeof(itype))
{
par_transform(input.begin(), input.end(), output.begin(), [&](auto x) {
return migraphx::bit_cast<otype>(x);
});
}
else
{
// not possible to hit this unless somehow the types change after compute_shape
// is called
MIGRAPHX_THROW("BIT_CAST: type size mismatch");
}
});
});
return result;
}
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
11 changes: 8 additions & 3 deletions src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,20 @@
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP

#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/vec.hpp>

namespace migraphx {

template <typename To,
typename From,
MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept
inline constexpr auto bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
return __builtin_bit_cast(To, fr);
return vec_transform(fr)([](auto x) -> To {
static_assert(sizeof(To) == sizeof(decltype(x)));
return __builtin_bit_cast(To, x);
});
}

} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
15 changes: 15 additions & 0 deletions test/op_shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,21 @@ TEST_CASE(binary_dyn_static_error)
throws_shape(migraphx::make_op("add"), a_shape, b_shape);
}

TEST_CASE(bit_cast_typesize_mismatch)
{
migraphx::shape a_shape{migraphx::shape::int8_type, {1, 4, 4}};
throws_shape(migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::int32_type}}),
a_shape);
}

TEST_CASE(bit_cast_dyn)
{
migraphx::shape a_shape{migraphx::shape::int8_type, {{1, 1}, {4, 8}, {4, 8}}};
expect_shape(migraphx::shape{migraphx::shape::uint8_type, {{1, 1}, {4, 8}, {4, 8}}},
migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::uint8_type}}),
a_shape);
}

TEST_CASE(bitwise_and_not_integral_error)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}};
Expand Down
75 changes: 75 additions & 0 deletions test/ref/bit_cast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>

#include <test.hpp>

TEST_CASE(bit_cast_fp8)
{
using migraphx::fp8::fp8e4m3fn;
using migraphx::fp8::fp8e4m3fnuz;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::fp8e4m3fn_type, {2, 2}};
std::vector<fp8e4m3fn> data;
data.push_back(fp8e4m3fn{26.0f});
data.push_back(fp8e4m3fn{3.0f});
data.push_back(fp8e4m3fn{96.0f});
data.push_back(fp8e4m3fn{-1.25f});
auto lit = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(
migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), lit);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<fp8e4m3fnuz> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<fp8e4m3fnuz> gold;
gold.push_back(fp8e4m3fnuz{13.0f});
gold.push_back(fp8e4m3fnuz{1.5f});
gold.push_back(fp8e4m3fnuz{48.0f});
gold.push_back(fp8e4m3fnuz{-0.625f});
EXPECT(results_vector == gold);
}

TEST_CASE(bit_cast_uint8)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int8_type, {2, 2}};
std::vector<int8_t> data = {23, -3, 0, -1};
auto lit = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(
migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::uint8_type}}), lit);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint8_t> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {23, 253, 0, 255};
EXPECT(results_vector == gold);
}
8 changes: 6 additions & 2 deletions test/verify/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,18 @@ int main(int argc, const char* argv[])
"test_block_reduce_small<67, migraphx::shape::int8_type>",
"test_block_reduce_small<128, migraphx::shape::int8_type>",
"test_block_reduce_small<129, migraphx::shape::int8_type>",

// disabled because CPU does eliminate_data_type to float for everything
"test_bitwise_and<migraphx::shape::int32_type>",
"test_bitwise_and<migraphx::shape::uint8_type>",

"test_unpack_int4<migraphx::shape::uint8_type>",
"test_unpack_int4<migraphx::shape::int8_type>",
"test_unpack_int4<migraphx::shape::uint8_type, 0>",
"test_unpack_int4<migraphx::shape::int8_type, 0>"});
"test_unpack_int4<migraphx::shape::int8_type, 0>",
"test_bit_cast<migraphx::shape::uint8_type, migraphx::shape::int8_type>",
"test_bit_cast<migraphx::shape::int8_type, migraphx::shape::uint8_type>",
"test_bit_cast<migraphx::shape::fp8e4m3fn_type, migraphx::shape::fp8e4m3fnuz_type>",
"test_bit_cast<migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::fp8e4m3fn_type>"});
rv.disable_test_for("gpu",
{
// These passes on MI300 but fails on others, same issue as CPU.
Expand Down
55 changes: 55 additions & 0 deletions test/verify/test_bit_cast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>

#include <migraphx/make_op.hpp>

template <migraphx::shape::type_t From, migraphx::shape::type_t To>
struct test_bit_cast : verify_program<test_bit_cast<From, To>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{From, {8}};
auto pa = mm->add_parameter("a", s);
auto pb = mm->add_parameter("b", s);
auto ia = mm->add_instruction(
migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pa);
auto ib = mm->add_instruction(
migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pb);
auto ret = mm->add_instruction(migraphx::make_op("add"), ia, ib);
mm->add_return({ret});
return p;
};
};

template struct test_bit_cast<migraphx::shape::uint8_type, migraphx::shape::int8_type>;
template struct test_bit_cast<migraphx::shape::int8_type, migraphx::shape::uint8_type>;
template struct test_bit_cast<migraphx::shape::fp8e4m3fn_type, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_bit_cast<migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::fp8e4m3fn_type>;

0 comments on commit dde7986

Please sign in to comment.