Skip to content

Commit

Permalink
coll: add coll_group to treealgo routines
Browse files Browse the repository at this point in the history
The topology-aware tree utilities need check coll_group for correct
world ranks.
  • Loading branch information
hzhou committed Aug 23, 2024
1 parent 59e9b90 commit 295e2e0
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 60 deletions.
52 changes: 30 additions & 22 deletions src/mpi/coll/algorithms/treealgo/treealgo.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,49 +33,55 @@ int MPII_Treealgo_comm_cleanup(MPIR_Comm * comm)
return mpi_errno;
}

static bool match_param_topo_aware(MPIR_Treealgo_param_t * param, int root, int k)
static bool match_param_topo_aware(MPIR_Treealgo_param_t * param, int coll_group, int root, int k)
{
return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE &&
param->coll_group == coll_group && coll_group < MPIR_SUBGROUP_NUM_RESERVED &&
param->root == root && param->u.topo_aware.k == k);
}

static void set_param_topo_aware(MPIR_Treealgo_param_t * param, int root, int k)
static void set_param_topo_aware(MPIR_Treealgo_param_t * param, int coll_group, int root, int k)
{
param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE;
param->coll_group = coll_group;
param->root = root;
param->u.topo_aware.k = k;
}

static bool match_param_topo_aware_k(MPIR_Treealgo_param_t * param, int root, int k)
static bool match_param_topo_aware_k(MPIR_Treealgo_param_t * param, int coll_group, int root, int k)
{
return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K &&
param->coll_group == coll_group && coll_group < MPIR_SUBGROUP_NUM_RESERVED &&
param->root == root && param->u.topo_aware.k == k);
}

static void set_param_topo_aware_k(MPIR_Treealgo_param_t * param, int root, int k)
static void set_param_topo_aware_k(MPIR_Treealgo_param_t * param, int coll_group, int root, int k)
{
param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE_K;
param->coll_group = coll_group;
param->root = root;
param->u.topo_aware.k = k;
}

static inline bool match_param_topo_wave(MPIR_Treealgo_param_t * param,
static inline bool match_param_topo_wave(MPIR_Treealgo_param_t * param, int coll_group,
int root, int overhead, int lat_diff_groups,
int lat_diff_switches, int lat_same_switches)
{
return (param->type == MPIR_TREE_TYPE_TOPOLOGY_WAVE &&
param->coll_group == coll_group && coll_group < MPIR_SUBGROUP_NUM_RESERVED &&
param->root == root &&
param->u.topo_wave.overhead == overhead &&
param->u.topo_wave.lat_diff_groups == lat_diff_groups &&
param->u.topo_wave.lat_diff_switches == lat_diff_switches &&
param->u.topo_wave.lat_same_switches == lat_same_switches);
}

static inline void set_param_topo_wave(MPIR_Treealgo_param_t * param,
static inline void set_param_topo_wave(MPIR_Treealgo_param_t * param, int coll_group,
int root, int overhead, int lat_diff_groups,
int lat_diff_switches, int lat_same_switches)
{
param->type = MPIR_TREE_TYPE_TOPOLOGY_WAVE;
param->coll_group = coll_group;
param->root = root;
param->u.topo_wave.overhead = overhead;
param->u.topo_wave.lat_diff_groups = lat_diff_groups;
Expand Down Expand Up @@ -125,7 +131,8 @@ int MPIR_Treealgo_tree_create(int rank, int nranks, int tree_type, int k, int ro
}


int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, int root,
int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int coll_group, int tree_type,
int k, int root,
bool enable_reorder, MPIR_Treealgo_tree_t * ct)
{
int mpi_errno = MPI_SUCCESS;
Expand All @@ -135,7 +142,7 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k,
switch (tree_type) {
case MPIR_TREE_TYPE_TOPOLOGY_AWARE:
if (!comm->coll.cached_tree ||
!match_param_topo_aware(&comm->coll.cached_tree_param, root, k)) {
!match_param_topo_aware(&comm->coll.cached_tree_param, coll_group, root, k)) {
if (comm->coll.cached_tree) {
MPIR_Treealgo_tree_free(comm->coll.cached_tree);
} else {
Expand All @@ -144,11 +151,11 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k,
MPL_MEM_BUFFER);
}
mpi_errno =
MPII_Treeutil_tree_topology_aware_init(comm, k, root, enable_reorder,
comm->coll.cached_tree);
MPII_Treeutil_tree_topology_aware_init(comm, coll_group, k, root,
enable_reorder, comm->coll.cached_tree);
MPIR_ERR_CHECK(mpi_errno);
*ct = *comm->coll.cached_tree;
set_param_topo_aware(&comm->coll.cached_tree_param, root, k);
set_param_topo_aware(&comm->coll.cached_tree_param, coll_group, root, k);
}
*ct = *comm->coll.cached_tree;
utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL);
Expand All @@ -160,7 +167,7 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k,

case MPIR_TREE_TYPE_TOPOLOGY_AWARE_K:
if (!comm->coll.cached_tree ||
!match_param_topo_aware_k(&comm->coll.cached_tree_param, root, k)) {
!match_param_topo_aware_k(&comm->coll.cached_tree_param, coll_group, root, k)) {
if (comm->coll.cached_tree) {
MPIR_Treealgo_tree_free(comm->coll.cached_tree);
} else {
Expand All @@ -169,11 +176,12 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k,
MPL_MEM_BUFFER);
}
mpi_errno =
MPII_Treeutil_tree_topology_aware_k_init(comm, k, root, enable_reorder,
MPII_Treeutil_tree_topology_aware_k_init(comm, coll_group, k, root,
enable_reorder,
comm->coll.cached_tree);
MPIR_ERR_CHECK(mpi_errno);
*ct = *comm->coll.cached_tree;
set_param_topo_aware_k(&comm->coll.cached_tree_param, root, k);
set_param_topo_aware_k(&comm->coll.cached_tree_param, coll_group, root, k);
}
*ct = *comm->coll.cached_tree;
utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL);
Expand Down Expand Up @@ -201,7 +209,7 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k,
}


