From ff515940c229a3af49671895d07b0295a320a16c Mon Sep 17 00:00:00 2001 From: Sang Ik Lee Date: Fri, 20 Dec 2024 12:58:22 -0800 Subject: [PATCH] Add XeTile block operation fallback pass (#991) Certain block operation legal at XeTile dialect level cannot be supported by matching XeGPU dialect op due to not meeting HW restriction. This PR adds a new pass that provide a fallback for some cases. Pass can be called with command line arg --xetile-blockop-fallback The cases covered are: Source of Tile is a static shaped row major memref but - pitch is not a multiple of 16 bytes or less than 64 bytes - or memory space indicates SLM memory For such fitting case, this pass turns - block tile to scatter tile - load_tile to load - store_tile to store - update_tile_offset to use tile shaped indices instead of X, Y offset - impacted scf.for arguments from block tile type to scatter tile type --- .../imex/Dialect/XeTile/Transforms/Passes.h | 1 + .../imex/Dialect/XeTile/Transforms/Passes.td | 17 + .../XeTile/Transforms/BlockOpFallback.cpp | 443 ++++++++++++++++++ lib/Dialect/XeTile/Transforms/CMakeLists.txt | 1 + lib/Transforms/RemoveSingleElemVector.cpp | 4 +- .../XeTile/Transforms/block_op_fallback.mlir | 391 ++++++++++++++++ .../fallback/narrow_tile_one_elem_wide.mlir | 64 +++ .../fallback/narrow_tile_two_elem_wide.mlir | 65 +++ .../Dialect/XeTile/fallback/slm.mlir | 80 ++++ .../fallback/xetile-fallback-to-func-vc.pp | 41 ++ .../postop_reduce_n.mlir | 4 +- 11 files changed, 1109 insertions(+), 2 deletions(-) create mode 100644 lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp create mode 100644 test/Dialect/XeTile/Transforms/block_op_fallback.mlir create mode 100644 test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir create mode 100644 test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir create mode 100644 test/Integration/Dialect/XeTile/fallback/slm.mlir create mode 100644 test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.h b/include/imex/Dialect/XeTile/Transforms/Passes.h index 869732f3b..c24b30184 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.h +++ b/include/imex/Dialect/XeTile/Transforms/Passes.h @@ -40,6 +40,7 @@ std::unique_ptr createXeTileBlockingPass(const std::string &device = "pvc"); std::unique_ptr createXeTileWgToSgPass(); std::unique_ptr createXeTileCanonicalizationPass(); +std::unique_ptr createXeTileBlockOpFallbackPass(); #define GEN_PASS_DECL_XETILEBLOCKING #define GEN_PASS_DECL_XETILECANONICALIZATION diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.td b/include/imex/Dialect/XeTile/Transforms/Passes.td index 83c141718..6381c8887 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.td +++ b/include/imex/Dialect/XeTile/Transforms/Passes.td @@ -96,4 +96,21 @@ def XeTileBlocking : Pass<"xetile-blocking", "::mlir::gpu::GPUModuleOp">{ } +def XeTileBlockOpFallback : Pass<"xetile-blockop-fallback", "::mlir::gpu::GPUModuleOp">{ + let summary = "Transform unsuitable block ops to fallback scattered ops"; + + let description = [{ + This transform pass transforms XeTile block ops that are not suitable due to HW restrictions, + to scattered XeTile ops. + }]; + + let constructor = "imex::createXeTileBlockOpFallbackPass()"; + let dependentDialects = ["imex::xetile::XeTileDialect", + "mlir::arith::ArithDialect", + "mlir::gpu::GPUDialect", + "mlir::index::IndexDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect"]; +} + #endif // _XeTile_PASSES_TD_INCLUDED_ diff --git a/lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp b/lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp new file mode 100644 index 000000000..a08b5fb43 --- /dev/null +++ b/lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp @@ -0,0 +1,443 @@ +//====-- BlockOpFallback.cpp - XeTile Block Op Fallback Pass ----*- C++-*-===// +// +// Copyright 2024 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This pass detects XeTile InitTile ops that does not meet HW restrictions +/// and rewrite into InitTile ops with scattered description. +/// This triggers change to tile type. Shape is same but scattered attribute +/// is added. As a result, ops requiring block description gets invalid. +/// Patterns that legalizes those ops are added and GreedyPatternRewriteDriver +/// is used to apply the patterns. +/// +//===----------------------------------------------------------------------===// + +#include "imex/Dialect/XeTile/IR/XeTileOps.h" +#include "imex/Dialect/XeTile/Transforms/Passes.h" +#include "imex/Utils/XeCommon.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace imex { +#define GEN_PASS_DEF_XETILEBLOCKOPFALLBACK +#include "imex/Dialect/XeTile/Transforms/Passes.h.inc" +} // namespace imex + +// +// Limitations and future plan: +// Currently limited to 2D static shaped source memref of row major order. +// Dynamic shape with known pitch value will be supported in the future. +// Scattered offset is calculated by generated code sequence but will be +// using constant immediate values for static shapes in the future. +// Current code sequence is not optimal for supporting blocking pass and +// SIMT lowering. This will be improved in the future. +// Correct mask generation requires getting absolute offsets from source +// memref. This is not supported in the current implementation and will be +// added in the future. For now, mask generation assume all accesses are +// in bounds +// + +namespace blockopfallback { + +static imex::xetile::TileType addScatterAttr(imex::xetile::TileType tileTy) { + auto tileAttr = + mlir::dyn_cast_or_null(tileTy.getEncoding()); + int memorySpace = tileTy.getMemorySpaceAsInt(); + if (!tileAttr) { + std::vector orderVec = {1, 0}; + llvm::ArrayRef order(orderVec); + auto encoding = imex::xetile::XeTileAttr::get(tileTy.getContext(), order, + memorySpace, true); + return imex::xetile::TileType::get(tileTy.getShape(), + tileTy.getElementType(), encoding); + } + auto sgMap = tileAttr.getSgMap(); + auto wgMap = tileAttr.getWgMap(); + auto order = tileAttr.getOrder().asArrayRef(); + auto scatterTileAttr = imex::xetile::XeTileAttr::get( + tileTy.getContext(), sgMap, wgMap, order, memorySpace, true); + return imex::xetile::TileType::get(tileTy.getShape(), tileTy.getElementType(), + scatterTileAttr); +} + +struct InitTileOpPattern final + : public mlir::OpRewritePattern { + InitTileOpPattern(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(imex::xetile::InitTileOp initTileOp, + mlir::PatternRewriter &rewriter) const override { + auto tileTy = initTileOp.getType(); + // Skip if tile is scattered + if (tileTy.getScatterAttr()) { + return mlir::failure(); + } + // Skip 1D tile + if (tileTy.getRank() < 2) { + return mlir::failure(); + } + // Skip if tile is column major + if (tileTy.getOrder().asArrayRef() != llvm::ArrayRef{1, 0}) { + return mlir::failure(); + } + // Currenty only supports memref source + if (!initTileOp.isSourceMemRef()) { + return mlir::failure(); + } + // Cannot handle non static shape + if (!initTileOp.sourceMemRefHasStaticShape()) { + return mlir::failure(); + } + + auto srcShape = initTileOp.getSourceMemrefStaticShape(); + // Cannot handle non 2D source memref + if (srcShape.size() != 2) { + return mlir::failure(); + } + // Check if memspace is SLM + auto memorySpace = initTileOp.getSourceMemorySpaceAsInt(); + bool isSLM = memorySpace == 3; + // Check pitch >= 64bytes and pitch is multiple of 16bytes + bool isValidPitch = true; + auto pitchNumElems = srcShape[srcShape.size() - 1]; + auto elemBitwidth = + initTileOp.getSourceMemrefElemType().getIntOrFloatBitWidth(); + auto pitchNumBytes = pitchNumElems * elemBitwidth / 8; + isValidPitch = pitchNumBytes >= 64 && (pitchNumBytes % 16 == 0); + // If memspace is not SLM and pitch is valid, no need to rewrite + if (!isSLM && isValidPitch) { + return mlir::failure(); + } + // Get flat shape size + int64_t flatSize = 1; + for (auto dim : srcShape) { + flatSize *= dim; + } + + // reinterpret_cast to flat memref of flatSize + mlir::MemRefLayoutAttrInterface layout = {}; + // Is source offset always 0? No API to check. + auto flatMemref = rewriter.create( + initTileOp.getLoc(), + mlir::MemRefType::get({flatSize}, initTileOp.getSourceMemrefElemType(), + layout, initTileOp.getSourceMemorySpace()), + initTileOp.getSource(), 0, llvm::ArrayRef{flatSize}, + llvm::ArrayRef{1}); + + // Create indices for scatter + auto offsets = initTileOp.getMixedOffsets(); + auto loc = initTileOp.getLoc(); + auto offsetX = imex::getValueOrConstantOp(offsets[0], loc, rewriter, + rewriter.getIndexType()); + auto offsetY = imex::getValueOrConstantOp(offsets[1], loc, rewriter, + rewriter.getIndexType()); + auto indexVecTy = + mlir::VectorType::get(tileTy.getShape(), rewriter.getIndexType()); + bool isSingleCol = tileTy.getShape().back() == 1; + bool isSingleRow = tileTy.getShape().front() == 1; + auto rowIndexVecTy = mlir::VectorType::get({tileTy.getShape().front()}, + rewriter.getIndexType()); + auto colIndexVecTy = mlir::VectorType::get({tileTy.getShape().back()}, + rewriter.getIndexType()); + + // Create + // [0, ...., TileShape[1]-1] broadcasted to TileShape + // if isSingleCol, splat offsetY to TileShape + mlir::Value stepOffsetTile; + if (isSingleCol) { + stepOffsetTile = rewriter.createOrFold( + loc, indexVecTy, offsetY); + } else { + auto stepVec = + rewriter.createOrFold(loc, colIndexVecTy); + auto stepTile = rewriter.createOrFold( + loc, indexVecTy, stepVec); + auto offsetYTile = rewriter.createOrFold( + loc, indexVecTy, offsetY); + // Add offsetY to step + stepOffsetTile = rewriter.createOrFold( + loc, indexVecTy, stepTile, offsetYTile); + } + + // create [0, 1, 2, ...., TileShape[0]-1]^T broadcasted to TileShape + // if isSingleRow, splat offsetX to TileShape + mlir::Value rowOffsetTile; + if (isSingleRow) { + rowOffsetTile = rewriter.createOrFold( + loc, indexVecTy, offsetX); + } else { + auto rowVecT = + rewriter.createOrFold(loc, rowIndexVecTy); + auto offsetXVec = rewriter.createOrFold( + loc, + mlir::VectorType::get({tileTy.getShape().front()}, + rewriter.getIndexType()), + offsetX); + // Add offsetX to rowVecT + auto rowOffsetVecT = rewriter.createOrFold( + loc, rowIndexVecTy, rowVecT, offsetXVec); + // reshape to row x 1 + auto rowOffsetVec = rewriter.createOrFold( + loc, + mlir::VectorType::get({tileTy.getShape().front(), 1}, + rewriter.getIndexType()), + rowOffsetVecT); + // broadcast to TileShape + rowOffsetTile = rewriter.createOrFold( + loc, indexVecTy, rowOffsetVec); + } + + // create [pitchNumElems] splatted to TileShape + auto stride = rewriter.createOrFold( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(pitchNumElems)); + auto strideTile = + rewriter.createOrFold(loc, indexVecTy, stride); + // Create a temp with just rowTile * strideTile + auto rowStrideTile = rewriter.createOrFold( + loc, indexVecTy, rowOffsetTile, strideTile); + // Create scatter indices complete row*stride + step + auto indices = mlir::dyn_cast_or_null>( + rewriter.createOrFold( + loc, indexVecTy, rowStrideTile, stepOffsetTile)); + if (!indices) { + return rewriter.notifyMatchFailure(initTileOp, + "Could not generate scatter indices."); + } + // Add scatter attribute to tile type + auto scatterTileTy = addScatterAttr(tileTy); + // Replace InitTileOp + rewriter.replaceOpWithNewOp( + initTileOp, scatterTileTy, flatMemref, indices); + + return mlir::success(); + } +}; + +struct LoadTileOpPattern final + : public mlir::OpRewritePattern { + LoadTileOpPattern(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(imex::xetile::LoadTileOp loadTileOp, + mlir::PatternRewriter &rewriter) const override { + auto tile = loadTileOp.getSource(); + auto tileTy = tile.getType(); + if (!tileTy.getScatterAttr()) { + return mlir::failure(); + } + auto one = rewriter.createOrFold( + loadTileOp.getLoc(), rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto mask = rewriter.createOrFold( + loadTileOp.getLoc(), + mlir::VectorType::get(tileTy.getShape(), rewriter.getI1Type()), one); + rewriter.replaceOpWithNewOp( + loadTileOp, loadTileOp.getType(), tile, mask, + loadTileOp.getPaddingAttr(), loadTileOp.getL1HintAttr(), + loadTileOp.getL2HintAttr(), loadTileOp.getL3HintAttr()); + return mlir::success(); + } +}; + +struct StoreTileOpPattern final + : public mlir::OpRewritePattern { + StoreTileOpPattern(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(imex::xetile::StoreTileOp storeTileOp, + mlir::PatternRewriter &rewriter) const override { + auto tile = storeTileOp.getTile(); + auto tileTy = tile.getType(); + if (!tileTy.getScatterAttr()) { + return mlir::failure(); + } + auto one = rewriter.createOrFold( + storeTileOp.getLoc(), rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto mask = rewriter.createOrFold( + storeTileOp.getLoc(), + mlir::VectorType::get(tileTy.getShape(), rewriter.getI1Type()), one); + rewriter.replaceOpWithNewOp( + storeTileOp, storeTileOp.getValue(), tile, mask, + storeTileOp.getL1HintAttr(), storeTileOp.getL2HintAttr(), + storeTileOp.getL3HintAttr()); + return mlir::success(); + } +}; + +static imex::xetile::InitTileOp +getInitTileOp(mlir::TypedValue tile) { + // Three sources of tile value + auto dop = tile.getDefiningOp(); + // 1. BlockArgument of scf.for + if (!dop) { + if (!mlir::isa(tile)) { + return nullptr; + } + auto blockArg = llvm::cast(tile); + auto blockArgNum = blockArg.getArgNumber(); + auto parentOp = blockArg.getOwner()->getParentOp(); + if (!mlir::isa(parentOp)) { + return nullptr; + } + auto scfForOp = mlir::dyn_cast(parentOp); + auto numInductionVars = scfForOp.getNumInductionVars(); + auto init = scfForOp.getInits()[blockArgNum - numInductionVars]; + if (!mlir::isa>(init)) { + return nullptr; + } + return getInitTileOp( + mlir::dyn_cast>(init)); + } + // 2. InitTileOp + if (mlir::isa(dop)) { + return mlir::dyn_cast(dop); + } + // 3. UpdateTileOffsetOp + else if (mlir::isa(dop)) { + auto updateTileOffsetOp = + mlir::dyn_cast(dop); + return getInitTileOp(updateTileOffsetOp.getTile()); + } + return nullptr; +} + +struct UpdateTileOffsetOpPattern final + : public mlir::OpRewritePattern { + UpdateTileOffsetOpPattern(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(imex::xetile::UpdateTileOffsetOp updateTileOffsetOp, + mlir::PatternRewriter &rewriter) const override { + auto tile = updateTileOffsetOp.getTile(); + auto tileTy = tile.getType(); + if (!tileTy.getScatterAttr()) { + return mlir::failure(); + } + // Return if indices are already set + if (updateTileOffsetOp.getIndices()) { + return mlir::failure(); + } + auto initTileOp = getInitTileOp(tile); + if (!initTileOp) { + return rewriter.notifyMatchFailure(updateTileOffsetOp, + "Could not find InitTileOp."); + } + auto srcMemref = initTileOp.getSource(); + auto castOp = srcMemref.getDefiningOp(); + if (!castOp || !mlir::isa(castOp)) { + return rewriter.notifyMatchFailure(updateTileOffsetOp, + "Source is not flat memref."); + } + auto reinterOp = mlir::dyn_cast(castOp); + auto baseMemref = reinterOp.getSource(); + if (!mlir::isa(baseMemref.getType())) { + return rewriter.notifyMatchFailure(updateTileOffsetOp, + "Source is not a ranked memref."); + } + auto baseMemrefTy = mlir::dyn_cast(baseMemref.getType()); + auto baseShape = baseMemrefTy.getShape(); + auto pitchNumElems = baseShape[baseShape.size() - 1]; + auto loc = updateTileOffsetOp.getLoc(); + // Create update indices by doing vector.splat with (offX*stride + offY) + auto pitch = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(pitchNumElems)); + auto offX = updateTileOffsetOp.getOffsetX(); + auto stride = rewriter.createOrFold( + loc, rewriter.getIndexType(), offX, pitch); + auto offY = updateTileOffsetOp.getOffsetY(); + auto index = rewriter.createOrFold( + loc, rewriter.getIndexType(), stride, offY); + auto indices = rewriter.createOrFold( + loc, mlir::VectorType::get(tileTy.getShape(), rewriter.getIndexType()), + index); + rewriter.replaceOpWithNewOp( + updateTileOffsetOp, tile, nullptr, nullptr, indices); + return mlir::success(); + } +}; + +struct SCFForOpPattern final : public mlir::OpRewritePattern { + SCFForOpPattern(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(mlir::scf::ForOp scfForOp, + mlir::PatternRewriter &rewriter) const override { + auto initArgs = scfForOp.getInitArgs(); + auto regionIterArgs = scfForOp.getRegionIterArgs(); + auto results = scfForOp.getResults(); + bool isUpdated = false; + for (auto [init, arg, res] : + llvm::zip_equal(initArgs, regionIterArgs, results)) { + auto initTy = init.getType(); + if (mlir::isa(initTy)) { + if (!mlir::isa(arg.getType()) || + !mlir::isa(res.getType())) { + return rewriter.notifyMatchFailure(scfForOp, "TileType mismatch."); + } + auto initTileTy = mlir::dyn_cast(initTy); + if (initTileTy.getScatterAttr()) { + auto argTileTy = + mlir::dyn_cast(arg.getType()); + if (argTileTy.getScatterAttr()) { + continue; + } + auto scatterTileTy = addScatterAttr(argTileTy); + arg.setType(scatterTileTy); + res.setType(scatterTileTy); + isUpdated = true; + } + } + } + if (!isUpdated) { + return mlir::failure(); + } + return mlir::success(); + } +}; + +struct XeTileBlockOpFallbackPass final + : public imex::impl::XeTileBlockOpFallbackBase { + void runOnOperation() override { + auto *context = &getContext(); + mlir::Operation *op = getOperation(); + + mlir::RewritePatternSet patterns(context); + mlir::GreedyRewriteConfig config; + config.enableRegionSimplification = + mlir::GreedySimplifyRegionLevel::Disabled; + config.useTopDownTraversal = true; + config.strictMode = mlir::GreedyRewriteStrictness::ExistingAndNewOps; + patterns.add(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) { + return signalPassFailure(); + } + } +}; + +} // namespace blockopfallback + +namespace imex { +std::unique_ptr createXeTileBlockOpFallbackPass() { + return std::make_unique(); +} +} // namespace imex diff --git a/lib/Dialect/XeTile/Transforms/CMakeLists.txt b/lib/Dialect/XeTile/Transforms/CMakeLists.txt index 5d74fdda2..fb20cb9fd 100644 --- a/lib/Dialect/XeTile/Transforms/CMakeLists.txt +++ b/lib/Dialect/XeTile/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_imex_dialect_library(IMEXXeTileTransforms Blocking.cpp BlockingAnalysis.cpp + BlockOpFallback.cpp InitDuplicate.cpp Canonicalization.cpp WgToSg.cpp diff --git a/lib/Transforms/RemoveSingleElemVector.cpp b/lib/Transforms/RemoveSingleElemVector.cpp index 29ef2475e..b39495828 100644 --- a/lib/Transforms/RemoveSingleElemVector.cpp +++ b/lib/Transforms/RemoveSingleElemVector.cpp @@ -253,7 +253,9 @@ struct RemoveSingleElemVectorPass final }); mlir::RewritePatternSet patterns(context); - patterns.add( typeConverter, context); diff --git a/test/Dialect/XeTile/Transforms/block_op_fallback.mlir b/test/Dialect/XeTile/Transforms/block_op_fallback.mlir new file mode 100644 index 000000000..bcf4d76a2 --- /dev/null +++ b/test/Dialect/XeTile/Transforms/block_op_fallback.mlir @@ -0,0 +1,391 @@ +// RUN: imex-opt --split-input-file --xetile-blockop-fallback %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @test_module { + // CHECK-LABEL: @test_pitch_one_elems_and_offset_attr + gpu.func @test_pitch_one_elems_and_offset_attr(%arg0: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR3:.*]] = xetile.load %[[VAR2]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_pitch_one_elems_and_offset_vars + gpu.func @test_pitch_one_elems_and_offset_vars(%arg0: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %cst0 = arith.constant 0 : index + %0 = xetile.init_tile %arg0 [%cst0, %cst0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR3:.*]] = xetile.load %[[VAR2]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_pitch_one_elems_and_mixed_offsets + gpu.func @test_pitch_one_elems_and_mixed_offsets(%arg0: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %cst0 = arith.constant 0 : index + %0 = xetile.init_tile %arg0 [%cst0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR3:.*]] = xetile.load %[[VAR2]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_pitch_two_elems + gpu.func @test_pitch_two_elems(%arg0: memref<512x2xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CST0:.*]] = arith.constant dense<2> : vector<32x1xindex> + // CHECK: %[[CST1:.*]] = arith.constant dense<16> : vector<32xindex> + // CHECK: %[[CST2:.*]] = arith.constant dense<1> : vector<32x1xindex> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1024], strides: [1] : memref<512x2xf32> to memref<1024xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = arith.addi %[[VAR0]], %[[CST1]] : vector<32xindex> + // CHECK: %[[VAR2:.*]] = vector.shape_cast %[[VAR1]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR2]], %[[CST0]] : vector<32x1xindex> + // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR3]], %[[CST2]] : vector<32x1xindex> + // CHECK: %[[VAR5:.*]] = xetile.init_tile %[[CAST]], %[[VAR4]] : memref<1024xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %0 = xetile.init_tile %arg0 [16, 1] : memref<512x2xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR6:.*]] = xetile.load %[[VAR5]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_store_tile + gpu.func @test_store_tile(%arg0: memref<512x1xf32>, %arg1: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[CAST0:.*]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR3:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR4:.*]] = vector.shape_cast %[[VAR3:.*]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR5:.*]] = xetile.init_tile %[[CAST0]], %[[VAR4]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR6:.*]] = xetile.load %[[VAR2]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %2 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + // CHECK: xetile.store %[[VAR6]], %[[VAR5]], %[[CST]] : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> + xetile.store_tile %2, %1 : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_update_tile_offset + gpu.func @test_update_tile_offset(%arg0: memref<512x1xf32>, %arg1: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<32> : vector<32x1xindex> + // CHECK: %[[CST0:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %cst0 = arith.constant 0 : index + %cst32 = arith.constant 32 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR3:.*]] = xetile.load %[[VAR2]], %[[CST0]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + // CHECK: %[[VAR4:.*]] = xetile.update_tile_offset %[[VAR2]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xindex> + %2 = xetile.update_tile_offset %0, [%cst32, %cst0] : !xetile.tile<32x1xf32, #xetile.tile_attr> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_multiple_update_tile_offset + gpu.func @test_multiple_update_tile_offset(%arg0: memref<512x1xf32>, %arg1: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<16> : vector<32x1xindex> + // CHECK: %[[CST0:.*]] = arith.constant dense<32> : vector<32x1xindex> + // CHECK: %[[CST1:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %cst0 = arith.constant 0 : index + %cst32 = arith.constant 32 : index + %cst16 = arith.constant 16 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR6:.*]] = xetile.load %[[VAR2]], %[[CST1]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + // CHECK: %[[VAR7:.*]] = xetile.update_tile_offset %[[VAR2]], %[[CST0]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xindex> + %2 = xetile.update_tile_offset %0, [%cst32, %cst0] : !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR8:.*]] = xetile.update_tile_offset %[[VAR7]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xindex> + %3 = xetile.update_tile_offset %2, [%cst16, %cst0] : !xetile.tile<32x1xf32, #xetile.tile_attr> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_scf_for + gpu.func @test_scf_for(%arg0: memref<512x1xf32>, %arg1: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<32> : vector<32x1xindex> + // CHECK: %[[CST0:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[C480:.*]] = arith.constant 480 : index + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR3:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR5:.*]] = xetile.init_tile %[[CAST]], %[[VAR3]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %cst0 = arith.constant 0 : index + %cst32 = arith.constant 32 : index + %cst512 = arith.constant 512 : index + %cst480 = arith.constant 480 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[CAST1:.*]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR6:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR9:.*]] = vector.shape_cast %[[VAR6]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR11:.*]] = xetile.init_tile %[[CAST1]], %[[VAR9]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR12:.*]]:2 = scf.for %arg2 = %[[C0]] to %[[C480]] step %[[C32]] iter_args(%arg3 = %[[VAR5]], %arg4 = %[[VAR11]]) -> (!xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr>) { + %out:2 = scf.for %k = %cst0 to %cst480 step %cst32 + iter_args(%a_tile = %0, %b_tile = %1) + -> (!xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr>) { + // CHECK: %[[VAR14:.*]] = xetile.load %arg3, %[[CST0]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %a_value = xetile.load_tile %a_tile : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + // CHECK: xetile.store %[[VAR14]], %arg4, %[[CST0]] : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> + xetile.store_tile %a_value, %b_tile : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR15:.*]] = xetile.update_tile_offset %arg3, %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xindex> + %a_next_tile = xetile.update_tile_offset %a_tile, [%cst32, %cst0] : !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR16:.*]] = xetile.update_tile_offset %arg4, %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xindex> + %b_next_tile = xetile.update_tile_offset %b_tile, [%cst32, %cst0] : !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: scf.yield %[[VAR15]], %[[VAR16]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr> + scf.yield %a_next_tile, %b_next_tile : !xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr> + } + // CHECK: %[[VAR13:.*]] = xetile.load %[[VAR12]]#0, %[[CST0]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %2 = xetile.load_tile %out#0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + // CHECK: xetile.store %[[VAR13]], %[[VAR12]]#1, %[[CST0]] : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> + xetile.store_tile %2, %out#1 : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr> + gpu.return + } +} + +// ----- + +module attributes {gpu.container_module} { + gpu.module @test_module { + // CHECK-LABEL: func @test_nested_scf_for + gpu.func @test_nested_scf_for(%arg0: memref<16384x1xf32>, %arg1: memref<16384x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<32> : vector<32x1xindex> + // CHECK: %[[CST0:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CST1:.*]] = arith.constant dense<128> : vector<32x1xindex> + // CHECK: %[[CST2:.*]] = arith.constant dense<512> : vector<32x1xindex> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C512:.*]] = arith.constant 512 : index + // CHECK: %[[C128:.*]] = arith.constant 128 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[C16384:.*]] = arith.constant 16384 : index + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [16384], strides: [1] : memref<16384x1xf32> to memref<16384xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR3:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR5:.*]] = xetile.init_tile %[[CAST]], %[[VAR3]] : memref<16384xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %c0 = arith.constant 0 : index + %c512 = arith.constant 512 : index + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c16384 = arith.constant 16384 : index + %1 = xetile.init_tile %arg0[%c0, %c0] : memref<16384x1xf32> -> !xetile.tile<32x1xf32> + // CHECK: %[[CAST3:.*]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [16384], strides: [1] : memref<16384x1xf32> to memref<16384xf32> + // CHECK: %[[VAR6:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR9:.*]] = vector.shape_cast %[[VAR6]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR11:.*]] = xetile.init_tile %[[CAST3]], %[[VAR9]] : memref<16384xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %2 = xetile.init_tile %arg1[%c0, %c0] : memref<16384x1xf32> -> !xetile.tile<32x1xf32> + // CHECK: %[[VAR12:.*]]:2 = scf.for %arg2 = %[[C0]] to %[[C16384]] step %[[C512]] iter_args(%arg3 = %[[VAR5]], %arg4 = %[[VAR11]]) -> (!xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr>) { + %3:2 = scf.for %arg3 = %c0 to %c16384 step %c512 iter_args(%arg4 = %1, %arg5 = %2) -> (!xetile.tile<32x1xf32>, !xetile.tile<32x1xf32>) { + %4 = xetile.update_tile_offset %arg4, [%c512, %c0] : !xetile.tile<32x1xf32> + %5 = xetile.update_tile_offset %arg5, [%c512, %c0] : !xetile.tile<32x1xf32> + // CHECK: %[[VAR15:.*]]:2 = scf.for %arg5 = %[[C0]] to %[[C512]] step %[[C128]] iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (!xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr>) { + %6:2 = scf.for %arg6 = %c0 to %c512 step %c128 iter_args(%arg7 = %arg4, %arg8 = %arg5) -> (!xetile.tile<32x1xf32>, !xetile.tile<32x1xf32>) { + %7 = xetile.update_tile_offset %arg7, [%c128, %c0] : !xetile.tile<32x1xf32> + %8 = xetile.update_tile_offset %arg8, [%c128, %c0] : !xetile.tile<32x1xf32> + // CHECK: %[[VAR18:.*]]:2 = scf.for %arg8 = %[[C0]] to %[[C128]] step %[[C32]] iter_args(%arg9 = %arg6, %arg10 = %arg7) -> (!xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr>) { + %9:2 = scf.for %arg9 = %c0 to %c128 step %c32 iter_args(%arg10 = %arg7, %arg11 = %arg8) -> (!xetile.tile<32x1xf32>, !xetile.tile<32x1xf32>) { + %10 = xetile.load_tile %arg10 : !xetile.tile<32x1xf32> -> vector<32x1xf32> + xetile.store_tile %10, %arg11 : vector<32x1xf32>, !xetile.tile<32x1xf32> + %11 = xetile.update_tile_offset %arg10, [%c32, %c0] : !xetile.tile<32x1xf32> + %12 = xetile.update_tile_offset %arg11, [%c32, %c0] : !xetile.tile<32x1xf32> + scf.yield %11, %12 : !xetile.tile<32x1xf32>, !xetile.tile<32x1xf32> + } + scf.yield %7, %8 : !xetile.tile<32x1xf32>, !xetile.tile<32x1xf32> + } + scf.yield %4, %5 : !xetile.tile<32x1xf32>, !xetile.tile<32x1xf32> + } + gpu.return + } + } +} + +// ----- + +module attributes {gpu.container_module} { + func.func @postop_reduce_m_entry(%arg0: memref<16384x12288xbf16>, %arg1: memref<2048x12288xbf16>, %arg2: memref<32x2048xf32>) attributes {gemm_tiles_b = 1 : i64, gemm_tiles_x = dense<[8, 2, 4, 8]> : vector<4xi64>, gemm_tiles_y = dense<[1, 2, 8, 4]> : vector<4xi64>, physical_nd_range = dense<[8, 32]> : vector<2xi64>, region_partition = 0 : i64, region_size = 32 : i64, syn.fusion_successful, syn.tensor_signature = (tensor<16384x12288xbf16>, tensor<2048x12288xbf16>) -> tensor<32x2048xf32>, synFusionGenOps = 6 : i64, synFusionRequiredBeamSize = 1 : i64, synFusionTotalCost = 1003595802.6 : f64} { + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + gpu.launch_func @postop_reduce_m::@postop_reduce_m blocks in (%c8, %c32, %c1) threads in (%c8, %c4, %c1) args(%arg0 : memref<16384x12288xbf16>, %arg1 : memref<2048x12288xbf16>, %arg2 : memref<32x2048xf32>) + return + } + gpu.module @postop_reduce_m attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + // CHECK-LABEL: func @postop_reduce_m + gpu.func @postop_reduce_m(%arg0: memref<16384x12288xbf16>, %arg1: memref<2048x12288xbf16>, %arg2: memref<32x2048xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<8x4xi1> + // CHECK: %[[CST0:.*]] = arith.constant dense<128> : vector<8x4xindex> + // CHECK: %[[CST1:.*]] = arith.constant dense : vector<1x32xi1> + // CHECK: %[[CST2:.*]] = arith.constant dense<128> : vector<1x32xindex> + %c12288 = arith.constant 12288 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + %c128 = arith.constant 128 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %cst = arith.constant dense<0.000000e+00> : vector<32x32xf32> + %cst_0 = arith.constant dense<0.000000e+00> : vector<1x32xf32> + %cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32> + %cst_2 = arith.constant dense<0.000000e+00> : vector<1x4xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.divsi %block_id_y, %c8 : index + %1 = arith.remsi %block_id_y, %c8 : index + %2 = arith.muli %block_id_x, %c4 : index + %3 = arith.addi %2, %0 : index + %4 = arith.muli %1, %c128 : index + %5 = gpu.subgroup_id : index + %6 = index.floordivs %5, %c32 + %7 = index.remu %5, %c32 + %8 = index.remu %6, %c1 + %9 = index.add %3, %8 + %10 = index.remu %7, %c32 + %11 = index.mul %10, %c4 + %12 = index.add %4, %11 + %13 = xetile.init_tile %arg2[%9, %12] : memref<32x2048xf32> -> !xetile.tile<1x4xf32> + %14 = arith.muli %block_id_x, %c2048 : index + %15 = arith.muli %0, %c256 : index + %16 = arith.addi %14, %15 : index + %17 = index.floordivs %5, %c4 + %18 = index.remu %5, %c4 + %19 = index.remu %17, %c8 + %20 = index.mul %19, %c32 + %21 = index.add %16, %20 + %22 = index.remu %18, %c1 + %23 = index.mul %22, %c32 + %24 = xetile.init_tile %arg0[%21, %23] : memref<16384x12288xbf16> -> !xetile.tile<32x32xbf16> + %25 = index.floordivs %5, %c8 + %26 = index.remu %5, %c8 + %27 = index.remu %26, %c4 + %28 = index.mul %27, %c32 + %29 = index.add %4, %28 + %30 = index.remu %25, %c1 + %31 = index.mul %30, %c32 + %32 = xetile.init_tile %arg1[%29, %31] : memref<2048x12288xbf16> -> !xetile.tile<32x32xbf16> + %33:2 = scf.for %arg3 = %c0 to %c2 step %c1 iter_args(%arg4 = %13, %arg5 = %32) -> (!xetile.tile<1x4xf32>, !xetile.tile<32x32xbf16>) { + %34 = xetile.update_tile_offset %arg5, [%c1024, %c0] : !xetile.tile<32x32xbf16> + %35 = xetile.update_tile_offset %arg4, [%c0, %c1024] : !xetile.tile<1x4xf32> + %36:2 = scf.for %arg6 = %c0 to %c2 step %c1 iter_args(%arg7 = %cst_2, %arg8 = %24) -> (vector<1x4xf32>, !xetile.tile<32x32xbf16>) { + %37 = xetile.update_tile_offset %arg8, [%c1024, %c0] : !xetile.tile<32x32xbf16> + %38:3 = scf.for %arg9 = %c0 to %c12288 step %c32 iter_args(%arg10 = %cst, %arg11 = %arg8, %arg12 = %arg5) -> (vector<32x32xf32>, !xetile.tile<32x32xbf16>, !xetile.tile<32x32xbf16>) { + %56 = xetile.update_tile_offset %arg12, [%c0, %c32] : !xetile.tile<32x32xbf16> + %57 = xetile.update_tile_offset %arg11, [%c0, %c32] : !xetile.tile<32x32xbf16> + %58 = xetile.load_tile %arg11 : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> + %59 = math.exp %58 : vector<32x32xbf16> + %60 = xetile.load_tile %arg12 : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> + %61 = xetile.transpose %60, [1, 0] : vector<32x32xbf16> -> vector<32x32xbf16> + xegpu.compile_hint + %62 = xetile.tile_mma %59, %61, %arg10 : vector<32x32xbf16>, vector<32x32xbf16>, vector<32x32xf32> -> vector<32x32xf32> + xegpu.compile_hint + scf.yield %62, %57, %56 : vector<32x32xf32>, !xetile.tile<32x32xbf16>, !xetile.tile<32x32xbf16> + } + %39 = math.exp %38#0 : vector<32x32xf32> + %40 = vector.shape_cast %cst_0 : vector<1x32xf32> to vector<32xf32> + %41 = xetile.reduction , %39 [0] : vector<32x32xf32> -> vector<1x32xf32> + %42 = vector.shape_cast %41 : vector<1x32xf32> to vector<32xf32> + %43 = arith.addf %42, %40 : vector<32xf32> + %44 = vector.shape_cast %43 : vector<32xf32> to vector<1x32xf32> + %alloc = memref.alloc() : memref<4096xi8, 3> + %view = memref.view %alloc[%c0][] : memref<4096xi8, 3> to memref<8x128xf32, 3> + %45 = index.mul %18, %c32 + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[VIEW:.*]] to offset: [0], sizes: [1024], strides: [1] : memref<8x128xf32, 3> to memref<1024xf32, 3> + // CHECK: %[[VAR48:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR49:.*]] = vector.broadcast %[[VAR48]] : vector<32xindex> to vector<1x32xindex> + // CHECK: %[[VAR50:.*]] = vector.splat %[[VAR45:.*]] : vector<1x32xindex> + // CHECK: %[[VAR52:.*]] = arith.addi %[[VAR49]], %[[VAR50]] : vector<1x32xindex> + // CHECK: %[[VAR54:.*]] = vector.splat %[[VAR17:.*]] : vector<1x32xindex> + // CHECK: %[[VAR55:.*]] = arith.muli %[[VAR54]], %[[CST2]] : vector<1x32xindex> + // CHECK: %[[VAR56:.*]] = arith.addi %[[VAR55]], %[[VAR52]] : vector<1x32xindex> + // CHECK: %[[VAR57:.*]] = xetile.init_tile %[[CAST]], %[[VAR56]] : memref<1024xf32, 3>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %46 = xetile.init_tile %view[%17, %45] : memref<8x128xf32, 3> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + // CHECK: xetile.store %[[VAR44:.*]], %[[VAR57]], %[[CST1]] : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> + xetile.store_tile %44, %46 : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr> + gpu.barrier + // CHECK: %[[VAR58:.*]] = index.mul + // CHECK: %[[VAR59:.*]] = index.mul + %47 = index.mul %6, %c8 + %48 = index.mul %7, %c4 + // CHECK: %[[CAST7:.*]] = memref.reinterpret_cast %[[VIEW]] to offset: [0], sizes: [1024], strides: [1] : memref<8x128xf32, 3> to memref<1024xf32, 3> + // CHECK: %[[VAR62:.*]] = vector.step : vector<4xindex> + // CHECK: %[[VAR63:.*]] = vector.broadcast %[[VAR62]] : vector<4xindex> to vector<8x4xindex> + // CHECK: %[[VAR61:.*]] = vector.splat %[[VAR59]] : vector<8x4xindex> + // CHECK: %[[VAR64:.*]] = arith.addi %[[VAR63]], %[[VAR61]] : vector<8x4xindex> + // CHECK: %[[VAR65:.*]] = vector.step : vector<8xindex> + // CHECK: %[[VAR60:.*]] = vector.splat %[[VAR58]] : vector<8xindex> + // CHECK: %[[VAR66:.*]] = arith.addi %[[VAR65]], %[[VAR60]] : vector<8xindex> + // CHECK: %[[VAR67:.*]] = vector.shape_cast %[[VAR66]] : vector<8xindex> to vector<8x1xindex> + // CHECK: %[[VAR68:.*]] = vector.broadcast %[[VAR67]] : vector<8x1xindex> to vector<8x4xindex> + // CHECK: %[[VAR69:.*]] = arith.muli %[[VAR68]], %[[CST0]] : vector<8x4xindex> + // CHECK: %[[VAR70:.*]] = arith.addi %[[VAR69]], %[[VAR64]] : vector<8x4xindex> + // CHECK: %[[VAR71:.*]] = xetile.init_tile %[[CAST7]], %[[VAR70]] : memref<1024xf32, 3>, vector<8x4xindex> -> !xetile.tile<8x4xf32, #xetile.tile_attr> + %49 = xetile.init_tile %view[%47, %48] : memref<8x128xf32, 3> -> !xetile.tile<8x4xf32, #xetile.tile_attr> + // CHECK: %[[VAR72:.*]] = xetile.load %[[VAR71]], %[[CST]] : !xetile.tile<8x4xf32, #xetile.tile_attr>, vector<8x4xi1> -> vector<8x4xf32> + %50 = xetile.load_tile %49 : !xetile.tile<8x4xf32, #xetile.tile_attr> -> vector<8x4xf32> + %51 = xetile.reduction , %50 [0] : vector<8x4xf32> -> vector<1x4xf32> + %52 = vector.shape_cast %51 : vector<1x4xf32> to vector<4xf32> + %53 = arith.addf %52, %cst_1 : vector<4xf32> + %54 = vector.shape_cast %53 : vector<4xf32> to vector<1x4xf32> + %55 = arith.addf %54, %arg7 : vector<1x4xf32> + scf.yield %55, %37 : vector<1x4xf32>, !xetile.tile<32x32xbf16> + } + xetile.store_tile %36#0, %arg4 : vector<1x4xf32>, !xetile.tile<1x4xf32> + scf.yield %35, %34 : !xetile.tile<1x4xf32>, !xetile.tile<32x32xbf16> + } + gpu.return + } + } +} diff --git a/test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir b/test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir new file mode 100644 index 000000000..445e96a82 --- /dev/null +++ b/test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir @@ -0,0 +1,64 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +module @narrow_tile attributes {gpu.container_module} { + func.func @test(%A: memref<64x1xf32>) -> memref<64x1xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %A_gpu = gpu.alloc host_shared() : memref<64x1xf32> + memref.copy %A, %A_gpu : memref<64x1xf32> to memref<64x1xf32> + %B_gpu = gpu.alloc host_shared() : memref<64x1xf32> + gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<64x1xf32>, %B_gpu : memref<64x1xf32>) + %B = memref.alloc() : memref<64x1xf32> + memref.copy %B_gpu, %B : memref<64x1xf32> to memref<64x1xf32> + gpu.dealloc %A_gpu : memref<64x1xf32> + gpu.dealloc %B_gpu : memref<64x1xf32> + return %B : memref<64x1xf32> + } + gpu.module @test_module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_scf_for(%arg0: memref<64x1xf32>, %arg1: memref<64x1xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst0 = arith.constant 0 : index + %cst16 = arith.constant 16 : index + %cst64 = arith.constant 64 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<64x1xf32> -> !xetile.tile<16x1xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1 [0, 0] : memref<64x1xf32> -> !xetile.tile<16x1xf32, #xetile.tile_attr> + %out:2 = scf.for %k = %cst0 to %cst64 step %cst16 + iter_args(%a_tile = %0, %b_tile = %1) + -> (!xetile.tile<16x1xf32, #xetile.tile_attr>, !xetile.tile<16x1xf32, #xetile.tile_attr>) { + %a_value = xetile.load_tile %a_tile : !xetile.tile<16x1xf32, #xetile.tile_attr> -> vector<16x1xf32> + xetile.store_tile %a_value, %b_tile : vector<16x1xf32>, !xetile.tile<16x1xf32, #xetile.tile_attr> + %a_next_tile = xetile.update_tile_offset %a_tile, [%cst16, %cst0] : !xetile.tile<16x1xf32, #xetile.tile_attr> + %b_next_tile = xetile.update_tile_offset %b_tile, [%cst16, %cst0] : !xetile.tile<16x1xf32, #xetile.tile_attr> + scf.yield %a_next_tile, %b_next_tile : !xetile.tile<16x1xf32, #xetile.tile_attr>, !xetile.tile<16x1xf32, #xetile.tile_attr> + } + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %A = memref.alloc() : memref<64x1xf32> + scf.for %arg0 = %c0 to %c64 step %c1 { + %0 = index.castu %arg0 : index to i32 + %val = arith.uitofp %0 : i32 to f32 + memref.store %val, %A[%arg0, %c0] : memref<64x1xf32> + } + %C = call @test(%A) : (memref<64x1xf32>) -> memref<64x1xf32> + %cast_A = memref.cast %A : memref<64x1xf32> to memref<*xf32> + %cast_C = memref.cast %C : memref<64x1xf32> to memref<*xf32> + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_A) : (memref<*xf32>, memref<*xf32>) -> () + //call @printMemrefF32(%cast_A) : (memref<*xf32>) -> () + //call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + return + } + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} + //func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir b/test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir new file mode 100644 index 000000000..bf85261cc --- /dev/null +++ b/test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir @@ -0,0 +1,65 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +module @narrow_tile attributes {gpu.container_module} { + func.func @test(%A: memref<64x2xf32>) -> memref<64x2xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %A_gpu = gpu.alloc host_shared() : memref<64x2xf32> + memref.copy %A, %A_gpu : memref<64x2xf32> to memref<64x2xf32> + %B_gpu = gpu.alloc host_shared() : memref<64x2xf32> + gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<64x2xf32>, %B_gpu : memref<64x2xf32>) + %B = memref.alloc() : memref<64x2xf32> + memref.copy %B_gpu, %B : memref<64x2xf32> to memref<64x2xf32> + gpu.dealloc %A_gpu : memref<64x2xf32> + gpu.dealloc %B_gpu : memref<64x2xf32> + return %B : memref<64x2xf32> + } + gpu.module @test_module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_scf_for(%arg0: memref<64x2xf32>, %arg1: memref<64x2xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst0 = arith.constant 0 : index + %cst16 = arith.constant 16 : index + %cst64 = arith.constant 64 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<64x2xf32> -> !xetile.tile<16x2xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1 [0, 0] : memref<64x2xf32> -> !xetile.tile<16x2xf32, #xetile.tile_attr> + %out:2 = scf.for %k = %cst0 to %cst64 step %cst16 + iter_args(%a_tile = %0, %b_tile = %1) + -> (!xetile.tile<16x2xf32, #xetile.tile_attr>, !xetile.tile<16x2xf32, #xetile.tile_attr>) { + %a_value = xetile.load_tile %a_tile : !xetile.tile<16x2xf32, #xetile.tile_attr> -> vector<16x2xf32> + xetile.store_tile %a_value, %b_tile : vector<16x2xf32>, !xetile.tile<16x2xf32, #xetile.tile_attr> + %a_next_tile = xetile.update_tile_offset %a_tile, [%cst16, %cst0] : !xetile.tile<16x2xf32, #xetile.tile_attr> + %b_next_tile = xetile.update_tile_offset %b_tile, [%cst16, %cst0] : !xetile.tile<16x2xf32, #xetile.tile_attr> + scf.yield %a_next_tile, %b_next_tile : !xetile.tile<16x2xf32, #xetile.tile_attr>, !xetile.tile<16x2xf32, #xetile.tile_attr> + } + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %A = memref.alloc() : memref<64x2xf32> + scf.for %arg0 = %c0 to %c64 step %c1 { + %0 = index.castu %arg0 : index to i32 + %val = arith.uitofp %0 : i32 to f32 + memref.store %val, %A[%arg0, %c0] : memref<64x2xf32> + memref.store %val, %A[%arg0, %c1] : memref<64x2xf32> + } + %C = call @test(%A) : (memref<64x2xf32>) -> memref<64x2xf32> + %cast_A = memref.cast %A : memref<64x2xf32> to memref<*xf32> + %cast_C = memref.cast %C : memref<64x2xf32> to memref<*xf32> + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_A) : (memref<*xf32>, memref<*xf32>) -> () + //call @printMemrefF32(%cast_A) : (memref<*xf32>) -> () + //call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + return + } + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} + //func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/fallback/slm.mlir b/test/Integration/Dialect/XeTile/fallback/slm.mlir new file mode 100644 index 000000000..4bc81a351 --- /dev/null +++ b/test/Integration/Dialect/XeTile/fallback/slm.mlir @@ -0,0 +1,80 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +module @narrow_tile attributes {gpu.container_module} { + func.func @test(%A: memref<32x32xf32>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %A_gpu = gpu.alloc host_shared() : memref<32x32xf32> + memref.copy %A, %A_gpu : memref<32x32xf32> to memref<32x32xf32> + %B_gpu = gpu.alloc host_shared() : memref<32x32xf32> + gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<32x32xf32>, %B_gpu : memref<32x32xf32>) + %B = memref.alloc() : memref<32x32xf32> + memref.copy %B_gpu, %B : memref<32x32xf32> to memref<32x32xf32> + gpu.dealloc %A_gpu : memref<32x32xf32> + gpu.dealloc %B_gpu : memref<32x32xf32> + return %B : memref<32x32xf32> + } + gpu.module @test_module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_scf_for(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst0 = arith.constant 0 : index + %cst8 = arith.constant 8 : index + %cst16 = arith.constant 16 : index + %cst32 = arith.constant 32 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<32x32xf32> -> !xetile.tile<8x16xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1 [0, 0] : memref<32x32xf32> -> !xetile.tile<8x16xf32, #xetile.tile_attr> + %slm = memref.alloc() : memref<8x16xf32, 3> + %slm_tile = xetile.init_tile %slm [0, 0] : memref<8x16xf32, 3> -> !xetile.tile<8x16xf32, #xetile.tile_attr> + %out:2 = scf.for %j = %cst0 to %cst32 step %cst8 + iter_args(%a_tile = %0, %b_tile = %1) + -> (!xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr>) { + %out:2 = scf.for %k = %cst0 to %cst32 step %cst16 + iter_args(%c_tile = %a_tile, %d_tile = %b_tile) + -> (!xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr>) { + %c_value = xetile.load_tile %c_tile : !xetile.tile<8x16xf32, #xetile.tile_attr> -> vector<8x16xf32> + xetile.store_tile %c_value, %slm_tile : vector<8x16xf32>, !xetile.tile<8x16xf32, #xetile.tile_attr> + %d_value = xetile.load_tile %slm_tile : !xetile.tile<8x16xf32, #xetile.tile_attr> -> vector<8x16xf32> + xetile.store_tile %d_value, %d_tile : vector<8x16xf32>, !xetile.tile<8x16xf32, #xetile.tile_attr> + %c_next_tile = xetile.update_tile_offset %c_tile, [%cst0, %cst16] : !xetile.tile<8x16xf32, #xetile.tile_attr> + %d_next_tile = xetile.update_tile_offset %d_tile, [%cst0, %cst16] : !xetile.tile<8x16xf32, #xetile.tile_attr> + scf.yield %c_next_tile, %d_next_tile : !xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr> + } + %a_next_tile = xetile.update_tile_offset %a_tile, [%cst8, %cst0] : !xetile.tile<8x16xf32, #xetile.tile_attr> + %b_next_tile = xetile.update_tile_offset %b_tile, [%cst8, %cst0] : !xetile.tile<8x16xf32, #xetile.tile_attr> + scf.yield %a_next_tile, %b_next_tile : !xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr> + } + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %A = memref.alloc() : memref<32x32xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %0 = index.castu %arg0 : index to i32 + %1 = index.castu %arg1 : index to i32 + %2 = arith.addi %0, %1 : i32 + %val = arith.uitofp %2 : i32 to f32 + memref.store %val, %A[%arg0, %arg1] : memref<32x32xf32> + } + } + %C = call @test(%A) : (memref<32x32xf32>) -> memref<32x32xf32> + %cast_A = memref.cast %A : memref<32x32xf32> to memref<*xf32> + %cast_C = memref.cast %C : memref<32x32xf32> to memref<*xf32> + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_A) : (memref<*xf32>, memref<*xf32>) -> () + //call @printMemrefF32(%cast_A) : (memref<*xf32>) -> () + //call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + return + } + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} + //func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp b/test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp new file mode 100644 index 000000000..3899a933c --- /dev/null +++ b/test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp @@ -0,0 +1,41 @@ +builtin.module( + cse + gpu.module(xetile-init-duplicate + xetile-canonicalization + xetile-blockop-fallback + xetile-blocking + cse + convert-xetile-to-xegpu + cse + imex-xegpu-hoist-transpose + imex-xegpu-apply-vnni-transformation + imex-xegpu-optimize-transpose) + cse + imex-vector-linearize + cse + imex-remove-single-elem-vector + canonicalize + cse + gpu.module(convert-xegpu-to-vc) + reconcile-unrealized-casts + bf16-to-gpu + cse + imex-convert-gpu-to-spirv + spirv.module(spirv-lower-abi-attrs + spirv-update-vce) + func.func(llvm-request-c-wrappers) + serialize-spirv + convert-vector-to-scf + convert-gpu-to-gpux + convert-scf-to-cf + expand-strided-metadata + finalize-memref-to-llvm + convert-cf-to-llvm + convert-vector-to-llvm + convert-index-to-llvm + convert-arith-to-llvm + convert-func-to-llvm + convert-math-to-llvm + convert-gpux-to-llvm + lower-affine + reconcile-unrealized-casts) diff --git a/test/Transforms/RemoveSingleElemVector/postop_reduce_n.mlir b/test/Transforms/RemoveSingleElemVector/postop_reduce_n.mlir index 464eb9507..0e4bde3b5 100644 --- a/test/Transforms/RemoveSingleElemVector/postop_reduce_n.mlir +++ b/test/Transforms/RemoveSingleElemVector/postop_reduce_n.mlir @@ -61,7 +61,9 @@ module { %34 = arith.remsi %11, %c4 : index %35 = scf.for %arg3 = %c0 to %c3 step %c1 iter_args(%arg4 = %cst) -> (vector<8x1xf32>) { %39 = vector.shape_cast %arg4 : vector<8x1xf32> to vector<8xf32> - //CHECK-COUNT-8: vector.extractelement {{.*}} : vector<8xf32> + // Disabling remove single elem vector.extra_stride_slice for now. + // DISABLE-CHECK-COUNT-8: vector.extractelement {{.*}} : vector<8xf32> + // CHECK-COUNT-8: vector.extract_strided_slice %40 = vector.extract_strided_slice %39 {offsets = [0], sizes = [1], strides = [1]} : vector<8xf32> to vector<1xf32> %41 = vector.extract_strided_slice %39 {offsets = [1], sizes = [1], strides = [1]} : vector<8xf32> to vector<1xf32> %42 = vector.extract_strided_slice %39 {offsets = [2], sizes = [1], strides = [1]} : vector<8xf32> to vector<1xf32>