From 2d9f3845331fc9a1072881874975268d9062d085 Mon Sep 17 00:00:00 2001 From: Nick Sarkauskas Date: Wed, 8 Nov 2023 09:00:51 -0800 Subject: [PATCH] Add sliding window support to urom cl --- src/components/cl/urom/Makefile.am | 7 +- src/components/cl/urom/allreduce/allreduce.c | 273 ++++++++++++++++++- src/components/cl/urom/allreduce/allreduce.h | 38 +-- src/components/cl/urom/cl_urom_coll.c | 6 + src/components/cl/urom/cl_urom_coll.h | 12 + src/components/cl/urom/cl_urom_context.c | 2 +- src/components/cl/urom/cl_urom_team.c | 2 +- 7 files changed, 302 insertions(+), 38 deletions(-) diff --git a/src/components/cl/urom/Makefile.am b/src/components/cl/urom/Makefile.am index 8562ca5e58..331b00d015 100644 --- a/src/components/cl/urom/Makefile.am +++ b/src/components/cl/urom/Makefile.am @@ -6,6 +6,10 @@ alltoall = \ alltoall/alltoall.h \ alltoall/alltoall.c +allreduce = \ + allreduce/allreduce.h \ + allreduce/allreduce.c + sources = \ cl_urom.h \ cl_urom.c \ @@ -13,7 +17,8 @@ sources = \ cl_urom_context.c \ cl_urom_team.c \ cl_urom_coll.c \ - $(alltoall) + $(alltoall) \ + $(allreduce) module_LTLIBRARIES = libucc_cl_urom.la libucc_cl_urom_la_SOURCES = $(sources) diff --git a/src/components/cl/urom/allreduce/allreduce.c b/src/components/cl/urom/allreduce/allreduce.c index c69cc4db36..3488ddfdee 100644 --- a/src/components/cl/urom/allreduce/allreduce.c +++ b/src/components/cl/urom/allreduce/allreduce.c @@ -8,16 +8,265 @@ #include "../allreduce/allreduce.h" ucc_base_coll_alg_info_t - ucc_cl_hier_allreduce_algs[UCC_CL_HIER_ALLREDUCE_ALG_LAST + 1] = { - [UCC_CL_HIER_ALLREDUCE_ALG_RAB] = - {.id = UCC_CL_HIER_ALLREDUCE_ALG_RAB, - .name = "rab", - .desc = "intra-node reduce, followed by inter-node allreduce," - " followed by innode broadcast"}, - [UCC_CL_HIER_ALLREDUCE_ALG_SPLIT_RAIL] = - {.id = UCC_CL_HIER_ALLREDUCE_ALG_SPLIT_RAIL, - .name = "split_rail", - .desc = "intra-node reduce_scatter, followed by PPN concurrent " - " inter-node allreduces, followed by intra-node allgather"}, - [UCC_CL_HIER_ALLREDUCE_ALG_LAST] = { + ucc_cl_urom_allreduce_algs[UCC_CL_UROM_ALLREDUCE_ALG_LAST + 1] = { + [UCC_CL_UROM_ALLREDUCE_ALG_FULL] = + {.id = UCC_CL_UROM_ALLREDUCE_ALG_FULL, + .name = "urom_full_offload", + .desc = "full offload of allreduce"}, + [UCC_CL_UROM_ALLREDUCE_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; + + +static int buffer_export_ucc(ucp_context_h ucp_context, void *buf, size_t len, + struct export_buf *ebuf) +{ + ucs_status_t ucs_status; + ucp_mem_map_params_t params; + ucp_memh_pack_params_t pack_params; + + ebuf->ucp_context = ucp_context; + + params.field_mask = + UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH; + params.address = buf; + params.length = len; + + ucs_status = ucp_mem_map(ucp_context, ¶ms, &ebuf->memh); + assert(ucs_status == UCS_OK); + + pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS; + pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT; + + ucs_status = ucp_memh_pack(ebuf->memh, &pack_params, &ebuf->packed_memh, + &ebuf->packed_memh_len); + if (ucs_status != UCS_OK) { + printf("ucp_memh_pack() returned error: %s\n", + ucs_status_string(ucs_status)); + ebuf->packed_memh = NULL; + ebuf->packed_memh_len = 0; + } + ucs_status = ucp_rkey_pack(ucp_context, ebuf->memh, &ebuf->packed_key, + &ebuf->packed_key_len); + if (UCS_OK != ucs_status) { + printf("ucp_rkey_pack() returned error: %s\n", + ucs_status_string(ucs_status)); + return UROM_ERR_NO_RESOURCE; + } + + return 0; +} + +ucc_status_t ucc_cl_urom_allreduce_triggered_post_setup(ucc_coll_task_t *task) +{ + return UCC_OK; +} + +static size_t dt_size(ucc_datatype_t ucc_dt) +{ + size_t size_mod = 8; + + switch(ucc_dt) { + case UCC_DT_INT8: + case UCC_DT_UINT8: + size_mod = sizeof(char); + break; + case UCC_DT_INT32: + case UCC_DT_UINT32: + case UCC_DT_FLOAT32: + size_mod = sizeof(int); + break; + case UCC_DT_INT64: + case UCC_DT_UINT64: + case UCC_DT_FLOAT64: + size_mod = sizeof(uint64_t); + break; + case UCC_DT_INT128: + case UCC_DT_UINT128: + case UCC_DT_FLOAT128: + size_mod = sizeof(__int128_t); + break; + default: + break; + } + + return size_mod; +} + +static ucc_status_t ucc_cl_urom_allreduce_full_start(ucc_coll_task_t *task) +{ + ucc_cl_urom_team_t *cl_team = ucc_derived_of(task->team, ucc_cl_urom_team_t); + ucc_cl_urom_context_t *ctx = UCC_CL_UROM_TEAM_CTX(cl_team); + ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t); + ucc_coll_args_t *coll_args = &task->bargs.args; + urom_status_t urom_status; + int ucp_index = cl_lib->tl_ucp_index; + ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of(ctx->super.tl_ctxs[ucp_index], ucc_tl_ucp_context_t); + urom_worker_cmd_t coll_cmd = { + .cmd_type = UROM_WORKER_CMD_UCC, + .ucc.dpu_worker_id = UCC_CL_TEAM_RANK(cl_team), + .ucc.cmd_type = UROM_WORKER_CMD_UCC_COLL, + .ucc.coll_cmd.coll_args = coll_args, + .ucc.coll_cmd.team = cl_team->teams[0], + .ucc.coll_cmd.use_xgvmi = 0, + .ucc.coll_cmd.use_sliding_window_allreduce = 1, + }; + ucc_memory_type_t prev_src, prev_dst; + ucc_cl_urom_schedule_t *schedule = + ucc_derived_of(task, ucc_cl_urom_schedule_t); + struct export_buf *src_ebuf = &schedule->src_ebuf; + struct export_buf *dst_ebuf = &schedule->dst_ebuf; + + prev_src = coll_args->src.info.mem_type; + prev_dst = coll_args->dst.info.mem_type; + coll_args->src.info.mem_type = UCC_MEMORY_TYPE_HOST; + coll_args->dst.info.mem_type = UCC_MEMORY_TYPE_HOST; + + buffer_export_ucc(tl_ctx->worker.ucp_context, coll_args->src.info.buffer, coll_args->src.info.count * dt_size(coll_args->src.info.datatype), src_ebuf); + buffer_export_ucc(tl_ctx->worker.ucp_context, coll_args->dst.info.buffer, coll_args->dst.info.count * dt_size(coll_args->dst.info.datatype), dst_ebuf); + + memcpy(coll_cmd.ucc.coll_cmd.src_rkey_packed, src_ebuf->packed_key, src_ebuf->packed_key_len); + coll_cmd.ucc.coll_cmd.src_rkey_packed_len = src_ebuf->packed_key_len; + + memcpy(coll_cmd.ucc.coll_cmd.dst_rkey_packed, dst_ebuf->packed_key, dst_ebuf->packed_key_len); + coll_cmd.ucc.coll_cmd.dst_rkey_packed_len = dst_ebuf->packed_key_len; + + urom_status = urom_worker_push_cmdq(cl_lib->urom_ctx.urom_worker, 0, &coll_cmd); + if (UROM_OK != urom_status) { + cl_debug(&cl_lib->super, "failed to push collective to urom"); + return UCC_ERR_NO_MESSAGE; + } + coll_args->src.info.mem_type = prev_src; + coll_args->dst.info.mem_type = prev_dst; + +/* + if (coll_args->src.info.mem_type != UCC_MEMORY_TYPE_CUDA) { + urom_status = urom_worker_push_cmdq(cl_lib->urom_ctx.urom_worker, 0, &coll_cmd); + if (UROM_OK != urom_status) { + cl_debug(&cl_lib->super, "failed to push collective to urom"); + return UCC_ERR_NO_MESSAGE; + } + } else { + coll_args->src.info.mem_type = UCC_MEMORY_TYPE_HOST; + coll_args->dst.info.mem_type = UCC_MEMORY_TYPE_HOST; + urom_status = urom_worker_push_cmdq(cl_lib->urom_ctx.urom_worker, 0, &coll_cmd); + if (UROM_OK != urom_status) { + cl_debug(&cl_lib->super, "failed to push collective to urom"); + return UCC_ERR_NO_MESSAGE; + } + coll_args->src.info.mem_type = UCC_MEMORY_TYPE_CUDA; + coll_args->dst.info.mem_type = UCC_MEMORY_TYPE_CUDA; + } +*/ + task->status = UCC_INPROGRESS; + cl_debug(&cl_lib->super, "pushed the collective to urom"); + return ucc_progress_queue_enqueue(ctx->super.super.ucc_context->pq, task); +} + +static ucc_status_t ucc_cl_urom_allreduce_full_finalize(ucc_coll_task_t *task) +{ + ucc_cl_urom_team_t *cl_team = ucc_derived_of(task->team, ucc_cl_urom_team_t); + ucc_cl_urom_context_t *ctx = UCC_CL_UROM_TEAM_CTX(cl_team); + ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t); + int ucp_index = cl_lib->tl_ucp_index; + ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of(ctx->super.tl_ctxs[ucp_index], ucc_tl_ucp_context_t); + ucc_status_t status; + ucc_cl_urom_schedule_t *schedule = + ucc_derived_of(task, ucc_cl_urom_schedule_t); + struct export_buf *src_ebuf = &schedule->src_ebuf; + struct export_buf *dst_ebuf = &schedule->dst_ebuf; + + ucp_mem_unmap(tl_ctx->worker.ucp_context, src_ebuf->memh); + ucp_mem_unmap(tl_ctx->worker.ucp_context, dst_ebuf->memh); + + status = ucc_schedule_finalize(task); + ucc_cl_urom_put_schedule(&schedule->super.super); + return status; +} + +static void ucc_cl_urom_allreduce_full_progress(ucc_coll_task_t *ctask) +{ + ucc_cl_urom_team_t *cl_team = ucc_derived_of(ctask->team, ucc_cl_urom_team_t); + ucc_cl_urom_context_t *ctx = UCC_CL_UROM_TEAM_CTX(cl_team); + ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t); + urom_status_t urom_status = 0; + urom_worker_notify_t *notif; + + urom_status = urom_worker_pop_notifyq(cl_lib->urom_ctx.urom_worker, 0, ¬if); + if (UROM_ERR_QUEUE_EMPTY == urom_status) { + return; + } + + if (urom_status < 0) { + cl_error(cl_lib, "Error in UROM"); + ctask->status = UCC_ERR_NO_MESSAGE; + return; + } + + if (notif->notify_type != UROM_WORKER_NOTIFY_UCC) { + cl_debug(cl_lib, "WRONG NOTIFICATION (%ld != %d)", notif->notify_type, UROM_WORKER_NOTIFY_UCC); + return; + } + + if (ctx->req_mc) { + size_t size_mod = dt_size(ctask->bargs.args.dst.info.datatype); + + if ((ucc_status_t) notif->ucc.status == UCC_OK) { + ucc_mc_memcpy(ctx->old_dest, ctask->bargs.args.dst.info.buffer, ctask->bargs.args.dst.info.count * size_mod, ctask->bargs.args.dst.info.mem_type, UCC_MEMORY_TYPE_HOST); + ctask->bargs.args.dst.info.buffer = ctx->old_dest; + ctask->bargs.args.src.info.buffer = ctx->old_src; + } + } + cl_debug(&cl_lib->super, "completed the collective from urom"); + + ctask->status = (ucc_status_t) notif->ucc.status; +} + +ucc_status_t ucc_cl_urom_allreduce_full_init( + ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, + ucc_coll_task_t **task) +{ + ucc_cl_urom_team_t *cl_team = ucc_derived_of(team, ucc_cl_urom_team_t); + ucc_cl_urom_context_t *ctx = UCC_CL_UROM_TEAM_CTX(cl_team); + ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t); + + ucc_cl_urom_schedule_t *cl_schedule; + ucc_base_coll_args_t args; + ucc_schedule_t *schedule; + ucc_status_t status; + + cl_schedule = ucc_cl_urom_get_schedule(cl_team); + if (ucc_unlikely(!cl_schedule)) { + return UCC_ERR_NO_MEMORY; + } + schedule = &cl_schedule->super.super; + if (ctx->req_mc) { + size_t size_mod = dt_size(coll_args->args.src.info.datatype); + size_t count = coll_args->args.src.info.count * size_mod; + //memcpy args to xgvmi buffer + printf("nick memcpy args to xgvmi buffer\n"); + void * ptr = ctx->xgvmi.xgvmi_buffer + (cl_lib->cfg.xgvmi_buffer_size * (schedule->super.seq_num % cl_lib->cfg.num_buffers)); + ucc_mc_memcpy(ptr, coll_args->args.src.info.buffer, count, UCC_MEMORY_TYPE_HOST, coll_args->args.src.info.mem_type); + + ctx->old_src = coll_args->args.src.info.buffer; + coll_args->args.src.info.buffer = ptr; + ctx->old_dest = coll_args->args.dst.info.buffer; + coll_args->args.dst.info.buffer = ptr + count; + } + memcpy(&args, coll_args, sizeof(args)); + status = ucc_schedule_init(schedule, &args, team); + if (UCC_OK != status) { + ucc_cl_urom_put_schedule(schedule); + return status; + } + + schedule->super.post = ucc_cl_urom_allreduce_full_start; + schedule->super.progress = ucc_cl_urom_allreduce_full_progress; + schedule->super.finalize = ucc_cl_urom_allreduce_full_finalize; + schedule->super.triggered_post = ucc_triggered_post; + schedule->super.triggered_post_setup = + ucc_cl_urom_allreduce_triggered_post_setup; + + *task = &schedule->super; + cl_debug(cl_lib, "urom coll init'd"); + return UCC_OK; +} diff --git a/src/components/cl/urom/allreduce/allreduce.h b/src/components/cl/urom/allreduce/allreduce.h index 0549f4ae3c..5d7ad880db 100644 --- a/src/components/cl/urom/allreduce/allreduce.h +++ b/src/components/cl/urom/allreduce/allreduce.h @@ -1,44 +1,36 @@ /** - * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ -#ifndef ALLREDUCE_H_ -#define ALLREDUCE_H_ -#include "../cl_hier.h" +#ifndef UROM_ALLREDUCE_H_ +#define UROM_ALLREDUCE_H_ +#include "../cl_urom_coll.h" +#include "../../../tl/ucp/tl_ucp.h" enum { - UCC_CL_HIER_ALLREDUCE_ALG_RAB, - UCC_CL_HIER_ALLREDUCE_ALG_SPLIT_RAIL, - UCC_CL_HIER_ALLREDUCE_ALG_LAST, + UCC_CL_UROM_ALLREDUCE_ALG_FULL, + UCC_CL_UROM_ALLREDUCE_ALG_LAST, }; extern ucc_base_coll_alg_info_t - ucc_cl_hier_allreduce_algs[UCC_CL_HIER_ALLREDUCE_ALG_LAST + 1]; + ucc_cl_urom_allreduce_algs[UCC_CL_UROM_ALLREDUCE_ALG_LAST + 1]; -#define UCC_CL_HIER_ALLREDUCE_DEFAULT_ALG_SELECT_STR "allreduce:0-4k:@rab" +ucc_status_t ucc_cl_urom_allreduce_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t * team, + ucc_coll_task_t ** task); -ucc_status_t ucc_cl_hier_allreduce_rab_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task); - -ucc_status_t -ucc_cl_hier_allreduce_split_rail_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task); - -static inline int ucc_cl_hier_allreduce_alg_from_str(const char *str) +static inline int ucc_cl_urom_allreduce_alg_from_str(const char *str) { int i; - - for (i = 0; i < UCC_CL_HIER_ALLREDUCE_ALG_LAST; i++) { - if (0 == strcasecmp(str, ucc_cl_hier_allreduce_algs[i].name)) { + for (i = 0; i < UCC_CL_UROM_ALLREDUCE_ALG_LAST; i++) { + if (0 == strcasecmp(str, ucc_cl_urom_allreduce_algs[i].name)) { break; } } return i; } -#endif +#endif \ No newline at end of file diff --git a/src/components/cl/urom/cl_urom_coll.c b/src/components/cl/urom/cl_urom_coll.c index 52decf3613..39fccaddc8 100644 --- a/src/components/cl/urom/cl_urom_coll.c +++ b/src/components/cl/urom/cl_urom_coll.c @@ -16,6 +16,10 @@ ucc_status_t ucc_cl_urom_alltoall_full_init( ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task); +ucc_status_t ucc_cl_urom_allreduce_full_init( + ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, + ucc_coll_task_t **task); + ucc_status_t ucc_cl_urom_coll_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task) @@ -49,6 +53,8 @@ ucc_status_t ucc_cl_urom_coll_init(ucc_base_coll_args_t *coll_args, switch (coll_args->args.coll_type) { case UCC_COLL_TYPE_ALLTOALL: return ucc_cl_urom_alltoall_full_init(coll_args, team, task); + case UCC_COLL_TYPE_ALLREDUCE: + return ucc_cl_urom_allreduce_full_init(coll_args, team, task); default: cl_error(urom_lib, "coll_type %s is not supported", ucc_coll_type_str(coll_args->args.coll_type)); break; diff --git a/src/components/cl/urom/cl_urom_coll.h b/src/components/cl/urom/cl_urom_coll.h index 544f3844d9..98a6ef39cd 100644 --- a/src/components/cl/urom/cl_urom_coll.h +++ b/src/components/cl/urom/cl_urom_coll.h @@ -16,9 +16,21 @@ extern const char *ucc_cl_urom_default_alg_select_str[UCC_CL_UROM_N_DEFAULT_ALG_SELECT_STR]; +struct export_buf { + ucp_context_h ucp_context; + ucp_mem_h memh; + void *packed_memh; + size_t packed_memh_len; + void *packed_key; + size_t packed_key_len; + uint64_t memh_id; +}; + typedef struct ucc_cl_urom_schedule_t { ucc_schedule_pipelined_t super; ucc_mc_buffer_header_t *scratch; + struct export_buf src_ebuf; + struct export_buf dst_ebuf; } ucc_cl_urom_schedule_t; static inline ucc_cl_urom_schedule_t * diff --git a/src/components/cl/urom/cl_urom_context.c b/src/components/cl/urom/cl_urom_context.c index ac62e75632..7b4bc27a4f 100644 --- a/src/components/cl/urom/cl_urom_context.c +++ b/src/components/cl/urom/cl_urom_context.c @@ -139,7 +139,7 @@ UCC_CLASS_INIT_FUNC(ucc_cl_urom_context_t, urom_domain_params.workers = &urom_lib->urom_ctx.urom_worker; urom_domain_params.num_workers = 1, urom_domain_params.domain_size = params->params.oob.n_oob_eps; - self->req_mc = 1; /* requires a memcpy */ + self->req_mc = 0; /* requires a memcpy */ if (params->context->params.mask & UCC_CONTEXT_PARAM_FIELD_OOB && params->context->params.mask & UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS) { diff --git a/src/components/cl/urom/cl_urom_team.c b/src/components/cl/urom/cl_urom_team.c index 1fc3f77e0f..932460616d 100644 --- a/src/components/cl/urom/cl_urom_team.c +++ b/src/components/cl/urom/cl_urom_team.c @@ -80,7 +80,7 @@ ucc_status_t ucc_cl_urom_team_create_test(ucc_base_team_t *cl_team) team->teams[team->n_teams] = notif->ucc.team_create_nqe.team; ++team->n_teams; ucc_status = ucc_coll_score_build_default(cl_team, UCC_CL_UROM_DEFAULT_SCORE, - ucc_cl_urom_coll_init, UCC_COLL_TYPE_ALLTOALL, + ucc_cl_urom_coll_init, UCC_COLL_TYPE_ALLTOALL | UCC_COLL_TYPE_ALLREDUCE, mem_types, mt_n, &score); if (UCC_OK != ucc_status) { return ucc_status;