Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only check actually used IDs in predicate elimination for MmaOp #3414

Merged
merged 31 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9bf2645
Move OptOutMutator tests to new file and add repro
jacobhinkle Nov 15, 2024
96dd201
Add additional_ids arg to big ctor
jacobhinkle Nov 15, 2024
c7c790b
Only check actually used IDs in predicate elimination
jacobhinkle Nov 15, 2024
11083a1
clang-format
jacobhinkle Nov 15, 2024
11c43c4
clang-tidy of TensorDomain ctor
jacobhinkle Nov 15, 2024
683ae1e
Merge branch 'mutator_preserve_additional_ids' into mma_predicate_eli…
jacobhinkle Nov 15, 2024
c29470f
Merge branch 'main' into mma_predicate_elimination
jacobhinkle Nov 15, 2024
64be2c7
Remove debugging comment
jacobhinkle Nov 18, 2024
a9cc7aa
Merge remote-tracking branch 'origin/main' into mma_predicate_elimina…
jacobhinkle Nov 18, 2024
05d5ca4
[DO NOT MERGE] added throw to test impact on existing tests
jacobhinkle Nov 18, 2024
5974567
Refactor getting indexing IDs into utility
jacobhinkle Nov 19, 2024
e0ad380
Put back accidentally removed replay
jacobhinkle Nov 19, 2024
3c2631f
Add skipped root->logical mappings in c2p
jacobhinkle Nov 19, 2024
3342e77
Simplify getIndexIDs
jacobhinkle Nov 19, 2024
ee5329f
Remove NVF_THROW and disable matmul test for codediff
jacobhinkle Nov 19, 2024
0cf29e5
Enable test
jacobhinkle Nov 19, 2024
381035f
Avoid processing non-indexing inputs to Merge
jacobhinkle Nov 20, 2024
732b873
Remove declaration that shadowed c2p_tmp
jacobhinkle Nov 20, 2024
60b23a3
Merge remote-tracking branch 'origin/main' into mma_predicate_elimina…
jacobhinkle Nov 20, 2024
9feb8f8
Update in light of #3452
jacobhinkle Nov 20, 2024
2f89ab4
Merge remote-tracking branch 'origin/main' into mma_predicate_elimina…
jacobhinkle Nov 20, 2024
6f451f7
Only check index IDs for MmaOp
jacobhinkle Nov 27, 2024
9742b5b
Merge remote-tracking branch 'origin/main' into mma_predicate_elimina…
jacobhinkle Dec 3, 2024
9fb9aad
Simplify utility to lower_utils::getIdsBetween
jacobhinkle Dec 3, 2024
e623561
Rename to getIdsAlongPathBetween and add example to comment
jacobhinkle Dec 3, 2024
0660f8c
Use loop group traversal from alloc to loop
jacobhinkle Dec 4, 2024
6e17d11
Remove getIdsAlongPathBetween
jacobhinkle Dec 5, 2024
ee6a89a
Use TensorIndexer and getValsBetween
jacobhinkle Dec 5, 2024
2959d88
Don't need to promote allocation IDs
jacobhinkle Dec 5, 2024
c41dec0
Update csrc/device_lower/analysis/predicate_elimination.cpp
jacobhinkle Dec 5, 2024
2693456
Revert reordering of passes, fix comment, fix typo.
jacobhinkle Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions csrc/device_lower/analysis/predicate_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,37 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {
BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
.getReplay();

ProducerConsumerPairAnalyzer analyzer(c2p);
// Find all IterDomains involved in index expressions
// TODO: do we need to find logical IDs in producer that are involved in
// its allocation domain too?
std::vector<Val*> mapped_root_vals, loop_vals;
for (IterDomain* id : consumer->getRootDomain()) {
if (c2p.find(id) != c2p.end()) {
mapped_root_vals.push_back(id);
}
}
for (IterDomain* id : consumer->getLoopDomain()) {
loop_vals.push_back(id);
}

// Collect all IterDomains along path instead of Exprs
std::unordered_set<IterDomain*> index_ids;
for ([[maybe_unused]] auto [expr, dir] : IRBFS::getExprsBetween(
mapped_root_vals,
loop_vals,
/*require_all_to_visited=*/false)) {
for (Val* v : expr->inputs()) {
if (auto* id = dynamic_cast<IterDomain*>(v)) {
index_ids.insert(id);
}
}
for (Val* v : expr->outputs()) {
if (auto* id = dynamic_cast<IterDomain*>(v)) {
index_ids.insert(id);
}
}
}
ProducerConsumerPairAnalyzer analyzer(c2p, index_ids);

for (auto id : consumer->getLoopDomain()) {
if (analyzer.needsPredicate(id)) {
Expand All @@ -201,12 +231,19 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {

private:
ProducerConsumerPairAnalyzer(
const std::unordered_map<IterDomain*, IterDomain*>& c2p)
: c2p_(c2p) {}
const std::unordered_map<IterDomain*, IterDomain*>& c2p,
const std::unordered_set<IterDomain*> index_ids)
: c2p_(c2p), index_ids_(index_ids) {}

// Returns true if no out-of-bound accesses could occur with a
// producer
bool needsPredicate(IterDomain* consumer_id) {
// TODO: check that this consumer_id is actually involved in indexing the
// producer. If it is not connected to the producer allocation domain in
// the broadcast graph, then we can skip processing it.
if (index_ids_.find(consumer_id) == index_ids_.end()) {
return false;
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
}
needs_predicate_ = false;
handle(consumer_id);
return needs_predicate_;
Expand Down Expand Up @@ -297,6 +334,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {
//! BestEffort map from consumer IDs to producer IDs
const std::unordered_map<IterDomain*, IterDomain*>& c2p_;
bool needs_predicate_ = false;
std::unordered_set<IterDomain*> index_ids_;
};

class PredicateChcker : public IterVisitor {
Expand Down
205 changes: 177 additions & 28 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,35 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulPipelineGmem) {
}
}

// Check that mma op is not predicated.
class PredicateChecker : public kir::IrVisitor {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose instead of using this we could possibly use PredicatedChecker::isPredicated instead. I kept it here to mirror the check in the AmpereSwizzle test.

public:
using kir::IrVisitor::handle;
bool found_mma = false;

private:
void handle(kir::Asm* asm_) final {
#if IS_CPP20
if (!asm_->code().starts_with("mma") &&
!asm_->code().starts_with("wgmma")) {
#else
if (asm_->code().substr(0, 3) != "mma" &&
asm_->code().substr(0, 5) != "wgmma") {
#endif
return;
}
found_mma = true;
for (auto expr : scope_exprs_) {
NVF_CHECK(
!expr->isA<kir::IfThenElse>() ||
expr->as<kir::IfThenElse>()->predicate()->isTrivial(),
"MmaOp should't be predicated!",
" Get predicate ",
expr->as<kir::IfThenElse>()->predicate()->toInlineString());
}
}
};

// Matmul test for Ampere MMA: checking CTA Swizzles
TEST_P(MatmulTestWithLayout, AmpereSwizzle) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
Expand Down Expand Up @@ -549,35 +578,9 @@ TEST_P(MatmulTestWithLayout, AmpereSwizzle) {
runtime = 0;
}

// Check that mma op is not predicated. This is a regression test for
// https://github.com/NVIDIA/Fuser/issues/95
class PredicateChecker : public kir::IrVisitor {
public:
using kir::IrVisitor::handle;
bool found_mma = false;

private:
void handle(kir::Asm* asm_) final {
#if IS_CPP20
if (!asm_->code().starts_with("mma")) {
#else
if (asm_->code().substr(0, 3) != "mma") {
#endif
return;
}
found_mma = true;
for (auto expr : scope_exprs_) {
NVF_CHECK(
!expr->isA<kir::IfThenElse>() ||
expr->as<kir::IfThenElse>()->predicate()->isTrivial(),
"MmaOp should't be predicated!",
" Get predicate ",
expr->as<kir::IfThenElse>()->predicate()->toInlineString());
}
}
} pred_checker;

// This is a regression test for https://github.com/NVIDIA/Fuser/issues/95
GpuLower gpulw(&fusion);
PredicateChecker pred_checker;
pred_checker.handle(gpulw.run()->topLevelExprs());
ASSERT_TRUE(pred_checker.found_mma);
};
Expand Down Expand Up @@ -3798,4 +3801,150 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

// Test scheduling a Hopper matmul where the operands are 2D
TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is rather large, but since MmaOp is the only way I currently know how to trigger this behavior, I decided to just put the bulk of the test from #3406 here.

Fusion fusion;
FusionGuard fg(&fusion);

// constexpr int64_t M = 2048, N = 2048, K = 8192;
constexpr auto macro = MmaMacro::Hopper_64_256_16;
// constexpr auto layout = MmaLayout::NT; // [K, M] x [K, N] -> [M, N]
constexpr auto swizzle = MmaInputSmemSwizzle::B128;
const auto dtype = DataType::Half;

constexpr int64_t stages = 1;
constexpr int64_t prefetch = 3;
const int64_t cta_m = 2 * getM(macro);
const int64_t cta_n = 1 * getN(macro);

auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // [K, M]
auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // [K, N]
fusion.addInput(tv0);
fusion.addInput(tv1);

// The output is [M, N, K] (no reordering needed)
MmaOp::AxisMapping axis_mapping{.a_axes = {1, -1, 0}, .b_axes = {-1, 1, 0}};
auto tv2 =
fusedMultiplySum(tv0, tv1, /*axes=*/{-1}, /*init=*/nullptr, axis_mapping);

auto tv3 = castOp(DataType::Half, tv2);

fusion.addOutput(tv3);

auto mma_ops = ir_utils::getOpsOfType<MmaOp>(&fusion);
NVF_CHECK(
1 == mma_ops.size(),
"Invalid number of MmaOp instances in fusion definition, expected 1, got ",
mma_ops.size());
mma_ops.front()->setMacro(macro);

// gmem [K, M] x gmem [K, N] -mma-> register [M, N, rK]
// register [M, N, rK] -cast-> gmem [M, N]

auto tv0c = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
tv0c->setMemoryType(MemoryType::Shared);
auto tv1c = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
tv1c->setMemoryType(MemoryType::Shared);
auto tv3c = tv3->cacheBefore();

tv0c->broadcast(-1); // [K, M] -> [K, M, 1]
tv1c->broadcast(-2); // [K, N] -> [K, 1, N]

// gmem [K, M, 1] -TMA-> smem [K, M, 1]
// gmem [K, 1, N] -TMA-> smem [K, 1, N]
// smem [K, M, 1] x smem [K, 1, N] -mma-> register [M, N, rK]
// register [M, N, rK] -cast-> register [M, N] -set-> gmem [M, N]

// Create tiles
tv2->split(-3, cta_m);
tv2->split(-2, cta_n);
tv2->split(-1, getK(macro));
// [Mo, Mi, No, Ni, Ko, Ki] -> [Mo, No, Ko, Mi, Ni, Ki]
tv2->reorder({{-5, -3}, {-3, -2}});
tv2->axis(0)->parallelize(ParallelType::BIDy);
tv2->axis(1)->parallelize(ParallelType::BIDx);

// NOTE: since in this case we do not have "proper" broadcast in the inputs,
// we cannot simply propagate transforms to the operands. Instead, we
// propagate forward to the outputs and manually schedule the smem operands.
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
tv2,
-1,
{tv3},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

// Schedule operands
for (TensorView* tv : {tv0c, tv1c}) {
tv->reorder({{-3, -1}}); // [K, M, N] -> [M, N, K]
// NOTE: above axes are given in MNK order, but inputs are in KMN
tv->split(-3, cta_m);
tv->split(-2, cta_n);
tv->split(-1, getK(macro));
// [Mo, Mi, No, Ni, Ko, Ki] -> [Mo, No, Ko, Mi, Ni, Ki]
// [Ko, Ki, Mo, Mi, No, Ni] -> [Mo, No, Ko, Mi, Ni, Ki]
tv->reorder({{-5, -3}, {-3, -2}});
tv->axis(0)->parallelize(ParallelType::BIDy);
tv->axis(1)->parallelize(ParallelType::BIDx);
}

// [..., Mi, Ni, Ki] -> [..., Ni, Ki, Mi]
tv0c->reorder({{-3, -1}});
tv0c->applyMmaSwizzleForTMALoad(swizzle);
// [..., Mi, Ni, Ki] -> [..., Mi, Ki, Ni]
tv1c->reorder({{-1, -2}});
tv1c->applyMmaSwizzleForTMALoad(swizzle);

{
tv2->split(-3, getM(macro));
tv2->split(-2, getN(macro));
// [Mo, No, Ko, Mio, Mii, Nio, Nii, Ki]
// -> [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki]
tv2->reorder({{-4, -3}});
tv2->merge(-5);
tv2->axis(-4)->parallelize(ParallelType::TIDy);
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
tv2,
-1,
{tv3},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());
}

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv2->getLoopDomain());
tv2->setAllocationDomain(s.as<IterDomain*>(), true);
tv2->axis(-1)->parallelize(ParallelType::Mma);
tv2->axis(-2)->parallelize(ParallelType::Mma);
tv2->axis(-3)->parallelize(ParallelType::Mma);
}

for (auto tv : {tv3c, tv3}) {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
}
tv3->axis(-1)->parallelize(ParallelType::Vectorize);

inlineMost();

if (stages > 1) {
tv0c->circularBuffer(stages, prefetch);
tv1c->circularBuffer(stages, prefetch);
}

// Test that predicate elimination works when the MmaOp's operands have no
// logical broadcasts
GpuLower gpulw(&fusion);
kir::Kernel* kernel = gpulw.run();
PredicateChecker pred_checker;
pred_checker.handle(kernel->topLevelExprs());
ASSERT_TRUE(pred_checker.found_mma);

// TODO: compile and run kernel once inlining is fixed
}

} // namespace nvfuser
Loading