int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root,
int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int coll_group, int k, int root,
bool enable_reorder, int overhead, int lat_diff_groups,
int lat_diff_switches, int lat_same_switches,
MPIR_Treealgo_tree_t * ct)
Expand All @@ -211,21 +219,21 @@ int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root,
MPIR_FUNC_ENTER;

if (!comm->coll.cached_tree ||
!match_param_topo_wave(&comm->coll.cached_tree_param, root, overhead,
lat_diff_groups, lat_diff_switches, lat_same_switches)) {
!match_param_topo_wave(&comm->coll.cached_tree_param, coll_group, root,
overhead, lat_diff_groups, lat_diff_switches, lat_same_switches)) {
if (comm->coll.cached_tree) {
MPIR_Treealgo_tree_free(comm->coll.cached_tree);
} else {
comm->coll.cached_tree =
(MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER);
}
mpi_errno = MPII_Treeutil_tree_topology_wave_init(comm, k, root, enable_reorder, overhead,
lat_diff_groups, lat_diff_switches,
lat_same_switches,
comm->coll.cached_tree);
mpi_errno =
MPII_Treeutil_tree_topology_wave_init(comm, coll_group, k, root, enable_reorder,
overhead, lat_diff_groups, lat_diff_switches,
lat_same_switches, comm->coll.cached_tree);
MPIR_ERR_CHECK(mpi_errno);
*ct = *comm->coll.cached_tree;
set_param_topo_wave(&comm->coll.cached_tree_param, root, overhead,
set_param_topo_wave(&comm->coll.cached_tree_param, coll_group, root, overhead,
lat_diff_groups, lat_diff_switches, lat_same_switches);
}
*ct = *comm->coll.cached_tree;
Expand Down
5 changes: 3 additions & 2 deletions src/mpi/coll/algorithms/treealgo/treealgo.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ int MPII_Treealgo_comm_init(MPIR_Comm * comm);
int MPII_Treealgo_comm_cleanup(MPIR_Comm * comm);
int MPIR_Treealgo_tree_create(int rank, int nranks, int tree_type, int k, int root,
MPIR_Treealgo_tree_t * ct);
int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, int root,
int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int coll_group, int tree_type,
int k, int root,
bool enable_reorder, MPIR_Treealgo_tree_t * ct);
int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root,
int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int coll_group, int k, int root,
bool enable_reorder, int overhead, int lat_diff_groups,
int lat_diff_switches, int lat_same_switches,
MPIR_Treealgo_tree_t * ct);
Expand Down
1 change: 1 addition & 0 deletions src/mpi/coll/algorithms/treealgo/treealgo_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ typedef struct {

typedef struct {
MPIR_Tree_type_t type;
int coll_group;
int root;
union {
struct {
Expand Down
55 changes: 34 additions & 21 deletions src/mpi/coll/algorithms/treealgo/treeutil.c
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ static void MPII_Treeutil_hierarchy_reorder(UT_array * hierarchy, int rank)
}

/* tree init function is for building hierarchy of MPIR_Process::coords_dims */
static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int rank, int nranks, int root,
bool enable_reorder, UT_array * hierarchy)
static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int coll_group, int rank, int nranks,
int root, bool enable_reorder, UT_array * hierarchy)
{
int mpi_errno = MPI_SUCCESS;

Expand Down Expand Up @@ -504,8 +504,12 @@ static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int rank, int nran
MPIR_Assert(upper_level != NULL);

/* Get wrank from the communicator as the coords are stored with wrank */
int comm_rank = r;
if (coll_group > 0) {
comm_rank = comm->subgroups[coll_group].proc_table[r];
}
uint64_t temp = 0;
MPID_Comm_get_lpid(comm, r, &temp, FALSE);
MPID_Comm_get_lpid(comm, comm_rank, &temp, FALSE);
int wrank = (int) temp;
if (wrank < 0)
goto fn_fail;
Expand Down Expand Up @@ -600,20 +604,22 @@ static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int rank, int nran
* build the hierarchy of the topology-aware tree.
* For the mentioned cases see tags 'goto fn_fallback;'. */

int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bool enable_reorder,
MPIR_Treealgo_tree_t * ct)
int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int coll_group, int k, int root,
bool enable_reorder, MPIR_Treealgo_tree_t * ct)
{
int mpi_errno = MPI_SUCCESS;
int rank = comm->rank;
int nranks = comm->local_size;

int rank, nranks;
MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks);

