diff --git a/nemo/utils/get_rank.py b/nemo/utils/get_rank.py index 9b36eed6246b..37d3906760e7 100644 --- a/nemo/utils/get_rank.py +++ b/nemo/utils/get_rank.py @@ -32,9 +32,14 @@ def is_global_rank_zero(): if slurm_rank is not None: return slurm_rank == 0 - # if neither pytorch and SLURM env vars are set + # Try to get the MPI global rank env var + mpi_rank = get_envint("OMPI_COMM_WORLD_RANK", None) + if mpi_rank is not None: + return mpi_rank == 0 + + # if neither pytorch, SLURM nor MPI env vars are set # check NODE_RANK/GROUP_RANK and LOCAL_RANK env vars - # asume global_rank is zero if undefined + # assume global_rank is zero if undefined node_rank = get_envint("NODE_RANK", get_envint("GROUP_RANK", 0)) local_rank = get_envint("LOCAL_RANK", 0) return node_rank == 0 and local_rank == 0