-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Only check actually used IDs in predicate elimination for MmaOp #3414
Conversation
!test --diff |
@@ -3798,4 +3801,154 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { | |||
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); | |||
} | |||
|
|||
// Test scheduling a Hopper matmul where the operands are 2D | |||
TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is rather large, but since MmaOp is the only way I currently know how to trigger this behavior, I decided to just put the bulk of the test from #3406 here.
@@ -449,6 +449,35 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulPipelineGmem) { | |||
} | |||
} | |||
|
|||
// Check that mma op is not predicated. | |||
class PredicateChecker : public kir::IrVisitor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose instead of using this we could possibly use PredicatedChecker::isPredicated
instead. I kept it here to mirror the check in the AmpereSwizzle
test.
!build |
!test --diff |
!build |
!test |
This updates the NVF_THROW check to rule out the BroadcastOp case.
!test |
!test |
This is a potential alternative to #3414
This is an alternative to #3414 that limits the changes to only affect MmaOp.
@naoyam in the latest push I modified this check to only apply to |
!test --diff |
!test --diff |
// Fill ValGraph and grab all ValGroups on path from producer alloc to | ||
// consumer loop. | ||
IdModel& id_model = GpuLower::current()->idModel(); | ||
id_model.maybeBuildGraph(IdMappingMode::LOOP); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually am planning to not expose the loop graph but just expose the loop promotion map.
Here, IIUC, we want to find which IDs matter for indexing the producer since this analysis is used by ProducerConsumerPairAnalysis::needsPredicate
. The simplest way I would try is:
const auto& graph = tensorIndexer().traversalGraph();
ValGroups loop_groups;
for (auto loop_id : consumer->getLoopDomain()) {
auto promotion = getLoopPromotion(loop_id, idModel());
loop_groups.pushBack(graph.toGroup(promotion));
}
ValGroups producer_all_ids = graph.toGroups(producer->allIds());
auto used_vals = getValsBetween<ValGraphBFS>(loop_groups, producer_all_ids, graph)
This used_vals
tells us which IDs matter for indexing the producer. I think this info is sufficient for the MmaOp case because what we want to do here is to ignore the broadcast domain that only appear in the consumer tensor.
Instead of using producer->allIds()
, we could just consider the actual indexed IDs, which conceptually should be returned by getAllocationDomain()
, but as I mentioned before that may not be always true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah thanks for the tip. Locally I am using TensorIndexer but I forgot to promote the loop IDs. In the example here we are starting from the consumer and going to the producer. In the MmaOp case, the producer will have fewer groups than the consumer, so I was planning to start at producer allocation. However, I think I still need to promote the consumer loop IDs and use those as the target domain for the traversal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just pushed a change implementing something close to what you suggest. This simplifies the code a lot. I also removed the utility in device_lower/utils.cpp
.
!build |
csrc/device_lower/lower2device.cpp
Outdated
@@ -589,6 +584,11 @@ void GpuLower::analysis(Fusion* fusion) { | |||
tensor_indexer_ = std::make_unique<TensorIndexer>(*id_model_); | |||
} | |||
|
|||
// Detects all exprssions that don't need predicates. Depends on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please mention this now also depends on tensor_indexer_.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed your suggestions so I no longer depend on tensor_indexer_. So I reverted the changes to this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (finally!). Just left some minor comments.
Co-authored-by: Naoya Maruyama <[email protected]>
!build |
Stacked on #3414 This PR enables us to inline an MmaOp properly when its inputs are missing broadcast dimensions. We do this by always allowing inlining past loop broadcasts or their transforms. For example ``` tv0: logical [ iS1{i0} ] loop [ iS1{i0} bS5{1} ] tv1: logical [ iS2{i1} ] loop [ bS6{1} iS2{i1} ] tv2 = foo(tv0, tv1) logical [ iS3{i0} iS4{i1} ] ``` As long as the operation `foo` properly maps its arguments despite the missing logical dimensions (as `MmaOp` does as of #3391), then we should be able to fully inline this case because the loop broadcasts `bS5` and `bS6` are imaginary in the sense that they don't impact indexing.
Now that we can define MmaOp with unbroadcasted inputs (see #3391), it is possible to have ops for which some consumer loop IDs are not used at all for indexing some consumers.
For example, the output of
MmaOp
has three logical dimensions, M, N, and K. These are scheduled by spitting, merging, and swizzling, so in the end the consumer loop domain can contain things like a split of the N dimension into two other IterDomains. Now if we look at the producer A, it has logical size [M K], so there is no N dimension at all. Our current predicate elimination pass places a predicate on this operation when the N dimension is symbolic and we can't prove that the producer is parallelized the same way as this consumer in this dimension. However, since N cannot affect the indexing of the producer A which has no N dimension, we should skip checking these IterDomains.This PR does this by performing a BFS from the collection of consumer root IDs that map to producer logical IDs to the consumer leaf domain. Only IDs along that path are checked using the existing conditions.
Detailed example
In the test included in this PR, we have shared memory operand tensors that are scheduled like this
Notice that in
T4_s
that the loop broadcastsbblockIdx.x33{1}
andbS34{256}
are not derived from the logical domain. Instead, they are actually both the products of aSplit
involving an original "loop broadcast", although this is not currently shown infusion->printTransforms()
:In the predicate elimination pass with
T4_s
and producer andT2_l
as consumer, the consumer IDiblockIdx.x19{( ceilDiv(i4, 256) )}
normally would map to a logical broadcast ID inT4_s
, but with these loop domain broadcasts we do not have such a mapping. Before this PR that would cause predication. This PR notices thatiblockIdx.x19{( ceilDiv(i4, 256) )}
is not actually used for indexing the producerT4_s
so we do not need to worry about out-of-bounds accesses in this dimension.Without this PR, if we remove the check at
Fuser/csrc/device_lower/analysis/predicate_elimination.cpp
Lines 34 to 37 in 3266b9d
After this PR, the predicate around the
wgmma
call is removed and theassertOnWarpOps
check can be restored.