Skip to content

Commit

Permalink
[2D] Enable 2D FSDP+TP model.load_state_dict() (pytorch#110925)
Browse files Browse the repository at this point in the history
This PR adds a all_gather_dtensor() method to fsdp/_fsdp_extensions.py and the actual implementation in tensor/parallel/fsdp.py. This enables FSDP to load 2D DTensor state_dict into model when calling `model.load_state_dict()`.

cc. @fegin

Pull Request resolved: pytorch#110925
Approved by: https://github.com/fegin
ghstack dependencies: pytorch#110831, pytorch#110846
  • Loading branch information
wz337 authored and pytorchmergebot committed Oct 11, 2023
1 parent fd4ba80 commit 80dfc97
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 18 deletions.
106 changes: 98 additions & 8 deletions test/distributed/tensor/parallel/test_fsdp_2d_parallel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Owner(s): ["oncall: distributed"]

import functools
import io
from copy import deepcopy
from typing import Any

import torch
import torch.distributed as dist
import torch.nn as nn

import torch.nn.functional as F
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
Expand All @@ -25,7 +27,11 @@
from torch.distributed.tensor.parallel.input_reshard import input_reshard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu

from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
Expand All @@ -51,6 +57,25 @@ def forward(self, x):
x = F.relu(self.net3(x))
return x

def get_input(self):
return torch.rand(4, 5, device="cuda")


class SimpleModelUneven(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(5, 10), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(10, 15), nn.ReLU())
self.net3 = nn.Linear(15, 30)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(30, 5))

def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))

def get_input(self):
return torch.rand(4, 5, device="cuda")


def _wrap_module(
module,
Expand Down Expand Up @@ -298,7 +323,7 @@ def test_2d_fsdp_integration_fsdp_nested_param_groups(self) -> None:
)


class TestNew2dParallelIntegration(DTensorTestBase):
class TestNew2dParallelTraining(DTensorTestBase):
# TODO: this is duplicate code from above, but once we remove the enable_2d_with_fsdp(),
# we will remove the above test class Test2dParallelIntegration.
def _compare_params(self, m1, m2):
Expand Down Expand Up @@ -401,13 +426,19 @@ def test_2d_e2e_training_use_orig_params(self):
def test_2d_e2e_training_not_use_orig_params(self):
self._test_2d_e2e_training(recompute_activation=True)


class TestNew2dParallelStateDict(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
def test_2d_state_dict(self):
@parametrize("is_even_sharded_model", [True, False])
@parametrize("use_orig_params", [True, False])
def test_2d_state_dict(self, is_even_sharded_model, use_orig_params):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven

# Create a model without wrapper
torch.manual_seed(0)
simple_model = SimpleModel().cuda(self.rank)
no_wrap_state_dict = simple_model.state_dict()
no_wrap_model = simple_model().cuda(self.rank)
no_wrap_state_dict = no_wrap_model.state_dict()

# Create a model and sharded it with 2D FSDP + TP
torch.manual_seed(0)
Expand All @@ -416,11 +447,13 @@ def test_2d_state_dict(self):
)
tp_mesh = mesh_2d["tp"]
dp_mesh = mesh_2d["dp"]
model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, PairwiseParallel())
model_2d = parallelize_module(
simple_model().cuda(), tp_mesh, PairwiseParallel()
)
model_2d = FSDP(
model_2d,
device_mesh=dp_mesh,
use_orig_params=True,
use_orig_params=use_orig_params,
)

FSDP.set_state_dict_type(
Expand Down Expand Up @@ -452,6 +485,63 @@ def test_2d_state_dict(self):
torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True
)

@with_comms
@skip_if_lt_x_gpu(4)
@parametrize("is_even_sharded_model", [True, False])
@parametrize("use_orig_params", [True, False])
def test_2d_load_state_dict(self, is_even_sharded_model, use_orig_params):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven

torch.manual_seed(0)
mesh_2d = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
)
tp_mesh = mesh_2d["tp"]
dp_mesh = mesh_2d["dp"]
model_2d = parallelize_module(
simple_model().cuda(), tp_mesh, PairwiseParallel()
)
model_2d = FSDP(
model_2d,
device_mesh=dp_mesh,
use_orig_params=use_orig_params,
)
optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)

FSDP.set_state_dict_type(
model_2d,
StateDictType.SHARDED_STATE_DICT,
)
checkpoint = io.BytesIO()
torch.save(model_2d.state_dict(), checkpoint)
# Deepcopy to save current state_dict to compare with the state_dict loaded back below.
ref_state_dict = deepcopy(model_2d.state_dict())

# Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
model_2d(model_2d.get_input().cuda(self.rank)).sum().backward()
optim_2d.step()

# Load ref_state_dict back.
checkpoint.seek(0)
load_ref_state_dict = torch.load(checkpoint)
model_2d.load_state_dict(load_ref_state_dict)
new_state_dict = model_2d.state_dict()

# Check whether new_state_dict is the same as ref_state_dict.
for (k1, v1), (k2, v2) in zip(ref_state_dict.items(), new_state_dict.items()):
# check whether fqn are the same
self.assertEqual(k1, k2)

self.assertEqual(type(v1), DT)
self.assertEqual(type(v2), DT)
# check whether DTensor are the same
# TODO: 2D DTensor comparison is not supported at the time, so we are comparing the spec and the local tensor for now.
# TODO: Update it to compare the two DTensors once 2D DTensor comparison is supported.
self.assertEqual(v1.to_local(), v2.to_local())
self.assertEqual(v1.device_mesh, v2.device_mesh)
self.assertEqual(v1.placements, v2.placements)


