Skip to content

Commit

Permalink
[SYCL] Fix ambiguity for bit_cast (intel#9398)
Browse files Browse the repository at this point in the history
After merging
kokkos/kokkos@ab41ef8
which implements `Kokkos::bit_cast`, we (@dalg24 and me) are seeing
ambiguous calls to `bit_cast` originating from calls from
https://github.com/intel/llvm/blob/28113ec691679dc0316a54c0da453014ff68a2c3/sycl/include/sycl/detail/spirv.hpp#L926-L937.

The call to `bit_cast` there is unqualified, and passes an object in the
`Kokkos` namespace which means that ADL kicks in. The compiler then
finds both Kokkos::bit_cast and sycl::bit_cast. This problem was likely
introduced by
intel@d4b66bd#diff-dad3832b9c0831f4d9d2fa5695efc8c44f58d3fb0c97716af6bfd4eea882eb2fR631.
To fix it, the calls should be qualified as proposed in the first
commit.

While working on a workaround in `Kokkos`, I also noticed that
`sycl::bit_cast` is unconstrained and always participates in overload
resolution although `std::bit_cast` should only participate in overload
resolution only if `sizeof(To) == sizeof(From)` and both `To` and `From`
are
[TriviallyCopyable](https://en.cppreference.com/w/cpp/named_req/TriviallyCopyable)
types, see, e.g., https://en.cppreference.com/w/cpp/numeric/bit_cast.
Thus, the second commit moves the `static_assert`s to template
constraints in terms of SFINAE.
  • Loading branch information
masterleinad authored May 12, 2023
1 parent fb38046 commit 834df47
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 27 deletions.
2 changes: 1 addition & 1 deletion sycl/include/sycl/atomic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class __SYCL2020_DEPRECATED(
cl_int, addressSpace, access::decorated::yes>::pointer>(Ptr);
cl_int TmpVal = __spirv_AtomicLoad(
TmpPtr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order));
cl_float ResVal = bit_cast<cl_float>(TmpVal);
cl_float ResVal = sycl::bit_cast<cl_float>(TmpVal);
return ResVal;
}
#else
Expand Down
11 changes: 4 additions & 7 deletions sycl/include/sycl/bit_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,11 @@ template <typename To, typename From>
#if __cpp_lib_bit_cast || __has_builtin(__builtin_bit_cast)
constexpr
#endif
To
std::enable_if_t<sizeof(To) == sizeof(From) &&
std::is_trivially_copyable<From>::value &&
std::is_trivially_copyable<To>::value,
To>
bit_cast(const From &from) noexcept {
static_assert(sizeof(To) == sizeof(From),
"Sizes of To and From must be equal");
static_assert(std::is_trivially_copyable<From>::value,
"From must be trivially copyable");
static_assert(std::is_trivially_copyable<To>::value,
"To must be trivially copyable");
#if __cpp_lib_bit_cast
return std::bit_cast<To>(from);
#else // __cpp_lib_bit_cast
Expand Down
38 changes: 19 additions & 19 deletions sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,9 @@ GroupBroadcast(const ext::oneapi::experimental::opportunistic_group &g, T x,
template <typename Group, typename T, typename IdT>
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(Group g, T x, IdT local_id) {
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
auto BroadcastX = bit_cast<BroadcastT>(x);
auto BroadcastX = sycl::bit_cast<BroadcastT>(x);
BroadcastT Result = GroupBroadcast(g, BroadcastX, local_id);
return bit_cast<T>(Result);
return sycl::bit_cast<T>(Result);
}
template <typename Group, typename T, typename IdT>
EnableIfGenericBroadcast<T, IdT> GroupBroadcast(Group g, T x, IdT local_id) {
Expand Down Expand Up @@ -406,9 +406,9 @@ template <typename Group, typename T, int Dimensions>
EnableIfBitcastBroadcast<T> GroupBroadcast(Group g, T x,
id<Dimensions> local_id) {
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
auto BroadcastX = bit_cast<BroadcastT>(x);
auto BroadcastX = sycl::bit_cast<BroadcastT>(x);
BroadcastT Result = GroupBroadcast(g, BroadcastX, local_id);
return bit_cast<T>(Result);
return sycl::bit_cast<T>(Result);
}
template <typename Group, typename T, int Dimensions>
EnableIfGenericBroadcast<T> GroupBroadcast(Group g, T x,
Expand Down Expand Up @@ -502,11 +502,11 @@ AtomicCompareExchange(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
auto SPIRVFailure = getMemorySemanticsMask(Failure);
auto SPIRVScope = getScope(Scope);
auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
I DesiredInt = bit_cast<I>(Desired);
I ExpectedInt = bit_cast<I>(Expected);
I DesiredInt = sycl::bit_cast<I>(Desired);
I ExpectedInt = sycl::bit_cast<I>(Expected);
I ResultInt = __spirv_AtomicCompareExchange(
PtrInt, SPIRVScope, SPIRVSuccess, SPIRVFailure, DesiredInt, ExpectedInt);
return bit_cast<T>(ResultInt);
return sycl::bit_cast<T>(ResultInt);
}

template <typename T, access::address_space AddressSpace,
Expand All @@ -530,7 +530,7 @@ AtomicLoad(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
auto SPIRVOrder = getMemorySemanticsMask(Order);
auto SPIRVScope = getScope(Scope);
I ResultInt = __spirv_AtomicLoad(PtrInt, SPIRVScope, SPIRVOrder);
return bit_cast<T>(ResultInt);
return sycl::bit_cast<T>(ResultInt);
}

template <typename T, access::address_space AddressSpace,
Expand All @@ -553,7 +553,7 @@ AtomicStore(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
auto SPIRVOrder = getMemorySemanticsMask(Order);
auto SPIRVScope = getScope(Scope);
I ValueInt = bit_cast<I>(Value);
I ValueInt = sycl::bit_cast<I>(Value);
__spirv_AtomicStore(PtrInt, SPIRVScope, SPIRVOrder, ValueInt);
}

Expand All @@ -577,10 +577,10 @@ AtomicExchange(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
auto SPIRVOrder = getMemorySemanticsMask(Order);
auto SPIRVScope = getScope(Scope);
I ValueInt = bit_cast<I>(Value);
I ValueInt = sycl::bit_cast<I>(Value);
I ResultInt =
__spirv_AtomicExchange(PtrInt, SPIRVScope, SPIRVOrder, ValueInt);
return bit_cast<T>(ResultInt);
return sycl::bit_cast<T>(ResultInt);
}

template <typename T, access::address_space AddressSpace,
Expand Down Expand Up @@ -898,54 +898,54 @@ using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t<T>;
template <typename T>
EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
using ShuffleT = ConvertToNativeShuffleType_t<T>;
auto ShuffleX = bit_cast<ShuffleT>(x);
auto ShuffleX = sycl::bit_cast<ShuffleT>(x);
#ifndef __NVPTX__
ShuffleT Result = __spirv_SubgroupShuffleINTEL(
ShuffleX, static_cast<uint32_t>(local_id.get(0)));
#else
ShuffleT Result =
__nvvm_shfl_sync_idx_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
#endif
return bit_cast<T>(Result);
return sycl::bit_cast<T>(Result);
}

template <typename T>
EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
using ShuffleT = ConvertToNativeShuffleType_t<T>;
auto ShuffleX = bit_cast<ShuffleT>(x);
auto ShuffleX = sycl::bit_cast<ShuffleT>(x);
#ifndef __NVPTX__
ShuffleT Result = __spirv_SubgroupShuffleXorINTEL(
ShuffleX, static_cast<uint32_t>(local_id.get(0)));
#else
ShuffleT Result =
__nvvm_shfl_sync_bfly_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
#endif
return bit_cast<T>(Result);
return sycl::bit_cast<T>(Result);
}

template <typename T>
EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
using ShuffleT = ConvertToNativeShuffleType_t<T>;
auto ShuffleX = bit_cast<ShuffleT>(x);
auto ShuffleX = sycl::bit_cast<ShuffleT>(x);
#ifndef __NVPTX__
ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(ShuffleX, ShuffleX, delta);
#else
ShuffleT Result =
__nvvm_shfl_sync_down_i32(membermask(), ShuffleX, delta, 0x1f);
#endif
return bit_cast<T>(Result);
return sycl::bit_cast<T>(Result);
}

template <typename T>
EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
using ShuffleT = ConvertToNativeShuffleType_t<T>;
auto ShuffleX = bit_cast<ShuffleT>(x);
auto ShuffleX = sycl::bit_cast<ShuffleT>(x);
#ifndef __NVPTX__
ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(ShuffleX, ShuffleX, delta);
#else
ShuffleT Result = __nvvm_shfl_sync_up_i32(membermask(), ShuffleX, delta, 0);
#endif
return bit_cast<T>(Result);
return sycl::bit_cast<T>(Result);
}

template <typename T>
Expand Down

0 comments on commit 834df47

Please sign in to comment.