diff --git a/src/components/tl/ucp/alltoall/alltoall_onesided.c b/src/components/tl/ucp/alltoall/alltoall_onesided.c index 05860f9fe6..99c56d281c 100644 --- a/src/components/tl/ucp/alltoall/alltoall_onesided.c +++ b/src/components/tl/ucp/alltoall/alltoall_onesided.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -57,7 +57,7 @@ void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask) if ((*pSync < gsize) || (task->onesided.put_completed < task->onesided.put_posted)) { - ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->ucp_worker); + ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->worker.ucp_worker); return; } diff --git a/src/components/tl/ucp/alltoall/alltoall_pairwise.c b/src/components/tl/ucp/alltoall/alltoall_pairwise.c index e7cc27029b..b37d742a4a 100644 --- a/src/components/tl/ucp/alltoall/alltoall_pairwise.c +++ b/src/components/tl/ucp/alltoall/alltoall_pairwise.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -45,7 +45,7 @@ void ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *coll_task) while ((task->tagged.send_posted < gsize || task->tagged.recv_posted < gsize) && (polls++ < task->n_polls)) { - ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->ucp_worker); + ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->worker.ucp_worker); while ((task->tagged.recv_posted < gsize) && ((task->tagged.recv_posted - task->tagged.recv_completed) < nreqs)) { diff --git a/src/components/tl/ucp/alltoallv/alltoallv_pairwise.c b/src/components/tl/ucp/alltoallv/alltoallv_pairwise.c index dcdbf248bb..5af6b59c10 100644 --- a/src/components/tl/ucp/alltoallv/alltoallv_pairwise.c +++ b/src/components/tl/ucp/alltoallv/alltoallv_pairwise.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-20022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -46,7 +46,7 @@ void ucc_tl_ucp_alltoallv_pairwise_progress(ucc_coll_task_t *coll_task) while ((task->tagged.send_posted < gsize || task->tagged.recv_posted < gsize) && (polls++ < task->n_polls)) { - ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->ucp_worker); + ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->worker.ucp_worker); while ((task->tagged.recv_posted < gsize) && ((task->tagged.recv_posted - task->tagged.recv_completed) < nreqs)) { diff --git a/src/components/tl/ucp/reduce_scatter/reduce_scatter_ring.c b/src/components/tl/ucp/reduce_scatter/reduce_scatter_ring.c index 4f6f6ca82d..abf1829446 100644 --- a/src/components/tl/ucp/reduce_scatter/reduce_scatter_ring.c +++ b/src/components/tl/ucp/reduce_scatter/reduce_scatter_ring.c @@ -88,7 +88,7 @@ static inline ucc_status_t ucc_tl_ucp_test_ring(ucc_tl_ucp_task_t *task) task->tagged.recv_posted == task->tagged.recv_completed) { return UCC_OK; } - ucp_worker_progress(TASK_CTX(task)->ucp_worker); + ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker); } return UCC_INPROGRESS; } diff --git a/src/components/tl/ucp/reduce_scatterv/reduce_scatterv_ring.c b/src/components/tl/ucp/reduce_scatterv/reduce_scatterv_ring.c index 1c30f12cce..e295d35890 100644 --- a/src/components/tl/ucp/reduce_scatterv/reduce_scatterv_ring.c +++ b/src/components/tl/ucp/reduce_scatterv/reduce_scatterv_ring.c @@ -96,7 +96,7 @@ static inline ucc_status_t ucc_tl_ucp_test_ring(ucc_tl_ucp_task_t *task) task->tagged.recv_posted == task->tagged.recv_completed) { return UCC_OK; } - ucp_worker_progress(TASK_CTX(task)->ucp_worker); + ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker); } return UCC_INPROGRESS; } diff --git a/src/components/tl/ucp/tl_ucp.c b/src/components/tl/ucp/tl_ucp.c index cd3ccaca75..ffd298afad 100644 --- a/src/components/tl/ucp/tl_ucp.c +++ b/src/components/tl/ucp/tl_ucp.c @@ -175,6 +175,20 @@ static ucs_config_field_t ucc_tl_ucp_context_config_table[] = { ucc_offsetof(ucc_tl_ucp_context_config_t, pre_reg_mem), UCC_CONFIG_TYPE_UINT}, + {"SERVICE_WORKER", "n", + "If set to 0, uses the same worker for collectives and " + "service. If not, creates a special worker for service collectives " + "for which UCX_TL and UCX_NET_DEVICES are configured by the variables " + "UCC_TL_UCP_SERVICE_TLS and UCC_TL_UCP_SERVICE_NET_DEVICES respectively", + ucc_offsetof(ucc_tl_ucp_context_config_t, service_worker), + UCC_CONFIG_TYPE_BOOL}, + + {"SERVICE_THROTTLING_THRESH", "100", + "Number of call to ucc_context_progress function between two consecutive " + "calls to service worker progress function", + ucc_offsetof(ucc_tl_ucp_context_config_t, service_throttling_thresh), + UCC_CONFIG_TYPE_UINT}, + {NULL}}; UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_ucp_lib_t, ucc_base_lib_t, diff --git a/src/components/tl/ucp/tl_ucp.h b/src/components/tl/ucp/tl_ucp.h index 4f34c75bd8..7108abc858 100644 --- a/src/components/tl/ucp/tl_ucp.h +++ b/src/components/tl/ucp/tl_ucp.h @@ -74,6 +74,8 @@ typedef struct ucc_tl_ucp_context_config { uint32_t n_polls; uint32_t oob_npolls; uint32_t pre_reg_mem; + uint32_t service_worker; + uint32_t service_throttling_thresh; } ucc_tl_ucp_context_config_t; typedef struct ucc_tl_ucp_lib { @@ -92,16 +94,22 @@ typedef struct ucc_tl_ucp_remote_info { size_t packed_key_len; } ucc_tl_ucp_remote_info_t; +typedef struct ucc_tl_ucp_worker { + ucp_context_h ucp_context; + ucp_worker_h ucp_worker; + size_t ucp_addrlen; + ucp_address_t * worker_address; + tl_ucp_ep_hash_t *ep_hash; + ucp_ep_h * eps; +} ucc_tl_ucp_worker_t; + typedef struct ucc_tl_ucp_context { ucc_tl_context_t super; ucc_tl_ucp_context_config_t cfg; - ucp_context_h ucp_context; - ucp_worker_h ucp_worker; - size_t ucp_addrlen; - ucp_address_t * worker_address; + ucc_tl_ucp_worker_t worker; + ucc_tl_ucp_worker_t service_worker; + int service_worker_throttling_count; ucc_mpool_t req_mp; - tl_ucp_ep_hash_t * ep_hash; - ucp_ep_h * eps; ucc_tl_ucp_remote_info_t * remote_info; ucp_rkey_h * rkeys; uint64_t n_rinfo_segs; @@ -118,6 +126,7 @@ typedef struct ucc_tl_ucp_team { ucc_tl_ucp_task_t *preconnect_task; void * va_base[MAX_NR_SEGMENTS]; size_t base_length[MAX_NR_SEGMENTS]; + ucc_tl_ucp_worker_t * worker; } ucc_tl_ucp_team_t; UCC_CLASS_DECLARE(ucc_tl_ucp_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); @@ -135,15 +144,20 @@ UCC_CLASS_DECLARE(ucc_tl_ucp_team_t, ucc_base_context_t *, #define UCC_TL_UCP_TEAM_CTX(_team) \ (ucc_derived_of((_team)->super.super.context, ucc_tl_ucp_context_t)) -#define UCC_TL_UCP_WORKER(_team) UCC_TL_UCP_TEAM_CTX(_team)->ucp_worker +#define IS_SERVICE_TEAM(_team) \ + ((_team)->super.super.params.scope == UCC_CL_LAST + 1) + +#define USE_SERVICE_WORKER(_team) \ + (IS_SERVICE_TEAM(_team) && UCC_TL_UCP_TEAM_CTX(_team)->cfg.service_worker) + +#define UCC_TL_UCP_TASK_TEAM(_task) \ + (ucc_derived_of((_task)->super.team, ucc_tl_ucp_team_t)) #define UCC_TL_CTX_HAS_OOB(_ctx) \ ((_ctx)->super.super.ucc_context->params.mask & UCC_CONTEXT_PARAM_FIELD_OOB) #define UCC_TL_CTX_OOB(_ctx) ((_ctx)->super.super.ucc_context->params.oob) -#define IS_SERVICE_TEAM(_team) ((_team)->super.super.params.scope == UCC_CL_LAST + 1) - #define UCC_TL_UCP_REMOTE_RKEY(_ctx, _rank, _seg) \ ((_ctx)->rkeys[_rank * _ctx->n_rinfo_segs + _seg]) diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index bc7dcc4db5..57cd14f8cb 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -297,7 +297,7 @@ static inline ucc_status_t ucc_tl_ucp_test(ucc_tl_ucp_task_t *task) if (UCC_TL_UCP_TASK_P2P_COMPLETE(task)) { return UCC_OK; } - ucp_worker_progress(TASK_CTX(task)->ucp_worker); + ucp_worker_progress(UCC_TL_UCP_TASK_TEAM(task)->worker->ucp_worker); } return UCC_INPROGRESS; } diff --git a/src/components/tl/ucp/tl_ucp_context.c b/src/components/tl/ucp/tl_ucp_context.c index 75c9432679..00694efd46 100644 --- a/src/components/tl/ucp/tl_ucp_context.c +++ b/src/components/tl/ucp/tl_ucp_context.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -9,10 +9,127 @@ #include "tl_ucp_coll.h" #include "tl_ucp_ep.h" #include "utils/ucc_math.h" +#include "utils/ucc_string.h" #include "utils/arch/cpu.h" #include "schedule/ucc_schedule_pipelined.h" #include +#define UCP_CHECK(function, msg, go, ctx) \ + status = function; \ + if (UCS_OK != status) { \ + tl_error(ctx->super.super.lib, msg ", %s", ucs_status_string(status)); \ + ucc_status = ucs_status_to_ucc_status(status); \ + goto go; \ + } + +#define CHECK(test, msg, go, return_status, ctx) \ + if (test) { \ + tl_error(ctx->super.super.lib, msg); \ + ucc_status = return_status; \ + goto go; \ + } + +unsigned ucc_tl_ucp_service_worker_progress(void *progress_arg) +{ + ucc_tl_ucp_context_t *ctx = (ucc_tl_ucp_context_t *)progress_arg; + int throttling_count = + ucc_atomic_fadd32(&ctx->service_worker_throttling_count, 1); + + if (throttling_count == ctx->cfg.service_throttling_thresh) { + ctx->service_worker_throttling_count = 0; + return ucp_worker_progress(ctx->service_worker.ucp_worker); + } + + return 0; +} + +static inline ucc_status_t +ucc_tl_ucp_eps_ephash_init(const ucc_base_context_params_t *params, + ucc_tl_ucp_context_t * ctx, + tl_ucp_ep_hash_t **ep_hash, ucp_ep_h **eps) +{ + if (params->context->params.mask & UCC_CONTEXT_PARAM_FIELD_OOB) { + /* Global ctx mode, we will have ctx_map so can use array for eps */ + *eps = ucc_calloc(params->context->params.oob.n_oob_eps, + sizeof(ucp_ep_h), "ucp_eps"); + if (!(*eps)) { + tl_error(ctx->super.super.lib, + "failed to allocate %zd bytes for ucp_eps", + params->context->params.oob.n_oob_eps * sizeof(ucp_ep_h)); + return UCC_ERR_NO_MEMORY; + } + } else { + *eps = NULL; + *ep_hash = kh_init(tl_ucp_ep_hash); + } + return UCC_OK; +} + +static inline ucc_status_t +ucc_tl_ucp_context_service_init(const char *prefix, ucp_params_t ucp_params, + ucp_worker_params_t worker_params, + const ucc_base_context_params_t *params, + ucc_tl_ucp_context_t * ctx) +{ + ucc_status_t ucc_status; + ucp_config_t *ucp_config; + ucp_context_h ucp_context_service; + ucp_worker_h ucp_worker_service; + ucs_status_t status; + char * service_prefix; + + ucc_status = ucc_str_concat(prefix, "_SERVICE", &service_prefix); + if (UCC_OK != ucc_status) { + tl_error(ctx->super.super.lib, "failed to concat service prefix str"); + return ucc_status; + } + UCP_CHECK(ucp_config_read(service_prefix, NULL, &ucp_config), + "failed to read ucp configuration", err_cfg_read, ctx); + ucc_free(service_prefix); + service_prefix = NULL; + + UCP_CHECK(ucp_init(&ucp_params, ucp_config, &ucp_context_service), + "failed to init ucp context for service worker", err_cfg, ctx); + ucp_config_release(ucp_config); + + UCP_CHECK(ucp_worker_create(ucp_context_service, &worker_params, + &ucp_worker_service), + "failed to create ucp service worker", err_worker_create, ctx); + + ctx->service_worker.ucp_context = ucp_context_service; + ctx->service_worker.ucp_worker = ucp_worker_service; + ctx->service_worker.worker_address = NULL; + + CHECK(UCC_OK != ucc_tl_ucp_eps_ephash_init(params, ctx, + &ctx->service_worker.ep_hash, + &ctx->service_worker.eps), + "failed to allocate memory for endpoint storage for service worker", + err_thread_mode, UCC_ERR_NO_MESSAGE, ctx); + + ctx->service_worker_throttling_count = 0; + CHECK(UCC_OK != + ucc_context_progress_register( + params->context, + (ucc_context_progress_fn_t)ucc_tl_ucp_service_worker_progress, + ctx), + "failed to register progress function for service worker", + err_thread_mode, UCC_ERR_NO_MESSAGE, ctx); + + return UCC_OK; + +err_thread_mode: + ucp_worker_destroy(ucp_worker_service); +err_worker_create: + ucp_cleanup(ucp_context_service); +err_cfg: + ucp_config_release(ucp_config); +err_cfg_read: + if (service_prefix) { + ucc_free(service_prefix); + } + return ucc_status; +} + UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, const ucc_base_context_params_t *params, const ucc_base_config_t *config) @@ -28,17 +145,20 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, ucp_context_h ucp_context; ucp_worker_h ucp_worker; ucs_status_t status; + char * prefix; UCC_CLASS_CALL_SUPER_INIT(ucc_tl_context_t, &tl_ucp_config->super, params->context); memcpy(&self->cfg, tl_ucp_config, sizeof(*tl_ucp_config)); - status = ucp_config_read(params->prefix, NULL, &ucp_config); - if (UCS_OK != status) { - tl_error(self->super.super.lib, "failed to read ucp configuration, %s", - ucs_status_string(status)); - ucc_status = ucs_status_to_ucc_status(status); - goto err_cfg; + + prefix = strdup(params->prefix); + if (!prefix) { + tl_error(self->super.super.lib, "failed to duplicate prefix str"); + return UCC_ERR_NO_MEMORY; } + prefix[strlen(prefix) - 1] = '\0'; + UCP_CHECK(ucp_config_read(prefix, NULL, &ucp_config), + "failed to read ucp configuration", err_cfg_read, self); ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_TAG_SENDER_MASK; @@ -58,24 +178,15 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, ucp_params.estimated_num_eps = params->estimated_num_eps; } - status = ucp_init(&ucp_params, ucp_config, &ucp_context); + UCP_CHECK(ucp_init(&ucp_params, ucp_config, &ucp_context), + "failed to init ucp context", err_cfg, self); ucp_config_release(ucp_config); - if (UCS_OK != status) { - tl_error(self->super.super.lib, "failed to init ucp context, %s", - ucs_status_string(status)); - ucc_status = ucs_status_to_ucc_status(status); - goto err_cfg; - } context_attr.field_mask = UCP_ATTR_FIELD_MEMORY_TYPES; - status = ucp_context_query(ucp_context, &context_attr); - if (UCS_OK != status) { - tl_error(self->super.super.lib, - "failed to query supported memory types, %s", - ucs_status_string(status)); - ucc_status = ucs_status_to_ucc_status(status); - goto err_worker_create; - } + UCP_CHECK(ucp_context_query(ucp_context, &context_attr), + "failed to query supported memory types", err_worker_create, + self); + self->ucp_memory_types = context_attr.memory_types; worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; switch (params->thread_mode) { @@ -91,13 +202,9 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, ucc_assert(0); break; } - status = ucp_worker_create(ucp_context, &worker_params, &ucp_worker); - if (UCS_OK != status) { - tl_error(self->super.super.lib, "failed to create ucp worker, %s", - ucs_status_string(status)); - ucc_status = ucs_status_to_ucc_status(status); - goto err_worker_create; - } + + UCP_CHECK(ucp_worker_create(ucp_context, &worker_params, &ucp_worker), + "failed to create ucp worker", err_worker_create, self); if (params->thread_mode == UCC_THREAD_MULTIPLE) { worker_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; @@ -110,9 +217,9 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, } } - self->ucp_context = ucp_context; - self->ucp_worker = ucp_worker; - self->worker_address = NULL; + self->worker.ucp_context = ucp_context; + self->worker.ucp_worker = ucp_worker; + self->worker.worker_address = NULL; ucc_status = ucc_mpool_init( &self->req_mp, 0, @@ -124,14 +231,13 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, "failed to initialize tl_ucp_req mpool"); goto err_thread_mode; } - if (UCC_OK != ucc_context_progress_register( - params->context, - (ucc_context_progress_fn_t)ucp_worker_progress, - self->ucp_worker)) { - tl_error(self->super.super.lib, "failed to register progress function"); - ucc_status = UCC_ERR_NO_MESSAGE; - goto err_thread_mode; - } + + CHECK(UCC_OK != ucc_context_progress_register( + params->context, + (ucc_context_progress_fn_t)ucp_worker_progress, + self->worker.ucp_worker), + "failed to register progress function", err_thread_mode, + UCC_ERR_NO_MESSAGE, self); self->remote_info = NULL; self->n_rinfo_segs = 0; @@ -145,21 +251,21 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, goto err_thread_mode; } } - if (params->context->params.mask & UCC_CONTEXT_PARAM_FIELD_OOB) { - /* Global ctx mode, we will have ctx_map so can use array for eps */ - self->eps = ucc_calloc(params->context->params.oob.n_oob_eps, - sizeof(ucp_ep_h), "ucp_eps"); - if (!self->eps) { - tl_error(self->super.super.lib, - "failed to allocate %zd bytes for ucp_eps", - params->context->params.oob.n_oob_eps * sizeof(ucp_ep_h)); - ucc_status = UCC_ERR_NO_MEMORY; - goto err_thread_mode; - } - } else { - self->eps = NULL; - self->ep_hash = kh_init(tl_ucp_ep_hash); + + CHECK(UCC_OK != ucc_tl_ucp_eps_ephash_init( + params, self, &self->worker.ep_hash, &self->worker.eps), + "failed to allocate memory for endpoint storage", err_thread_mode, + UCC_ERR_NO_MESSAGE, self); + + if (self->cfg.service_worker) { + CHECK(UCC_OK != ucc_tl_ucp_context_service_init( + prefix, ucp_params, worker_params, params, self), + "failed to init service worker", err_cfg, UCC_ERR_NO_MESSAGE, + self); } + ucc_free(prefix); + prefix = NULL; + tl_info(self->super.super.lib, "initialized tl context: %p", self); return UCC_OK; @@ -168,6 +274,11 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, err_worker_create: ucp_cleanup(ucp_context); err_cfg: + ucp_config_release(ucp_config); +err_cfg_read: + if (prefix) { + ucc_free(prefix); + } return ucc_status; } @@ -194,7 +305,10 @@ static void ucc_tl_ucp_context_barrier(ucc_tl_ucp_context_t *ctx, &req)) { ucc_assert(req); while (UCC_OK != (status = oob->req_test(req))) { - ucp_worker_progress(ctx->ucp_worker); + ucp_worker_progress(ctx->worker.ucp_worker); + if (ctx->cfg.service_worker != 0) { + ucp_worker_progress(ctx->service_worker.ucp_worker); + } if (status < 0) { tl_error(ctx->super.super.lib, "failed to test oob req"); break; @@ -219,7 +333,7 @@ ucc_status_t ucc_tl_ucp_rinfo_destroy(ucc_tl_ucp_context_t *ctx) } for (i = 0; i < ctx->n_rinfo_segs; i++) { if (ctx->remote_info[i].mem_h) { - ucp_mem_unmap(ctx->ucp_context, ctx->remote_info[i].mem_h); + ucp_mem_unmap(ctx->worker.ucp_context, ctx->remote_info[i].mem_h); } if (ctx->remote_info[i].packed_key) { ucp_rkey_buffer_release(ctx->remote_info[i].packed_key); @@ -233,30 +347,54 @@ ucc_status_t ucc_tl_ucp_rinfo_destroy(ucc_tl_ucp_context_t *ctx) return UCC_OK; } -UCC_CLASS_CLEANUP_FUNC(ucc_tl_ucp_context_t) +static inline void ucc_tl_ucp_eps_cleanup(ucc_tl_ucp_worker_t * worker, + ucc_tl_ucp_context_t *ctx) { - tl_info(self->super.super.lib, "finalizing tl context: %p", self); - ucc_tl_ucp_close_eps(self); - if (self->eps) { - ucc_free(self->eps); + ucc_tl_ucp_close_eps(worker, ctx); + if (worker->eps) { + ucc_free(worker->eps); } else { - kh_destroy(tl_ucp_ep_hash, self->ep_hash); + kh_destroy(tl_ucp_ep_hash, worker->ep_hash); } +} + +static inline void ucc_tl_ucp_worker_cleanup(ucc_tl_ucp_worker_t worker) +{ + if (worker.worker_address) { + ucp_worker_release_address(worker.ucp_worker, worker.worker_address); + } + ucp_worker_destroy(worker.ucp_worker); + ucp_cleanup(worker.ucp_context); +} + +UCC_CLASS_CLEANUP_FUNC(ucc_tl_ucp_context_t) +{ + tl_info(self->super.super.lib, "finalizing tl context: %p", self); if (self->remote_info) { ucc_tl_ucp_rinfo_destroy(self); } - if (UCC_TL_CTX_HAS_OOB(self)) { - ucc_tl_ucp_context_barrier(self, &UCC_TL_CTX_OOB(self)); - } ucc_context_progress_deregister( self->super.super.ucc_context, - (ucc_context_progress_fn_t)ucp_worker_progress, self->ucp_worker); - if (self->worker_address) { - ucp_worker_release_address(self->ucp_worker, self->worker_address); + (ucc_context_progress_fn_t)ucp_worker_progress, + self->worker.ucp_worker); + if (self->cfg.service_worker != 0) { + ucc_context_progress_deregister( + self->super.super.ucc_context, + (ucc_context_progress_fn_t)ucc_tl_ucp_service_worker_progress, + self); } - ucp_worker_destroy(self->ucp_worker); ucc_mpool_cleanup(&self->req_mp, 1); - ucp_cleanup(self->ucp_context); + ucc_tl_ucp_eps_cleanup(&self->worker, self); + if (self->cfg.service_worker != 0) { + ucc_tl_ucp_eps_cleanup(&self->service_worker, self); + } + if (UCC_TL_CTX_HAS_OOB(self)) { + ucc_tl_ucp_context_barrier(self, &UCC_TL_CTX_OOB(self)); + } + ucc_tl_ucp_worker_cleanup(self->worker); + if (self->cfg.service_worker != 0) { + ucc_tl_ucp_worker_cleanup(self->service_worker); + } } UCC_CLASS_DEFINE(ucc_tl_ucp_context_t, ucc_tl_context_t); @@ -277,12 +415,12 @@ ucc_status_t ucc_tl_ucp_populate_rcache(void *addr, size_t length, mmap_params.memory_type = mem_type; /* do map and umap to populate the cache */ - status = ucp_mem_map(ctx->ucp_context, &mmap_params, &mh); + status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); if (ucc_unlikely(status != UCS_OK)) { return ucs_status_to_ucc_status(status); } - status = ucp_mem_unmap(ctx->ucp_context, mh); + status = ucp_mem_unmap(ctx->worker.ucp_context, mh); if (ucc_unlikely(status != UCS_OK)) { return ucs_status_to_ucc_status(status); } @@ -336,7 +474,7 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx, mmap_params.address = map.segments[i].address; mmap_params.length = map.segments[i].len; - status = ucp_mem_map(ctx->ucp_context, &mmap_params, &mh); + status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); if (UCS_OK != status) { tl_error(ctx->super.super.lib, "ucp_mem_map failed with error code: %d", status); @@ -344,9 +482,9 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx, goto fail_mem_map; } ctx->remote_info[i].mem_h = (void *)mh; - status = - ucp_rkey_pack(ctx->ucp_context, mh, &ctx->remote_info[i].packed_key, - &ctx->remote_info[i].packed_key_len); + status = ucp_rkey_pack(ctx->worker.ucp_context, mh, + &ctx->remote_info[i].packed_key, + &ctx->remote_info[i].packed_key_len); if (UCS_OK != status) { tl_error(ctx->super.super.lib, "failed to pack UCP key with error code: %d", status); @@ -362,7 +500,7 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx, fail_mem_map: for (i = 0; i < nsegs; i++) { if (ctx->remote_info[i].mem_h) { - ucp_mem_unmap(ctx->ucp_context, ctx->remote_info[i].mem_h); + ucp_mem_unmap(ctx->worker.ucp_context, ctx->remote_info[i].mem_h); } if (ctx->remote_info[i].packed_key) { ucp_rkey_buffer_release(ctx->remote_info[i].packed_key); @@ -407,23 +545,44 @@ ucc_status_t ucc_tl_ucp_get_context_attr(const ucc_base_context_t *context, ucc_base_ctx_attr_t *attr) { ucc_tl_ucp_context_t *ctx = ucc_derived_of(context, ucc_tl_ucp_context_t); + uint64_t * offset = (uint64_t *)attr->attr.ctx_addr; ucs_status_t ucs_status; size_t packed_length; int i; - if ((attr->attr.mask & (UCC_CONTEXT_ATTR_FIELD_CTX_ADDR_LEN | - UCC_CONTEXT_ATTR_FIELD_CTX_ADDR)) && - (NULL == ctx->worker_address)) { - ucs_status = ucp_worker_get_address( - ctx->ucp_worker, &ctx->worker_address, &ctx->ucp_addrlen); - if (UCS_OK != ucs_status) { - tl_error(ctx->super.super.lib, "failed to get ucp worker address"); - return ucs_status_to_ucc_status(ucs_status); + if (attr->attr.mask & (UCC_CONTEXT_ATTR_FIELD_CTX_ADDR_LEN | + UCC_CONTEXT_ATTR_FIELD_CTX_ADDR)) { + if (NULL == ctx->worker.worker_address) { + ucs_status = ucp_worker_get_address(ctx->worker.ucp_worker, + &ctx->worker.worker_address, + &ctx->worker.ucp_addrlen); + if (UCS_OK != ucs_status) { + tl_error(ctx->super.super.lib, + "failed to get ucp worker address"); + return ucs_status_to_ucc_status(ucs_status); + } + if (ctx->cfg.service_worker != 0 && + (NULL == ctx->service_worker.worker_address)) { + ucs_status = + ucp_worker_get_address(ctx->service_worker.ucp_worker, + &ctx->service_worker.worker_address, + &ctx->service_worker.ucp_addrlen); + if (UCS_OK != ucs_status) { + tl_error( + ctx->super.super.lib, + "failed to get ucp special service worker address"); + return ucs_status_to_ucc_status(ucs_status); + } + } } } if (attr->attr.mask & UCC_CONTEXT_ATTR_FIELD_CTX_ADDR_LEN) { - packed_length = TL_UCP_EP_ADDRLEN_SIZE + ctx->ucp_addrlen; + packed_length = TL_UCP_EP_ADDRLEN_SIZE + ctx->worker.ucp_addrlen; + if (ctx->cfg.service_worker != 0) { + packed_length += + TL_UCP_EP_ADDRLEN_SIZE + ctx->service_worker.ucp_addrlen; + } if (NULL != ctx->remote_info) { packed_length += ctx->n_rinfo_segs * (sizeof(size_t) * 3); for (i = 0; i < ctx->n_rinfo_segs; i++) { @@ -433,12 +592,19 @@ ucc_status_t ucc_tl_ucp_get_context_attr(const ucc_base_context_t *context, attr->attr.ctx_addr_len = packed_length; } if (attr->attr.mask & UCC_CONTEXT_ATTR_FIELD_CTX_ADDR) { - *((uint64_t*)attr->attr.ctx_addr) = ctx->ucp_addrlen; - memcpy(TL_UCP_EP_ADDR_WORKER(attr->attr.ctx_addr), - ctx->worker_address, ctx->ucp_addrlen); + *offset = ctx->worker.ucp_addrlen; + offset = TL_UCP_EP_ADDR_WORKER(offset); + memcpy(offset, ctx->worker.worker_address, ctx->worker.ucp_addrlen); + offset = PTR_OFFSET(offset, ctx->worker.ucp_addrlen); + if (ctx->cfg.service_worker != 0) { + *offset = ctx->service_worker.ucp_addrlen; + offset = TL_UCP_EP_ADDR_WORKER(offset); + memcpy(offset, ctx->service_worker.worker_address, + ctx->service_worker.ucp_addrlen); + offset = PTR_OFFSET(offset, ctx->service_worker.ucp_addrlen); + } if (NULL != ctx->remote_info) { - ucc_tl_ucp_ctx_remote_pack_data(ctx, - TL_UCP_EP_ADDR_ONESIDED_INFO(attr->attr.ctx_addr)); + ucc_tl_ucp_ctx_remote_pack_data(ctx, offset); } } if (attr->attr.mask & UCC_CONTEXT_ATTR_FIELD_WORK_BUFFER_SIZE) { diff --git a/src/components/tl/ucp/tl_ucp_ep.c b/src/components/tl/ucp/tl_ucp_ep.c index 246b06c99e..096564b7b3 100644 --- a/src/components/tl/ucp/tl_ucp_ep.c +++ b/src/components/tl/ucp/tl_ucp_ep.c @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + #include "tl_ucp.h" #include "tl_ucp_ep.h" @@ -10,9 +16,11 @@ static void ucc_tl_ucp_err_handler(void *arg, ucp_ep_h ep, ucs_status_t status) } static inline ucc_status_t ucc_tl_ucp_connect_ep(ucc_tl_ucp_context_t *ctx, - ucp_ep_h *ep, + int is_service, ucp_ep_h *ep, void *ucp_address) { + ucp_worker_h worker = + (is_service) ? ctx->service_worker.ucp_worker : ctx->worker.ucp_worker; ucp_ep_params_t ep_params; ucs_status_t status; if (*ep) { @@ -29,7 +37,7 @@ static inline ucc_status_t ucc_tl_ucp_connect_ep(ucc_tl_ucp_context_t *ctx, ep_params.field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; } - status = ucp_ep_create(ctx->ucp_worker, &ep_params, ep); + status = ucp_ep_create(worker, &ep_params, ep); if (ucc_unlikely(UCS_OK != status)) { tl_error(ctx->super.super.lib, "ucp returned connect error: %s", @@ -43,35 +51,41 @@ ucc_status_t ucc_tl_ucp_connect_team_ep(ucc_tl_ucp_team_t *team, ucc_rank_t core_rank, ucp_ep_h *ep) { ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); + int use_service_worker = USE_SERVICE_WORKER(team); void *addr; addr = ucc_get_team_ep_addr(UCC_TL_CORE_CTX(team), UCC_TL_CORE_TEAM(team), core_rank, ucc_tl_ucp.super.super.id); - return ucc_tl_ucp_connect_ep(ctx, ep, TL_UCP_EP_ADDR_WORKER(addr)); + addr = use_service_worker ? TL_UCP_EP_ADDR_WORKER_SERVICE(addr) + : TL_UCP_EP_ADDR_WORKER(addr); + + return ucc_tl_ucp_connect_ep(ctx, use_service_worker, ep, addr); } /* Finds next non-NULL ep in the storage and returns that handle for closure. In case of "hash" storage it pops the item, in case of "array" sets it to NULL */ -static inline ucp_ep_h get_next_ep_to_close(ucc_tl_ucp_context_t *ctx, int *i) +static inline ucp_ep_h get_next_ep_to_close(ucc_tl_ucp_worker_t * worker, + ucc_tl_ucp_context_t *ctx, int *i) { ucp_ep_h ep = NULL; ucc_rank_t size; - if (ctx->eps) { + if (worker->eps) { size = (ucc_rank_t)ctx->super.super.ucc_context->params.oob.n_oob_eps; while (NULL == ep && (*i) < size) { - ep = ctx->eps[*i]; - ctx->eps[*i] = NULL; + ep = worker->eps[*i]; + worker->eps[*i] = NULL; (*i)++; } } else { - ep = tl_ucp_hash_pop(ctx->ep_hash); + ep = tl_ucp_hash_pop(worker->ep_hash); } return ep; } -void ucc_tl_ucp_close_eps(ucc_tl_ucp_context_t *ctx) +void ucc_tl_ucp_close_eps(ucc_tl_ucp_worker_t * worker, + ucc_tl_ucp_context_t *ctx) { int i = 0; ucp_ep_h ep; @@ -81,13 +95,16 @@ void ucc_tl_ucp_close_eps(ucc_tl_ucp_context_t *ctx) param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS; param.flags = 0; // 0 means FLUSH - ep = get_next_ep_to_close(ctx, &i); + ep = get_next_ep_to_close(worker, ctx, &i); while (ep) { close_req = ucp_ep_close_nbx(ep, ¶m); if (UCS_PTR_IS_PTR(close_req)) { do { - ucp_worker_progress(ctx->ucp_worker); + ucp_worker_progress(ctx->worker.ucp_worker); + if (ctx->cfg.service_worker != 0) { + ucp_worker_progress(ctx->service_worker.ucp_worker); + } status = ucp_request_check_status(close_req); } while (status == UCS_INPROGRESS); ucp_request_free(close_req); @@ -100,6 +117,6 @@ void ucc_tl_ucp_close_eps(ucc_tl_ucp_context_t *ctx) "error during ucp ep close, ep %p, status %s", ep, ucs_status_string(status)); } - ep = get_next_ep_to_close(ctx, &i); + ep = get_next_ep_to_close(worker, ctx, &i); } } diff --git a/src/components/tl/ucp/tl_ucp_ep.h b/src/components/tl/ucp/tl_ucp_ep.h index 3bf063a3b6..6874ef6ccf 100644 --- a/src/components/tl/ucp/tl_ucp_ep.h +++ b/src/components/tl/ucp/tl_ucp_ep.h @@ -1,8 +1,9 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ + #include "config.h" #ifndef UCC_TL_UCP_EP_H_ @@ -12,25 +13,36 @@ #include "tl_ucp.h" #include "core/ucc_team.h" -/* TL/UCP endpoint address layout: (ucp_addrlen may very per proc) +/* TL/UCP endpoint address layout: (ucp_addrlen may vary per proc) - [ucp_addrlen][ucp_worker_address][onesided_info] + [worker->ucp_addrlen][ucp_worker_address][onesided_info] 8 bytes ucp_addrlen bytes + + If a special service worker is set through UCC_TL_UCP_SERVICE_TLS: + [worker->ucp_addrlen][ucp_worker_address][service_worker->ucp_addrlen][ucp_service_worker_address][onesided_info] + 8 bytes ucp_addrlen bytes 8 bytes service.ucp_addrlen bytes */ #define TL_UCP_EP_ADDRLEN_SIZE 8 #define TL_UCP_EP_ADDR_WORKER_LEN(_addr) (*((uint64_t*)(_addr))) -#define TL_UCP_EP_ADDR_WORKER(_addr) PTR_OFFSET((_addr), 8) -#define TL_UCP_EP_ADDR_ONESIDED_INFO(_addr) \ - PTR_OFFSET((_addr), 8 + TL_UCP_EP_ADDR_WORKER_LEN(_addr)) +#define TL_UCP_EP_ADDR_WORKER(_addr) PTR_OFFSET((_addr), TL_UCP_EP_ADDRLEN_SIZE) +#define TL_UCP_EP_OFFSET_WORKER_INFO(_addr) \ + PTR_OFFSET((_addr), \ + TL_UCP_EP_ADDRLEN_SIZE + TL_UCP_EP_ADDR_WORKER_LEN(_addr)) +#define TL_UCP_EP_ADDR_WORKER_SERVICE(_addr) \ + TL_UCP_EP_ADDR_WORKER(TL_UCP_EP_OFFSET_WORKER_INFO(_addr)) +#define TL_UCP_EP_ADDR_ONESIDED_INFO(_addr, _ctx) \ + _ctx->cfg.service_worker \ + ? TL_UCP_EP_OFFSET_WORKER_INFO(TL_UCP_EP_OFFSET_WORKER_INFO(_addr)) \ + : TL_UCP_EP_OFFSET_WORKER_INFO(_addr) typedef struct ucc_tl_ucp_context ucc_tl_ucp_context_t; typedef struct ucc_tl_ucp_team ucc_tl_ucp_team_t; -ucc_status_t ucc_tl_ucp_connect_team_ep(ucc_tl_ucp_team_t *team, - ucc_rank_t team_rank, - ucp_ep_h *ep); +ucc_status_t ucc_tl_ucp_connect_team_ep(ucc_tl_ucp_team_t *team, + ucc_rank_t team_rank, ucp_ep_h *ep); -void ucc_tl_ucp_close_eps(ucc_tl_ucp_context_t *ctx); +void ucc_tl_ucp_close_eps(ucc_tl_ucp_worker_t * worker, + ucc_tl_ucp_context_t *ctx); static inline ucc_context_addr_header_t * ucc_tl_ucp_get_team_ep_header(ucc_tl_ucp_team_t *team, ucc_rank_t core_rank) @@ -40,27 +52,25 @@ ucc_tl_ucp_get_team_ep_header(ucc_tl_ucp_team_t *team, ucc_rank_t core_rank) core_rank); } -static inline ucc_status_t ucc_tl_ucp_get_ep(ucc_tl_ucp_team_t *team, ucc_rank_t rank, - ucp_ep_h *ep) +static inline ucc_status_t ucc_tl_ucp_get_ep(ucc_tl_ucp_team_t *team, + ucc_rank_t rank, ucp_ep_h *ep) { - ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); ucc_context_addr_header_t *h = NULL; ucc_rank_t ctx_rank = 0; ucc_status_t status; ucc_rank_t core_rank; - core_rank = ucc_ep_map_eval(UCC_TL_TEAM_MAP(team), rank); - if (ctx->eps) { + if (team->worker->eps) { ucc_team_t *core_team = UCC_TL_CORE_TEAM(team); /* Core super.super.team ptr is NULL for service_team which has scope == UCC_CL_LAST + 1*/ ucc_assert((NULL != core_team) || IS_SERVICE_TEAM(team)); ctx_rank = core_team ? ucc_get_ctx_rank(core_team, core_rank) : core_rank; - *ep = ctx->eps[ctx_rank]; + *ep = team->worker->eps[ctx_rank]; } else { h = ucc_tl_ucp_get_team_ep_header(team, core_rank); - *ep = tl_ucp_hash_get(ctx->ep_hash, h->ctx_id); + *ep = tl_ucp_hash_get(team->worker->ep_hash, h->ctx_id); } if (NULL == (*ep)) { /* Not connected yet */ @@ -70,10 +80,10 @@ static inline ucc_status_t ucc_tl_ucp_get_ep(ucc_tl_ucp_team_t *team, ucc_rank_t *ep = NULL; return status; } - if (ctx->eps) { - ctx->eps[ctx_rank] = *ep; + if (!h) { + team->worker->eps[ctx_rank] = *ep; } else { - tl_ucp_hash_put(ctx->ep_hash, h->ctx_id, *ep); + tl_ucp_hash_put(team->worker->ep_hash, h->ctx_id, *ep); } } return UCC_OK; diff --git a/src/components/tl/ucp/tl_ucp_sendrecv.h b/src/components/tl/ucp/tl_ucp_sendrecv.h index 87142b869a..3b3ea6c1b2 100644 --- a/src/components/tl/ucp/tl_ucp_sendrecv.h +++ b/src/components/tl/ucp/tl_ucp_sendrecv.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Meta Platforms, Inc. and affiliates. 2022. * * See file LICENSE for terms. @@ -56,10 +56,10 @@ void ucc_tl_ucp_recv_completion_cb(void *request, ucs_status_t status, if (ucc_unlikely(UCS_PTR_IS_ERR(ucp_status))) { \ tl_error(UCC_TL_TEAM_LIB(team), \ "tag %u; dest %d; team_id %u; errmsg %s", \ - task->tagged.tag, \ - dest_group_rank, team->super.super.params.id, \ + task->tagged.tag, dest_group_rank, \ + team->super.super.params.id, \ ucs_status_string(UCS_PTR_STATUS(ucp_status))); \ - ucp_request_cancel(UCC_TL_UCP_WORKER(team), ucp_status); \ + ucp_request_cancel(team->worker->ucp_worker, ucp_status); \ ucp_request_free(ucp_status); \ return ucs_status_to_ucc_status(UCS_PTR_STATUS(ucp_status)); \ } \ @@ -152,7 +152,7 @@ ucc_tl_ucp_recv_common(void *buffer, size_t msglen, ucc_memory_type_t mtype, req_param.memory_type = ucc_memtype_to_ucs[mtype]; req_param.user_data = (void *)task; task->tagged.recv_posted++; - return ucp_tag_recv_nbx(UCC_TL_UCP_WORKER(team), buffer, 1, ucp_tag, + return ucp_tag_recv_nbx(team->worker->ucp_worker, buffer, 1, ucp_tag, ucp_tag_mask, &req_param); } @@ -246,7 +246,7 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep, offset = ucc_get_team_ep_addr(UCC_TL_CORE_CTX(team), UCC_TL_CORE_TEAM(team), core_rank, ucc_tl_ucp.super.super.id); - base_offset = (ptrdiff_t)TL_UCP_EP_ADDR_ONESIDED_INFO(offset); + base_offset = (ptrdiff_t)(TL_UCP_EP_ADDR_ONESIDED_INFO(offset, ctx)); rvas = (uint64_t *)base_offset; key_sizes = PTR_OFFSET(base_offset, (section_offset * 2)); keys = PTR_OFFSET(base_offset, (section_offset * 3)); @@ -282,8 +282,7 @@ static inline ucc_status_t ucc_tl_ucp_flush(ucc_tl_ucp_team_t *team) ucp_request_param_t req_param = {0}; ucs_status_ptr_t req; - req = - ucp_worker_flush_nbx(UCC_TL_UCP_TEAM_CTX(team)->ucp_worker, &req_param); + req = ucp_worker_flush_nbx(team->worker->ucp_worker, &req_param); if (UCS_OK != req) { if (UCS_PTR_IS_ERR(req)) { return ucs_status_to_ucc_status(UCS_PTR_STATUS(req)); diff --git a/src/components/tl/ucp/tl_ucp_team.c b/src/components/tl/ucp/tl_ucp_team.c index fafd0becab..536ce2d13b 100644 --- a/src/components/tl/ucp/tl_ucp_team.c +++ b/src/components/tl/ucp/tl_ucp_team.c @@ -96,6 +96,12 @@ ucc_status_t ucc_tl_ucp_team_create_test(ucc_base_team_t *tl_team) ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); ucc_status_t status; + if (USE_SERVICE_WORKER(team)) { + team->worker = &ctx->service_worker; + } else { + team->worker = &ctx->worker; + } + if (team->status == UCC_OK) { return UCC_OK; } diff --git a/src/core/ucc_context.c b/src/core/ucc_context.c index b48f8bf22d..ddc972c06a 100644 --- a/src/core/ucc_context.c +++ b/src/core/ucc_context.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ @@ -895,7 +895,7 @@ ucc_status_t ucc_context_progress_register(ucc_context_t *ctx, ucc_context_progress_entry_t *entry = ucc_malloc(sizeof(*entry), "progress_entry"); if (!entry) { - ucc_error("failed to allocate %zd bytes for progress ntry", + ucc_error("failed to allocate %zd bytes for progress entry", sizeof(*entry)); return UCC_ERR_NO_MEMORY; }