diff --git a/nemo/lightning/fabric/strategies.py b/nemo/lightning/fabric/strategies.py index 575f69a58caf..71ca4fb78e07 100644 --- a/nemo/lightning/fabric/strategies.py +++ b/nemo/lightning/fabric/strategies.py @@ -13,6 +13,7 @@ # limitations under the License. from contextlib import ExitStack, contextmanager +from dataclasses import fields from datetime import timedelta from typing import ( TYPE_CHECKING, @@ -39,6 +40,7 @@ from lightning_fabric.strategies.strategy import _validate_keys_for_strict_loading from lightning_fabric.utilities.types import _PATH, _Stateful from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.optimizer import OptimizerConfig from pytorch_lightning import LightningDataModule from pytorch_lightning.loops.fetchers import _DataFetcher @@ -58,8 +60,6 @@ from nemo.lightning.pytorch.strategies import MegatronStrategy if TYPE_CHECKING: - from megatron.core.model_parallel_config import ModelParallelConfig - from nemo.lightning.pytorch.plugins.data_sampler import DataSampler @@ -405,20 +405,17 @@ def checkpoint_io(self) -> CheckpointIO: return self._checkpoint_io @property - def parallelism(self): - from nemo.lightning.pytorch.strategies.megatron_strategy import ParallelismConfig - - return ParallelismConfig( - tensor_model_parallel_size=self.tensor_model_parallel_size, - pipeline_model_parallel_size=self.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size, - microbatch_group_size_per_vp_stage=self.microbatch_group_size_per_vp_stage, - context_parallel_size=self.context_parallel_size, - sequence_parallel=self.sequence_parallel, - expert_model_parallel_size=self.expert_model_parallel_size, - moe_extended_tp=self.moe_extended_tp, - pipeline_dtype=self.pipeline_dtype, - ) + def parallelism(self) -> ModelParallelConfig: + # Get fields from ModelParallelConfig dataclass + config_fields = {} + for field in fields(ModelParallelConfig): + # Only include field if it exists in self + if hasattr(self, field.name): + config_fields[field.name] = getattr(self, field.name) + + # Initialize ModelParallelConfig with only available fields + model_parallel_config = ModelParallelConfig(**config_fields) + return model_parallel_config # TODO: Fix this diff --git a/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py b/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py index fc4312e2ff84..b75081e1ed94 100644 --- a/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py +++ b/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py @@ -13,15 +13,15 @@ # limitations under the License. from dataclasses import asdict, dataclass, fields -import pytorch_lightning as pl +import pytorch_lightning as pl from megatron.core import ModelParallelConfig from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from pytorch_lightning.callbacks.callback import Callback from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import TransformerLayerTPOverlapCfg -from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy, ParallelismConfig +from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy from nemo.utils import logging try: @@ -118,7 +118,7 @@ def __init__( def _get_model_comm_overlap_cfgs( self, - parallelism_cfg: ParallelismConfig, + parallelism_cfg: ModelParallelConfig, ) -> _CommOverlapConfig: comm_overlap_cfg = _CommOverlapConfig() @@ -159,7 +159,7 @@ def _get_model_comm_overlap_cfgs( comm_overlap_cfg = self._override_user_cfgs(comm_overlap_cfg) return comm_overlap_cfg - def _get_optimizer_overlap_cfgs(self, parallelism_cfg: ParallelismConfig) -> _CommOverlapConfig: + def _get_optimizer_overlap_cfgs(self, parallelism_cfg: ModelParallelConfig) -> _CommOverlapConfig: from nemo.utils import AppState app_state = AppState() diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index c62a90313b45..3497c491755e 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -18,7 +18,7 @@ import shutil from collections import OrderedDict from contextlib import ExitStack, contextmanager -from dataclasses import dataclass +from dataclasses import fields from pathlib import Path from typing import ( TYPE_CHECKING, @@ -41,6 +41,7 @@ from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment from lightning_fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.optimizer import OptimizerConfig from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop @@ -80,27 +81,6 @@ DDPLiteral = Literal["megatron", "pytorch"] -@dataclass -class ParallelismConfig: - """ - POD containing parallelism configuration. - Parallelism configuration is passed to MegatronStrategy via constructor arguments, - then copied to model's config during model setup. - """ - - tensor_model_parallel_size: int - pipeline_model_parallel_size: int - virtual_pipeline_model_parallel_size: int - microbatch_group_size_per_vp_stage: int - context_parallel_size: int - sequence_parallel: bool - expert_model_parallel_size: int - moe_extended_tp: bool - pipeline_dtype: torch.dtype - encoder_tensor_model_parallel_size: int = 0 - encoder_pipeline_model_parallel_size: int = 0 - - class MegatronStrategy(DDPStrategy, io.IOMixin): """Megatron plugin for Pytorch Lightning. @@ -232,6 +212,7 @@ def __init__( self.expert_model_parallel_size = expert_model_parallel_size self.moe_extended_tp = moe_extended_tp self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size + self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage self.sequence_parallel = sequence_parallel self.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size self.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size @@ -849,21 +830,17 @@ def restore_checkpoint_after_setup(self) -> bool: return True @property - def parallelism(self) -> ParallelismConfig: - """Returns parallelism config from class attrs as a POD""" - return ParallelismConfig( - tensor_model_parallel_size=self.tensor_model_parallel_size, - pipeline_model_parallel_size=self.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size, - microbatch_group_size_per_vp_stage=self.microbatch_group_size_per_vp_stage, - context_parallel_size=self.context_parallel_size, - sequence_parallel=self.sequence_parallel, - expert_model_parallel_size=self.expert_model_parallel_size, - moe_extended_tp=self.moe_extended_tp, - encoder_tensor_model_parallel_size=self.encoder_tensor_model_parallel_size, - encoder_pipeline_model_parallel_size=self.encoder_pipeline_model_parallel_size, - pipeline_dtype=self.pipeline_dtype, - ) + def parallelism(self) -> ModelParallelConfig: + # Get fields from ModelParallelConfig dataclass + config_fields = {} + for field in fields(ModelParallelConfig): + # Only include field if it exists in self + if hasattr(self, field.name): + config_fields[field.name] = getattr(self, field.name) + + # Initialize ModelParallelConfig with only available fields + model_parallel_config = ModelParallelConfig(**config_fields) + return model_parallel_config @contextmanager @override