From 0f70d73db8f6fe3e8d1c41a048c44fc52ea9cad9 Mon Sep 17 00:00:00 2001 From: Edgar Gabriel Date: Wed, 12 Oct 2022 09:08:21 -0500 Subject: [PATCH] EC/ROCM: add host execution capability (#609) Use host based reduction and copy operations for short messages. This avoids the cost of a kernel launch. Co-authored-by: valentin petrov --- src/components/ec/rocm/ec_rocm.c | 15 +++- src/components/ec/rocm/ec_rocm.h | 4 + src/components/ec/rocm/ec_rocm_executor.c | 43 +++++++++- .../ec/rocm/ec_rocm_executor_interruptible.c | 79 +++++++++++++++++++ .../ec/rocm/kernel/ec_rocm_executor_kernel.cu | 9 ++- .../ec/rocm/kernel/ec_rocm_wait_kernel.cu | 2 +- 6 files changed, 143 insertions(+), 9 deletions(-) diff --git a/src/components/ec/rocm/ec_rocm.c b/src/components/ec/rocm/ec_rocm.c index cae2d81a7c..4ef65a5567 100644 --- a/src/components/ec/rocm/ec_rocm.c +++ b/src/components/ec/rocm/ec_rocm.c @@ -65,7 +65,7 @@ static ucc_config_field_t ucc_ec_rocm_config_table[] = { ucc_offsetof(ucc_ec_rocm_config_t, exec_max_tasks), UCC_CONFIG_TYPE_ULUNITS}, - {"EXEC_NUM_STREAMS", "16", + {"EXEC_NUM_STREAMS", "8", "Number of streams used by interruptible executor", ucc_offsetof(ucc_ec_rocm_config_t, exec_num_streams), UCC_CONFIG_TYPE_ULUNITS}, @@ -75,6 +75,19 @@ static ucc_config_field_t ucc_ec_rocm_config_table[] = { ucc_offsetof(ucc_ec_rocm_config_t, reduce_num_blocks), UCC_CONFIG_TYPE_ULUNITS}, + {"REDUCE_HOST_LIMIT", "256", + "Maximum data size for which to use host-based reduction operations", + ucc_offsetof(ucc_ec_rocm_config_t, reduce_host_limit), + UCC_CONFIG_TYPE_MEMUNITS}, + + /* Disabled by default. + * Recommended settings: MI100: 64 bytes, MI200: 4kbytes + */ + {"COPY_HOST_LIMIT", "0", + "Maximum data size for which to use host-based copy operations", + ucc_offsetof(ucc_ec_rocm_config_t, copy_host_limit), + UCC_CONFIG_TYPE_MEMUNITS}, + {NULL} }; diff --git a/src/components/ec/rocm/ec_rocm.h b/src/components/ec/rocm/ec_rocm.h index db6c6f2c0a..c22370df03 100644 --- a/src/components/ec/rocm/ec_rocm.h +++ b/src/components/ec/rocm/ec_rocm.h @@ -10,6 +10,7 @@ #include "components/ec/base/ucc_ec_base.h" #include "components/ec/ucc_ec_log.h" +#include "core/ucc_ee.h" #include "utils/ucc_mpool.h" #include "utils/arch/rocm_def.h" #include @@ -80,6 +81,8 @@ typedef struct ucc_ec_rocm_config { unsigned long exec_num_streams; unsigned long reduce_num_blocks; int reduce_num_threads; + int reduce_host_limit; + int copy_host_limit; } ucc_ec_rocm_config_t; typedef struct ucc_ec_rocm { @@ -97,6 +100,7 @@ typedef struct ucc_ec_rocm { ucc_ec_rocm_task_stream_type_t task_strm_type; ucc_ec_rocm_task_post_fn post_strm_task; ucc_spinlock_t init_spinlock; + ucc_ee_executor_t *cpu_executor; } ucc_ec_rocm_t; typedef struct ucc_rocm_ec_event { diff --git a/src/components/ec/rocm/ec_rocm_executor.c b/src/components/ec/rocm/ec_rocm_executor.c index 03b7d8d00f..6861a5ff8c 100644 --- a/src/components/ec/rocm/ec_rocm_executor.c +++ b/src/components/ec/rocm/ec_rocm_executor.c @@ -6,6 +6,7 @@ */ #include "ec_rocm_executor.h" +#include "components/ec/ucc_ec.h" ucc_status_t ucc_rocm_executor_persistent_start(ucc_ee_executor_t *executor, void *ee_context); @@ -19,7 +20,12 @@ ucc_status_t ucc_rocm_executor_interruptible_stop(ucc_ee_executor_t *executor); ucc_status_t ucc_rocm_executor_init(const ucc_ee_executor_params_t *params, ucc_ee_executor_t **executor) { - ucc_ec_rocm_executor_t *eee = ucc_mpool_get(&ucc_ec_rocm.executors); + ucc_ec_rocm_executor_t *eee = ucc_mpool_get(&ucc_ec_rocm.executors); + ucc_status_t status; + ucc_ee_executor_params_t cpu_params = { + .mask = UCC_EE_EXECUTOR_PARAM_FIELD_TYPE, + .ee_type = UCC_EE_CPU_THREAD + }; if (ucc_unlikely(!eee)) { ec_error(&ucc_ec_rocm.super, "failed to allocate executor"); @@ -30,6 +36,12 @@ ucc_status_t ucc_rocm_executor_init(const ucc_ee_executor_params_t *params, eee->super.ee_type = params->ee_type; eee->state = UCC_EC_ROCM_EXECUTOR_INITIALIZED; + status = ucc_ee_executor_init(&cpu_params, &ucc_ec_rocm.cpu_executor); + if (status != UCC_OK) { + ec_error(&ucc_ec_rocm.super, + "Error initializing CPU executor from ROCm component"); + } + *executor = &eee->super; return UCC_OK; } @@ -56,12 +68,19 @@ ucc_status_t ucc_rocm_executor_finalize(ucc_ee_executor_t *executor) { ucc_ec_rocm_executor_t *eee = ucc_derived_of(executor, ucc_ec_rocm_executor_t); + ucc_status_t status; ec_debug(&ucc_ec_rocm.super, "executor free, eee: %p", eee); ucc_assert(eee->state == UCC_EC_ROCM_EXECUTOR_INITIALIZED); ucc_mpool_put(eee); - return UCC_OK; + status = ucc_ee_executor_finalize(ucc_ec_rocm.cpu_executor); + if (status != UCC_OK) { + ec_error(&ucc_ec_rocm.super, + "Error finalizing CPU executor from ROCm component"); + } + + return status; } ucc_status_t ucc_rocm_executor_task_post(ucc_ee_executor_t *executor, @@ -91,7 +110,16 @@ ucc_status_t ucc_rocm_executor_task_finalize(ucc_ee_executor_task_t *task) ucc_status_t ucc_rocm_executor_start(ucc_ee_executor_t *executor, void *ee_context) { - if (!ee_context) { + ucc_status_t status; + + status = ucc_ee_executor_start(ucc_ec_rocm.cpu_executor, ee_context); + if (status != UCC_OK) { + ec_error(&ucc_ec_rocm.super, + "Error starting CPU executor from ROCm component"); + return status; + } + + if (!ee_context) { return ucc_rocm_executor_interruptible_start(executor); } else { return ucc_rocm_executor_persistent_start(executor, ee_context); @@ -102,6 +130,15 @@ ucc_status_t ucc_rocm_executor_stop(ucc_ee_executor_t *executor) { ucc_ec_rocm_executor_t *eee = ucc_derived_of(executor, ucc_ec_rocm_executor_t); + ucc_status_t status; + + status = ucc_ee_executor_stop(ucc_ec_rocm.cpu_executor); + if (status != UCC_OK) { + ec_error(&ucc_ec_rocm.super, + "Error stopping CPU executor from ROCm component"); + return status; + } + if (eee->mode == UCC_EC_ROCM_EXECUTOR_MODE_INTERRUPTIBLE) { return ucc_rocm_executor_interruptible_stop(executor); } else { diff --git a/src/components/ec/rocm/ec_rocm_executor_interruptible.c b/src/components/ec/rocm/ec_rocm_executor_interruptible.c index 79eb98c36b..0fd438726f 100644 --- a/src/components/ec/rocm/ec_rocm_executor_interruptible.c +++ b/src/components/ec/rocm/ec_rocm_executor_interruptible.c @@ -7,8 +7,78 @@ #include "ec_rocm_executor.h" #include "components/mc/ucc_mc.h" +#include "components/ec/ucc_ec.h" #include "utils/ucc_atomic.h" +static bool ucc_ec_rocm_copy_multi_use_host (const ucc_ee_executor_task_args_t* task_args) +{ + bool result = true; + + for (int i = 0; i < task_args->copy_multi.num_vectors; i++) { + if (task_args->copy_multi.counts[i] > EC_ROCM_CONFIG->copy_host_limit) { + result = false; + break; + } + } + + return result; +} + +static int ucc_ec_rocm_total_reduce_len(const ucc_ee_executor_task_args_t* task_args) +{ + int total_len = 0; + ucc_datatype_t dt; + size_t count; + + if (task_args->task_type == UCC_EE_EXECUTOR_TASK_REDUCE) { + dt = task_args->reduce.dt; + count = task_args->reduce.count; + } else { + ucc_assert(task_args->task_type == UCC_EE_EXECUTOR_TASK_REDUCE_STRIDED); + dt = task_args->reduce_strided.dt; + count = task_args->reduce_strided.count; + } + total_len += count * ucc_dt_size(dt); + + return total_len; +} + +static bool ucc_ec_rocm_host_dt_supported(const ucc_ee_executor_task_args_t* task_args) +{ + bool result = false; + ucc_datatype_t dt; + + if (task_args->task_type == UCC_EE_EXECUTOR_TASK_REDUCE) { + dt = task_args->reduce.dt; + } else { + ucc_assert(task_args->task_type == UCC_EE_EXECUTOR_TASK_REDUCE_STRIDED); + dt = task_args->reduce_strided.dt; + } + if (dt != UCC_DT_BFLOAT16 && + dt != UCC_DT_FLOAT16 && + dt != UCC_DT_FLOAT32_COMPLEX && + dt != UCC_DT_FLOAT64_COMPLEX) { + result = true; + } + return result; +} + +static inline +bool ec_rocm_use_host_ops(const ucc_ee_executor_task_args_t *_task_args) +{ + if ( (_task_args->task_type == UCC_EE_EXECUTOR_TASK_COPY && + _task_args->copy.len <= EC_ROCM_CONFIG->copy_host_limit) || + (_task_args->task_type == UCC_EE_EXECUTOR_TASK_COPY_MULTI && + ucc_ec_rocm_copy_multi_use_host(_task_args)) || + ((_task_args->task_type == UCC_EE_EXECUTOR_TASK_REDUCE || + _task_args->task_type == UCC_EE_EXECUTOR_TASK_REDUCE_STRIDED) && + ucc_ec_rocm_total_reduce_len(_task_args) <= EC_ROCM_CONFIG->reduce_host_limit && + ucc_ec_rocm_host_dt_supported(_task_args) )) { + return true; + } + return false; +} + ucc_status_t ucc_rocm_executor_interruptible_get_stream(hipStream_t *stream) { static uint32_t last_used = 0; @@ -56,6 +126,15 @@ ucc_rocm_executor_interruptible_task_post(ucc_ee_executor_t *executor, hipStream_t stream; ucc_status_t status; + if (ec_rocm_use_host_ops(task_args)) { + status = ucc_ee_executor_task_post(ucc_ec_rocm.cpu_executor, task_args, + task); + if (ucc_unlikely(status != UCC_OK)) { + ec_error(&ucc_ec_rocm.super, "failed to execute host ops from ROCm component"); + } + return status; + } + status = ucc_rocm_executor_interruptible_get_stream(&stream); if (ucc_unlikely(status != UCC_OK)) { return status; diff --git a/src/components/ec/rocm/kernel/ec_rocm_executor_kernel.cu b/src/components/ec/rocm/kernel/ec_rocm_executor_kernel.cu index 15e86ab8f4..8b168f89e4 100644 --- a/src/components/ec/rocm/kernel/ec_rocm_executor_kernel.cu +++ b/src/components/ec/rocm/kernel/ec_rocm_executor_kernel.cu @@ -46,9 +46,11 @@ __device__ void executor_copy_aligned(T* __restrict__ d, T* __restrict__ s, char1 *s1 = (char1*)s; char1 *d1 = (char1*)d; -#pragma unroll 4 - for(int i = 0; i < num_iter; i++) { - d[i * step + idx] = s[i * step + idx]; + for(int i = 0; i < num_iter; i+=4) { + d[i * step + idx] = s[i * step + idx]; + d[(i+1) * step + idx] = s[(i+1) * step + idx]; + d[(i+2) * step + idx] = s[(i+2) * step + idx]; + d[(i+3) * step + idx] = s[(i+3) * step + idx]; } if (idx < count % sizeof(T)) { @@ -229,7 +231,6 @@ __device__ void executor_copy_multi(ucc_eee_task_copy_multi_t *task) const int num_iter = n / step + ((threadIdx.x < n % step) ? 1 : 0); for (size_t i = 0; i < num_iter; i++) { -#pragma unroll for (int j = 0; j < task->num_vectors; j++) { dsts[j][idx] = srcs[j][idx]; } diff --git a/src/components/ec/rocm/kernel/ec_rocm_wait_kernel.cu b/src/components/ec/rocm/kernel/ec_rocm_wait_kernel.cu index a1d7055ad3..b33edff060 100644 --- a/src/components/ec/rocm/kernel/ec_rocm_wait_kernel.cu +++ b/src/components/ec/rocm/kernel/ec_rocm_wait_kernel.cu @@ -12,7 +12,7 @@ __global__ void wait_kernel(volatile uint32_t *status) { *status = UCC_EC_ROCM_TASK_STARTED; do { st = (ucc_status_t)*status; - } while(st != UCC_EC_ROCM_TASK_COMPLETED); + } while(st != (ucc_status_t)UCC_EC_ROCM_TASK_COMPLETED); *status = UCC_EC_ROCM_TASK_COMPLETED_ACK; return; }