Skip to content

Commit

Permalink
[DTensor] Supported 2D clip_grad_norm_ (pytorch#121945)
Browse files Browse the repository at this point in the history
This PR adds support for 2D `clip_grad_norm_` (`foreach=True`).
- This PR changes `OpSchema.args_spec` to use pytree if the runtime schema info specifies it.
- This PR includes a unit test for 2D FSDP2 + SP with `clip_grad_norm_` enabled, which serves as a complete numerics test for 2D.

Note: With this PR patched, 2-way SP + 4-way FSDP matches 8-way FSDP numerics on Llama-7B (doubling local batch size for the 2-way SP run).

Pull Request resolved: pytorch#121945
Approved by: https://github.com/wanchaol
ghstack dependencies: pytorch#121747, pytorch#121869
  • Loading branch information
awgu authored and pytorchmergebot committed Mar 15, 2024
1 parent 2c33e3a commit f4dd2fd
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 33 deletions.
106 changes: 78 additions & 28 deletions test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import copy
import functools
from typing import Union
from typing import Optional, Union

import torch
import torch.nn as nn
from torch.distributed._composable import replicate
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import Shard
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests
Expand All @@ -18,37 +21,21 @@
)


class TestClipGradNormMultiThread(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)

@skip_if_lt_x_gpu(2)
def test_clip_grad_norm_1d(self):
self.run_subtests(
{"max_norm": [1], "norm_type": [2, 1, float("inf")]},
self._test_clip_grad_norm_1d,
)

def _test_clip_grad_norm_1d(
class _TestClipGradNormBase(FSDPTest):
def _test_clip_grad_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int],
ref_model: nn.Module,
ref_optim: torch.optim.Optimizer,
model: nn.Module,
optim: torch.optim.Optimizer,
dp_mesh: Optional[DeviceMesh] = None,
):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randint(0, model_args.vocab_size, (3, 16), device="cuda")
dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,))
torch.manual_seed(42 + dp_mesh.get_local_rank() + 1)
inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda")
for iter_idx in range(10):
ref_optim.zero_grad()
ref_model(inp).sum().backward()
Expand All @@ -60,6 +47,10 @@ def _test_clip_grad_norm_1d(
p.grad.to_local().detach().clone() for p in model.parameters()
]
for ref_grad, param in zip(ref_grads, model.parameters()):
# TODO: Skip the check for the parameters since FSDP needs
# strided sharding for it to work with `full_tensor`
if tuple(param.placements) == (Shard(0), Shard(0)):
continue
self.assertEqual(ref_grad, param.grad.full_tensor())

# Check that all gradients have norm greater than the max norm
Expand All @@ -83,7 +74,12 @@ def _test_clip_grad_norm_1d(
foreach=True,
)
self.assertEqual(ref_total_norm, total_norm)
self.assertEqual(comm_mode.get_total_counts(), 1) # one all-reduce
# Expect one all-reduce per mesh dim for partial -> replicate
expected_all_reduces = len(total_norm.placements)
self.assertEqual(
comm_mode.get_comm_counts()[torch.ops.c10d_functional.all_reduce],
expected_all_reduces,
)
# For zero gradients, clipping has no effect
for param, grad in zip(ref_model.parameters(), ref_grads):
self.assertTrue(vector_norm_fn(param.grad).item() <= max_norm)
Expand All @@ -97,5 +93,59 @@ def _test_clip_grad_norm_1d(
self.assertFalse(torch.equal(param.grad.to_local(), grad))


class TestClipGradNormWorldSize2(_TestClipGradNormBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)

@skip_if_lt_x_gpu(2)
def test_clip_grad_norm_1d(self):
for norm_type in (2, 1, float("inf")):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
self._test_clip_grad_norm(1, norm_type, ref_model, ref_optim, model, optim)


class TestClipGradNormWorldSize4(_TestClipGradNormBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 4)

@skip_if_lt_x_gpu(4)
def test_clip_grad_norm_2d(self):
for norm_type in (2, 1, 3, float("inf")):
dp_size = 2
global_mesh = init_device_mesh(
"cuda",
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(
copy.deepcopy(model).cuda(), process_group=dp_mesh.get_group()
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
model = Transformer.parallelize(model, tp_mesh, use_seq_parallel=True)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module, mesh=dp_mesh)
fully_shard(model, mesh=dp_mesh)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
self._test_clip_grad_norm(
1, norm_type, ref_model, ref_optim, model, optim, dp_mesh
)


if __name__ == "__main__":
run_tests()
22 changes: 18 additions & 4 deletions torch/distributed/_tensor/op_schema.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from dataclasses import dataclass
from functools import cached_property
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch._ops import OpOverload
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.distributed.device_mesh import DeviceMesh

try:
from torch.utils._cxx_pytree import tree_map_only, TreeSpec
from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec
except ImportError:
from torch.utils._pytree import ( # type: ignore[no-redef, assignment]
tree_leaves,
tree_map_only,
TreeSpec,
)
Expand Down Expand Up @@ -198,7 +199,7 @@ class RuntimeSchemaInfo:
static_kwargkey: Optional[List[str]] = None
# each op can decide if it wants to use pytree flatten/unflatten during operator
# eager execution, by default we don't need to do flatten/unflatten, only if the
# op indicate it needs to, this is to accelate eager performance.
# op indicate it needs to, this is to accelerate eager performance.
needs_pytree: bool = False


Expand Down Expand Up @@ -236,6 +237,12 @@ def args_spec(self) -> Tuple[DTensorSpec, ...]:
"""
# filter out non-relevant values from args schema to get a clean spec list
# this would mainly be used by sharding propagation rules
if self.schema_info is not None and self.schema_info.needs_pytree:
return tuple(
item
for item in tree_leaves(self.args_schema)
if isinstance(item, DTensorSpec)
)
return tuple(item for item in self.args_schema if isinstance(item, DTensorSpec))

def __repr__(self) -> str:
Expand Down Expand Up @@ -382,7 +389,14 @@ def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
suggestion_args_spec = self.args_spec
new_arg_schema: List[object] = []
idx_of_args_spec = 0
for arg in origin_schema.args_schema:
if (
origin_schema.schema_info is not None
and origin_schema.schema_info.needs_pytree
):
args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema)
else:
args_schema = origin_schema.args_schema
for arg in args_schema:
if isinstance(arg, DTensorSpec):
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
idx_of_args_spec += 1
Expand Down
8 changes: 7 additions & 1 deletion torch/distributed/_tensor/ops/math_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import math
from dataclasses import dataclass
from enum import Enum
from typing import cast, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -87,7 +88,12 @@ def _partition_value(
if self.reduce_op in (c10d.ReduceOp.MAX, c10d.ReduceOp.MIN):
return tensor
elif self.reduce_op == c10d.ReduceOp.SUM:
return tensor / mesh.size(mesh_dim=mesh_dim)
if self.norm_type == 0:
raise NotImplementedError(f"Unsupported norm type:: {self.norm_type}")
elif self.norm_type == 1:
return tensor / mesh.size(mesh_dim)
assert isinstance(self.norm_type, (int, float))
return tensor / math.pow(mesh.size(mesh_dim), 1 / self.norm_type)
raise NotImplementedError(self.reduce_op)

def _reduce_shard_value(
Expand Down

0 comments on commit f4dd2fd

Please sign in to comment.