diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index cacc1eb754d..2a4b3875431 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -30,7 +31,6 @@ #include #include #include -#include namespace nvfuser { @@ -4359,9 +4359,8 @@ std::vector MatmulOp::evaluate( auto matmul_out = at::matmul(a, b); if (out()->hasAllocation()){ auto matmul_sizes = matmul_out.sizes().vec(); - auto strides = ir_utils::inferStrides( - out()->getLogicalDomain(), - out()->getMaybeAllocationDomain(), + auto strides = inferStrides( + out(), matmul_sizes ); matmul_out = at::as_strided(matmul_out, matmul_sizes, strides); diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index b03c1a8e52f..ae7e4ee8418 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1347,29 +1347,4 @@ std::vector strideOrderToAllocation( return allocation_domain; } -std::vector inferStrides( - const std::vector& logical_domain, - const std::vector& allocation_domain, - const std::vector& sizes -){ - std::optional> out_order = ir_utils::computePermutation( - TensorDomain::noReductions(logical_domain), - TensorDomain::noReductions(allocation_domain)); - NVF_CHECK(out_order.has_value(), "Valid permute from logical to allocation domain was not found."); - - auto rank = sizes.size(); - std::vector sorted_strides (rank); - auto permuted_sizes = ir_utils::applyPermutation(sizes, *out_order); - sorted_strides[rank - 1] = 1; - for (int64_t idx = rank - 2; idx >= 0; idx--){ - sorted_strides[idx] = permuted_sizes[idx + 1] * sorted_strides[idx + 1]; - } - // Rearrange the strides in correct order of allocation - std::vector strides (rank); - for (auto idx: c10::irange(rank)){ - strides[out_order.value()[idx]] = sorted_strides[idx]; - } - return strides; -} - } // namespace nvfuser::ir_utils diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 32fdee2de42..5ecb2b05be6 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -317,6 +317,34 @@ inferAndValidateAllocationSizesAndStrides( return {std::move(sizes), std::move(strides)}; } +std::vector inferStrides( + TensorView* tv, + const std::vector& sizes +){ + const auto& logical_domain = tv->getLogicalDomain(); + const auto& allocation_domain = tv->getMaybeAllocationDomain(); + + std::optional> out_order = ir_utils::computePermutation( + TensorDomain::noReductions(logical_domain), + TensorDomain::noReductions(allocation_domain)); + NVF_CHECK(out_order.has_value(), "Valid permute from logical to allocation domain was not found."); + + auto rank = sizes.size(); + std::vector sorted_strides (rank); + auto permuted_sizes = ir_utils::applyPermutation(sizes, *out_order); + sorted_strides[rank - 1] = 1; + for (int64_t idx = rank - 2; idx >= 0; idx--){ + sorted_strides[idx] = permuted_sizes[idx + 1] * sorted_strides[idx + 1]; + } + // Rearrange the strides in correct order of allocation + std::vector strides (rank); + for (auto idx: c10::irange(rank)){ + strides[out_order.value()[idx]] = sorted_strides[idx]; + } + validateAllocationSizesAndStrides(allocation_domain, tv->getContiguity(), sizes, strides); + return strides; +} + namespace { std::pair, std::vector> unshardedSizesAndStrides( TensorView* tv, diff --git a/csrc/tensor_metadata.h b/csrc/tensor_metadata.h index d238ef7af5e..1b31fb60814 100644 --- a/csrc/tensor_metadata.h +++ b/csrc/tensor_metadata.h @@ -100,6 +100,10 @@ struct TensorMetaData : public Struct { } }; +std::vector inferStrides( + TensorView* tv, + const std::vector& sizes); + // Given an ATen tensor, whose sizes and strides are w.r.t to the logical domain // of its corresponding TensorView, compute the sizes and strides of the tensor // with respect to its allocation domain. diff --git a/tests/python/test_matmul.py b/tests/python/test_matmul.py index 1881f2b46ed..8d6a1193f0b 100644 --- a/tests/python/test_matmul.py +++ b/tests/python/test_matmul.py @@ -205,17 +205,18 @@ def fusion_func(fd: FusionDefinition) -> None: def test_matmul_stride(self): b, m, n, k = 3, 2, 5, 4 inputs = [ - torch.randn(b, m, k, device="cuda", dtype=torch.float16), + torch.randn(b, b, m, k, device="cuda", dtype=torch.float16), torch.randn(k, n, device="cuda", dtype=torch.float16) ] - def fusion_func(fd: FusionDefinition) -> None: - a = fd.from_pytorch(inputs[0]) - b = fd.from_pytorch(inputs[1]) - out = fd.ops.matmul(a, b) - fd.add_output(out, [1, 0, 2]) - with FusionDefinition() as fd: - fusion_func(fd) - outputs = fd.execute(inputs) - print (outputs[0].stride()) - print (outputs[0].shape) - verify_stride_order(outputs[0].stride(), [1, 0, 2]) \ No newline at end of file + for perm in itertools.permutations(range(4), 4): + def fusion_func(fd: FusionDefinition) -> None: + a = fd.from_pytorch(inputs[0]) + b = fd.from_pytorch(inputs[1]) + out = fd.ops.matmul(a, b) + fd.add_output(out, stride_order=perm) + with FusionDefinition() as fd: + fusion_func(fd) + out = fd.execute(inputs) + verify_stride_order(out[0].stride(), perm) + # Verify that setting the stride order does not change the logical shape + self.assertEqual(out[0].shape, torch.Size([b, b, m, n])) \ No newline at end of file