From 5d6eeba0cbc3f35df32f54fe4d3cec2333372bdb Mon Sep 17 00:00:00 2001 From: richagadgil Date: Tue, 10 Dec 2024 18:33:54 -0600 Subject: [PATCH] working float equals --- .../device/include/migraphx/gpu/device/float_equal.hpp | 10 ++++++++++ .../gpu/device/include/migraphx/gpu/device/types.hpp | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp index 9fb6f858d18..392083c2646 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp @@ -44,6 +44,16 @@ __device__ bool float_equal_device(T x, T y) std::nextafter(x, std::numeric_limits::max()) >= y; } +template <> +__device__ bool float_equal_device(__bf16 x, __bf16 y) +{ + float xf = static_cast(x); + float yf = static_cast(y); + return std::isfinite(xf) and std::isfinite(yf) and + std::nextafter(xf, std::numeric_limits::lowest()) <= yf and + std::nextafter(xf, std::numeric_limits::max()) >= yf; +} + template {})> __device__ bool float_equal_device(T x, T y) { diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp index a2b5bc56f46..b6ed5a7786e 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp @@ -199,7 +199,7 @@ struct is_arithmetic<__fp16> : std::true_type {}; // Redo for __bf16 template <> -struct is_floating_point<__bf16> : std::false_type {}; +struct is_floating_point<__bf16> : std::true_type {}; template <> struct is_signed<__bf16> : std::true_type {}; template <>