Skip to content

Commit

Permalink
Urom cl support for sliding window allreduce, alltoall, and alltoallv
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarka committed Jan 26, 2024
1 parent 7543cb4 commit 6862aee
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 256 deletions.
4 changes: 1 addition & 3 deletions src/components/cl/urom/allreduce/allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/

#include "allreduce.h"
#include "../allreduce/allreduce.h"

ucc_base_coll_alg_info_t
ucc_cl_urom_allreduce_algs[UCC_CL_UROM_ALLREDUCE_ALG_LAST + 1] = {
Expand Down Expand Up @@ -107,8 +106,7 @@ static ucc_status_t ucc_cl_urom_allreduce_full_start(ucc_coll_task_t *task)
.ucc.cmd_type = UROM_WORKER_CMD_UCC_COLL,
.ucc.coll_cmd.coll_args = coll_args,
.ucc.coll_cmd.team = cl_team->teams[0],
//.ucc.coll_cmd.use_xgvmi = 0,
//.ucc.coll_cmd.use_sliding_window_allreduce = 1,
.ucc.coll_cmd.use_xgvmi = 1,
};
ucc_memory_type_t prev_src, prev_dst;
ucc_cl_urom_schedule_t *schedule =
Expand Down
69 changes: 69 additions & 0 deletions src/components/cl/urom/alltoall/alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,45 @@ ucc_base_coll_alg_info_t
[UCC_CL_UROM_ALLTOALL_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

static int buffer_export_ucc(ucp_context_h ucp_context, void *buf, size_t len,
struct export_buf *ebuf)
{
ucs_status_t ucs_status;
ucp_mem_map_params_t params;
ucp_memh_pack_params_t pack_params;

ebuf->ucp_context = ucp_context;

params.field_mask =
UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
params.address = buf;
params.length = len;

ucs_status = ucp_mem_map(ucp_context, &params, &ebuf->memh);
assert(ucs_status == UCS_OK);

pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS;
pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT;

ucs_status = ucp_memh_pack(ebuf->memh, &pack_params, &ebuf->packed_memh,
&ebuf->packed_memh_len);
if (ucs_status != UCS_OK) {
printf("ucp_memh_pack() returned error: %s\n",
ucs_status_string(ucs_status));
ebuf->packed_memh = NULL;
ebuf->packed_memh_len = 0;
}
ucs_status = ucp_rkey_pack(ucp_context, ebuf->memh, &ebuf->packed_key,
&ebuf->packed_key_len);
if (UCS_OK != ucs_status) {
printf("ucp_rkey_pack() returned error: %s\n",
ucs_status_string(ucs_status));
return UROM_ERR_NO_RESOURCE;
}

return 0;
}

ucc_status_t ucc_cl_urom_alltoall_triggered_post_setup(ucc_coll_task_t *task)
{
return UCC_OK;
Expand Down Expand Up @@ -58,20 +97,38 @@ static ucc_status_t ucc_cl_urom_alltoall_full_start(ucc_coll_task_t *task)
ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t);
ucc_coll_args_t *coll_args = &task->bargs.args;
urom_status_t urom_status;
int ucp_index = cl_lib->tl_ucp_index;
ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of(ctx->super.tl_ctxs[ucp_index], ucc_tl_ucp_context_t);
urom_worker_cmd_t coll_cmd = {
.cmd_type = UROM_WORKER_CMD_UCC,
.ucc.dpu_worker_id = UCC_CL_TEAM_RANK(cl_team),
.ucc.cmd_type = UROM_WORKER_CMD_UCC_COLL,
.ucc.coll_cmd.coll_args = coll_args,
.ucc.coll_cmd.team = cl_team->teams[0],
//.ucc.coll_cmd.use_xgvmi = ctx->xgvmi_enabled,
.ucc.coll_cmd.use_xgvmi = 1,
};
ucc_memory_type_t prev_src, prev_dst;

ucc_cl_urom_schedule_t *schedule =
ucc_derived_of(task, ucc_cl_urom_schedule_t);
struct export_buf *src_ebuf = &schedule->src_ebuf;
struct export_buf *dst_ebuf = &schedule->dst_ebuf;

prev_src = coll_args->src.info.mem_type;
prev_dst = coll_args->dst.info.mem_type;
coll_args->src.info.mem_type = UCC_MEMORY_TYPE_HOST;
coll_args->dst.info.mem_type = UCC_MEMORY_TYPE_HOST;

buffer_export_ucc(tl_ctx->worker.ucp_context, coll_args->src.info.buffer, coll_args->src.info.count * dt_size(coll_args->src.info.datatype), src_ebuf);
buffer_export_ucc(tl_ctx->worker.ucp_context, coll_args->dst.info.buffer, coll_args->dst.info.count * dt_size(coll_args->dst.info.datatype), dst_ebuf);

coll_cmd.ucc.coll_cmd.src_memh_packed = src_ebuf->packed_memh;
coll_cmd.ucc.coll_cmd.src_memh_packed_len = src_ebuf->packed_memh_len;

coll_cmd.ucc.coll_cmd.dst_memh_packed = dst_ebuf->packed_memh;
coll_cmd.ucc.coll_cmd.dst_memh_packed_len = dst_ebuf->packed_memh_len;

urom_status = urom_worker_push_cmdq(cl_lib->urom_ctx.urom_worker, 0, &coll_cmd);
if (UROM_OK != urom_status) {
cl_debug(&cl_lib->super, "failed to push collective to urom");
Expand Down Expand Up @@ -110,6 +167,18 @@ static ucc_status_t ucc_cl_urom_alltoall_full_finalize(ucc_coll_task_t *task)
ucc_derived_of(task, ucc_cl_urom_schedule_t);
ucc_status_t status;

ucc_cl_urom_team_t *cl_team = ucc_derived_of(task->team, ucc_cl_urom_team_t);
ucc_cl_urom_context_t *ctx = UCC_CL_UROM_TEAM_CTX(cl_team);
ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t);
int ucp_index = cl_lib->tl_ucp_index;
ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of(ctx->super.tl_ctxs[ucp_index], ucc_tl_ucp_context_t);

struct export_buf *src_ebuf = &schedule->src_ebuf;
struct export_buf *dst_ebuf = &schedule->dst_ebuf;

ucp_mem_unmap(tl_ctx->worker.ucp_context, src_ebuf->memh);
ucp_mem_unmap(tl_ctx->worker.ucp_context, dst_ebuf->memh);

status = ucc_schedule_finalize(task);
ucc_cl_urom_put_schedule(&schedule->super.super);
return status;
Expand Down
4 changes: 1 addition & 3 deletions src/components/cl/urom/alltoall/alltoall.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef UROM_ALLTOALL_H_
#define UROM_ALLTOALL_H_
#include "../cl_urom_coll.h"
#include "../../../tl/ucp/tl_ucp.h"

enum
{
Expand All @@ -17,9 +18,6 @@ enum
extern ucc_base_coll_alg_info_t
ucc_cl_urom_alltoall_algs[UCC_CL_UROM_ALLTOALL_ALG_LAST + 1];

ucc_status_t ucc_cl_urom_alltoall_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task);

static inline int ucc_cl_urom_alltoall_alg_from_str(const char *str)
{
Expand Down
Loading

0 comments on commit 6862aee

Please sign in to comment.