Skip to content

Commit

Permalink
move to tensor metadata, add validation check
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 committed Nov 28, 2024
1 parent e27842c commit 2dee366
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 41 deletions.
7 changes: 3 additions & 4 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <kernel_ir.h>
#include <logical_domain_map.h>
#include <ops/arith.h>
#include <tensor_metadata.h>
#include <transform_iter.h>
#include <transform_rfactor.h>
#include <transform_view.h>
Expand All @@ -30,7 +31,6 @@
#include <numeric>
#include <sstream>
#include <string>
#include <tensor_metadata.h>

namespace nvfuser {

Expand Down Expand Up @@ -4359,9 +4359,8 @@ std::vector<PolymorphicValue> MatmulOp::evaluate(
auto matmul_out = at::matmul(a, b);
if (out()->hasAllocation()){
auto matmul_sizes = matmul_out.sizes().vec();
auto strides = ir_utils::inferStrides(
out()->getLogicalDomain(),
out()->getMaybeAllocationDomain(),
auto strides = inferStrides(
out(),
matmul_sizes
);
matmul_out = at::as_strided(matmul_out, matmul_sizes, strides);
Expand Down
25 changes: 0 additions & 25 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1347,29 +1347,4 @@ std::vector<IterDomain*> strideOrderToAllocation(
return allocation_domain;
}

std::vector<int64_t> inferStrides(
const std::vector<IterDomain*>& logical_domain,
const std::vector<IterDomain*>& allocation_domain,
const std::vector<int64_t>& sizes
){
std::optional<std::vector<int64_t>> out_order = ir_utils::computePermutation(
TensorDomain::noReductions(logical_domain),
TensorDomain::noReductions(allocation_domain));
NVF_CHECK(out_order.has_value(), "Valid permute from logical to allocation domain was not found.");

auto rank = sizes.size();
std::vector<int64_t> sorted_strides (rank);
auto permuted_sizes = ir_utils::applyPermutation(sizes, *out_order);
sorted_strides[rank - 1] = 1;
for (int64_t idx = rank - 2; idx >= 0; idx--){
sorted_strides[idx] = permuted_sizes[idx + 1] * sorted_strides[idx + 1];
}
// Rearrange the strides in correct order of allocation
std::vector<int64_t> strides (rank);
for (auto idx: c10::irange(rank)){
strides[out_order.value()[idx]] = sorted_strides[idx];
}
return strides;
}

} // namespace nvfuser::ir_utils
28 changes: 28 additions & 0 deletions csrc/tensor_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,34 @@ inferAndValidateAllocationSizesAndStrides(
return {std::move(sizes), std::move(strides)};
}

std::vector<int64_t> inferStrides(
TensorView* tv,
const std::vector<int64_t>& sizes
){
const auto& logical_domain = tv->getLogicalDomain();
const auto& allocation_domain = tv->getMaybeAllocationDomain();

std::optional<std::vector<int64_t>> out_order = ir_utils::computePermutation(
TensorDomain::noReductions(logical_domain),
TensorDomain::noReductions(allocation_domain));
NVF_CHECK(out_order.has_value(), "Valid permute from logical to allocation domain was not found.");

auto rank = sizes.size();
std::vector<int64_t> sorted_strides (rank);
auto permuted_sizes = ir_utils::applyPermutation(sizes, *out_order);
sorted_strides[rank - 1] = 1;
for (int64_t idx = rank - 2; idx >= 0; idx--){
sorted_strides[idx] = permuted_sizes[idx + 1] * sorted_strides[idx + 1];
}
// Rearrange the strides in correct order of allocation
std::vector<int64_t> strides (rank);
for (auto idx: c10::irange(rank)){
strides[out_order.value()[idx]] = sorted_strides[idx];
}
validateAllocationSizesAndStrides(allocation_domain, tv->getContiguity(), sizes, strides);
return strides;
}

namespace {
std::pair<std::vector<int64_t>, std::vector<int64_t>> unshardedSizesAndStrides(
TensorView* tv,
Expand Down
4 changes: 4 additions & 0 deletions csrc/tensor_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ struct TensorMetaData : public Struct {
}
};

std::vector<int64_t> inferStrides(
TensorView* tv,
const std::vector<int64_t>& sizes);

// Given an ATen tensor, whose sizes and strides are w.r.t to the logical domain
// of its corresponding TensorView, compute the sizes and strides of the tensor
// with respect to its allocation domain.
Expand Down
25 changes: 13 additions & 12 deletions tests/python/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,18 @@ def fusion_func(fd: FusionDefinition) -> None:
def test_matmul_stride(self):
b, m, n, k = 3, 2, 5, 4
inputs = [
torch.randn(b, m, k, device="cuda", dtype=torch.float16),
torch.randn(b, b, m, k, device="cuda", dtype=torch.float16),
torch.randn(k, n, device="cuda", dtype=torch.float16)
]
def fusion_func(fd: FusionDefinition) -> None:
a = fd.from_pytorch(inputs[0])
b = fd.from_pytorch(inputs[1])
out = fd.ops.matmul(a, b)
fd.add_output(out, [1, 0, 2])
with FusionDefinition() as fd:
fusion_func(fd)
outputs = fd.execute(inputs)
print (outputs[0].stride())
print (outputs[0].shape)
verify_stride_order(outputs[0].stride(), [1, 0, 2])
for perm in itertools.permutations(range(4), 4):
def fusion_func(fd: FusionDefinition) -> None:
a = fd.from_pytorch(inputs[0])
b = fd.from_pytorch(inputs[1])
out = fd.ops.matmul(a, b)
fd.add_output(out, stride_order=perm)
with FusionDefinition() as fd:
fusion_func(fd)
out = fd.execute(inputs)
verify_stride_order(out[0].stride(), perm)
# Verify that setting the stride order does not change the logical shape
self.assertEqual(out[0].shape, torch.Size([b, b, m, n]))

0 comments on commit 2dee366

Please sign in to comment.