Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only check actually used IDs in predicate elimination for MmaOp #3414

Merged
merged 31 commits into from
Dec 5, 2024

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Nov 15, 2024

Now that we can define MmaOp with unbroadcasted inputs (see #3391), it is possible to have ops for which some consumer loop IDs are not used at all for indexing some consumers.

For example, the output of MmaOp has three logical dimensions, M, N, and K. These are scheduled by spitting, merging, and swizzling, so in the end the consumer loop domain can contain things like a split of the N dimension into two other IterDomains. Now if we look at the producer A, it has logical size [M K], so there is no N dimension at all. Our current predicate elimination pass places a predicate on this operation when the N dimension is symbolic and we can't prove that the producer is parallelized the same way as this consumer in this dimension. However, since N cannot affect the indexing of the producer A which has no N dimension, we should skip checking these IterDomains.

This PR does this by performing a BFS from the collection of consumer root IDs that map to producer logical IDs to the consumer leaf domain. Only IDs along that path are checked using the existing conditions.

Detailed example

In the test included in this PR, we have shared memory operand tensors that are scheduled like this

Inputs:
  T0_g___half[ iS0{i0}, iS1{i1} ]
  T1_g___half[ iS2{i3}, iS3{i4} ]
Outputs:
  T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 )

%kernel_math {
T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 )
   = CpAsyncBulkTensorTile( T0_g___half[ iS0{i0}, iS1{i1} ] )
T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 )
   = CpAsyncBulkTensorTile( T1_g___half[ iS2{i3}, iS3{i4} ] )
T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 )
   = mma(T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 ),
         T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 ))
T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 )
   = __float2half(T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 ));
T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 )
   = Set( T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 ), cache_op=Streaming )
} // %kernel_math

T0_g___half[ iS0{i0}, iS1{i1} ]
 logical domain : (iS0{i0}, iS1{i1})
 contiguity: t t
 loop domain : (iS0{i0}, iS1{i1})
T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 )
 logical domain : (iS9{i0}, iS10{i1})
 allocation domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
 contiguity: t n t n t t t t t t
  Split: iS10{i1} by factor 128 -> iblockIdx.y31{( ceilDiv(i1, 128) )}, iS32{128}
  Split: iS9{i0} by factor 16 -> iS35{( ceilDiv(i0, 16) )}, iS36{16}
  Split: iS32{128} by factor 64 -> iS43{2}, iS44{64}
  Split: iS36{16} by factor 8 -> iB45{2}, iS46{8}
  Split: iS46{8} by factor 1 -> iS47{8}, iB48{1}
  Split: iS44{64} by factor 8 -> iS49{8}, iB50{8}
  Xor(2D): iS47{8} , iS49{8} -> iB51{8} , iB52{8}
 loop domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
T1_g___half[ iS2{i3}, iS3{i4} ]
 logical domain : (iS2{i3}, iS3{i4})
 contiguity: t t
 loop domain : (iS2{i3}, iS3{i4})
T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 )
 logical domain : (iS11{i3}, iS12{i4})
 allocation domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
 contiguity: n t t n t t t t t t
  Split: iS12{i4} by factor 256 -> iblockIdx.x39{( ceilDiv(i4, 256) )}, iS40{256}
  Split: iS11{i3} by factor 16 -> iS41{( ceilDiv(i3, 16) )}, iS42{16}
  Split: iS40{256} by factor 64 -> iS53{4}, iS54{64}
  Split: iS42{16} by factor 8 -> iB55{2}, iS56{8}
  Split: iS56{8} by factor 1 -> iS57{8}, iB58{1}
  Split: iS54{64} by factor 8 -> iS59{8}, iB60{8}
  Xor(2D): iS57{8} , iS59{8} -> iB61{8} , iB62{8}
 loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (iS4{i1}, iS5{i4}, rS6{i0})
 allocation domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, ithreadIdx.x87{128}, iMMA82{32}, iMMA81{2}, iMMA85{2}, rMMA90{2}, rMMA91{4}, rMMA89{2})
 contiguity: t t n t t t t t n n n
  Split: iS4{i1} by factor 128 -> iblockIdx.y17{( ceilDiv(i1, 128) )}, iS18{128}
  Split: iS5{i4} by factor 256 -> iblockIdx.x19{( ceilDiv(i4, 256) )}, iS20{256}
  Split: rS6{i0} by factor 16 -> rS21{( ceilDiv(i0, 16) )}, rMMA22{16}
  Split: iS18{128} by factor 64 -> iS63{2}, iMMA64{64}
  Split: iS20{256} by factor 256 -> iS65{1}, iMMA66{256}
  Merge: iS63{2} and iS65{1} -> ithreadIdx.y67{2}
 loop domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16})