instantiate_parametrized_tests(TestNew2dParallelStateDict)
if __name__ == "__main__":
run_tests()
28 changes: 27 additions & 1 deletion torch/distributed/fsdp/_fsdp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed.fsdp._shard_utils import (
_all_gather_dtensor,
_create_chunk_dtensor,
_create_chunk_sharded_tensor,
)
Expand Down Expand Up @@ -70,6 +71,19 @@ def pre_load_state_dict_transform(
"""
...

@abstractmethod
def all_gather_dtensor(
self,
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
"""
This is to be called before loading a *sharded* DTensor state dict.
This gathers tensor in FSDP dimension and returns local tensor of
TP DTensor.
"""
...


_extensions: Optional[FSDPExtensions] = None

Expand Down Expand Up @@ -143,3 +157,15 @@ def _ext_pre_load_state_dict_transform(
assert type(tensor) is ShardedTensor
shards = tensor.local_shards()
return (tensor, shards)


def _ext_all_gather_dtensor(
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
all_gather_dtensor_fn = (
_extensions.all_gather_dtensor
if _extensions is not None
else _all_gather_dtensor
)
return all_gather_dtensor_fn(tensor, parent_mesh)
21 changes: 21 additions & 0 deletions torch/distributed/fsdp/_shard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,24 @@ def _create_chunk_dtensor(
device_mesh=device_mesh,
placements=shard_placements,
)


def _all_gather_dtensor(
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
"""
All gather a DTensor in its sharded dimension and return the local tensor.
"""
assert parent_mesh is None

placements = list(copy.deepcopy(tensor.placements))
# FSDP placements: [Shard(0)] -> [Replicate()]
# HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
placements[-1] = Replicate()
tensor = tensor.redistribute(
device_mesh=tensor.device_mesh,
placements=placements,
)

return tensor.to_local()
23 changes: 14 additions & 9 deletions torch/distributed/fsdp/_state_dict_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import contextlib
import copy
import logging
import math
import warnings
Expand All @@ -17,7 +16,7 @@
Shard,
ShardedTensor,
)
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed._tensor import DTensor
from torch.distributed._tensor.device_mesh import mesh_resources

from torch.distributed.distributed_c10d import _get_pg_default_device
Expand Down Expand Up @@ -46,8 +45,10 @@
from torch.distributed.utils import _replace_by_prefix

from ._fsdp_extensions import (
_ext_all_gather_dtensor,
_ext_chunk_dtensor,
_ext_chunk_tensor,
_ext_post_unflatten_transform,
_ext_pre_load_state_dict_transform,
)
from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM
Expand Down Expand Up @@ -606,6 +607,9 @@ def _sharded_pre_load_state_dict_hook(
"load_sharded_state_dict can only be called when parameters "
"are flattened and sharded."
)
fqn_to_param_ext = dict(
zip(handle.flat_param._fqns, handle.flat_param._param_extensions)
)

device = fsdp_state.compute_device
for fqn, _, _ in _param_name_infos(module, fsdp_state):
Expand Down Expand Up @@ -679,13 +683,14 @@ def _sharded_pre_load_state_dict_hook(
else:
if param.device != fsdp_state._device_mesh.device_type:
param = param.to(fsdp_state._device_mesh.device_type)
placements = list(copy.deepcopy(param.placements))
placements[-1] = Replicate()
param = param.redistribute(
device_mesh=param.device_mesh,
placements=placements,
)
state_dict[fqn_from_global_root] = param.to_local()

parent_mesh = mesh_resources.get_parent_mesh(fsdp_state._device_mesh)
local_tensor = _ext_all_gather_dtensor(param, parent_mesh)

if fqn_to_param_ext.get(fqn) is not None:
ext = fqn_to_param_ext[fqn]
local_tensor = _ext_post_unflatten_transform(local_tensor, ext)
state_dict[fqn_from_global_root] = local_tensor

with SimpleProfiler.profile("_enter_unshard_params_ctx"):
_enter_unshard_params_ctx(module, fsdp_state, writeback=True)
Expand Down
27 changes: 27 additions & 0 deletions torch/distributed/tensor/parallel/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,26 @@ def _pre_load_state_dict(
return (tensor, shards if len(shards) > 0 else [])


def _all_gather_dtensor(
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
"""
All gather a DTensor in its FSDP dimension and return the local tensor.
"""
assert parent_mesh == tensor.device_mesh

placements = list(copy.deepcopy(tensor.placements))
# FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement]
placements[0] = Replicate()
tensor = tensor.redistribute(
device_mesh=tensor.device_mesh,
placements=placements,
)

return tensor.to_local()


class DTensorExtensions(FSDPExtensions):
"""
DTensorExtension is the TensorFlattener extension needed for 2D FSDP + TP.
Expand Down Expand Up @@ -338,6 +358,13 @@ def pre_load_state_dict_transform(
) -> Tuple[torch.Tensor, List[Shard]]:
return _pre_load_state_dict(tensor)

def all_gather_dtensor(
self,
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
return _all_gather_dtensor(tensor, parent_mesh)


# TODO: remove enable_2d_with_fsdp() once we roll out the new 2D flow.
def enable_2d_with_fsdp() -> bool:
Expand Down

0 comments on commit 80dfc97

Please sign in to comment.