Skip to content

Commit

Permalink
fixes for oshmem
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrol aderholdt committed Sep 13, 2023
1 parent f677ee1 commit f86fd47
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 136 deletions.
160 changes: 80 additions & 80 deletions src/components/cl/urom/alltoall/alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,37 @@ ucc_status_t ucc_cl_urom_alltoall_triggered_post_setup(ucc_coll_task_t *task)
return UCC_OK;
}

static size_t dt_size(ucc_datatype_t ucc_dt)
{
size_t size_mod = 8;

switch(ucc_dt) {
case UCC_DT_INT8:
case UCC_DT_UINT8:
size_mod = sizeof(char);
break;
case UCC_DT_INT32:
case UCC_DT_UINT32:
case UCC_DT_FLOAT32:
size_mod = sizeof(int);
break;
case UCC_DT_INT64:
case UCC_DT_UINT64:
case UCC_DT_FLOAT64:
size_mod = sizeof(uint64_t);
break;
case UCC_DT_INT128:
case UCC_DT_UINT128:
case UCC_DT_FLOAT128:
size_mod = sizeof(__int128_t);
break;
default:
break;
}

return size_mod;
}

static ucc_status_t ucc_cl_urom_alltoall_full_start(ucc_coll_task_t *task)
{
ucc_cl_urom_team_t *cl_team = ucc_derived_of(task->team, ucc_cl_urom_team_t);
Expand All @@ -37,7 +68,14 @@ static ucc_status_t ucc_cl_urom_alltoall_full_start(ucc_coll_task_t *task)
.ucc.coll_cmd.use_xgvmi = cl_lib->xgvmi_enabled,
};

if (coll_args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) {
if (coll_args->src.info.mem_type != UCC_MEMORY_TYPE_CUDA) {
urom_status = urom_worker_push_cmdq(cl_lib->urom_worker, 0, &coll_cmd);
if (UROM_OK != urom_status) {
cl_debug(&cl_lib->super, "failed to push collective to urom");
return UCC_ERR_NO_MESSAGE;
}
} else {
#if HAVE_CUDA
// FIXME: a better way is to tweak args in urom
cudaStreamSynchronize(cl_lib->cuda_stream);

Expand All @@ -50,12 +88,10 @@ static ucc_status_t ucc_cl_urom_alltoall_full_start(ucc_coll_task_t *task)
}
coll_args->src.info.mem_type = UCC_MEMORY_TYPE_CUDA;
coll_args->dst.info.mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
urom_status = urom_worker_push_cmdq(cl_lib->urom_worker, 0, &coll_cmd);
if (UROM_OK != urom_status) {
cl_debug(&cl_lib->super, "failed to push collective to urom");
return UCC_ERR_NO_MESSAGE;
}
#else
cl_error(&cl_lib->super, "attempting to use CUDA without CUDA support");
return UCC_ERR_NO_RESOURCE;
#endif
}
task->status = UCC_INPROGRESS;
cl_debug(&cl_lib->super, "pushed the collective to urom");
Expand Down Expand Up @@ -97,43 +133,25 @@ static void ucc_cl_urom_alltoall_full_progress(ucc_coll_task_t *ctask)
return;
}

