diff --git a/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py b/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py index cdfa523b3c5b10..06f4ac5789a68f 100644 --- a/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py +++ b/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py @@ -1,10 +1,12 @@ # Owner(s): ["oncall: distributed"] - import functools +import io +from copy import deepcopy from typing import Any import torch import torch.distributed as dist +import torch.nn as nn import torch.nn.functional as F from torch.distributed._shard.sharded_tensor.api import ShardedTensor @@ -25,7 +27,11 @@ from torch.distributed.tensor.parallel.input_reshard import input_reshard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, +) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -51,6 +57,25 @@ def forward(self, x): x = F.relu(self.net3(x)) return x + def get_input(self): + return torch.rand(4, 5, device="cuda") + + +class SimpleModelUneven(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.net1 = nn.Sequential(nn.Linear(5, 10), nn.ReLU()) + self.net2 = nn.Sequential(nn.Linear(10, 15), nn.ReLU()) + self.net3 = nn.Linear(15, 30) + self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(30, 5)) + + def forward(self, x): + return self.net4(self.net3(self.net2(self.net1(x)))) + + def get_input(self): + return torch.rand(4, 5, device="cuda") + def _wrap_module( module, @@ -298,7 +323,7 @@ def test_2d_fsdp_integration_fsdp_nested_param_groups(self) -> None: ) -class TestNew2dParallelIntegration(DTensorTestBase): +class TestNew2dParallelTraining(DTensorTestBase): # TODO: this is duplicate code from above, but once we remove the enable_2d_with_fsdp(), # we will remove the above test class Test2dParallelIntegration. def _compare_params(self, m1, m2): @@ -401,13 +426,19 @@ def test_2d_e2e_training_use_orig_params(self): def test_2d_e2e_training_not_use_orig_params(self): self._test_2d_e2e_training(recompute_activation=True) + +class TestNew2dParallelStateDict(DTensorTestBase): @with_comms @skip_if_lt_x_gpu(4) - def test_2d_state_dict(self): + @parametrize("is_even_sharded_model", [True, False]) + @parametrize("use_orig_params", [True, False]) + def test_2d_state_dict(self, is_even_sharded_model, use_orig_params): + simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven + # Create a model without wrapper torch.manual_seed(0) - simple_model = SimpleModel().cuda(self.rank) - no_wrap_state_dict = simple_model.state_dict() + no_wrap_model = simple_model().cuda(self.rank) + no_wrap_state_dict = no_wrap_model.state_dict() # Create a model and sharded it with 2D FSDP + TP torch.manual_seed(0) @@ -416,11 +447,13 @@ def test_2d_state_dict(self): ) tp_mesh = mesh_2d["tp"] dp_mesh = mesh_2d["dp"] - model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, PairwiseParallel()) + model_2d = parallelize_module( + simple_model().cuda(), tp_mesh, PairwiseParallel() + ) model_2d = FSDP( model_2d, device_mesh=dp_mesh, - use_orig_params=True, + use_orig_params=use_orig_params, ) FSDP.set_state_dict_type( @@ -452,6 +485,63 @@ def test_2d_state_dict(self): torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True ) + @with_comms + @skip_if_lt_x_gpu(4) + @parametrize("is_even_sharded_model", [True, False]) + @parametrize("use_orig_params", [True, False]) + def test_2d_load_state_dict(self, is_even_sharded_model, use_orig_params): + simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven + + torch.manual_seed(0) + mesh_2d = init_device_mesh( + self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") + ) + tp_mesh = mesh_2d["tp"] + dp_mesh = mesh_2d["dp"] + model_2d = parallelize_module( + simple_model().cuda(), tp_mesh, PairwiseParallel() + ) + model_2d = FSDP( + model_2d, + device_mesh=dp_mesh, + use_orig_params=use_orig_params, + ) + optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) + FSDP.set_state_dict_type( + model_2d, + StateDictType.SHARDED_STATE_DICT, + ) + checkpoint = io.BytesIO() + torch.save(model_2d.state_dict(), checkpoint) + # Deepcopy to save current state_dict to compare with the state_dict loaded back below. + ref_state_dict = deepcopy(model_2d.state_dict()) + + # Update the parameters so model.state_dict() will be different from ref_dtensor_sd. + model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() + optim_2d.step() + + # Load ref_state_dict back. + checkpoint.seek(0) + load_ref_state_dict = torch.load(checkpoint) + model_2d.load_state_dict(load_ref_state_dict) + new_state_dict = model_2d.state_dict() + + # Check whether new_state_dict is the same as ref_state_dict. + for (k1, v1), (k2, v2) in zip(ref_state_dict.items(), new_state_dict.items()): + # check whether fqn are the same + self.assertEqual(k1, k2) + + self.assertEqual(type(v1), DT) + self.assertEqual(type(v2), DT) + # check whether DTensor are the same + # TODO: 2D DTensor comparison is not supported at the time, so we are comparing the spec and the local tensor for now. + # TODO: Update it to compare the two DTensors once 2D DTensor comparison is supported. + self.assertEqual(v1.to_local(), v2.to_local()) + self.assertEqual(v1.device_mesh, v2.device_mesh) + self.assertEqual(v1.placements, v2.placements) + + +instantiate_parametrized_tests(TestNew2dParallelStateDict) if __name__ == "__main__": run_tests() diff --git a/torch/distributed/fsdp/_fsdp_extensions.py b/torch/distributed/fsdp/_fsdp_extensions.py index 1a158cf79dd00e..26f1f361f0b077 100644 --- a/torch/distributed/fsdp/_fsdp_extensions.py +++ b/torch/distributed/fsdp/_fsdp_extensions.py @@ -5,8 +5,9 @@ import torch.distributed as dist from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed._shard.sharded_tensor.shard import Shard -from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed.fsdp._shard_utils import ( + _all_gather_dtensor, _create_chunk_dtensor, _create_chunk_sharded_tensor, ) @@ -70,6 +71,19 @@ def pre_load_state_dict_transform( """ ... + @abstractmethod + def all_gather_dtensor( + self, + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + ) -> torch.Tensor: + """ + This is to be called before loading a *sharded* DTensor state dict. + This gathers tensor in FSDP dimension and returns local tensor of + TP DTensor. + """ + ... + _extensions: Optional[FSDPExtensions] = None @@ -143,3 +157,15 @@ def _ext_pre_load_state_dict_transform( assert type(tensor) is ShardedTensor shards = tensor.local_shards() return (tensor, shards) + + +def _ext_all_gather_dtensor( + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], +) -> torch.Tensor: + all_gather_dtensor_fn = ( + _extensions.all_gather_dtensor + if _extensions is not None + else _all_gather_dtensor + ) + return all_gather_dtensor_fn(tensor, parent_mesh) diff --git a/torch/distributed/fsdp/_shard_utils.py b/torch/distributed/fsdp/_shard_utils.py index f4edbb1faf14f4..0e63ebd1b9c148 100644 --- a/torch/distributed/fsdp/_shard_utils.py +++ b/torch/distributed/fsdp/_shard_utils.py @@ -190,3 +190,24 @@ def _create_chunk_dtensor( device_mesh=device_mesh, placements=shard_placements, ) + + +def _all_gather_dtensor( + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], +) -> torch.Tensor: + """ + All gather a DTensor in its sharded dimension and return the local tensor. + """ + assert parent_mesh is None + + placements = list(copy.deepcopy(tensor.placements)) + # FSDP placements: [Shard(0)] -> [Replicate()] + # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] + placements[-1] = Replicate() + tensor = tensor.redistribute( + device_mesh=tensor.device_mesh, + placements=placements, + ) + + return tensor.to_local() diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 69cbb484dec1e7..2ab27ddc9587e6 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -1,5 +1,4 @@ import contextlib -import copy import logging import math import warnings @@ -17,7 +16,7 @@ Shard, ShardedTensor, ) -from torch.distributed._tensor import DTensor, Replicate +from torch.distributed._tensor import DTensor from torch.distributed._tensor.device_mesh import mesh_resources from torch.distributed.distributed_c10d import _get_pg_default_device @@ -46,8 +45,10 @@ from torch.distributed.utils import _replace_by_prefix from ._fsdp_extensions import ( + _ext_all_gather_dtensor, _ext_chunk_dtensor, _ext_chunk_tensor, + _ext_post_unflatten_transform, _ext_pre_load_state_dict_transform, ) from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM @@ -606,6 +607,9 @@ def _sharded_pre_load_state_dict_hook( "load_sharded_state_dict can only be called when parameters " "are flattened and sharded." ) + fqn_to_param_ext = dict( + zip(handle.flat_param._fqns, handle.flat_param._param_extensions) + ) device = fsdp_state.compute_device for fqn, _, _ in _param_name_infos(module, fsdp_state): @@ -679,13 +683,14 @@ def _sharded_pre_load_state_dict_hook( else: if param.device != fsdp_state._device_mesh.device_type: param = param.to(fsdp_state._device_mesh.device_type) - placements = list(copy.deepcopy(param.placements)) - placements[-1] = Replicate() - param = param.redistribute( - device_mesh=param.device_mesh, - placements=placements, - ) - state_dict[fqn_from_global_root] = param.to_local() + + parent_mesh = mesh_resources.get_parent_mesh(fsdp_state._device_mesh) + local_tensor = _ext_all_gather_dtensor(param, parent_mesh) + + if fqn_to_param_ext.get(fqn) is not None: + ext = fqn_to_param_ext[fqn] + local_tensor = _ext_post_unflatten_transform(local_tensor, ext) + state_dict[fqn_from_global_root] = local_tensor with SimpleProfiler.profile("_enter_unshard_params_ctx"): _enter_unshard_params_ctx(module, fsdp_state, writeback=True) diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index 2d7dc98ca14e1f..64d54e6c1fca77 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -293,6 +293,26 @@ def _pre_load_state_dict( return (tensor, shards if len(shards) > 0 else []) +def _all_gather_dtensor( + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], +) -> torch.Tensor: + """ + All gather a DTensor in its FSDP dimension and return the local tensor. + """ + assert parent_mesh == tensor.device_mesh + + placements = list(copy.deepcopy(tensor.placements)) + # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement] + placements[0] = Replicate() + tensor = tensor.redistribute( + device_mesh=tensor.device_mesh, + placements=placements, + ) + + return tensor.to_local() + + class DTensorExtensions(FSDPExtensions): """ DTensorExtension is the TensorFlattener extension needed for 2D FSDP + TP. @@ -338,6 +358,13 @@ def pre_load_state_dict_transform( ) -> Tuple[torch.Tensor, List[Shard]]: return _pre_load_state_dict(tensor) + def all_gather_dtensor( + self, + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + ) -> torch.Tensor: + return _all_gather_dtensor(tensor, parent_mesh) + # TODO: remove enable_2d_with_fsdp() once we roll out the new 2D flow. def enable_2d_with_fsdp() -> bool: