Skip to content

Commit

Permalink
TL/UCP: add special service worker (#560)
Browse files Browse the repository at this point in the history
* TL/UCP: add special service worker

* TL/UCP: option to set devices for service worker

* TL/UCP: remove is_service from task

* TL/UCP: refactor context init

* TL/UCP: fix context cleanup

* TL/UCP: remove is_used from context

* TL/UCP: add worker struct and ptr in team

* TL/UCP: fix deadlock w/ service & default workers

* TL/UCP: add throttling param for progress fn

* TL/UCP: separate prefix for service worker

* TL/UCP: fix memory leak

* TL/UCP: separate progress functions

* TL/UCP: change config interface of service worker

* TL/UCP: pass ptr instead of worker in cleanup

* TL/UCP: fix warnings in linter

Co-authored-by: Valentin Petrov <[email protected]>
  • Loading branch information
samnordmann and Valentin Petrov authored Oct 12, 2022
1 parent ad9ff51 commit 6091d4f
Show file tree
Hide file tree
Showing 14 changed files with 377 additions and 151 deletions.
4 changes: 2 additions & 2 deletions src/components/tl/ucp/alltoall/alltoall_onesided.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -57,7 +57,7 @@ void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask)

if ((*pSync < gsize) ||
(task->onesided.put_completed < task->onesided.put_posted)) {
ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->ucp_worker);
ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->worker.ucp_worker);
return;
}

Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/ucp/alltoall/alltoall_pairwise.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -45,7 +45,7 @@ void ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *coll_task)
while ((task->tagged.send_posted < gsize ||
task->tagged.recv_posted < gsize) &&
(polls++ < task->n_polls)) {
ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->ucp_worker);
ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->worker.ucp_worker);
while ((task->tagged.recv_posted < gsize) &&
((task->tagged.recv_posted - task->tagged.recv_completed) <
nreqs)) {
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/ucp/alltoallv/alltoallv_pairwise.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-20022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -46,7 +46,7 @@ void ucc_tl_ucp_alltoallv_pairwise_progress(ucc_coll_task_t *coll_task)
while ((task->tagged.send_posted < gsize ||
task->tagged.recv_posted < gsize) &&
(polls++ < task->n_polls)) {
ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->ucp_worker);
ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->worker.ucp_worker);
while ((task->tagged.recv_posted < gsize) &&
((task->tagged.recv_posted - task->tagged.recv_completed) <
nreqs)) {
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/ucp/reduce_scatter/reduce_scatter_ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ static inline ucc_status_t ucc_tl_ucp_test_ring(ucc_tl_ucp_task_t *task)
task->tagged.recv_posted == task->tagged.recv_completed) {
return UCC_OK;
}
ucp_worker_progress(TASK_CTX(task)->ucp_worker);
ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker);
}
return UCC_INPROGRESS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ static inline ucc_status_t ucc_tl_ucp_test_ring(ucc_tl_ucp_task_t *task)
task->tagged.recv_posted == task->tagged.recv_completed) {
return UCC_OK;
}
ucp_worker_progress(TASK_CTX(task)->ucp_worker);
ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker);
}
return UCC_INPROGRESS;
}
Expand Down
14 changes: 14 additions & 0 deletions src/components/tl/ucp/tl_ucp.c
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,20 @@ static ucs_config_field_t ucc_tl_ucp_context_config_table[] = {
ucc_offsetof(ucc_tl_ucp_context_config_t, pre_reg_mem),
UCC_CONFIG_TYPE_UINT},

{"SERVICE_WORKER", "n",
"If set to 0, uses the same worker for collectives and "
"service. If not, creates a special worker for service collectives "
"for which UCX_TL and UCX_NET_DEVICES are configured by the variables "
"UCC_TL_UCP_SERVICE_TLS and UCC_TL_UCP_SERVICE_NET_DEVICES respectively",
ucc_offsetof(ucc_tl_ucp_context_config_t, service_worker),
UCC_CONFIG_TYPE_BOOL},

{"SERVICE_THROTTLING_THRESH", "100",
"Number of call to ucc_context_progress function between two consecutive "
"calls to service worker progress function",
ucc_offsetof(ucc_tl_ucp_context_config_t, service_throttling_thresh),
UCC_CONFIG_TYPE_UINT},

