From ec0130487be8d2af96b6ccac10677219a1efe5a4 Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Tue, 19 Nov 2024 10:23:46 +0800 Subject: [PATCH 1/2] Implement Aten::_foreach_norm when ord == inf (#908) Implementing https://github.com/intel/torch-xpu-ops/issues/710 --------- Co-authored-by: Yutao Xu --- src/ATen/native/xpu/ForeachReduceOp.cpp | 3 +- .../native/xpu/sycl/ForeachReduceKernels.cpp | 302 +++++++++++++----- src/ATen/native/xpu/sycl/GroupReduceUtils.h | 42 +++ src/ATen/native/xpu/sycl/MultiTensorApply.h | 37 +-- 4 files changed, 282 insertions(+), 102 deletions(-) diff --git a/src/ATen/native/xpu/ForeachReduceOp.cpp b/src/ATen/native/xpu/ForeachReduceOp.cpp index 003f6ae14..a60eb600c 100644 --- a/src/ATen/native/xpu/ForeachReduceOp.cpp +++ b/src/ATen/native/xpu/ForeachReduceOp.cpp @@ -61,7 +61,8 @@ std::vector foreach_tensor_norm_xpu( at::isComplexType(scalar_type); }); if (!at::native::can_use_fast_route(tensors) || has_int_or_complex || - !(p == static_cast(1) || p == static_cast(2))) { + !(p == static_cast(1) || p == static_cast(2) || + p == std::numeric_limits::infinity())) { return at::native::foreach_tensor_norm_slow(tensors, ord, dtype); } check_foreach_norm_dtype( diff --git a/src/ATen/native/xpu/sycl/ForeachReduceKernels.cpp b/src/ATen/native/xpu/sycl/ForeachReduceKernels.cpp index cc90fa893..9e982e4a0 100644 --- a/src/ATen/native/xpu/sycl/ForeachReduceKernels.cpp +++ b/src/ATen/native/xpu/sycl/ForeachReduceKernels.cpp @@ -9,22 +9,21 @@ #include -enum class NormType { L1, L2 }; - +enum class NormType { L1, L2, LInf }; +#define SIMD16 16 +#define SIMD32 32 namespace at::native::xpu { template < typename T, NormType norm_type, typename opmath_t, + int SIMD, int depth = 1, int r_args_depth = 1, int res_arg_index = 0> -struct LpNormFunctor { - static_assert( - norm_type == NormType::L1 || norm_type == NormType::L2, - "foreach_norm supports only L1 and L2 norm"); +struct LpNormFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { template - void operator()( + [[intel::reqd_sub_group_size(SIMD)]] void operator()( const int64_t chunk_size, TLA tlAddress, TLW tlWGMeta, @@ -57,9 +56,13 @@ struct LpNormFunctor { #pragma unroll for (int ii = 0; ii < kILP; ii++) { opmath_t next = static_cast(r_x[ii]); - vals[ii] += norm_type == NormType::L1 - ? static_cast(std::fabs((opmath_t)next)) - : static_cast(next * next); + if constexpr (norm_type == NormType::LInf) { + vals[ii] = max_impl(vals[ii], std::fabs((opmath_t)next)); + } else { + vals[ii] += norm_type == NormType::L1 + ? static_cast(std::fabs((opmath_t)next)) + : static_cast(next * next); + } } } } else { @@ -70,9 +73,13 @@ struct LpNormFunctor { int i = i_start + item_idx + ii * item_range; if (i < n && i < chunk_size) { opmath_t next = static_cast(x[i]); - vals[ii] += norm_type == NormType::L1 - ? static_cast(std::fabs((opmath_t)next)) - : static_cast(next * next); + if constexpr (norm_type == NormType::LInf) { + vals[ii] = max_impl(vals[ii], ::abs(std::fabs((opmath_t)next))); + } else { + vals[ii] += norm_type == NormType::L1 + ? static_cast(std::fabs((opmath_t)next)) + : static_cast(next * next); + } } } } @@ -80,22 +87,35 @@ struct LpNormFunctor { auto val = opmath_t(0); for (int i = 0; i < kILP; i++) { - val += vals[i]; + if constexpr (norm_type == NormType::LInf) { + val = max_impl(val, vals[i]); + } else { + val += vals[i]; + } } - auto sum_val = sycl::reduce_over_group( - item_id.get_group(), val, sycl::plus()); + auto sum_val = norm_type == NormType::L1 || norm_type == NormType::L2 + ? GroupReduceSumWithoutBroadcast(item_id, val, shared_) + : GroupReduceMaxWithoutBroadcast(item_id, val, shared_); if (item_idx == 0) { output_per_tensor[tensor_loc * max_chunks_per_tensor + chunk_idx] = sum_val; } } + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = + sycl_local_acc_t(get_group_reduce_group_size(SIMD), cgh); + } + + private: + sycl_local_acc_t shared_; }; -template -struct lpnormChunkReduceKernelFunctor { - void operator()(sycl::nd_item<1> item_id) const { +template +struct lpnormChunkReduceKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + [[intel::reqd_sub_group_size(SIMD)]] void operator()( + sycl::nd_item<1> item_id) const { auto lid = item_id.get_local_linear_id(); auto group_id = item_id.get_group(0); @@ -103,15 +123,26 @@ struct lpnormChunkReduceKernelFunctor { output_per_tensor_ + group_id * max_chunks_per_tensor_; opmath_t val = 0; for (int i = lid; i < max_chunks_per_tensor_; i += wg_size_) { - val += output_this_tensor[i]; + if constexpr (norm_type == NormType::LInf) { + val = max_impl(val, output_this_tensor[i]); + } else { + val += output_this_tensor[i]; + } } - auto sum_val = sycl::reduce_over_group( - item_id.get_group(), val, sycl::plus()); + auto sum_val = norm_type == NormType::L1 || norm_type == NormType::L2 + ? GroupReduceSumWithoutBroadcast(item_id, val, shared_) + : GroupReduceMaxWithoutBroadcast(item_id, val, shared_); if (lid == 0) { *(ret_per_tensor_[group_id]) = - norm_type == NormType::L1 ? sum_val : std::sqrt((opmath_t)sum_val); + norm_type == NormType::L1 || norm_type == NormType::LInf + ? sum_val + : std::sqrt((opmath_t)sum_val); } } + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = + sycl_local_acc_t(get_group_reduce_group_size(SIMD), cgh); + } lpnormChunkReduceKernelFunctor( const opmath_t* output_per_tensor, out_t** ret_per_tensor, @@ -127,16 +158,17 @@ struct lpnormChunkReduceKernelFunctor { out_t** ret_per_tensor_; int max_chunks_per_tensor_; int wg_size_; + sycl_local_acc_t shared_; }; -template +template void launch_lpnorm_chunk_reduce_kernel( const out_opmath_t* output_per_tensor, out_t** ret_per_tensor, int wg_size, int max_chunks_per_tensor, int n_tensor) { - lpnormChunkReduceKernelFunctor kfn( + lpnormChunkReduceKernelFunctor kfn( output_per_tensor, ret_per_tensor, max_chunks_per_tensor, wg_size); sycl_kernel_submit( @@ -159,18 +191,18 @@ void launch_lpnorm_chunk_reduce_kernel( AT_PRIVATE_CASE_TYPE_USING_HINT( \ at::ScalarType::BFloat16, out_t, __VA_ARGS__)) -template void foreach_norn_kernel_config( TensorList tensors, TensorOptions output_per_tensor_option, + int64_t simd, int64_t& wg_size, int& max_chunks_per_tensor, Tensor& output_per_tensor) { const int ntensors = tensors.size(); max_chunks_per_tensor = -1; - wg_size = multi_tensor_apply_kernel_get_wg_size(); - int64_t kChunkSize = multi_tensor_apply_kernel_get_chunk_size(); + wg_size = multi_tensor_apply_kernel_get_wg_size(simd); + int64_t kChunkSize = multi_tensor_apply_kernel_get_chunk_size(simd); for (int t = 0; t < ntensors; t++) { int max_chunks_this_tensor = @@ -191,7 +223,6 @@ std::vector foreach_norm_kernel( double p, c10::optional dtype) { const int ntensors = tensors.size(); - const ScalarType output_dtype = // tensors[0].scalar_type(); dtype.has_value() ? dtype.value() : tensors[0].scalar_type(); const auto options = tensors[0].options(); @@ -217,6 +248,14 @@ std::vector foreach_norm_kernel( int64_t wg_size; int max_chunks_per_tensor; Tensor output_per_tensor; + int64_t simd = syclMaxSubGroupSize(); + foreach_norn_kernel_config( + tensors, + output_per_tensor_option, + simd, + wg_size, + max_chunks_per_tensor, + output_per_tensor); if (p == static_cast(1)) { AT_DISPATCH_FLOATING_TYPES_AND2( kHalf, @@ -227,23 +266,28 @@ std::vector foreach_norm_kernel( AT_DISPATCH_OUT_DTYPES( output_dtype, "foreach_norm_out_dtype_xpu", [&]() { using out_opmath_t = typename at::opmath_type; - using KernelClass = lpnormChunkReduceKernelFunctor< - out_t, - NormType::L1, - out_opmath_t>; - foreach_norn_kernel_config( - tensors, - output_per_tensor_option, - wg_size, - max_chunks_per_tensor, - output_per_tensor); - // sum temp val for each chunk - multi_tensor_apply<1>( - tensor_lists, - LpNormFunctor(), - output_per_tensor.mutable_data_ptr(), - max_chunks_per_tensor); + if (simd == SIMD32) { + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor< + scalar_t, + NormType::L1, + out_opmath_t, + SIMD32>(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + } else { + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor< + scalar_t, + NormType::L1, + out_opmath_t, + SIMD16>(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + } for (int i = 0; i < ntensors; i++) { tensor_list_addresses[i] = ret_per_tensor[i].mutable_data_ptr(); @@ -257,15 +301,29 @@ std::vector foreach_norm_kernel( (void*)tensor_list_addresses, tensor_list_addresses_dptr.get_context(), at::xpu::getCurrentXPUStream()); - launch_lpnorm_chunk_reduce_kernel< - out_t, - NormType::L1, - out_opmath_t>( - output_per_tensor.mutable_data_ptr(), - (out_t**)(metaAddress), - wg_size, - max_chunks_per_tensor, - ntensors); + if (simd == SIMD32) { + launch_lpnorm_chunk_reduce_kernel< + out_t, + NormType::L1, + out_opmath_t, + SIMD32>( + output_per_tensor.mutable_data_ptr(), + (out_t**)(metaAddress), + wg_size, + max_chunks_per_tensor, + ntensors); + } else { + launch_lpnorm_chunk_reduce_kernel< + out_t, + NormType::L1, + out_opmath_t, + SIMD16>( + output_per_tensor.mutable_data_ptr(), + (out_t**)(metaAddress), + wg_size, + max_chunks_per_tensor, + ntensors); + } }); }); } else if (p == static_cast(2)) { @@ -278,22 +336,96 @@ std::vector foreach_norm_kernel( AT_DISPATCH_OUT_DTYPES( output_dtype, "foreach_norm_out_dtype_xpu", [&]() { using out_opmath_t = typename at::opmath_type; - using KernelClass = lpnormChunkReduceKernelFunctor< - out_t, - NormType::L2, - out_opmath_t>; - foreach_norn_kernel_config( - tensors, - output_per_tensor_option, - wg_size, - max_chunks_per_tensor, - output_per_tensor); + if (simd == SIMD32) { + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor< + scalar_t, + NormType::L2, + out_opmath_t, + SIMD32>(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + } else { + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor< + scalar_t, + NormType::L2, + out_opmath_t, + SIMD16>(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + } + for (int i = 0; i < ntensors; i++) { + tensor_list_addresses[i] = + ret_per_tensor[i].mutable_data_ptr(); + } + q.memcpy( + (void*)metaAddress, + (void*)tensor_list_addresses, + sizeof(void*) * ntensors); - multi_tensor_apply<1>( - tensor_lists, - LpNormFunctor(), - output_per_tensor.mutable_data_ptr(), - max_chunks_per_tensor); + at::xpu::CachingHostAllocator_recordEvent( + (void*)tensor_list_addresses, + tensor_list_addresses_dptr.get_context(), + at::xpu::getCurrentXPUStream()); + if (simd == SIMD32) { + launch_lpnorm_chunk_reduce_kernel< + out_t, + NormType::L2, + out_opmath_t, + SIMD32>( + output_per_tensor.mutable_data_ptr(), + (out_t**)(metaAddress), + wg_size, + max_chunks_per_tensor, + ntensors); + } else { + launch_lpnorm_chunk_reduce_kernel< + out_t, + NormType::L2, + out_opmath_t, + SIMD16>( + output_per_tensor.mutable_data_ptr(), + (out_t**)(metaAddress), + wg_size, + max_chunks_per_tensor, + ntensors); + } + }); + }); + } else if (p == std::numeric_limits::infinity()) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + tensor_lists[0][0].scalar_type(), + "foreach_norm", + [&]() { + AT_DISPATCH_OUT_DTYPES( + output_dtype, "foreach_norm_out_dtype_xpu", [&]() { + using out_opmath_t = typename at::opmath_type; + if (simd == SIMD32) { + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor< + scalar_t, + NormType::LInf, + out_opmath_t, + SIMD32>(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + } else { + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor< + scalar_t, + NormType::LInf, + out_opmath_t, + SIMD16>(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + } for (int i = 0; i < ntensors; i++) { tensor_list_addresses[i] = ret_per_tensor[i].mutable_data_ptr(); @@ -307,15 +439,29 @@ std::vector foreach_norm_kernel( (void*)tensor_list_addresses, tensor_list_addresses_dptr.get_context(), at::xpu::getCurrentXPUStream()); - launch_lpnorm_chunk_reduce_kernel< - out_t, - NormType::L2, - out_opmath_t>( - output_per_tensor.mutable_data_ptr(), - (out_t**)(metaAddress), - wg_size, - max_chunks_per_tensor, - ntensors); + if (simd == SIMD32) { + launch_lpnorm_chunk_reduce_kernel< + out_t, + NormType::LInf, + out_opmath_t, + SIMD32>( + output_per_tensor.mutable_data_ptr(), + (out_t**)(metaAddress), + wg_size, + max_chunks_per_tensor, + ntensors); + } else { + launch_lpnorm_chunk_reduce_kernel< + out_t, + NormType::LInf, + out_opmath_t, + SIMD16>( + output_per_tensor.mutable_data_ptr(), + (out_t**)(metaAddress), + wg_size, + max_chunks_per_tensor, + ntensors); + } }); }); } else { diff --git a/src/ATen/native/xpu/sycl/GroupReduceUtils.h b/src/ATen/native/xpu/sycl/GroupReduceUtils.h index 3eb1cd08b..07aecd092 100644 --- a/src/ATen/native/xpu/sycl/GroupReduceUtils.h +++ b/src/ATen/native/xpu/sycl/GroupReduceUtils.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -69,6 +70,47 @@ inline T& GroupReduceSumWithoutBroadcast( return val; } +template +inline T& SubgroupReduceMaxWithoutBroadcast(sycl::nd_item& item, T& val) { + auto sg = item.get_sub_group(); + auto sg_tid = sg.get_local_linear_id(); +#pragma unroll + for (int offset = 1; offset < SIMD; offset <<= 1) { + T temp = sycl::shift_group_left(sg, val, offset); + if (sg_tid < SIMD - offset) { + val = max_impl(temp, val); + } + } + return val; +} + +template +inline T& GroupReduceMaxWithoutBroadcast( + sycl::nd_item& item, + T& val, + shared_t shared) { + auto sg = item.get_sub_group(); + int sg_tid = sg.get_local_linear_id(); + int sg_id = sg.get_group_linear_id(); + int n_sg = get_local_linear_range(item) / SIMD; + val = SubgroupReduceMaxWithoutBroadcast(item, val); + item.barrier(sycl_local_fence); // prevent races when GroupReduceSum are + // called in a row. + if (n_sg == 1) { + return val; + } + if (sg_tid == 0) { + shared[sg_id] = val; + } + item.barrier(sycl_local_fence); + if (sg_id == 0) { + for (int i = 1; i < n_sg; i++) { + val = max_impl(val, shared[i]); + } + } + return val; +} + template inline T& SubgroupReduceWithoutBroadcast( sycl::nd_item& item, diff --git a/src/ATen/native/xpu/sycl/MultiTensorApply.h b/src/ATen/native/xpu/sycl/MultiTensorApply.h index ca0feb3c0..51ee195a9 100644 --- a/src/ATen/native/xpu/sycl/MultiTensorApply.h +++ b/src/ATen/native/xpu/sycl/MultiTensorApply.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -48,14 +49,12 @@ struct TLMetaForWG { uint32_t wg_to_chunk; }; -template -static int64_t multi_tensor_apply_kernel_get_wg_size() { - return syclMaxWorkGroupSize(); +static int64_t multi_tensor_apply_kernel_get_wg_size(int simd) { + return get_group_reduce_group_size(simd); } -template -static int64_t multi_tensor_apply_kernel_get_chunk_size() { - int64_t max_wg_size = multi_tensor_apply_kernel_get_wg_size(); +static int64_t multi_tensor_apply_kernel_get_chunk_size(int simd) { + int64_t max_wg_size = multi_tensor_apply_kernel_get_wg_size(simd); return max_wg_size * kElementPerThread; } @@ -118,18 +117,19 @@ void launch_multi_tensor_apply_kernel( U callable, int num_wg, ArgTypes... args) { - using KernelClass = MultiTensorApplyKernelFunctor; auto& q = getCurrentSYCLQueue(); - int64_t max_wg_size = multi_tensor_apply_kernel_get_wg_size(); - int64_t kChunkSize = multi_tensor_apply_kernel_get_chunk_size(); + int64_t simd = syclMaxSubGroupSize(); + int64_t max_wg_size = multi_tensor_apply_kernel_get_wg_size(simd); + int64_t kChunkSize = multi_tensor_apply_kernel_get_chunk_size(simd); if constexpr (fused_kernel) { max_wg_size = multi_tensor_apply_fused_kernel_get_wg_size(); kChunkSize = multi_tensor_apply_fused_kernel_get_chunk_size(); } - KernelClass kfn(kChunkSize, tlAddressMeta, tlWGMeta, callable, args...); + MultiTensorApplyKernelFunctor kfn( + kChunkSize, tlAddressMeta, tlWGMeta, callable, args...); sycl_kernel_submit( sycl::range<1>(num_wg * max_wg_size), @@ -145,19 +145,14 @@ void multi_tensor_apply( T callable, ArgTypes... args) { using scalar_vals_t = typename T::opmath_t; - using KernelClass = MultiTensorApplyKernelFunctor< - TLMetaForAddressScalar*, - TLMetaForWG*, - T, - ArgTypes...>; - TORCH_CHECK( tensor_lists.size() == depth, "Number of tensor lists has to match he depth"); size_t n_tensors = tensor_lists[0].size(); auto& q = getCurrentSYCLQueue(); - int64_t kChunkSize = multi_tensor_apply_kernel_get_chunk_size(); + int64_t simd = syclMaxSubGroupSize(); + int64_t kChunkSize = multi_tensor_apply_kernel_get_chunk_size(simd); auto addressStorage = at::empty( {(int)(sizeof(TLMetaForAddressScalar) * n_tensors)}, @@ -231,11 +226,6 @@ void multi_tensor_apply( std::vector>& tensor_lists, T callable, ArgTypes... args) { - using KernelClass = MultiTensorApplyKernelFunctor< - TLMetaForAddress*, - TLMetaForWG*, - T, - ArgTypes...>; TORCH_CHECK( tensor_lists.size() == depth, @@ -243,7 +233,8 @@ void multi_tensor_apply( size_t n_tensors = tensor_lists[0].size(); auto& q = getCurrentSYCLQueue(); - int64_t kChunkSize = multi_tensor_apply_kernel_get_chunk_size(); + int64_t simd = syclMaxSubGroupSize(); + int64_t kChunkSize = multi_tensor_apply_kernel_get_chunk_size(simd); auto addressStorage = at::empty( {(int)(sizeof(TLMetaForAddress) * n_tensors)}, From 000d3439cec482d051b82ceb558dba22b2fe00b3 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Tue, 19 Nov 2024 13:10:36 +0800 Subject: [PATCH 2/2] Rebase tensor mode operator for avoiding SYCL intrinsics (#1095) Resolve issue: https://github.com/intel/torch-xpu-ops/issues/1093 --- src/ATen/native/xpu/sycl/TensorModeKernel.cpp | 1270 +++++++---------- 1 file changed, 535 insertions(+), 735 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorModeKernel.cpp b/src/ATen/native/xpu/sycl/TensorModeKernel.cpp index d9dfbe69a..7ae95e36b 100644 --- a/src/ATen/native/xpu/sycl/TensorModeKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorModeKernel.cpp @@ -1,11 +1,13 @@ #include #include +#include #include #include +#include +#include +#include #include #include - -#include #include #include #include @@ -14,673 +16,547 @@ namespace at::native::xpu { using namespace at::xpu::detail; -// this helper's meaning: -// The [status] contains the judgement result comapred with the adjecent -// elements's equivalence. In reduce, the [status] is used to help to record the -// max appearance times's index after scan. The [value] is the initial number -// according to the status after comparison. It is used in the following scan. -struct ModeOpHelper { - // why not int64_t: ModeOpHelper is used for the condition that problem - // size is equal to or smaller than the work group max size, so the - // accumulation value will not exceed the int32_t range - int32_t status; - int32_t value; -}; - -// this is used to record the sorted value(T) and the associated index(int64_t) -// for the fused mode kernel -template -struct ModeOpValueIndex { - T value; - int64_t index; -}; +constexpr int64_t MAX_GROUP_SIZE = 256; +constexpr int64_t MAX_GRID_SIZE = 65535LL; -// value handler for non standard bool value -template -inline scalar_t value_ptr_convert(const scalar_t* base, int64_t index) { - return base[index]; +template +inline integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; } -template <> -inline bool value_ptr_convert(const bool* base, int64_t index) { - return reinterpret_cast(base)[index] > 0 ? true : false; +// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks +inline uint64_t next_highest_power_of_2(uint64_t n) { + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; +#ifndef _MSC_VER + n |= n >> 32; +#endif + n++; + return n; } -// problem size >> array size -template -static inline void ConditionalInclusiveScanForMode( - T* memStatus, - T* memValue, - Functor functor, - const int64_t problem_size, - const int64_t outer_offset, - const int64_t inner_limit, - const item_t& item) { - auto id = item.get_local_id(0); - auto group_size = item.get_local_range(0); - - for (auto inner_id = id; inner_id < inner_limit; inner_id += group_size) { - // x x x x | x x x x | x x x x | x x x x - // ^ ^ --- one item solely compute these two values exclude the - // first one - auto global_id = outer_offset + inner_id; - if (id == 0 && inner_id != 0 && inner_id < problem_size) { - std::tie(memStatus[global_id], memValue[global_id]) = functor( - memStatus[global_id - 1], - memValue[global_id - 1], - memStatus[global_id], - memValue[global_id]); - } - item.barrier(sycl_global_fence); - - // x x x x | x x x x | x x x x | x x x x - // ^ ^ ^ ^ --- one group scan one piece of the full data - for (auto stride = 1; stride < group_size; stride <<= 1) { - T PreStatus = 0; - T PreValue = 0; - T CurStatus = 0; - T CurValue = 0; - if (inner_id < problem_size && id >= stride) { - PreStatus = memStatus[global_id - stride]; - PreValue = memValue[global_id - stride]; - CurStatus = memStatus[global_id]; - CurValue = memValue[global_id]; - } - item.barrier(sycl_global_fence); - if (inner_id < problem_size && id >= stride) { - std::tie(memStatus[global_id], memValue[global_id]) = - functor(PreStatus, PreValue, CurStatus, CurValue); - } - item.barrier(sycl_global_fence); - } +std::tuple get_workgroup_number_from_tiles( + int64_t gridTiles) { + if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { + TORCH_INTERNAL_ASSERT(false); } -} -// problem size = array size = slm size -// No implement it as down sweep and up sweep -template -static inline void ConditionalInclusiveScanForFusedMode( - T* memHelper, - Functor functor, - const int64_t problem_size, - const item_t& item) { - auto id = item.get_local_id(0); - for (auto stride = 1; stride < problem_size; stride <<= 1) { - T PreElem; - if (id >= stride) { - PreElem = memHelper[id - stride]; - } - item.barrier(sycl_local_fence); - if (id >= stride) { - memHelper[id] = functor(PreElem, memHelper[id]); + int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + int64_t gridY = 1; + int64_t gridZ = 1; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; } - item.barrier(sycl_local_fence); } + return std::make_tuple(gridX, gridY, gridZ); } -// reduce, problem size >> array size -template -static inline void ReduceHelperForMode( - T* mem, - Functor functor, - const int64_t problem_size, - const int64_t outer_offset, - const int64_t group_size, - const item_t& item) { - auto id = item.get_local_id(0); - - // first add - for (auto inner_id = id + group_size; inner_id < problem_size; - inner_id += group_size) { - mem[outer_offset + id] = - functor(mem[outer_offset + id], mem[outer_offset + inner_id]); - } - item.barrier(sycl_global_fence); +template +inline index_t get_linear_group_id(sycl::nd_item<3> item) { + return item.get_group(0) * item.get_group_range(1) * item.get_group_range(2) + + item.get_group(1) * item.get_group_range(2) + item.get_group(2); +} - // naive tree - for (auto stride = group_size / 2; stride > 0; stride >>= 1) { - if (id < stride) { - auto tree_id = outer_offset + id; - mem[tree_id] = functor(mem[tree_id], mem[tree_id + stride]); - } - item.barrier(sycl_global_fence); +template +inline void swapVars(T& t1, T& t2) { + T tmp = t1; + t1 = t2; + t2 = tmp; +} + +template +inline void bitonicSwapKeys( + K& kA, + bool& validA, + K& kB, + bool& validB, + bool dir, + const Comparator& comp) { + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(validA, validB); } } -// reduce, problem size = array size = slm size -template -static inline void ReduceHelperForFusedMode( - ModeOpHelper* mem, - Functor functor, - const int64_t group_size, - const item_t& item) { - auto id = item.get_local_id(0); - - // naive tree - for (auto stride = group_size / 2; stride > 0; stride >>= 1) { - if (id < stride) { - mem[id].value = functor(mem[id], mem[id + stride]); +template < + typename K, + typename IndexType, + int Power2SortSize, + typename Comparator> +inline void bitonicSortKeys( + sycl::nd_item<3> item, + K keys[Power2SortSize], + bool valid[Power2SortSize], + const Comparator& comp) { + auto tx = item.get_local_id(2); +#pragma unroll + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((tx & (size / 2)) != 0); + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + item.barrier(sycl_local_fence); + unsigned int pos = 2 * tx - (tx & (stride - 1)); + bitonicSwapKeys( + keys[pos], + valid[pos], + keys[pos + stride], + valid[pos + stride], + flag, + comp); } + } +#pragma unroll + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { item.barrier(sycl_local_fence); + unsigned int pos = 2 * tx - (tx & (stride - 1)); + bitonicSwapKeys( + keys[pos], + valid[pos], + keys[pos + stride], + valid[pos + stride], + false, + comp); + } + item.barrier(sycl_local_fence); +} - // odd stride - if (stride % 2 == 1) { - if (id == 0) { - mem[0].value = functor(mem[0], mem[stride - 1]); - } +template +struct BitonicSortFn { + bool operator()(const T& a, const T& b) const { + return a < b; + } +}; + +// Used for a segmented reduction +struct ModeUnsignedBoolPair { + unsigned int val; + bool flag; +}; + +// In the kernel below, we have a common pattern of reducing (unsigned int, +// unsigned int) pairs of data +struct ModeUnsignedPair { + unsigned int val; + unsigned int index; +}; + +// Inclusive Scan via an upsweep/downsweep mechanism. Assumes: +// +// 1. Power2ScanSize is a power of 2. This code still works for collections that +// do not exactly contain a power of 2 number of elements, simply round up to +// the nearest power of 2 and then call. +// +// 2. That there are two-elements per thread, i.e. the size of the smem storage +// is 2 * groupDim.x * sizeof(T). +// +// Consider a (+)-Scan on the following elements: +// +// Upsweep: +// +// 0 1 2 3 4 5 6 7 +// 1 5 9 13 +// 6 22 +// 28 +// +// Downsweep: +// 15 +// 3 10 21 +template +inline void inclusivePrefixScan( + sycl::nd_item<3> item, + T* smem, + BinaryOp binop) { + // Reduce step ("upsweep") +#pragma unroll + for (int stride = 1; stride < Power2ScanSize; stride <<= 1) { + int index = (item.get_local_id(2) + 1) * stride * 2 - 1; + if (index < Power2ScanSize) { + smem[index] = binop(smem[index], smem[index - stride]); } item.barrier(sycl_local_fence); } - // odd group_size - if (group_size % 2 == 1) { - if (id == 0) { - mem[0].value = functor(mem[0], mem[group_size - 1]); + // Post-reduce step ("downsweep") +#pragma unroll + for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) { + int index = (item.get_local_id(2) + 1) * stride * 2 - 1; + if ((index + stride) < Power2ScanSize) { + smem[index + stride] = binop(smem[index + stride], smem[index]); } + item.barrier(sycl_local_fence); } - item.barrier(sycl_local_fence); } -template -static inline void ReduceGetMaxElemIndexForFusedMode( - ModeOpHelper* mem, - const int64_t group_size, - const item_t& item) { - auto id = item.get_local_id(0); - // value : 0 1 0 1 [2] 0 0 1 - // status: 0 1 0 1 2 0 0 1 - // reduce to find the [maximal] value, it means the most appear times of Mode - ReduceHelperForFusedMode( - mem, - [&](const ModeOpHelper& a, const ModeOpHelper& b) { - return (a.value > b.value) ? (a.value) : (b.value); - }, - group_size, - item); - item.barrier(sycl_local_fence); - - auto max_appearance_time = mem[0].value; - item.barrier(sycl_local_fence); +template +struct InclusivePrefixScanFunctor { + ModeUnsignedBoolPair operator()(const T& a, const T& b) const { + ModeUnsignedBoolPair c; + c.val = a.flag ? a.val : a.val + b.val; + c.flag = a.flag | b.flag; + return c; + } +}; - // id: 0 1 2 3 4 5 6 7 - // value : [2] x x x x x x x (x means ignore value) - // status: 0 1 0 1 2 0 0 1 - // new value: M M M M [4] M M M (M means the max value) - mem[id].value = (mem[id].status == max_appearance_time) - ? (id) - : (std::numeric_limits::max()); - item.barrier(sycl_local_fence); +template +inline T reduceGroupWithNThreadLocalReductions( + sycl::nd_item<3> item, + T* smem, + T threadVals[N], + const unsigned int numVals, + ReduceOp reduceOp, + T init) { + int offset = item.get_local_id(2) * N; + T local = offset < numVals ? threadVals[0] : init; + +#pragma unroll + for (int i = 1; i < N; ++i) { + ++offset; + T next = offset < numVals ? threadVals[i] : init; + local = reduceOp.combine(local, next); + } - // reduce again to find the [minimal] index and put it into mem[0].value - // value: M M M M [4] M M M - ReduceHelperForFusedMode( - mem, - [&](const ModeOpHelper& a, const ModeOpHelper& b) { - return (a.value < b.value) ? (a.value) : (b.value); - }, - group_size, - item); + return GroupReduceWithoutBroadcast( + item, local, reduceOp, smem); } -template < - typename scalar_t, - typename value_info_t, - typename indice_info_t, - typename item_t> -void mode_impl( - const scalar_t* problem_values_ptr, - const int64_t* problem_indices_ptr, - value_info_t answer_values, - indice_info_t answer_indices, - scalar_t* slm_ptr, - int64_t* scratch_status_ptr, - int64_t* scratch_value_ptr, - const int64_t problem_time, - const int64_t problem_size, - const int64_t wg_number, - const int64_t wg_size, - const int64_t inner_limit, - const item_t& item) { - auto group_id = item.get_group_linear_id(); - auto item_id = item.get_local_id(0); - - // outer loop, problem time level - for (auto outer_id = group_id; outer_id < problem_time; - outer_id += wg_number) { - auto outer_offset = outer_id * problem_size; - // inner loop, problem size level - for (auto inner_id = item_id; inner_id < inner_limit; inner_id += wg_size) { - auto global_index = outer_offset + inner_id; - - // load piece of data into slm - if (inner_id < problem_size) { - slm_ptr[item_id] = value_ptr_convert(problem_values_ptr, global_index); - } - item.barrier(sycl_local_fence); - - // compare and record the status using true and false into scratch pad - // buffer. 0 means begin a new sequence, 1 means the duplicated values. - // sorted values: 0 0 1 1 1 3 4 4 (here problem value is sorted) - // associated indices: 4 6 0 5 7 2 1 3 - // scratch status: 0 1 0 1 1 0 0 1 - // scratch value: 0 1 0 1 1 0 0 1 - // the first value is always status 0 and value 0 - if (inner_id == 0) { - scratch_status_ptr[outer_offset] = 0; - scratch_value_ptr[outer_offset] = 0; - } else { - if (inner_id < problem_size) { - // kick out the first item - auto judgeEqual = false; - if (item_id == 0) { - // for the first one, its pre value is not in slm - // slm: 0 1 1 - // global mem: 0 ^ ---- the pre one is in global mem - judgeEqual = bool( - value_ptr_convert(problem_values_ptr, global_index - 1) == - slm_ptr[item_id]); - } else { - judgeEqual = bool(slm_ptr[item_id - 1] == slm_ptr[item_id]); - } - auto status = (judgeEqual) ? (1) : (0); - scratch_status_ptr[global_index] = status; - scratch_value_ptr[global_index] = status; - } - } - item.barrier(sycl_global_fence); +template +struct ComputeModeKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + [[intel::reqd_sub_group_size(32)]] void operator()( + sycl::nd_item<3> item) const { + int tidx = item.get_local_id(2); + int stidx = item.get_local_range(2) + + item.get_local_id(2); // Second index this thread responsible for + + // First, we need to calculate the offset into the sorted Tensor that + // represents the start of the slice for this group to calculate the mode + // for. This offset is a combination of the gridIndices, and the number of + // elements in the slice. + unsigned int groupId = get_linear_group_id(item); + unsigned int linearOffset = groupId * sliceSize_; + + if (groupId >= slices_) { + return; } - // index: 0 1 2 3 4 5 6 7 - // scratch status: 0 1 0 1 1 0 0 1 - // scratch value: 0 1 0 1 1 0 0 1 - // conditional scan rule: - // 1. according to current status, if 0, keep the current value. If 1, do - // accumulation - // 2. new status = previous scratch status & current scratch status to - // record if there is duplicated value. - // The conditional inclusive scan result value should be: 0 1 0 1 2 0 0 1 - ConditionalInclusiveScanForMode( - scratch_status_ptr, - scratch_value_ptr, - [&](const int64_t& PreStatus, - const int64_t& PreValue, - const int64_t& CurStatus, - const int64_t& CurValue) { - auto TempValue = CurStatus ? (PreValue + CurValue) : (CurValue); - auto TempStatus = PreStatus & CurStatus; - return std::make_tuple(TempStatus, TempValue); - }, - problem_size, - outer_offset, - inner_limit, - item); - item.barrier(sycl_global_fence); - - // copy scratch value into status - // scratch status: 0 1 0 1 2 0 0 1 - // scratch value: 0 1 0 1 2 0 0 1 - // now the status is changed to be used to record the scan result - for (auto inner_id = item_id; inner_id < problem_size; - inner_id += wg_size) { - auto global_index = outer_offset + inner_id; - scratch_status_ptr[global_index] = scratch_value_ptr[global_index]; - } - item.barrier(sycl_global_fence); - - // scratch status: 0 1 0 1 2 0 0 1 - // scratch value: 0 1 0 1 [2] 0 0 1 - // reduce scratch value to find the max number - ReduceHelperForMode( - scratch_value_ptr, - [&](const int64_t& a, const int64_t& b) { return (a < b) ? (b) : (a); }, - problem_size, - outer_offset, - wg_size, - item); - - // the reduced maximal number is stored in the first slot - auto most_appearance_time = scratch_value_ptr[outer_offset]; - item.barrier(sycl_global_fence); - - // update the value by comparing that if the status is equal to the found - // maximul appearance times, if equal, assign the global index scratch - // index: 0 1 2 3 4 5 6 7 - // max appearance: [2] - // scratch status: 0 1 0 1 [2] 0 0 1 - // update scratch value: M M M M 4 M M M - // M is max int number - for (auto inner_id = item_id; inner_id < problem_size; - inner_id += wg_size) { - auto global_index = outer_offset + inner_id; - scratch_value_ptr[global_index] = - (scratch_status_ptr[global_index] == most_appearance_time) - ? (inner_id) - : (std::numeric_limits::max()); + // smem represents a proportion of the shared memory buffer that is used to + // store the elements from the slice: + T* smem = reinterpret_cast( + shmem_.template get_multi_ptr().get()); + + // Each thread loads up to two elements from the Tensor into shared memory + if (tidx < sliceSize_) { + smem[tidx] = c10::load(&input_[linearOffset + tidx]); } - item.barrier(sycl_global_fence); - - // appearance in scratch value - // scratch value: M M M M [4] M M M - // reduce scratch value to find the min number - ReduceHelperForMode( - scratch_value_ptr, - [&](const int64_t& a, const int64_t& b) { return (a < b) ? (a) : (b); }, - problem_size, - outer_offset, - wg_size, - item); - - // only one elem operation - if (item_id == 0) { - // the reduced minimal number is stored in the first index - auto reduce_min_index = scratch_value_ptr[outer_offset]; - - // index: 0 1 2 3 [4] 5 6 7 - // sorted indices: 4 6 0 5 [7] 2 1 3 - // find out the first-appeared and most-appeared element's original - // indices, it is 7 - auto answer_mode_index = - problem_indices_ptr[outer_offset + reduce_min_index]; - - // index: 0 1 2 3 [4] 5 6 7 - // sorted values: 0 0 1 1 [1] 3 4 4 - // find out the most-appeared value, it is 1 - auto answer_mode_value = value_ptr_convert( - problem_values_ptr, outer_offset + reduce_min_index); - - // write back - auto output_index = - IndexToOffset::get(outer_id, answer_values); - answer_values.data[output_index] = answer_mode_value; - answer_indices.data[output_index] = answer_mode_index; + if (stidx < sliceSize_) { + smem[stidx] = c10::load(&input_[linearOffset + stidx]); } - item.barrier(sycl_global_fence); - } -} -template < - typename scalar_t, - typename value_info_t, - typename indice_info_t, - typename item_t> -void mode_fused_impl( - const scalar_t* problem_values_ptr, - value_info_t answer_values, - indice_info_t answer_indices, - ModeOpHelper* slm_helper_ptr, - ModeOpValueIndex* slm_value_indice_ptr, - std::byte* sort_scratch_pointer, - const int64_t sort_scratch_memory_size, - const int64_t problem_time, - const int64_t problem_size, - const int64_t wg_number, - const int64_t wg_size, - const item_t& item) { - // read problem values into slm of the group - auto group_id = item.get_group_linear_id(); - auto item_id = item.get_local_id(0); - - // outer loop, problem time level - for (auto outer_id = group_id; outer_id < problem_time; - outer_id += wg_number) { - auto global_index = outer_id * problem_size + item_id; - - // load values and record indices into slm - // slm value 1 4 3 4 0 1 0 1 - // slm indices 0 1 2 3 4 5 6 7 - slm_value_indice_ptr[item_id].value = - value_ptr_convert(problem_values_ptr, global_index); - slm_value_indice_ptr[item_id].index = item_id; - item.barrier(sycl_local_fence); + // Next, we initialize a boolean region of the buffer, offset by the loaded + // element smem region + bool* bmem = reinterpret_cast(&smem[Power2Size]); + + // The first use of this region stores bmem[i] = i < sliceSize to mark the + // valid components in the smem buffer + bmem[tidx] = tidx < sliceSize_; + bmem[stidx] = stidx < sliceSize_; + item.barrier(sycl_local_fence); // barrier for smem, bmem initialization + + // First, sort the input slice in ascending order. smem contains the input + // elements, and bmem marks the valid indices + bitonicSortKeys( + item, smem, bmem, BitonicSortFn()); + item.barrier( + sycl_local_fence); // make no assumptions that the sort syncs at end + + // The next step of our algorithm is performing a group-wide comparison of + // neighboring elements. In particular, given an sorted input slice A, we + // produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], + // otherwise 0. + // + // Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8] + // B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1] + // + // In particular, we can think of B[i] true indicating the start of a + // sequence of equal values in the sorted list. Similarly, we will also + // store the negation of B, which we'll call C. In particular, we can think + // of C[i] = true iff A[i-1] == A[i] in our original sorted slice. + // + // C = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0] + + // We overwrite bmem, and treat the rest of shared memory as a buffer of + // (index, flag) pairs where the index represents values from C, and the + // flag represents values from B. + // + // [smem (sorted slice) | ubpmem (index, flag pairs)] + + struct ModeUnsignedBoolPair* ubpmem = + reinterpret_cast(&smem[Power2Size]); + + if (tidx == 0) { + ubpmem[0].flag = true; + ubpmem[0].val = 0; + } - // sort - slm_value_indice_ptr[item_id] = - sycl::ext::oneapi::experimental::sort_over_group( - sycl::ext::oneapi::experimental::group_with_scratchpad( - item.get_group(), - sycl::span{ - sort_scratch_pointer, - static_cast(sort_scratch_memory_size)}), - slm_value_indice_ptr[item_id], - [&](const ModeOpValueIndex& A, - const ModeOpValueIndex& B) { - return A.value < B.value; - }); - item.barrier(sycl_local_fence); + // Compares elements (0, 1), (2, 3), ... and sets 1, 3, ... + ubpmem[tidx * 2 + 1].flag = + smem[tidx * 2] != smem[tidx * 2 + 1]; // (0, 1), (1, 2), etc. + ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag; - // compare and compute the status/value using 0/1. - // sorted values: 0 0 1 1 1 3 4 4 - // sorted indices: 4 6 0 5 7 2 1 3 - // slm helper status: 0 1 0 1 1 0 0 1 - // slm helper value: 0 1 0 1 1 0 0 1 - // 0 means a new sequence. 1 means the value is duplicated with the pre one. - // kick out the first one - if (item_id == 0) { - slm_helper_ptr[item_id].status = 0; - slm_helper_ptr[item_id].value = 0; - } else { - auto judgeEqual = bool( - slm_value_indice_ptr[item_id - 1].value == - slm_value_indice_ptr[item_id].value); - auto status = (judgeEqual) ? (1) : (0); - slm_helper_ptr[item_id].status = status; - slm_helper_ptr[item_id].value = status; + // Compares elements (1, 2), (3, 4), ... and sets 2, 4, ... + if (((tidx + 1) * 2) < Power2Size) { + ubpmem[(tidx + 1) * 2].flag = + smem[((tidx + 1) * 2) - 1] != smem[(tidx + 1) * 2]; + ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag; } + item.barrier(sycl_local_fence); // barrier for ubpmem initialization + + // Next, we perform a segmented prefix sum on the neighboring elements, + // where + // the presence of a one indicates the start of a segment. In this case B + // acts as the segment start flags, and C is the buffer to be summed: + // + // Input (C) = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0] + // Flag (B) = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1] + // Output (C) = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0] + // + // Afterwards, the (index) components of the ubpmem buffer contain the + // lengths of the segments (minus 1), i.e. the counts of each element in the + // original input. + inclusivePrefixScan( + item, ubpmem, InclusivePrefixScanFunctor()); + // assumes scan syncs at the end + + // Next, we reinterpret the ubpmem buffer as pairs of unsigned integers + // (i.e. we treat the boolean flag regions as integers). We initialize these + // to represent indices, and we'll call this buffer I + struct ModeUnsignedPair* uupmem = + reinterpret_cast(ubpmem); + + // At this point, we need to find the maximum element in lengths buffer C. + // This element will represent the count (-1) of the mode. Because of the + // way we have set up the problem, the index where this mode occurs will + // also be the location of the mode value in the sorted array, e.g. + // + // smem = [0, 0, 1, 1, 1, 2] + // C = [0, 1, 0, 1, 2, 0] + // I = [0, 1, 2, 3, 4, 5] + // ^ + // maximum value, also aligned with mode = 1 + // + // We perform a group wide max-reduction of the C buffer, but we also need + // the indices to come along with it, so we utilize the uupmem construction. + // + // At the end we need to return the ModeUnsignedPair containing index = 4, + // val = 2, which represents the max + + // In practice, we will make each thread locally reduce 2 values in its + // registers prior to the global group-wide reduction. Note that instead of + // tidx/stidx, we utilize tidx * 2, tidx * 2 + 1, so each thread deals with + // adjacent elements. This is because the reduce code below relies on thread + // elements to be adjacent. + struct ModeUnsignedPair uup[2]; + uup[0].index = tidx * 2; + uup[0].val = ubpmem[tidx * 2].val; + uup[1].index = tidx * 2 + 1; + uup[1].val = ubpmem[tidx * 2 + 1].val; item.barrier(sycl_local_fence); - // index: 0 1 2 3 4 5 6 7 - // slm helper status: 0 1 0 1 1 0 0 1 - // slm helper value: 0 1 0 1 1 0 0 1 - // conditional scan rule: - // 1. according to current status, if 0, keep the current value. If 1, do - // accumulation - // 2. new status = previous scratch status & current scratch status to - // record if there is duplicated value. - // The conditional inclusive scan result value should be: 0 1 0 1 2 0 0 1 - ConditionalInclusiveScanForFusedMode( - slm_helper_ptr, - [&](const ModeOpHelper& Pre, const ModeOpHelper& Cur) { - ModeOpHelper Temp{}; - Temp.value = Cur.status ? (Pre.value + Cur.value) : (Cur.value); - Temp.status = Pre.status & Cur.status; - return Temp; - }, - problem_size, - item); - item.barrier(sycl_local_fence); + struct ModeUnsignedPair max = {0, 0}; - // truncate the status with value and use it for reduce - // status: 0 1 0 1 2 0 0 1 - // value: 0 1 0 1 2 0 0 1 - // [watch out] status is used for following reduce and now has totally same - // number with value - slm_helper_ptr[item_id].status = slm_helper_ptr[item_id].value; - item.barrier(sycl_local_fence); + struct MaxOp { + inline ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) + const { + return b.val > a.val ? b : a; + } + } max_op; - // index: 0 1 2 3 [4] 5 6 7 - // conditional inclusive scan result: 0 1 0 1 2 0 0 1 - // [watch out] reduce: ^ <- reduce to get index 4 - // reduce_min_index means this position contains the first-appeared and - // most-appeared element's index. - ReduceGetMaxElemIndexForFusedMode(slm_helper_ptr, wg_size, item); - item.barrier(sycl_local_fence); + max = reduceGroupWithNThreadLocalReductions<2>( + item, uupmem, uup, sliceSize_, max_op, max); - // only one elem operation - if (item_id == 0) { - auto reduce_min_index = slm_helper_ptr[0].value; - - // index: 0 1 2 3 [4] 5 6 7 - // sorted indices: 4 6 0 5 [7] 2 1 3 - // according to the reduce_min_index, find out the first-appeared and - // most-appeared element's original indices, is 7 - auto answer_mode_index = slm_value_indice_ptr[reduce_min_index].index; - - // index: 0 1 2 3 [4] 5 6 7 - // sorted values: 0 0 1 1 [1] 3 4 4 - // find out the most-appeared value is 1 - auto answer_mode_value = slm_value_indice_ptr[reduce_min_index].value; - - // write back - auto output_index = - IndexToOffset::get(outer_id, answer_values); - answer_values.data[output_index] = answer_mode_value; - answer_indices.data[output_index] = answer_mode_index; + // Given the above constraints, the mode is the value at the reduced index + // in the original sorted element buffer + if (tidx == 0) { + mode_[0] = smem[max.index]; + } + item.barrier(sycl_local_fence); // broadcast mode + + // Finally, we need to find "an" index of the mode in the input + // Tensor. The API does not constrain which index we pick, but here + // we always pick the largest index. We store the index if the value + // is the mode, or 0 otherwise. Then find the maximum value. + // + // Again we reduce 2 elements in the thread's registers prior to the + // group-wide reduction + unsigned mode_index[2] = {0u, 0u}; + if (tidx * 2 < sliceSize_) { + const unsigned idx = tidx * 2; + mode_index[0] = + c10::load(&input_[linearOffset + idx]) == mode_[0] ? idx : 0u; + } + if (tidx * 2 + 1 < sliceSize_) { + const unsigned idx = tidx * 2 + 1; + mode_index[1] = + c10::load(&input_[linearOffset + idx]) == mode_[0] ? idx : 0u; } - item.barrier(sycl_local_fence); - } -} -template -struct ModeKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { - void operator()(sycl::nd_item<1> item) const { - mode_impl( - problem_values_ptr_, - problem_indices_ptr_, - values_info_, - indices_info_, - static_cast( - slm_.template get_multi_ptr().get()), - scratch_status_ptr_, - scratch_value_ptr_, - problem_time_, - problem_size_, - group_number_, - group_size_, - problem_upper_limit_, - item); + struct MaxIndexOp { + inline unsigned combine(unsigned a, unsigned b) const { + return b > a ? b : a; + } + } max_index_op; + + int64_t index = reduceGroupWithNThreadLocalReductions<2>( + item, + reinterpret_cast( + shmem_.template get_multi_ptr().get()), + mode_index, + sliceSize_, + max_index_op, + 0u); + + // Finally, we have the mode, and an index where it occurs. We use a single + // thread to place this in the appropriate output position + if (tidx == 0) { + unsigned int outputOffset = + at::xpu::detail::IndexToOffset::get( + groupId, values_); + values_.data[outputOffset] = mode_[0]; + indices_.data[outputOffset] = index; + } } void sycl_ker_config_convention(sycl::handler& cgh) { - // SLM(group size) is used for adjecent element comparing - slm_ = sycl_local_acc_t(group_size_, cgh); + shmem_ = sycl_local_acc_t(memsize_, cgh); + mode_ = sycl_local_acc_t(1, cgh); } - ModeKernelFunctor( - const scalar_t* problem_values_ptr, - const int64_t* problem_indices_ptr, - TensorInfo values_info, - TensorInfo indices_info, - int64_t* scratch_status_ptr, - int64_t* scratch_value_ptr, - int64_t problem_time, - int64_t problem_size, - int64_t group_number, - int64_t group_size, - int64_t problem_upper_limit) - : problem_values_ptr_(problem_values_ptr), - problem_indices_ptr_(problem_indices_ptr), - values_info_(values_info), - indices_info_(indices_info), - scratch_status_ptr_(scratch_status_ptr), - scratch_value_ptr_(scratch_value_ptr), - problem_time_(problem_time), - problem_size_(problem_size), - group_number_(group_number), - group_size_(group_size), - problem_upper_limit_(problem_upper_limit) {} + ComputeModeKernelFunctor( + const T* input, + at::xpu::detail::TensorInfo values, + at::xpu::detail::TensorInfo indices, + int64_t sliceSize, + int64_t slices, + int64_t memsize) + : input_(input), + values_(values), + indices_(indices), + sliceSize_(sliceSize), + slices_(slices), + memsize_(memsize) {} private: - const scalar_t* problem_values_ptr_; - const int64_t* problem_indices_ptr_; - TensorInfo values_info_; - TensorInfo indices_info_; - int64_t* scratch_status_ptr_; - int64_t* scratch_value_ptr_; - int64_t problem_time_; - int64_t problem_size_; - int64_t group_number_; - int64_t group_size_; - int64_t problem_upper_limit_; - - sycl_local_acc_t slm_; + const T* input_; + at::xpu::detail::TensorInfo values_; + at::xpu::detail::TensorInfo indices_; + int64_t sliceSize_; + int64_t slices_; + int64_t memsize_; + sycl_local_acc_t shmem_; + sycl_local_acc_t mode_; }; -template -struct ModeFusedKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { - void operator()(sycl::nd_item<1> item) const { - mode_fused_impl( - problem_values_ptr_, - values_info_, - indices_info_, - static_cast( - slm_helper_.template get_multi_ptr() - .get()), - static_cast*>( - slm_value_indice_ - .template get_multi_ptr() - .get()), - static_cast( - sort_scratch_.template get_multi_ptr() - .get()), - sort_scratch_memory_size_, - problem_time_, - problem_size_, - group_number_, - group_size_, - item); - } - - void sycl_ker_config_convention(sycl::handler& cgh) { - // SLM used for record status for mode - slm_helper_ = sycl_local_acc_t(group_size_, cgh); - - // SLM used for store value and its associated indice - slm_value_indice_ = - sycl_local_acc_t, 1>(group_size_, cgh); +template +void handle_fused_mode( + std::tuple nwgs, + const TensorBase& self, + at::xpu::detail::TensorInfo& ti_values, + at::xpu::detail::TensorInfo& ti_indices, + int64_t slice_size, + int64_t slices) { + constexpr int num_threads = size / 2; + constexpr int sg_size = 32; + TORCH_INTERNAL_ASSERT( + num_threads % sg_size == 0 && num_threads <= (sg_size * sg_size), ""); + const auto memsize = + (sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int)); + auto gx = std::get<0>(nwgs); + auto gy = std::get<1>(nwgs); + auto gz = std::get<2>(nwgs); + sycl::range<3> local_range(1, 1, num_threads); + sycl::range<3> global_range(gz, gy, gx * num_threads); + auto caller = ComputeModeKernelFunctor( + self.const_data_ptr(), + ti_values, + ti_indices, + slice_size, + slices, + memsize); + sycl_kernel_submit(global_range, local_range, getCurrentSYCLQueue(), caller); +} - // SLM used for sort - sort_scratch_ = - sycl_local_acc_t(sort_scratch_memory_size_, cgh); +template +void fused_mode( + const TensorBase& values, + const TensorBase& indices, + const TensorBase& self, + int64_t slice_size, + int64_t slices) { + // Set-up TensorInfo structs for passing to kernel + auto ti_values = + at::xpu::detail::getTensorInfo(values); + auto ti_indices = + at::xpu::detail::getTensorInfo(indices); + + // The number of work group is the number of slices that we need to calculate + // the mode for. Each group is responsible for computing a single mode + auto nwgs = get_workgroup_number_from_tiles(slices); + + // The groupsize is two elements per thread, rounded up to the nearest power + // of 2 + auto ceilPowerOf2 = next_highest_power_of_2(slice_size); + + // Tradeoff between compilation time and the number of specializations. + // Ideally we would have one handle_fused_mode for each power of 2 + switch (ceilPowerOf2) { + case 2048: + handle_fused_mode<2048, scalar_t>( + nwgs, self, ti_values, ti_indices, slice_size, slices); + break; + case 1024: + case 512: + case 256: + handle_fused_mode<1024, scalar_t>( + nwgs, self, ti_values, ti_indices, slice_size, slices); + break; + case 128: + case 64: + case 32: + case 16: + case 8: + case 4: + case 2: + handle_fused_mode<128, scalar_t>( + nwgs, self, ti_values, ti_indices, slice_size, slices); + break; + case 1: + default: + TORCH_INTERNAL_ASSERT(false); } +} - ModeFusedKernelFunctor( - const scalar_t* problem_values_ptr, - TensorInfo values_info, - TensorInfo indices_info, - int64_t sort_scratch_memory_size, - int64_t problem_time, - int64_t problem_size, - int64_t group_number, - int64_t group_size) - : problem_values_ptr_(problem_values_ptr), - values_info_(values_info), - indices_info_(indices_info), - sort_scratch_memory_size_(sort_scratch_memory_size), - problem_time_(problem_time), - problem_size_(problem_size), - group_number_(group_number), - group_size_(group_size) {} - - private: - const scalar_t* problem_values_ptr_; - TensorInfo values_info_; - TensorInfo indices_info_; - int64_t sort_scratch_memory_size_; - int64_t problem_time_; - int64_t problem_size_; - int64_t group_number_; - int64_t group_size_; - - sycl_local_acc_t slm_helper_; - sycl_local_acc_t, 1> slm_value_indice_; - sycl_local_acc_t sort_scratch_; -}; +void launch_fused_mode_kernel( + const TensorBase& values, + const TensorBase& indices, + const TensorBase& self, + int64_t slice_size, + int64_t slices) { + AT_DISPATCH_ALL_TYPES_AND3( + kBool, kBFloat16, kHalf, self.scalar_type(), "xpu_mode", [&] { + fused_mode(values, indices, self, slice_size, slices); + }); +} -/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * -The answer rule of the cornor condition is: -1. indice need to be the max indice of the most-appeared value -2. if values appear same times, the returned value should be the smaller one -The implementation idea overview: -1. sort the input values -2. compare the adjecent value and record the status when checking equality -3. conditional scan to calculate the appear times for each kind of value -4. reduce to get the most-appeared value's most appear times -5. reduce again to get the most-appeared value's minimal indice -6. get the answer from the sorted value/indice according to the second reduce -result -* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ -template -void mode_kernel_impl( +void mode_kernel( Tensor& values, Tensor& indices, const Tensor& self, @@ -688,19 +564,16 @@ void mode_kernel_impl( bool keepdim) { auto self_sizes = ensure_nonempty_vec(self.sizes().vec()); int64_t ndim = ensure_nonempty_dim(self.dim()); - // problem size, the element size of the tensor at this dim - auto problem_size = ensure_nonempty_size(self, dim); - // calculation times needed for each problem - auto problem_time = self.numel() / problem_size; - - // make sure the passed dim does make sense - TORCH_CHECK( - (0 <= dim && static_cast(dim) < self_sizes.size()), - "The chosen dim should be between [0, ", - self_sizes.size() - 1, - "], but got unexpected ", - dim); - // problem dim suqeeze to 1 + int64_t slice_size = ensure_nonempty_size(self, dim); + int64_t slices = self.numel() / slice_size; + + bool use_fast_path = slice_size <= 2 * MAX_GROUP_SIZE && + slices <= MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE && + canUse32BitIndexMath(self); + + // Resize output value, index Tensors to appropriate sizes (i.e. the same as + // the input Tensor, except at dim=dimension, the size is 1) + assert(0 <= dim && static_cast(dim) < self_sizes.size()); self_sizes[dim] = 1; if (!keepdim) { @@ -712,13 +585,11 @@ void mode_kernel_impl( } } - // Resize output value, index Tensors to sizes after execution at::native::resize_output(values, self_sizes); at::native::resize_output(indices, self_sizes); - // If sliceSize is 1, it means the chosen dim has one value, - // then copy input to values and set indices to 0 - if (problem_size == 1) { + // If sliceSize is 1, copy input to values and set indices + if (slice_size == 1) { values.copy_(self); indices.fill_(0); if (!keepdim) { @@ -728,101 +599,47 @@ void mode_kernel_impl( return; } - // exchange the problem dim to the last dim for mem coalescing + if (!use_fast_path) { + const auto empty_cpu = [](const Tensor& t) { + return at::empty({0}, t.options().device(kCPU).pinned_memory(true)); + }; + auto values_ = empty_cpu(values); + auto indices_ = empty_cpu(indices); + const auto self_ = self.to(self.options().device(kCPU).pinned_memory(true)); + mode_stub(self_.device().type(), values_, indices_, self_, dim, keepdim); + if (!keepdim) { + values.squeeze_(dim); + indices.squeeze_(dim); + } + values.copy_(values_, /*non_blocking*/ true); + indices.copy_(indices_, /*non_blocking*/ true); + return; + } + + // Beginning our optimized implementation. First thing we want to do is to + // transpose the input Tensor along the sort dimension, and then make it + // contiguous. auto transposed = self.transpose(dim, ndim - 1); auto contiguous = transposed.contiguous(); + + // We also need to view the values and indices Tensors as transposed in order + // to properly determine the offset into the underlying storage in which to + // place the mode and index for a particular set of dimension values. auto values_transposed = values.transpose(dim, ndim - 1); auto indices_transposed = indices.transpose(dim, ndim - 1); - // max wg size - auto max_WG_Size = std::min( - syclMaxWorkGroupSize>(), - syclMaxWorkGroupSize>()); - - // one wg is responsible for one problem batch - auto group_number = problem_time; - - // When the problem size is larger than the max wg size, - // the wg is set the upper limitation of a wg size - if (problem_size > max_WG_Size) { - auto group_size = max_WG_Size; - - // sorted values and associated indices - auto sort_tuple_ret = at::sort( - contiguous, - /*stable*/ true, - /*dim*/ -1, /*descending*/ - false); - auto problem_values = std::get<0>(sort_tuple_ret); - auto problem_indices = std::get<1>(sort_tuple_ret); - - auto scratch_status_tensor = - at::zeros_like(self, TensorOptions(ScalarType::Long)); - auto scratch_value_tensor = - at::zeros_like(self, TensorOptions(ScalarType::Long)); - - auto values_info = getTensorInfo(values_transposed); - auto indices_info = getTensorInfo(indices_transposed); - - // be used to set the limitation for the inner loop wg - auto problem_upper_limit = ((problem_size % group_size) == 0) - ? (problem_size) - : ((problem_size / group_size + 1) * group_size); - - auto problem_values_ptr = problem_values.const_data_ptr(); - auto problem_indices_ptr = problem_indices.const_data_ptr(); - auto scratch_status_ptr = scratch_status_tensor.data_ptr(); - auto scratch_value_ptr = scratch_value_tensor.data_ptr(); - ModeKernelFunctor kfn( - problem_values_ptr, - problem_indices_ptr, - values_info, - indices_info, - scratch_status_ptr, - scratch_value_ptr, - problem_time, - problem_size, - group_number, - group_size, - problem_upper_limit); - sycl_kernel_submit( - group_number * group_size, group_size, getCurrentSYCLQueue(), kfn); - } else { - // problem_size <= max_WG_Size, wg size is set the problem size - auto group_size = problem_size; - - // scratch memory size needed by built-in sort -#if defined(__INTEL_LLVM_COMPILER_VERSION) && \ - __INTEL_LLVM_COMPILER_VERSION >= 20250000 - auto sort_scratch_memory_size = - sycl::ext::oneapi::experimental::default_sorters::group_sorter< - ModeOpValueIndex, - std::greater, - 1>::memory_required(sycl::memory_scope::work_group, group_size); -#else - auto sort_scratch_memory_size = sycl::ext::oneapi::experimental:: - default_sorter>::template memory_required< - ModeOpValueIndex>( - sycl::memory_scope::work_group, - sycl::range<1>{static_cast(group_size)}); -#endif - - auto values_info = getTensorInfo(values_transposed); - auto indices_info = getTensorInfo(indices_transposed); - - auto problem_values_ptr = contiguous.const_data_ptr(); - - ModeFusedKernelFunctor kfn( - problem_values_ptr, - values_info, - indices_info, - sort_scratch_memory_size, - problem_time, - problem_size, - group_number, - group_size); - sycl_kernel_submit( - group_number * group_size, group_size, getCurrentSYCLQueue(), kfn); + // Requirements for fused kernel implementation: + // + // 1. sliceSize <= 2 * max threads per group + // 2. uses one group per slice, so number of slices must be less than the + // maximum number of groups for a kernel launch + // 3. Can use 32-bit index math for indexing (mainly just for implementation + // conciseness, could be changed) + // + TORCH_INTERNAL_ASSERT(use_fast_path == true); + { + launch_fused_mode_kernel( + values_transposed, indices_transposed, contiguous, slice_size, slices); } if (!keepdim) { @@ -831,21 +648,4 @@ void mode_kernel_impl( } } -void mode_kernel( - Tensor& values, - Tensor& indices, - const Tensor& self, - int64_t dim, - bool keepdim) { - AT_DISPATCH_ALL_TYPES_AND3( - at::ScalarType::Bool, - at::ScalarType::Half, - at::ScalarType::BFloat16, - self.scalar_type(), - "mode_xpu", - [&]() { - mode_kernel_impl(values, indices, self, dim, keepdim); - }); -} - } // namespace at::native::xpu