Notice that in T4_s that the loop broadcasts bblockIdx.x33{1} and bS34{256} are not derived from the logical domain. Instead, they are actually both the products of a Split involving an original "loop broadcast", although this is not currently shown in fusion->printTransforms():

Split: bS15{1} by factor 256 -> bblockIdx.x33{1}, bS34{256}

In the predicate elimination pass with T4_s and producer and T2_l as consumer, the consumer ID iblockIdx.x19{( ceilDiv(i4, 256) )} normally would map to a logical broadcast ID in T4_s, but with these loop domain broadcasts we do not have such a mapping. Before this PR that would cause predication. This PR notices that iblockIdx.x19{( ceilDiv(i4, 256) )} is not actually used for indexing the producer T4_s so we do not need to worry about out-of-bounds accesses in this dimension.

Without this PR, if we remove the check at

NVF_ERROR(
!expr->isA<MmaOp>(),
"Mma op: cannot eliminate predicate for mma op, tiling not valid. ",
expr->toString());
then we generate the following code:

__global__ void nvfuser_none_f0_c0_r0_g0(      
    Tensor<__half, 2, 2> T0,                                                          
    Tensor<__half, 2, 2> T1,                                                          
    const __grid_constant__ TensorMap var0,                                           
    const __grid_constant__ TensorMap var1,                                           
    Tensor<__half, 2, 2> T3) {
  // ...
  nvfuser_index_t i4;
  i4 = 256 * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i7;
  i7 = 128 * ((nvfuser_index_t)blockIdx.y);
  nvfuser_index_t i19;
  i19 = 64 * ((nvfuser_index_t)threadIdx.y);
  bool b20;
  b20 = (i4 < T1.logical_size[1LL]) && ((i19 + i7) < T0.logical_size[1LL]);

#pragma unroll 1
  for (nvfuser_index_t i23 = 0; i23 < i2; ++i23) {
    nvfuser_index_t i24;
    i24 = 16 * i23;

    // ... load operands ...

    __syncthreads();
    if ((b20 && (i24 < T0.logical_size[0LL]))) {
      asm volatile(
          "{\n"
          "  .reg .pred p0; \n"
          "  setp.ne.b32 p0, %130, 0;\n"
          "  wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 {..." /*... long parameter list ... */);
    }
    asm volatile("wgmma.commit_group.sync.aligned;\n");
    asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory");
  }
  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory");
  // ... epilogue and write outputs ...
}

After this PR, the predicate around the wgmma call is removed and the assertOnWarpOps check can be restored.

@jacobhinkle
Copy link
Collaborator Author

!test --diff

@@ -3798,4 +3801,154 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

// Test scheduling a Hopper matmul where the operands are 2D
TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is rather large, but since MmaOp is the only way I currently know how to trigger this behavior, I decided to just put the bulk of the test from #3406 here.

@@ -449,6 +449,35 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulPipelineGmem) {
}
}

