Skip to content

Commit

Permalink
Use NCCL bootsrap backend for TP communication overlaps (#10622)
Browse files Browse the repository at this point in the history
* Use NCCL bootsrap backend for TP communication overlaps

Signed-off-by: Sangkug Lym <[email protected]>

* Apply isort and black reformatting

Signed-off-by: erhoo82 <[email protected]>

* fix for 2.0

Signed-off-by: Sangkug Lym <[email protected]>

---------

Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: erhoo82 <[email protected]>
Co-authored-by: erhoo82 <[email protected]>
  • Loading branch information
erhoo82 and erhoo82 authored Nov 20, 2024
1 parent 4633db6 commit 8b7999d
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
global_batch_size=cfg.get('global_batch_size'),
rampup_batch_size=cfg.get('rampup_batch_size', None),
use_fp8=cfg.get('fp8', False),
init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False),
init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False)
and cfg.get('ub_tp_comm_bootstrap_backend', 'nccl') == 'mpi',
seed=self.cfg.get('seed', 1234),
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
use_te_rng_tracker=self.cfg.get('use_te_rng_tracker', False),
Expand Down Expand Up @@ -1173,6 +1174,7 @@ def build_model_parallel_config(self) -> ModelParallelConfig:
"grad_sync_func": None, # set dynamically during training
"param_sync_func": None, # set dynamically during training
"tp_comm_overlap": self.cfg.get('ub_tp_comm_overlap', False),
"tp_comm_bootstrap_backend": self.cfg.get('ub_tp_comm_bootstrap_backend', 'nccl'),
}

# instantitate ModelParallelConfig from this dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def initialize_ub_func(self):
tp_size=self.cfg.get('tensor_model_parallel_size'),
use_fp8=self.cfg.get('fp8'),
ub_cfgs=ub_cfgs,
bootstrap_backend=self.cfg.get('ub_tp_comm_bootstrap_backend', 'nccl'),
)
self.initialize_ub = False

Expand Down
9 changes: 7 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,14 @@ def init_model_parallel(
app_state.data_parallel_size = parallel_state.get_data_parallel_world_size()
app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group()

# create MPI process group for UCX-based communication APIs
if app_state.init_mpi_proc_group:
torch.distributed.new_group(backend='mpi')
import packaging

te_version = packaging.version.Version(version('transformer_engine'))
if te_version < packaging.version.Version("1.9"):
# Create MPI process group for bootstrapping at old TE versions.
# From TE version v1.9, the process group is initialized in TE.
torch.distributed.new_group(backend='mpi')


class NLPDDPStrategy(DDPStrategy):
Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def init_parallel_ranks(
seed=seed,
pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None),
use_fp8=fp8,
init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False),
init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False)
and getattr(parallel_config, "tp_comm_bootstrap_backend", None) == 'mpi',
# apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
)

Expand Down
7 changes: 7 additions & 0 deletions nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class _CommOverlapConfig:
# Tensor parallel communication overlap (experimental)
tp_comm_overlap: bool = None
tp_comm_overlap_cfg: dict = None
tp_comm_bootstrap_backend: str = None
# Pipeline parallel communication overlap
overlap_p2p_comm: bool = None
batch_p2p_comm: bool = None
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
self,
tp_comm_overlap: bool = None,
tp_comm_overlap_cfg: TransformerLayerTPOverlapCfg = None,
tp_comm_bootstrap_backend: str = None,
overlap_p2p_comm: bool = None,
batch_p2p_comm: bool = None,
overlap_grad_reduce: bool = None,
Expand All @@ -102,6 +104,7 @@ def __init__(
self.user_comm_overlap_cfg = _CommOverlapConfig(
tp_comm_overlap=tp_comm_overlap,
tp_comm_overlap_cfg=tp_comm_overlap_cfg,
tp_comm_bootstrap_backend=tp_comm_bootstrap_backend,
overlap_p2p_comm=overlap_p2p_comm,
batch_p2p_comm=batch_p2p_comm,
overlap_grad_reduce=overlap_grad_reduce,
Expand All @@ -114,6 +117,7 @@ def __init__(
)

self.tp_comm_overlap_cfg = None
self.tp_comm_bootstrap_backend = None
self.need_tp_overlap_ub_init = False

def _get_model_comm_overlap_cfgs(
Expand All @@ -129,6 +133,7 @@ def _get_model_comm_overlap_cfgs(
# Optimizations disabled by default, can be overriden by user
comm_overlap_cfg.tp_comm_overlap = False
comm_overlap_cfg.tp_comm_overlap_cfg = None
comm_overlap_cfg.tp_comm_bootstrap_backend = None
comm_overlap_cfg.defer_embedding_wgrad_compute = False
comm_overlap_cfg.wgrad_deferral_limit = -1

Expand Down Expand Up @@ -216,6 +221,7 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str)

if trainer.model.config.tp_comm_overlap:
self.tp_comm_overlap_cfg = comm_overlap_cfg.tp_comm_overlap_cfg
self.tp_comm_bootstrap_backend = comm_overlap_cfg.tp_comm_bootstrap_backend
self.need_tp_overlap_ub_init = True

# Data parallel overlap is only available with the Megatron DDP and Distributed optimizer
Expand Down Expand Up @@ -258,6 +264,7 @@ def _init_te_userbuffers(self, model_parallel_cfg: ModelParallelConfig):
tp_size=parallel_state.get_tensor_model_parallel_world_size(),
use_fp8=fp8,
ub_cfgs=self.tp_comm_overlap_cfg,
bootstrap_backend=self.tp_comm_bootstrap_backend,
)
except Exception as error:
raise Exception(f"Tensor parallel overlap: userbuffer initialization failed with {error}")
Expand Down

0 comments on commit 8b7999d

Please sign in to comment.