Skip to content

Commit

Permalink
Fix _strategy_lib tests (#11033) (#11039)
Browse files Browse the repository at this point in the history
* fix world size and don't mock



* cleanup global state



* check app state instead



* fix syntax nemo logger test



---------

Signed-off-by: Maanu Grover <[email protected]>
Co-authored-by: Maanu Grover <[email protected]>
  • Loading branch information
pablo-garay and maanug-nv authored Oct 25, 2024
1 parent e34e04b commit 395c502
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 18 deletions.
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import pytest

from nemo.utils.metaclasses import Singleton

# Those variables probably should go to main NeMo configuration file (config.yaml).
__TEST_DATA_FILENAME = "test_data.tar.gz"
__TEST_DATA_URL = "https://github.com/NVIDIA/NeMo/releases/download/v1.0.0rc1/"
Expand Down Expand Up @@ -115,6 +117,11 @@ def cleanup_local_folder():
rmtree('./nemo_experiments', ignore_errors=True)


@pytest.fixture(autouse=True)
def reset_singletons():
Singleton._Singleton__instances = {}


@pytest.fixture(scope="session")
def test_data_dir():
"""
Expand Down Expand Up @@ -173,6 +180,7 @@ def k2_cuda_is_enabled(k2_is_appropriate) -> Tuple[bool, str]:
return k2_is_appropriate

import torch # noqa: E402

from nemo.core.utils.k2_guard import k2 # noqa: E402

if torch.cuda.is_available() and k2.with_cuda:
Expand Down
3 changes: 2 additions & 1 deletion tests/lightning/test_nemo_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def test_resume(self, trainer, tmp_path):
resume_ignore_no_checkpoint=True,
).setup(trainer)

path = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints").mkdir(parents=True)
path = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints")
path.mkdir(parents=True)
# Error because checkpoints do not exist in folder
with pytest.raises(NotFoundError):
nl.AutoResume(
Expand Down
42 changes: 25 additions & 17 deletions tests/lightning/test_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def configure_model(self):
assert model.config.pipeline_dtype == torch.float32


@patch('nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo')
def test_init_parallel_ranks(mock_initialize_model_parallel) -> None:
def test_init_parallel_ranks() -> None:
from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator
from megatron.core.parallel_state import destroy_model_parallel

from nemo.utils import AppState

app_state = AppState()
Expand All @@ -80,27 +82,33 @@ def test_init_parallel_ranks(mock_initialize_model_parallel) -> None:
mock_parallel_config.pipeline_model_parallel_split_rank = None

_strategy_lib.init_parallel_ranks(
world_size=2,
world_size=24,
global_rank=1,
local_rank=0,
parallel_config=mock_parallel_config,
seed=1234,
fp8=False,
)
mock_initialize_model_parallel.assert_called_once_with(
world_size=2,
global_rank=1,
local_rank=0,
tensor_model_parallel_size=2,
pipeline_model_parallel_size=3,
virtual_pipeline_model_parallel_size=4,
context_parallel_size=2,
expert_model_parallel_size=2,
seed=1234,
pipeline_model_parallel_split_rank=None,
use_fp8=False,
init_mpi_proc_group=False,
)
expected_app_state = {
"world_size": 24,
"global_rank": 1,
"local_rank": 0,
"tensor_model_parallel_size": 2,
"pipeline_model_parallel_size": 3,
"virtual_pipeline_model_parallel_size": 4,
"context_parallel_size": 2,
"expert_model_parallel_size": 2,
"pipeline_model_parallel_split_rank": None,
"use_fp8": False,
"init_mpi_proc_group": False,
}
for k, v in expected_app_state.items():
assert hasattr(app_state, k), f"Expected to find {k} in AppState"
app_attr = getattr(app_state, k)
assert app_attr == v, f"{k} in AppState is incorrect, Expected: {v} Actual: {app_attr}"

destroy_model_parallel()
destroy_num_microbatches_calculator()


@patch('torch.distributed.is_initialized', return_value=True)
Expand Down

0 comments on commit 395c502

Please sign in to comment.