UT_array hierarchy[MAX_HIERARCHY_DEPTH];
int dim = MPIR_Process.coords_dims - 1;
for (dim = MPIR_Process.coords_dims - 1; dim >= 0; --dim)
tree_ut_hierarchy_init(&hierarchy[dim]);

if (k <= 0 ||
0 != MPII_Treeutil_hierarchy_populate(comm, rank, nranks, root, enable_reorder, hierarchy))
0 != MPII_Treeutil_hierarchy_populate(comm, coll_group, rank, nranks, root, enable_reorder,
hierarchy))
goto fn_fallback;

ct->rank = rank;
Expand Down Expand Up @@ -695,16 +701,18 @@ int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bo
}

/* Implementation of 'Topology aware' algorithm with the branching factor k */
int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, bool enable_reorder,
MPIR_Treealgo_tree_t * ct)
int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int coll_group, int k, int root,
bool enable_reorder, MPIR_Treealgo_tree_t * ct)
{
int mpi_errno = MPI_SUCCESS;
int rank = comm->rank;
int nranks = comm->local_size;

int rank, nranks;
MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks);

/* fall back to MPII_Treeutil_tree_topology_aware_init if k is less or equal to 2 */
if (k <= 2) {
return MPII_Treeutil_tree_topology_aware_init(comm, k, root, enable_reorder, ct);
return MPII_Treeutil_tree_topology_aware_init(comm, coll_group, k, root, enable_reorder,
ct);
}

int *num_childrens = NULL;
Expand All @@ -719,7 +727,9 @@ int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root,
for (dim = MPIR_Process.coords_dims - 1; dim >= 0; --dim)
tree_ut_hierarchy_init(&hierarchy[dim]);

if (0 != MPII_Treeutil_hierarchy_populate(comm, rank, nranks, root, enable_reorder, hierarchy))
if (0 !=
MPII_Treeutil_hierarchy_populate(comm, coll_group, rank, nranks, root, enable_reorder,
hierarchy))
goto fn_fallback;

