Skip to content

Commit

Permalink
[GML] Broadcast attention mask to QK shape
Browse files Browse the repository at this point in the history
Signed-off-by: James Bartlett <[email protected]>
  • Loading branch information
JamesMBartlett committed Nov 26, 2024
1 parent 1e3d2a9 commit dcee4c2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
35 changes: 35 additions & 0 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/Matchers.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
Expand Down Expand Up @@ -1837,6 +1838,40 @@ class ConvertAtenScaledDotProductAttentionOp
reassociation);
};

if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
SmallVector<int64_t> attnWeightShape;
// attnWeight is (N, ..., L, S)
for (int i = 0; i < valueTy.getRank() - 2; ++i)
attnWeightShape.push_back(valueTy.getDimSize(i));
// get L from queryTy.
attnWeightShape.push_back(queryTy.getDimSize(queryTy.getRank() - 2));
// get S from keyTy.
attnWeightShape.push_back(keyTy.getDimSize(keyTy.getRank() - 2));

auto maskTy = cast<ShapedType>(mask.getType());
auto broadcastTy = maskTy.clone(attnWeightShape);

SmallVector<Value> broadcastToShape;
SmallVector<bool> useBroadcastToShape;
for (int i = 0; i < broadcastTy.getRank(); ++i) {
broadcastToShape.push_back(rewriter.create<arith::ConstantOp>(
op->getLoc(), rewriter.getI64IntegerAttr(attnWeightShape[i])));
useBroadcastToShape.push_back(
!ShapedType::isDynamic(attnWeightShape[i]));
}

// Broadcast mask to attnWeightShape.
Value broadcasted_mask;
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, mask, broadcastToShape,
RankedTensorType::get(attnWeightShape, maskTy.getElementType()),
broadcasted_mask, useBroadcastToShape))) {
op->emitError("failed to broadcast mask to attention weight shape");
return failure();
}
mask = broadcasted_mask;
}

query = collapseBatch(query);
key = collapseBatch(key);
value = collapseBatch(value);
Expand Down
1 change: 1 addition & 0 deletions utils/bazel/torch-mlir-overlay/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ cc_library(
":TorchMLIRTMTensorDialect",
":TorchMLIRTorchBackendTypeConversion",
":TorchMLIRTorchConversionDialect",
":TorchMLIRTorchToLinalg",
"@llvm-project//mlir:LinalgDialect",
],
)
Expand Down

0 comments on commit dcee4c2

Please sign in to comment.