From 5ff4f0b9b20ed25816bdfd3c328c9533140e23b3 Mon Sep 17 00:00:00 2001 From: Maanu Grover <109391026+maanug-nv@users.noreply.github.com> Date: Tue, 16 Jul 2024 19:45:46 -0500 Subject: [PATCH] Fix missing parallelisms (#9725) * pass cp and ep cfg to init mp Signed-off-by: Maanu Grover * update test Signed-off-by: Maanu Grover --------- Signed-off-by: Maanu Grover --- nemo/lightning/_strategy_lib.py | 6 +++++- tests/lightning/test_strategy_lib.py | 19 +++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index 5d7910f70f03..c3fce6aa9987 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -61,12 +61,14 @@ def init_parallel_ranks( global_rank=init_global_rank, local_rank=init_local_rank, tensor_model_parallel_size=parallel_config.tensor_model_parallel_size, + expert_model_parallel_size=parallel_config.expert_model_parallel_size, pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size, + context_parallel_size=parallel_config.context_parallel_size, seed=seed, pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None), use_fp8=fp8, - init_mpi_proc_group=getattr(parallel_config, "ub_tp_comm_overlap", False), + init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False), # apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30), ) @@ -92,6 +94,8 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None: pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, + context_parallel_size=app_state.context_parallel_size, + expert_model_parallel_size=app_state.expert_model_parallel_size, ) # assert that fake tp and pp rank match after model parallel init diff --git a/tests/lightning/test_strategy_lib.py b/tests/lightning/test_strategy_lib.py index b59930ab023d..2b9f41674a15 100644 --- a/tests/lightning/test_strategy_lib.py +++ b/tests/lightning/test_strategy_lib.py @@ -23,6 +23,8 @@ def test_init_parallel_ranks(mock_initialize_model_parallel) -> None: app_state.tensor_model_parallel_size = 2 app_state.pipeline_model_parallel_size = 3 + app_state.context_parallel_size = 2 + app_state.expert_model_parallel_size = 2 app_state.global_rank = 1 app_state.local_rank = 0 @@ -30,11 +32,18 @@ def test_init_parallel_ranks(mock_initialize_model_parallel) -> None: mock_parallel_config.tensor_model_parallel_size = 2 mock_parallel_config.pipeline_model_parallel_size = 3 mock_parallel_config.virtual_pipeline_model_parallel_size = 4 - mock_parallel_config.ub_tp_comm_overlap = False + mock_parallel_config.context_parallel_size = 2 + mock_parallel_config.expert_model_parallel_size = 2 + mock_parallel_config.tp_comm_overlap = False mock_parallel_config.pipeline_model_parallel_split_rank = None _strategy_lib.init_parallel_ranks( - world_size=2, global_rank=1, local_rank=0, parallel_config=mock_parallel_config, seed=1234, fp8=False, + world_size=2, + global_rank=1, + local_rank=0, + parallel_config=mock_parallel_config, + seed=1234, + fp8=False, ) mock_initialize_model_parallel.assert_called_once_with( world_size=2, @@ -43,6 +52,8 @@ def test_init_parallel_ranks(mock_initialize_model_parallel) -> None: tensor_model_parallel_size=2, pipeline_model_parallel_size=3, virtual_pipeline_model_parallel_size=4, + context_parallel_size=2, + expert_model_parallel_size=2, seed=1234, pipeline_model_parallel_split_rank=None, use_fp8=False, @@ -60,6 +71,8 @@ def test_init_model_parallel(mock_mpu, *args): app_state.tensor_model_parallel_size = 2 app_state.pipeline_model_parallel_size = 1 app_state.pipeline_model_parallel_split_rank = None + app_state.context_parallel_size = 2 + app_state.expert_model_parallel_size = 2 app_state.init_mpi_proc_group = False app_state.tensor_model_parallel_rank = 2 app_state.pipeline_model_parallel_rank = 0 @@ -72,6 +85,8 @@ def test_init_model_parallel(mock_mpu, *args): pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, + context_parallel_size=2, + expert_model_parallel_size=2, )