From be65f1e302462a8f5432e4ccb991750d362fc29e Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 24 Dec 2024 17:30:39 +0000 Subject: [PATCH] Add patterns for memref::viewOp --- .../XeTileToXeGPU/XeTileToXeGPU.cpp | 55 ++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp index cf7e0363e..f0fc569a1 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp @@ -456,6 +456,42 @@ class SCFYieldOpPattern : public mlir::OpConversionPattern { } }; +// TODO: this is a temporary solution to support memref::ViewOp. +// Since the upstream doesn't have lowering pattern for converting +// memref::ViewOp to SPIRV, so here we convert it with alloc instead. +// But it requires every alloc just has one view. It should be removed +// after enable the support in MemrefToSPIRV. +class MemRefViewOpPattern final + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + mlir::LogicalResult + matchAndRewrite(mlir::memref::ViewOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto memrefTy = op.getType(); + auto memSpace = + mlir::dyn_cast_if_present(memrefTy.getMemorySpace()); + + if (!memrefTy.hasStaticShape() || !memSpace || memSpace.getValue() != 3) + return mlir::failure(); + + // for simplicity, make sure source is an alloc op, and only has one use, + // otherwise skip it, since it is hard to guarantee the correctness. + auto src = op.getSource(); + if (!mlir::isa(src.getDefiningOp()) || + !src.hasOneUse()) + return mlir::failure(); + + auto alignmentAttr = rewriter.getI64IntegerAttr(32); + + auto allocOp = rewriter.create(op.getLoc(), memrefTy, + alignmentAttr); + + rewriter.replaceOp(op, allocOp); + return mlir::success(); + } +}; + class XeTileConversionTarget : public mlir::ConversionTarget { public: explicit XeTileConversionTarget(mlir::MLIRContext &context, @@ -488,6 +524,21 @@ class XeTileConversionTarget : public mlir::ConversionTarget { mlir::succeeded(uArchInterface->isLegalPrefetch2dOp(op))); }); + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { + auto viewOp = mlir::dyn_cast(op); + auto memrefTy = viewOp.getType(); + auto byteShift = viewOp.getByteShift(); + auto sizes = viewOp.getSizes(); + if (sizes.size() > 0 || !mlir::isConstantIntValue(byteShift, 0)) + return true; + auto memSpace = mlir::dyn_cast_if_present( + memrefTy.getMemorySpace()); + if (!memSpace || memSpace.getValue() != 3) + return true; + return memrefTy.getRank() != 2; + }); + addIllegalDialect(); addLegalDialect(); addLegalOp(); @@ -640,8 +691,8 @@ void populateXeTileToXeGPUConversionPatterns( patterns.add(converter, - patterns.getContext()); + SCFForOpPattern, SCFYieldOpPattern, MemRefViewOpPattern>( + converter, patterns.getContext()); } /// Create a pass that convert XeTile to XeGPU