{NULL}};

UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_ucp_lib_t, ucc_base_lib_t,
Expand Down
32 changes: 23 additions & 9 deletions src/components/tl/ucp/tl_ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ typedef struct ucc_tl_ucp_context_config {
uint32_t n_polls;
uint32_t oob_npolls;
uint32_t pre_reg_mem;
uint32_t service_worker;
uint32_t service_throttling_thresh;
} ucc_tl_ucp_context_config_t;

typedef struct ucc_tl_ucp_lib {
Expand All @@ -92,16 +94,22 @@ typedef struct ucc_tl_ucp_remote_info {
size_t packed_key_len;
} ucc_tl_ucp_remote_info_t;

typedef struct ucc_tl_ucp_worker {
ucp_context_h ucp_context;
ucp_worker_h ucp_worker;
size_t ucp_addrlen;
ucp_address_t * worker_address;
tl_ucp_ep_hash_t *ep_hash;
ucp_ep_h * eps;
} ucc_tl_ucp_worker_t;

typedef struct ucc_tl_ucp_context {
ucc_tl_context_t super;
ucc_tl_ucp_context_config_t cfg;
ucp_context_h ucp_context;
ucp_worker_h ucp_worker;
size_t ucp_addrlen;
ucp_address_t * worker_address;
ucc_tl_ucp_worker_t worker;
ucc_tl_ucp_worker_t service_worker;
int service_worker_throttling_count;
ucc_mpool_t req_mp;
tl_ucp_ep_hash_t * ep_hash;
ucp_ep_h * eps;
ucc_tl_ucp_remote_info_t * remote_info;
ucp_rkey_h * rkeys;
uint64_t n_rinfo_segs;
Expand All @@ -118,6 +126,7 @@ typedef struct ucc_tl_ucp_team {
ucc_tl_ucp_task_t *preconnect_task;
void * va_base[MAX_NR_SEGMENTS];
size_t base_length[MAX_NR_SEGMENTS];
ucc_tl_ucp_worker_t * worker;
} ucc_tl_ucp_team_t;
UCC_CLASS_DECLARE(ucc_tl_ucp_team_t, ucc_base_context_t *,
const ucc_base_team_params_t *);
Expand All @@ -135,15 +144,20 @@ UCC_CLASS_DECLARE(ucc_tl_ucp_team_t, ucc_base_context_t *,
#define UCC_TL_UCP_TEAM_CTX(_team) \
(ucc_derived_of((_team)->super.super.context, ucc_tl_ucp_context_t))

#define UCC_TL_UCP_WORKER(_team) UCC_TL_UCP_TEAM_CTX(_team)->ucp_worker
#define IS_SERVICE_TEAM(_team) \
((_team)->super.super.params.scope == UCC_CL_LAST + 1)

#define USE_SERVICE_WORKER(_team) \
(IS_SERVICE_TEAM(_team) && UCC_TL_UCP_TEAM_CTX(_team)->cfg.service_worker)

#define UCC_TL_UCP_TASK_TEAM(_task) \
(ucc_derived_of((_task)->super.team, ucc_tl_ucp_team_t))

#define UCC_TL_CTX_HAS_OOB(_ctx) \
((_ctx)->super.super.ucc_context->params.mask & UCC_CONTEXT_PARAM_FIELD_OOB)

#define UCC_TL_CTX_OOB(_ctx) ((_ctx)->super.super.ucc_context->params.oob)

#define IS_SERVICE_TEAM(_team) ((_team)->super.super.params.scope == UCC_CL_LAST + 1)

#define UCC_TL_UCP_REMOTE_RKEY(_ctx, _rank, _seg) \
((_ctx)->rkeys[_rank * _ctx->n_rinfo_segs + _seg])

Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ static inline ucc_status_t ucc_tl_ucp_test(ucc_tl_ucp_task_t *task)
if (UCC_TL_UCP_TASK_P2P_COMPLETE(task)) {
return UCC_OK;
}
ucp_worker_progress(TASK_CTX(task)->ucp_worker);
ucp_worker_progress(UCC_TL_UCP_TASK_TEAM(task)->worker->ucp_worker);
}
return UCC_INPROGRESS;
}
Expand Down
Loading

0 comments on commit 6091d4f

Please sign in to comment.