Skip to content

Commit

Permalink
Only check actually used IDs in predicate elimination for MmaOp (#3414)
Browse files Browse the repository at this point in the history
Now that we can define MmaOp with unbroadcasted inputs (see #3391), it
is possible to have ops for which some consumer loop IDs are not used at
all for indexing some consumers.

For example, the output of `MmaOp` has three logical dimensions, M, N,
and K. These are scheduled by spitting, merging, and swizzling, so in
the end the consumer loop domain can contain things like a split of the
N dimension into two other IterDomains. Now if we look at the producer
A, it has logical size [M K], so there is no N dimension at all. Our
current predicate elimination pass places a predicate on this operation
when the N dimension is symbolic and we can't prove that the producer is
parallelized the same way as this consumer in this dimension. However,
since N cannot affect the indexing of the producer A which has no N
dimension, we should skip checking these IterDomains.

This PR does this by performing a BFS from the collection of consumer
root IDs that map to producer logical IDs to the consumer leaf domain.
Only IDs along that path are checked using the existing conditions.

## Detailed example

In the test included in this PR, we have shared memory operand tensors
that are scheduled like this
```
Inputs:
  T0_g___half[ iS0{i0}, iS1{i1} ]
  T1_g___half[ iS2{i3}, iS3{i4} ]
Outputs:
  T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 )

%kernel_math {
T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 )
   = CpAsyncBulkTensorTile( T0_g___half[ iS0{i0}, iS1{i1} ] )
T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 )
   = CpAsyncBulkTensorTile( T1_g___half[ iS2{i3}, iS3{i4} ] )
T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 )
   = mma(T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 ),
         T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 ))
T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 )
   = __float2half(T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 ));
T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 )
   = Set( T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 ), cache_op=Streaming )
} // %kernel_math

T0_g___half[ iS0{i0}, iS1{i1} ]
 logical domain : (iS0{i0}, iS1{i1})
 contiguity: t t
 loop domain : (iS0{i0}, iS1{i1})
T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 )
 logical domain : (iS9{i0}, iS10{i1})
 allocation domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
 contiguity: t n t n t t t t t t
  Split: iS10{i1} by factor 128 -> iblockIdx.y31{( ceilDiv(i1, 128) )}, iS32{128}
  Split: iS9{i0} by factor 16 -> iS35{( ceilDiv(i0, 16) )}, iS36{16}
  Split: iS32{128} by factor 64 -> iS43{2}, iS44{64}
  Split: iS36{16} by factor 8 -> iB45{2}, iS46{8}
  Split: iS46{8} by factor 1 -> iS47{8}, iB48{1}
  Split: iS44{64} by factor 8 -> iS49{8}, iB50{8}
  Xor(2D): iS47{8} , iS49{8} -> iB51{8} , iB52{8}
 loop domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
T1_g___half[ iS2{i3}, iS3{i4} ]
 logical domain : (iS2{i3}, iS3{i4})
 contiguity: t t
 loop domain : (iS2{i3}, iS3{i4})
T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 )
 logical domain : (iS11{i3}, iS12{i4})
 allocation domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
 contiguity: n t t n t t t t t t
  Split: iS12{i4} by factor 256 -> iblockIdx.x39{( ceilDiv(i4, 256) )}, iS40{256}
  Split: iS11{i3} by factor 16 -> iS41{( ceilDiv(i3, 16) )}, iS42{16}
  Split: iS40{256} by factor 64 -> iS53{4}, iS54{64}
  Split: iS42{16} by factor 8 -> iB55{2}, iS56{8}
  Split: iS56{8} by factor 1 -> iS57{8}, iB58{1}
  Split: iS54{64} by factor 8 -> iS59{8}, iB60{8}
  Xor(2D): iS57{8} , iS59{8} -> iB61{8} , iB62{8}
 loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (iS4{i1}, iS5{i4}, rS6{i0})
 allocation domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, ithreadIdx.x87{128}, iMMA82{32}, iMMA81{2}, iMMA85{2}, rMMA90{2}, rMMA91{4}, rMMA89{2})
 contiguity: t t n t t t t t n n n
  Split: iS4{i1} by factor 128 -> iblockIdx.y17{( ceilDiv(i1, 128) )}, iS18{128}
  Split: iS5{i4} by factor 256 -> iblockIdx.x19{( ceilDiv(i4, 256) )}, iS20{256}
  Split: rS6{i0} by factor 16 -> rS21{( ceilDiv(i0, 16) )}, rMMA22{16}
  Split: iS18{128} by factor 64 -> iS63{2}, iMMA64{64}
  Split: iS20{256} by factor 256 -> iS65{1}, iMMA66{256}
  Merge: iS63{2} and iS65{1} -> ithreadIdx.y67{2}
 loop domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16})
```
Notice that in `T4_s` that the loop broadcasts `bblockIdx.x33{1}` and
`bS34{256}` are not derived from the logical domain. Instead, they are
actually both the products of a `Split` involving an original "loop
broadcast", although this is not currently shown in
`fusion->printTransforms()`:
```
Split: bS15{1} by factor 256 -> bblockIdx.x33{1}, bS34{256}
```
In the predicate elimination pass with `T4_s` and producer and `T2_l` as
consumer, the consumer ID `iblockIdx.x19{( ceilDiv(i4, 256) )}`
_normally_ would map to a logical broadcast ID in `T4_s`, but with these
loop domain broadcasts we do not have such a mapping. Before this PR
that would cause predication. This PR notices that `iblockIdx.x19{(
ceilDiv(i4, 256) )}` is not actually used for indexing the producer
`T4_s` so we do not need to worry about out-of-bounds accesses in this
dimension.

Without this PR, if we remove the check at
https://github.com/NVIDIA/Fuser/blob/3266b9d21cb82272fe6e766b71fb9a9f298de833/csrc/device_lower/analysis/predicate_elimination.cpp#L34-L37
then we generate the following code:
```c++
__global__ void nvfuser_none_f0_c0_r0_g0(      
    Tensor<__half, 2, 2> T0,                                                          
    Tensor<__half, 2, 2> T1,                                                          
    const __grid_constant__ TensorMap var0,                                           
    const __grid_constant__ TensorMap var1,                                           
    Tensor<__half, 2, 2> T3) {
  // ...
  nvfuser_index_t i4;
  i4 = 256 * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i7;
  i7 = 128 * ((nvfuser_index_t)blockIdx.y);
  nvfuser_index_t i19;
  i19 = 64 * ((nvfuser_index_t)threadIdx.y);
  bool b20;
  b20 = (i4 < T1.logical_size[1LL]) && ((i19 + i7) < T0.logical_size[1LL]);

#pragma unroll 1
  for (nvfuser_index_t i23 = 0; i23 < i2; ++i23) {
    nvfuser_index_t i24;
    i24 = 16 * i23;

    // ... load operands ...

    __syncthreads();
    if ((b20 && (i24 < T0.logical_size[0LL]))) {
      asm volatile(
          "{\n"
          "  .reg .pred p0; \n"
          "  setp.ne.b32 p0, %130, 0;\n"
          "  wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 {..." /*... long parameter list ... */);
    }
    asm volatile("wgmma.commit_group.sync.aligned;\n");
    asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory");
  }
  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory");
  // ... epilogue and write outputs ...
}
```
After this PR, the predicate around the `wgmma` call is removed and the
`assertOnWarpOps` check can be restored.

---------

Co-authored-by: Naoya Maruyama <[email protected]>
  • Loading branch information
jacobhinkle and naoyam authored Dec 5, 2024
1 parent ecabd46 commit 736a541
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 32 deletions.
70 changes: 66 additions & 4 deletions csrc/device_lower/analysis/predicate_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <predicate_compute.h>
#include <transform_iter.h>
#include <transform_replay.h>
#include "id_model/utils.h"
#include "val_graph_visitor.h"

namespace nvfuser {

Expand Down Expand Up @@ -185,10 +187,55 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {

auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer);
auto c2p =
BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
BestEffortReplay::replayPasC(
producer, consumer, /*consumer_compute_at_axis=*/-1, pairwise_map)
.getReplay();

ProducerConsumerPairAnalyzer analyzer(c2p);
// The variables graph and alloc_to_loop_groups are used to check whether we
// need to check a particular consumer ID. The alloc_to_loop_groups set
// constaints ValGroups along a shortest path in the loop graph from
// non-trivial dimensions in the allocation domain of the producer to the
// consumer's loop domain. Other domains might exist in the loop domain of
// the consumer: for example, for MmaOp we sometimes do not map the N
// dimension of the output logical domain to any ID in the A operand. We
// use this set to avoid performing unnecessary checks on these types of
// irrelevant consumer IDs.
//
// NOTE: if graph is nullptr, it will be
// ignored. We only fill it for MmaOp for now in order to limit our changes
// to the only op that currently requires this analysis.
const ValGraph* graph = nullptr;
std::unordered_set<ValGroup> alloc_to_loop_groups;
if (consumer->definition()->isA<MmaOp>()) {
// Fill ValGraph and grab all ValGroups on path from producer alloc to
// consumer loop.

const IdModel& id_model = GpuLower::current()->idModel();
graph = &id_model.idGraph(TensorIndexer::traversalGraphType());

// We flow from the producer's allocation domain to the consumer's loop
// domain. Here we assume that producer->getMaybeAllocationDomain()
// returns the actual indexed IDs, which is not always the case in
// general. However, it is always the case for MmaOp.
std::vector<ValGroup> alloc_groups;
for (IterDomain* id : producer->getMaybeAllocationDomain()) {
if (!id->isBroadcast() && !id->isReduction()) {
alloc_groups.push_back(graph->toGroup(id));
}
}
std::vector<ValGroup> loop_groups;
for (IterDomain* id : consumer->getLoopDomain()) {
id = getLoopPromotion(id, id_model);
loop_groups.push_back(graph->toGroup(id));
}

std::vector<ValGroup> indexing_groups =
getValsBetween<ValGraphBFS>(alloc_groups, loop_groups, *graph);

alloc_to_loop_groups.insert(
indexing_groups.begin(), indexing_groups.end());
}
ProducerConsumerPairAnalyzer analyzer(c2p, graph, alloc_to_loop_groups);

for (auto id : consumer->getLoopDomain()) {
if (analyzer.needsPredicate(id)) {
Expand All @@ -201,18 +248,31 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {

private:
ProducerConsumerPairAnalyzer(
const std::unordered_map<IterDomain*, IterDomain*>& c2p)
: c2p_(c2p) {}
const std::unordered_map<IterDomain*, IterDomain*>& c2p,
const ValGraph* graph,
const std::unordered_set<ValGroup> alloc_to_loop_groups)
: c2p_(c2p), graph_(graph), alloc_to_loop_groups_(alloc_to_loop_groups) {}

// Returns true if no out-of-bound accesses could occur with a
// producer
bool needsPredicate(IterDomain* consumer_id) {
// Check that this consumer_id is actually involved in indexing the
// producer. If it is not connected to the producer allocation domain in
// the indexing graph, then we can skip processing it.
if (graph_ != nullptr &&
alloc_to_loop_groups_.count(graph_->toGroup(consumer_id)) == 0) {
return false;
}
needs_predicate_ = false;
handle(consumer_id);
return needs_predicate_;
}

void handle(IterDomain* consumer_id) override {
if (graph_ != nullptr &&
alloc_to_loop_groups_.count(graph_->toGroup(consumer_id)) == 0) {
return;
}
// The traversal should have ended if needs_predicate_ was true
NVF_ERROR(!needs_predicate_);

Expand Down Expand Up @@ -297,6 +357,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {
//! BestEffort map from consumer IDs to producer IDs
const std::unordered_map<IterDomain*, IterDomain*>& c2p_;
bool needs_predicate_ = false;
const ValGraph* graph_ = nullptr;
const std::unordered_set<ValGroup> alloc_to_loop_groups_;
};

class PredicateChcker : public IterVisitor {
Expand Down
205 changes: 177 additions & 28 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,35 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulPipelineGmem) {
}
}

// Check that mma op is not predicated.
class PredicateChecker : public kir::IrVisitor {
public:
using kir::IrVisitor::handle;
bool found_mma = false;

private:
void handle(kir::Asm* asm_) final {
#if IS_CPP20
if (!asm_->code().starts_with("mma") &&
!asm_->code().starts_with("wgmma")) {
#else
if (asm_->code().substr(0, 3) != "mma" &&
asm_->code().substr(0, 5) != "wgmma") {
#endif
return;
}
found_mma = true;
for (auto expr : scope_exprs_) {
NVF_CHECK(
!expr->isA<kir::IfThenElse>() ||
expr->as<kir::IfThenElse>()->predicate()->isTrivial(),
"MmaOp should't be predicated!",
" Get predicate ",
expr->as<kir::IfThenElse>()->predicate()->toInlineString());
}
}
};

// Matmul test for Ampere MMA: checking CTA Swizzles
TEST_P(MatmulTestWithLayout, AmpereSwizzle) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
Expand Down Expand Up @@ -542,35 +571,9 @@ TEST_P(MatmulTestWithLayout, AmpereSwizzle) {
runtime = 0;
}

// Check that mma op is not predicated. This is a regression test for
// https://github.com/NVIDIA/Fuser/issues/95
class PredicateChecker : public kir::IrVisitor {
public:
using kir::IrVisitor::handle;
bool found_mma = false;

private:
void handle(kir::Asm* asm_) final {
#if IS_CPP20
if (!asm_->code().starts_with("mma")) {
#else
if (asm_->code().substr(0, 3) != "mma") {
#endif
return;
}
found_mma = true;
for (auto expr : scope_exprs_) {
NVF_CHECK(
!expr->isA<kir::IfThenElse>() ||
expr->as<kir::IfThenElse>()->predicate()->isTrivial(),
"MmaOp should't be predicated!",
" Get predicate ",
expr->as<kir::IfThenElse>()->predicate()->toInlineString());
}
}
} pred_checker;

// This is a regression test for https://github.com/NVIDIA/Fuser/issues/95
GpuLower gpulw(&fusion);
PredicateChecker pred_checker;
pred_checker.handle(gpulw.run()->topLevelExprs());
ASSERT_TRUE(pred_checker.found_mma);
};
Expand Down Expand Up @@ -3798,4 +3801,150 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

// Test scheduling a Hopper matmul where the operands are 2D
TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) {
Fusion fusion;
FusionGuard fg(&fusion);

// constexpr int64_t M = 2048, N = 2048, K = 8192;
constexpr auto macro = MmaMacro::Hopper_64_256_16;
// constexpr auto layout = MmaLayout::NT; // [K, M] x [K, N] -> [M, N]
constexpr auto swizzle = MmaInputSmemSwizzle::B128;
const auto dtype = DataType::Half;

constexpr int64_t stages = 1;
constexpr int64_t prefetch = 3;
const int64_t cta_m = 2 * getM(macro);
const int64_t cta_n = 1 * getN(macro);

auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // [K, M]
auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // [K, N]
fusion.addInput(tv0);
fusion.addInput(tv1);

// The output is [M, N, K] (no reordering needed)
MmaOp::AxisMapping axis_mapping{.a_axes = {1, -1, 0}, .b_axes = {-1, 1, 0}};
auto tv2 =
fusedMultiplySum(tv0, tv1, /*axes=*/{-1}, /*init=*/nullptr, axis_mapping);

auto tv3 = castOp(DataType::Half, tv2);

fusion.addOutput(tv3);

auto mma_ops = ir_utils::getOpsOfType<MmaOp>(&fusion);
NVF_CHECK(
1 == mma_ops.size(),
"Invalid number of MmaOp instances in fusion definition, expected 1, got ",
mma_ops.size());
mma_ops.front()->setMacro(macro);

// gmem [K, M] x gmem [K, N] -mma-> register [M, N, rK]
// register [M, N, rK] -cast-> gmem [M, N]

auto tv0c = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
tv0c->setMemoryType(MemoryType::Shared);
auto tv1c = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
tv1c->setMemoryType(MemoryType::Shared);
auto tv3c = tv3->cacheBefore();

tv0c->broadcast(-1); // [K, M] -> [K, M, 1]
tv1c->broadcast(-2); // [K, N] -> [K, 1, N]

// gmem [K, M, 1] -TMA-> smem [K, M, 1]
// gmem [K, 1, N] -TMA-> smem [K, 1, N]
// smem [K, M, 1] x smem [K, 1, N] -mma-> register [M, N, rK]
// register [M, N, rK] -cast-> register [M, N] -set-> gmem [M, N]

// Create tiles
tv2->split(-3, cta_m);
tv2->split(-2, cta_n);
tv2->split(-1, getK(macro));
// [Mo, Mi, No, Ni, Ko, Ki] -> [Mo, No, Ko, Mi, Ni, Ki]
tv2->reorder({{-5, -3}, {-3, -2}});
tv2->axis(0)->parallelize(ParallelType::BIDy);
tv2->axis(1)->parallelize(ParallelType::BIDx);

// NOTE: since in this case we do not have "proper" broadcast in the inputs,
// we cannot simply propagate transforms to the operands. Instead, we
// propagate forward to the outputs and manually schedule the smem operands.
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
tv2,
-1,
{tv3},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

// Schedule operands
for (TensorView* tv : {tv0c, tv1c}) {
tv->reorder({{-3, -1}}); // [K, M, N] -> [M, N, K]
// NOTE: above axes are given in MNK order, but inputs are in KMN
tv->split(-3, cta_m);
tv->split(-2, cta_n);
tv->split(-1, getK(macro));
// [Mo, Mi, No, Ni, Ko, Ki] -> [Mo, No, Ko, Mi, Ni, Ki]
// [Ko, Ki, Mo, Mi, No, Ni] -> [Mo, No, Ko, Mi, Ni, Ki]
tv->reorder({{-5, -3}, {-3, -2}});
tv->axis(0)->parallelize(ParallelType::BIDy);
tv->axis(1)->parallelize(ParallelType::BIDx);
}

// [..., Mi, Ni, Ki] -> [..., Ni, Ki, Mi]
tv0c->reorder({{-3, -1}});
tv0c->applyMmaSwizzleForTMALoad(swizzle);
// [..., Mi, Ni, Ki] -> [..., Mi, Ki, Ni]
tv1c->reorder({{-1, -2}});
tv1c->applyMmaSwizzleForTMALoad(swizzle);

{
tv2->split(-3, getM(macro));
tv2->split(-2, getN(macro));
// [Mo, No, Ko, Mio, Mii, Nio, Nii, Ki]
// -> [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki]
tv2->reorder({{-4, -3}});
tv2->merge(-5);
tv2->axis(-4)->parallelize(ParallelType::TIDy);
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
tv2,
-1,
{tv3},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());
}

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv2->getLoopDomain());
tv2->setAllocationDomain(s.as<IterDomain*>(), true);
tv2->axis(-1)->parallelize(ParallelType::Mma);
tv2->axis(-2)->parallelize(ParallelType::Mma);
tv2->axis(-3)->parallelize(ParallelType::Mma);
}

for (auto tv : {tv3c, tv3}) {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
}
tv3->axis(-1)->parallelize(ParallelType::Vectorize);

inlineMost();

if (stages > 1) {
tv0c->circularBuffer(stages, prefetch);
tv1c->circularBuffer(stages, prefetch);
}

// Test that predicate elimination works when the MmaOp's operands have no
// logical broadcasts
GpuLower gpulw(&fusion);
kir::Kernel* kernel = gpulw.run();
PredicateChecker pred_checker;
pred_checker.handle(kernel->topLevelExprs());
ASSERT_TRUE(pred_checker.found_mma);

// TODO: compile and run kernel once inlining is fixed
}

} // namespace nvfuser

0 comments on commit 736a541

Please sign in to comment.