diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index e588e94a6720..c9651da74c3b 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -1,3 +1,6 @@ +defaults: + - optional tp_overlap@model.ub_tp_comm_overlap_cfg: + name: megatron_gpt restore_from_path: null # used when starting from a .nemo file diff --git a/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs1_seqlen2048.yaml b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs1_seqlen2048.yaml new file mode 100644 index 000000000000..c6e25c087ffc --- /dev/null +++ b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs1_seqlen2048.yaml @@ -0,0 +1,53 @@ +# UB communicator configurations +# Model configs: A100/175B/TP4/MBS1/SeqLen2K/BF16 + +# Bulk overlap with AllGather +qkv_dgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +qkv_wgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +fc1_dgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +fc1_wgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +## Ring-exchange overlap with AllGather +qkv_fprop: + method: ring_exchange + aggregate: 0 + +proj_dgrad: + method: ring_exchange + aggregate: 0 + +fc1_fprop: + method: ring_exchange + aggregate: 0 + +fc2_dgrad: + method: ring_exchange + aggregate: 0 + +# Chunked-collective overlap with ReduceScatter +proj_fprop: + method: pipeline + num_sm: 4 + num_splits: 4 + set_sm_margin: 0 + +fc2_fprop: + method: pipeline + num_sm: 4 + num_splits: 4 + set_sm_margin: 0 diff --git a/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs2_seqlen2048.yaml b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs2_seqlen2048.yaml new file mode 100644 index 000000000000..434e0a29f42c --- /dev/null +++ b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs2_seqlen2048.yaml @@ -0,0 +1,53 @@ +# UB communicator configurations +# Model configs: A100/175B/TP4/MBS2/SeqLen2K/BF16 + +# Bulk overlap with AllGather +qkv_dgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +qkv_wgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +fc1_dgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +fc1_wgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +## Ring-exchange overlap with AllGather +qkv_fprop: + method: ring_exchange + aggregate: 0 + +proj_dgrad: + method: ring_exchange + aggregate: 0 + +fc1_fprop: + method: ring_exchange + aggregate: 0 + +fc2_dgrad: + method: ring_exchange + aggregate: 0 + +# Chunked-collective overlap with ReduceScatter +proj_fprop: + method: pipeline + num_sm: 8 + num_splits: 4 + set_sm_margin: 0 + +fc2_fprop: + method: pipeline + num_sm: 4 + num_splits: 4 + set_sm_margin: 0 diff --git a/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp4_mbs1_seqlen2048.yaml b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp4_mbs1_seqlen2048.yaml new file mode 100644 index 000000000000..21d02f3dd22c --- /dev/null +++ b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp4_mbs1_seqlen2048.yaml @@ -0,0 +1,59 @@ +# UB communicator configurations +# Model configs: H100/175B/TP4/MBS1/SeqLen2K/FP8 + +# Bulk overlap with AllGather / ReduceScatter +qkv_dgrad: + method: bulk + num_sm: 4 + cga_size: 2 + set_sm_margin: 0 + +qkv_wgrad: + method: bulk + num_sm: 8 + cga_size: 2 + set_sm_margin: 0 + +fc1_dgrad: + method: bulk + num_sm: 2 + cga_size: 2 + set_sm_margin: 0 + +fc1_wgrad: + method: bulk + num_sm: 4 + cga_size: 2 + set_sm_margin: 0 + +## Ring-exchange overlap with AllGather +qkv_fprop: + method: ring_exchange + aggregate: 0 + +proj_dgrad: + method: ring_exchange + aggregate: 0 + +fc1_fprop: + method: ring_exchange + aggregate: 0 + +fc2_dgrad: + method: ring_exchange + aggregate: 1 + +# Chunked-collective overlap with ReduceScatter +proj_fprop: + method: pipeline + num_sm: 24 + cga_size: 2 + num_splits: 4 + set_sm_margin: 1 + +fc2_fprop: + method: pipeline + num_sm: 20 + cga_size: 2 + num_splits: 4 + set_sm_margin: 1 diff --git a/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp8_mbs2_seqlen2048.yaml b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp8_mbs2_seqlen2048.yaml new file mode 100644 index 000000000000..444c8245e02c --- /dev/null +++ b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp8_mbs2_seqlen2048.yaml @@ -0,0 +1,59 @@ +# UB communicator configurations +# Model configs: H100/175B/TP8/MBS2/SeqLen2K/FP8 + +# Bulk overlap with AllGather +qkv_dgrad: + method: bulk + num_sm: 8 + cga_size: 2 + set_sm_margin: 0 + +qkv_wgrad: + method: bulk + num_sm: 16 + cga_size: 2 + set_sm_margin: 0 + +fc1_dgrad: + method: bulk + num_sm: 4 + cga_size: 2 + set_sm_margin: 0 + +fc1_wgrad: + method: bulk + num_sm: 16 + cga_size: 2 + set_sm_margin: 0 + +## Ring-exchange overlap with AllGather +qkv_fprop: + method: ring_exchange + aggregate: 0 + +proj_dgrad: + method: ring_exchange + aggregate: 1 + +fc1_fprop: + method: ring_exchange + aggregate: 0 + +fc2_dgrad: + method: ring_exchange + aggregate: 0 + +# Chunked-collective overlap with ReduceScatter +proj_fprop: + method: pipeline + num_sm: 16 + cga_size: 2 + num_splits: 4 + set_sm_margin: 1 + +fc2_fprop: + method: pipeline + num_sm: 24 + cga_size: 2 + num_splits: 4 + set_sm_margin: 1 diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 97d217c5c47a..7874b35dd83b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -511,20 +511,16 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): return loss_mean def initialize_ub_func(self): + ub_cfgs = self.cfg.get('ub_tp_comm_overlap_cfg', None) + if ub_cfgs is None: + warnings.warn( + "Couldn't find TP config. Please check the path correctness. Initializing TP comm overlap with the default config." + ) + input_shape = [ self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), self.cfg.get('hidden_size'), ] - ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None) - ub_cfgs = None - if ub_cfg_file_name is not None: - try: - import yaml - - with open(ub_cfg_file_name, 'r') as ub_cfg_file: - ub_cfgs = yaml.safe_load(ub_cfg_file) - except (ImportError, TypeError): - logging.error(f"Fail to read ub_tp_comm_overlap config file: {ub_cfg_file_name}.") te_module.base.initialize_ub( shape=input_shape, tp_size=self.cfg.get('tensor_model_parallel_size'), diff --git a/nemo/core/config/hydra_runner.py b/nemo/core/config/hydra_runner.py index 6c6c9b47e0fd..385db95a7299 100644 --- a/nemo/core/config/hydra_runner.py +++ b/nemo/core/config/hydra_runner.py @@ -23,6 +23,24 @@ from hydra.types import TaskFunction from omegaconf import DictConfig, OmegaConf + +def _get_gpu_name(): + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + cuda_capability, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + pynvml.nvmlShutdown() + if cuda_capability == 8: + return "a100" + elif cuda_capability == 9: + return "h100" + else: + return None + + +OmegaConf.register_new_resolver("gpu_name", _get_gpu_name) + # multiple interpolated values in the config OmegaConf.register_new_resolver("multiply", lambda x, y: x * y)