Skip to content

Commit

Permalink
[DTensor] Enable ASGD foreach optimizer and add the associated unit t…
Browse files Browse the repository at this point in the history
…est (pytorch#121942)

Enable ASGD foreach optimizer and add DTensor optimizer unit test for ASGD.

Note that we need to investigate why when using ASGD we need higher atol and rtol when comparing model parameters. Listing it as a TODO now.

Pull Request resolved: pytorch#121942
Approved by: https://github.com/wanchaol
  • Loading branch information
wz337 authored and pytorchmergebot committed Mar 15, 2024
1 parent f4dd2fd commit b92daff
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
56 changes: 55 additions & 1 deletion test/distributed/_tensor/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def _assert_optimizer(
dist_model,
dist_optim,
inputs,
*,
rtol: float = 1.3e-6,
atol: float = 1e-5,
):
for iter_idx in range(2):
# run forward/backward/optim for original model
Expand All @@ -78,7 +81,8 @@ def _assert_optimizer(
# check that the optimizer update parameters with same numerics
for p1, p2 in zip(model.parameters(), dist_model.parameters()):
p2 = p2.full_tensor()
self.assertEqual(p1, p2)
# Default 'rtol' and 'atol' for attr:`~torch.float32` are ``1.3e-6`` and ``1e-5``
self.assertEqual(p1, p2, atol=atol, rtol=rtol)

@with_comms
def test_adam_1d_sharding(self):
Expand Down Expand Up @@ -467,6 +471,56 @@ def test_adamax_1d_sharding(self):
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)

@with_comms
def test_asgd_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

asgd_configs = [
{"lr": 0.1},
{"lr": 0.1, "lambd": 0.001},
{"lr": 0.1, "lambd": 0.001, "alpha": 0.85},
{"lr": 0.1, "lambd": 0.001, "alpha": 0.85, "t0": 1e5},
{"lr": 0.1, "lambd": 0.001, "alpha": 0.85, "t0": 1e5, "weight_decay": 0.05},
{
"lr": 0.1,
"lambd": 0.001,
"alpha": 0.85,
"t0": 1e5,
"weight_decay": 0.05,
"foreach": True,
},
{
"lr": 0.1,
"lambd": 0.001,
"alpha": 0.85,
"t0": 1e5,
"weight_decay": 0.05,
"foreach": True,
"maximize": True,
},
]

for config in asgd_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.ASGD(mod.parameters(), **config)

dist_mod = distribute_module(
deepcopy(mod), mesh, shard_fn, input_fn, output_fn
)
dist_opt = torch.optim.ASGD(dist_mod.parameters(), **config)

# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)

# TODO: We want to keep a unit test for ASGD optimizer for the time being, but we need to look into why
# when using ASGD we need higher atol and rtol when comparing model parameters.
# Default 'rtol' and 'atol' for attr:`~torch.float32` are ``1.3e-6`` and ``1e-5``
# Pointer here: https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L65
self._assert_optimizer(
mesh, mod, opt, dist_mod, dist_opt, inp, atol=1.3e-5, rtol=1e-4
)


if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torch/distributed/_tensor/ops/pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> Strategy
aten._foreach_neg.default,
aten._foreach_neg_.default,
aten._foreach_reciprocal_.default,
aten._foreach_sub.List,
aten._foreach_sub_.Scalar,
aten._foreach_sqrt.default,
aten._foreach_sqrt_.default,
Expand Down

0 comments on commit b92daff

Please sign in to comment.