-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Insert freeze between masked loads and sdiv/srem instructions #2775
Conversation
@arunjose696 The idea of the algorithm is as follows: Look through each basic block of the function to find one that starts with a PhiNode. When we find a basic block that starts with a PhiNode, process that basic block by first checking to see if any of the PhiNode values are null/0 constants. If no PhiNode values are null/zero, no further action is needed. If we have a null or zero, then we iterate the instructions in the BB to see if any sdiv/srem instructions use that null/zero value. If so, we freeze the output of the PhiNode and replace the operand in the sdiv/srem instruction with that frozen value. The first loop only looks at the first instruction, but iterating all the instructions and breaking is a relatively easy way to do this (and is done in many other LLVM passes). The second loop has to look at all instructions in the BB. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get a lit test under test/LLVMIR
?
for (Instruction &I : BB) { | ||
if (I.getOpcode() == Instruction::SDiv || | ||
I.getOpcode() == Instruction::SRem) { | ||
const size_t OpIdx = 1; | ||
if (I.getOperand(OpIdx) == PhiNode) { | ||
auto *freezePhi = new FreezeInst( | ||
PhiNode, PhiNode->getName() + ".frozen", I.getIterator()); | ||
I.setOperand(OpIdx, freezePhi); | ||
Changed = true; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be better to iterate on PhiNode
's uses?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in that case, we don't need to pass in BB, and we can rename the function to processPHINode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to stay in the same basic block as the Phi node, so users()
is not entirely straightforward - I think iterating basic block instructions and checking for the operand match is clearer.
I'd rather have a lit test than the current test. But I'm open to having both. |
I can work on a lit test, but the regression test is far more important as the concern is keeping the mask false path intact throughout the LLVM optimization pipeline. |
third_party/intel/lib/CMakeLists.txt
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to add to PostProcess folder that contains other LLVM passes we have?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not think we had any other LLVM passes. Did I miss some? Because this pass operates on LLVMIR and not the MLIR LLVM dialect, it needs to be separate from the MLIR LLVM Dialect passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have some under third_party/intel/lib/Target/LLVMIR/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right - I tried integrating with those but they seem to run as part of the MLIR -> LLVMIR lowering and not as part of the LLVMIR optimization pipeline, so the compiler target needs to be different. Let’s make an issue to follow up if we feel strongly and leave the current directory structure as is for this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, we have to address that in this PR. Those passes are pure LLVM passes, should have no relation to MLIR. @etiotto correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes is correct. They post process the optimized LLVM IR produced by LLVM's opt
.
for (Instruction &I : BB) { | ||
if (I.getOpcode() == Instruction::SDiv || | ||
I.getOpcode() == Instruction::SRem) { | ||
const size_t OpIdx = 1; | ||
if (I.getOperand(OpIdx) == PhiNode) { | ||
auto *freezePhi = new FreezeInst( | ||
PhiNode, PhiNode->getName() + ".frozen", I.getIterator()); | ||
I.setOperand(OpIdx, freezePhi); | ||
Changed = true; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in that case, we don't need to pass in BB, and we can rename the function to processPHINode
Lit test added, all comments have been addressed in code or with a reply above. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Co-authored-by: Arun Jose <[email protected]>
Close #2726
From the code comments:
The Triton masked load pattern can generate instances where the
mask value causes undefined behavior in sdiv/srem instructions. The
language allows this UB as the result of those arithmetic
instructions is never used, and control flow to avoid computation
of these instructions would negatively affect performance. But,
LLVM SimplifyCFG aggressively marks code paths with undefined
behavior as dead. This can result in removal of the mask path and
incorrect results from legal Triton kernels due to masked elements
being used in computation. Run a pass to add a freeze instruction
between masked loads and sdiv/srem to signal to LLVM we consider
the sdiv/srem operands to be well defined.
The strategy here is to basically invalidate the assumptions under which SimplifyCFG can remove UB for sdiv/srem. The rationale is that, unlike C/C++, Triton explicitly allows UB in sdiv/srem instructions (likely because the hardware Triton is targeting allows that). Inserting a
freeze
instruction both signals that we expect the behavior of sdiv/srem to be well defined and hides the constant 0 in the phi from SimplifyCFG's UB optimizations.The pass needs to run after every instance of
InstCombine
because the LLVM optimization that removes UB only occurs if the sdiv/srem are in the same BB as the phi, which can happen after anyInstCombine
.Note that the directory structure for this pass is a little different than
BreakStructPhiNodesPass
because we are already using those directories inthird_party
for MLIR code. If we want to change that, I can open an issue but let's do it separately from this PR.