From 8dfa7be07180d30f685120f01ef00ae972ae907d Mon Sep 17 00:00:00 2001 From: aeng-openai Date: Tue, 17 Dec 2024 10:09:08 -0800 Subject: [PATCH 01/14] [AMD] Improve matmul detection in reorder instructions pass (#5393) Previously the matmul problem checks whether there is a for loop with a single dot in a function. This doesn't work well for nested loops used for example in persistent matmul kernels. The matmul problem check is updated to consider nested for loops that contain a single tl.dot operation with at least two loads. Then, the `scheduleGlobalLoadLocalStore` transformation is applied to the whole function if the whole function is just a matmul problem. Otherwise it applies to each leaf for loop with limited scope. Also now we ensure it captures both the loop body and global loads that have been peeled out into a loop prologue by the pipeliner. --- test/TritonGPU/amd/amd-sched-2nd-load.mlir | 42 +++++ .../ReorderInstructions.cpp | 150 ++++++++++++------ 2 files changed, 147 insertions(+), 45 deletions(-) diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir index 24139f66be..f6653d9735 100644 --- a/test/TritonGPU/amd/amd-sched-2nd-load.mlir +++ b/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -61,6 +61,48 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} #dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> #smem = #ttg.shared_memory +// Should apply: tile size 256x256x128 with nested single dot +// CHECK-LABEL: nested_sink_2nd_load_256x256x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @nested_sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x256xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + scf.for %arg2 = %c0 to %c1 step %c1 : i32 { + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> tensor<256x128xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + } + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + // Should apply: tile size 256x256x64 with single dot // CHECK-LABEL: sink_2nd_load_256x256x64 // CHECK: %[[tileA:.*]] = tt.load diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index f15feb7b25..c4a2c8ea17 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -3,7 +3,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Verifier.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -17,9 +16,23 @@ namespace ttg = mlir::triton::gpu; // Utility functions //===----------------------------------------------------------------------===// -// Return true if the given moduleOp contains a pure matmul problem; i.e., -// single dot in the main loop. -static bool isPureMatmulProblem(triton::FuncOp funcOp) { +static SmallVector getLeafForOps(triton::FuncOp funcOp) { + SmallVector allOps; + funcOp->walk([&](scf::ForOp forOp) { allOps.push_back(forOp); }); + + SmallVector leafOps; + for (scf::ForOp forOp : allOps) { + auto searchResult = forOp.getBody()->walk( + [](scf::ForOp) { return WalkResult::interrupt(); }); + if (!searchResult.wasInterrupted()) + leafOps.push_back(forOp); + } + return leafOps; +} + +// Return true if the given funcOp is a pure matmul problem; i.e., +// a single main loop with a single dot. +static bool isPureMatmulFunc(triton::FuncOp funcOp) { bool isMatmul = true; bool foundLoop = false; funcOp.walk([&](scf::ForOp forOp) -> void { @@ -31,6 +44,20 @@ static bool isPureMatmulProblem(triton::FuncOp funcOp) { return foundLoop && isMatmul; } +// Return true if the given ForOp contains a pure matmul problem; i.e., +// single dot and at least 2 glboal loads in the main loop. +static bool isPureMatmulLoop(scf::ForOp forOp) { + int dotCounter = 0; + int loadCounter = 0; + forOp.walk([&](Operation *op) { + if (isa(op)) + ++dotCounter; + else if (isa(op)) + ++loadCounter; + }); + return dotCounter == 1 && loadCounter >= 2; +} + // Search through block to find earliest insertion point for move op. This can // be either an atomic op or last usage of source pointer. Search ends when move // op is encountered. @@ -214,14 +241,41 @@ static void moveUpTranspose(triton::FuncOp funcOp) { } // Schedule global load and local store ops for better GEMM performance. -static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { +static void scheduleGlobalLoadLocalStore(Operation *parentOp) { SmallVector moveOps; - // Move local_stores early if dependence distance greater than one iteration. - // Best perf on GEMM when these precede global loads. - funcOp.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); - // Move global loads early to prefetch. This may increase register pressure - // but it enables issuing global loads early. - funcOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + + // Search through the forOp initArgs to find global loads for a GEMM that + // the pipeliner may have peeled into a loop prologue. + if (auto forOp = dyn_cast(parentOp)) { + SmallVector vals = forOp.getInitArgs(); + while (!vals.empty()) { + SmallVector nextVals; // Next set of values to search via BFS. + for (size_t i = 0; i < vals.size(); ++i) { + Operation *defOp = vals[i].getDefiningOp(); + if (isa_and_nonnull(defOp)) { + moveOps.push_back(defOp); + continue; + } + + // Find uses of the op that are local_store + for (Operation *op : vals[i].getUsers()) { + if (auto storeOp = dyn_cast(op)) { + // Recurse on operands of the local_store (to find a global_load). + nextVals.push_back(storeOp.getSrc()); + } + } + } + vals.swap(nextVals); + } + } + + // Move local_store ops inside the loop early if dependence distance greater + // than one iteration (i.e., num_stages > 2). For such case, better perf on + // GEMM when local_store ops precede global loads. + parentOp->walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + // Move global_load ops inside the loop early to prefetch. This may increase + // register pressure but it enables issuing global loads early. + parentOp->walk([&](triton::LoadOp op) { moveOps.push_back(op); }); for (auto op : llvm::reverse(moveOps)) { // Gather use-def chain in block. @@ -314,38 +368,36 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { // are experimenting how to better control instruction scheduling and enable // such optimizations. //===-------------------------------------------------------------------===// -static void sinkSecondLoad(triton::FuncOp funcOp) { - funcOp.walk([&](scf::ForOp forOp) -> void { - SetVector loadOps; - triton::DotOp dotOp; - for (Operation &op : forOp) { - if (auto loadOp = dyn_cast(&op)) - loadOps.insert(loadOp); - if (auto curOp = dyn_cast(&op)) - dotOp = curOp; - } - // Only apply the optimization when there are 2 load's in the loop - if (loadOps.size() != 2) - return; - // Only apply the optimization when tile size is large enough - // 1. nonKDim >= 128 - // 2. kDim >= 64 - auto ldAOp = loadOps[0]; - auto tileAShape = cast(ldAOp.getType()).getShape(); - auto ldBOp = loadOps[1]; - auto tileBShape = cast(ldBOp.getType()).getShape(); - if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128)) - return; - // Only apply the optimization when the moving is legal - // 1. Make sure the 2nd loadOp is before the dot - // 2. Make sure the first user of the 2nd loadOp is after the dot. - bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp); - auto firstUser = *ldBOp.getResult().getUsers().begin(); - bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser); - if (isBeforeDotOp && firstUserAfterDotOp) - // move ldBOp right before tt.dot - ldBOp->moveBefore(dotOp); - }); +static void sinkSecondLoad(scf::ForOp forOp) { + SetVector loadOps; + triton::DotOp dotOp; + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) + loadOps.insert(loadOp); + if (auto curOp = dyn_cast(&op)) + dotOp = curOp; + } + // Only apply the optimization when there are 2 load's in the loop + if (loadOps.size() != 2) + return; + // Only apply the optimization when tile size is large enough + // 1. nonKDim >= 128 + // 2. kDim >= 64 + auto ldAOp = loadOps[0]; + auto tileAShape = cast(ldAOp.getType()).getShape(); + auto ldBOp = loadOps[1]; + auto tileBShape = cast(ldBOp.getType()).getShape(); + if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128)) + return; + // Only apply the optimization when the moving is legal + // 1. Make sure the 2nd loadOp is before the dot + // 2. Make sure the first user of the 2nd loadOp is after the dot. + bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp); + auto firstUser = *ldBOp.getResult().getUsers().begin(); + bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser); + if (isBeforeDotOp && firstUserAfterDotOp) + // move ldBOp right before tt.dot + ldBOp->moveBefore(dotOp); } //===----------------------------------------------------------------------===// @@ -369,9 +421,17 @@ struct TritonAMDGPUReorderInstructionsPass moveUpTranspose(funcOp); - if (isPureMatmulProblem(funcOp)) { + if (isPureMatmulFunc(funcOp)) { scheduleGlobalLoadLocalStore(funcOp); - sinkSecondLoad(funcOp); + funcOp.walk([&](scf::ForOp forOp) -> void { sinkSecondLoad(forOp); }); + } else { + SmallVector leafForOps = getLeafForOps(funcOp); + for (auto forOp : leafForOps) { + if (isPureMatmulLoop(forOp)) { + scheduleGlobalLoadLocalStore(forOp); + sinkSecondLoad(forOp); + } + } } } } From e57b46897191b3b3061c78d0d60e58e94be565b6 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 17 Dec 2024 13:21:24 -0500 Subject: [PATCH 02/14] [BACKEND] Functional fixes for layout conversion that uses `stmatrix` (#5407) This PR: 1. Refactored construction logic in `LinearLayoutConversions.cpp` for `stmatrix` selection. Note that the heuristic-based approach will be replaced with LL-driven approach once we have `divideRight` and `divideLeft`. 2. Updated `SharedLayout` class and added `has_leading_offset` attribute. 3. Added comprehensive new test cases for MMA and shared layouts. --- .../TritonGPU/IR/LinearLayoutConversions.cpp | 96 ++++++++++++------- python/test/unit/language/test_core.py | 95 ++++++++++++++---- 2 files changed, 142 insertions(+), 49 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 99a36841cf..aec87fd89e 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -904,13 +904,6 @@ LinearLayout chooseStMatrixLayoutLeadingOffset( MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, ArrayRef paddedRepShape, ArrayRef order, int swizzleByteSize) { - StringAttr kReg = S("register"); - StringAttr kLane = S("lane"); - StringAttr kWarp = S("warp"); - StringAttr kCol = S("dim1"); - StringAttr kRow = S("dim0"); - StringAttr kOffset = S("offset"); - int perPhase; int maxPhase; if (swizzleByteSize == 32) { @@ -930,45 +923,84 @@ LinearLayout chooseStMatrixLayoutLeadingOffset( // stmatrix only supports 16-bit elements, and each vector has 8 elements int elemBitWidth = 16; int vecSize = 8; - int numRows = 16; - int numCols = 8 * swizzleByteSize / elemBitWidth; + int numRowsPerTile = 16; + int numColsPerChunk = 8 * swizzleByteSize / elemBitWidth; // Construct a single stmatrix.x4 (16x16) tile std::vector> basesReg = {{1, 0}, {2, 0}, {4, 0}}; std::vector> basesLane; - for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) { + for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile); logRow++) { int row = 1 << logRow; basesLane.push_back({vecSize * ((row / perPhase) % maxPhase), row}); } basesLane.push_back({8, 0}); - // Expand the tile's register dimension to fit swizzleByteSize, which is a - // "chunk" - for (int logChunk = 0; logChunk < llvm::Log2_32(numCols / 16); logChunk++) { - int chunk = 1 << logChunk; - basesReg.push_back({16 * chunk, 0}); + auto mma = cast(tensorTy.getEncoding()); + assert(mma.getVersionMajor() >= 3 && "Only MMAv3 is supported"); + int instrM = mma.getInstrShape()[0]; + int instrN = mma.getInstrShape()[1]; + + // TODO(Keren): The following logic can be simplified by using the + // `divideLeft` function in `LinearLayout` once it's available. + // Construct the bases for a single chunk + // In theory the following situation is valid but it will be + // suboptimal. Swizzling should happen within a warp. + assert(instrN >= numColsPerChunk && + "Each chunk is filled in with a single warp"); + for (int logCol = 0; logCol < llvm::Log2_32(numColsPerChunk / 16); logCol++) { + int col = 1 << logCol; + basesReg.push_back({16 * col, 0}); } - // Construct the layout for a single chunk - LinearLayout layout = - LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow}); + // Construct the bases for warpsPerCTA[0] + std::vector> basesWarp; + auto warpsPerCTA = mma.getWarpsPerCTA(); + auto shape = tensorTy.getShape(); + for (int logWarp = 0; logWarp < llvm::Log2_32(warpsPerCTA[0]); logWarp++) { + int warp = 1 << logWarp; + basesWarp.push_back({0, warp * instrM}); + } - // Expand the `warp` dimension according to warpsPerCTA. - auto mma = cast(tensorTy.getEncoding()); - layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) - .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + // Expand the `register` dimension so the size of columns matches `shape[1] / + // warpsPerCTA[1]` + auto numColsPerWarp = std::max(instrN, shape[1] / warpsPerCTA[1]); + assert(warpsPerCTA[1] * instrN >= shape[1] && + "There must be enough columns to use MMAv3"); + auto logNumCols = llvm::Log2_32(numColsPerWarp / numColsPerChunk); + for (int logCol = 0; logCol < logNumCols; logCol++) { + int chunk = 1 << logCol; + int basis = chunk * shape[0]; + basesReg.push_back({0, basis}); + } - // Expand the `register` dimension so the size of columns matches `n`. - int n = mma.getInstrShape()[1]; - int numWarpRows = layout.getOutDimSize(kRow); - layout = (layout.reshapeOuts({{kOffset, layout.getTotalOutDimSize()}}) * - LinearLayout::identity1D(n / numCols, kReg, kOffset)) - .reshapeOuts({{kCol, n}, {kRow, numWarpRows}}); + // Expand the `register` dimension so that the size of rows matches `shape[0]` + assert(warpsPerCTA[0] * instrM <= shape[0] && + "There must be enough rows to use MMAv3"); + auto logNumRows = llvm::Log2_32(shape[0] / (warpsPerCTA[0] * instrM)); + for (int logRow = 0; logRow < logNumRows; logRow++) { + int chunk = 1 << logRow; + int basis = chunk * warpsPerCTA[0] * instrM; + basesReg.push_back({0, basis}); + } - auto ret = - combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape()); - return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames())) - .reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}}); + // Expand the `warp` dimension so that the size of cols matches `shape[1]` + for (int logWarp = 0; logWarp < llvm::Log2_32(warpsPerCTA[1]); logWarp++) { + int warp = 1 << logWarp; + if (warp * numColsPerWarp >= shape[1]) { + basesWarp.push_back({0, 0}); + } else { + int basis = (warp * numColsPerWarp) / numColsPerChunk * shape[0]; + basesWarp.push_back({0, basis}); + } + } + + auto layout = LinearLayout({{S("register"), basesReg}, + {S("lane"), basesLane}, + {S("warp"), basesWarp}, + {S("block"), {}}}, + {S("offset1"), S("offset0")}); + return layout.reshapeOuts( + {{S("offset"), layout.getTotalOutDimSize()}, {S("iteration"), 1}}); } LinearLayout chooseStMatrixLayoutNoLeadingOffset( diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 00730f2418..26fcb146e4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -179,7 +179,8 @@ def __str__(self): class SharedLayout: - def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order, + has_leading_offset=False): self.vec = vec self.per_phase = per_phase self.max_phase = max_phase @@ -187,9 +188,11 @@ def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num self.ctas_per_cga = ctas_per_cga self.cta_split_num = cta_split_num self.cta_order = cta_order + self.has_leading_offset = has_leading_offset def __str__(self): - return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + has_leading_offset_str = "true" if self.has_leading_offset else "false" + return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, hasLeadingOffset={has_leading_offset_str}}}>" def is_layout_applicable(layout) -> bool: @@ -5418,7 +5421,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t k_width=1), ] -shared_layout_3d = [ +shared_layouts_3d = [ SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), @@ -5427,8 +5430,8 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t @pytest.mark.parametrize("M, N, K", [[8, 16, 32]]) -@pytest.mark.parametrize("shared_layout", shared_layout_3d) -@pytest.mark.parametrize("dist_layout", layouts_3d) +@pytest.mark.parametrize("shared_layout", shared_layouts_3d) +@pytest.mark.parametrize("dist_layout", filter_layouts(layouts_3d)) def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: pathlib.Path): layouts = f""" #dist = {dist_layout} @@ -5500,6 +5503,72 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: assert torch.equal(z, x) +mma_layouts = [ + MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # simple 4 warps case + MmaLayout((3, 0), [8, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # simple 8 warps case + MmaLayout((3, 0), [4, 2], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # multiple warps on the row + MmaLayout((3, 0), [4, 2], [1, 1], [1, 1], [0, 1], [16, 64, 16]), # small instrN + MmaLayout((3, 0), [8, 4], [1, 1], [1, 1], [0, 1], [16, 64, 16]), # large number of warps +] + +shared_layouts = [ + SharedLayout(8, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(8, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1], has_leading_offset=True), # small contiguous bytes + SharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1], has_leading_offset=True), # maximum contiguous bytes +] + + +@pytest.mark.parametrize("M, N", [[128, 128]]) +@pytest.mark.parametrize("mma_layout", filter_layouts(mma_layouts)) +@pytest.mark.parametrize("shared_layout", shared_layouts) +def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path: pathlib.Path): + num_warps = np.prod(mma_layout.warps_per_cta) + + layouts = f""" + #dist = {mma_layout} + #shared = {shared_layout} + #smem = #ttg.shared_memory + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #dist> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dist> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dist> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<{M}x1xi32, #dist> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #dist> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> -> tensor<1x{N}xi32, #dist> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #dist> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #dist>, tensor<{M}x{N}xi32, #dist> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #dist> + %12 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #dist>) -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem> + %13 = ttg.local_load %12 : !ttg.memdesc<{M}x{N}xf16, #shared, #smem> -> tensor<{M}x{N}xf16, #dist> + %14 = tt.addptr %3, %9 : tensor<{M}x{N}x!tt.ptr, #dist>, tensor<{M}x{N}xi32, #dist> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dist> + tt.return + }} +}} +""" + + x = torch.arange(0, M * N, device=device, dtype=torch.float16).reshape(M, N) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store_mma.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + if shared_layout.has_leading_offset == "true" and mma_layout.version[0] >= 3: + assert "stmatrix" in kernel.asm["ptx"] + + mma_pairs = [ [ MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), @@ -5546,18 +5615,10 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: @pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) @pytest.mark.parametrize("dtype", ['float16']) -@pytest.mark.parametrize("mma_pair", mma_pairs) -def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path): - if is_hip(): - pytest.skip("test_mma2mma is not supported in HIP") - +@pytest.mark.parametrize("mma_pair", filter_layouts(mma_pairs)) +def test_convert_mma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path): src_layout, _ = mma_pair - if is_cuda(): - cc = torch.cuda.get_device_capability() - if cc[0] < 9 and src_layout.version[0] >= 3: - pytest.skip("Skip testing MMAv3 on devices with CC < 9") - - num_warps = np.cumprod(src_layout.warps_per_cta)[-1] + num_warps = np.prod(src_layout.warps_per_cta) def do_test(src_layout, dst_layout): layouts = f""" @@ -5593,7 +5654,7 @@ def do_test(src_layout, dst_layout): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x) - temp_file = tmp_path / "test_convertmma2mma.ttgir" + temp_file = tmp_path / "test_convert_mma2mma.ttgir" temp_file.write_text(ir) kernel = triton.compile(str(temp_file)) From a52c88aa128eb5459e35c1809c3443bf9c7bd566 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Tue, 17 Dec 2024 21:00:15 +0000 Subject: [PATCH 03/14] [STANDARD] Fix inf handling in tl.flip (#5447) Fixes #5439 Currently we end up doing `0 * inf = nan`, the fix is to bitcast to int first where `x * 0 == 0` holds. --- python/test/unit/language/test_standard.py | 23 ++++++++++++++++++++++ python/triton/language/standard.py | 10 ++++++---- python/triton/runtime/interpreter.py | 4 +++- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/python/test/unit/language/test_standard.py b/python/test/unit/language/test_standard.py index b3392d4750..abc3223601 100644 --- a/python/test/unit/language/test_standard.py +++ b/python/test/unit/language/test_standard.py @@ -75,6 +75,29 @@ def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr): assert (y == z).all(), (y, z) +@pytest.mark.interpreter +def test_flip_inf(device): + # Reproducer for https://github.com/triton-lang/triton/issues/5439 + + @triton.jit + def triton_flip_kernel(out_ptr, x_ptr, N: tl.constexpr): + pid = tl.program_id(0) + x = tl.load(x_ptr + pid * N + tl.arange(0, N)) + shape: tl.constexpr = (N // 2, 2) + y = x.reshape(shape) + y = tl.flip(y, dim=1).reshape(x.shape) + tl.store(out_ptr + pid * N + tl.arange(0, N), y) + + x = torch.arange(0, 16, device=device).unsqueeze(0).float() + x[:, -1] = float('inf') + + expect = x.reshape(-1, 8, 2).flip(-1).reshape(-1, 16) + actual = torch.empty_like(x) + triton_flip_kernel[(x.shape[0], )](actual, x, x.shape[1]) + + torch.testing.assert_close(expect, actual) + + @pytest.mark.interpreter @pytest.mark.parametrize("size_i, size_j, size_g", [[5, 7, 3]]) def test_swizzle2d(size_i, size_j, size_g, device): diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index eaaeffb687..ff9a5efe24 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -412,11 +412,13 @@ def flip(x, dim=None): """ core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)])) core.static_assert(_is_power_of_two(x.numel)) - # # reshape the tensor to have all dimensions be 2. - # # TODO: We shouldn't have to change the dimensions not sorted. + # reshape the tensor to have all dimensions be 2. + # TODO: We shouldn't have to change the dimensions not sorted. steps: core.constexpr = _log2(x.numel) start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)]) - y = core.reshape(x, [2] * steps) + + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + y = core.reshape(x.to(idtype, bitcast=True), [2] * steps) y = core.expand_dims(y, start) flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2)) for i in core.static_range(start, steps): @@ -425,7 +427,7 @@ def flip(x, dim=None): if j != i and j != i + 1: flip2 = core.expand_dims(flip2, j) y = sum(y * flip2, i + 1, keep_dims=True) - x = core.reshape(y, x.shape) + x = core.reshape(y, x.shape).to(x.dtype, bitcast=True) return x diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 3b94f55ea3..d3cc67afcf 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -726,10 +726,12 @@ def check_tensor(self, input): self.check_axis(arg.shape, self.axis) def to_tensor(self, ret, dtype): + np_dtype = _get_np_dtype(dtype) if hasattr(ret, "shape") and ret.shape: + ret = ret.astype(np_dtype) ret_type = tl.block_type(dtype, list(ret.shape)) else: - ret = np.array([ret]).astype(_get_np_dtype(dtype)) + ret = np.array([ret], dtype=np_dtype) ret_type = dtype return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) From 80e2abdfa359dbb8efc386efbd47c6ed359ad205 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 17 Dec 2024 21:56:58 -0500 Subject: [PATCH 04/14] [BACKEND] Remove decomposition of splat -> shared conversion (#5450) --- .../Conversion/TritonGPUToLLVM/Patterns.h | 4 --- .../DecomposeUnsupportedConversions.cpp | 26 ------------------- .../DecomposeUnsupportedConversions.cpp | 2 -- .../DecomposeUnsupportedConversions.cpp | 1 - 4 files changed, 33 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Patterns.h b/include/triton/Conversion/TritonGPUToLLVM/Patterns.h index ac13ecc28c..c10ce46a3d 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Patterns.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Patterns.h @@ -13,10 +13,6 @@ namespace triton::gpu { /// |module| op because the codegen doesn't handle `blocked -> dot_op` directly. void decomposeBlockedToDotLayoutConversion(ModuleOp module); -/// Replaces `splat -> shared` with `splat -> blocked -> shared` in the given -/// |module| op. -void decomposeSplatOpToSharedLayoutConversion(ModuleOp module); - /// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the /// given |module| op, but bypass the decomposition if |shortcutFn| returns /// true. diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index d5afb6e2b1..9078c229ce 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -18,32 +18,6 @@ static void addAttrs(Operation *op, ArrayRef attrs) { namespace mlir::triton::gpu { -void decomposeSplatOpToSharedLayoutConversion(ModuleOp module) { - int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); - int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); - int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); - module.walk([&](triton::SplatOp splatOp) -> void { - auto dstType = cast(splatOp.getType()); - auto shared = - dyn_cast(dstType.getEncoding()); - if (shared) { - OpBuilder builder(splatOp); - SmallVector sizePerThread(dstType.getRank(), 1); - auto newType = RankedTensorType::get( - dstType.getShape(), dstType.getElementType(), - triton::gpu::BlockedEncodingAttr::get( - module.getContext(), dstType.getShape(), sizePerThread, - getOrder(shared), numWarps, threadsPerWarp, numCTAs)); - auto newSplat = builder.create(splatOp.getLoc(), newType, - splatOp.getSrc()); - auto newConvert = builder.create( - splatOp.getLoc(), dstType, newSplat.getResult()); - splatOp.replaceAllUsesWith(newConvert.getResult()); - splatOp.erase(); - } - }); -} - void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, ShortcutFn shortcutFn) { int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index bbacde54b0..53b0847183 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -34,8 +34,6 @@ struct DecomposeUnsupportedAMDConversions int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); - auto isShortcut = mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory)); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index e8ddf87104..f184ea2b47 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -76,7 +76,6 @@ struct DecomposeUnsupportedConversions auto nvidiaShortCutFn = [&](RankedTensorType srcTy, RankedTensorType dstTy) { return true; }; ModuleOp mod = getOperation(); - triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, nvidiaShortCutFn); triton::gpu::decomposeBlockedToDotLayoutConversion(mod); From 137bc62102f4a261cc921998221cea2b046a6c1b Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:07:41 +0100 Subject: [PATCH 05/14] [LAYOUTS] Implement generic layout propagation through ReshapeOp (#5389) This PR also: - Enables backward rematerialisation and hoisting for LLs - Adds a fold reshape(cvt) -> reshape when the layouts are structurally the same - Removes an assert that was disallowing the use of LLs across broadcast. When this happens, the LL will not have the same shape as the tensor. We do this to match the legacy behaviour and avoid the proliferation of new layouts - Removes the layout-specific tests from before and instead we create functional tests that test the axioms for the reshape function. We see that all the legacy layouts pass these tests. - Temporarily tested that the legacy path and the new path agree in CI in https://github.com/triton-lang/triton/pull/5389/commits/e93638b98409f8dc4c7c41aa644d0501f2630a77 --- include/triton/Dialect/Triton/IR/Dialect.h | 18 +- lib/Dialect/Triton/IR/Ops.cpp | 30 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 119 ++++-- lib/Dialect/TritonGPU/IR/Ops.cpp | 10 + .../Transforms/RemoveLayoutConversions.cpp | 9 +- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 15 +- unittest/Dialect/TritonGPU/DialectTest.cpp | 372 +++--------------- 7 files changed, 184 insertions(+), 389 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 56a1aa7032..39d006cc65 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -55,13 +55,19 @@ class DialectInferLayoutInterface // Tries to compute the encoding for the result of a reshape operation that // makes the reshape a "nop", i.e. the same GPU threads contain the same - // elements as before the reshape. Note that this is not always possible (in - // which case you'd need to choose a different layout for the input to the - // reshape). + // elements as before the reshape using legacy layouts. This is not always + // possible (in which case we fallback to using LinearLayouts) + // In the future we'll always use LinearLayouts virtual LogicalResult - inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const = 0; + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; + + // Check if two layouts are structurally the same, even if their names are + // different + virtual LogicalResult verifyLayoutsAreEqual(ArrayRef shape, + Attribute expected, Attribute got, + Location loc) const = 0; virtual LogicalResult inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index ab32bd992b..12a237924a 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -8,6 +8,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Tools/LinearLayout.h" #include "llvm/Support/ErrorHandling.h" namespace mlir { @@ -701,24 +702,21 @@ LogicalResult ReshapeOp::verify() { "encodings, or (b) neither does."); } - if (srcEnc && !getAllowReorder()) { - Attribute inferredDstEnc; - if (cast(&srcEnc.getDialect()) - ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, - dstTy.getShape(), inferredDstEnc, - getLoc()) - .failed()) { - return emitError("This reshape is impossible without reordering, but " - "reordering is not allowed. Try choosing a different " - "encoding for the input tensor (or allow reordering)."); - } - if (inferredDstEnc != dstEnc) { - return emitError("Expected result encoding ") - << inferredDstEnc << " but was " << dstEnc; - } + if (!srcEnc || getAllowReorder()) { + return success(); } - return success(); + // Check that we can infer the dst encoding from the src encoding + // and that the inferred dst encoding is the same as the given dst encoding + Attribute inferredDstEnc; + auto result = + cast(&srcEnc.getDialect()) + ->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(), + inferredDstEnc, getLoc()); + assert(succeeded(result)); + return cast(&srcEnc.getDialect()) + ->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc, + getLoc()); } //-- FpToFpOp -- diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a10174ad3e..337328f650 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1598,11 +1598,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { SmallVector LinearEncodingAttr::getElemsPerThread(ArrayRef shape, Type) const { - // We can relax this assert by calling toLinearLayout rather than - // getLinearLayout - SmallVector shapeVec(shape.begin(), shape.end()); - assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes())); - auto ll = getLinearLayout(); + // When broadcasting the layout the shape changes, otherwise the shape is + // the same as the shape of the tensor + // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep + // the invariant that the shape of the LL is that of the tensor + // We choose the former for BC + auto ll = *toLinearLayout(shape); return basesPerDim(ll, StringAttr::get(getContext(), "register")); } @@ -2623,8 +2624,8 @@ struct TritonGPUInferLayoutInterface // contains elements [a,b,c,d] before the reshape, it contains those same // elements after the reshape, they're just "renamed". // - // A dst encoding that satisfies this property does not exist for all inputs. - // Here are some positive and negative examples. + // Using legacy layouts, a dst encoding that satisfies this property may not + // exist. Here are some positive and negative examples. // // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so // dim 1 is the fastest-changing in the dst, but the src has the opposite @@ -2638,17 +2639,19 @@ struct TritonGPUInferLayoutInterface // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will // contain the same elements as before. // + // With linear layouts, we can always find a dst encoding that satisfies + // this property. See inferReshapeOpEncoding. + // // Users of this function require that it is symmetrical: if // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => // srcEnc. - LogicalResult - inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const override { + LogicalResult inferReshapeOpLegacyEncoding(ArrayRef srcShape, + Attribute srcEnc, + ArrayRef dstShape, + Attribute &dstEnc) const { auto src = mlir::dyn_cast(srcEnc); if (!src) { - return emitOptionalError( - loc, "Non-reordering reshape only supports BlockedEncoding"); + return failure(); } // Nop reshape; we can always infer an encoding. @@ -2681,9 +2684,7 @@ struct TritonGPUInferLayoutInterface // to handle CTASplitNum. if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { - return emitOptionalError( - loc, "Non-reordering reshape does not currently support multi-CTA " - "layouts other than the default layout."); + return failure(); } // Cowardly refuse to handle encodings where shape[dim] is not divisible by @@ -2693,12 +2694,7 @@ struct TritonGPUInferLayoutInterface for (int dim = 0; dim < srcShape.size(); dim++) { if (srcShape[dim] >= subblock[dim] && srcShape[dim] % subblock[dim] != 0) { - return emitOptionalError(loc, - "Can't do a non-reordering reshape because " - "the size of dimension ", - dim, " (", srcShape[dim], ")", - " is not divisible by ", name, "[", dim, "]", - " = ", subblock[dim]); + return failure(); } } return success(); @@ -2723,11 +2719,7 @@ struct TritonGPUInferLayoutInterface // physical order, with `a` being the most major. for (const auto &[srcDims, dstDims] : decomp) { if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) { - return emitOptionalError(loc, - "Cannot do a non-reordering reshape given " - "this src encoding order. Dimensions [", - join(srcDims), - "] must be physically consecutive."); + return failure(); } } @@ -2774,11 +2766,7 @@ struct TritonGPUInferLayoutInterface // Check that more-minor dims all have 1 in shapeRemaining. for (int j = i + 1; j < srcDims.size(); j++) { if (shapeRemaining[j] != 1) { - return emitOptionalError( - loc, - "Invalid src encoding for non-reordering reshape. Must use " - "up sizePerThread / threadsPerWarp / warpsPerCTA for " - "more-minor dimensions before more major-dims can use them."); + return failure(); } } @@ -2793,13 +2781,7 @@ struct TritonGPUInferLayoutInterface // only if we're the most-major dimension of the chunk and in all // future chunks, only this most-major dim has a non-1 size. if (shapeRemaining[i] == 0 && i != 0) { - return emitOptionalError( - loc, - "Invalid src encoding for non-reordering reshape. Block " - "size in dimension ", - dim, - " is larger than the shape that dimension, but this is only " - "allowed for the most-major dimension of a reshape chunk"); + return failure(); } } return success(); @@ -2889,6 +2871,65 @@ struct TritonGPUInferLayoutInterface return success(); } + LogicalResult verifyLayoutsAreEqual(ArrayRef shape, + Attribute expected, Attribute got, + Location loc) const override { + if (expected == got) { + return success(); + } + // Check whether the encodings are structurally the same. + auto expectedLL = triton::gpu::toLinearLayout(shape, expected); + auto gotLL = triton::gpu::toLinearLayout(shape, got); + if (expectedLL != gotLL) { + return emitError(loc, "Expected result encoding ") + << expected << " but was " << got; + } + return success(); + } + + LogicalResult + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { + auto result = + inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc); + if (succeeded(result)) { + return result; + } + + // If the legacy encoding failed use LinearLayouts. + // Once LinearLayouts are more widely used, we can remove + // inferReshapeOpLegacyEncoding and simply use LLs. + auto *ctx = getContext(); + auto src = triton::gpu::toLinearLayout(srcShape, srcEnc); + if (!src) { + return emitOptionalError(loc, + "src encoding does not support linear layout"); + } + + if (product(srcShape) != product(dstShape)) { + return emitOptionalError(loc, "numel of dst shape does not match " + "numel of src shape"); + } + + auto newRank = dstShape.size(); + SmallVector> newOutDims; + for (auto [dim, size] : + llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) { + newOutDims.emplace_back(dim, size); + } + auto srcOutDims = llvm::to_vector(src->getOutDimNames()); + // reshapeOp assumes minor-to-major, so we need to transpose the out dims + // before the reshape + std::reverse(srcOutDims.begin(), srcOutDims.end()); + std::reverse(newOutDims.begin(), newOutDims.end()); + auto dst = src->transposeOuts(srcOutDims) + .reshapeOuts(newOutDims) + .transposeOuts(standardOutDimNames(ctx, newRank)); + dstEnc = LinearEncodingAttr::get(ctx, dst); + return success(); + } + LogicalResult inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, std::optional loc) const override { diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index ff9fded9ef..39d52ac891 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -42,6 +42,16 @@ struct CanonicalizeConvertFromReshape auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); + // If the layouts are structurally the same, the convert is trivial + auto srcType = convert.getSrc().getType(); + auto dstType = convert.getType(); + auto srcLL = toLinearLayout(srcType.getShape(), srcType.getEncoding()); + auto dstLL = toLinearLayout(dstType.getShape(), dstType.getEncoding()); + if (srcLL && dstLL && *srcLL == *dstLL) { + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getAllowReorder()); + return mlir::success(); + } if (isExpensiveView(convert.getSrc().getType(), op.getType())) return failure(); if (!op.getAllowReorder() || op.getEfficientLayout()) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index a56b6f7977..e0e048415d 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -1025,9 +1025,7 @@ void LayoutRematerialization::backwardRematerialization( // we don't handle conversions to DotOperandEncodingAttr // this is a heuristic to accommodate fused attention RankedTensorType targetType = convertOp.getType(); - // We stop the rematerialization of linear layouts as we have to be a bit more - // careful with the heuristics for both correctness and perf - if (isa(targetType.getEncoding())) + if (isa(targetType.getEncoding())) return; Value oldV = convertOp.getSrc(); LDBG("check backward remat with source " << oldV << " encoding " @@ -1069,11 +1067,8 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention - // We stop the rematerialization of linear layouts as we have to be a bit more - // careful with the heuristics for both correctness and perf RankedTensorType targetType = convertOp.getType(); - if (mlir::isa( - targetType.getEncoding())) + if (isa(targetType.getEncoding())) return; auto isExtOrBroadcastOp = [](Operation *op) { diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 20ac0954ad..7fe4576cb8 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -407,14 +407,13 @@ static Attribute inferReshapeOpDstEncoding(ArrayRef srcShape, return {}; Attribute dstEnc; - if (succeeded( - srcEnc.getDialect() - .getRegisteredInterface() - ->inferReshapeOpNoReorderEncoding( - srcShape, srcEnc, dstShape, dstEnc, /*loc=*/std::nullopt))) { - return dstEnc; - } - return {}; + auto result = + srcEnc.getDialect() + .getRegisteredInterface() + ->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc, + /*loc=*/std::nullopt); + assert(succeeded(result)); + return dstEnc; } static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) { diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 7a34955a96..bff54d0f9d 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -77,135 +77,6 @@ int64_t getFlatIdx(ArrayRef idx, ArrayRef shape, return flatIdx; } -// Represents the many indices of one element of a tensor with a -// BlockedEncoding. -// -// The purpose of this class is we can say, if two MultiIdx's have the same -// flatFoo values before and after a reshape, then the same GPU thread contains -// the same element (and the reshape is a nop, at least for that element). -struct MultiIdx { - using Vec = SmallVector; - - // Logical index into the tensor. - Vec idx; - - // If the tensor's encoding has e.g. numPerThread = [2,2], then idxInThread - // tells us which of the four elements per thread this is. Same for idxInWarp - // and idxInCTA. - Vec idxInThread; - Vec idxInWarp; - Vec idxInCTA; - - // If the tensor's encoding defines a block of size [x,y,z], the tensor itself - // may be larger than this, comprising multiple blocks. This tells us which - // block we're in. - Vec idxOuter; - - // flatIdx is flattened according to the tensor's logical order (i.e. ignoring - // the encoding). The others are flattened according to the tensor's physical - // encoding. - int64_t flatIdx; - int64_t flatIdxInThread; - int64_t flatIdxInWarp; - int64_t flatIdxInCTA; - int64_t flatIdxOuter; -}; - -bool sameFlatIdxs(const MultiIdx &a, const MultiIdx &b) { - return a.flatIdx == b.flatIdx && // - a.flatIdxInThread == b.flatIdxInThread && - a.flatIdxInWarp == b.flatIdxInWarp && - a.flatIdxInCTA == b.flatIdxInCTA && // - a.flatIdxOuter == b.flatIdxOuter; -} - -std::string multiIdxsToString(ArrayRef> idxs) { - std::stringstream ss; - for (const auto &idxPtr : idxs) { - const MultiIdx &idx = *idxPtr; - ss // - << " [" << triton::join(idx.idx, ",") << "] (" << idx.flatIdx << ") " - << "elem=[" << triton::join(idx.idxInThread, ",") << "] (" - << idx.flatIdxInThread << ") " - << "thread=[" << triton::join(idx.idxInWarp, ",") << "] (" - << idx.flatIdxInWarp << ") " - << "warp=[" << triton::join(idx.idxInCTA, ",") << "] (" - << idx.flatIdxInCTA << ") " - << "outer=[" << triton::join(idx.idxOuter, ",") << "] (" - << idx.flatIdxOuter << ")\n"; - } - return ss.str(); -} - -std::vector> getMultiIdxs(ArrayRef shape, - BlockedEncodingAttr enc) { - using Vec = MultiIdx::Vec; - - const unsigned rank = shape.size(); - auto sizePerThread = enc.getSizePerThread(); - auto threadsPerWarp = enc.getThreadsPerWarp(); - auto warpsPerCTA = enc.getWarpsPerCTA(); - auto order = enc.getOrder(); - - Vec numBlocks; - for (int i = 0; i < rank; i++) { - numBlocks.push_back(ceil( - shape[i], sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i])); - } - - Vec idxInThread(rank, 0); - Vec idxInWarp(rank, 0); - Vec idxInCTA(rank, 0); - Vec idxOuter(rank, 0); - - int64_t nElems = product(sizePerThread) * product(threadsPerWarp) * - product(warpsPerCTA) * product(numBlocks); - - // We eventually sort this array, and if the elements are plain MultiIdx - // elements rather than pointers, we have to swap them, which ends up being - // expensive. - std::vector> elems; - elems.reserve(nElems); - - for (int64_t i = 0; i < nElems; i++) { - auto e = std::make_unique(); - e->idxInThread = idxInThread; - e->idxInWarp = idxInWarp; - e->idxInCTA = idxInCTA; - e->idxOuter = idxOuter; - - for (int i = 0; i < rank; i++) { - e->idx.push_back( // - idxInThread[i] + // - idxInWarp[i] * sizePerThread[i] + - idxInCTA[i] * sizePerThread[i] * threadsPerWarp[i] + - idxOuter[i] * sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]); - } - - e->flatIdxInThread = getFlatIdx(e->idxInThread, sizePerThread, order); - e->flatIdxInWarp = getFlatIdx(e->idxInWarp, threadsPerWarp, order); - e->flatIdxInCTA = getFlatIdx(e->idxInCTA, warpsPerCTA, order); - e->flatIdxOuter = getFlatIdx(e->idxOuter, numBlocks, order); - e->flatIdx = getFlatIdx(e->idx, shape, - llvm::to_vector(llvm::reverse(llvm::seq(rank)))); - - elems.push_back(std::move(e)); - - if (advance(idxInThread, sizePerThread, order)) { - if (advance(idxInWarp, threadsPerWarp, order)) { - if (advance(idxInCTA, warpsPerCTA, order)) { - advance(idxOuter, numBlocks, order); - } - } - } - } - llvm::sort(elems, [](const std::unique_ptr &a, - const std::unique_ptr &b) { - return a->flatIdx < b->flatIdx; - }); - return elems; -} - class InferLayoutTest : public ::testing::Test { public: InferLayoutTest() @@ -221,25 +92,12 @@ class InferLayoutTest : public ::testing::Test { /*static*/ MLIRContext InferLayoutTest::ctx; -// The optional outparam couldReshape tells the caller whether the reshape -// worked. You might want this to be a return value instead, but gtest ASSERT -// and FAIL have an implicit `return`, so only work in fns that return void. void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, std::optional expectedDstEnc, - std::optional expectSuccess, DialectInferLayoutInterface *inferLayout, - bool longErrors = true, bool *couldReshape = nullptr) { - std::unique_ptr couldReshapeStorage; - if (!couldReshape) { - couldReshapeStorage = std::make_unique(); - couldReshape = couldReshapeStorage.get(); - } - *couldReshape = false; + bool longErrors = true) { MLIRContext *ctx = srcTy.getContext(); - ASSERT_TRUE(expectSuccess || !dstTy.getEncoding()) - << "dstTy shouldn't have an expected encoding if we're expecting the " - "reshape to be impossible!"; // Capture any errors from calling inferReshapeNoOpReorderEncoding, so we can // print them if we expected the reshape to succeed but it failed. @@ -249,29 +107,17 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, { ScopedDiagnosticHandler scopedHandler( ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); - result = inferLayout->inferReshapeOpNoReorderEncoding( + result = inferLayout->inferReshapeOpEncoding( srcTy.getShape(), srcTy.getEncoding(), dstTy.getShape(), inferredEnc, UnknownLoc::get(ctx)); } - if (!expectSuccess.has_value() && !succeeded(result)) { - // We didn't know whether or not it was supposed to succeed, and it didn't. - // Test passes! - return; - } - - if (expectSuccess.has_value() && !*expectSuccess) { - EXPECT_FALSE(succeeded(result)) - << "Expected reshape to be impossible, but got dst encoding: " - << stringifyLLVMType(inferredEnc); - *couldReshape = true; - return; - } + // We expect the reshape to succeed as long as the inputs have the same + // number of elements + EXPECT_TRUE(succeeded(result)) + << "Expected reshape to succeed, but it didn't! Error(s):\n" + << join(diags, "\n"); - if (!succeeded(result)) { - FAIL() << "Expected reshape to succeed, but it didn't! Error(s):\n" - << join(diags, "\n"); - } if (auto expectedEnc = dstTy.getEncoding()) { EXPECT_EQ(inferredEnc, expectedEnc); } @@ -279,12 +125,14 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, // We know that infer(srcShape, srcEnc, dstShape) => dstEnc. Check that it // works the other way around too: infer(dstShape, dstEnc, srcShape) => // srcEnc. (This is an invariant of the inference function.) + // Even more, we check that the inferred encoding is structurally the same as + // the src encoding, showing that the inference is consistent. { std::vector diags; ScopedDiagnosticHandler scopedHandler( ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); Attribute inferredSrcEnc; - auto result = inferLayout->inferReshapeOpNoReorderEncoding( + auto result = inferLayout->inferReshapeOpEncoding( dstTy.getShape(), inferredEnc, srcTy.getShape(), inferredSrcEnc, UnknownLoc::get(ctx)); EXPECT_TRUE(succeeded(result)) @@ -292,56 +140,40 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, << " " << stringifyLLVMType(inferredEnc) << " -> " << triton::join(srcTy.getShape(), "x") << "failed:\n" << join(diags, "\n"); - if (succeeded(result)) { - EXPECT_EQ(inferredSrcEnc, srcTy.getEncoding()) - << "Inverse encoding inference (" - << triton::join(dstTy.getShape(), "x") << " " - << stringifyLLVMType(inferredEnc) << " -> " - << triton::join(srcTy.getShape(), "x") - << " gave the wrong result. Expected " - << stringifyLLVMType(srcTy.getEncoding()) << " but got " - << stringifyLLVMType(inferredSrcEnc) << ".\n"; - } + auto srcLinear = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + auto inferredSrcLinear = toLinearLayout(srcTy.getShape(), inferredSrcEnc); + EXPECT_EQ(inferredSrcLinear, srcLinear) + << "Inverse encoding inference (" << triton::join(dstTy.getShape(), "x") + << " " << stringifyLLVMType(inferredEnc) << " -> " + << triton::join(srcTy.getShape(), "x") + << " gave the wrong result. Expected " << srcLinear->toString() + << " but " + << "got " << inferredSrcLinear->toString() << ".\n"; } - std::vector> srcMultiIdxs = - getMultiIdxs(SmallVector(srcTy.getShape()), - mlir::cast(srcTy.getEncoding())); - - std::vector> dstMultiIdxs = - getMultiIdxs(SmallVector(dstTy.getShape()), - mlir::cast(inferredEnc)); - - if (srcMultiIdxs.size() != dstMultiIdxs.size() || - !llvm::all_of(llvm::zip_equal(srcMultiIdxs, dstMultiIdxs), - [](const auto &pair) { - const auto &[a, b] = pair; - return sameFlatIdxs(*a, *b); - })) { - SCOPED_TRACE(longErrors ? "dst indices:\n" + multiIdxsToString(dstMultiIdxs) - : ""); - SCOPED_TRACE(longErrors ? "src indices:\n" + multiIdxsToString(srcMultiIdxs) - : ""); - ADD_FAILURE() << "Reified indices do not match for encodings:\n" - << " src: [" << triton::join(srcTy.getShape(), "x") << "] " - << stringifyLLVMType(srcTy.getEncoding()) << "\n" - << " dst: [" << triton::join(dstTy.getShape(), "x") << "] " - << stringifyLLVMType(inferredEnc); - } else { - *couldReshape = true; - } + // The funtional characterisation of resize is that, if we have a srcLayout + // and a dstLayout, then the flattened layouts are views of the same data + // when considered as C-contiguous. + auto makeFlattenedCContig = [](ArrayRef shape, Attribute layout) { + auto ctx = layout.getContext(); + auto linear = *toLinearLayout(shape, layout); + auto dims = standardOutDimNames(ctx, shape.size()); + std::reverse(dims.begin(), dims.end()); + return linear.transposeOuts(dims).reshapeOuts( + {{dims.back(), linear.getTotalOutDimSize()}}); + }; + EXPECT_EQ(makeFlattenedCContig(srcTy.getShape(), srcTy.getEncoding()), + makeFlattenedCContig(dstTy.getShape(), inferredEnc)); } -class InferReshapeOpNoReorderEncodingTest +class InferReshapeOpEncodingTest : public InferLayoutTest, public ::testing::WithParamInterface< - std::tuple> {}; + std::tuple> {}; -TEST_P(InferReshapeOpNoReorderEncodingTest, DoIt) { +TEST_P(InferReshapeOpEncodingTest, DoIt) { std::string srcTyStr = expandTyStr(std::get<0>(GetParam())); std::string dstTyStr = expandTyStr(std::get<1>(GetParam())); - bool expectSuccess = std::get<2>(GetParam()); auto src = mlir::parseType(srcTyStr, &ctx); if (!src) @@ -357,7 +189,7 @@ TEST_P(InferReshapeOpNoReorderEncodingTest, DoIt) { } testReshape(cast(src), cast(dst), - expectedDstEnc, expectSuccess, inferLayout, /*longErrors=*/true); + expectedDstEnc, inferLayout, /*longErrors=*/true); } // A testcase of {a, b, c} means: @@ -368,158 +200,72 @@ TEST_P(InferReshapeOpNoReorderEncodingTest, DoIt) { // encoding that makes the reshape a nop, and // - if b has an encoding, check that the inferred encoding matches b's. INSTANTIATE_TEST_SUITE_P( - Reshapes, InferReshapeOpNoReorderEncodingTest, - ::testing::ValuesIn(std::vector< - std::tuple>({ + Reshapes, InferReshapeOpEncodingTest, + ::testing::ValuesIn(std::vector>({ // Use raw strings in here so clang-format doesn't try to wrap them. {R"(T<128x64xf32, #B<{spt=[1,1], tpw=[1,32], wpc=[1,1], ord=[1,0]}>>)", - R"(T<8192xf32, #B<{spt=[1], tpw=[32], wpc=[1], ord=[0]}>>)", - true}, + R"(T<8192xf32, #B<{spt=[1], tpw=[32], wpc=[1], ord=[0]}>>)"}, {R"(T<128xf32, #B<{spt=[4], tpw=[32], wpc=[1], ord=[0]}>>)", - R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", - true}, + R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)"}, {R"(T<128xf32, #B<{spt=[4], tpw=[32], wpc=[1], ord=[0]}>>)", - R"(T<16x8xf32, #B<{spt=[1,4], tpw=[16,2], wpc=[1,1], ord=[1,0]}>>)", - true}, + R"(T<16x8xf32, #B<{spt=[1,4], tpw=[16,2], wpc=[1,1], ord=[1,0]}>>)"}, {R"(T<32x32xf32, #B<{spt=[2,2], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", - "T<128xf32>", false}, + "T<1024xf32>"}, {R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", - R"(T<2x16x2x2xf32, #B<{spt=[1,1,2,2], tpw=[2,16,1,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", - true}, + R"(T<2x16x2x2xf32, #B<{spt=[1,1,2,2], tpw=[2,16,1,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, {R"(T<4x32xf32, #B<{spt=[4,1], tpw=[1,32], wpc=[1,1], ord=[0,1]}>>)", - R"(T<2x2x2x16xf32, #B<{spt=[2,2,1,1], tpw=[1,1,2,16], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)", - true}, + R"(T<2x2x2x16xf32, #B<{spt=[2,2,1,1], tpw=[1,1,2,16], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)"}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<2x16x2x16xf32, #B<{spt=[1,4,1,4], tpw=[1,4,2,4], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", - true}, + R"(T<2x16x2x16xf32, #B<{spt=[1,4,1,4], tpw=[1,4,2,4], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<16x2x16x2xf32, #B<{spt=[2,2,2,2], tpw=[4,1,8,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", - true}, + R"(T<16x2x16x2xf32, #B<{spt=[2,2,2,2], tpw=[4,1,8,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[0,1]}>>)", - R"(T<16x2x16x2xf32>)", true}, + R"(T<16x2x16x2xf32>)"}, // nop reshape, but the block size is 2x larger than the tensor. {R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)", - R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)", - true}, + R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)"}, {R"(T<2x4x2x4xf32, #B<{spt=[1,2,2,1], tpw=[1,2,1,2], wpc=[1,2,2,1], ord=[2,1,0,3]}>>)", - R"(T<4x2x2x4xf32>)", false}, + R"(T<4x2x2x4xf32>)"}, {R"(T<1x2x2x4xf32, #B<{spt=[1,32,4,4], tpw=[4,4,16,16], wpc=[8,8,8,1], ord=[0,1,2,3]}>>)", - R"(T<2x2x4x1xf32>)", false}, + R"(T<2x2x4x1xf32>)"}, {R"(T<2x2x2x2xf32, #B<{spt=[2,2,2,2], tpw=[1,1,1,1], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)", - R"(T<4x4xf32>)", true}, + R"(T<4x4xf32>)"}, {R"(T<16x8xf32, #B<{spt=[1,2], tpw=[2,4], wpc=[2,1], ord=[1,0]}>>)", - R"(T<128xf32>)", true}, + R"(T<128xf32>)"}, {R"(T<16x1x8xf32, #B<{spt=[8,1,1], tpw=[2,1,1], wpc=[1,1,8], ord=[2,1,0]}>>)", - R"(T<128x1xf32>)", false}, + R"(T<128x1xf32>)"}, {R"(T<16x1x8xf32, #B<{spt=[1,1,8], tpw=[2,1,1], wpc=[8,1,1], ord=[2,1,0]}>>)", - R"(T<128x1xf32>)", true}, + R"(T<128x1xf32>)"}, {R"(T<32x32xf32, #B<{spt=[1,2], tpw=[1,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<1024xf32>)", true}, + R"(T<1024xf32>)"}, {R"(T<4x4xf32, #B<{spt=[1,1], tpw=[2,4], wpc=[2,1], ord=[0,1]}>>)", - R"(T<16xf32>)", false}, + R"(T<16xf32>)"}, {R"(T<32xf32, #B<{spt=[2], tpw=[32], wpc=[2], ord=[0]}>>)", - R"(T<16x2xf32, #B<{spt=[1,2], tpw=[32,1], wpc=[2,1], ord=[1,0]}>>)", - true}, + R"(T<16x2xf32, #B<{spt=[1,2], tpw=[32,1], wpc=[2,1], ord=[1,0]}>>)"}, {R"(T<2x1x2xf32, #B<{spt=[2,1,1], tpw=[2,1,2], wpc=[4,1,8], ord=[2,1,0]}>>)", - R"(T<2x2xf32, #B<{spt=[2,1], tpw=[2,2], wpc=[4,8], ord=[1,0]}>>)", - true}, + R"(T<2x2xf32, #B<{spt=[2,1], tpw=[2,2], wpc=[4,8], ord=[1,0]}>>)"}, }))); -TEST_F(InferLayoutTest, FuzzReshape) { - const int numTests = 1000; // Increase to get more coverage. - - std::minstd_rand rng(/*seed=*/0); - auto randPow2Vec = [&](int rank, int maxPow2) { - SmallVector ret; - for (int i = 0; i < rank; i++) { - int pow2 = std::uniform_int_distribution(0, maxPow2)(rng); - if (pow2 == maxPow2 && maxPow2 > 0) { - maxPow2--; - } - ret.push_back(1 << pow2); - } - return ret; - }; - - int numSuccess = 0; - for (int i = 0; i < numTests; i++) { - SCOPED_TRACE("iteration " + std::to_string(i)); - int rank = std::uniform_int_distribution(1, 4)(rng); - - SmallVector srcShape( - convertType(randPow2Vec(rank, /*maxPow2=*/4))); - SmallVector dstShape = srcShape; - std::shuffle(dstShape.begin(), dstShape.end(), rng); - - // Optionally merge some dimensions in dst. - for (int i = 1; i < dstShape.size(); i++) { - if (std::uniform_real_distribution(0, 1)(rng) > 1.0 / rank) { - dstShape[i - 1] *= dstShape[i]; - dstShape.erase(dstShape.begin() + i); - i--; - } - } - - SmallVector sizePerThread = randPow2Vec(rank, /*maxPow2=*/3); - SmallVector threadsPerWarp = randPow2Vec(rank, /*maxPow2=*/3); - SmallVector warpsPerCTA = randPow2Vec(rank, /*maxPow2=*/3); - - SmallVector order(llvm::to_vector(llvm::seq(rank))); - std::shuffle(order.begin(), order.end(), rng); - - auto ctaLayout = CTALayoutAttr::get( - &ctx, SmallVector(rank, 1), SmallVector(rank, 1), - llvm::to_vector(llvm::reverse(llvm::seq(rank)))); - - auto srcTy = RankedTensorType::get( - srcShape, FloatType::getF32(&ctx), - BlockedEncodingAttr::get(&ctx, sizePerThread, threadsPerWarp, - warpsPerCTA, order, ctaLayout)); - auto dstTy = RankedTensorType::get(dstShape, FloatType::getF32(&ctx)); - - bool couldReshape = false; - testReshape(srcTy, dstTy, /*expectedDstEnc=*/std::nullopt, - /*expectSuccess=*/std::nullopt, inferLayout, - /*longErrors=*/false, &couldReshape); - if (couldReshape) - numSuccess++; - } - - // We don't expect or want 100% success, but if only a tiny fraction of tests - // actually exercise the successful reshape logic, then that gives us bad - // coverage. I'm currently getting 35% success, which seems good enough, - // especially since the successful cases take a lot longer to run because of - // the MultiIdx checks (so we're spending most of our time on successful - // cases, even if they're only 1/3 of the iterations). - // - // Run ctest with --verbose to see this output. For example: - // $ cd python/build/cmake.blah.blah - // $ ninja - // $ $(git rev-parse --show-toplevel)/.venv/bin/ctest --verbose - printf("Fuzz success rate: %d/%d = %.2f%%\n", numSuccess, numTests, - 100.0 * numSuccess / numTests); -} - class AMDMfmaLayoutTest : public ::testing::Test { public: AMDMfmaLayoutTest() { From 48468af3b4bfd9913d325a7fee660ed2961ce953 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:53:20 +0100 Subject: [PATCH 06/14] [LAYOUTS] Enable Slice(Dot) LL conversion (#5400) There's no reason to disable this one. --- .../TritonGPU/IR/LinearLayoutConversions.cpp | 5 +-- python/test/unit/language/test_core.py | 36 +++++++++++++------ .../TritonGPU/LinearLayoutConversionsTest.cpp | 31 ++++++++++++++++ 3 files changed, 58 insertions(+), 14 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index aec87fd89e..e7174f0f9b 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -770,10 +770,7 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { std::optional parentLL = triton::gpu::toLinearLayout(parentShape, getParent()); if (!parentLL.has_value()) { - if (mlir::isa(getParent())) - return std::nullopt; - llvm::report_fatal_error( - "Failed to compute parent layout for slice layout."); + return std::nullopt; } // Remove dimension getDim() from the parent layout. diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 26fcb146e4..053a2551ee 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -162,6 +162,16 @@ def __str__(self): return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>" +class SliceLayout: + + def __init__(self, dim, parent): + self.dim = dim + self.parent = parent + + def __str__(self): + return f"#{GPU_DIALECT}.slice<{{dim = {self.dim}, parent = {self.parent}}}>" + + class BlockedLayout: def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): @@ -199,6 +209,8 @@ def is_layout_applicable(layout) -> bool: common_layouts = [BlockedLayout, SharedLayout] if layout in common_layouts: return True + elif isinstance(layout, SliceLayout): + return is_layout_applicable(layout.parent) elif is_cuda(): mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout if not isinstance(mma_layout, MmaLayout): @@ -2850,8 +2862,11 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): # TODO (lixun): Add MfmaLayout BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], - instr_shape=[16, 8]) + instr_shape=[16, 8]), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), ] @@ -5281,17 +5296,12 @@ def kernel(Out): # TODO: backend should be tested separately layouts = [ + BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1), - BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 1], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([8, 1], [16, THREADS_PER_WARP // 16], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), @@ -5300,7 +5310,13 @@ def kernel(Out): DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), - MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), + SliceLayout( + dim=1, + parent=DotOperandLayout(parent=MmaLayout([3, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [16, 32, 16]), + op_idx=0, k_width=2)), + SliceLayout( + dim=1, parent=DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), + op_idx=1, k_width=2)), ] intermediate_layouts = [ diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index cd742f62c9..7850b87ac5 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -733,6 +733,37 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, SliceDot) { + // Slice layout with a DotOperand (MMAv2) as the parent. + auto parentV2 = dot(mma(2, 0, {16, 8}, {1, 1}), /*opIdx=*/0, /*kWidth=*/8); + auto sliceV2 = slice(parentV2, /*dim=*/1); + + EXPECT_EQ(toLinearLayout({16}, sliceV2), + LinearLayout( + { + {S("register"), {{8}}}, + {S("lane"), {{0}, {0}, {1}, {2}, {4}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0")})); + + // Slice layout with a DotOperand (MMAv3) as the parent. + auto parentV3 = + dot(mma(3, 0, {16, 16, 8}, {4, 1}), /*opIdx=*/0, /*kWidth=*/2); + auto sliceV3 = slice(parentV3, /*dim=*/0); + + EXPECT_EQ(toLinearLayout({16}, sliceV3), + LinearLayout( + { + {S("register"), {{1}, {8}}}, + {S("lane"), {{2}, {4}, {0}, {0}, {0}}}, + {S("warp"), {{0}, {0}}}, + {S("block"), {}}, + }, + {S("dim0")})); +} + TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false); From 84c8ab490dedc3c6079babd61c5a2b62d8c623bc Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 18 Dec 2024 18:52:21 -0500 Subject: [PATCH 07/14] [intel] Implement `filter_layouts` (#3045) By implementing `filter_layouts`, we can add back layouts from other backends to reduce differences from upstream. Signed-off-by: Whitney Tsang --- python/test/unit/language/test_core.py | 36 ++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index addba58ae8..ead97bb57d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -221,6 +221,8 @@ def is_layout_applicable(layout) -> bool: common_layouts = [BlockedLayout, SharedLayout] if layout in common_layouts: return True + elif isinstance(layout, BlockedLayout) or isinstance(layout, SharedLayout): + return True elif is_cuda(): mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout if not isinstance(mma_layout, MmaLayout): @@ -238,6 +240,9 @@ def is_layout_applicable(layout) -> bool: return isinstance(layout, MfmaLayout) else: return False + elif is_xpu(): + mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout + return isinstance(mma_layout, DpasLayout) else: return True @@ -2692,6 +2697,23 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [THREADS_PER_WARP // 16, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 2], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 16, 16]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 32, 16]), + MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=True), + WmmaLayout(version=1, warps_per_cta=[2, 2]), + WmmaLayout(version=1, warps_per_cta=[4, 1]), + WmmaLayout(version=1, warps_per_cta=[1, 4]), DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32, warps_per_cta=[4, 1], rep_cluster=[1, 1]), DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32, @@ -2887,6 +2909,8 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): # TODO (lixun): Add MfmaLayout BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32, warps_per_cta=[4, 1], rep_cluster=[1, 1]) ] @@ -5309,6 +5333,9 @@ def kernel(Out): # TODO: backend should be tested separately layouts = [ + MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1), BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), @@ -5317,6 +5344,15 @@ def kernel(Out): BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32, warps_per_cta=[4, 1], rep_cluster=[1, 1]) ] From 713b5b82d89788caf3c7c91bd97ec6e0536955cf Mon Sep 17 00:00:00 2001 From: Pavel Chekin Date: Wed, 18 Dec 2024 16:07:24 -0800 Subject: [PATCH 08/14] Use PTI and DLE for benchmarks (#3044) Change `benchmarking_method` to `UPSTREAM_PYTORCH_PROFILER` and use DLE instead of PTDB by default. This is to change the default values, will remove IPEX and PTDB as available options in a separate PR. Fixes #2701, #2592. --- .github/actions/setup-pytorch/action.yml | 3 ++- .github/workflows/triton-benchmarks.yml | 21 ++++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/.github/actions/setup-pytorch/action.yml b/.github/actions/setup-pytorch/action.yml index 88b7189ff0..17153cfc0e 100644 --- a/.github/actions/setup-pytorch/action.yml +++ b/.github/actions/setup-pytorch/action.yml @@ -74,8 +74,9 @@ runs: - name: Generate PyTorch cache key shell: bash run: | + ONEAPI_LINK=$(readlink /opt/intel/oneapi || true) ONEAPI_KEY=$(sha256sum /opt/intel/installed.txt 2> /dev/null | cut -d\ -f1 || true) - PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} ${{ inputs.mode }}$ONEAPI_KEY | sha256sum - | cut -d\ -f1) + PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} ${{ inputs.mode }}${ONEAPI_KEY}${ONEAPI_LINK} | sha256sum - | cut -d\ -f1) echo "PYTORCH_CACHE_KEY=$PYTORCH_CACHE_KEY" | tee -a "$GITHUB_ENV" - name: Load PyTorch from a cache diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index 659eedee72..df30b48392 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -19,7 +19,7 @@ on: - PYTORCH_LEGACY_PROFILER_USING_IPEX - ELAPSED_TIME - UPSTREAM_PYTORCH_PROFILER - default: PYTORCH_LEGACY_PROFILER_USING_IPEX + default: UPSTREAM_PYTORCH_PROFILER run_name: description: Run name type: string @@ -32,6 +32,13 @@ on: description: Use Python built with pyenv type: boolean default: false + oneapi_bundle: + description: oneAPI bundle + type: choice + options: + - PTDB + - DLE + default: DLE schedule: - cron: "5 23 * * *" @@ -46,8 +53,8 @@ permissions: read-all env: PYTHON_VERSION: "3.10" - BENCHMARKING_METHOD: ${{ inputs.benchmarking_method || 'PYTORCH_LEGACY_PROFILER_USING_IPEX' }} - USE_IPEX: ${{ github.event_name != 'workflow_dispatch' && '1' || inputs.benchmarking_method == 'PYTORCH_LEGACY_PROFILER_USING_IPEX' && '1' || '0' }} + BENCHMARKING_METHOD: ${{ inputs.benchmarking_method || 'UPSTREAM_PYTORCH_PROFILER' }} + USE_IPEX: ${{ (inputs.benchmarking_method || 'UPSTREAM_PYTORCH_PROFILER') == 'PYTORCH_LEGACY_PROFILER_USING_IPEX' && '1' || '0' }} TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }} jobs: @@ -66,6 +73,14 @@ jobs: ${{ toJSON(inputs) }} EOF + - name: Use DLE + if: ${{ (github.oneapi_bundle || 'DLE') == 'DLE' }} + shell: bash + run: | + if [[ -e /opt/intel/dle ]]; then + sudo ln -sfT /opt/intel/dle /opt/intel/oneapi + fi + - name: Checkout repository uses: actions/checkout@v4 From c83c0edffe10ba30e9a408fe6f132e3c6a740118 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Thu, 19 Dec 2024 01:09:54 +0100 Subject: [PATCH 09/14] Remove workaround for upstream profiler (#2484) Signed-off-by: Anatoly Myachev --- .../benchmark_testing.py | 32 ++++++++--------- .../flash_attention_fwd_benchmark.py | 6 ++-- .../triton_kernels_benchmark/fused_softmax.py | 15 ++------ .../gemm_benchmark.py | 35 ++----------------- .../gemm_postop_addmatrix_benchmark.py | 4 +-- .../gemm_postop_gelu_benchmark.py | 4 +-- .../gemm_preop_exp_benchmark.py | 4 +-- .../gemm_splitk_benchmark.py | 4 +-- .../gemm_streamk_benchmark.py | 5 ++- .../triton_kernels_benchmark/prefix_sums.py | 3 +- 10 files changed, 31 insertions(+), 81 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 9d1020b95d..1e088291fa 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -37,7 +37,7 @@ def _summarize_statistics(times, quantiles, return_mode): def do_bench_ipex(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", - sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument + sync_submitting=True): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -108,7 +108,7 @@ def extract_kernels(funcs): def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", - device="xpu", kernel_name=None): # pylint: disable=unused-argument + device="xpu"): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -159,7 +159,7 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, - return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None): + return_mode="mean", device="xpu", sync_submitting=True): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -178,7 +178,7 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no assert return_mode in ["min", "max", "mean", "median"] import torch - from torch.profiler import profile, ProfilerActivity + from torch.profiler import profile, ProfilerActivity, record_function fn() synchronize() @@ -206,24 +206,24 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no if sync_submitting: synchronize() # record time of `fn` - fn() + with record_function("__profile_kernel_of_func"): + fn() # Record clocks synchronize() - function_events = prof.events() + profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), prof.events()) + functions = list(profiling_func_filter) - all_functions = [] - if isinstance(kernel_name, str): - kernel_name = [kernel_name] - for ker_name in kernel_name: - functions = list(filter(lambda x: x.name.startswith(ker_name), function_events)) # pylint: disable=cell-var-from-loop - assert len(functions) == n_repeat, f"the profiling number for kernel: '{ker_name}' not match, {len(functions)}" - all_functions.append(functions) - # profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events) + def extract_kernels(funcs): + kernels = [] + kernels += list(itertools.chain.from_iterable(map(lambda func: extract_kernels(func.cpu_children), funcs))) + kernels += list(itertools.chain.from_iterable([func.kernels for func in funcs])) + return kernels + kernels = [extract_kernels(func.cpu_children) for func in functions] + assert len(kernels) == n_repeat, "the profiling number not match" # Make the time to the milliseconds. - times = torch.tensor([sum(map(lambda elem: elem.self_device_time_total, f)) * 1e-3 for f in zip(*all_functions)], - dtype=torch.float) + times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index efb4987cb5..132898c023 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -265,8 +265,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, - kernel_name='_attn_fwd') + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'xetla': module_name = f'flash_attn_causal_{CAUSAL}'.lower() @@ -281,8 +280,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, - kernel_name='gpu::xetla::fmha::FmhaForwardKernel<') + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 3fa5983d7b..6782e92d6b 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -131,8 +131,7 @@ def benchmark(M, N, provider): triton_fn = lambda: softmax(x, out) torch_fn = lambda: torch.softmax(x, axis=-1) benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10, - kernel_name="softmax_kernel") + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10) elif provider == "torch-jit": _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, @@ -145,17 +144,7 @@ def benchmark(M, N, provider): xetla_fn = lambda: func(x, out, 0) torch_fn = lambda: torch.softmax(x, axis=-1) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch") - kernels_name = { - "softmax_shape_4096_256": "mat1_4096x256_bf16_cfg0", - "softmax_shape_4096_1024": "mat1_4096x1024_bf16_cfg0", - "softmax_shape_4096_2048": "mat1_4096x2048_bf16_cfg0", - "softmax_shape_4096_4096": "mat1_4096x4096_bf16_cfg0", - "softmax_shape_4096_8192": "mat1_4096x8k_bf16_cfg0", - "softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0", - "softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0", - } - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10, - kernel_name=kernels_name[name]) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10) else: raise NotImplementedError(f"Unsupported provider {provider}") diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 15d0deb5af..535b34dce9 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -288,7 +288,7 @@ def benchmark(B, M, N, K, provider): # Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method do_bench = do_bench_elapsed_time _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name='gemm_kernel') + quantiles=quantiles) elif provider == 'triton': assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: @@ -301,8 +301,7 @@ def benchmark(B, M, N, K, provider): rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, - kernel_name='matmul_kernel_with_block_pointers') + quantiles=quantiles) elif provider == 'xetla': if B == 1: c = torch.zeros((M, N), device='xpu', dtype=torch.float32) @@ -329,37 +328,9 @@ def xetla_func_with_acc_allocation(): xetla_fn = xetla_func_with_acc_allocation torch_fn = lambda: torch.matmul(a, b).to(torch.float32) - kernels_name = { - 'gemm_shape_1_1024_1024_1024': 'Test_1x1024x1024x1024_row_row', - 'gemm_shape_1_2048_2048_2048': 'Test_1x2048x2048x2048_row_row', - 'gemm_shape_1_4096_4096_4096': 'Test_1x4096x4096x4096_row_row', - 'gemm_shape_1_8192_8192_8192': 'Test_1x8192x8192x8192_row_row', - 'gemm_shape_1_1_5120_13824': 'Test_1x1x5120x13824_row_row', - 'gemm_shape_1_4_4096_12288': 'Test_1x4x4096x12288_row_row', - 'gemm_shape_1_512_8192_8192': 'Test_1x512x8192x8192_row_row', - 'gemm_shape_1_512_8192_32768': 'Test_1x512x8192x32768_row_row', - 'gemm_shape_1_512_32768_8192': 'Test_1x512x32768x8192_row_row', - 'gemm_shape_1_1024_16384_8192': 'Test_1x1024x16384x8192_row_row', - 'gemm_shape_1_1024_28672_8192': 'Test_1x1024x28672x8192_row_row', - 'gemm_shape_1_3072_4096_3072': 'Test_1x3072x4096x3072_row_row', - 'gemm_shape_1_4096_16384_8192': 'Test_1x4096x16384x8192_row_row', - 'gemm_shape_1_8192_16384_1024': 'Test_1x8192x16384x1024_row_row', - 'gemm_shape_1_8192_16384_4096': 'Test_1x8192x16384x4096_row_row', - 'gemm_shape_1_16384_1024_8192': 'Test_1x16384x1024x8192_row_row', - 'gemm_shape_1_16384_4096_8192': 'Test_1x16384x4096x8192_row_row', - 'gemm_shape_1_16384_8192_1024': 'Test_1x16384x8192x1024_row_row', - 'gemm_shape_1_16384_8192_4096': 'Test_1x16384x8192x4096_row_row', - 'gemm_shape_4_32768_128_4096': 'Test_4x32768x128x4096_row_row', - 'gemm_shape_4_32768_4096_128': 'Test_4x32768x4096x128_row_row', - 'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row', - 'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row', - 'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row', - 'gemm_streamk_shape_3072_4096_3072': 'stream_k_gemm_run', - } - # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name=kernels_name[name]) + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 307100dcfe..cefbd5abc9 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -266,17 +266,15 @@ def benchmark(B, M, N, K, provider): assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers_batched' else: assert len(a.shape) == 2, 'Expecting shape of length 2' c = torch.empty((M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, d, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name=kernel_name) + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 85bb594ade..68cec3931e 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -268,17 +268,15 @@ def benchmark(B, M, N, K, provider): assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers_batched' else: assert len(a.shape) == 2, 'Expecting shape of length 2' c = torch.empty((M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32)) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name=kernel_name) + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 30ed124d44..dd5b57c84f 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -256,17 +256,15 @@ def benchmark(B, M, N, K, provider): assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers_batched' else: assert len(a.shape) == 2, 'Expecting shape of length 2' c = torch.empty((M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name=kernel_name) + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 06d2d90e1d..c4114c4466 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -159,7 +159,7 @@ def benchmark(M, N, K, provider): rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name='_kernel') + quantiles=quantiles) elif provider == 'xetla': c = torch.zeros((M, N), device='xpu', dtype=torch.float32) acc = torch.zeros((M, N), device='xpu', dtype=torch.float32) @@ -172,7 +172,7 @@ def benchmark(M, N, K, provider): # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name='split_k_gemm_run') + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 12f37e9d31..f0743cfe64 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -280,8 +280,7 @@ def benchmark(M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, - kernel_name=['first_wave', 'full_tiles']) + quantiles=quantiles) elif provider == 'xetla': c = torch.zeros((M, N), device='xpu', dtype=torch.float32) acc = torch.zeros((M, N), device='xpu', dtype=torch.float32) @@ -294,7 +293,7 @@ def benchmark(M, N, K, provider): # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name='stream_k_gemm_run') + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/prefix_sums.py b/benchmarks/triton_kernels_benchmark/prefix_sums.py index 8f17fb9e9f..bb3d2069f0 100644 --- a/benchmarks/triton_kernels_benchmark/prefix_sums.py +++ b/benchmarks/triton_kernels_benchmark/prefix_sums.py @@ -44,8 +44,7 @@ def benchmark(M, N, AXIS, provider): if provider == 'triton': triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS) - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, - kernel_name='scan_kernel') + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') From 635435fc2e56b2a30276302d75df87956b541848 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Wed, 18 Dec 2024 17:20:49 -0800 Subject: [PATCH 10/14] [Pipeliner] Add support for pipelining loads with different latencies (#5460) @pawelszczerbuk wrote the code. I just fixed a few things and added a test :) This generalizes the loop pipeliner infrastructure a bit to support loads with different latencies that are pipelined and multibuffered differently, allowing more fine-grained buffer allocation. The feature isn't exposed yet, but the PR also adds an attribute to the TMA load op allowing the user to manually specify the desired latency. --------- Co-authored-by: Pawel Szczerbuk --- .../Transforms/Pipeliner/AssignLatencies.cpp | 21 ++ .../Pipeliner/MatmulLoopPipeline.cpp | 193 ++++++++++-------- .../loop-pipeline-async-latencies.mlir | 144 +++++++++++++ 3 files changed, 277 insertions(+), 81 deletions(-) create mode 100644 test/TritonGPU/loop-pipeline-async-latencies.mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp index f274363730..afa22f164a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -183,6 +183,23 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, return loadOpToIndLevel; } +bool hasLatenciesAssigned(scf::ForOp forOp) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (op.hasAttr("tt_latency")) + return true; + } + return false; +} + +void assignUserProvidedLatencies(scf::ForOp forOp, + DenseMap &opLatency) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (auto latencyAttr = op.getAttr("tt_latency")) { + opLatency[&op] = mlir::cast(latencyAttr).getInt(); + } + } +} + } // namespace // Look for load ops that directly or indirectly feed into dot ops. Based @@ -212,6 +229,10 @@ DenseMap assignLatencies(ModuleOp moduleOp, DenseMap opLatency; for (auto forOp : loops) { + if (hasLatenciesAssigned(forOp)) { + assignUserProvidedLatencies(forOp, opLatency); + continue; + } int numStages = getNumStagesOrDefault(forOp); bool pipelineWithoutDot = forOp->hasAttr(mlir::triton::kNumStagesAttrName); ModuleOp moduleOp = forOp->getParentOfType(); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index fc037cd26f..d77c0546d5 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -121,7 +121,7 @@ static Operation *getFirstUseOfPipelinedLoad(Operation *loadOp) { static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, llvm::MapVector &loadToInfo, - int numStages, int maxClusterId) { + int maxClusterId) { int retCode = -1; OpBuilderWithStage builder(forOp); auto opPair = tt::getStageCluster(loadOp); @@ -234,8 +234,7 @@ static void createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, Value phase, - llvm::MapVector &loadToInfo, - int numStages) { + llvm::MapVector &loadToInfo) { assert(phase && "Phase value is required for TMA async copy."); OpBuilderWithStage builder(forOp); auto [stage, clusterId] = tt::getStageCluster(loadOp); @@ -585,21 +584,28 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { return barrierAlloc; } +struct StageGroup { + Value insertIdx; + Value extractIdx; + Value phase; + bool hasTMALoad = false; +}; struct AsyncLoad { - AsyncLoad(Operation *loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {} Operation *loadOp; Value alloc; Value barrier; Operation *waitOp = nullptr; int firstUseStage, firstUseCluster; bool isTMALoad = false; + int numBuffers = 0; }; // Create barriers and wait ops for the async loads. Barriers may be shared by -// multiple loads is the schedule allows it. +// multiple loads if the schedule allows it. static void createTMABarrierAndWait( - scf::ForOp &forOp, SmallVector &asyncLoads, Value insertIdx, - Value extractIdx, Value phase, int numBuffers, SmallVector &barriers, + scf::ForOp &forOp, SmallVector &asyncLoads, + SmallVector &barriers, + const llvm::MapVector &stageGroups, const llvm::MapVector &loadToInfo) { llvm::SmallDenseMap loadToAsyncLoad; for (AsyncLoad &asyncLoad : asyncLoads) { @@ -639,12 +645,15 @@ static void createTMABarrierAndWait( }; addToGroup(&asyncLoad); Operation *nextOp = asyncLoad.loadOp->getNextNode(); + int numBuffers = asyncLoad.numBuffers; while (nextOp) { if (users.count(nextOp) || visited.count(nextOp)) break; if (isa(nextOp)) { auto it = loadToAsyncLoad.find(nextOp); if (it != loadToAsyncLoad.end() && it->second->isTMALoad) { + if (it->second->numBuffers != numBuffers) + break; if (group.size() > 0 && sameStageCluster(group[0]->loadOp, it->second->loadOp)) addToGroup(it->second); @@ -659,6 +668,8 @@ static void createTMABarrierAndWait( // load. for (SmallVector &group : loadGroups) { int sizeInBytes = 0; + int numBuffers = group[0]->numBuffers; + const StageGroup &stageGroup = stageGroups.find(numBuffers)->second; for (AsyncLoad *asyncLoad : group) { auto tensorTy = cast(asyncLoad->loadOp->getResult(0).getType()); @@ -682,7 +693,7 @@ static void createTMABarrierAndWait( builder.setInsertionPoint(group[0]->loadOp); Value barrier = builder.createWithStage( loc, stage, cluster, barrierTy, barrierAlloc, - ArrayRef({insertIdx})); + ArrayRef({stageGroup.insertIdx})); Value pred = builder.createWithStage(loc, stage, cluster, 1, 1); Operation *expect = builder.createWithStage( @@ -691,10 +702,10 @@ static void createTMABarrierAndWait( builder.setInsertionPointAfter(group.back()->loadOp); Value barrierViewWait = builder.createWithStage( loc, group[0]->firstUseStage, group[0]->firstUseCluster, barrierTy, - barrierAlloc, ArrayRef({extractIdx})); + barrierAlloc, ArrayRef({stageGroup.extractIdx})); Operation *wait = builder.createWithStage( loc, group[0]->firstUseStage, group[0]->firstUseCluster, - barrierViewWait, phase); + barrierViewWait, stageGroup.phase); // Update the async loads info. for (AsyncLoad *asyncLoad : group) { asyncLoad->barrier = barrier; @@ -855,46 +866,47 @@ static SmallVector createAsyncOps(scf::ForOp &forOp, llvm::MapVector &loadToInfo, SmallVector &barriers, int numStages) { - // Calculate the number of buffers needed for each load. - // TODO pawel: we could do more fine-grained allocation here and - // allocate only the number of buffers that specific loads need. - // Instead, we allocate the maximum number of buffers needed by any load. - int numBuffers = - llvm::max_element(llvm::make_second_range(loadToInfo), [](auto &lhs, - auto &rhs) { - return lhs.distToUse < rhs.distToUse; - })->distToUse; - bool hasMMAV3 = llvm::any_of(loadToInfo, [](auto &kv) { - return kv.second.isMMAv3Shared || kv.second.isMMAv3Registers; - }); - if (hasMMAV3) { - // For MMAv3, we need an extra buffer as this is assumed in the wgmma - // pipelining post-processing. - numBuffers++; - }; - llvm::MapVector tmaBufferMapping; if (failed(allocTMABuffers(forOp, tmaBufferMapping, numStages))) { llvm_unreachable("TMA pipelining failed"); } + // Each group of loads/allocs with the same number of buffers (and stages) + // will share the indices and barriers. + SmallVector asyncLoads; SmallVector allocs; - bool hasTMALoad = false; + llvm::MapVector stageGroups; + for (auto &[loadOp, info] : loadToInfo) { + AsyncLoad asyncLoad = {.loadOp = loadOp}; + bool isTMALoad = false; + int numBuffers = info.distToUse; + // For MMAv3, we need an extra buffer as this is assumed in the wgmma + // pipelining post-processing. + if (info.isMMAv3Shared || info.isMMAv3Registers) { + ++numBuffers; + } + if (isa(loadOp)) { + isTMALoad = true; + asyncLoad.isTMALoad = isTMALoad; + } assert(info.sharedEncoding && "LoadOp shared encoding not defined."); Value alloc = createAlloc(forOp, loadOp, info.sharedEncoding, numBuffers); assert(alloc && "Failed to create alloc for the async load."); allocs.push_back(alloc); - asyncLoads.emplace_back(loadOp, alloc); - if (isa(loadOp)) { - hasTMALoad = true; - asyncLoads.back().isTMALoad = true; - } + asyncLoad.alloc = alloc; + auto *firstUse = getFirstUseOfPipelinedLoad(loadOp); auto [firstUseStage, firstUseCluster] = tt::getStageCluster(firstUse); - asyncLoads.back().firstUseStage = firstUseStage; - asyncLoads.back().firstUseCluster = firstUseCluster; + asyncLoad.firstUseStage = firstUseStage; + asyncLoad.firstUseCluster = firstUseCluster; + asyncLoad.numBuffers = numBuffers; + stageGroups.insert({numBuffers, {}}); + if (isTMALoad) { + stageGroups[numBuffers].hasTMALoad = true; + } + asyncLoads.push_back(asyncLoad); } IRRewriter builder(forOp.getContext()); @@ -908,41 +920,34 @@ createAsyncOps(scf::ForOp &forOp, Value minusOne = builder.create(loc, -1, 32); Value zero = builder.create(loc, 0, 32); Value one = builder.create(loc, 1, 32); - Value insertIdx = minusOne; - Value extractIdx = minusOne; - Value phase = Value(); - Value numBuffersVal = - builder.create(loc, numBuffers, 32); SmallVector newOperands; - newOperands.push_back(insertIdx); - newOperands.push_back(extractIdx); - if (hasTMALoad) { - // A single barrier arrival sequence is a "phase" and two phases can - // overlap, provided the phases are differentiated with an alternating - // boolean value. - phase = builder.create(loc, 0, 32); - newOperands.push_back(phase); + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + for (auto [_, stageGroup] : stageGroups) { + newOperands.push_back(minusOne); // insertIdx + newOperands.push_back(minusOne); // extractIdx + if (stageGroup.hasTMALoad) { + // A single barrier arrival sequence is a "phase" and two phases can + // overlap, provided the phases are differentiated with an alternating + // boolean value. + newOperands.push_back(zero); // phase + } } // Also create one counter per TMA buffer. This allows the descriptors to be // updated independently without needing to write duplicate of existing tma // descriptors. + unsigned tmaCounterArgsStartIdx = newOperandIndex + newOperands.size(); for (int i = 0; i < tmaBufferMapping.size(); ++i) { newOperands.push_back(zero); } - unsigned newOperandIndex = forOp.getBody()->getNumArguments(); // Patch the loop to add the new loop carried dependencies. scf::ForOp newForOp = replaceForOpWithNewSignature(builder, forOp, newOperands); forOp.erase(); forOp = newForOp; - insertIdx = newForOp.getBody()->getArgument(newOperandIndex); - extractIdx = newForOp.getBody()->getArgument(newOperandIndex + 1); - if (phase) { - phase = newForOp.getBody()->getArgument(newOperandIndex + 2); - } + auto tmaCounters = ArrayRef(newForOp.getBody()->getArguments()) - .slice(newOperandIndex + (phase ? 3 : 2)); + .slice(tmaCounterArgsStartIdx); // Update yield op with temporary yield values auto forYield = cast(newForOp.getBody()->getTerminator()); @@ -956,44 +961,70 @@ createAsyncOps(scf::ForOp &forOp, } tmaBufferMapping.clear(); - // FIXME: loads can be in different (stage, cluster) - // Create two counters for the insert and extract indices to avoid creating - // long liverange. - builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); - insertIdx = builder.create(loc, insertIdx, one); - Value cndIns = builder.create(loc, arith::CmpIPredicate::slt, - insertIdx, numBuffersVal); - insertIdx = builder.create(loc, cndIns, insertIdx, zero); - - extractIdx = builder.create(loc, extractIdx, one); - Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, - extractIdx, numBuffersVal); - extractIdx = builder.create(loc, cndExt, extractIdx, zero); - if (phase) { - Value nextPhase = builder.create(loc, phase, one); - phase = builder.create(loc, cndExt, phase, nextPhase); + builder.setInsertionPoint(forOp); + loc = forOp.getLoc(); + int argIdx = newOperandIndex; + for (auto &[numBuffers, stageGroup] : stageGroups) { + Value insertIdx = newForOp.getBody()->getArgument(argIdx); + argIdx++; + Value extractIdx = newForOp.getBody()->getArgument(argIdx); + argIdx++; + Value phase = nullptr; + if (stageGroup.hasTMALoad) { + phase = newForOp.getBody()->getArgument(argIdx); + argIdx++; + } + + // Create two counters for the insert and extract indices to avoid creating + // long liverange. + builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + insertIdx = builder.create(loc, insertIdx, one); + Value cndIns = builder.create(loc, arith::CmpIPredicate::slt, + insertIdx, numBuffersVal); + insertIdx = builder.create(loc, cndIns, insertIdx, zero); + stageGroup.insertIdx = insertIdx; + + extractIdx = builder.create(loc, extractIdx, one); + // Duplicate the constant to keep it from being carried across loops. + numBuffersVal = builder.create(loc, numBuffers, 32); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + stageGroup.extractIdx = extractIdx; + if (phase) { + Value nextPhase = builder.create(loc, phase, one); + phase = builder.create(loc, cndExt, phase, nextPhase); + stageGroup.phase = phase; + } } - createTMABarrierAndWait(forOp, asyncLoads, insertIdx, extractIdx, phase, - numBuffers, barriers, loadToInfo); + createTMABarrierAndWait(forOp, asyncLoads, barriers, stageGroups, loadToInfo); auto [_, maxClusterId] = tt::getMinMaxCluster(forOp); for (AsyncLoad &asyncLoad : asyncLoads) { + auto [insertIdx, extractIdx, phase, _] = stageGroups[asyncLoad.numBuffers]; if (auto loadOp = dyn_cast(asyncLoad.loadOp)) { createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, - loadToInfo, numStages, maxClusterId); + loadToInfo, maxClusterId); } else { auto descLoad = cast(asyncLoad.loadOp); createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx, extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase, - loadToInfo, numStages); + loadToInfo); } } - // Patch the yield with the updated counters. - forYield.setOperand(newOperandIndex + -1, insertIdx); - forYield.setOperand(newOperandIndex + 0, extractIdx); - if (phase) { - forYield.setOperand(newOperandIndex + 1, phase); + // Patch the yield with the updated counters. Subtract to account for the loop + // counter. + argIdx = newOperandIndex - 1; + for (auto &[numBuffers, stageGroup] : stageGroups) { + forYield.setOperand(argIdx++, stageGroup.insertIdx); + forYield.setOperand(argIdx++, stageGroup.extractIdx); + if (stageGroup.phase) + forYield.setOperand(argIdx++, stageGroup.phase); } + assert(argIdx + 1 == tmaCounterArgsStartIdx); tt::CoarseSchedule coarseSchedule(numStages); coarseSchedule.deSerialize(forOp); diff --git a/test/TritonGPU/loop-pipeline-async-latencies.mlir b/test/TritonGPU/loop-pipeline-async-latencies.mlir new file mode 100644 index 0000000000..59dc073743 --- /dev/null +++ b/test/TritonGPU/loop-pipeline-async-latencies.mlir @@ -0,0 +1,144 @@ +// RUN: triton-opt %s --tritongpu-loop-scheduling="num-stages=3" --tritongpu-pipeline="num-stages=3" -canonicalize -cse | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: matmul_kernel_tma_persistent +tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg1: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg2: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = arith.subi %arg3, %c2_i32 : i32 + + %1 = tt.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc> + %2 = tt.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc> + + // CHECK: [[LHS_BUFFERS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x64xf16, + // CHECK: [[RHS_BUFFERS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<4x256x64xf16, + + // CHECK: [[LHS_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64, + // CHECK-NEXT: [[LHS_BAR0:%.*]] = ttg.memdesc_subview [[LHS_BARS]][%c0_i32] + // CHECK-NEXT: ttng.init_barrier [[LHS_BAR0]] + // CHECK-NEXT: [[LHS_BAR1:%.*]] = ttg.memdesc_subview [[LHS_BARS]][%c1_i32] + // CHECK-NEXT: ttng.init_barrier [[LHS_BAR1]] + + // CHECK: [[RHS_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<4xi64, + // CHECK-NEXT: [[RHS_BAR0:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c0_i32] + // CHECK-NEXT: ttng.init_barrier [[RHS_BAR0]] + // CHECK-NEXT: [[RHS_BAR1:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c1_i32] + // CHECK-NEXT: ttng.init_barrier [[RHS_BAR1]] + // CHECK-NEXT: [[RHS_BAR2:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c2_i32] + // CHECK-NEXT: ttng.init_barrier [[RHS_BAR2]] + // CHECK-NEXT: [[RHS_BAR3:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c3_i32] + // CHECK-NEXT: ttng.init_barrier [[RHS_BAR3]] + + // CHECK: [[MASK0:%.*]] = arith.cmpi sgt, %arg3, %c0_i32 + // CHECK-NEXT: ttng.barrier_expect [[RHS_BAR0]], 32768, [[MASK0]] + // CHECK-NEXT: [[RHS_BUF0:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][%c0_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, %c0_i32] [[RHS_BUF0]], [[RHS_BAR0]], [[MASK0]] + + // CHECK: [[MASK1:%.*]] = arith.cmpi sgt, %arg3, %c1_i32 + // CHECK-NEXT: ttng.barrier_expect [[RHS_BAR1]], 32768, [[MASK1]] + // CHECK-NEXT: [[RHS_BUF1:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][%c1_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, %c1_i32] [[RHS_BUF1]], [[RHS_BAR1]], [[MASK1]] + + // CHECK: [[MASK2:%.*]] = arith.cmpi sgt, %arg3, %c2_i32 + + // CHECK-NEXT: ttng.barrier_expect [[LHS_BAR0]], 16384, [[MASK0]] + // CHECK-NEXT: [[LHS_BUF0:%.*]] = ttg.memdesc_subview [[LHS_BUFFERS]][%c0_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] [[LHS_BUF0]], [[LHS_BAR0]], [[MASK0]] + + // CHECK: ttng.barrier_expect [[RHS_BAR2]], 32768, [[MASK2]] + // CHECK-NEXT: [[RHS_BUF2:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][%c2_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, %c2_i32] [[RHS_BUF2]], [[RHS_BAR2]], [[MASK2]] + + %true = arith.constant true + %false = arith.constant false + + // CHECK: scf.for [[I:%.*]] = %c0_i32 to + // CHECK-SAME: iter_args([[ACCUM:%arg[0-9]+]] = %cst + + // CHECK-SAME: [[NEXT_LHS_BUF_IDX:%arg[0-9]+]] = %c0_i32 + // CHECK-SAME: [[LHS_BUF_IDX:%arg[0-9]+]] = %c-1_i32 + // CHECK-SAME: [[LHS_PHASE_ARG:%arg[0-9]+]] = %c0_i32 + + // CHECK-SAME: [[NEXT_RHS_BUF_IDX:%arg[0-9]+]] = %c2_i32 + // CHECK-SAME: [[RHS_BUF_IDX:%arg[0-9]+]] = %c-1_i32 + // CHECK-SAME: [[RHS_PHASE_ARG:%arg[0-9]+]] = %c0_i32 + %3 = scf.for %arg6 = %c0_i32 to %arg3 step %c1_i32 iter_args(%arg7 = %cst) -> (tensor<128x256xf32, #mma>) : i32 { + // CHECK: [[RHS_MAX_ITER:%.*]] = arith.subi %arg3, %c3_i32 + // CHECK-NEXT: [[RHS_MASK:%.*]] = arith.cmpi slt, [[I]], [[RHS_MAX_ITER]] + // CHECK: [[LHS_MAX_ITER:%.*]] = arith.subi %arg3, %c1_i32 + // CHECK-NEXT: [[LHS_MASK:%.*]] = arith.cmpi slt, [[I]], [[LHS_MAX_ITER]] + + // Compute RHS buffer index modulo 4. + // CHECK: [[V0:%.*]] = arith.addi [[RHS_BUF_IDX]], %c1_i32 + // CHECK-NEXT: [[V1:%.*]] = arith.cmpi slt, [[V0]], %c4_i32 + // CHECK-NEXT: [[RHS_BUF_IDX:%.*]] = arith.select [[V1]], [[V0]], %c0_i32 + + // Compute RHS phase index modulo 4. + // CHECK: [[V0:%.*]] = arith.xori [[RHS_PHASE_ARG]], %c1_i32 + // CHECK-NEXT: [[RHS_PHASE:%.*]] = arith.select [[V1]], [[RHS_PHASE_ARG]], [[V0]] + + // Compute LHS buffer index modulo 2. + // CHECK: [[V0:%.*]] = arith.addi [[LHS_BUF_IDX]], %c1_i32 + // CHECK-NEXT: [[V1:%.*]] = arith.cmpi slt, [[V0]], %c2_i32 + // CHECK-NEXT: [[LHS_BUF_IDX:%.*]] = arith.select [[V1]], [[V0]], %c0_i32 + + // Compute LHS phase index modulo 2. + // CHECK: [[V0:%.*]] = arith.xori [[LHS_PHASE_ARG]], %c1_i32 + // CHECK-NEXT: [[LHS_PHASE:%.*]] = arith.select [[V1]], [[LHS_PHASE_ARG]], [[V0]] + + // CHECK: [[LHS_MBAR:%.*]] = ttg.memdesc_subview [[LHS_BARS]][[[LHS_BUF_IDX]]] + // CHECK-NEXT: ttng.wait_barrier [[LHS_MBAR]], [[LHS_PHASE]] + + // CHECK: [[RHS_MBAR:%.*]] = ttg.memdesc_subview [[RHS_BARS]][[[RHS_BUF_IDX]]] + // CHECK-NEXT: ttng.wait_barrier [[RHS_MBAR]], [[RHS_PHASE]] + + %4 = tt.experimental_descriptor_load %1[%c0_i32, %arg6] {tt_latency = 1 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %5 = ttg.local_alloc %4 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %6 = tt.experimental_descriptor_load %2[%c0_i32, %arg6] {tt_latency = 3 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #blocked> + %7 = ttg.local_alloc %6 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared, #smem> + %8 = ttg.memdesc_trans %7 {order = array} : !ttg.memdesc<256x64xf16, #shared, #smem> -> !ttg.memdesc<64x256xf16, #shared1, #smem> + %9 = ttng.warp_group_dot %5, %8, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> + + // CHECK: [[V0:%.*]] = arith.addi [[NEXT_LHS_BUF_IDX]], %c1_i32 + // CHECK-NEXT: [[V1:%.*]] = arith.cmpi slt, [[V0]], %c2_i32 + // CHECK-NEXT: [[NEXT_LHS_BUF_IDX:%.*]] = arith.select [[V1]], [[V0]], %c0_i32 + // CHECK-NEXT: [[NEXT_LHS_BAR:%.*]] = ttg.memdesc_subview [[LHS_BARS]][[[NEXT_LHS_BUF_IDX]]] + // CHECK-NEXT: ttng.barrier_expect [[NEXT_LHS_BAR]], 16384, [[LHS_MASK]] + + // CHECK-NEXT: [[NEXT_LHS_BUF:%.*]] = ttg.memdesc_subview [[LHS_BUFFERS]][[[NEXT_LHS_BUF_IDX]], %c0_i32, %c0_i32] + // CHECK-NEXT: [[NEXT_LHS_IDX:%.*]] = arith.addi [[I]], %c1_i32 + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg0[%c0_i32, [[NEXT_LHS_IDX]]] [[NEXT_LHS_BUF]], [[NEXT_LHS_BAR]], [[LHS_MASK]] + + // CHECK: [[V0:%.*]] = arith.addi [[NEXT_RHS_BUF_IDX]], %c1_i32 + // CHECK-NEXT: [[V1:%.*]] = arith.cmpi slt, [[V0]], %c4_i32 + // CHECK-NEXT: [[NEXT_RHS_BUF_IDX:%.*]] = arith.select [[V1]], [[V0]], %c0_i32 + // CHECK-NEXT: [[NEXT_RHS_BAR:%.*]] = ttg.memdesc_subview [[RHS_BARS]][[[NEXT_RHS_BUF_IDX]]] + // CHECK-NEXT: ttng.barrier_expect [[NEXT_RHS_BAR]], 32768, [[RHS_MASK]] + + // CHECK-NEXT: [[NEXT_RHS_BUF:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][[[NEXT_RHS_BUF_IDX]], %c0_i32, %c0_i32] + // CHECK-NEXT: [[NEXT_RHS_IDX:%.*]] = arith.addi [[I]], %c3_i32 + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, [[NEXT_RHS_IDX]]] [[NEXT_RHS_BUF]], [[NEXT_RHS_BAR]], [[RHS_MASK]] + + %10 = arith.cmpi eq, %arg3, %0 : i32 + scf.if %10 { + %11 = arith.truncf %9 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %12 = ttg.convert_layout %11 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + %13 = tt.reinterpret_tensor_descriptor %arg2 : !tt.ptr to !tt.tensordesc> + tt.experimental_descriptor_store %13[%c0_i32, %c0_i32], %12 : !tt.tensordesc>, tensor<128x256xf16, #blocked1> + } + // CHECK: yield %{{.*}}, [[NEXT_LHS_BUF_IDX]], [[LHS_BUF_IDX]], [[LHS_PHASE]], [[NEXT_RHS_BUF_IDX]], [[RHS_BUF_IDX]], [[RHS_PHASE]] + scf.yield %9 : tensor<128x256xf32, #mma> + } {tt.num_stages = 4 : i32} + tt.return +} + +} From fdab3bb9150566e543f6986490a6e1a6932cb06d Mon Sep 17 00:00:00 2001 From: Si Yudong Date: Thu, 19 Dec 2024 10:14:51 +0800 Subject: [PATCH 11/14] Fix `test_gather` (#3010) Make `getStackPointer` as interface of the `TargetInfo` to generalize `getSharedMemoryBase` in gather op. --- .../TritonGPUToLLVM/TargetInfoBase.h | 3 ++ .../Conversion/TritonGPUToLLVM/Utility.h | 16 ++-------- .../TritonGPUToLLVM/ControlFlowOpToLLVM.cpp | 2 +- python/test/unit/language/test_core.py | 2 -- .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 13 ++++++++ .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.h | 3 ++ .../ControlFlowOpToLLVM.cpp | 4 +-- .../ConvertLayoutOpToLLVM.cpp | 8 ++--- .../HistogramOpToLLVM.cpp | 4 +-- .../LoadStoreOpToLLVM.cpp | 8 ++--- .../TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp | 4 +-- .../TritonIntelGPUToLLVM/ReduceScanCommon.h | 4 +-- .../lib/TritonIntelGPUToLLVM/TargetInfo.cpp | 10 ++++++ .../lib/TritonIntelGPUToLLVM/TargetInfo.h | 3 ++ .../intel/lib/TritonIntelGPUToLLVM/Utility.h | 32 ------------------- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 13 ++++++++ .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 3 ++ 17 files changed, 67 insertions(+), 65 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 87db94f25f..85a0e6fc77 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -91,6 +91,9 @@ class TargetInfoBase { virtual bool supportVectorizedAtomics() const = 0; + virtual Value getStackPointer(RewriterBase &rewriter, + FunctionOpInterface funcOp) const = 0; + virtual ~TargetInfoBase() {} }; } // namespace mlir::triton diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 2359444f14..999355b77b 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -381,19 +381,6 @@ inline bool isKernel(FunctionOpInterface funcOp) { return funcOp.getVisibility() == SymbolTable::Visibility::Public; } -inline Value getStackPointer(RewriterBase &rewriter, - FunctionOpInterface funcOp) { - // See NOTE: [Additional Function Arguments] - if (!isKernel(funcOp)) { - return funcOp.getArgument(funcOp.getNumArguments() - 2); - } - - auto mod = funcOp->getParentOfType(); - auto globalBase = dyn_cast(mod.lookupSymbol("global_smem")); - assert(globalBase); - return rewriter.create(funcOp.getLoc(), globalBase); -} - inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter, FunctionOpInterface funcOp, Value allocOffset = {}) { @@ -457,7 +444,8 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, .getValue() .getZExtValue(); Value offVal = i32_val(offset); - Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); + Value base = + gep(ptrTy, i8_ty, target.getStackPointer(rewriter, func), offVal); return base; } diff --git a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp index 06e19029eb..31833e994b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -83,7 +83,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { callOp.getLoc(), /*opOperands=*/callOp->getOperands(), adaptor.getOperands(), rewriter); if (!caller->hasAttr("allocation.offset")) { - auto base = LLVM::getStackPointer(rewriter, caller); + auto base = targetInfo.getStackPointer(rewriter, caller); promotedOperands.push_back(base); } else { auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ead97bb57d..b2b373870b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6304,8 +6304,6 @@ def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: ([128, 64], [128, 128], 1), ]) def test_gather(src_shape, indices_shape, axis, device): - if is_xpu(): - pytest.skip("Fail on XPU") def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 9a00987900..bdd6e5140e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -426,6 +426,19 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, int TargetInfo::getSharedAddressSpace() const { return 3; } +Value TargetInfo::getStackPointer(RewriterBase &rewriter, + FunctionOpInterface funcOp) const { + // See NOTE: [Additional Function Arguments] + if (!LLVM::isKernel(funcOp)) { + return funcOp.getArgument(funcOp.getNumArguments() - 2); + } + + auto mod = funcOp->getParentOfType(); + auto globalBase = dyn_cast(mod.lookupSymbol("global_smem")); + assert(globalBase); + return rewriter.create(funcOp.getLoc(), globalBase); +} + bool TargetInfo::supportVectorizedAtomics() const { // Note: not currently tested or used, but AMD generally supports vectorized // atomics. diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 31fa09e519..4461984897 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -65,6 +65,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase { bool supportVectorizedAtomics() const override; + Value getStackPointer(RewriterBase &rewriter, + FunctionOpInterface funcOp) const override; + private: void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args, RewriterBase &rewriter, bool useStdErr) const; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ControlFlowOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ControlFlowOpToLLVM.cpp index 398ede09a7..6074a99bc6 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -86,11 +86,11 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { callOp.getLoc(), /*opOperands=*/callOp->getOperands(), adaptor.getOperands(), rewriter); if (!caller->hasAttr("allocation.offset")) { - auto base = LLVM::intel::getStackPointer(rewriter, caller); + auto base = targetInfo.getStackPointer(rewriter, caller); promotedOperands.push_back(base); return promotedOperands; } - promotedOperands.push_back(LLVM::intel::getSharedMemoryBase( + promotedOperands.push_back(LLVM::getSharedMemoryBase( callOp->getLoc(), rewriter, targetInfo, callOp)); return promotedOperands; } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 5d676eb5fb..6e4ece2d35 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -237,8 +237,8 @@ struct ConvertLayoutOpConversion Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - Value smemBase = LLVM::intel::getSharedMemoryBase(loc, rewriter, targetInfo, - op.getOperation()); + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); smemBase = bitcast(smemBase, elemPtrTy); auto shape = dstTy.getShape(); @@ -819,8 +819,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion Type elementType = inVals.front().getType(); auto mod = rewriter.getInsertionPoint()->getParentOfType(); - Value smemBase = LLVM::intel::getSharedMemoryBase( - loc, rewriter, targetInfo, &*rewriter.getInsertionPoint()); + Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + &*rewriter.getInsertionPoint()); Type ptrType = smemBase.getType(); int numRows = inVals.size(); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/HistogramOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/HistogramOpToLLVM.cpp index 4aa8d753bc..94f1bcb82e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/HistogramOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/HistogramOpToLLVM.cpp @@ -181,8 +181,8 @@ struct HistogramOpConversion // TODO: we could skip this for cases with num_warps=1 as long as we can // generate the right layout. Currently the warp level histogram generates // data in the default blocked layout. - Value baseSharedMemPtr = LLVM::intel::getSharedMemoryBase( - loc, rewriter, targetInfo, op.getOperation()); + Value baseSharedMemPtr = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); auto dstType = op.getType(); Attribute dstEncoding = dstType.getEncoding(); auto indices = ::intel::emitIndices(op.getLoc(), rewriter, targetInfo, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 8bb7a39e81..518823320c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1500,8 +1500,8 @@ struct AtomicCASOpConversion rewriter.eraseOp(op); return success(); } - Value atomPtr = LLVM::intel::getSharedMemoryBase( - loc, rewriter, targetInfo, op.getOperation()); + Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); targetInfo.storeShared(rewriter, loc, atomPtr, ret, mask); createBarrier(rewriter, loc, numCTAs); @@ -1681,8 +1681,8 @@ struct AtomicRMWOpConversion rewriter.eraseOp(op); return success(); } - Value atomPtr = LLVM::intel::getSharedMemoryBase( - loc, rewriter, targetInfo, op.getOperation()); + Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); // Only threads with rmwMask = True store the result targetInfo.storeShared(rewriter, loc, atomPtr, ret, rmwMask); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp index 01798b0d9b..5b5256252e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp @@ -81,8 +81,8 @@ struct LocalAllocOpConversion if (!op.isSharedMemoryAlloc()) return failure(); Location loc = op->getLoc(); - Value smemBase = LLVM::intel::getSharedMemoryBase(loc, rewriter, targetInfo, - op.getOperation()); + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); auto resultTy = cast(op.getType()); auto typeConverter = getTypeConverter(); auto sharedLayout = diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceScanCommon.h b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceScanCommon.h index df2e4ed09b..2fbe98a267 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceScanCommon.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceScanCommon.h @@ -148,8 +148,8 @@ class ConvertTritonIntelGPUReduceScanToLLVMPattern }); // Assign base index to each operand in their order in indices std::map indexToBase; - auto basePtr = LLVM::intel::getSharedMemoryBase(loc, rewriter, targetInfo, - op.getOperation()); + auto basePtr = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); indexToBase[indices[0]] = basePtr; for (unsigned i = 1; i < op.getNumOperands(); ++i) { indexToBase[indices[i]] = diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp index c3be0734ef..4ddddc0f3a 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp @@ -304,4 +304,14 @@ bool TargetInfo::supportVectorizedAtomics() const { return true; } +Value TargetInfo::getStackPointer(RewriterBase &rewriter, + FunctionOpInterface funcOp) const { + auto mod = funcOp->getParentOfType(); + LLVM::LLVMPointerType ptrTy = ptr_ty( + rewriter.getContext(), TritonGEN::TritonGENMemorySpace::kWorkgroup); + if (mod->getAttrOfType("ttg.shared").getInt() == 0) + return rewriter.create(funcOp.getLoc(), ptrTy); + return funcOp.getArgument(funcOp.getNumArguments() - 1); +} + } // namespace mlir::triton::intel diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h index acc0c25da3..35f923b240 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h @@ -66,6 +66,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase { bool supportVectorizedAtomics() const override; + Value getStackPointer(RewriterBase &rewriter, + FunctionOpInterface funcOp) const override; + private: }; } // namespace mlir::triton::intel diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index 0160421dc6..b2b1d0e9a4 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -83,38 +83,6 @@ Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, LLVM::LLVMFuncOp getSpirvPrintfDeclaration(RewriterBase &rewriter); -static Value getStackPointer(PatternRewriter &rewriter, - FunctionOpInterface funcOp) { - auto mod = funcOp->getParentOfType(); - LLVM::LLVMPointerType ptrTy = ptr_ty( - rewriter.getContext(), TritonGEN::TritonGENMemorySpace::kWorkgroup); - if (mod->getAttrOfType("ttg.shared").getInt() == 0) - return rewriter.create(funcOp.getLoc(), ptrTy); - return funcOp.getArgument(funcOp.getNumArguments() - 1); -} - -static Value getSharedMemoryBase(Location loc, - ConversionPatternRewriter &rewriter, - const TargetInfoBase &target, Operation *op) { - auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), - target.getSharedAddressSpace()); - FunctionOpInterface func = op->getParentOfType(); - // CI debugging usage here - if (!op->hasAttr("allocation.offset")) { - auto mod = op->getParentOfType(); - llvm::errs() << "op: " << *op << "\n"; - llvm::errs() << "mod:" << mod << "\n"; - llvm_unreachable("missing allocation.offset"); - } - size_t offset = cast(op->getAttr("allocation.offset")) - .getValue() - .getZExtValue(); - Value offVal = i32_val(offset); - Value base = - gep(ptrTy, i8_ty, LLVM::intel::getStackPointer(rewriter, func), offVal); - return base; -} - static Value getModuleWarpSize(RewriterBase &rewriter, Location loc) { auto mod = rewriter.getBlock()->getParent()->getParentOfType(); return i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 7c4a9e5b92..f9386c43fd 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -612,6 +612,19 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, int TargetInfo::getSharedAddressSpace() const { return 3; } +Value TargetInfo::getStackPointer(RewriterBase &rewriter, + FunctionOpInterface funcOp) const { + // See NOTE: [Additional Function Arguments] + if (!LLVM::isKernel(funcOp)) { + return funcOp.getArgument(funcOp.getNumArguments() - 2); + } + + auto mod = funcOp->getParentOfType(); + auto globalBase = dyn_cast(mod.lookupSymbol("global_smem")); + assert(globalBase); + return rewriter.create(funcOp.getLoc(), globalBase); +} + bool TargetInfo::supportVectorizedAtomics() const { return computeCapability >= 90 && ptxVersion >= 81; } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index b891792ec8..5cdfaccd27 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -60,6 +60,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase { StringRef file, StringRef func, int line) const override; int getSharedAddressSpace() const override; + Value getStackPointer(RewriterBase &rewriter, + FunctionOpInterface funcOp) const override; + bool supportVectorizedAtomics() const override; int getPtxVersion() const { return ptxVersion; } From d662e65379fe923c201a61d80d3d65d01e7f8eef Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Thu, 19 Dec 2024 10:17:45 +0800 Subject: [PATCH 12/14] [CommonCodeClean]Clean changes in common code (#2950) Clean changes in common code --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 27 +----------- .../LoadStoreOpToLLVM.cpp | 42 +++++++++---------- 2 files changed, 21 insertions(+), 48 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 39694ccd83..862c6625da 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -304,16 +304,6 @@ SmallVector getOrder(Attribute layout) { } if (auto dotLayout = dyn_cast(layout)) { auto rank = dotLayout.getWarpsPerCTA().size(); - // FIXME: delete if branch for `DpasEncodingAttr` and provide more - // general solution to make `getOrderForDotOperand` function compatible - // with Intel layouts. - // More details: - // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 - if (dyn_cast(dotLayout.getParent())) { - SmallVector order(rank); - std::iota(order.rbegin(), order.rend(), 0); - return order; - } return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true); } if (auto sliceLayout = dyn_cast(layout)) { @@ -1093,10 +1083,6 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return amdWmmaParent.getTotalElemsPerThreadForOperand( shape, eltTy, getKWidth(), getOpIdx()); } - if (auto dpasParent = mlir::dyn_cast(mmaParent)) { - return dpasParent.getTotalElemsPerThreadForOperand( - shape, eltTy, getKWidth(), getOpIdx()); - } } if (auto blockedLayout = mlir::dyn_cast(getParent())) { auto shapePerCTA = getShapePerCTA(*this, shape); @@ -1159,17 +1145,8 @@ SmallVector DotOperandEncodingAttr::getWarpOrder() const { return {}; } SmallVector DotOperandEncodingAttr::getThreadOrder() const { - // FIXME: delete if branch for `DpasEncodingAttr` and provide more - // general solution to make `getOrderForDotOperand` function compatible - // with Intel layouts. - // More details: - // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 - if (mlir::dyn_cast(getParent())) { - return ::getOrder(*this); - } else { - return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), - /*kMajor*/ true); - } + return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), + /*kMajor*/ true); } LogicalResult DotOperandEncodingAttr::verify( diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 518823320c..73eae00cd6 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -526,6 +526,21 @@ struct LoadOpConversion }; auto opIdx = getOpIdx(); + std::optional llEncoding = + cast(encoding).toLinearLayout( + tensorType.getShape()); + assert(llEncoding.has_value() && "invalid dot layout to linear layout"); + LinearEncodingAttr llAttr = + LinearEncodingAttr::get(rewriter.getContext(), *llEncoding); + SmallVector threadOrder = llAttr.getThreadOrder(); + size_t rank = threadOrder.size(); + const bool valueRowMajor = + (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); + assert((valueRowMajor || + (threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) && + "Only row_major or column_major is allowed"); + const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; + Type eltTy = tensorType.getElementType(); unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); @@ -539,7 +554,7 @@ struct LoadOpConversion SmallVector numReps = dpasLayout.getDPASRepetitions(tensorShape, opIdx); const SmallVector warpsPerCTA = dpasLayout.getWarpsPerCTA(); - SmallVector dpasOrder = triton::gpu::getOrder(dpasLayout); + SmallVector dpasWarpsOrder = triton::gpu::getOrder(dpasLayout); int threadsPerWarp = triton::gpu::getWarpSize(dpasLayout); Value warpId = rewriter.create( @@ -547,7 +562,7 @@ struct LoadOpConversion rewriter.create(loc, /*upperBound=*/nullptr)); SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder); + delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); if (hasDpasLayout) { // A block load with the DPAS layout but without the DotDpasLayout is @@ -557,14 +572,6 @@ struct LoadOpConversion // aligns to the DPAS layout as the DPAS operation output layout // distributes rows across work items. - size_t rank = dpasOrder.size(); - const bool valueRowMajor = - (dpasOrder[rank - 2] == 1 && dpasOrder[rank - 1] == 0); - assert((valueRowMajor || - (dpasOrder[rank - 2] == 0 && dpasOrder[rank - 1] == 1)) && - "Only row_major or column_major is allowed"); - const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; - if (isTransposeRequired) { // TODO: this would likely require a shuffle to match the expected // ordering coming out of the DPAS layout and requires more @@ -675,17 +682,6 @@ struct LoadOpConversion return success(); } - DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value(); - auto dotOrder = dotLayout.getThreadOrder(); - - size_t rank = dotOrder.size(); - const bool valueRowMajor = - (dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0); - assert((valueRowMajor || - (dotOrder[rank - 2] == 0 && dotOrder[rank - 1] == 1)) && - "Only row_major or column_major is allowed"); - const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; - bool isOperandA = (opIdx == DpasEncodingAttr::OpIdx::OperandA); SmallVector dpasInstShape = isOperandA ? dpasLayout.getDPASInstShapeA() @@ -749,8 +745,8 @@ struct LoadOpConversion offsetBaseY] = getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter); - unsigned tileWidth = elemsPerDPASInst[dotOrder[rank - 2]]; - unsigned tileHeight = elemsPerDPASInst[dotOrder[rank - 1]]; + unsigned tileWidth = elemsPerDPASInst[threadOrder[rank - 2]]; + unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1]]; unsigned vBlocks = 1; unsigned numOperandsOuterDimPerLoad = 1; unsigned numOperandsInnerDimPerLoad = 1; From c280ea50d29ae24dc97cd351e3837d6df170dc37 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 19 Dec 2024 02:59:57 +0000 Subject: [PATCH 13/14] Fix Windows build failure Signed-off-by: Whitney Tsang --- .../TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index d77c0546d5..99704242ef 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -879,7 +879,8 @@ createAsyncOps(scf::ForOp &forOp, llvm::MapVector stageGroups; for (auto &[loadOp, info] : loadToInfo) { - AsyncLoad asyncLoad = {.loadOp = loadOp}; + AsyncLoad asyncLoad; + asyncLoad.loadOp = loadOp; bool isTMALoad = false; int numBuffers = info.distToUse; // For MMAv3, we need an extra buffer as this is assumed in the wgmma From 13725c19f3510c99f5d68bf5476801d3a4f1b0f6 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 19 Dec 2024 08:53:25 -0500 Subject: [PATCH 14/14] [TEST] Test `DpasLayout` in `test_local_load_store_mma` (#3047) Signed-off-by: Whitney Tsang --- python/test/unit/language/test_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 158cc0ea69..a12d0641e4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5595,6 +5595,8 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: MmaLayout((3, 0), [4, 2], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # multiple warps on the row MmaLayout((3, 0), [4, 2], [1, 1], [1, 1], [0, 1], [16, 64, 16]), # small instrN MmaLayout((3, 0), [8, 4], [1, 1], [1, 1], [0, 1], [16, 64, 16]), # large number of warps + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32, + warps_per_cta=[4, 1], rep_cluster=[1, 1]), ] shared_layouts = [