// Check that mma op is not predicated.
class PredicateChecker : public kir::IrVisitor {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose instead of using this we could possibly use PredicatedChecker::isPredicated instead. I kept it here to mirror the check in the AmpereSwizzle test.

Base automatically changed from mutator_preserve_additional_ids to main November 15, 2024 19:00
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

!test --diff

@jacobhinkle jacobhinkle marked this pull request as ready for review November 18, 2024 13:16
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle requested review from zasdfgbnm and naoyam and removed request for naoyam November 18, 2024 14:52
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle jacobhinkle marked this pull request as draft November 18, 2024 21:06
This updates the NVF_THROW check to rule out the BroadcastOp case.
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

!test

jacobhinkle added a commit that referenced this pull request Nov 25, 2024
jacobhinkle added a commit that referenced this pull request Nov 27, 2024
This is an alternative to #3414 that limits the changes to only affect
MmaOp.
@jacobhinkle jacobhinkle changed the title Only check actually used IDs in predicate elimination Only check actually used IDs in predicate elimination for MmaOp Nov 27, 2024
@jacobhinkle
Copy link
Collaborator Author

@naoyam in the latest push I modified this check to only apply to MmaOp, which is the only op that currently requires special handling.

@jacobhinkle
Copy link
Collaborator Author

!test --diff

csrc/device_lower/utils.cpp Outdated Show resolved Hide resolved
csrc/device_lower/utils.cpp Outdated Show resolved Hide resolved
@jacobhinkle
Copy link
Collaborator Author

!test --diff

// Fill ValGraph and grab all ValGroups on path from producer alloc to
// consumer loop.
IdModel& id_model = GpuLower::current()->idModel();
id_model.maybeBuildGraph(IdMappingMode::LOOP);
Copy link
Collaborator

@naoyam naoyam Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually am planning to not expose the loop graph but just expose the loop promotion map.

Here, IIUC, we want to find which IDs matter for indexing the producer since this analysis is used by ProducerConsumerPairAnalysis::needsPredicate. The simplest way I would try is:

const auto& graph = tensorIndexer().traversalGraph();
ValGroups loop_groups;
for (auto loop_id : consumer->getLoopDomain()) {
  auto promotion = getLoopPromotion(loop_id, idModel());
 loop_groups.pushBack(graph.toGroup(promotion));
}

ValGroups producer_all_ids = graph.toGroups(producer->allIds());
auto used_vals = getValsBetween<ValGraphBFS>(loop_groups, producer_all_ids, graph)

This used_vals tells us which IDs matter for indexing the producer. I think this info is sufficient for the MmaOp case because what we want to do here is to ignore the broadcast domain that only appear in the consumer tensor.

Instead of using producer->allIds(), we could just consider the actual indexed IDs, which conceptually should be returned by getAllocationDomain(), but as I mentioned before that may not be always true.

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks for the tip. Locally I am using TensorIndexer but I forgot to promote the loop IDs. In the example here we are starting from the consumer and going to the producer. In the MmaOp case, the producer will have fewer groups than the consumer, so I was planning to start at producer allocation. However, I think I still need to promote the consumer loop IDs and use those as the target domain for the traversal.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed a change implementing something close to what you suggest. This simplifies the code a lot. I also removed the utility in device_lower/utils.cpp.

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle requested a review from naoyam December 5, 2024 13:42
@@ -589,6 +584,11 @@ void GpuLower::analysis(Fusion* fusion) {
tensor_indexer_ = std::make_unique<TensorIndexer>(*id_model_);
}

// Detects all exprssions that don't need predicates. Depends on
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please mention this now also depends on tensor_indexer_.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed your suggestions so I no longer depend on tensor_indexer_. So I reverted the changes to this file.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (finally!). Just left some minor comments.

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle merged commit 736a541 into main Dec 5, 2024
17 checks passed
@jacobhinkle jacobhinkle deleted the mma_predicate_elimination branch December 5, 2024 19:01
jacobhinkle added a commit that referenced this pull request Dec 11, 2024
Stacked on #3414 

This PR enables us to inline an MmaOp properly when its inputs are
missing broadcast dimensions. We do this by always allowing inlining
past loop broadcasts or their transforms. For example
```
tv0:
  logical [ iS1{i0} ]
  loop [ iS1{i0} bS5{1} ]
tv1:
  logical [ iS2{i1} ]
  loop [ bS6{1} iS2{i1} ]
tv2 = foo(tv0, tv1)
  logical [ iS3{i0} iS4{i1} ]
```
As long as the operation `foo` properly maps its arguments despite the
missing logical dimensions (as `MmaOp` does as of #3391), then we should
be able to fully inline this case because the loop broadcasts `bS5` and
`bS6` are imaginary in the sense that they don't impact indexing.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants