Skip to content

Commit

Permalink
Allow inlining past loop broadcasts for MmaOp (#3416)
Browse files Browse the repository at this point in the history
Stacked on #3414 

This PR enables us to inline an MmaOp properly when its inputs are
missing broadcast dimensions. We do this by always allowing inlining
past loop broadcasts or their transforms. For example
```
tv0:
  logical [ iS1{i0} ]
  loop [ iS1{i0} bS5{1} ]
tv1:
  logical [ iS2{i1} ]
  loop [ bS6{1} iS2{i1} ]
tv2 = foo(tv0, tv1)
  logical [ iS3{i0} iS4{i1} ]
```
As long as the operation `foo` properly maps its arguments despite the
missing logical dimensions (as `MmaOp` does as of #3391), then we should
be able to fully inline this case because the loop broadcasts `bS5` and
`bS6` are imaginary in the sense that they don't impact indexing.
  • Loading branch information
jacobhinkle authored Dec 11, 2024
1 parent 2749296 commit f5f2ab5
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 9 deletions.
90 changes: 89 additions & 1 deletion csrc/scheduler/tools/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <device_lower/utils.h>
#include <id_model/utils.h>
#include <ir/utils.h>
#include <iter_visitor.h>
#include <logical_domain_map.h>
#include <scheduler/tools/inlining.h>
#include <transform_iter.h>
#include <val_graph_visitor.h>

#include <utility>

Expand Down Expand Up @@ -193,6 +195,46 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
}
return producer->nDims();
} else {
std::optional<std::unordered_set<ValGroup>> loop_path_groups = std::nullopt;
if (consumer->definition()->isA<MmaOp>()) {
// We handle MmaOp specially here since it is currently the only operation
// for which we generate code (i.e. not SdpaFwdOp or SdpaBwdOp) that has
// some output dimensions that do not map to input dimensions. For this
// case, we need to identify potential inlined pairs each ID of which is
// not mapped at all to the other TensorView (see example below).

// Get ValGroups in loop domains of producer and consumer that are
// connected to _mapped_ IterDomains in the pairwise map.
//
// Note that for MmaOp, it would be sufficient to traverse from the
// producer loop to the consumer loop and identify when _either_ the
// consumer or producer ID is not mapped. Here we are instead traversing
// from mapped domains to both roots so that we can check that _both_
// consumer and producer ID is not mapped. This is slightly safer and this
// symmetry might be handy in handling new ops that use this feature in
// the future.
std::vector<ValGroup> pairwise_mapped_groups;
for (auto [c_id, p_id] : PairwiseLogicalDomainMap(producer, consumer)
.mapConsumerToProducer()) {
pairwise_mapped_groups.push_back(inliningGraph().toGroup(c_id));
}
// We propagate toward the loop groups from both consumer and producer
std::vector<ValGroup> all_loop_groups;
for (IterDomain* id : producer->getLoopDomain()) {
all_loop_groups.push_back(inliningGraph().toGroup(id));
}
for (IterDomain* id : consumer->getLoopDomain()) {
all_loop_groups.push_back(inliningGraph().toGroup(id));
}
// getValsBetween does not require all target groups to be visited. The
// means the result contains the subset of both loop groups that we are
// looking for
std::vector<ValGroup> group_path = getValsBetween<ValGraphBFS>(
pairwise_mapped_groups, all_loop_groups, inliningGraph());
loop_path_groups =
std::unordered_set<ValGroup>(group_path.begin(), group_path.end());
}

auto consumer_it = consumer->getLoopDomain().begin();
for (const auto producer_pos : c10::irange(producer->nDims())) {
auto p_id = producer->getLoopDomain().at(producer_pos);
Expand All @@ -211,8 +253,54 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
}

IterDomain* c_id = *consumer_it;

// We can inline past positions in which both producer and consumer are
// not connected to any mapped logical IterDomain pairs.
//
// For example, an MmaOp can be constructed as follows:
//
// tv0:
// root/logical: [ iS0, iS1 ]
// loop: [ iS0, bS7, iS1 ]
// tv1:
// root/logical: [ iS2, iS3 ]
// loop: [ bS8, iS2, iS3 ]
// tv2:
// root/logical/loop: [ iS4, iS5, rS6 ]
//
// iS4 maps to iS0 so when producer==tv0 we can inline past iS0. When
// producer==tv1, iS4 doesn't map to anything in tv1 and bS8 is a loop
// broadcast in that position so we inline past the first ID in that
// case also. Similarly, we inline past iS5, iS2, and bS7.
if (loop_path_groups.has_value()) {
bool p_id_connected =
loop_path_groups->count(inliningGraph().toGroup(p_id));
bool c_id_connected =
loop_path_groups->count(inliningGraph().toGroup(c_id));
NVF_ERROR(
p_id_connected ||
(consumer->definition()->isA<MmaOp>() && p_id->isBroadcast()),
"Expected unmapped producer id to be broadcast domain in MmaOp input but found ",
p_id->toString());

if (!p_id_connected && !c_id_connected) {
NVF_ERROR(
p_id->isBroadcast(),
"Unmapped producer ID must be a broadcast created in scheduling but found ",
p_id->toString());
++consumer_it;
continue;
}
}

if (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) ||
!isAllowedID(c_id, consumer, best_effort, true, false, true)) {
!isAllowedID(
c_id,
consumer,
best_effort,
/*allow_reduction=*/true,
/*allow_vectorize=*/false,
/*allow_unmappable=*/true)) {
return producer_pos;
}

Expand Down
34 changes: 26 additions & 8 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3657,7 +3657,6 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
const auto dtype = DataType::Half;

constexpr bool use_smem_epilogue = false;
constexpr bool use_warp_specialization = true;

constexpr int64_t stages = 4;
constexpr int64_t prefetch = 3;
Expand Down Expand Up @@ -3801,13 +3800,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {

inlineMost();

if (use_warp_specialization) {
tv0c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy));
tv1c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy));
} else {
tv0c->circularBuffer(stages, prefetch);
tv1c->circularBuffer(stages, prefetch);
}
tv0c->circularBuffer(stages, prefetch);
tv1c->circularBuffer(stages, prefetch);

auto inputs =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
Expand Down Expand Up @@ -3948,8 +3942,32 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) {
}
tv3->axis(-1)->parallelize(ParallelType::Vectorize);

{
// Check using a copy that improperly aligned axis are not inlined
Fusion tmp_fusion;
IrCloner ir_cloner = Fusion::copy(&fusion, &tmp_fusion);
FusionGuard tmp_fg(&tmp_fusion);
// [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki]
// Swap the No and Ko axes, but only in tv2, the mma output
// [Mo, Ko, No, Mio, Nio, Mii, Nii, Ki]
// This should mean the smem operands are now inlined at position 1 instead
// of 3
ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}});
inlineMost();
tmp_fusion.printMath();
ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}});
EXPECT_EQ(ir_cloner.clone(tv0c)->getComputeAtPosition(), 1);
// The outermost loop dim of tv1c is a broadcast Mo axis, so
// tv1c->inlineAt(1) does not inline past that axis and we wind up with
// compute-at position 0.
EXPECT_EQ(ir_cloner.clone(tv1c)->getComputeAtPosition(), 0);
}

inlineMost();

EXPECT_EQ(tv0c->getComputeAtPosition(), 3);
EXPECT_EQ(tv1c->getComputeAtPosition(), 3);

if (stages > 1) {
tv0c->circularBuffer(stages, prefetch);
tv1c->circularBuffer(stages, prefetch);
Expand Down

0 comments on commit f5f2ab5

Please sign in to comment.