if (1 || cl_lib->xgvmi_enabled) {
size_t size_mod = 8;

switch(ctask->bargs.args.src.info.datatype) {
case UCC_DT_INT8:
case UCC_DT_UINT8:
size_mod = sizeof(char);
break;
case UCC_DT_INT32:
case UCC_DT_UINT32:
case UCC_DT_FLOAT32:
size_mod = sizeof(int);
break;
case UCC_DT_INT64:
case UCC_DT_UINT64:
case UCC_DT_FLOAT64:
size_mod = sizeof(uint64_t);
break;
case UCC_DT_INT128:
case UCC_DT_UINT128:
case UCC_DT_FLOAT128:
size_mod = sizeof(__int128_t);
break;
default:
printf("**** SCALAR UNKNOWN: %ld\n", ctask->bargs.args.src.info.datatype);
break;
if (cl_lib->xgvmi_enabled) {
size_t size_mod = dt_size(ctask->bargs.args.dst.info.datatype);

if (cl_lib->req_mc) {
if (ctask->bargs.args.dst.info.mem_type != UCC_MEMORY_TYPE_CUDA) {
memcpy(cl_lib->old_dest, ctask->bargs.args.dst.info.buffer, ctask->bargs.args.src.info.count * size_mod);
} else {
#if HAVE_CUDA
cudaMemcpyAsync(cl_lib->old_dest, ctask->bargs.args.dst.info.buffer, ctask->bargs.args.src.info.count * size_mod , cudaMemcpyHostToDevice, cl_lib->cuda_stream);
cudaStreamSynchronize(cl_lib->cuda_stream);
#else
cl_error(&cl_lib->super, "attempting to use CUDA without CUDA support");
return UCC_ERR_NO_RESOURCE;
#endif
}
ctask->bargs.args.dst.info.buffer = cl_lib->old_dest;
ctask->bargs.args.src.info.buffer = cl_lib->old_src;
}

if (ctask->bargs.args.dst.info.mem_type == UCC_MEMORY_TYPE_CUDA) {
cudaMemcpyAsync(cl_lib->old_dest, ctask->bargs.args.dst.info.buffer, ctask->bargs.args.src.info.count * size_mod , cudaMemcpyHostToDevice, cl_lib->cuda_stream);
cudaStreamSynchronize(cl_lib->cuda_stream);
} else {
memcpy(cl_lib->old_dest, ctask->bargs.args.dst.info.buffer, ctask->bargs.args.src.info.count * size_mod);
}

ctask->bargs.args.dst.info.buffer = cl_lib->old_dest;
ctask->bargs.args.src.info.buffer = cl_lib->old_src;
}
cl_debug(&cl_lib->super, "completed the collective from urom");

Expand All @@ -157,44 +175,26 @@ ucc_status_t ucc_cl_urom_alltoall_full_init(
return UCC_ERR_NO_MEMORY;
}
schedule = &cl_schedule->super.super;
if (1 || cl_lib->xgvmi_enabled) {
size_t size_mod = 8;
switch(coll_args->args.src.info.datatype) {
case UCC_DT_INT8:
case UCC_DT_UINT8:
size_mod = sizeof(char);
break;
case UCC_DT_INT32:
case UCC_DT_UINT32:
case UCC_DT_FLOAT32:
size_mod = sizeof(int);
break;
case UCC_DT_INT64:
case UCC_DT_UINT64:
case UCC_DT_FLOAT64:
size_mod = sizeof(uint64_t);
break;
case UCC_DT_INT128:
case UCC_DT_UINT128:
case UCC_DT_FLOAT128:
size_mod = sizeof(__int128_t);
break;
default:
printf("**** SCALAR UNKNOWN: %ld\n", coll_args->args.src.info.datatype);
break;
}
//memcpy args to xgvmi buffer
void * ptr = cl_lib->xgvmi_buffer + (cl_lib->cfg.xgvmi_buffer_size * (schedule->super.seq_num % cl_lib->cfg.num_buffers));
if (coll_args->args.src.info.mem_type == UCC_MEMORY_TYPE_CUDA) {
cudaMemcpyAsync(ptr, coll_args->args.src.info.buffer, coll_args->args.src.info.count * size_mod, cudaMemcpyDeviceToHost, cl_lib->cuda_stream);
} else {
memcpy(ptr, coll_args->args.src.info.buffer, coll_args->args.src.info.count * size_mod);
}

cl_lib->old_src = coll_args->args.src.info.buffer;
coll_args->args.src.info.buffer = ptr;
cl_lib->old_dest = coll_args->args.dst.info.buffer;
coll_args->args.dst.info.buffer = ptr + coll_args->args.src.info.count * size_mod;
if (cl_lib->xgvmi_enabled) {
size_t size_mod = dt_size(coll_args->args.src.info.datatype);
if (cl_lib->req_mc) {
//memcpy args to xgvmi buffer
void * ptr = cl_lib->xgvmi_buffer + (cl_lib->cfg.xgvmi_buffer_size * (schedule->super.seq_num % cl_lib->cfg.num_buffers));
if (coll_args->args.src.info.mem_type != UCC_MEMORY_TYPE_CUDA) {
memcpy(ptr, coll_args->args.src.info.buffer, coll_args->args.src.info.count * size_mod);
} else {
#if HAVE_CUDA
cudaMemcpyAsync(ptr, coll_args->args.src.info.buffer, coll_args->args.src.info.count * size_mod, cudaMemcpyDeviceToHost, cl_lib->cuda_stream);
#else
cl_error(&cl_lib->super, "attempting to use CUDA without CUDA support");
return UCC_ERR_NO_RESOURCE;
#endif
}
cl_lib->old_src = coll_args->args.src.info.buffer;
coll_args->args.src.info.buffer = ptr;
cl_lib->old_dest = coll_args->args.dst.info.buffer;
coll_args->args.dst.info.buffer = ptr + coll_args->args.src.info.count * size_mod;
}
}
memcpy(&args, coll_args, sizeof(args));
status = ucc_schedule_init(schedule, &args, team);
Expand Down
1 change: 1 addition & 0 deletions src/components/cl/urom/cl_urom.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ typedef struct ucc_cl_urom_lib {
uint64_t packed_xgvmi_len;
void * xgvmi_buffer;
size_t xgvmi_size;
int req_mc;
void * old_dest;
void * old_src;
int xgvmi_offsets[NUM_OFFSETS];
Expand Down
114 changes: 58 additions & 56 deletions src/components/cl/urom/cl_urom_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ UCC_CLASS_INIT_FUNC(ucc_cl_urom_context_t,
.size = params->params.oob.n_oob_eps,
},
};

ucc_tl_ucp_context_t *tl_ctx;
ucc_status_t status;
urom_status_t urom_status;
Expand Down Expand Up @@ -130,78 +131,77 @@ UCC_CLASS_INIT_FUNC(ucc_cl_urom_context_t,
urom_domain_params.workers = &urom_lib->urom_worker;
urom_domain_params.num_workers = 1,
urom_domain_params.domain_size = params->params.oob.n_oob_eps;
urom_lib->req_mc = 1; /* requires a memcpy */

printf("my rank %d with size %ld and worker id %ld\n", urom_domain_params.oob.oob_index, urom_domain_params.domain_size, urom_domain_params.domain_worker_id);

/*
if (params->context->params.mask & UCC_CONTEXT_PARAM_FIELD_OOB &&
params->context->params.mask & UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS) {
n_segments = ucc_mem_params.n_segments;
}
*/
/* FIXME: rename xgvmi_buffer -> scratch_buffer */
if (urom_lib->cfg.use_xgvmi || n_segments == 0) {
n_segments += 1; //xgvmi segment
domain_mem_map = ucc_calloc(n_segments, sizeof(urom_mem_map_t),
"urom_domain_mem_map");
if (!domain_mem_map) {
cl_error(&urom_lib->super.super, "Failed to allocate urom_mem_map");
return UCC_ERR_NO_MEMORY;

/* remap the segments for xgvmi if enabled */
urom_lib->xgvmi_buffer = ucc_mem_params.segments[0].address;
urom_lib->xgvmi_size = ucc_mem_params.segments[0].len;
if (urom_lib->cfg.use_xgvmi) {
urom_lib->req_mc = 0;
}

/* add xgvmi buffer */
} else {
urom_lib->xgvmi_size = urom_lib->cfg.num_buffers * urom_lib->cfg.xgvmi_buffer_size;
urom_lib->xgvmi_buffer = ucc_calloc(1, urom_lib->xgvmi_size, "xgvmi buffer");
if (!urom_lib->xgvmi_buffer) {
return UCC_ERR_NO_MEMORY;
}
// mem_map the segment
mem_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mem_params.address = urom_lib->xgvmi_buffer;
mem_params.length = urom_lib->xgvmi_size;
}
n_segments = 1; /* FIXME: just for now */

ucs_status = ucp_mem_map(tl_ctx->worker.ucp_context, &mem_params, &urom_lib->xgvmi_memh);
assert(ucs_status == UCS_OK);
domain_mem_map = ucc_calloc(n_segments, sizeof(urom_mem_map_t),
"urom_domain_mem_map");
if (!domain_mem_map) {
cl_error(&urom_lib->super.super, "Failed to allocate urom_mem_map");
return UCC_ERR_NO_MEMORY;
}

if (urom_lib->cfg.use_xgvmi) {
pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS;
pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT;
// mem_map the segment
mem_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mem_params.address = urom_lib->xgvmi_buffer;
mem_params.length = urom_lib->xgvmi_size;

ucs_status = ucp_memh_pack(urom_lib->xgvmi_memh, &pack_params, &urom_lib->packed_xgvmi_memh, &urom_lib->packed_xgvmi_len);
if (ucs_status != UCS_OK) {
cl_error(&urom_lib->super.super, "ucp_memh_pack() returned error: %s", ucs_status_string(ucs_status));
cl_error(&urom_lib->super.super, "xgvmi will be disabled");
xgvmi_level = 0;
}
xgvmi_level = 1;
}
ucs_status = ucp_mem_map(tl_ctx->worker.ucp_context, &mem_params, &urom_lib->xgvmi_memh);
assert(ucs_status == UCS_OK);

ucs_status = ucp_rkey_pack(tl_ctx->worker.ucp_context, urom_lib->xgvmi_memh, &urom_lib->packed_mkey,
&urom_lib->packed_mkey_len);
if (UCS_OK != ucs_status) {
printf("ucp_rkey_pack() returned error: %s\n",
ucs_status_string(ucs_status));
return UCC_ERR_NO_RESOURCE;
}
if (urom_lib->cfg.use_xgvmi) {
pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS;
pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT;

domain_mem_map[n_segments - 1].mask = UROM_WORKER_MEM_MAP_FIELD_BASE_VA | UROM_WORKER_MEM_MAP_FIELD_MKEY;
domain_mem_map[n_segments - 1].base_va = (uint64_t)urom_lib->xgvmi_buffer;
domain_mem_map[n_segments - 1].len = urom_lib->xgvmi_size;
domain_mem_map[n_segments - 1].mkey = urom_lib->packed_mkey;
domain_mem_map[n_segments - 1].mkey_len = urom_lib->packed_mkey_len;
if (1 || xgvmi_level) {
domain_mem_map[n_segments - 1].mask |= UROM_WORKER_MEM_MAP_FIELD_MEMH;
domain_mem_map[n_segments - 1].memh = urom_lib->packed_xgvmi_memh;
domain_mem_map[n_segments - 1].memh_len = urom_lib->packed_xgvmi_len;
ucs_status = ucp_memh_pack(urom_lib->xgvmi_memh, &pack_params, &urom_lib->packed_xgvmi_memh, &urom_lib->packed_xgvmi_len);
if (ucs_status != UCS_OK) {
cl_error(&urom_lib->super.super, "ucp_memh_pack() returned error: %s", ucs_status_string(ucs_status));
cl_error(&urom_lib->super.super, "xgvmi will be disabled");
xgvmi_level = 0;
} else {
xgvmi_level = 1;
}
urom_domain_params.mask |= UROM_DOMAIN_PARAM_FIELD_MEM_MAP;
urom_domain_params.mem_map.segments = domain_mem_map;
urom_domain_params.mem_map.n_segments = 1;
urom_lib->xgvmi_enabled = 1; //FIXME: for now, just use xgvmi buffers
} else { /* FIXME: shouldn't need an else here */
urom_lib->xgvmi_enabled = 0;
}

ucs_status = ucp_rkey_pack(tl_ctx->worker.ucp_context, urom_lib->xgvmi_memh, &urom_lib->packed_mkey,
&urom_lib->packed_mkey_len);
if (UCS_OK != ucs_status) {
printf("ucp_rkey_pack() returned error: %s\n",
ucs_status_string(ucs_status));
return UCC_ERR_NO_RESOURCE;
}
domain_mem_map[0].mask = UROM_WORKER_MEM_MAP_FIELD_BASE_VA | UROM_WORKER_MEM_MAP_FIELD_MKEY;
domain_mem_map[0].base_va = (uint64_t)urom_lib->xgvmi_buffer;
domain_mem_map[0].len = urom_lib->xgvmi_size;
domain_mem_map[0].mkey = urom_lib->packed_mkey;
domain_mem_map[0].mkey_len = urom_lib->packed_mkey_len;
if (xgvmi_level) {
domain_mem_map[0].mask |= UROM_WORKER_MEM_MAP_FIELD_MEMH;
domain_mem_map[0].memh = urom_lib->packed_xgvmi_memh;
domain_mem_map[0].memh_len = urom_lib->packed_xgvmi_len;
}
urom_domain_params.mask |= UROM_DOMAIN_PARAM_FIELD_MEM_MAP;
urom_domain_params.mem_map.segments = domain_mem_map;
urom_domain_params.mem_map.n_segments = 1;
urom_lib->xgvmi_enabled = xgvmi_level; //FIXME: for now, just use xgvmi buffers

urom_status = urom_domain_create_post(&urom_domain_params, &self->urom_domain);
if (urom_status < UROM_OK) {
cl_error(&urom_lib->super.super, "failed to post urom domain: %s", urom_status_string(urom_status));
Expand Down Expand Up @@ -292,7 +292,9 @@ UCC_CLASS_CLEANUP_FUNC(ucc_cl_urom_context_t)
(urom_status = urom_worker_pop_notifyq(urom_lib->urom_worker, 0, &notif))) {
sched_yield();
}
ucc_free(urom_lib->xgvmi_buffer);
if (urom_lib->req_mc) {
ucc_free(urom_lib->xgvmi_buffer);
}

cudaStreamDestroy(urom_lib->cuda_stream);
}
Expand Down

0 comments on commit f86fd47

Please sign in to comment.