From e572bfce9f19da495b7ebcd47fffefbfb474da40 Mon Sep 17 00:00:00 2001 From: Jan Baczek Date: Fri, 7 Jul 2023 17:17:02 +0200 Subject: [PATCH 01/11] Pass tp config via hydra Signed-off-by: Jan Baczek --- .../language_modeling/megatron_gpt_model.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 97d217c5c47a..23c7488f99fe 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -282,6 +282,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True) self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False) + self.ub_cfgs = self.cfg.get('ub_tp_comm_overlap_cfg', None) def get_gpt_module_list(self): if isinstance(self.model, list): @@ -515,21 +516,11 @@ def initialize_ub_func(self): self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), self.cfg.get('hidden_size'), ] - ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None) - ub_cfgs = None - if ub_cfg_file_name is not None: - try: - import yaml - - with open(ub_cfg_file_name, 'r') as ub_cfg_file: - ub_cfgs = yaml.safe_load(ub_cfg_file) - except (ImportError, TypeError): - logging.error(f"Fail to read ub_tp_comm_overlap config file: {ub_cfg_file_name}.") - te_module.base.initialize_ub( + te_module.initialize_ub( shape=input_shape, tp_size=self.cfg.get('tensor_model_parallel_size'), use_fp8=self.cfg.get('fp8'), - ub_cfgs=ub_cfgs, + ub_cfgs=self.ub_cfgs, ) self.initialize_ub = False From 24444200de4157f242f846821b4c21155f148d56 Mon Sep 17 00:00:00 2001 From: Jan Baczek Date: Mon, 10 Jul 2023 15:43:07 +0200 Subject: [PATCH 02/11] Remove self.ub_cfgs field - it isn't used anywhere else Signed-off-by: Jan Baczek --- .../nlp/models/language_modeling/megatron_gpt_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 23c7488f99fe..6381888c4fce 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -282,7 +282,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True) self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False) - self.ub_cfgs = self.cfg.get('ub_tp_comm_overlap_cfg', None) def get_gpt_module_list(self): if isinstance(self.model, list): @@ -520,7 +519,7 @@ def initialize_ub_func(self): shape=input_shape, tp_size=self.cfg.get('tensor_model_parallel_size'), use_fp8=self.cfg.get('fp8'), - ub_cfgs=self.ub_cfgs, + ub_cfgs=self.cfg.get('ub_tp_comm_overlap_cfg', None), ) self.initialize_ub = False From 2b2ed389fd8920c8a15cb4b28ae36c7ce039853b Mon Sep 17 00:00:00 2001 From: Jan Baczek Date: Thu, 13 Jul 2023 20:47:26 +0200 Subject: [PATCH 03/11] Allow tp_overlap tree substitution in hydra config Signed-off-by: Jan Baczek --- .../conf/megatron_gpt_config.yaml | 3 + ...b_cfg_a100_h12288_tp4_mbs1_seqlen2048.yaml | 53 +++++++++++++++++ ...b_cfg_a100_h12288_tp4_mbs2_seqlen2048.yaml | 53 +++++++++++++++++ ...b_cfg_h100_h12288_tp4_mbs1_seqlen2048.yaml | 59 +++++++++++++++++++ ...b_cfg_h100_h12288_tp8_mbs2_seqlen2048.yaml | 59 +++++++++++++++++++ 5 files changed, 227 insertions(+) create mode 100644 examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs1_seqlen2048.yaml create mode 100644 examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs2_seqlen2048.yaml create mode 100644 examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp4_mbs1_seqlen2048.yaml create mode 100644 examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp8_mbs2_seqlen2048.yaml diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index e588e94a6720..c9651da74c3b 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -1,3 +1,6 @@ +defaults: + - optional tp_overlap@model.ub_tp_comm_overlap_cfg: + name: megatron_gpt restore_from_path: null # used when starting from a .nemo file diff --git a/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs1_seqlen2048.yaml b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs1_seqlen2048.yaml new file mode 100644 index 000000000000..c6e25c087ffc --- /dev/null +++ b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs1_seqlen2048.yaml @@ -0,0 +1,53 @@ +# UB communicator configurations +# Model configs: A100/175B/TP4/MBS1/SeqLen2K/BF16 + +# Bulk overlap with AllGather +qkv_dgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +qkv_wgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +fc1_dgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +fc1_wgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +## Ring-exchange overlap with AllGather +qkv_fprop: + method: ring_exchange + aggregate: 0 + +proj_dgrad: + method: ring_exchange + aggregate: 0 + +fc1_fprop: + method: ring_exchange + aggregate: 0 + +fc2_dgrad: + method: ring_exchange + aggregate: 0 + +# Chunked-collective overlap with ReduceScatter +proj_fprop: + method: pipeline + num_sm: 4 + num_splits: 4 + set_sm_margin: 0 + +fc2_fprop: + method: pipeline + num_sm: 4 + num_splits: 4 + set_sm_margin: 0 diff --git a/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs2_seqlen2048.yaml b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs2_seqlen2048.yaml new file mode 100644 index 000000000000..434e0a29f42c --- /dev/null +++ b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_a100_h12288_tp4_mbs2_seqlen2048.yaml @@ -0,0 +1,53 @@ +# UB communicator configurations +# Model configs: A100/175B/TP4/MBS2/SeqLen2K/BF16 + +# Bulk overlap with AllGather +qkv_dgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +qkv_wgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +fc1_dgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +fc1_wgrad: + method: bulk + num_sm: 2 + set_sm_margin: 0 + +## Ring-exchange overlap with AllGather +qkv_fprop: + method: ring_exchange + aggregate: 0 + +proj_dgrad: + method: ring_exchange + aggregate: 0 + +fc1_fprop: + method: ring_exchange + aggregate: 0 + +fc2_dgrad: + method: ring_exchange + aggregate: 0 + +# Chunked-collective overlap with ReduceScatter +proj_fprop: + method: pipeline + num_sm: 8 + num_splits: 4 + set_sm_margin: 0 + +fc2_fprop: + method: pipeline + num_sm: 4 + num_splits: 4 + set_sm_margin: 0 diff --git a/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp4_mbs1_seqlen2048.yaml b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp4_mbs1_seqlen2048.yaml new file mode 100644 index 000000000000..21d02f3dd22c --- /dev/null +++ b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp4_mbs1_seqlen2048.yaml @@ -0,0 +1,59 @@ +# UB communicator configurations +# Model configs: H100/175B/TP4/MBS1/SeqLen2K/FP8 + +# Bulk overlap with AllGather / ReduceScatter +qkv_dgrad: + method: bulk + num_sm: 4 + cga_size: 2 + set_sm_margin: 0 + +qkv_wgrad: + method: bulk + num_sm: 8 + cga_size: 2 + set_sm_margin: 0 + +fc1_dgrad: + method: bulk + num_sm: 2 + cga_size: 2 + set_sm_margin: 0 + +fc1_wgrad: + method: bulk + num_sm: 4 + cga_size: 2 + set_sm_margin: 0 + +## Ring-exchange overlap with AllGather +qkv_fprop: + method: ring_exchange + aggregate: 0 + +proj_dgrad: + method: ring_exchange + aggregate: 0 + +fc1_fprop: + method: ring_exchange + aggregate: 0 + +fc2_dgrad: + method: ring_exchange + aggregate: 1 + +# Chunked-collective overlap with ReduceScatter +proj_fprop: + method: pipeline + num_sm: 24 + cga_size: 2 + num_splits: 4 + set_sm_margin: 1 + +fc2_fprop: + method: pipeline + num_sm: 20 + cga_size: 2 + num_splits: 4 + set_sm_margin: 1 diff --git a/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp8_mbs2_seqlen2048.yaml b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp8_mbs2_seqlen2048.yaml new file mode 100644 index 000000000000..444c8245e02c --- /dev/null +++ b/examples/nlp/language_modeling/conf/tp_overlap/ub_cfg_h100_h12288_tp8_mbs2_seqlen2048.yaml @@ -0,0 +1,59 @@ +# UB communicator configurations +# Model configs: H100/175B/TP8/MBS2/SeqLen2K/FP8 + +# Bulk overlap with AllGather +qkv_dgrad: + method: bulk + num_sm: 8 + cga_size: 2 + set_sm_margin: 0 + +qkv_wgrad: + method: bulk + num_sm: 16 + cga_size: 2 + set_sm_margin: 0 + +fc1_dgrad: + method: bulk + num_sm: 4 + cga_size: 2 + set_sm_margin: 0 + +fc1_wgrad: + method: bulk + num_sm: 16 + cga_size: 2 + set_sm_margin: 0 + +## Ring-exchange overlap with AllGather +qkv_fprop: + method: ring_exchange + aggregate: 0 + +proj_dgrad: + method: ring_exchange + aggregate: 1 + +fc1_fprop: + method: ring_exchange + aggregate: 0 + +fc2_dgrad: + method: ring_exchange + aggregate: 0 + +# Chunked-collective overlap with ReduceScatter +proj_fprop: + method: pipeline + num_sm: 16 + cga_size: 2 + num_splits: 4 + set_sm_margin: 1 + +fc2_fprop: + method: pipeline + num_sm: 24 + cga_size: 2 + num_splits: 4 + set_sm_margin: 1 From ba6b50c84d86ac68f364d1a67eee998ceffe8e77 Mon Sep 17 00:00:00 2001 From: Jan Baczek Date: Wed, 19 Jul 2023 10:36:30 +0200 Subject: [PATCH 04/11] Add warning in case of usage of the default tp config Signed-off-by: Jan Baczek --- .../nlp/models/language_modeling/megatron_gpt_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 6381888c4fce..d3ff2523d198 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -511,6 +511,10 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): return loss_mean def initialize_ub_func(self): + ub_cfgs = self.cfg.get('ub_tp_comm_overlap_cfg', None) + if ub_cfgs is None: + warnings.warn("TP config not provided. Initializing TP comm overlap with the default config.") + input_shape = [ self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), self.cfg.get('hidden_size'), @@ -519,7 +523,7 @@ def initialize_ub_func(self): shape=input_shape, tp_size=self.cfg.get('tensor_model_parallel_size'), use_fp8=self.cfg.get('fp8'), - ub_cfgs=self.cfg.get('ub_tp_comm_overlap_cfg', None), + ub_cfgs=ub_cfgs ) self.initialize_ub = False From 8c224c65fabfbb014587e80588d446e8244a1e87 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jul 2023 08:37:31 +0000 Subject: [PATCH 05/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jan Baczek --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index d3ff2523d198..9c6428882766 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -523,7 +523,7 @@ def initialize_ub_func(self): shape=input_shape, tp_size=self.cfg.get('tensor_model_parallel_size'), use_fp8=self.cfg.get('fp8'), - ub_cfgs=ub_cfgs + ub_cfgs=ub_cfgs, ) self.initialize_ub = False From d406888b6d4fc99e9661fa74ef7aea4edc9e16ef Mon Sep 17 00:00:00 2001 From: Jan Baczek Date: Fri, 21 Jul 2023 11:12:34 +0200 Subject: [PATCH 06/11] Change warning message Signed-off-by: Jan Baczek --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 9c6428882766..4994a6ed6299 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -513,7 +513,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): def initialize_ub_func(self): ub_cfgs = self.cfg.get('ub_tp_comm_overlap_cfg', None) if ub_cfgs is None: - warnings.warn("TP config not provided. Initializing TP comm overlap with the default config.") + warnings.warn("Couldn't find TP config. Please check the path correctness. Initializing TP comm overlap with the default config.") input_shape = [ self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), From a262a86e1ceada4f557b4c218c0f0881415d8bdf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jul 2023 09:13:32 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jan Baczek --- .../nlp/models/language_modeling/megatron_gpt_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 4994a6ed6299..78f0e0912ab5 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -513,7 +513,9 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): def initialize_ub_func(self): ub_cfgs = self.cfg.get('ub_tp_comm_overlap_cfg', None) if ub_cfgs is None: - warnings.warn("Couldn't find TP config. Please check the path correctness. Initializing TP comm overlap with the default config.") + warnings.warn( + "Couldn't find TP config. Please check the path correctness. Initializing TP comm overlap with the default config." + ) input_shape = [ self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), From 0fe4b1fafc73a847352183360ca3406cdb0da10d Mon Sep 17 00:00:00 2001 From: Jan Baczek Date: Thu, 10 Aug 2023 18:25:53 +0200 Subject: [PATCH 08/11] Add compute capability resolver Signed-off-by: Jan Baczek --- nemo/core/config/hydra_runner.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/nemo/core/config/hydra_runner.py b/nemo/core/config/hydra_runner.py index 6c6c9b47e0fd..f29660ca993a 100644 --- a/nemo/core/config/hydra_runner.py +++ b/nemo/core/config/hydra_runner.py @@ -23,6 +23,21 @@ from hydra.types import TaskFunction from omegaconf import DictConfig, OmegaConf +def _get_gpu_name(): + import pynvml + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + cuda_capability, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + pynvml.nvmlShutdown() + if cuda_capability == 8: + return "a100" + elif cuda_capability == 9: + return "h100" + else: + return None + +omegaconf.OmegaConf.register_new_resolver("gpu_name", _get_gpu_name) + # multiple interpolated values in the config OmegaConf.register_new_resolver("multiply", lambda x, y: x * y) From 1cff863719e7f635af4aa2571e12ccd4aa9dabf8 Mon Sep 17 00:00:00 2001 From: Jan Baczek Date: Fri, 11 Aug 2023 11:22:44 +0200 Subject: [PATCH 09/11] Bugfix Signed-off-by: Jan Baczek --- nemo/core/config/hydra_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/core/config/hydra_runner.py b/nemo/core/config/hydra_runner.py index f29660ca993a..2845fcc67305 100644 --- a/nemo/core/config/hydra_runner.py +++ b/nemo/core/config/hydra_runner.py @@ -36,7 +36,7 @@ def _get_gpu_name(): else: return None -omegaconf.OmegaConf.register_new_resolver("gpu_name", _get_gpu_name) +OmegaConf.register_new_resolver("gpu_name", _get_gpu_name) # multiple interpolated values in the config OmegaConf.register_new_resolver("multiply", lambda x, y: x * y) From 08b35b5970478b4ba6fa9a889d8470399dee6a85 Mon Sep 17 00:00:00 2001 From: Jan Baczek Date: Thu, 14 Sep 2023 13:20:59 +0200 Subject: [PATCH 10/11] Fix cherry pick Signed-off-by: Jan Baczek --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 78f0e0912ab5..7874b35dd83b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -521,7 +521,7 @@ def initialize_ub_func(self): self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), self.cfg.get('hidden_size'), ] - te_module.initialize_ub( + te_module.base.initialize_ub( shape=input_shape, tp_size=self.cfg.get('tensor_model_parallel_size'), use_fp8=self.cfg.get('fp8'), From 5ca2f5c57cf442ec4046b78e44722311e6da5f3a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Sep 2023 12:59:59 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/core/config/hydra_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo/core/config/hydra_runner.py b/nemo/core/config/hydra_runner.py index 2845fcc67305..385db95a7299 100644 --- a/nemo/core/config/hydra_runner.py +++ b/nemo/core/config/hydra_runner.py @@ -23,8 +23,10 @@ from hydra.types import TaskFunction from omegaconf import DictConfig, OmegaConf + def _get_gpu_name(): import pynvml + pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(0) cuda_capability, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle) @@ -36,6 +38,7 @@ def _get_gpu_name(): else: return None + OmegaConf.register_new_resolver("gpu_name", _get_gpu_name) # multiple interpolated values in the config