Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insert freeze between masked loads and sdiv/srem instructions #2775

Merged
merged 18 commits into from
Dec 3, 2024

Conversation

alexbaden
Copy link
Contributor

Close #2726

From the code comments:

The Triton masked load pattern can generate instances where the
mask value causes undefined behavior in sdiv/srem instructions. The
language allows this UB as the result of those arithmetic
instructions is never used, and control flow to avoid computation
of these instructions would negatively affect performance. But,
LLVM SimplifyCFG aggressively marks code paths with undefined
behavior as dead. This can result in removal of the mask path and
incorrect results from legal Triton kernels due to masked elements
being used in computation. Run a pass to add a freeze instruction
between masked loads and sdiv/srem to signal to LLVM we consider
the sdiv/srem operands to be well defined.

The strategy here is to basically invalidate the assumptions under which SimplifyCFG can remove UB for sdiv/srem. The rationale is that, unlike C/C++, Triton explicitly allows UB in sdiv/srem instructions (likely because the hardware Triton is targeting allows that). Inserting a freeze instruction both signals that we expect the behavior of sdiv/srem to be well defined and hides the constant 0 in the phi from SimplifyCFG's UB optimizations.

The pass needs to run after every instance of InstCombine because the LLVM optimization that removes UB only occurs if the sdiv/srem are in the same BB as the phi, which can happen after any InstCombine.

Note that the directory structure for this pass is a little different than BreakStructPhiNodesPass because we are already using those directories in third_party for MLIR code. If we want to change that, I can open an issue but let's do it separately from this PR.

@alexbaden
Copy link
Contributor Author

@arunjose696 The idea of the algorithm is as follows:

Look through each basic block of the function to find one that starts with a PhiNode.

When we find a basic block that starts with a PhiNode, process that basic block by first checking to see if any of the PhiNode values are null/0 constants.

If no PhiNode values are null/zero, no further action is needed. If we have a null or zero, then we iterate the instructions in the BB to see if any sdiv/srem instructions use that null/zero value. If so, we freeze the output of the PhiNode and replace the operand in the sdiv/srem instruction with that frozen value.

The first loop only looks at the first instruction, but iterating all the instructions and breaking is a relatively easy way to do this (and is done in many other LLVM passes). The second loop has to look at all instructions in the BB.

@alexbaden alexbaden requested a review from a team November 22, 2024 12:30
Copy link
Contributor

@victor-eds victor-eds left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get a lit test under test/LLVMIR?

third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp Outdated Show resolved Hide resolved
third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp Outdated Show resolved Hide resolved
Comment on lines 20 to 30
for (Instruction &I : BB) {
if (I.getOpcode() == Instruction::SDiv ||
I.getOpcode() == Instruction::SRem) {
const size_t OpIdx = 1;
if (I.getOperand(OpIdx) == PhiNode) {
auto *freezePhi = new FreezeInst(
PhiNode, PhiNode->getName() + ".frozen", I.getIterator());
I.setOperand(OpIdx, freezePhi);
Changed = true;
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to iterate on PhiNode's uses?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case, we don't need to pass in BB, and we can rename the function to processPHINode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to stay in the same basic block as the Phi node, so users() is not entirely straightforward - I think iterating basic block instructions and checking for the operand match is clearer.

python/test/regression/test_divide.py Outdated Show resolved Hide resolved
@victor-eds
Copy link
Contributor

I'd rather have a lit test than the current test. But I'm open to having both.

@alexbaden
Copy link
Contributor Author

I can work on a lit test, but the regression test is far more important as the concern is keeping the mask false path intact throughout the LLVM optimization pipeline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to add to PostProcess folder that contains other LLVM passes we have?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not think we had any other LLVM passes. Did I miss some? Because this pass operates on LLVMIR and not the MLIR LLVM dialect, it needs to be separate from the MLIR LLVM Dialect passes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have some under third_party/intel/lib/Target/LLVMIR/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - I tried integrating with those but they seem to run as part of the MLIR -> LLVMIR lowering and not as part of the LLVMIR optimization pipeline, so the compiler target needs to be different. Let’s make an issue to follow up if we feel strongly and leave the current directory structure as is for this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we have to address that in this PR. Those passes are pure LLVM passes, should have no relation to MLIR. @etiotto correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes is correct. They post process the optimized LLVM IR produced by LLVM's opt.

Comment on lines 20 to 30
for (Instruction &I : BB) {
if (I.getOpcode() == Instruction::SDiv ||
I.getOpcode() == Instruction::SRem) {
const size_t OpIdx = 1;
if (I.getOperand(OpIdx) == PhiNode) {
auto *freezePhi = new FreezeInst(
PhiNode, PhiNode->getName() + ".frozen", I.getIterator());
I.setOperand(OpIdx, freezePhi);
Changed = true;
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case, we don't need to pass in BB, and we can rename the function to processPHINode

third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp Outdated Show resolved Hide resolved
@alexbaden
Copy link
Contributor Author

Lit test added, all comments have been addressed in code or with a reply above.

Copy link
Contributor

@victor-eds victor-eds left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@alexbaden alexbaden merged commit 78c13a5 into main Dec 3, 2024
5 checks passed
@alexbaden alexbaden deleted the alex/2585 branch December 3, 2024 03:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants