Skip to content

Commit

Permalink
[GML] Update SDPA ReduceOpVariants pattern after rebase (#2)
Browse files Browse the repository at this point in the history
Signed-off-by: James Bartlett <[email protected]>
  • Loading branch information
JamesMBartlett authored Sep 25, 2024
1 parent 9713a52 commit 5984034
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ void TorchMatchSpecializedBackendOp::populateSpecializedConversions(
llvm::SmallVector<Value> newOperands{
oldOperands[0], oldOperands[1], oldOperands[2], oldOperands[3],
oldOperands[4], oldOperands[5], oldOperands[7]};
Value enableGQA =
rewriter.create<ConstantBoolOp>(op->getLoc(), false);
newOperands.push_back(enableGQA);

auto newOp = rewriter.create<Torch::AtenScaledDotProductAttentionOp>(
op.getLoc(), op->getResultTypes()[0], newOperands,
Expand Down

0 comments on commit 5984034

Please sign in to comment.