Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MCore FSDP2 support #11216

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel as McoreDDP
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed.torch_fsdp2 import FullyShardedDataParallel as McoreFSDP
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.transformer_config import TransformerConfig
from pytorch_lightning.utilities import move_data_to_device
Expand Down Expand Up @@ -142,6 +143,7 @@ class MegatronParallel(nn.ModuleList, Generic[ModelT]):
vp_size (Optional[int]): Virtual pipeline parallel size.
ddp_config (Optional[DistributedDataParallelConfig]): An instance of Megatron core's
DistributedDataParallelConfig which controls the Megatron DDP configuration.
fsdp (bool): Whether model should run Torch FSDP2 instead of DDP.
cpu (bool): Whether model should reside on CPU.
convert_module_fn (Optional[Callable[[ModelT], nn.Module]]): An optional function to
apply to the model parameters after initialization.
Expand Down Expand Up @@ -176,6 +178,7 @@ def __init__(
loss_reduction: Optional[Callable[[ModelT], "MegatronLossReduction"]] = None,
vp_size: Optional[int] = None,
ddp_config: Optional[DistributedDataParallelConfig] = None,
fsdp: bool = False,
cpu: bool = False,
convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None,
) -> None:
Expand Down Expand Up @@ -210,6 +213,7 @@ def __init__(
self.forward_step = forward_step or default_forward_step
self.loss_reduction: MegatronLossReduction = loss_reduction
self.ddp_config = ddp_config
self.fsdp = fsdp
self.convert_module_fn = convert_module_fn

def forward(
Expand Down Expand Up @@ -568,6 +572,11 @@ def init_ddp(self):

from megatron.core import parallel_state

if self.fsdp:
DP = FSDP
else:
DP = DDP

for model_chunk_idx, model_chunk in enumerate(self):
module = model_chunk.module

Expand All @@ -584,13 +593,15 @@ def init_ddp(self):
disable_bucketing = (model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step

with init_ddp_context():
ddp = DDP(
ddp = DP(
module.config,
self.ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
disable_bucketing=disable_bucketing,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
)

model_chunk.module = ddp
Expand Down Expand Up @@ -787,6 +798,35 @@ def __getattr__(self, item: Any) -> Any:
return getattr_proxy(self, item)


class FSDP(McoreFSDP):
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
disable_bucketing: bool = False,
**kwargs,
):
init_parameters = inspect.signature(McoreDDP.__init__).parameters
# Updates to the McoreDDP class have removed some parameters, so we need to
# filter out any kwargs that are not part of the updated signature, if a new
# version of mcore is being used.
filtered_kwargs = {k: v for k, v in kwargs.items() if k in init_parameters}
super().__init__(
config=config,
ddp_config=ddp_config,
module=module,
disable_bucketing=disable_bucketing,
**filtered_kwargs,
)

def state_dict(self, prefix='', keep_vars=False, **kwargs):
self.module.state_dict(prefix=prefix, keep_vars=keep_vars, **kwargs)

def __getattr__(self, item: Any) -> Any:
return getattr_proxy(self, item)


class CallbackConnector:
"""
A connector for managing and invoking callbacks.
Expand Down
25 changes: 25 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(
ckpt_load_optimizer: bool = True,
ckpt_save_optimizer: bool = True,
ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron",
fsdp: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename it torch_fsdp because we know we'll have mcore_fsdp eventually?

Copy link
Collaborator Author

@BoxiangW BoxiangW Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can make this arg a string? ['mcore', 'torch'] or None.
Or expand this arg to fsdp defualts to false, and use_torch_fsdp defaults to false?
So:
fsdp=Ture, use_torch_fsdp=True -> torch_fsdp
fsdp=Ture -> mcore_fsdp (to be supported)
nothing -> ddp

Just trying to make the interface less confusing?

lazy_init: bool = False,
pipeline_dtype: Optional[torch.dtype] = None,
save_ckpt_format: str = "torch_dist",
Expand Down Expand Up @@ -250,11 +251,14 @@ def __init__(
self.restore_config = restore_config

self._ddp = ddp
self._fsdp = fsdp
if ddp == "megatron":
self.ddp_config = DistributedDataParallelConfig(check_for_nan_in_grad=True)
elif isinstance(ddp, DistributedDataParallelConfig):
self.ddp_config = ddp
elif ddp == "pytorch":
if fsdp:
raise ValueError("Please set ddp to megatron to run Torch FSDP2.")
self.ddp_config = None
self.no_ddp_communication_hook = False
else:
Expand Down Expand Up @@ -407,6 +411,7 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
vp_size=self.virtual_pipeline_model_parallel_size,
cpu=isinstance(trainer.accelerator, CPUAccelerator),
ddp_config=self.ddp_config,
fsdp=self._fsdp,
convert_module_fn=convert_module_fn,
)

Expand Down Expand Up @@ -732,8 +737,25 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any], selective_res
if not self.should_restore_optimizer_states(selective_restore=selective_restore):
return

from megatron.core import parallel_state
from torch.distributed import DeviceMesh
from torch.distributed._tensor import DTensor, Shard

mesh = DeviceMesh.from_group(parallel_state.get_data_parallel_group(), "cuda")

optimizer_states = checkpoint["optimizer"]
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
if self._fsdp:
opt_state['fp32_from_fp16_params'] = OrderedDict()
for opt_param in opt_state['optimizer']['state'].values():
if isinstance(opt_param, Dict):
for opt_param_state_key, opt_param_state in opt_param.items():
opt_param[opt_param_state_key] = DTensor.from_local(
opt_param_state,
mesh,
(Shard(dim=0),),
)

optimizer.load_state_dict(opt_state)
_optimizer_to_device(optimizer, self.root_device)

Expand All @@ -746,6 +768,9 @@ def remove_checkpoint(self, filepath: Union[str, Path]) -> None:
shutil.rmtree(ckpt)

def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
if self._fsdp:
return

assert self.megatron_parallel is not None

_strategy_lib.load_model_state_dict(self.megatron_parallel, checkpoint, strict=strict)
Expand Down
Loading
Loading