ct->rank = rank;
Expand Down Expand Up @@ -758,7 +768,7 @@ int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root,
/* Do an allgather to know the current num_children on each rank */
MPIR_Errflag_t errflag = MPIR_ERR_NONE;
MPIR_Allgather_impl(&(ct->num_children), 1, MPI_INT, num_childrens, 1, MPI_INT,
comm, MPIR_SUBGROUP_NONE, errflag);
comm, coll_group, errflag);
if (mpi_errno) {
goto fn_fail;
}
Expand Down Expand Up @@ -1111,13 +1121,12 @@ static int init_root_switch(const UT_array * hierarchy, heap_vector * minHeaps,
}

/* 'Topology Wave' implementation */
int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, bool enable_reorder,
int overhead, int lat_diff_groups, int lat_diff_switches,
int lat_same_switches, MPIR_Treealgo_tree_t * ct)
int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int coll_group, int k, int root,
bool enable_reorder, int overhead, int lat_diff_groups,
int lat_diff_switches, int lat_same_switches,
MPIR_Treealgo_tree_t * ct)
{
int mpi_errno = MPI_SUCCESS;
int rank = comm->rank;
int nranks = comm->local_size;
int root_gr_sorted_idx = 0;
int root_sw_sorted_idx = 0;
int group_offset = 0;
Expand All @@ -1126,6 +1135,9 @@ int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, boo
UT_array hierarchy[MAX_HIERARCHY_DEPTH];
UT_array *unv_set = NULL;

int rank, nranks;
MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks);

heap_vector minHeaps;
heap_vector_init(&minHeaps);

Expand All @@ -1135,7 +1147,8 @@ int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, boo
tree_ut_hierarchy_init(&hierarchy[dim]);

if (overhead <= 0 || lat_diff_groups <= 0 || lat_diff_switches <= 0 || lat_same_switches <= 0 ||
0 != MPII_Treeutil_hierarchy_populate(comm, rank, nranks, root, enable_reorder, hierarchy))
0 != MPII_Treeutil_hierarchy_populate(comm, coll_group, rank, nranks, root, enable_reorder,
hierarchy))
goto fn_fallback;

UT_icd intpair_icd = { sizeof(pair), NULL, NULL, NULL };
Expand Down
15 changes: 8 additions & 7 deletions src/mpi/coll/algorithms/treealgo/treeutil.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,16 @@ int MPII_Treeutil_tree_knomial_2_init(int rank, int nranks, int k, int root,
MPIR_Treealgo_tree_t * ct);

/* Generate topology_aware tree information */
int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bool enable_reorder,
MPIR_Treealgo_tree_t * ct);
int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int coll_group, int k, int root,
bool enable_reorder, MPIR_Treealgo_tree_t * ct);

int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, bool enable_reorder,
MPIR_Treealgo_tree_t * ct);
int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int coll_group, int k, int root,
bool enable_reorder, MPIR_Treealgo_tree_t * ct);

/* Generate topology_wave tree information */
int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, bool enable_reorder,
int overhead, int lat_diff_groups, int lat_diff_switches,
int lat_same_switches, MPIR_Treealgo_tree_t * ct);
int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int coll_group, int k, int root,
bool enable_reorder, int overhead, int lat_diff_groups,
int lat_diff_switches, int lat_same_switches,
MPIR_Treealgo_tree_t * ct);

#endif /* TREEUTIL_H_INCLUDED */
4 changes: 2 additions & 2 deletions src/mpi/coll/allreduce/allreduce_intra_tree.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf,
/* initialize the tree */
if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE || tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K) {
mpi_errno =
MPIR_Treealgo_tree_create_topo_aware(comm_ptr, tree_type, k, root,
MPIR_Treealgo_tree_create_topo_aware(comm_ptr, coll_group, tree_type, k, root,
MPIR_CVAR_ALLREDUCE_TOPO_REORDER_ENABLE, &my_tree);
} else if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_WAVE) {
MPIR_Csel_coll_sig_s coll_sig = {
Expand Down Expand Up @@ -96,7 +96,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf,
}

mpi_errno =
MPIR_Treealgo_tree_create_topo_wave(comm_ptr, k, root,
MPIR_Treealgo_tree_create_topo_wave(comm_ptr, coll_group, k, root,
MPIR_CVAR_ALLREDUCE_TOPO_REORDER_ENABLE,
overhead, lat_diff_groups, lat_diff_switches,
lat_same_switches, &my_tree);
Expand Down
Loading

0 comments on commit 295e2e0

Please sign in to comment.