diff --git a/CHANGELOG.md b/CHANGELOG.md index 829a5bdd5..b843b5371 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,12 @@ Documentation for rocPRIM is available at [https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/). -## Unreleased rocPRIM-3.2.0 for ROCm 6.2.0 +## rocPRIM-3.2.1 for ROCm 6.2.1 + +### Optimizations +* Improved performance of block_reduce_warp_reduce when warp size == block size. + +## rocPRIM-3.2.0 for ROCm 6.2.0 ### Additions diff --git a/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp b/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp index 590b57220..a6957f6a8 100644 --- a/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp +++ b/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp @@ -178,21 +178,25 @@ class block_reduce_warp_reduce input, output, num_valid, reduce_op ); - // i-th warp will have its partial stored in storage_.warp_partials[i-1] - if(lane_id == 0) + // Final reduction across warps is only required if there is more than 1 warp + if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1) { - storage_.warp_partials[warp_id] = output; - } - ::rocprim::syncthreads(); - - if(flat_tid < warps_no_) - { - // Use warp partial to calculate the final reduce results for every thread - auto warp_partial = storage_.warp_partials[lane_id]; - - warp_reduce( - warp_partial, output, warps_no_, reduce_op - ); + // i-th warp will have its partial stored in storage_.warp_partials[i-1] + if(lane_id == 0) + { + storage_.warp_partials[warp_id] = output; + } + ::rocprim::syncthreads(); + + if(flat_tid < warps_no_) + { + // Use warp partial to calculate the final reduce results for every thread + auto warp_partial = storage_.warp_partials[lane_id]; + + warp_reduce( + warp_partial, output, warps_no_, reduce_op + ); + } } } @@ -244,22 +248,26 @@ class block_reduce_warp_reduce input, output, num_valid, reduce_op ); - // i-th warp will have its partial stored in storage_.warp_partials[i-1] - if(lane_id == 0) + // Final reduction across warps is only required if there is more than 1 warp + if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1) { - storage_.warp_partials[warp_id] = output; - } - ::rocprim::syncthreads(); - - if(flat_tid < warps_no_) - { - // Use warp partial to calculate the final reduce results for every thread - auto warp_partial = storage_.warp_partials[lane_id]; - - unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_; - warp_reduce_output_type().reduce( - warp_partial, output, valid_warps_no, reduce_op - ); + // i-th warp will have its partial stored in storage_.warp_partials[i-1] + if(lane_id == 0) + { + storage_.warp_partials[warp_id] = output; + } + ::rocprim::syncthreads(); + + if(flat_tid < warps_no_) + { + // Use warp partial to calculate the final reduce results for every thread + auto warp_partial = storage_.warp_partials[lane_id]; + + unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_; + warp_reduce_output_type().reduce( + warp_partial, output, valid_warps_no, reduce_op + ); + } } } };