From c2f21e84e4e9781a2ae54f673682d5280aeb5acd Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Fri, 20 Dec 2024 20:22:08 +0000 Subject: [PATCH] Add XeTile block operation fallback pass Certain block operation legal at XeTile dialect level cannot be supported by matching XeGPU dialect op due to not meeting HW restriction. This PR adds a new pass that provide a fallback for some cases. Pass can be called with command line arg --xetile-blockop-fallback The cases covered are: Source of Tile is a static shaped row major memref but - pitch is not a multiple of 16 bytes or less than 64 bytes - or memory space indicates SLM memory For such fitting case, this pass turns - block tile to scatter tile - load_tile to load - store_tile to store - update_tile_offset to use tile shaped indices instead of X, Y offset - impacted scf.for arguments from block tile type to scatter tile type --- .../imex/Dialect/XeTile/Transforms/Passes.h | 1 + .../imex/Dialect/XeTile/Transforms/Passes.td | 17 + .../XeTile/Transforms/BlockOpFallback.cpp | 443 ++++++++++++++++++ lib/Dialect/XeTile/Transforms/CMakeLists.txt | 1 + lib/Transforms/RemoveSingleElemVector.cpp | 4 +- .../XeTile/Transforms/block_op_fallback.mlir | 391 ++++++++++++++++ .../fallback/narrow_tile_one_elem_wide.mlir | 64 +++ .../fallback/narrow_tile_two_elem_wide.mlir | 65 +++ .../Dialect/XeTile/fallback/slm.mlir | 80 ++++ .../fallback/xetile-fallback-to-func-vc.pp | 41 ++ .../postop_reduce_n.mlir | 4 +- 11 files changed, 1109 insertions(+), 2 deletions(-) create mode 100644 lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp create mode 100644 test/Dialect/XeTile/Transforms/block_op_fallback.mlir create mode 100644 test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir create mode 100644 test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir create mode 100644 test/Integration/Dialect/XeTile/fallback/slm.mlir create mode 100644 test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.h b/include/imex/Dialect/XeTile/Transforms/Passes.h index 869732f3b..c24b30184 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.h +++ b/include/imex/Dialect/XeTile/Transforms/Passes.h @@ -40,6 +40,7 @@ std::unique_ptr createXeTileBlockingPass(const std::string &device = "pvc"); std::unique_ptr createXeTileWgToSgPass(); std::unique_ptr createXeTileCanonicalizationPass(); +std::unique_ptr createXeTileBlockOpFallbackPass(); #define GEN_PASS_DECL_XETILEBLOCKING #define GEN_PASS_DECL_XETILECANONICALIZATION diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.td b/include/imex/Dialect/XeTile/Transforms/Passes.td index 83c141718..6381c8887 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.td +++ b/include/imex/Dialect/XeTile/Transforms/Passes.td @@ -96,4 +96,21 @@ def XeTileBlocking : Pass<"xetile-blocking", "::mlir::gpu::GPUModuleOp">{ } +def XeTileBlockOpFallback : Pass<"xetile-blockop-fallback", "::mlir::gpu::GPUModuleOp">{ + let summary = "Transform unsuitable block ops to fallback scattered ops"; + + let description = [{ + This transform pass transforms XeTile block ops that are not suitable due to HW restrictions, + to scattered XeTile ops. + }]; + + let constructor = "imex::createXeTileBlockOpFallbackPass()"; + let dependentDialects = ["imex::xetile::XeTileDialect", + "mlir::arith::ArithDialect", + "mlir::gpu::GPUDialect", + "mlir::index::IndexDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect"]; +} + #endif // _XeTile_PASSES_TD_INCLUDED_ diff --git a/lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp b/lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp new file mode 100644 index 000000000..a08b5fb43 --- /dev/null +++ b/lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp @@ -0,0 +1,443 @@ +//====-- BlockOpFallback.cpp - XeTile Block Op Fallback Pass ----*- C++-*-===// +// +// Copyright 2024 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This pass detects XeTile InitTile ops that does not meet HW restrictions +/// and rewrite into InitTile ops with scattered description. +/// This triggers change to tile type. Shape is same but scattered attribute +/// is added. As a result, ops requiring block description gets invalid. +/// Patterns that legalizes those ops are added and GreedyPatternRewriteDriver +/// is used to apply the patterns. +/// +//===----------------------------------------------------------------------===// + +#include "imex/Dialect/XeTile/IR/XeTileOps.h" +#include "imex/Dialect/XeTile/Transforms/Passes.h" +#include "imex/Utils/XeCommon.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace imex { +#define GEN_PASS_DEF_XETILEBLOCKOPFALLBACK +#include "imex/Dialect/XeTile/Transforms/Passes.h.inc" +} // namespace imex + +// +// Limitations and future plan: +// Currently limited to 2D static shaped source memref of row major order. +// Dynamic shape with known pitch value will be supported in the future. +// Scattered offset is calculated by generated code sequence but will be +// using constant immediate values for static shapes in the future. +// Current code sequence is not optimal for supporting blocking pass and +// SIMT lowering. This will be improved in the future. +// Correct mask generation requires getting absolute offsets from source +// memref. This is not supported in the current implementation and will be +// added in the future. For now, mask generation assume all accesses are +// in bounds +// + +namespace blockopfallback { + +static imex::xetile::TileType addScatterAttr(imex::xetile::TileType tileTy) { + auto tileAttr = + mlir::dyn_cast_or_null(tileTy.getEncoding()); + int memorySpace = tileTy.getMemorySpaceAsInt(); + if (!tileAttr) { + std::vector orderVec = {1, 0}; + llvm::ArrayRef order(orderVec); + auto encoding = imex::xetile::XeTileAttr::get(tileTy.getContext(), order, + memorySpace, true); + return imex::xetile::TileType::get(tileTy.getShape(), + tileTy.getElementType(), encoding); + } + auto sgMap = tileAttr.getSgMap(); + auto wgMap = tileAttr.getWgMap(); + auto order = tileAttr.getOrder().asArrayRef(); + auto scatterTileAttr = imex::xetile::XeTileAttr::get( + tileTy.getContext(), sgMap, wgMap, order, memorySpace, true); + return imex::xetile::TileType::get(tileTy.getShape(), tileTy.getElementType(), + scatterTileAttr); +} + +struct InitTileOpPattern final + : public mlir::OpRewritePattern { + InitTileOpPattern(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(imex::xetile::InitTileOp initTileOp, + mlir::PatternRewriter &rewriter) const override { + auto tileTy = initTileOp.getType(); + // Skip if tile is scattered + if (tileTy.getScatterAttr()) { + return mlir::failure(); + } + // Skip 1D tile + if (tileTy.getRank() < 2) { + return mlir::failure(); + } + // Skip if tile is column major + if (tileTy.getOrder().asArrayRef() != llvm::ArrayRef{1, 0}) { + return mlir::failure(); + } + // Currenty only supports memref source + if (!initTileOp.isSourceMemRef()) { + return mlir::failure(); + } + // Cannot handle non static shape + if (!initTileOp.sourceMemRefHasStaticShape()) { + return mlir::failure(); + } + + auto srcShape = initTileOp.getSourceMemrefStaticShape(); + // Cannot handle non 2D source memref + if (srcShape.size() != 2) { + return mlir::failure(); + } + // Check if memspace is SLM + auto memorySpace = initTileOp.getSourceMemorySpaceAsInt(); + bool isSLM = memorySpace == 3; + // Check pitch >= 64bytes and pitch is multiple of 16bytes + bool isValidPitch = true; + auto pitchNumElems = srcShape[srcShape.size() - 1]; + auto elemBitwidth = + initTileOp.getSourceMemrefElemType().getIntOrFloatBitWidth(); + auto pitchNumBytes = pitchNumElems * elemBitwidth / 8; + isValidPitch = pitchNumBytes >= 64 && (pitchNumBytes % 16 == 0); + // If memspace is not SLM and pitch is valid, no need to rewrite + if (!isSLM && isValidPitch) { + return mlir::failure(); + } + // Get flat shape size + int64_t flatSize = 1; + for (auto dim : srcShape) { + flatSize *= dim; + } + + // reinterpret_cast to flat memref of flatSize + mlir::MemRefLayoutAttrInterface layout = {}; + // Is source offset always 0? No API to check. + auto flatMemref = rewriter.create( + initTileOp.getLoc(), + mlir::MemRefType::get({flatSize}, initTileOp.getSourceMemrefElemType(), + layout, initTileOp.getSourceMemorySpace()), + initTileOp.getSource(), 0, llvm::ArrayRef{flatSize}, + llvm::ArrayRef{1}); + + // Create indices for scatter + auto offsets = initTileOp.getMixedOffsets(); + auto loc = initTileOp.getLoc(); + auto offsetX = imex::getValueOrConstantOp(offsets[0], loc, rewriter, + rewriter.getIndexType()); + auto offsetY = imex::getValueOrConstantOp(offsets[1], loc, rewriter, + rewriter.getIndexType()); + auto indexVecTy = + mlir::VectorType::get(tileTy.getShape(), rewriter.getIndexType()); + bool isSingleCol = tileTy.getShape().back() == 1; + bool isSingleRow = tileTy.getShape().front() == 1; + auto rowIndexVecTy = mlir::VectorType::get({tileTy.getShape().front()}, + rewriter.getIndexType()); + auto colIndexVecTy = mlir::VectorType::get({tileTy.getShape().back()}, + rewriter.getIndexType()); + + // Create + // [0, ...., TileShape[1]-1] broadcasted to TileShape + // if isSingleCol, splat offsetY to TileShape + mlir::Value stepOffsetTile; + if (isSingleCol) { + stepOffsetTile = rewriter.createOrFold( + loc, indexVecTy, offsetY); + } else { + auto stepVec = + rewriter.createOrFold(loc, colIndexVecTy); + auto stepTile = rewriter.createOrFold( + loc, indexVecTy, stepVec); + auto offsetYTile = rewriter.createOrFold( + loc, indexVecTy, offsetY); + // Add offsetY to step + stepOffsetTile = rewriter.createOrFold( + loc, indexVecTy, stepTile, offsetYTile); + } + + // create [0, 1, 2, ...., TileShape[0]-1]^T broadcasted to TileShape + // if isSingleRow, splat offsetX to TileShape + mlir::Value rowOffsetTile; + if (isSingleRow) { + rowOffsetTile = rewriter.createOrFold( + loc, indexVecTy, offsetX); + } else { + auto rowVecT = + rewriter.createOrFold(loc, rowIndexVecTy); + auto offsetXVec = rewriter.createOrFold( + loc, + mlir::VectorType::get({tileTy.getShape().front()}, + rewriter.getIndexType()), + offsetX); + // Add offsetX to rowVecT + auto rowOffsetVecT = rewriter.createOrFold( + loc, rowIndexVecTy, rowVecT, offsetXVec); + // reshape to row x 1 + auto rowOffsetVec = rewriter.createOrFold( + loc, + mlir::VectorType::get({tileTy.getShape().front(), 1}, + rewriter.getIndexType()), + rowOffsetVecT); + // broadcast to TileShape + rowOffsetTile = rewriter.createOrFold( + loc, indexVecTy, rowOffsetVec); + } + + // create [pitchNumElems] splatted to TileShape + auto stride = rewriter.createOrFold( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(pitchNumElems)); + auto strideTile = + rewriter.createOrFold(loc, indexVecTy, stride); + // Create a temp with just rowTile * strideTile + auto rowStrideTile = rewriter.createOrFold( + loc, indexVecTy, rowOffsetTile, strideTile); + // Create scatter indices complete row*stride + step + auto indices = mlir::dyn_cast_or_null>( + rewriter.createOrFold( + loc, indexVecTy, rowStrideTile, stepOffsetTile)); + if (!indices) { + return rewriter.notifyMatchFailure(initTileOp, + "Could not generate scatter indices."); + } + // Add scatter attribute to tile type + auto scatterTileTy = addScatterAttr(tileTy); + // Replace InitTileOp + rewriter.replaceOpWithNewOp( + initTileOp, scatterTileTy, flatMemref, indices); + + return mlir::success(); + } +}; + +struct LoadTileOpPattern final + : public mlir::OpRewritePattern { + LoadTileOpPattern(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(imex::xetile::LoadTileOp loadTileOp, + mlir::PatternRewriter &rewriter) const override { + auto tile = loadTileOp.getSource(); + auto tileTy = tile.getType(); + if (!tileTy.getScatterAttr()) { + return mlir::failure(); + } + auto one = rewriter.createOrFold( + loadTileOp.getLoc(), rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto mask = rewriter.createOrFold( + loadTileOp.getLoc(), + mlir::VectorType::get(tileTy.getShape(), rewriter.getI1Type()), one); + rewriter.replaceOpWithNewOp( + loadTileOp, loadTileOp.getType(), tile, mask, + loadTileOp.getPaddingAttr(), loadTileOp.getL1HintAttr(), + loadTileOp.getL2HintAttr(), loadTileOp.getL3HintAttr()); + return mlir::success(); + } +}; + +struct StoreTileOpPattern final + : public mlir::OpRewritePattern { + StoreTileOpPattern(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(imex::xetile::StoreTileOp storeTileOp, + mlir::PatternRewriter &rewriter) const override { + auto tile = storeTileOp.getTile(); + auto tileTy = tile.getType(); + if (!tileTy.getScatterAttr()) { + return mlir::failure(); + } + auto one = rewriter.createOrFold( + storeTileOp.getLoc(), rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto mask = rewriter.createOrFold( + storeTileOp.getLoc(), + mlir::VectorType::get(tileTy.getShape(), rewriter.getI1Type()), one); + rewriter.replaceOpWithNewOp( + storeTileOp, storeTileOp.getValue(), tile, mask, + storeTileOp.getL1HintAttr(), storeTileOp.getL2HintAttr(), + storeTileOp.getL3HintAttr()); + return mlir::success(); + } +}; + +static imex::xetile::InitTileOp +getInitTileOp(mlir::TypedValue tile) { + // Three sources of tile value + auto dop = tile.getDefiningOp(); + // 1. BlockArgument of scf.for + if (!dop) { + if (!mlir::isa(tile)) { + return nullptr; + } + auto blockArg = llvm::cast(tile); + auto blockArgNum = blockArg.getArgNumber(); + auto parentOp = blockArg.getOwner()->getParentOp(); + if (!mlir::isa(parentOp)) { + return nullptr; + } + auto scfForOp = mlir::dyn_cast(parentOp); + auto numInductionVars = scfForOp.getNumInductionVars(); + auto init = scfForOp.getInits()[blockArgNum - numInductionVars]; + if (!mlir::isa>(init)) { + return nullptr; + } + return getInitTileOp( + mlir::dyn_cast>(init)); + } + // 2. InitTileOp + if (mlir::isa(dop)) { + return mlir::dyn_cast(dop); + } + // 3. UpdateTileOffsetOp + else if (mlir::isa(dop)) { + auto updateTileOffsetOp = + mlir::dyn_cast(dop); + return getInitTileOp(updateTileOffsetOp.getTile()); + } + return nullptr; +} + +struct UpdateTileOffsetOpPattern final + : public mlir::OpRewritePattern { + UpdateTileOffsetOpPattern(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(imex::xetile::UpdateTileOffsetOp updateTileOffsetOp, + mlir::PatternRewriter &rewriter) const override { + auto tile = updateTileOffsetOp.getTile(); + auto tileTy = tile.getType(); + if (!tileTy.getScatterAttr()) { + return mlir::failure(); + } + // Return if indices are already set + if (updateTileOffsetOp.getIndices()) { + return mlir::failure(); + } + auto initTileOp = getInitTileOp(tile); + if (!initTileOp) { + return rewriter.notifyMatchFailure(updateTileOffsetOp, + "Could not find InitTileOp."); + } + auto srcMemref = initTileOp.getSource(); + auto castOp = srcMemref.getDefiningOp(); + if (!castOp || !mlir::isa(castOp)) { + return rewriter.notifyMatchFailure(updateTileOffsetOp, + "Source is not flat memref."); + } + auto reinterOp = mlir::dyn_cast(castOp); + auto baseMemref = reinterOp.getSource(); + if (!mlir::isa(baseMemref.getType())) { + return rewriter.notifyMatchFailure(updateTileOffsetOp, + "Source is not a ranked memref."); + } + auto baseMemrefTy = mlir::dyn_cast(baseMemref.getType()); + auto baseShape = baseMemrefTy.getShape(); + auto pitchNumElems = baseShape[baseShape.size() - 1]; + auto loc = updateTileOffsetOp.getLoc(); + // Create update indices by doing vector.splat with (offX*stride + offY) + auto pitch = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(pitchNumElems)); + auto offX = updateTileOffsetOp.getOffsetX(); + auto stride = rewriter.createOrFold( + loc, rewriter.getIndexType(), offX, pitch); + auto offY = updateTileOffsetOp.getOffsetY(); + auto index = rewriter.createOrFold( + loc, rewriter.getIndexType(), stride, offY); + auto indices = rewriter.createOrFold( + loc, mlir::VectorType::get(tileTy.getShape(), rewriter.getIndexType()), + index); + rewriter.replaceOpWithNewOp( + updateTileOffsetOp, tile, nullptr, nullptr, indices); + return mlir::success(); + } +}; + +struct SCFForOpPattern final : public mlir::OpRewritePattern { + SCFForOpPattern(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(mlir::scf::ForOp scfForOp, + mlir::PatternRewriter &rewriter) const override { + auto initArgs = scfForOp.getInitArgs(); + auto regionIterArgs = scfForOp.getRegionIterArgs(); + auto results = scfForOp.getResults(); + bool isUpdated = false; + for (auto [init, arg, res] : + llvm::zip_equal(initArgs, regionIterArgs, results)) { + auto initTy = init.getType(); + if (mlir::isa(initTy)) { + if (!mlir::isa(arg.getType()) || + !mlir::isa(res.getType())) { + return rewriter.notifyMatchFailure(scfForOp, "TileType mismatch."); + } + auto initTileTy = mlir::dyn_cast(initTy); + if (initTileTy.getScatterAttr()) { + auto argTileTy = + mlir::dyn_cast(arg.getType()); + if (argTileTy.getScatterAttr()) { + continue; + } + auto scatterTileTy = addScatterAttr(argTileTy); + arg.setType(scatterTileTy); + res.setType(scatterTileTy); + isUpdated = true; + } + } + } + if (!isUpdated) { + return mlir::failure(); + } + return mlir::success(); + } +}; + +struct XeTileBlockOpFallbackPass final + : public imex::impl::XeTileBlockOpFallbackBase { + void runOnOperation() override { + auto *context = &getContext(); + mlir::Operation *op = getOperation(); + + mlir::RewritePatternSet patterns(context); + mlir::GreedyRewriteConfig config; + config.enableRegionSimplification = + mlir::GreedySimplifyRegionLevel::Disabled; + config.useTopDownTraversal = true; + config.strictMode = mlir::GreedyRewriteStrictness::ExistingAndNewOps; + patterns.add(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) { + return signalPassFailure(); + } + } +}; + +} // namespace blockopfallback + +namespace imex { +std::unique_ptr createXeTileBlockOpFallbackPass() { + return std::make_unique(); +} +} // namespace imex diff --git a/lib/Dialect/XeTile/Transforms/CMakeLists.txt b/lib/Dialect/XeTile/Transforms/CMakeLists.txt index 5d74fdda2..fb20cb9fd 100644 --- a/lib/Dialect/XeTile/Transforms/CMakeLists.txt +++ b/lib/Dialect/XeTile/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_imex_dialect_library(IMEXXeTileTransforms Blocking.cpp BlockingAnalysis.cpp + BlockOpFallback.cpp InitDuplicate.cpp Canonicalization.cpp WgToSg.cpp diff --git a/lib/Transforms/RemoveSingleElemVector.cpp b/lib/Transforms/RemoveSingleElemVector.cpp index 29ef2475e..b39495828 100644 --- a/lib/Transforms/RemoveSingleElemVector.cpp +++ b/lib/Transforms/RemoveSingleElemVector.cpp @@ -253,7 +253,9 @@ struct RemoveSingleElemVectorPass final }); mlir::RewritePatternSet patterns(context); - patterns.add( typeConverter, context); diff --git a/test/Dialect/XeTile/Transforms/block_op_fallback.mlir b/test/Dialect/XeTile/Transforms/block_op_fallback.mlir new file mode 100644 index 000000000..bcf4d76a2 --- /dev/null +++ b/test/Dialect/XeTile/Transforms/block_op_fallback.mlir @@ -0,0 +1,391 @@ +// RUN: imex-opt --split-input-file --xetile-blockop-fallback %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @test_module { + // CHECK-LABEL: @test_pitch_one_elems_and_offset_attr + gpu.func @test_pitch_one_elems_and_offset_attr(%arg0: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR3:.*]] = xetile.load %[[VAR2]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_pitch_one_elems_and_offset_vars + gpu.func @test_pitch_one_elems_and_offset_vars(%arg0: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %cst0 = arith.constant 0 : index + %0 = xetile.init_tile %arg0 [%cst0, %cst0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR3:.*]] = xetile.load %[[VAR2]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_pitch_one_elems_and_mixed_offsets + gpu.func @test_pitch_one_elems_and_mixed_offsets(%arg0: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %cst0 = arith.constant 0 : index + %0 = xetile.init_tile %arg0 [%cst0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR3:.*]] = xetile.load %[[VAR2]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_pitch_two_elems + gpu.func @test_pitch_two_elems(%arg0: memref<512x2xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CST0:.*]] = arith.constant dense<2> : vector<32x1xindex> + // CHECK: %[[CST1:.*]] = arith.constant dense<16> : vector<32xindex> + // CHECK: %[[CST2:.*]] = arith.constant dense<1> : vector<32x1xindex> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1024], strides: [1] : memref<512x2xf32> to memref<1024xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = arith.addi %[[VAR0]], %[[CST1]] : vector<32xindex> + // CHECK: %[[VAR2:.*]] = vector.shape_cast %[[VAR1]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR2]], %[[CST0]] : vector<32x1xindex> + // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR3]], %[[CST2]] : vector<32x1xindex> + // CHECK: %[[VAR5:.*]] = xetile.init_tile %[[CAST]], %[[VAR4]] : memref<1024xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %0 = xetile.init_tile %arg0 [16, 1] : memref<512x2xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR6:.*]] = xetile.load %[[VAR5]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_store_tile + gpu.func @test_store_tile(%arg0: memref<512x1xf32>, %arg1: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[CAST0:.*]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR3:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR4:.*]] = vector.shape_cast %[[VAR3:.*]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR5:.*]] = xetile.init_tile %[[CAST0]], %[[VAR4]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR6:.*]] = xetile.load %[[VAR2]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %2 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + // CHECK: xetile.store %[[VAR6]], %[[VAR5]], %[[CST]] : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> + xetile.store_tile %2, %1 : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_update_tile_offset + gpu.func @test_update_tile_offset(%arg0: memref<512x1xf32>, %arg1: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<32> : vector<32x1xindex> + // CHECK: %[[CST0:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %cst0 = arith.constant 0 : index + %cst32 = arith.constant 32 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR3:.*]] = xetile.load %[[VAR2]], %[[CST0]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + // CHECK: %[[VAR4:.*]] = xetile.update_tile_offset %[[VAR2]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xindex> + %2 = xetile.update_tile_offset %0, [%cst32, %cst0] : !xetile.tile<32x1xf32, #xetile.tile_attr> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_multiple_update_tile_offset + gpu.func @test_multiple_update_tile_offset(%arg0: memref<512x1xf32>, %arg1: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<16> : vector<32x1xindex> + // CHECK: %[[CST0:.*]] = arith.constant dense<32> : vector<32x1xindex> + // CHECK: %[[CST1:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR1:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR2:.*]] = xetile.init_tile %[[CAST]], %[[VAR1]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %cst0 = arith.constant 0 : index + %cst32 = arith.constant 32 : index + %cst16 = arith.constant 16 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR6:.*]] = xetile.load %[[VAR2]], %[[CST1]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + // CHECK: %[[VAR7:.*]] = xetile.update_tile_offset %[[VAR2]], %[[CST0]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xindex> + %2 = xetile.update_tile_offset %0, [%cst32, %cst0] : !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR8:.*]] = xetile.update_tile_offset %[[VAR7]], %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xindex> + %3 = xetile.update_tile_offset %2, [%cst16, %cst0] : !xetile.tile<32x1xf32, #xetile.tile_attr> + gpu.return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @test_scf_for + gpu.func @test_scf_for(%arg0: memref<512x1xf32>, %arg1: memref<512x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<32> : vector<32x1xindex> + // CHECK: %[[CST0:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[C480:.*]] = arith.constant 480 : index + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR3:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR5:.*]] = xetile.init_tile %[[CAST]], %[[VAR3]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %cst0 = arith.constant 0 : index + %cst32 = arith.constant 32 : index + %cst512 = arith.constant 512 : index + %cst480 = arith.constant 480 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[CAST1:.*]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [512], strides: [1] : memref<512x1xf32> to memref<512xf32> + // CHECK: %[[VAR6:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR9:.*]] = vector.shape_cast %[[VAR6]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR11:.*]] = xetile.init_tile %[[CAST1]], %[[VAR9]] : memref<512xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1 [0, 0] : memref<512x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR12:.*]]:2 = scf.for %arg2 = %[[C0]] to %[[C480]] step %[[C32]] iter_args(%arg3 = %[[VAR5]], %arg4 = %[[VAR11]]) -> (!xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr>) { + %out:2 = scf.for %k = %cst0 to %cst480 step %cst32 + iter_args(%a_tile = %0, %b_tile = %1) + -> (!xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr>) { + // CHECK: %[[VAR14:.*]] = xetile.load %arg3, %[[CST0]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %a_value = xetile.load_tile %a_tile : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + // CHECK: xetile.store %[[VAR14]], %arg4, %[[CST0]] : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> + xetile.store_tile %a_value, %b_tile : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR15:.*]] = xetile.update_tile_offset %arg3, %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xindex> + %a_next_tile = xetile.update_tile_offset %a_tile, [%cst32, %cst0] : !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: %[[VAR16:.*]] = xetile.update_tile_offset %arg4, %[[CST]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xindex> + %b_next_tile = xetile.update_tile_offset %b_tile, [%cst32, %cst0] : !xetile.tile<32x1xf32, #xetile.tile_attr> + // CHECK: scf.yield %[[VAR15]], %[[VAR16]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr> + scf.yield %a_next_tile, %b_next_tile : !xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr> + } + // CHECK: %[[VAR13:.*]] = xetile.load %[[VAR12]]#0, %[[CST0]] : !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> -> vector<32x1xf32> + %2 = xetile.load_tile %out#0 : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<32x1xf32> + // CHECK: xetile.store %[[VAR13]], %[[VAR12]]#1, %[[CST0]] : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr>, vector<32x1xi1> + xetile.store_tile %2, %out#1 : vector<32x1xf32>, !xetile.tile<32x1xf32, #xetile.tile_attr> + gpu.return + } +} + +// ----- + +module attributes {gpu.container_module} { + gpu.module @test_module { + // CHECK-LABEL: func @test_nested_scf_for + gpu.func @test_nested_scf_for(%arg0: memref<16384x1xf32>, %arg1: memref<16384x1xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<32> : vector<32x1xindex> + // CHECK: %[[CST0:.*]] = arith.constant dense : vector<32x1xi1> + // CHECK: %[[CST1:.*]] = arith.constant dense<128> : vector<32x1xindex> + // CHECK: %[[CST2:.*]] = arith.constant dense<512> : vector<32x1xindex> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C512:.*]] = arith.constant 512 : index + // CHECK: %[[C128:.*]] = arith.constant 128 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[C16384:.*]] = arith.constant 16384 : index + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [16384], strides: [1] : memref<16384x1xf32> to memref<16384xf32> + // CHECK: %[[VAR0:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR3:.*]] = vector.shape_cast %[[VAR0]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR5:.*]] = xetile.init_tile %[[CAST]], %[[VAR3]] : memref<16384xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %c0 = arith.constant 0 : index + %c512 = arith.constant 512 : index + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c16384 = arith.constant 16384 : index + %1 = xetile.init_tile %arg0[%c0, %c0] : memref<16384x1xf32> -> !xetile.tile<32x1xf32> + // CHECK: %[[CAST3:.*]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [16384], strides: [1] : memref<16384x1xf32> to memref<16384xf32> + // CHECK: %[[VAR6:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR9:.*]] = vector.shape_cast %[[VAR6]] : vector<32xindex> to vector<32x1xindex> + // CHECK: %[[VAR11:.*]] = xetile.init_tile %[[CAST3]], %[[VAR9]] : memref<16384xf32>, vector<32x1xindex> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + %2 = xetile.init_tile %arg1[%c0, %c0] : memref<16384x1xf32> -> !xetile.tile<32x1xf32> + // CHECK: %[[VAR12:.*]]:2 = scf.for %arg2 = %[[C0]] to %[[C16384]] step %[[C512]] iter_args(%arg3 = %[[VAR5]], %arg4 = %[[VAR11]]) -> (!xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr>) { + %3:2 = scf.for %arg3 = %c0 to %c16384 step %c512 iter_args(%arg4 = %1, %arg5 = %2) -> (!xetile.tile<32x1xf32>, !xetile.tile<32x1xf32>) { + %4 = xetile.update_tile_offset %arg4, [%c512, %c0] : !xetile.tile<32x1xf32> + %5 = xetile.update_tile_offset %arg5, [%c512, %c0] : !xetile.tile<32x1xf32> + // CHECK: %[[VAR15:.*]]:2 = scf.for %arg5 = %[[C0]] to %[[C512]] step %[[C128]] iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (!xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr>) { + %6:2 = scf.for %arg6 = %c0 to %c512 step %c128 iter_args(%arg7 = %arg4, %arg8 = %arg5) -> (!xetile.tile<32x1xf32>, !xetile.tile<32x1xf32>) { + %7 = xetile.update_tile_offset %arg7, [%c128, %c0] : !xetile.tile<32x1xf32> + %8 = xetile.update_tile_offset %arg8, [%c128, %c0] : !xetile.tile<32x1xf32> + // CHECK: %[[VAR18:.*]]:2 = scf.for %arg8 = %[[C0]] to %[[C128]] step %[[C32]] iter_args(%arg9 = %arg6, %arg10 = %arg7) -> (!xetile.tile<32x1xf32, #xetile.tile_attr>, !xetile.tile<32x1xf32, #xetile.tile_attr>) { + %9:2 = scf.for %arg9 = %c0 to %c128 step %c32 iter_args(%arg10 = %arg7, %arg11 = %arg8) -> (!xetile.tile<32x1xf32>, !xetile.tile<32x1xf32>) { + %10 = xetile.load_tile %arg10 : !xetile.tile<32x1xf32> -> vector<32x1xf32> + xetile.store_tile %10, %arg11 : vector<32x1xf32>, !xetile.tile<32x1xf32> + %11 = xetile.update_tile_offset %arg10, [%c32, %c0] : !xetile.tile<32x1xf32> + %12 = xetile.update_tile_offset %arg11, [%c32, %c0] : !xetile.tile<32x1xf32> + scf.yield %11, %12 : !xetile.tile<32x1xf32>, !xetile.tile<32x1xf32> + } + scf.yield %7, %8 : !xetile.tile<32x1xf32>, !xetile.tile<32x1xf32> + } + scf.yield %4, %5 : !xetile.tile<32x1xf32>, !xetile.tile<32x1xf32> + } + gpu.return + } + } +} + +// ----- + +module attributes {gpu.container_module} { + func.func @postop_reduce_m_entry(%arg0: memref<16384x12288xbf16>, %arg1: memref<2048x12288xbf16>, %arg2: memref<32x2048xf32>) attributes {gemm_tiles_b = 1 : i64, gemm_tiles_x = dense<[8, 2, 4, 8]> : vector<4xi64>, gemm_tiles_y = dense<[1, 2, 8, 4]> : vector<4xi64>, physical_nd_range = dense<[8, 32]> : vector<2xi64>, region_partition = 0 : i64, region_size = 32 : i64, syn.fusion_successful, syn.tensor_signature = (tensor<16384x12288xbf16>, tensor<2048x12288xbf16>) -> tensor<32x2048xf32>, synFusionGenOps = 6 : i64, synFusionRequiredBeamSize = 1 : i64, synFusionTotalCost = 1003595802.6 : f64} { + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + gpu.launch_func @postop_reduce_m::@postop_reduce_m blocks in (%c8, %c32, %c1) threads in (%c8, %c4, %c1) args(%arg0 : memref<16384x12288xbf16>, %arg1 : memref<2048x12288xbf16>, %arg2 : memref<32x2048xf32>) + return + } + gpu.module @postop_reduce_m attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + // CHECK-LABEL: func @postop_reduce_m + gpu.func @postop_reduce_m(%arg0: memref<16384x12288xbf16>, %arg1: memref<2048x12288xbf16>, %arg2: memref<32x2048xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<8x4xi1> + // CHECK: %[[CST0:.*]] = arith.constant dense<128> : vector<8x4xindex> + // CHECK: %[[CST1:.*]] = arith.constant dense : vector<1x32xi1> + // CHECK: %[[CST2:.*]] = arith.constant dense<128> : vector<1x32xindex> + %c12288 = arith.constant 12288 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + %c128 = arith.constant 128 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %cst = arith.constant dense<0.000000e+00> : vector<32x32xf32> + %cst_0 = arith.constant dense<0.000000e+00> : vector<1x32xf32> + %cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32> + %cst_2 = arith.constant dense<0.000000e+00> : vector<1x4xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.divsi %block_id_y, %c8 : index + %1 = arith.remsi %block_id_y, %c8 : index + %2 = arith.muli %block_id_x, %c4 : index + %3 = arith.addi %2, %0 : index + %4 = arith.muli %1, %c128 : index + %5 = gpu.subgroup_id : index + %6 = index.floordivs %5, %c32 + %7 = index.remu %5, %c32 + %8 = index.remu %6, %c1 + %9 = index.add %3, %8 + %10 = index.remu %7, %c32 + %11 = index.mul %10, %c4 + %12 = index.add %4, %11 + %13 = xetile.init_tile %arg2[%9, %12] : memref<32x2048xf32> -> !xetile.tile<1x4xf32> + %14 = arith.muli %block_id_x, %c2048 : index + %15 = arith.muli %0, %c256 : index + %16 = arith.addi %14, %15 : index + %17 = index.floordivs %5, %c4 + %18 = index.remu %5, %c4 + %19 = index.remu %17, %c8 + %20 = index.mul %19, %c32 + %21 = index.add %16, %20 + %22 = index.remu %18, %c1 + %23 = index.mul %22, %c32 + %24 = xetile.init_tile %arg0[%21, %23] : memref<16384x12288xbf16> -> !xetile.tile<32x32xbf16> + %25 = index.floordivs %5, %c8 + %26 = index.remu %5, %c8 + %27 = index.remu %26, %c4 + %28 = index.mul %27, %c32 + %29 = index.add %4, %28 + %30 = index.remu %25, %c1 + %31 = index.mul %30, %c32 + %32 = xetile.init_tile %arg1[%29, %31] : memref<2048x12288xbf16> -> !xetile.tile<32x32xbf16> + %33:2 = scf.for %arg3 = %c0 to %c2 step %c1 iter_args(%arg4 = %13, %arg5 = %32) -> (!xetile.tile<1x4xf32>, !xetile.tile<32x32xbf16>) { + %34 = xetile.update_tile_offset %arg5, [%c1024, %c0] : !xetile.tile<32x32xbf16> + %35 = xetile.update_tile_offset %arg4, [%c0, %c1024] : !xetile.tile<1x4xf32> + %36:2 = scf.for %arg6 = %c0 to %c2 step %c1 iter_args(%arg7 = %cst_2, %arg8 = %24) -> (vector<1x4xf32>, !xetile.tile<32x32xbf16>) { + %37 = xetile.update_tile_offset %arg8, [%c1024, %c0] : !xetile.tile<32x32xbf16> + %38:3 = scf.for %arg9 = %c0 to %c12288 step %c32 iter_args(%arg10 = %cst, %arg11 = %arg8, %arg12 = %arg5) -> (vector<32x32xf32>, !xetile.tile<32x32xbf16>, !xetile.tile<32x32xbf16>) { + %56 = xetile.update_tile_offset %arg12, [%c0, %c32] : !xetile.tile<32x32xbf16> + %57 = xetile.update_tile_offset %arg11, [%c0, %c32] : !xetile.tile<32x32xbf16> + %58 = xetile.load_tile %arg11 : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> + %59 = math.exp %58 : vector<32x32xbf16> + %60 = xetile.load_tile %arg12 : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> + %61 = xetile.transpose %60, [1, 0] : vector<32x32xbf16> -> vector<32x32xbf16> + xegpu.compile_hint + %62 = xetile.tile_mma %59, %61, %arg10 : vector<32x32xbf16>, vector<32x32xbf16>, vector<32x32xf32> -> vector<32x32xf32> + xegpu.compile_hint + scf.yield %62, %57, %56 : vector<32x32xf32>, !xetile.tile<32x32xbf16>, !xetile.tile<32x32xbf16> + } + %39 = math.exp %38#0 : vector<32x32xf32> + %40 = vector.shape_cast %cst_0 : vector<1x32xf32> to vector<32xf32> + %41 = xetile.reduction , %39 [0] : vector<32x32xf32> -> vector<1x32xf32> + %42 = vector.shape_cast %41 : vector<1x32xf32> to vector<32xf32> + %43 = arith.addf %42, %40 : vector<32xf32> + %44 = vector.shape_cast %43 : vector<32xf32> to vector<1x32xf32> + %alloc = memref.alloc() : memref<4096xi8, 3> + %view = memref.view %alloc[%c0][] : memref<4096xi8, 3> to memref<8x128xf32, 3> + %45 = index.mul %18, %c32 + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[VIEW:.*]] to offset: [0], sizes: [1024], strides: [1] : memref<8x128xf32, 3> to memref<1024xf32, 3> + // CHECK: %[[VAR48:.*]] = vector.step : vector<32xindex> + // CHECK: %[[VAR49:.*]] = vector.broadcast %[[VAR48]] : vector<32xindex> to vector<1x32xindex> + // CHECK: %[[VAR50:.*]] = vector.splat %[[VAR45:.*]] : vector<1x32xindex> + // CHECK: %[[VAR52:.*]] = arith.addi %[[VAR49]], %[[VAR50]] : vector<1x32xindex> + // CHECK: %[[VAR54:.*]] = vector.splat %[[VAR17:.*]] : vector<1x32xindex> + // CHECK: %[[VAR55:.*]] = arith.muli %[[VAR54]], %[[CST2]] : vector<1x32xindex> + // CHECK: %[[VAR56:.*]] = arith.addi %[[VAR55]], %[[VAR52]] : vector<1x32xindex> + // CHECK: %[[VAR57:.*]] = xetile.init_tile %[[CAST]], %[[VAR56]] : memref<1024xf32, 3>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %46 = xetile.init_tile %view[%17, %45] : memref<8x128xf32, 3> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + // CHECK: xetile.store %[[VAR44:.*]], %[[VAR57]], %[[CST1]] : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> + xetile.store_tile %44, %46 : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr> + gpu.barrier + // CHECK: %[[VAR58:.*]] = index.mul + // CHECK: %[[VAR59:.*]] = index.mul + %47 = index.mul %6, %c8 + %48 = index.mul %7, %c4 + // CHECK: %[[CAST7:.*]] = memref.reinterpret_cast %[[VIEW]] to offset: [0], sizes: [1024], strides: [1] : memref<8x128xf32, 3> to memref<1024xf32, 3> + // CHECK: %[[VAR62:.*]] = vector.step : vector<4xindex> + // CHECK: %[[VAR63:.*]] = vector.broadcast %[[VAR62]] : vector<4xindex> to vector<8x4xindex> + // CHECK: %[[VAR61:.*]] = vector.splat %[[VAR59]] : vector<8x4xindex> + // CHECK: %[[VAR64:.*]] = arith.addi %[[VAR63]], %[[VAR61]] : vector<8x4xindex> + // CHECK: %[[VAR65:.*]] = vector.step : vector<8xindex> + // CHECK: %[[VAR60:.*]] = vector.splat %[[VAR58]] : vector<8xindex> + // CHECK: %[[VAR66:.*]] = arith.addi %[[VAR65]], %[[VAR60]] : vector<8xindex> + // CHECK: %[[VAR67:.*]] = vector.shape_cast %[[VAR66]] : vector<8xindex> to vector<8x1xindex> + // CHECK: %[[VAR68:.*]] = vector.broadcast %[[VAR67]] : vector<8x1xindex> to vector<8x4xindex> + // CHECK: %[[VAR69:.*]] = arith.muli %[[VAR68]], %[[CST0]] : vector<8x4xindex> + // CHECK: %[[VAR70:.*]] = arith.addi %[[VAR69]], %[[VAR64]] : vector<8x4xindex> + // CHECK: %[[VAR71:.*]] = xetile.init_tile %[[CAST7]], %[[VAR70]] : memref<1024xf32, 3>, vector<8x4xindex> -> !xetile.tile<8x4xf32, #xetile.tile_attr> + %49 = xetile.init_tile %view[%47, %48] : memref<8x128xf32, 3> -> !xetile.tile<8x4xf32, #xetile.tile_attr> + // CHECK: %[[VAR72:.*]] = xetile.load %[[VAR71]], %[[CST]] : !xetile.tile<8x4xf32, #xetile.tile_attr>, vector<8x4xi1> -> vector<8x4xf32> + %50 = xetile.load_tile %49 : !xetile.tile<8x4xf32, #xetile.tile_attr> -> vector<8x4xf32> + %51 = xetile.reduction , %50 [0] : vector<8x4xf32> -> vector<1x4xf32> + %52 = vector.shape_cast %51 : vector<1x4xf32> to vector<4xf32> + %53 = arith.addf %52, %cst_1 : vector<4xf32> + %54 = vector.shape_cast %53 : vector<4xf32> to vector<1x4xf32> + %55 = arith.addf %54, %arg7 : vector<1x4xf32> + scf.yield %55, %37 : vector<1x4xf32>, !xetile.tile<32x32xbf16> + } + xetile.store_tile %36#0, %arg4 : vector<1x4xf32>, !xetile.tile<1x4xf32> + scf.yield %35, %34 : !xetile.tile<1x4xf32>, !xetile.tile<32x32xbf16> + } + gpu.return + } + } +} diff --git a/test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir b/test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir new file mode 100644 index 000000000..445e96a82 --- /dev/null +++ b/test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir @@ -0,0 +1,64 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +module @narrow_tile attributes {gpu.container_module} { + func.func @test(%A: memref<64x1xf32>) -> memref<64x1xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %A_gpu = gpu.alloc host_shared() : memref<64x1xf32> + memref.copy %A, %A_gpu : memref<64x1xf32> to memref<64x1xf32> + %B_gpu = gpu.alloc host_shared() : memref<64x1xf32> + gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<64x1xf32>, %B_gpu : memref<64x1xf32>) + %B = memref.alloc() : memref<64x1xf32> + memref.copy %B_gpu, %B : memref<64x1xf32> to memref<64x1xf32> + gpu.dealloc %A_gpu : memref<64x1xf32> + gpu.dealloc %B_gpu : memref<64x1xf32> + return %B : memref<64x1xf32> + } + gpu.module @test_module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_scf_for(%arg0: memref<64x1xf32>, %arg1: memref<64x1xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst0 = arith.constant 0 : index + %cst16 = arith.constant 16 : index + %cst64 = arith.constant 64 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<64x1xf32> -> !xetile.tile<16x1xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1 [0, 0] : memref<64x1xf32> -> !xetile.tile<16x1xf32, #xetile.tile_attr> + %out:2 = scf.for %k = %cst0 to %cst64 step %cst16 + iter_args(%a_tile = %0, %b_tile = %1) + -> (!xetile.tile<16x1xf32, #xetile.tile_attr>, !xetile.tile<16x1xf32, #xetile.tile_attr>) { + %a_value = xetile.load_tile %a_tile : !xetile.tile<16x1xf32, #xetile.tile_attr> -> vector<16x1xf32> + xetile.store_tile %a_value, %b_tile : vector<16x1xf32>, !xetile.tile<16x1xf32, #xetile.tile_attr> + %a_next_tile = xetile.update_tile_offset %a_tile, [%cst16, %cst0] : !xetile.tile<16x1xf32, #xetile.tile_attr> + %b_next_tile = xetile.update_tile_offset %b_tile, [%cst16, %cst0] : !xetile.tile<16x1xf32, #xetile.tile_attr> + scf.yield %a_next_tile, %b_next_tile : !xetile.tile<16x1xf32, #xetile.tile_attr>, !xetile.tile<16x1xf32, #xetile.tile_attr> + } + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %A = memref.alloc() : memref<64x1xf32> + scf.for %arg0 = %c0 to %c64 step %c1 { + %0 = index.castu %arg0 : index to i32 + %val = arith.uitofp %0 : i32 to f32 + memref.store %val, %A[%arg0, %c0] : memref<64x1xf32> + } + %C = call @test(%A) : (memref<64x1xf32>) -> memref<64x1xf32> + %cast_A = memref.cast %A : memref<64x1xf32> to memref<*xf32> + %cast_C = memref.cast %C : memref<64x1xf32> to memref<*xf32> + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_A) : (memref<*xf32>, memref<*xf32>) -> () + //call @printMemrefF32(%cast_A) : (memref<*xf32>) -> () + //call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + return + } + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} + //func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir b/test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir new file mode 100644 index 000000000..bf85261cc --- /dev/null +++ b/test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir @@ -0,0 +1,65 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +module @narrow_tile attributes {gpu.container_module} { + func.func @test(%A: memref<64x2xf32>) -> memref<64x2xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %A_gpu = gpu.alloc host_shared() : memref<64x2xf32> + memref.copy %A, %A_gpu : memref<64x2xf32> to memref<64x2xf32> + %B_gpu = gpu.alloc host_shared() : memref<64x2xf32> + gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<64x2xf32>, %B_gpu : memref<64x2xf32>) + %B = memref.alloc() : memref<64x2xf32> + memref.copy %B_gpu, %B : memref<64x2xf32> to memref<64x2xf32> + gpu.dealloc %A_gpu : memref<64x2xf32> + gpu.dealloc %B_gpu : memref<64x2xf32> + return %B : memref<64x2xf32> + } + gpu.module @test_module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_scf_for(%arg0: memref<64x2xf32>, %arg1: memref<64x2xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst0 = arith.constant 0 : index + %cst16 = arith.constant 16 : index + %cst64 = arith.constant 64 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<64x2xf32> -> !xetile.tile<16x2xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1 [0, 0] : memref<64x2xf32> -> !xetile.tile<16x2xf32, #xetile.tile_attr> + %out:2 = scf.for %k = %cst0 to %cst64 step %cst16 + iter_args(%a_tile = %0, %b_tile = %1) + -> (!xetile.tile<16x2xf32, #xetile.tile_attr>, !xetile.tile<16x2xf32, #xetile.tile_attr>) { + %a_value = xetile.load_tile %a_tile : !xetile.tile<16x2xf32, #xetile.tile_attr> -> vector<16x2xf32> + xetile.store_tile %a_value, %b_tile : vector<16x2xf32>, !xetile.tile<16x2xf32, #xetile.tile_attr> + %a_next_tile = xetile.update_tile_offset %a_tile, [%cst16, %cst0] : !xetile.tile<16x2xf32, #xetile.tile_attr> + %b_next_tile = xetile.update_tile_offset %b_tile, [%cst16, %cst0] : !xetile.tile<16x2xf32, #xetile.tile_attr> + scf.yield %a_next_tile, %b_next_tile : !xetile.tile<16x2xf32, #xetile.tile_attr>, !xetile.tile<16x2xf32, #xetile.tile_attr> + } + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %A = memref.alloc() : memref<64x2xf32> + scf.for %arg0 = %c0 to %c64 step %c1 { + %0 = index.castu %arg0 : index to i32 + %val = arith.uitofp %0 : i32 to f32 + memref.store %val, %A[%arg0, %c0] : memref<64x2xf32> + memref.store %val, %A[%arg0, %c1] : memref<64x2xf32> + } + %C = call @test(%A) : (memref<64x2xf32>) -> memref<64x2xf32> + %cast_A = memref.cast %A : memref<64x2xf32> to memref<*xf32> + %cast_C = memref.cast %C : memref<64x2xf32> to memref<*xf32> + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_A) : (memref<*xf32>, memref<*xf32>) -> () + //call @printMemrefF32(%cast_A) : (memref<*xf32>) -> () + //call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + return + } + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} + //func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/fallback/slm.mlir b/test/Integration/Dialect/XeTile/fallback/slm.mlir new file mode 100644 index 000000000..4bc81a351 --- /dev/null +++ b/test/Integration/Dialect/XeTile/fallback/slm.mlir @@ -0,0 +1,80 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +module @narrow_tile attributes {gpu.container_module} { + func.func @test(%A: memref<32x32xf32>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %A_gpu = gpu.alloc host_shared() : memref<32x32xf32> + memref.copy %A, %A_gpu : memref<32x32xf32> to memref<32x32xf32> + %B_gpu = gpu.alloc host_shared() : memref<32x32xf32> + gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<32x32xf32>, %B_gpu : memref<32x32xf32>) + %B = memref.alloc() : memref<32x32xf32> + memref.copy %B_gpu, %B : memref<32x32xf32> to memref<32x32xf32> + gpu.dealloc %A_gpu : memref<32x32xf32> + gpu.dealloc %B_gpu : memref<32x32xf32> + return %B : memref<32x32xf32> + } + gpu.module @test_module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_scf_for(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst0 = arith.constant 0 : index + %cst8 = arith.constant 8 : index + %cst16 = arith.constant 16 : index + %cst32 = arith.constant 32 : index + %0 = xetile.init_tile %arg0 [0, 0] : memref<32x32xf32> -> !xetile.tile<8x16xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1 [0, 0] : memref<32x32xf32> -> !xetile.tile<8x16xf32, #xetile.tile_attr> + %slm = memref.alloc() : memref<8x16xf32, 3> + %slm_tile = xetile.init_tile %slm [0, 0] : memref<8x16xf32, 3> -> !xetile.tile<8x16xf32, #xetile.tile_attr> + %out:2 = scf.for %j = %cst0 to %cst32 step %cst8 + iter_args(%a_tile = %0, %b_tile = %1) + -> (!xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr>) { + %out:2 = scf.for %k = %cst0 to %cst32 step %cst16 + iter_args(%c_tile = %a_tile, %d_tile = %b_tile) + -> (!xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr>) { + %c_value = xetile.load_tile %c_tile : !xetile.tile<8x16xf32, #xetile.tile_attr> -> vector<8x16xf32> + xetile.store_tile %c_value, %slm_tile : vector<8x16xf32>, !xetile.tile<8x16xf32, #xetile.tile_attr> + %d_value = xetile.load_tile %slm_tile : !xetile.tile<8x16xf32, #xetile.tile_attr> -> vector<8x16xf32> + xetile.store_tile %d_value, %d_tile : vector<8x16xf32>, !xetile.tile<8x16xf32, #xetile.tile_attr> + %c_next_tile = xetile.update_tile_offset %c_tile, [%cst0, %cst16] : !xetile.tile<8x16xf32, #xetile.tile_attr> + %d_next_tile = xetile.update_tile_offset %d_tile, [%cst0, %cst16] : !xetile.tile<8x16xf32, #xetile.tile_attr> + scf.yield %c_next_tile, %d_next_tile : !xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr> + } + %a_next_tile = xetile.update_tile_offset %a_tile, [%cst8, %cst0] : !xetile.tile<8x16xf32, #xetile.tile_attr> + %b_next_tile = xetile.update_tile_offset %b_tile, [%cst8, %cst0] : !xetile.tile<8x16xf32, #xetile.tile_attr> + scf.yield %a_next_tile, %b_next_tile : !xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr> + } + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %A = memref.alloc() : memref<32x32xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %0 = index.castu %arg0 : index to i32 + %1 = index.castu %arg1 : index to i32 + %2 = arith.addi %0, %1 : i32 + %val = arith.uitofp %2 : i32 to f32 + memref.store %val, %A[%arg0, %arg1] : memref<32x32xf32> + } + } + %C = call @test(%A) : (memref<32x32xf32>) -> memref<32x32xf32> + %cast_A = memref.cast %A : memref<32x32xf32> to memref<*xf32> + %cast_C = memref.cast %C : memref<32x32xf32> to memref<*xf32> + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_A) : (memref<*xf32>, memref<*xf32>) -> () + //call @printMemrefF32(%cast_A) : (memref<*xf32>) -> () + //call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + return + } + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} + //func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp b/test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp new file mode 100644 index 000000000..3899a933c --- /dev/null +++ b/test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp @@ -0,0 +1,41 @@ +builtin.module( + cse + gpu.module(xetile-init-duplicate + xetile-canonicalization + xetile-blockop-fallback + xetile-blocking + cse + convert-xetile-to-xegpu + cse + imex-xegpu-hoist-transpose + imex-xegpu-apply-vnni-transformation + imex-xegpu-optimize-transpose) + cse + imex-vector-linearize + cse + imex-remove-single-elem-vector + canonicalize + cse + gpu.module(convert-xegpu-to-vc) + reconcile-unrealized-casts + bf16-to-gpu + cse + imex-convert-gpu-to-spirv + spirv.module(spirv-lower-abi-attrs + spirv-update-vce) + func.func(llvm-request-c-wrappers) + serialize-spirv + convert-vector-to-scf + convert-gpu-to-gpux + convert-scf-to-cf + expand-strided-metadata + finalize-memref-to-llvm + convert-cf-to-llvm + convert-vector-to-llvm + convert-index-to-llvm + convert-arith-to-llvm + convert-func-to-llvm + convert-math-to-llvm + convert-gpux-to-llvm + lower-affine + reconcile-unrealized-casts) diff --git a/test/Transforms/RemoveSingleElemVector/postop_reduce_n.mlir b/test/Transforms/RemoveSingleElemVector/postop_reduce_n.mlir index 464eb9507..0e4bde3b5 100644 --- a/test/Transforms/RemoveSingleElemVector/postop_reduce_n.mlir +++ b/test/Transforms/RemoveSingleElemVector/postop_reduce_n.mlir @@ -61,7 +61,9 @@ module { %34 = arith.remsi %11, %c4 : index %35 = scf.for %arg3 = %c0 to %c3 step %c1 iter_args(%arg4 = %cst) -> (vector<8x1xf32>) { %39 = vector.shape_cast %arg4 : vector<8x1xf32> to vector<8xf32> - //CHECK-COUNT-8: vector.extractelement {{.*}} : vector<8xf32> + // Disabling remove single elem vector.extra_stride_slice for now. + // DISABLE-CHECK-COUNT-8: vector.extractelement {{.*}} : vector<8xf32> + // CHECK-COUNT-8: vector.extract_strided_slice %40 = vector.extract_strided_slice %39 {offsets = [0], sizes = [1], strides = [1]} : vector<8xf32> to vector<1xf32> %41 = vector.extract_strided_slice %39 {offsets = [1], sizes = [1], strides = [1]} : vector<8xf32> to vector<1xf32> %42 = vector.extract_strided_slice %39 {offsets = [2], sizes = [1], strides = [1]} : vector<8xf32> to vector<1xf32>