From 21144ce5704f5d95dff8d28e3a389c798b03afe3 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Sat, 1 Jun 2024 12:09:08 -0700 Subject: [PATCH] [dtensor] implement scatter op with simple replication (#126713) as titled, implement torch.scatter op with simple replications strategy, need to follow up and see if we could actually support any sharding pattern Pull Request resolved: https://github.com/pytorch/pytorch/pull/126713 Approved by: https://github.com/tianyu-l ghstack dependencies: #126712 --- test/distributed/_tensor/test_dtensor_ops.py | 1 - test/distributed/_tensor/test_tensor_ops.py | 34 ++++++++++++++++++++ torch/distributed/_tensor/ops/tensor_ops.py | 27 ++++++++++++++++ torch/distributed/_tensor/ops/utils.py | 8 +++++ 4 files changed, 69 insertions(+), 1 deletion(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 22a56118b2120b..83f0bb8751670c 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -403,7 +403,6 @@ def wrapped(fn): xfail("rsub"), xfail("scalar_tensor"), xfail("scatter_add"), - xfail("scatter"), xfail("scatter_reduce", "amax"), xfail("scatter_reduce", "amin"), xfail("scatter_reduce", "mean"), diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py index 24e52753331561..e86a702855c69a 100644 --- a/test/distributed/_tensor/test_tensor_ops.py +++ b/test/distributed/_tensor/test_tensor_ops.py @@ -390,6 +390,40 @@ def test_new_empty_strided(self): self.assertEqual(new_empty_strided_dt._local_tensor.size(), (12, 4)) self.assertEqual(new_empty_strided_dt._local_tensor.stride(), (4, 1)) + @with_comms + def test_scatter(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + comm_mode = CommDebugMode() + + # case 1 all replicate: input replicated, index/src replicated, output replicated + global_indexs = [ + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[0, 1, 2], [0, 1, 4]]), + ] + for scatter_dim in [0, 1]: + srcs = [torch.arange(1, 11).reshape((2, 5)), 4] + for global_src in srcs: + global_input = torch.zeros(3, 5, dtype=torch.int64) + global_index = global_indexs[scatter_dim] + + input_dt = distribute_tensor( + global_input.clone(), device_mesh, [Replicate()] + ) + index_dt = distribute_tensor(global_index, device_mesh, [Replicate()]) + if isinstance(global_src, torch.Tensor): + src_dt = distribute_tensor(global_src, device_mesh, [Replicate()]) + else: + src_dt = global_src + global_output = torch.scatter( + global_input, scatter_dim, global_index, global_src + ) + with comm_mode: + output_dt = torch.scatter(input_dt, scatter_dim, index_dt, src_dt) + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual(output_dt.placements, [Replicate()]) + self.assertEqual(output_dt.to_local(), global_output) + @with_comms def test_gather(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index 7aa90f2ebcd774..40f75c151579a4 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -4,6 +4,7 @@ import torch from torch.distributed._tensor._op_schema import ( + _is_inplace_op, OpSchema, OpStrategy, OutputSharding, @@ -359,6 +360,32 @@ def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType return OpStrategy([PlacementStrategy(replicate_spec)]) +@register_op_strategy( + [aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src], + schema_info=RuntimeSchemaInfo(1), +) +def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index, src] + # first we always have replicate all for inputs and output + if len(op_schema.args_strategy) < 3: + # scatter_.src/scatter.src with src be float number instead of tensor + all_replicate: List[Placement] = [Replicate()] * 3 + else: + all_replicate = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + # TODO: see if we can support input sharding pattern + inplace_op = _is_inplace_op(op_schema.op) + + op_strategy = expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, inplace_op=inplace_op + ) + return op_strategy + + @register_op_strategy(aten.gather.default) def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: input_strategy = cast(OpStrategy, op_schema.args_schema[0]) diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py index b957842b427685..245298607c5e93 100644 --- a/torch/distributed/_tensor/ops/utils.py +++ b/torch/distributed/_tensor/ops/utils.py @@ -237,7 +237,9 @@ def expand_to_full_mesh_op_strategy( mesh: DeviceMesh, op_schema: OpSchema, single_mesh_dim_strategies: List[List[Placement]], + *, input_index: int = 1, + inplace_op: bool = False, ) -> OpStrategy: # Expand the single_mesh_dim_strategies to full mesh dim strategies. all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim @@ -253,6 +255,12 @@ def expand_to_full_mesh_op_strategy( input_specs = spec_list[input_index:] input_args_strategy = op_schema.args_strategy assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the placement strategy to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + # check inputs shardable inputs_shardable = all( is_tensor_shardable(inp.shape, s)