From 9bf2645f38acac19b93a331cb16486e00e85a397 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Nov 2024 19:49:12 -0500 Subject: [PATCH 01/25] Move OptOutMutator tests to new file and add repro --- CMakeLists.txt | 1 + tests/cpp/test_dynamic_transform.cpp | 80 -------------- tests/cpp/test_mutator.cpp | 149 +++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 80 deletions(-) create mode 100644 tests/cpp/test_mutator.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 87f11f16658..9d7d7b32cdb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -560,6 +560,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_memory.cpp ${NVFUSER_ROOT}/tests/cpp/test_move_split_cat.cpp ${NVFUSER_ROOT}/tests/cpp/test_move_pad.cpp + ${NVFUSER_ROOT}/tests/cpp/test_mutator.cpp ${NVFUSER_ROOT}/tests/cpp/test_no_op.cpp ${NVFUSER_ROOT}/tests/cpp/test_persistent_buffer.cpp ${NVFUSER_ROOT}/tests/cpp/test_pointwise.cpp diff --git a/tests/cpp/test_dynamic_transform.cpp b/tests/cpp/test_dynamic_transform.cpp index 8eb468999b7..e6c6b1292b3 100644 --- a/tests/cpp/test_dynamic_transform.cpp +++ b/tests/cpp/test_dynamic_transform.cpp @@ -1174,86 +1174,6 @@ TEST_F(NVFuserTest, Issue249InputNegative1_CUDA) { executor_cache.fusion(), outputs, {at_x, 2, 4, 15}, __LINE__, __FILE__); } -// Test that OptOutMutator mutates expressions in a predictable way -// See https://github.com/NVIDIA/Fuser/issues/852 -TEST_F(NVFuserTest, OptOutMutatorMutatedOutput) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion* fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - - auto tv1 = neg(tv0); - - auto tv2 = set(tv1); - fusion->addOutput(tv2); - - auto tv3 = set(tv0); - - OptOutMutator mut; - mut.registerMutation(tv1, tv3); - - for (auto stmt : StmtSort::getStmts(fusion)) { - mut.dispatchMutate(stmt); - } - - EXPECT_NE(tv3->definition(), nullptr); - EXPECT_TRUE(tv3->definition()->isA()); - EXPECT_NE(tv2->definition(), nullptr); - EXPECT_TRUE(tv2->definition()->isA()); - EXPECT_EQ(tv2->definition()->input(0), tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({3}, options); - - inlineMost(); - - KernelExecutor ke; - ke.compile(fusion); - - auto outputs = ke.run({t0}); - - testValidate(fusion, outputs, {t0}, __LINE__, __FILE__); -} - -// Another test related to https://github.com/NVIDIA/Fuser/issues/852 -TEST_F(NVFuserTest, OptOutMutatorRedefinedConstant) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion* fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto s0 = IrBuilder::create(DataType::Int); - fusion->addInput(s0); - auto s1 = neg(s0); - - auto tv0 = full({IrBuilder::create(2L)}, s1, DataType::Int); - fusion->addOutput(tv0); - - // After the following mutation, it's reasonable to expect the input scalar s0 - // to be ignored, and the output to just be ones. - OptOutMutator mut; - auto c = fusion->oneVal(DataType::Int); - mut.registerMutation(s1, c); - - for (auto stmt : StmtSort::getStmts(fusion)) { - mut.dispatchMutate(stmt); - } - - EXPECT_EQ( - c->definition(), nullptr); // Replacement value should not be redefined - EXPECT_EQ(tv0->definition()->as()->getFillValue(), c); - - inlineMost(); - - KernelExecutor ke; - ke.compile(fusion); - - auto outputs = ke.run({3L}); - - testValidate(fusion, outputs, {3L}, __LINE__, __FILE__); -} - // Test that we can squeeze Symbolic IterDomains and that we properly detect // improper concretizations where we have squeezed a dimension with extent // other than 1. diff --git a/tests/cpp/test_mutator.cpp b/tests/cpp/test_mutator.cpp new file mode 100644 index 00000000000..763c3b0c4ac --- /dev/null +++ b/tests/cpp/test_mutator.cpp @@ -0,0 +1,149 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace nvfuser { + +// Test that OptOutMutator mutates expressions in a predictable way +// See https://github.com/NVIDIA/Fuser/issues/852 +TEST_F(NVFuserTest, OptOutMutatorMutatedOutput) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + + auto tv1 = neg(tv0); + + auto tv2 = set(tv1); + fusion->addOutput(tv2); + + auto tv3 = set(tv0); + + OptOutMutator mut; + mut.registerMutation(tv1, tv3); + + for (auto stmt : StmtSort::getStmts(fusion)) { + mut.dispatchMutate(stmt); + } + + EXPECT_NE(tv3->definition(), nullptr); + EXPECT_TRUE(tv3->definition()->isA()); + EXPECT_NE(tv2->definition(), nullptr); + EXPECT_TRUE(tv2->definition()->isA()); + EXPECT_EQ(tv2->definition()->input(0), tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({3}, options); + + inlineMost(); + + KernelExecutor ke; + ke.compile(fusion); + + auto outputs = ke.run({t0}); + + testValidate(fusion, outputs, {t0}, __LINE__, __FILE__); +} + +// Another test related to https://github.com/NVIDIA/Fuser/issues/852 +TEST_F(NVFuserTest, OptOutMutatorRedefinedConstant) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto s0 = IrBuilder::create(DataType::Int); + fusion->addInput(s0); + auto s1 = neg(s0); + + auto tv0 = full({IrBuilder::create(2L)}, s1, DataType::Int); + fusion->addOutput(tv0); + + // After the following mutation, it's reasonable to expect the input scalar s0 + // to be ignored, and the output to just be ones. + OptOutMutator mut; + auto c = fusion->oneVal(DataType::Int); + mut.registerMutation(s1, c); + + for (auto stmt : StmtSort::getStmts(fusion)) { + mut.dispatchMutate(stmt); + } + + EXPECT_EQ( + c->definition(), nullptr); // Replacement value should not be redefined + EXPECT_EQ(tv0->definition()->as()->getFillValue(), c); + + inlineMost(); + + KernelExecutor ke; + ke.compile(fusion); + + auto outputs = ke.run({3L}); + + testValidate(fusion, outputs, {3L}, __LINE__, __FILE__); +} + +// Test that additional IDs are preserved when mutating a TensorView +TEST_F(NVFuserTest, OptOutMutatorAdditionalBroadcastID) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + + auto tv1 = exp(tv0); + + fusion->addOutput(tv1); + + // We add a broadcast domain bS2{1}. This adds the new Broadcast ID to tv1->domain()->additionalIDs() + // logical: [ iS1{i0} ] + // loop: [ iS1{i0}, bS2{1} ] + // additional IDs: [ bS2{1} ] + tv1->broadcast(1); + EXPECT_FALSE(tv1->domain()->additionalIDs().empty()); + + // After this split we have + // logical: [ iS1{i0} ] + // loop: [ iS1{i0}, bS3{1}, bS4{2} ] + // additional IDs: [ bS2{1} ] + tv1->split(1, 2); + EXPECT_FALSE(tv1->domain()->additionalIDs().empty()); + + // Now register a mutation that will alter some IDs in the domain + OptOutMutator mut; + mut.registerMutation(tv1->axis(0)->extent(), IrBuilder::create(DataType::Index)); + TensorDomain* old_tensor_domain = tv1->domain(); + auto all_stmts = StmtSort::getStmts( + fusion, + /*traverse_members*/ true, + /*traverse_attributes*/ true, + /*traverse_siblings*/ true); + for (auto stmt : all_stmts) { + mut.dispatchMutate(stmt); + } + EXPECT_TRUE(tv1->domain() != old_tensor_domain) << "Mutation did not change the TensorDomain"; + + EXPECT_FALSE(tv1->domain()->additionalIDs().empty())<< "Mutation did not preserve additional IDs"; +} + + +} // namespace nvfuser From 96dd201ee27d2d44b995e89687ba5617bc2070ed Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Nov 2024 19:49:28 -0500 Subject: [PATCH 02/25] Add additional_ids arg to big ctor --- csrc/ir/internal_base_nodes.h | 3 ++- csrc/ir/nodes.cpp | 4 +++- csrc/mutator.cpp | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index f9f422cd994..58086ec1b0d 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -441,7 +441,8 @@ class TensorDomain : public Val { std::vector logical_domain, std::vector allocation, std::vector loop_domain, - std::vector> contiguity = {}); + std::vector> contiguity = {}, + std::vector additional_ids = {}); TensorDomain(IrBuilderPasskey, const TensorDomain* src); diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 3d4213f68fe..f88618e03e5 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -3078,13 +3078,15 @@ TensorDomain::TensorDomain( std::vector logical_domain, std::vector allocation_domain, std::vector loop_domain, - std::vector> contiguity) + std::vector> contiguity, + std::vector additional_ids) : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), logical_domain_(std::move(logical_domain)), allocation_domain_(std::move(allocation_domain)), loop_domain_(std::move(loop_domain)), initial_loop_domain_(loop_domain_), + additional_ids_(additional_ids), contiguity_( contiguity.empty() ? getContiguityFilledWith(maybeAllocation(), false) : std::move(contiguity)) { diff --git a/csrc/mutator.cpp b/csrc/mutator.cpp index 5f183bf4839..26d1ca82924 100644 --- a/csrc/mutator.cpp +++ b/csrc/mutator.cpp @@ -156,6 +156,7 @@ void OptOutMutator::mutate(TensorDomain* td) { ? updateIdVec(td->allocation()) : std::vector(); std::vector domain = updateIdVec(td->loop()); + std::vector additional_ids = updateIdVec(td->additionalIDs()); if (!mutated) { return; @@ -167,7 +168,8 @@ void OptOutMutator::mutate(TensorDomain* td) { logical_dom, allocation_dom, domain, - td->contiguity()); + td->contiguity(), + additional_ids); registerMutation(td, mutated_val); } From c7c790b24c1a187f0cc9072068addf8b2c164898 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 15 Nov 2024 09:46:47 -0500 Subject: [PATCH 03/25] Only check actually used IDs in predicate elimination --- .../analysis/predicate_elimination.cpp | 44 +++- tests/cpp/test_matmul.cpp | 209 +++++++++++++++--- 2 files changed, 222 insertions(+), 31 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 347cb63222c..3159de8a5db 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -188,7 +188,37 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) .getReplay(); - ProducerConsumerPairAnalyzer analyzer(c2p); + // Find all IterDomains involved in index expressions + // TODO: do we need to find logical IDs in producer that are involved in + // its allocation domain too? + std::vector mapped_root_vals, loop_vals; + for (IterDomain* id : consumer->getRootDomain()) { + if (c2p.find(id) != c2p.end()) { + mapped_root_vals.push_back(id); + } + } + for (IterDomain* id : consumer->getLoopDomain()) { + loop_vals.push_back(id); + } + + // Collect all IterDomains along path instead of Exprs + std::unordered_set index_ids; + for ([[maybe_unused]] auto [expr, dir] : IRBFS::getExprsBetween( + mapped_root_vals, + loop_vals, + /*require_all_to_visited=*/false)) { + for (Val* v : expr->inputs()) { + if (auto* id = dynamic_cast(v)) { + index_ids.insert(id); + } + } + for (Val* v : expr->outputs()) { + if (auto* id = dynamic_cast(v)) { + index_ids.insert(id); + } + } + } + ProducerConsumerPairAnalyzer analyzer(c2p, index_ids); for (auto id : consumer->getLoopDomain()) { if (analyzer.needsPredicate(id)) { @@ -201,12 +231,19 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { private: ProducerConsumerPairAnalyzer( - const std::unordered_map& c2p) - : c2p_(c2p) {} + const std::unordered_map& c2p, + const std::unordered_set index_ids) + : c2p_(c2p), index_ids_(index_ids) {} // Returns true if no out-of-bound accesses could occur with a // producer bool needsPredicate(IterDomain* consumer_id) { + // TODO: check that this consumer_id is actually involved in indexing the + // producer. If it is not connected to the producer allocation domain in + // the broadcast graph, then we can skip processing it. + if (index_ids_.find(consumer_id) == index_ids_.end()) { + return false; + } needs_predicate_ = false; handle(consumer_id); return needs_predicate_; @@ -297,6 +334,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { //! BestEffort map from consumer IDs to producer IDs const std::unordered_map& c2p_; bool needs_predicate_ = false; + std::unordered_set index_ids_; }; class PredicateChcker : public IterVisitor { diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index fa185665253..ec087e8fd6c 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -449,6 +449,35 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulPipelineGmem) { } } +// Check that mma op is not predicated. +class PredicateChecker : public kir::IrVisitor { + public: + using kir::IrVisitor::handle; + bool found_mma = false; + + private: + void handle(kir::Asm* asm_) final { +#if IS_CPP20 + if (!asm_->code().starts_with("mma") && + !asm_->code().starts_with("wgmma")) { +#else + if (asm_->code().substr(0, 3) != "mma" && + asm_->code().substr(0, 5) != "wgmma") { +#endif + return; + } + found_mma = true; + for (auto expr : scope_exprs_) { + NVF_CHECK( + !expr->isA() || + expr->as()->predicate()->isTrivial(), + "MmaOp should't be predicated!", + " Get predicate ", + expr->as()->predicate()->toInlineString()); + } + } +}; + // Matmul test for Ampere MMA: checking CTA Swizzles TEST_P(MatmulTestWithLayout, AmpereSwizzle) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); @@ -549,35 +578,9 @@ TEST_P(MatmulTestWithLayout, AmpereSwizzle) { runtime = 0; } - // Check that mma op is not predicated. This is a regression test for - // https://github.com/NVIDIA/Fuser/issues/95 - class PredicateChecker : public kir::IrVisitor { - public: - using kir::IrVisitor::handle; - bool found_mma = false; - - private: - void handle(kir::Asm* asm_) final { -#if IS_CPP20 - if (!asm_->code().starts_with("mma")) { -#else - if (asm_->code().substr(0, 3) != "mma") { -#endif - return; - } - found_mma = true; - for (auto expr : scope_exprs_) { - NVF_CHECK( - !expr->isA() || - expr->as()->predicate()->isTrivial(), - "MmaOp should't be predicated!", - " Get predicate ", - expr->as()->predicate()->toInlineString()); - } - } - } pred_checker; - + // This is a regression test for https://github.com/NVIDIA/Fuser/issues/95 GpuLower gpulw(&fusion); + PredicateChecker pred_checker; pred_checker.handle(gpulw.run()->topLevelExprs()); ASSERT_TRUE(pred_checker.found_mma); }; @@ -3798,4 +3801,154 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } +// Test scheduling a Hopper matmul where the operands are 2D +TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { + Fusion fusion; + FusionGuard fg(&fusion); + + // constexpr int64_t M = 2048, N = 2048, K = 8192; + constexpr auto macro = MmaMacro::Hopper_64_256_16; + // constexpr auto layout = MmaLayout::NT; // [K, M] x [K, N] -> [M, N] + constexpr auto swizzle = MmaInputSmemSwizzle::B128; + const auto dtype = DataType::Half; + + constexpr int64_t stages = 1; + constexpr int64_t prefetch = 3; + const int64_t cta_m = 2 * getM(macro); + const int64_t cta_n = 1 * getN(macro); + + auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // [K, M] + auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // [K, N] + fusion.addInput(tv0); + fusion.addInput(tv1); + + // The output is [M, N, K] (no reordering needed) + MmaOp::AxisMapping axis_mapping{.a_axes = {1, -1, 0}, .b_axes = {-1, 1, 0}}; + auto tv2 = + fusedMultiplySum(tv0, tv1, /*axes=*/{-1}, /*init=*/nullptr, axis_mapping); + + auto tv3 = castOp(DataType::Half, tv2); + + fusion.addOutput(tv3); + + auto mma_ops = ir_utils::getOpsOfType(&fusion); + NVF_CHECK( + 1 == mma_ops.size(), + "Invalid number of MmaOp instances in fusion definition, expected 1, got ", + mma_ops.size()); + mma_ops.front()->setMacro(macro); + + // gmem [K, M] x gmem [K, N] -mma-> register [M, N, rK] + // register [M, N, rK] -cast-> gmem [M, N] + + auto tv0c = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv0c->setMemoryType(MemoryType::Shared); + auto tv1c = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv1c->setMemoryType(MemoryType::Shared); + auto tv3c = tv3->cacheBefore(); + + tv0c->broadcast(-1); // [K, M] -> [K, M, 1] + tv1c->broadcast(-2); // [K, N] -> [K, 1, N] + + // gmem [K, M, 1] -TMA-> smem [K, M, 1] + // gmem [K, 1, N] -TMA-> smem [K, 1, N] + // smem [K, M, 1] x smem [K, 1, N] -mma-> register [M, N, rK] + // register [M, N, rK] -cast-> register [M, N] -set-> gmem [M, N] + + // Create tiles + tv2->split(-3, cta_m); + tv2->split(-2, cta_n); + tv2->split(-1, getK(macro)); + // [Mo, Mi, No, Ni, Ko, Ki] -> [Mo, No, Ko, Mi, Ni, Ki] + tv2->reorder({{-5, -3}, {-3, -2}}); + tv2->axis(0)->parallelize(ParallelType::BIDy); + tv2->axis(1)->parallelize(ParallelType::BIDx); + + // NOTE: since in this case we do not have "proper" broadcast in the inputs, + // we cannot simply propagate transforms to the operands. Instead, we + // propagate forward to the outputs and manually schedule the smem operands. + + // ComputeAtMap in this case finds bS15 and bS16. They are in the loop domain + // at this point + + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + tv2, + -1, + {tv3}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); + + // Schedule operands + for (TensorView* tv : {tv0c, tv1c}) { + tv->reorder({{-3, -1}}); // [K, M, N] -> [M, N, K] + // NOTE: above axes are given in MNK order, but inputs are in KMN + tv->split(-3, cta_m); + tv->split(-2, cta_n); + tv->split(-1, getK(macro)); + // [Mo, Mi, No, Ni, Ko, Ki] -> [Mo, No, Ko, Mi, Ni, Ki] + // [Ko, Ki, Mo, Mi, No, Ni] -> [Mo, No, Ko, Mi, Ni, Ki] + tv->reorder({{-5, -3}, {-3, -2}}); + tv->axis(0)->parallelize(ParallelType::BIDy); + tv->axis(1)->parallelize(ParallelType::BIDx); + } + + // [..., Mi, Ni, Ki] -> [..., Ni, Ki, Mi] + tv0c->reorder({{-3, -1}}); + tv0c->applyMmaSwizzleForTMALoad(swizzle); + // [..., Mi, Ni, Ki] -> [..., Mi, Ki, Ni] + tv1c->reorder({{-1, -2}}); + tv1c->applyMmaSwizzleForTMALoad(swizzle); + + { + tv2->split(-3, getM(macro)); + tv2->split(-2, getN(macro)); + // [Mo, No, Ko, Mio, Mii, Nio, Nii, Ki] + // -> [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki] + tv2->reorder({{-4, -3}}); + tv2->merge(-5); + tv2->axis(-4)->parallelize(ParallelType::TIDy); + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + tv2, + -1, + {tv3}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); + } + + { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv2->getLoopDomain()); + tv2->setAllocationDomain(s.as(), true); + tv2->axis(-1)->parallelize(ParallelType::Mma); + tv2->axis(-2)->parallelize(ParallelType::Mma); + tv2->axis(-3)->parallelize(ParallelType::Mma); + } + + for (auto tv : {tv3c, tv3}) { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv->getLoopDomain()); + tv->setLoopDomain(s.as()); + } + tv3->axis(-1)->parallelize(ParallelType::Vectorize); + + inlineMost(); + + if (stages > 1) { + tv0c->circularBuffer(stages, prefetch); + tv1c->circularBuffer(stages, prefetch); + } + + // Test that predicate elimination works when the MmaOp's operands have no + // logical broadcasts + GpuLower gpulw(&fusion); + kir::Kernel* kernel = gpulw.run(); + PredicateChecker pred_checker; + pred_checker.handle(kernel->topLevelExprs()); + ASSERT_TRUE(pred_checker.found_mma); + + // TODO: compile and run kernel once inlining is fixed +} + } // namespace nvfuser From 11083a17f7172bf959b193a0b31bbb118b4f8f5a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 15 Nov 2024 13:11:46 -0500 Subject: [PATCH 04/25] clang-format --- tests/cpp/test_mutator.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/cpp/test_mutator.cpp b/tests/cpp/test_mutator.cpp index 763c3b0c4ac..ef1f3433a51 100644 --- a/tests/cpp/test_mutator.cpp +++ b/tests/cpp/test_mutator.cpp @@ -114,10 +114,9 @@ TEST_F(NVFuserTest, OptOutMutatorAdditionalBroadcastID) { fusion->addOutput(tv1); - // We add a broadcast domain bS2{1}. This adds the new Broadcast ID to tv1->domain()->additionalIDs() - // logical: [ iS1{i0} ] - // loop: [ iS1{i0}, bS2{1} ] - // additional IDs: [ bS2{1} ] + // We add a broadcast domain bS2{1}. This adds the new Broadcast ID to + // tv1->domain()->additionalIDs() logical: [ iS1{i0} ] loop: [ iS1{i0}, bS2{1} + // ] additional IDs: [ bS2{1} ] tv1->broadcast(1); EXPECT_FALSE(tv1->domain()->additionalIDs().empty()); @@ -130,7 +129,8 @@ TEST_F(NVFuserTest, OptOutMutatorAdditionalBroadcastID) { // Now register a mutation that will alter some IDs in the domain OptOutMutator mut; - mut.registerMutation(tv1->axis(0)->extent(), IrBuilder::create(DataType::Index)); + mut.registerMutation( + tv1->axis(0)->extent(), IrBuilder::create(DataType::Index)); TensorDomain* old_tensor_domain = tv1->domain(); auto all_stmts = StmtSort::getStmts( fusion, @@ -140,10 +140,11 @@ TEST_F(NVFuserTest, OptOutMutatorAdditionalBroadcastID) { for (auto stmt : all_stmts) { mut.dispatchMutate(stmt); } - EXPECT_TRUE(tv1->domain() != old_tensor_domain) << "Mutation did not change the TensorDomain"; + EXPECT_TRUE(tv1->domain() != old_tensor_domain) + << "Mutation did not change the TensorDomain"; - EXPECT_FALSE(tv1->domain()->additionalIDs().empty())<< "Mutation did not preserve additional IDs"; + EXPECT_FALSE(tv1->domain()->additionalIDs().empty()) + << "Mutation did not preserve additional IDs"; } - } // namespace nvfuser From 11c43c41a398bfc749f4f6c0b341ffeb8fd98304 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 15 Nov 2024 13:11:54 -0500 Subject: [PATCH 05/25] clang-tidy of TensorDomain ctor --- csrc/ir/nodes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index f88618e03e5..8758a8022cb 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -3086,7 +3086,7 @@ TensorDomain::TensorDomain( allocation_domain_(std::move(allocation_domain)), loop_domain_(std::move(loop_domain)), initial_loop_domain_(loop_domain_), - additional_ids_(additional_ids), + additional_ids_(std::move(additional_ids)), contiguity_( contiguity.empty() ? getContiguityFilledWith(maybeAllocation(), false) : std::move(contiguity)) { From 64be2c7a18cdb7d73e9134c39d0a354d674d6cdf Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 18 Nov 2024 08:48:02 -0500 Subject: [PATCH 06/25] Remove debugging comment --- tests/cpp/test_matmul.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index ec087e8fd6c..c4fed527a42 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3867,10 +3867,6 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { // NOTE: since in this case we do not have "proper" broadcast in the inputs, // we cannot simply propagate transforms to the operands. Instead, we // propagate forward to the outputs and manually schedule the smem operands. - - // ComputeAtMap in this case finds bS15 and bS16. They are in the loop domain - // at this point - scheduler_utils::BoundedDirectionalTransformPropagator::forward( tv2, -1, From 05d5ca42b95d1bceb0794f35496c466cdeea02c3 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 18 Nov 2024 15:09:16 -0500 Subject: [PATCH 07/25] [DO NOT MERGE] added throw to test impact on existing tests --- csrc/device_lower/analysis/predicate_elimination.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 3159de8a5db..78c88bfea9a 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -242,6 +242,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // producer. If it is not connected to the producer allocation domain in // the broadcast graph, then we can skip processing it. if (index_ids_.find(consumer_id) == index_ids_.end()) { + NVF_THROW("FOUND UNEXPECTED PATH IN TEST"); return false; } needs_predicate_ = false; From 59745676a37d983d748c88626f8e55c101f7280d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 18 Nov 2024 21:56:32 -0500 Subject: [PATCH 08/25] Refactor getting indexing IDs into utility This updates the NVF_THROW check to rule out the BroadcastOp case. --- .../analysis/predicate_elimination.cpp | 45 ++-------- csrc/device_lower/utils.cpp | 84 +++++++++++++++++++ csrc/device_lower/utils.h | 9 ++ 3 files changed, 101 insertions(+), 37 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 78c88bfea9a..4fd7a4a42b1 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -183,42 +183,11 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { return true; } - auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer); auto c2p = - BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) - .getReplay(); - - // Find all IterDomains involved in index expressions - // TODO: do we need to find logical IDs in producer that are involved in - // its allocation domain too? - std::vector mapped_root_vals, loop_vals; - for (IterDomain* id : consumer->getRootDomain()) { - if (c2p.find(id) != c2p.end()) { - mapped_root_vals.push_back(id); - } - } - for (IterDomain* id : consumer->getLoopDomain()) { - loop_vals.push_back(id); - } - - // Collect all IterDomains along path instead of Exprs - std::unordered_set index_ids; - for ([[maybe_unused]] auto [expr, dir] : IRBFS::getExprsBetween( - mapped_root_vals, - loop_vals, - /*require_all_to_visited=*/false)) { - for (Val* v : expr->inputs()) { - if (auto* id = dynamic_cast(v)) { - index_ids.insert(id); - } - } - for (Val* v : expr->outputs()) { - if (auto* id = dynamic_cast(v)) { - index_ids.insert(id); - } - } - } - ProducerConsumerPairAnalyzer analyzer(c2p, index_ids); + PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); + [[maybe_unused]] const auto [producer_index_ids, consumer_index_ids] = + lower_utils::getIndexIDs(producer, consumer, &c2p); + ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); for (auto id : consumer->getLoopDomain()) { if (analyzer.needsPredicate(id)) { @@ -238,10 +207,12 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // Returns true if no out-of-bound accesses could occur with a // producer bool needsPredicate(IterDomain* consumer_id) { - // TODO: check that this consumer_id is actually involved in indexing the + // Check that this consumer_id is actually involved in indexing the // producer. If it is not connected to the producer allocation domain in // the broadcast graph, then we can skip processing it. - if (index_ids_.find(consumer_id) == index_ids_.end()) { + if (!consumer_id->isBroadcast() && + index_ids_.find(consumer_id) == index_ids_.end()) { + // TODO: Remove this line and the isBroadcast check in the condition above NVF_THROW("FOUND UNEXPECTED PATH IN TEST"); return false; } diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index f77a7520e9a..d0a5b6488c2 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2004,6 +2004,90 @@ std::vector getSyncExprs(AsyncOpType async_type, int64_t keep_stages) { return sync_exprs; } +std::pair, std::unordered_set> +getIndexIDs( + TensorView* producer, + TensorView* consumer, + const std::unordered_map* c2p) { + // First we find the consumer root IDs that map to the producer + std::unordered_map c2p_tmp; + if (c2p == nullptr) { + auto c2p_tmp = + PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); + c2p = &c2p_tmp; + } + // Track the IDs involved in indexing in both producer and consumer + std::unordered_set consumer_indexing_ids; + std::unordered_set producer_indexing_ids; + for (IterDomain* id : consumer->getMaybeRootDomain()) { + auto it = c2p->find(id); + if (it == c2p->end()) { + continue; + } + // These are the immediately mapped consumer root and producer logical + // IDs. This is a starting point for our later traversals, which will fill + // these sets out. + consumer_indexing_ids.insert(it->first); + producer_indexing_ids.insert(it->second); + } + + // Now traverse from the starting set (which, as noted above is a subset of + // either the producer logical or consumer root) to the target which is + // either the producer allocation domain or the consumer loop domain. These + // are the IDs that will actually affect indexing. Any other IDs can be + // skipped. + auto traverse = [](std::unordered_set& indexing_ids, + const std::vector& start_domain, + const std::vector& target_domain) { + for (auto [expr, dir] : IRBFS::getExprsBetween( + {start_domain.begin(), start_domain.end()}, + {target_domain.begin(), target_domain.end()}, + /*require_all_to_visited=*/false)) { + // If there are any indexing IDs in the inputs, count all outputs as + // indexing IDs + if (dir == Direction::Forward) { + if (std::any_of( + expr->inputs().begin(), expr->inputs().end(), [&](Val* input) { + auto* id = dynamic_cast(input); + return id && indexing_ids.count(id) != 0; + })) { + for (Val* v : expr->outputs()) { + if (auto* id = dynamic_cast(v)) { + indexing_ids.insert(id); + } + } + } + } else if (dir == Direction::Backward) { + if (std::any_of( + expr->outputs().begin(), + expr->outputs().end(), + [&](Val* output) { + auto* id = dynamic_cast(output); + return id && indexing_ids.count(id) != 0; + })) { + for (Val* v : expr->inputs()) { + if (auto* id = dynamic_cast(v)) { + indexing_ids.insert(id); + } + } + } + } else { + NVF_THROW("Found unexpected direction"); + } + } + }; + traverse( + producer_indexing_ids, + /*start_domain=*/producer->getLogicalDomain(), + /*target_domain=*/producer->getMaybeAllocationDomain()); + traverse( + consumer_indexing_ids, + /*start_domain=*/consumer->getMaybeRootDomain(), + /*target_domain=*/consumer->getLoopDomain()); + + return {producer_indexing_ids, consumer_indexing_ids}; +} + } // namespace lower_utils } // namespace nvfuser diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index 2f53e7ed0ae..f78208243e8 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -19,6 +19,7 @@ #include #include +#include "logical_domain_map.h" // Provides utilities for dealing with nested ForLoop and IfThenElse scopes @@ -379,6 +380,14 @@ std::vector getSyncExprs( AsyncOpType async_type, int64_t keep_stages = 0); +//! Get the set of IterDomains on the shortest path from the producer allocation +//! domain to the consumer loop domain. +std::pair, std::unordered_set> +getIndexIDs( + TensorView* producer, + TensorView* consumer, + const std::unordered_map* c2p = nullptr); + } // namespace lower_utils } // namespace nvfuser From e0ad380f7bbefa68dc0bdcfe07adda7511a12826 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 19 Nov 2024 07:18:26 -0500 Subject: [PATCH 09/25] Put back accidentally removed replay --- csrc/device_lower/analysis/predicate_elimination.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 4fd7a4a42b1..34296fdf006 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -183,8 +183,10 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { return true; } + auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer); auto c2p = - PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); + BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) + .getReplay(); [[maybe_unused]] const auto [producer_index_ids, consumer_index_ids] = lower_utils::getIndexIDs(producer, consumer, &c2p); ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); @@ -213,7 +215,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { if (!consumer_id->isBroadcast() && index_ids_.find(consumer_id) == index_ids_.end()) { // TODO: Remove this line and the isBroadcast check in the condition above - NVF_THROW("FOUND UNEXPECTED PATH IN TEST"); + NVF_THROW("FOUND UNEXPECTED PATH IN TEST ", consumer_id->toString()); return false; } needs_predicate_ = false; From 3c2631fbc202e805cc6ba3c09ecbeb8042d42d20 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 19 Nov 2024 09:25:52 -0500 Subject: [PATCH 10/25] Add skipped root->logical mappings in c2p --- csrc/device_lower/analysis/predicate_elimination.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 34296fdf006..fb1911a8425 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -185,8 +185,13 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer); auto c2p = - BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) + BestEffortReplay::replayPasC( + producer, consumer, /*consumer_compute_at_axis=*/-1, pairwise_map) .getReplay(); + for (auto [c, p] : pairwise_map.mapConsumerToProducer()) { + // replayPasC skips mapping after the compute at position + c2p[c] = p; + } [[maybe_unused]] const auto [producer_index_ids, consumer_index_ids] = lower_utils::getIndexIDs(producer, consumer, &c2p); ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); From 3342e77bfa5e45c42acf41f2dfeb9f41af9c25f4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 19 Nov 2024 09:26:41 -0500 Subject: [PATCH 11/25] Simplify getIndexIDs --- csrc/device_lower/utils.cpp | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index d0a5b6488c2..3fccec517ac 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2045,32 +2045,24 @@ getIndexIDs( /*require_all_to_visited=*/false)) { // If there are any indexing IDs in the inputs, count all outputs as // indexing IDs - if (dir == Direction::Forward) { - if (std::any_of( - expr->inputs().begin(), expr->inputs().end(), [&](Val* input) { - auto* id = dynamic_cast(input); - return id && indexing_ids.count(id) != 0; - })) { - for (Val* v : expr->outputs()) { + const auto processExpr = [&indexing_ids]( + const std::vector& prev_vals, + const std::vector& next_vals) { + if (std::any_of(prev_vals.begin(), prev_vals.end(), [&](Val* prev) { + auto* id = dynamic_cast(prev); + return id && indexing_ids.count(id) != 0; + })) { + for (Val* v : next_vals) { if (auto* id = dynamic_cast(v)) { indexing_ids.insert(id); } } } + }; + if (dir == Direction::Forward) { + processExpr(expr->inputs(), expr->outputs()); } else if (dir == Direction::Backward) { - if (std::any_of( - expr->outputs().begin(), - expr->outputs().end(), - [&](Val* output) { - auto* id = dynamic_cast(output); - return id && indexing_ids.count(id) != 0; - })) { - for (Val* v : expr->inputs()) { - if (auto* id = dynamic_cast(v)) { - indexing_ids.insert(id); - } - } - } + processExpr(expr->outputs(), expr->inputs()); } else { NVF_THROW("Found unexpected direction"); } From ee5329fc21b4f8c995b5788baaa5a4f84199886d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 19 Nov 2024 12:32:13 -0500 Subject: [PATCH 12/25] Remove NVF_THROW and disable matmul test for codediff --- csrc/device_lower/analysis/predicate_elimination.cpp | 5 +---- tests/cpp/test_matmul.cpp | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index fb1911a8425..dd357a62608 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -217,10 +217,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // Check that this consumer_id is actually involved in indexing the // producer. If it is not connected to the producer allocation domain in // the broadcast graph, then we can skip processing it. - if (!consumer_id->isBroadcast() && - index_ids_.find(consumer_id) == index_ids_.end()) { - // TODO: Remove this line and the isBroadcast check in the condition above - NVF_THROW("FOUND UNEXPECTED PATH IN TEST ", consumer_id->toString()); + if (index_ids_.find(consumer_id) == index_ids_.end()) { return false; } needs_predicate_ = false; diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index c4fed527a42..2f6709d1e18 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3802,7 +3802,7 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { } // Test scheduling a Hopper matmul where the operands are 2D -TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { +TEST_F(HopperMatmulTest, DISABLED_HSH_NT_128BSwizzle_NoBroadcasts) { Fusion fusion; FusionGuard fg(&fusion); From 0cf29e5126f3752e6bcef703cd2869e6ff5e6b7a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 19 Nov 2024 15:37:48 -0500 Subject: [PATCH 13/25] Enable test codediff passed --- tests/cpp/test_matmul.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 2f6709d1e18..c4fed527a42 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3802,7 +3802,7 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { } // Test scheduling a Hopper matmul where the operands are 2D -TEST_F(HopperMatmulTest, DISABLED_HSH_NT_128BSwizzle_NoBroadcasts) { +TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { Fusion fusion; FusionGuard fg(&fusion); From 381035fcace11805df0c4b9a813f53f29d533d93 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 20 Nov 2024 12:24:23 -0500 Subject: [PATCH 14/25] Avoid processing non-indexing inputs to Merge If we have a non-indexing ID id1 and an indexing ID id2 and we merge them, we should only need to process id2. --- csrc/device_lower/analysis/predicate_elimination.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index dd357a62608..ce0cc1d5c5a 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -226,6 +226,9 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { } void handle(IterDomain* consumer_id) override { + if (index_ids_.find(consumer_id) == index_ids_.end()) { + return; + } // The traversal should have ended if needs_predicate_ was true NVF_ERROR(!needs_predicate_); From 732b8738f9349757ad55c4ad4b047b5ce6ec4e87 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 20 Nov 2024 13:21:41 -0500 Subject: [PATCH 15/25] Remove declaration that shadowed c2p_tmp This doesn't affect this PR but has an impact on the inlining use case --- csrc/device_lower/utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 3fccec517ac..c52267195ec 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2012,7 +2012,7 @@ getIndexIDs( // First we find the consumer root IDs that map to the producer std::unordered_map c2p_tmp; if (c2p == nullptr) { - auto c2p_tmp = + c2p_tmp = PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); c2p = &c2p_tmp; } From 9feb8f8ca2f552eb142efb4ff7c40b5d38eb76d8 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 20 Nov 2024 13:25:47 -0500 Subject: [PATCH 16/25] Update in light of #3452 --- csrc/device_lower/utils.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index c52267195ec..75898dc0a21 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2040,9 +2040,10 @@ getIndexIDs( const std::vector& start_domain, const std::vector& target_domain) { for (auto [expr, dir] : IRBFS::getExprsBetween( - {start_domain.begin(), start_domain.end()}, - {target_domain.begin(), target_domain.end()}, - /*require_all_to_visited=*/false)) { + {start_domain.begin(), start_domain.end()}, + {target_domain.begin(), target_domain.end()}, + /*require_all_to_visited=*/false) + .first) { // If there are any indexing IDs in the inputs, count all outputs as // indexing IDs const auto processExpr = [&indexing_ids]( From 6f451f7ab294a48d551f421c6d07f534fec0d295 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 27 Nov 2024 09:36:57 -0500 Subject: [PATCH 17/25] Only check index IDs for MmaOp --- .../analysis/predicate_elimination.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index ce0cc1d5c5a..3f7e5cafed9 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -192,8 +192,14 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // replayPasC skips mapping after the compute at position c2p[c] = p; } - [[maybe_unused]] const auto [producer_index_ids, consumer_index_ids] = - lower_utils::getIndexIDs(producer, consumer, &c2p); + std::unordered_set consumer_index_ids; + if (consumer->definition()->isA()) { + // NOTE: if consumer_index_ids is empty, it will be ignored. We only fill + // it for MmaOp for now in order to limit our changes to the only op that + // currently requires this analysis. + consumer_index_ids = + lower_utils::getIndexIDs(producer, consumer, &c2p).second; + } ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); for (auto id : consumer->getLoopDomain()) { @@ -217,7 +223,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // Check that this consumer_id is actually involved in indexing the // producer. If it is not connected to the producer allocation domain in // the broadcast graph, then we can skip processing it. - if (index_ids_.find(consumer_id) == index_ids_.end()) { + if (!index_ids_.empty() && + index_ids_.find(consumer_id) == index_ids_.end()) { return false; } needs_predicate_ = false; @@ -226,7 +233,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { } void handle(IterDomain* consumer_id) override { - if (index_ids_.find(consumer_id) == index_ids_.end()) { + if (!index_ids_.empty() && + index_ids_.find(consumer_id) == index_ids_.end()) { return; } // The traversal should have ended if needs_predicate_ was true From 9fb9aad9d9681142f79f0dc588f9eaf13a99324e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 3 Dec 2024 15:19:37 -0500 Subject: [PATCH 18/25] Simplify utility to lower_utils::getIdsBetween --- .../analysis/predicate_elimination.cpp | 19 ++-- csrc/device_lower/utils.cpp | 96 +++++-------------- csrc/device_lower/utils.h | 15 ++- 3 files changed, 44 insertions(+), 86 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 3f7e5cafed9..e89a850451e 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -188,17 +188,24 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { BestEffortReplay::replayPasC( producer, consumer, /*consumer_compute_at_axis=*/-1, pairwise_map) .getReplay(); - for (auto [c, p] : pairwise_map.mapConsumerToProducer()) { - // replayPasC skips mapping after the compute at position - c2p[c] = p; - } std::unordered_set consumer_index_ids; if (consumer->definition()->isA()) { // NOTE: if consumer_index_ids is empty, it will be ignored. We only fill // it for MmaOp for now in order to limit our changes to the only op that // currently requires this analysis. - consumer_index_ids = - lower_utils::getIndexIDs(producer, consumer, &c2p).second; + + // We flow from mapped IDs to the consumer's loop domain + std::vector mapped_ids; + auto root2logical = pairwise_map.mapConsumerToProducer(); + for (IterDomain* id : consumer->getMaybeRootDomain()) { + if (root2logical.find(id) != root2logical.end()) { + mapped_ids.push_back(id); + } + } + // This set will omit loop IDs that are not mapped to the producer, such + // as N dimensions when the producer is an A operand without broadcasts. + consumer_index_ids = lower_utils::getIdsBetween( + /*from=*/mapped_ids, /*to=*/consumer->getLoopDomain()); } ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index f4ea1c46596..52620dbd527 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -1987,81 +1987,33 @@ std::vector getSyncExprs(AsyncOpType async_type, int64_t keep_stages) { return sync_exprs; } -std::pair, std::unordered_set> -getIndexIDs( - TensorView* producer, - TensorView* consumer, - const std::unordered_map* c2p) { - // First we find the consumer root IDs that map to the producer - std::unordered_map c2p_tmp; - if (c2p == nullptr) { - c2p_tmp = - PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); - c2p = &c2p_tmp; - } - // Track the IDs involved in indexing in both producer and consumer - std::unordered_set consumer_indexing_ids; - std::unordered_set producer_indexing_ids; - for (IterDomain* id : consumer->getMaybeRootDomain()) { - auto it = c2p->find(id); - if (it == c2p->end()) { - continue; - } - // These are the immediately mapped consumer root and producer logical - // IDs. This is a starting point for our later traversals, which will fill - // these sets out. - consumer_indexing_ids.insert(it->first); - producer_indexing_ids.insert(it->second); - } - - // Now traverse from the starting set (which, as noted above is a subset of - // either the producer logical or consumer root) to the target which is - // either the producer allocation domain or the consumer loop domain. These - // are the IDs that will actually affect indexing. Any other IDs can be - // skipped. - auto traverse = [](std::unordered_set& indexing_ids, - const std::vector& start_domain, - const std::vector& target_domain) { - for (auto [expr, dir] : IRBFS::getExprsBetween( - {start_domain.begin(), start_domain.end()}, - {target_domain.begin(), target_domain.end()}, - /*require_all_to_visited=*/false) - .first) { - // If there are any indexing IDs in the inputs, count all outputs as - // indexing IDs - const auto processExpr = [&indexing_ids]( - const std::vector& prev_vals, - const std::vector& next_vals) { - if (std::any_of(prev_vals.begin(), prev_vals.end(), [&](Val* prev) { - auto* id = dynamic_cast(prev); - return id && indexing_ids.count(id) != 0; - })) { - for (Val* v : next_vals) { - if (auto* id = dynamic_cast(v)) { - indexing_ids.insert(id); - } - } +std::unordered_set getIdsBetween( + const std::vector& from, + const std::vector& to) { + std::unordered_set ids{from.begin(), from.end()}; + for (auto [expr, dir] : getExprsBetween( + {from.begin(), from.end()}, + {to.begin(), to.end()}, + /*require_all_to_visited=*/false) + .first) { + const std::vector& prev_vals = + dir == Direction::Forward ? expr->inputs() : expr->outputs(); + const std::vector& next_vals = + dir == Direction::Forward ? expr->outputs() : expr->inputs(); + // If there are _any_ IDs that were found in prev_vals then we count all the + // next vals as found + if (std::any_of(prev_vals.begin(), prev_vals.end(), [&](Val* prev) { + auto* id = dynamic_cast(prev); + return id && ids.count(id) != 0; + })) { + for (Val* v : next_vals) { + if (auto* id = dynamic_cast(v)) { + ids.insert(id); } - }; - if (dir == Direction::Forward) { - processExpr(expr->inputs(), expr->outputs()); - } else if (dir == Direction::Backward) { - processExpr(expr->outputs(), expr->inputs()); - } else { - NVF_THROW("Found unexpected direction"); } } - }; - traverse( - producer_indexing_ids, - /*start_domain=*/producer->getLogicalDomain(), - /*target_domain=*/producer->getMaybeAllocationDomain()); - traverse( - consumer_indexing_ids, - /*start_domain=*/consumer->getMaybeRootDomain(), - /*target_domain=*/consumer->getLoopDomain()); - - return {producer_indexing_ids, consumer_indexing_ids}; + } + return ids; } } // namespace lower_utils diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index f187cccb1f9..2ddf0463646 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -19,7 +19,6 @@ #include #include -#include "logical_domain_map.h" // Provides utilities for dealing with nested ForLoop and IfThenElse scopes @@ -374,13 +373,13 @@ std::vector getSyncExprs( AsyncOpType async_type, int64_t keep_stages = 0); -//! Get the set of IterDomains on the shortest path from the producer allocation -//! domain to the consumer loop domain. -std::pair, std::unordered_set> -getIndexIDs( - TensorView* producer, - TensorView* consumer, - const std::unordered_map* c2p = nullptr); +//! Get a set of IterDomains in TV between two given domains +//! (inclusive). If `from` is provided, IDs without any producers in `from` will +//! be omitted. +//! TODO: example: +std::unordered_set getIdsBetween( + const std::vector& from, + const std::vector& to); } // namespace lower_utils From e623561ee5b0a1d61d835ed44a1351c2d1a9c61c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 3 Dec 2024 16:01:26 -0500 Subject: [PATCH 19/25] Rename to getIdsAlongPathBetween and add example to comment --- .../analysis/predicate_elimination.cpp | 2 +- csrc/device_lower/utils.cpp | 2 +- csrc/device_lower/utils.h | 23 +++++++++++++++---- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index e89a850451e..8e7a18ed095 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -204,7 +204,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { } // This set will omit loop IDs that are not mapped to the producer, such // as N dimensions when the producer is an A operand without broadcasts. - consumer_index_ids = lower_utils::getIdsBetween( + consumer_index_ids = lower_utils::getIdsAlongPathBetween( /*from=*/mapped_ids, /*to=*/consumer->getLoopDomain()); } ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 52620dbd527..0d16621daa3 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -1987,7 +1987,7 @@ std::vector getSyncExprs(AsyncOpType async_type, int64_t keep_stages) { return sync_exprs; } -std::unordered_set getIdsBetween( +std::unordered_set getIdsAlongPathBetween( const std::vector& from, const std::vector& to) { std::unordered_set ids{from.begin(), from.end()}; diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index 2ddf0463646..b2f71730609 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -373,11 +373,24 @@ std::vector getSyncExprs( AsyncOpType async_type, int64_t keep_stages = 0); -//! Get a set of IterDomains in TV between two given domains -//! (inclusive). If `from` is provided, IDs without any producers in `from` will -//! be omitted. -//! TODO: example: -std::unordered_set getIdsBetween( +//! Get a set of IterDomains on a path between two given domains (inclusive). +//! +//! For example: +//! +//! i3 = merge(i0, i1) +//! i4, i5 = split(i3) +//! +//! If we are given +//! from = [ i0, i2 ] +//! to = [ i4 ] +//! This will return [ i0, i2, i3, i4, i5 ] +//! +//! If we are given +//! from = [ i4, i5 ] +//! to = [ i1 ] +//! This will return [ i4, i5, i3, i0, i1 ] +//! +std::unordered_set getIdsAlongPathBetween( const std::vector& from, const std::vector& to); From 0660f8c64d973efcc9c2ab10d593efe8bd186a0e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 4 Dec 2024 09:27:32 -0500 Subject: [PATCH 20/25] Use loop group traversal from alloc to loop --- .../analysis/predicate_elimination.cpp | 93 ++++++++++++++----- 1 file changed, 71 insertions(+), 22 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 8e7a18ed095..9dc40160c39 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -18,6 +18,7 @@ #include #include #include +#include "val_graph_visitor.h" namespace nvfuser { @@ -188,26 +189,72 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { BestEffortReplay::replayPasC( producer, consumer, /*consumer_compute_at_axis=*/-1, pairwise_map) .getReplay(); - std::unordered_set consumer_index_ids; + + // The variables graph and alloc_to_loop_groups are used to check whether we + // need to check a particular consumer ID. The alloc_to_loop_groups set + // constaints ValGroups along a shortest path in the loop graph from + // non-trivial dimensions in the allocation domain of the producer to the + // consumer's loop domain. Other domains might exist in the loop domain of + // the consumer: for example, for MmaOp we sometimes do not map the N + // dimension of the output logical domain to any ID in the A operand. We + // use this set to avoid performing unnecessary checks on these types of + // irrelevant consumer IDs. + // + // NOTE: if graph is nullptr, it will be + // ignored. We only fill it for MmaOp for now in order to limit our changes + // to the only op that currently requires this analysis. + const ValGraph* graph = nullptr; + std::unordered_set alloc_to_loop_groups; if (consumer->definition()->isA()) { - // NOTE: if consumer_index_ids is empty, it will be ignored. We only fill - // it for MmaOp for now in order to limit our changes to the only op that - // currently requires this analysis. + // Fill ValGraph and grab all ValGroups on path from producer alloc to + // consumer loop. + IdModel& id_model = GpuLower::current()->idModel(); + id_model.maybeBuildGraph(IdMappingMode::LOOP); + const ValGraph& loop_graph = id_model.idGraph(IdMappingMode::LOOP); // We flow from mapped IDs to the consumer's loop domain - std::vector mapped_ids; - auto root2logical = pairwise_map.mapConsumerToProducer(); - for (IterDomain* id : consumer->getMaybeRootDomain()) { - if (root2logical.find(id) != root2logical.end()) { - mapped_ids.push_back(id); + ValGroups alloc_groups; + for (IterDomain* id : producer->getMaybeAllocationDomain()) { + if (!id->isBroadcast() && !id->isReduction()) { + alloc_groups.pushBack(loop_graph.toGroup(id)); + } + } + ValGroups loop_groups; + for (IterDomain* id : consumer->getLoopDomain()) { + loop_groups.pushBack(loop_graph.toGroup(id)); + } + + const auto [path, all_reached] = ValGraphBFS::getExprGroupsBetween( + loop_graph, + /*from=*/alloc_groups, + /*to=*/loop_groups, + /*require_all_to_visited=*/false); + + if (!all_reached) { + // If we reached all loop groups, there's no need to perform this check + graph = &loop_graph; + alloc_to_loop_groups.insert(alloc_groups.begin(), alloc_groups.end()); + for (const auto& [expr_group, direction] : path) { + const std::vector prev_groups = + direction == Direction::Forward + ? loop_graph.inputGroups(expr_group) + : loop_graph.outputGroups(expr_group); + const std::vector next_groups = + direction == Direction::Forward + ? loop_graph.outputGroups(expr_group) + : loop_graph.inputGroups(expr_group); + if (std::any_of( + prev_groups.begin(), + prev_groups.end(), + [&alloc_to_loop_groups](const ValGroup& group) { + return alloc_to_loop_groups.count(group) > 0; + })) { + alloc_to_loop_groups.insert(next_groups.begin(), next_groups.end()); + } } } - // This set will omit loop IDs that are not mapped to the producer, such - // as N dimensions when the producer is an A operand without broadcasts. - consumer_index_ids = lower_utils::getIdsAlongPathBetween( - /*from=*/mapped_ids, /*to=*/consumer->getLoopDomain()); } - ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); + ProducerConsumerPairAnalyzer analyzer(c2p, graph, alloc_to_loop_groups); for (auto id : consumer->getLoopDomain()) { if (analyzer.needsPredicate(id)) { @@ -221,17 +268,18 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { private: ProducerConsumerPairAnalyzer( const std::unordered_map& c2p, - const std::unordered_set index_ids) - : c2p_(c2p), index_ids_(index_ids) {} + const ValGraph* graph, + const std::unordered_set alloc_to_loop_groups) + : c2p_(c2p), graph_(graph), alloc_to_loop_groups_(alloc_to_loop_groups) {} // Returns true if no out-of-bound accesses could occur with a // producer bool needsPredicate(IterDomain* consumer_id) { // Check that this consumer_id is actually involved in indexing the // producer. If it is not connected to the producer allocation domain in - // the broadcast graph, then we can skip processing it. - if (!index_ids_.empty() && - index_ids_.find(consumer_id) == index_ids_.end()) { + // the indexing graph, then we can skip processing it. + if (graph_ != nullptr && + alloc_to_loop_groups_.count(graph_->toGroup(consumer_id)) == 0) { return false; } needs_predicate_ = false; @@ -240,8 +288,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { } void handle(IterDomain* consumer_id) override { - if (!index_ids_.empty() && - index_ids_.find(consumer_id) == index_ids_.end()) { + if (graph_ != nullptr && + alloc_to_loop_groups_.count(graph_->toGroup(consumer_id)) == 0) { return; } // The traversal should have ended if needs_predicate_ was true @@ -328,7 +376,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { //! BestEffort map from consumer IDs to producer IDs const std::unordered_map& c2p_; bool needs_predicate_ = false; - std::unordered_set index_ids_; + const ValGraph* graph_ = nullptr; + const std::unordered_set alloc_to_loop_groups_; }; class PredicateChcker : public IterVisitor { From 6e17d11232529cda7c3779283fad2a902d473a68 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 5 Dec 2024 07:52:27 -0500 Subject: [PATCH 21/25] Remove getIdsAlongPathBetween --- csrc/device_lower/utils.cpp | 29 ----------------------------- csrc/device_lower/utils.h | 21 --------------------- 2 files changed, 50 deletions(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 0d16621daa3..0ff59a27b26 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -1987,35 +1987,6 @@ std::vector getSyncExprs(AsyncOpType async_type, int64_t keep_stages) { return sync_exprs; } -std::unordered_set getIdsAlongPathBetween( - const std::vector& from, - const std::vector& to) { - std::unordered_set ids{from.begin(), from.end()}; - for (auto [expr, dir] : getExprsBetween( - {from.begin(), from.end()}, - {to.begin(), to.end()}, - /*require_all_to_visited=*/false) - .first) { - const std::vector& prev_vals = - dir == Direction::Forward ? expr->inputs() : expr->outputs(); - const std::vector& next_vals = - dir == Direction::Forward ? expr->outputs() : expr->inputs(); - // If there are _any_ IDs that were found in prev_vals then we count all the - // next vals as found - if (std::any_of(prev_vals.begin(), prev_vals.end(), [&](Val* prev) { - auto* id = dynamic_cast(prev); - return id && ids.count(id) != 0; - })) { - for (Val* v : next_vals) { - if (auto* id = dynamic_cast(v)) { - ids.insert(id); - } - } - } - } - return ids; -} - } // namespace lower_utils } // namespace nvfuser diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index b2f71730609..fa62d3b3f76 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -373,27 +373,6 @@ std::vector getSyncExprs( AsyncOpType async_type, int64_t keep_stages = 0); -//! Get a set of IterDomains on a path between two given domains (inclusive). -//! -//! For example: -//! -//! i3 = merge(i0, i1) -//! i4, i5 = split(i3) -//! -//! If we are given -//! from = [ i0, i2 ] -//! to = [ i4 ] -//! This will return [ i0, i2, i3, i4, i5 ] -//! -//! If we are given -//! from = [ i4, i5 ] -//! to = [ i1 ] -//! This will return [ i4, i5, i3, i0, i1 ] -//! -std::unordered_set getIdsAlongPathBetween( - const std::vector& from, - const std::vector& to); - } // namespace lower_utils } // namespace nvfuser From ee6a89abdf797cfacbfd9999b84c6a60295c74fa Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 5 Dec 2024 08:27:51 -0500 Subject: [PATCH 22/25] Use TensorIndexer and getValsBetween --- .../analysis/predicate_elimination.cpp | 51 ++++++------------- csrc/device_lower/lower2device.cpp | 10 ++-- 2 files changed, 20 insertions(+), 41 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 9dc40160c39..c11efffde04 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -18,6 +18,7 @@ #include #include #include +#include "id_model/utils.h" #include "val_graph_visitor.h" namespace nvfuser { @@ -208,51 +209,29 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { if (consumer->definition()->isA()) { // Fill ValGraph and grab all ValGroups on path from producer alloc to // consumer loop. - IdModel& id_model = GpuLower::current()->idModel(); - id_model.maybeBuildGraph(IdMappingMode::LOOP); - const ValGraph& loop_graph = id_model.idGraph(IdMappingMode::LOOP); + + const IdModel& id_model = GpuLower::current()->idModel(); + graph = &GpuLower::current()->tensorIndexer().traversalGraph(); // We flow from mapped IDs to the consumer's loop domain - ValGroups alloc_groups; + std::vector alloc_groups; for (IterDomain* id : producer->getMaybeAllocationDomain()) { + id = getLoopPromotion(id, id_model); if (!id->isBroadcast() && !id->isReduction()) { - alloc_groups.pushBack(loop_graph.toGroup(id)); + alloc_groups.push_back(graph->toGroup(id)); } } - ValGroups loop_groups; + std::vector loop_groups; for (IterDomain* id : consumer->getLoopDomain()) { - loop_groups.pushBack(loop_graph.toGroup(id)); + id = getLoopPromotion(id, id_model); + loop_groups.push_back(graph->toGroup(id)); } - const auto [path, all_reached] = ValGraphBFS::getExprGroupsBetween( - loop_graph, - /*from=*/alloc_groups, - /*to=*/loop_groups, - /*require_all_to_visited=*/false); - - if (!all_reached) { - // If we reached all loop groups, there's no need to perform this check - graph = &loop_graph; - alloc_to_loop_groups.insert(alloc_groups.begin(), alloc_groups.end()); - for (const auto& [expr_group, direction] : path) { - const std::vector prev_groups = - direction == Direction::Forward - ? loop_graph.inputGroups(expr_group) - : loop_graph.outputGroups(expr_group); - const std::vector next_groups = - direction == Direction::Forward - ? loop_graph.outputGroups(expr_group) - : loop_graph.inputGroups(expr_group); - if (std::any_of( - prev_groups.begin(), - prev_groups.end(), - [&alloc_to_loop_groups](const ValGroup& group) { - return alloc_to_loop_groups.count(group) > 0; - })) { - alloc_to_loop_groups.insert(next_groups.begin(), next_groups.end()); - } - } - } + std::vector indexing_groups = + getValsBetween(alloc_groups, loop_groups, *graph); + + alloc_to_loop_groups.insert( + indexing_groups.begin(), indexing_groups.end()); } ProducerConsumerPairAnalyzer analyzer(c2p, graph, alloc_to_loop_groups); diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 79652bc67c5..4ce0cb10546 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -570,11 +570,6 @@ void GpuLower::analysis(Fusion* fusion) { nonDivisibleSplitInfo().build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build nonDivisibleSplitInfo"); - // Detects all exprssions that don't need predicates. Depends on - // nonDivisibleSplitInfo. - pred_elimination_ = std::make_unique(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination"); - circularBufferInfo().build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build circularBufferInfo"); @@ -589,6 +584,11 @@ void GpuLower::analysis(Fusion* fusion) { tensor_indexer_ = std::make_unique(*id_model_); } + // Detects all exprssions that don't need predicates. Depends on + // nonDivisibleSplitInfo. + pred_elimination_ = std::make_unique(fusion_); + dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination"); + consumerToTMAInfo() = getConsumerToTMAInfoMap(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "getConsumerToTMAInfoMap"); } From 2959d8899c72beba6a441f41a34712234a0ab74e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 5 Dec 2024 08:41:29 -0500 Subject: [PATCH 23/25] Don't need to promote allocation IDs --- csrc/device_lower/analysis/predicate_elimination.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index c11efffde04..e79cc177c17 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -216,7 +216,6 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // We flow from mapped IDs to the consumer's loop domain std::vector alloc_groups; for (IterDomain* id : producer->getMaybeAllocationDomain()) { - id = getLoopPromotion(id, id_model); if (!id->isBroadcast() && !id->isReduction()) { alloc_groups.push_back(graph->toGroup(id)); } From c41dec0e3c4af995b49ac187b46b1016b1dc513b Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Thu, 5 Dec 2024 13:04:35 -0500 Subject: [PATCH 24/25] Update csrc/device_lower/analysis/predicate_elimination.cpp Co-authored-by: Naoya Maruyama --- csrc/device_lower/analysis/predicate_elimination.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index e79cc177c17..b44f7057f67 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -211,7 +211,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // consumer loop. const IdModel& id_model = GpuLower::current()->idModel(); - graph = &GpuLower::current()->tensorIndexer().traversalGraph(); + graph = &id_model.idGraph(TensorIndexer::traversalIndexType()); // We flow from mapped IDs to the consumer's loop domain std::vector alloc_groups; From 2693456b225a78a1a802a1977a50ed0411b586e5 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 5 Dec 2024 13:08:01 -0500 Subject: [PATCH 25/25] Revert reordering of passes, fix comment, fix typo. --- csrc/device_lower/analysis/predicate_elimination.cpp | 7 +++++-- csrc/device_lower/lower2device.cpp | 10 +++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index b44f7057f67..b9b9d0e94ae 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -211,9 +211,12 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // consumer loop. const IdModel& id_model = GpuLower::current()->idModel(); - graph = &id_model.idGraph(TensorIndexer::traversalIndexType()); + graph = &id_model.idGraph(TensorIndexer::traversalGraphType()); - // We flow from mapped IDs to the consumer's loop domain + // We flow from the producer's allocation domain to the consumer's loop + // domain. Here we assume that producer->getMaybeAllocationDomain() + // returns the actual indexed IDs, which is not always the case in + // general. However, it is always the case for MmaOp. std::vector alloc_groups; for (IterDomain* id : producer->getMaybeAllocationDomain()) { if (!id->isBroadcast() && !id->isReduction()) { diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 4ce0cb10546..79652bc67c5 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -570,6 +570,11 @@ void GpuLower::analysis(Fusion* fusion) { nonDivisibleSplitInfo().build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build nonDivisibleSplitInfo"); + // Detects all exprssions that don't need predicates. Depends on + // nonDivisibleSplitInfo. + pred_elimination_ = std::make_unique(fusion_); + dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination"); + circularBufferInfo().build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build circularBufferInfo"); @@ -584,11 +589,6 @@ void GpuLower::analysis(Fusion* fusion) { tensor_indexer_ = std::make_unique(*id_model_); } - // Detects all exprssions that don't need predicates. Depends on - // nonDivisibleSplitInfo. - pred_elimination_ = std::make_unique(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination"); - consumerToTMAInfo() = getConsumerToTMAInfoMap(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "getConsumerToTMAInfoMap"); }