Skip to content

Commit

Permalink
[dtensor] implement scatter op with simple replication (pytorch#126713)
Browse files Browse the repository at this point in the history
as titled, implement torch.scatter op with simple replications strategy,
need to follow up and see if we could actually support any sharding
pattern

Pull Request resolved: pytorch#126713
Approved by: https://github.com/tianyu-l
ghstack dependencies: pytorch#126712
  • Loading branch information
wanchaol authored and pytorchmergebot committed Jun 3, 2024
1 parent ded580a commit 21144ce
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 1 deletion.
1 change: 0 additions & 1 deletion test/distributed/_tensor/test_dtensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ def wrapped(fn):
xfail("rsub"),
xfail("scalar_tensor"),
xfail("scatter_add"),
xfail("scatter"),
xfail("scatter_reduce", "amax"),
xfail("scatter_reduce", "amin"),
xfail("scatter_reduce", "mean"),
Expand Down
34 changes: 34 additions & 0 deletions test/distributed/_tensor/test_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,40 @@ def test_new_empty_strided(self):
self.assertEqual(new_empty_strided_dt._local_tensor.size(), (12, 4))
self.assertEqual(new_empty_strided_dt._local_tensor.stride(), (4, 1))

@with_comms
def test_scatter(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
comm_mode = CommDebugMode()

# case 1 all replicate: input replicated, index/src replicated, output replicated
global_indexs = [
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[0, 1, 2], [0, 1, 4]]),
]
for scatter_dim in [0, 1]:
srcs = [torch.arange(1, 11).reshape((2, 5)), 4]
for global_src in srcs:
global_input = torch.zeros(3, 5, dtype=torch.int64)
global_index = global_indexs[scatter_dim]

input_dt = distribute_tensor(
global_input.clone(), device_mesh, [Replicate()]
)
index_dt = distribute_tensor(global_index, device_mesh, [Replicate()])
if isinstance(global_src, torch.Tensor):
src_dt = distribute_tensor(global_src, device_mesh, [Replicate()])
else:
src_dt = global_src
global_output = torch.scatter(
global_input, scatter_dim, global_index, global_src
)
with comm_mode:
output_dt = torch.scatter(input_dt, scatter_dim, index_dt, src_dt)

self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(output_dt.placements, [Replicate()])
self.assertEqual(output_dt.to_local(), global_output)

@with_comms
def test_gather(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
Expand Down
27 changes: 27 additions & 0 deletions torch/distributed/_tensor/ops/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from torch.distributed._tensor._op_schema import (
_is_inplace_op,
OpSchema,
OpStrategy,
OutputSharding,
Expand Down Expand Up @@ -359,6 +360,32 @@ def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType
return OpStrategy([PlacementStrategy(replicate_spec)])


@register_op_strategy(
[aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src],
schema_info=RuntimeSchemaInfo(1),
)
def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
single_mesh_dim_strategies = []

# placement list stores placements of [output, input, index, src]
# first we always have replicate all for inputs and output
if len(op_schema.args_strategy) < 3:
# scatter_.src/scatter.src with src be float number instead of tensor
all_replicate: List[Placement] = [Replicate()] * 3
else:
all_replicate = [Replicate()] * 4
single_mesh_dim_strategies.append(all_replicate)

# TODO: see if we can support input sharding pattern
inplace_op = _is_inplace_op(op_schema.op)

op_strategy = expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, inplace_op=inplace_op
)
return op_strategy


@register_op_strategy(aten.gather.default)
def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
Expand Down
8 changes: 8 additions & 0 deletions torch/distributed/_tensor/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def expand_to_full_mesh_op_strategy(
mesh: DeviceMesh,
op_schema: OpSchema,
single_mesh_dim_strategies: List[List[Placement]],
*,
input_index: int = 1,
inplace_op: bool = False,
) -> OpStrategy:
# Expand the single_mesh_dim_strategies to full mesh dim strategies.
all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim
Expand All @@ -253,6 +255,12 @@ def expand_to_full_mesh_op_strategy(
input_specs = spec_list[input_index:]
input_args_strategy = op_schema.args_strategy
assert len(input_specs) == len(input_args_strategy)
self_spec = input_args_strategy[0].strategies[0].output_spec
if inplace_op and self_spec.placements != input_specs[0].placements:
# if it's inplace op, we would only allow the placement strategy to be added when the
# input_spec matches the first argument's runtime sharding, otherwise we skip
continue

# check inputs shardable
inputs_shardable = all(
is_tensor_shardable(inp.shape, s)
Expand Down

0 comments on commit 21144ce

Please sign in to comment.