From b45ee5a79989eada68de5e0c7f6ab8f9dcfbbf9b Mon Sep 17 00:00:00 2001 From: James Bartlett Date: Thu, 26 Sep 2024 15:16:43 -0700 Subject: [PATCH] [GML] Convert boolean masks to index tensors when converting to hacked twin version of aten.index.Tensor (#3) Signed-off-by: James Bartlett --- .../Torch/Transforms/DecomposeComplexOps.cpp | 62 +++++++++++++++++++ .../Transforms/ReifyShapeCalculations.cpp | 20 ++++++ 2 files changed, 82 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9b24d0e959f3..e2b027916915 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9337,6 +9337,63 @@ static FailureOr createNewIndices(Operation *op, return newIndexList; } +static LogicalResult +createIndexTensorsFromBoolMask(Operation *op, PatternRewriter &rewriter, + Value tensor, + SmallVectorImpl &indexTensors) { + auto loc = op->getLoc(); + auto type = llvm::cast(tensor.getType()); + + auto rank = type.getSizes().size(); + + auto indicesResultType = ValueTensorType::get( + op->getContext(), ArrayRef{-1, rank}, + rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); + auto indices = + rewriter.create(loc, indicesResultType, tensor); + + // In pytorch a boolean mask that is not 1D will expand beyond the dimension + // it is indexing. For example, if I have a tensor `x` of shape [1,2,3,4,5] + // and a boolean mask `mask` of shape [2,3,4]. x[:, mask, :] would result in a + // [1, ?, 5] tensor where ? is the number of True values in `mask`. + // So we need to create a new index tensor for each dimension it expands to. + auto dimResultType = ValueTensorType::get( + op->getContext(), ArrayRef{-1}, + rewriter.getIntegerType(/*width*/ 64, /*isSigned*/ true)); + auto dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + for (int64_t i = 0; i < rank; ++i) { + auto index = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + auto dimIndices = rewriter.create( + loc, dimResultType, indices, dim, index); + indexTensors.push_back(dimIndices); + } + + return success(); +} + +static LogicalResult +replaceBoolMasksWithIndices(Operation *op, PatternRewriter &rewriter, + SmallVectorImpl &indices) { + SmallVector newIndices; + for (auto tensor : indices) { + auto tensorType = llvm::dyn_cast(tensor.getType()); + if (!tensorType || !tensorType.hasDtype() || !tensorType.hasSizes() || + !tensorType.getDtype().isSignlessInteger(1)) { + newIndices.push_back(tensor); + continue; + } + + if (failed( + createIndexTensorsFromBoolMask(op, rewriter, tensor, newIndices))) + return rewriter.notifyMatchFailure( + op, "failed to convert boolean mask to index tensor"); + } + indices = newIndices; + return success(); +} + // The goal of this pattern is to eliminate `None` index in aten.Index.Tensor's // `indices` param and transform it to aten.index.Tensor_hacked_twin, for the // ease of various backend. @@ -9366,6 +9423,11 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { return isa(v.getType()); }; + if (failed(replaceBoolMasksWithIndices(op, rewriter, indices))) { + return rewriter.notifyMatchFailure( + op, "failed to convert all boolean masks to index tensors"); + } + // directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin if (llvm::all_of(indices, isTensor)) { // By default, we regard the first index type as the list element type. diff --git a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp index fb9d33123a9c..0e9aebba2a3a 100644 --- a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp @@ -14,6 +14,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/Support/MemoryBuffer.h" using namespace mlir; @@ -82,6 +83,25 @@ struct ReifyShapeCalculationsPass // in a `torch.shape.calculate` op. SmallVector functionsNeeded; WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult { + if (auto index_op = llvm::dyn_cast(op)) { + // IndexTensor shape inference is incorrect for boolean masks. + // Unforunately, torch-mlir's shape calculation infrastructure + // doesn't provide data types to the calculation functions so we can't + // easily fix the problem. Instead we disable shape inference for + // IndexTensor ops that have boolean inputs. + + SmallVector indices; + if (Torch::getListConstructElements(index_op.getIndices(), indices)) { + for (auto tensor : indices) { + auto tensorType = + llvm::dyn_cast(tensor.getType()); + if (tensorType && tensorType.hasDtype() && + tensorType.getDtype().isSignlessInteger(1)) { + return WalkResult::advance(); + } + } + } + } return wrapWithCalculateOpIfLibraryFunctionAvailable( op, *library, LibraryFunctionKind::ShapeFunction, functionsNeeded, shapeFunctionArgsBuilder);