Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Only check actually used IDs in predicate elimination for MmaOp (#3414)
Now that we can define MmaOp with unbroadcasted inputs (see #3391), it is possible to have ops for which some consumer loop IDs are not used at all for indexing some consumers. For example, the output of `MmaOp` has three logical dimensions, M, N, and K. These are scheduled by spitting, merging, and swizzling, so in the end the consumer loop domain can contain things like a split of the N dimension into two other IterDomains. Now if we look at the producer A, it has logical size [M K], so there is no N dimension at all. Our current predicate elimination pass places a predicate on this operation when the N dimension is symbolic and we can't prove that the producer is parallelized the same way as this consumer in this dimension. However, since N cannot affect the indexing of the producer A which has no N dimension, we should skip checking these IterDomains. This PR does this by performing a BFS from the collection of consumer root IDs that map to producer logical IDs to the consumer leaf domain. Only IDs along that path are checked using the existing conditions. ## Detailed example In the test included in this PR, we have shared memory operand tensors that are scheduled like this ``` Inputs: T0_g___half[ iS0{i0}, iS1{i1} ] T1_g___half[ iS2{i3}, iS3{i4} ] Outputs: T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 ) %kernel_math { T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 ) = CpAsyncBulkTensorTile( T0_g___half[ iS0{i0}, iS1{i1} ] ) T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 ) = CpAsyncBulkTensorTile( T1_g___half[ iS2{i3}, iS3{i4} ] ) T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 ) = mma(T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 ), T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 )) T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 ) = __float2half(T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 )); T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 ) = Set( T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 ), cache_op=Streaming ) } // %kernel_math T0_g___half[ iS0{i0}, iS1{i1} ] logical domain : (iS0{i0}, iS1{i1}) contiguity: t t loop domain : (iS0{i0}, iS1{i1}) T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 ) logical domain : (iS9{i0}, iS10{i1}) allocation domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8}) contiguity: t n t n t t t t t t Split: iS10{i1} by factor 128 -> iblockIdx.y31{( ceilDiv(i1, 128) )}, iS32{128} Split: iS9{i0} by factor 16 -> iS35{( ceilDiv(i0, 16) )}, iS36{16} Split: iS32{128} by factor 64 -> iS43{2}, iS44{64} Split: iS36{16} by factor 8 -> iB45{2}, iS46{8} Split: iS46{8} by factor 1 -> iS47{8}, iB48{1} Split: iS44{64} by factor 8 -> iS49{8}, iB50{8} Xor(2D): iS47{8} , iS49{8} -> iB51{8} , iB52{8} loop domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8}) T1_g___half[ iS2{i3}, iS3{i4} ] logical domain : (iS2{i3}, iS3{i4}) contiguity: t t loop domain : (iS2{i3}, iS3{i4}) T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 ) logical domain : (iS11{i3}, iS12{i4}) allocation domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8}) contiguity: n t t n t t t t t t Split: iS12{i4} by factor 256 -> iblockIdx.x39{( ceilDiv(i4, 256) )}, iS40{256} Split: iS11{i3} by factor 16 -> iS41{( ceilDiv(i3, 16) )}, iS42{16} Split: iS40{256} by factor 64 -> iS53{4}, iS54{64} Split: iS42{16} by factor 8 -> iB55{2}, iS56{8} Split: iS56{8} by factor 1 -> iS57{8}, iB58{1} Split: iS54{64} by factor 8 -> iS59{8}, iB60{8} Xor(2D): iS57{8} , iS59{8} -> iB61{8} , iB62{8} loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8}) T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 ) logical domain : (iS4{i1}, iS5{i4}, rS6{i0}) allocation domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, ithreadIdx.x87{128}, iMMA82{32}, iMMA81{2}, iMMA85{2}, rMMA90{2}, rMMA91{4}, rMMA89{2}) contiguity: t t n t t t t t n n n Split: iS4{i1} by factor 128 -> iblockIdx.y17{( ceilDiv(i1, 128) )}, iS18{128} Split: iS5{i4} by factor 256 -> iblockIdx.x19{( ceilDiv(i4, 256) )}, iS20{256} Split: rS6{i0} by factor 16 -> rS21{( ceilDiv(i0, 16) )}, rMMA22{16} Split: iS18{128} by factor 64 -> iS63{2}, iMMA64{64} Split: iS20{256} by factor 256 -> iS65{1}, iMMA66{256} Merge: iS63{2} and iS65{1} -> ithreadIdx.y67{2} loop domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16}) ``` Notice that in `T4_s` that the loop broadcasts `bblockIdx.x33{1}` and `bS34{256}` are not derived from the logical domain. Instead, they are actually both the products of a `Split` involving an original "loop broadcast", although this is not currently shown in `fusion->printTransforms()`: ``` Split: bS15{1} by factor 256 -> bblockIdx.x33{1}, bS34{256} ``` In the predicate elimination pass with `T4_s` and producer and `T2_l` as consumer, the consumer ID `iblockIdx.x19{( ceilDiv(i4, 256) )}` _normally_ would map to a logical broadcast ID in `T4_s`, but with these loop domain broadcasts we do not have such a mapping. Before this PR that would cause predication. This PR notices that `iblockIdx.x19{( ceilDiv(i4, 256) )}` is not actually used for indexing the producer `T4_s` so we do not need to worry about out-of-bounds accesses in this dimension. Without this PR, if we remove the check at https://github.com/NVIDIA/Fuser/blob/3266b9d21cb82272fe6e766b71fb9a9f298de833/csrc/device_lower/analysis/predicate_elimination.cpp#L34-L37 then we generate the following code: ```c++ __global__ void nvfuser_none_f0_c0_r0_g0( Tensor<__half, 2, 2> T0, Tensor<__half, 2, 2> T1, const __grid_constant__ TensorMap var0, const __grid_constant__ TensorMap var1, Tensor<__half, 2, 2> T3) { // ... nvfuser_index_t i4; i4 = 256 * ((nvfuser_index_t)blockIdx.x); nvfuser_index_t i7; i7 = 128 * ((nvfuser_index_t)blockIdx.y); nvfuser_index_t i19; i19 = 64 * ((nvfuser_index_t)threadIdx.y); bool b20; b20 = (i4 < T1.logical_size[1LL]) && ((i19 + i7) < T0.logical_size[1LL]); #pragma unroll 1 for (nvfuser_index_t i23 = 0; i23 < i2; ++i23) { nvfuser_index_t i24; i24 = 16 * i23; // ... load operands ... __syncthreads(); if ((b20 && (i24 < T0.logical_size[0LL]))) { asm volatile( "{\n" " .reg .pred p0; \n" " setp.ne.b32 p0, %130, 0;\n" " wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 {..." /*... long parameter list ... */); } asm volatile("wgmma.commit_group.sync.aligned;\n"); asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory"); } asm volatile("wgmma.commit_group.sync.aligned;\n"); asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory"); // ... epilogue and write outputs ... } ``` After this PR, the predicate around the `wgmma` call is removed and the `assertOnWarpOps` check can be restored. --------- Co-authored-by: Naoya Maruyama <[email protected]>
- Loading branch information