Skip to content

Commit

Permalink
Don't pass tuple to with statement (pytorch#110864)
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch#110864
Approved by: https://github.com/Skylion007, https://github.com/awgu
  • Loading branch information
ezyang authored and pytorchmergebot committed Oct 9, 2023
1 parent 4b881b0 commit 8ae623d
Showing 1 changed file with 8 additions and 14 deletions.
22 changes: 8 additions & 14 deletions test/distributed/test_inductor_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,10 @@ def example(inp, input_split_sizes_tensor, output_split_sizes_tensor, *, tag, ra
out = a2a / a2a.sum(dim=0)
return out

with (
_dynamo_dist_per_rank_init(self.rank, self.world_size),
torch._dynamo.config.patch(
dynamic_shapes=True,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
)
with _dynamo_dist_per_rank_init(self.rank, self.world_size), torch._dynamo.config.patch(
dynamic_shapes=True,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
):
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
input_split_sizes_tensor = torch.tensor([(i + 1) * (self.rank + 1) for i in range(self.world_size)], dtype=torch.int64)
Expand Down Expand Up @@ -400,13 +397,10 @@ def example(inp, output_split_sizes_tensor, *, tag, ranks, group_size):
out = a2a / a2a.sum(dim=0)
return out

with (
_dynamo_dist_per_rank_init(self.rank, self.world_size),
torch._dynamo.config.patch(
dynamic_shapes=True,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
)
with _dynamo_dist_per_rank_init(self.rank, self.world_size), torch._dynamo.config.patch(
dynamic_shapes=True,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
):
output_split_sizes_tensor = torch.tensor([1] * self.world_size, dtype=torch.int64)
inputs = (
Expand Down

0 comments on commit 8ae623d

Please sign in to comment.