Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Mcore ModelParallelConfig in strategy parallelism property #11232

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions nemo/lightning/fabric/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from contextlib import ExitStack, contextmanager
from dataclasses import fields
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Expand All @@ -39,6 +40,7 @@
from lightning_fabric.strategies.strategy import _validate_keys_for_strict_loading
from lightning_fabric.utilities.types import _PATH, _Stateful
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.optimizer import OptimizerConfig
from pytorch_lightning import LightningDataModule
from pytorch_lightning.loops.fetchers import _DataFetcher
Expand All @@ -58,8 +60,6 @@
from nemo.lightning.pytorch.strategies import MegatronStrategy

if TYPE_CHECKING:
from megatron.core.model_parallel_config import ModelParallelConfig

from nemo.lightning.pytorch.plugins.data_sampler import DataSampler


Expand Down Expand Up @@ -405,20 +405,17 @@ def checkpoint_io(self) -> CheckpointIO:
return self._checkpoint_io

@property
def parallelism(self):
from nemo.lightning.pytorch.strategies.megatron_strategy import ParallelismConfig

return ParallelismConfig(
tensor_model_parallel_size=self.tensor_model_parallel_size,
pipeline_model_parallel_size=self.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size,
microbatch_group_size_per_vp_stage=self.microbatch_group_size_per_vp_stage,
context_parallel_size=self.context_parallel_size,
sequence_parallel=self.sequence_parallel,
expert_model_parallel_size=self.expert_model_parallel_size,
moe_extended_tp=self.moe_extended_tp,
pipeline_dtype=self.pipeline_dtype,
)
def parallelism(self) -> ModelParallelConfig:
# Get fields from ModelParallelConfig dataclass
config_fields = {}
for field in fields(ModelParallelConfig):
# Only include field if it exists in self
if hasattr(self, field.name):
config_fields[field.name] = getattr(self, field.name)

# Initialize ModelParallelConfig with only available fields
model_parallel_config = ModelParallelConfig(**config_fields)
return model_parallel_config


# TODO: Fix this
Expand Down
8 changes: 4 additions & 4 deletions nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.

from dataclasses import asdict, dataclass, fields
import pytorch_lightning as pl

import pytorch_lightning as pl
from megatron.core import ModelParallelConfig
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from pytorch_lightning.callbacks.callback import Callback

from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import TransformerLayerTPOverlapCfg
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy, ParallelismConfig
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.utils import logging

try:
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(

def _get_model_comm_overlap_cfgs(
self,
parallelism_cfg: ParallelismConfig,
parallelism_cfg: ModelParallelConfig,
) -> _CommOverlapConfig:
comm_overlap_cfg = _CommOverlapConfig()

Expand Down Expand Up @@ -159,7 +159,7 @@ def _get_model_comm_overlap_cfgs(
comm_overlap_cfg = self._override_user_cfgs(comm_overlap_cfg)
return comm_overlap_cfg

def _get_optimizer_overlap_cfgs(self, parallelism_cfg: ParallelismConfig) -> _CommOverlapConfig:
def _get_optimizer_overlap_cfgs(self, parallelism_cfg: ModelParallelConfig) -> _CommOverlapConfig:
from nemo.utils import AppState

app_state = AppState()
Expand Down
51 changes: 14 additions & 37 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import shutil
from collections import OrderedDict
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass
from dataclasses import fields
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -41,6 +41,7 @@
from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment
from lightning_fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.optimizer import OptimizerConfig
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop
Expand Down Expand Up @@ -80,27 +81,6 @@
DDPLiteral = Literal["megatron", "pytorch"]


@dataclass
class ParallelismConfig:
"""
POD containing parallelism configuration.
Parallelism configuration is passed to MegatronStrategy via constructor arguments,
then copied to model's config during model setup.
"""

tensor_model_parallel_size: int
pipeline_model_parallel_size: int
virtual_pipeline_model_parallel_size: int
microbatch_group_size_per_vp_stage: int
context_parallel_size: int
sequence_parallel: bool
expert_model_parallel_size: int
moe_extended_tp: bool
pipeline_dtype: torch.dtype
encoder_tensor_model_parallel_size: int = 0
encoder_pipeline_model_parallel_size: int = 0


class MegatronStrategy(DDPStrategy, io.IOMixin):
"""Megatron plugin for Pytorch Lightning.

Expand Down Expand Up @@ -232,6 +212,7 @@ def __init__(
self.expert_model_parallel_size = expert_model_parallel_size
self.moe_extended_tp = moe_extended_tp
self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage
self.sequence_parallel = sequence_parallel
self.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size
self.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size
Expand Down Expand Up @@ -849,21 +830,17 @@ def restore_checkpoint_after_setup(self) -> bool:
return True

@property
def parallelism(self) -> ParallelismConfig:
"""Returns parallelism config from class attrs as a POD"""
return ParallelismConfig(
tensor_model_parallel_size=self.tensor_model_parallel_size,
pipeline_model_parallel_size=self.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size,
microbatch_group_size_per_vp_stage=self.microbatch_group_size_per_vp_stage,
context_parallel_size=self.context_parallel_size,
sequence_parallel=self.sequence_parallel,
expert_model_parallel_size=self.expert_model_parallel_size,
moe_extended_tp=self.moe_extended_tp,
encoder_tensor_model_parallel_size=self.encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size=self.encoder_pipeline_model_parallel_size,
pipeline_dtype=self.pipeline_dtype,
)
def parallelism(self) -> ModelParallelConfig:
# Get fields from ModelParallelConfig dataclass
config_fields = {}
for field in fields(ModelParallelConfig):
# Only include field if it exists in self
if hasattr(self, field.name):
config_fields[field.name] = getattr(self, field.name)

# Initialize ModelParallelConfig with only available fields
model_parallel_config = ModelParallelConfig(**config_fields)
return model_parallel_config

@contextmanager
@override
Expand Down
Loading