diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 74ea99b588..cfc00926dd 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -130,6 +130,17 @@ unsigned getNumWarpsPerCTA(Attribute layout); unsigned getNumCTAs(Attribute layout); +// Return the order that represents that the batch is in row-major or +// column-major order for a batch of matrices of shape [*, m, n] with +// len(shape) == rank. +SmallVector getMatrixOrder(unsigned rank, bool rowMajor); + +// Return the order that represents that the dot operand is in kMajor +// (contiguous in the inner dimension) or it's contiguous on the outer +// dimension. +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kMajor); + bool isExpensiveCat(CatOp cat, Attribute targetEncoding); // Return true if a view between the two types cannot be implemented as a no-op. diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index c32f12bc70..52604771aa 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -238,6 +238,19 @@ static SmallVector eraseOrder(ArrayRef order, return resOrder; } +SmallVector getMatrixOrder(unsigned rank, bool rowMajor) { + // Return the order that represents that the batch is in row-major or + // column-major order for a batch of matrices of shape [*, m, n] with + // len(shape) == rank. + assert(rank >= 2); + SmallVector order(rank); + std::iota(order.rbegin(), order.rend(), 0); + if (!rowMajor) { + std::swap(order[0], order[1]); + } + return order; +} + SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, bool kMajor) { // kMajor: if true, the matrix is fastest-running on k, @@ -247,15 +260,8 @@ SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, // batch (if rank == 3) is always the slowest running dimension assert(rank == 2 || rank == 3); assert(opIdx == 0 || opIdx == 1); - SmallVector order(rank); - std::iota(order.rbegin(), order.rend(), 0); - // If opIdx is 1 and kMajor is true, the order is [0, 1] - // (resp. [1, 2, 0] if rank == 3) - // Same if opIdx is 0 and kMajor is false - if (bool(opIdx) == kMajor) { - std::swap(order[0], order[1]); - } - return order; + auto rowMajor = bool(opIdx) != kMajor; + return getMatrixOrder(rank, rowMajor); } SmallVector getWarpOrder(Attribute layout) { @@ -265,20 +271,21 @@ SmallVector getWarpOrder(Attribute layout) { } } auto order = getOrder(layout); - // FIXME: This mmaLayout if should just return - // getOrderForDotOperand(0, order.size(), kMajor=false) - // as mma has the same order as DotOperand(opIdx=0) + // FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's + // M-major This is awkward. Since we can choose any warpOrder in Ampere, we + // should probably choose M-major and change `LinearLayoutConversion.cpp` and + // `MMAv2.cpp` to match. if (auto mmaLayout = dyn_cast(layout)) { if (mmaLayout.isHopper()) { - // Hopper MMA instructions force a warp order of [0, 1]. See docs: + // Hopper MMA instructions force warps to be column-major // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8 - auto it = std::find(order.begin(), order.end(), 0); - order.erase(it); - order.insert(order.begin(), 0); + return getMatrixOrder(order.size(), /*rowMajor*/ false); } } else if (auto dotOpLayout = dyn_cast(layout)) { - order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(), - /*kMajor*/ false); + // It's quite weird to talk about warp order when that the warps + // are broadcasted along the K dimension + llvm::report_fatal_error( + "DotOperandEncoding::getWarpOrder not implemented"); } return order; } @@ -288,11 +295,11 @@ SmallVector getOrder(Attribute layout) { return llvm::to_vector(blockedLayout.getOrder()); } if (auto mmaLayout = dyn_cast(layout)) { + // Order doesn't really matter. We just have to be consistent when unpacking + // the elements in the MMAv2/V3 lowerings. We choose row-major auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); - SmallVector order(rank); - std::iota(order.rbegin(), order.rend(), 0); - return order; + return getMatrixOrder(rank, /*rowMajor*/ true); } if (auto dotLayout = dyn_cast(layout)) { auto rank = dotLayout.getWarpsPerCTA().size(); @@ -434,7 +441,7 @@ unsigned getNumWarpsPerCTA(Attribute layout) { else if (auto wmmaLayout = dyn_cast(layout)) warpsPerCTA = wmmaLayout.getWarpsPerCTA(); else if (auto dotLayout = dyn_cast(layout)) - return getNumWarpsPerCTA(dotLayout.getParent()); + warpsPerCTA = dotLayout.getWarpsPerCTA(); else if (auto sharedLayout = dyn_cast(layout)) llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr"); else @@ -2176,25 +2183,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( ArrayRef shape, int kWidth, int opIdx) const { assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); - auto parentShapePerCTATile = getShapePerCTATile(shape); - auto rank = parentShapePerCTATile.size(); + auto shapePerCTATile = getShapePerCTATile(shape); + auto rank = shapePerCTATile.size(); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; // 4 threads * 2 subtiles - unsigned kWidthTile = kWidth * 2 * 4; - if (opIdx == 0) { - if (rank == 2) - return {parentShapePerCTATile[rank - 2], kWidthTile}; - else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], - kWidthTile}; - } else if (opIdx == 1) { - if (rank == 2) - return {kWidthTile, parentShapePerCTATile[rank - 1]}; - else - return {parentShapePerCTATile[0], kWidthTile, - parentShapePerCTATile[rank - 1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } + shapePerCTATile[kDim] = kWidth * 2 * 4; + return shapePerCTATile; } SmallVector NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index d1b8b03428..9901f246ba 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -42,6 +42,17 @@ SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { return ret; } +// TODO Have order be a mandatory argument of standardOutDimNames. +SmallVector permuteDimNames(const SmallVector &names, + const SmallVector &order) { + assert(names.size() == order.size()); + SmallVector ret; + for (unsigned i : order) { + ret.push_back(names[i]); + } + return ret; +} + void assertIsRegisterLayout(const LinearLayout &layout) { assert(layout.getNumInDims() > 0); MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); @@ -282,15 +293,19 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); SmallVector dimNames = standardOutDimNames(ctx, rank); + auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma)); + // By using `reverse(dimNames)` below, we set the order to be row-major + assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}}, {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, - llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); - - ctaLayout *= identityND( - S("warp"), mma.getWarpsPerCTA(), - llvm::to_vector(llvm::reverse(llvm::seq(rank))), dimNames); + ArrayRef(orderedDimNames).take_front(2)); + assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); + // FIXME(Lezcano). identityND should not have an `order` param as it's + // redundant with the order of the out dims. + ctaLayout *= + identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames); return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); } @@ -323,10 +338,14 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")), S("register"), S("dim1")); - // Expand the `warp` dimension according to warpsPerCTA. - // - // It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but - // this really does seem to be correct. + // The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major. + // Since the warpOrder needs to be M-major, we need to transpose the out + // dimensions AND transpose the order + // FIXME(Lezcano). identityND should not have an `order` param as it's + // redundant. The order is already given by the order of the + // out dims, and if it has an order, it shouldn't change the + // order of the out dims. + assert(getWarpOrder(mma) == SmallVector({0, 1})); ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}, {S("dim0"), S("dim1")}) .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); @@ -844,18 +863,24 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { LinearLayout ampereDotToLinearLayout(ArrayRef shape, DotOperandEncodingAttr dot) { - // TODO,BE. Implement ampereMMA in terms of this one + // Note that, even though MMAv2 looks similar to this layout, they are just + // the same at a register and lane level. The warps treatment is different! int rank = shape.size(); auto mma = cast(dot.getParent()); int kWidth = dot.getKWidth(); bool isA = dot.getOpIdx() == 0; - assert(mma.isAmpere()); assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); + assert(mma.isAmpere()); MLIRContext *ctx = mma.getContext(); - SmallVector dimNames = standardOutDimNames(ctx, rank); + // A and B have kMajor order + assert(getOrder(dot) == + getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true)); + + auto kMajorDims = + permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot)); // Implement A. For B transpose in the end std::vector> registers; @@ -882,24 +907,51 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, } registers.push_back({i, 0}); - if (!isA) { - for (auto &r : registers) { - std::swap(r[0], r[1]); + LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}}, + ArrayRef(kMajorDims).take_front(2)); + + // Let warpsPerCTAMma = {2, 2}, then + // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB + // assume warpOrder = {0, 1} + // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that + // the C is owned as per the following layout: + // C: 0 | 1 + // - | - + // 2 | 3 + // In order to be able to compute C, we need the following warp tiling of + // A and B: + // A: 0 1 | 0 1 B: 0 2 | 1 3 + // - - | - - - - | - - + // 2 3 | 2 3 0 2 | 1 3 + // In particular, for A and B we need to broadcast along K + + assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true)); + auto warpsPerCTAMma = mma.getWarpsPerCTA(); + std::vector> warps; + if (isA) { + for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { + warps.push_back({0, 0}); + } + for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + warps.push_back({0, i}); + } + } else { + for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { + warps.push_back({0, i}); } - for (auto &l : lanes) { - std::swap(l[0], l[1]); + for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + warps.push_back({0, 0}); + } + } + if (rank == 3) { + for (auto &w : warps) { + w.push_back(0); } } - LinearLayout ctaLayout( - {{S("register"), registers}, {S("lane"), lanes}}, - llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); - - auto order = dot.getCTAOrder(); - assert(order[0] == rank - 1 && order[1] == rank - 2); - ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); + ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims); - return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); + return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); } std::optional @@ -908,7 +960,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { if (auto mfmaLayout = llvm::dyn_cast(parent)) { return mfmaDotToLinearLayout(*this, shape); } else if (auto mma = mlir::dyn_cast(parent)) { - if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) { + if (mma.isAmpere()) { return ampereDotToLinearLayout(shape, *this); } } else if (auto dpasLayout = diff --git a/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp b/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp index 16183b1af4..203fe01ba6 100644 --- a/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp +++ b/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp @@ -1,4 +1,5 @@ #include "mlir/IR/Dominance.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/Passes.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -14,8 +15,52 @@ namespace gpu { #define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" -// Return true if the select could be merged into the If without breaking SSA -// rules. +/// The user of select maybe inside either the ThenRegion or ElseRegion of +/// the scf.if. So, canonicalize user of select in scf.if first. +static void canonicalizeSelectUsersInSCFIf(ModuleOp input) { + llvm::MapVector, SmallVector> + usersNeedreplaced; + input.walk([&](arith::SelectOp selectOp) { + auto *parentBlock = selectOp->getBlock(); + Value condition = selectOp.getOperand(0); + Value trueVal = selectOp.getOperand(1); + Value falseVal = selectOp.getOperand(2); + Value resVal = selectOp.getResult(); + for (auto *condUser : condition.getUsers()) { + if (!llvm::isa(condUser)) + continue; + scf::IfOp ifOp = llvm::cast(condUser); + for (auto *resUser : resVal.getUsers()) { + if (ifOp->isProperAncestor(resUser)) { + if (ifOp.getThenRegion().findAncestorOpInRegion(*resUser) != + nullptr) { + // The user is inside the ThenRegion of the scf.if. + usersNeedreplaced[std::make_pair(resVal, trueVal)].push_back( + resUser); + } else { + // The user is inside the ElseRegion of the scf.if. + usersNeedreplaced[std::make_pair(resVal, falseVal)].push_back( + resUser); + } + } + } + } + }); + + // Replace the operand of user. + for (auto [replacedSrcAndDst, users] : + llvm::make_early_inc_range(usersNeedreplaced)) { + Value srcVal = replacedSrcAndDst.first; + Value dstVal = replacedSrcAndDst.second; + for (Operation *user : llvm::make_early_inc_range(users)) { + srcVal.replaceUsesWithIf( + dstVal, [&](OpOperand &use) { return use.getOwner() == user; }); + } + } +} + +/// Return true if the select could be merged into the If without breaking SSA +/// rules. static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp, DominanceInfo &dom) { // If needs to be dominated by the select. @@ -38,10 +83,11 @@ class CombineTensorSelectAndIfPass void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); - DominanceInfo dom(m); + canonicalizeSelectUsersInSCFIf(m); // Go over the arith.select ops, look if there is an if // with the same condition. + DominanceInfo dom(m); llvm::MapVector> selectToIf; m.walk([&](arith::SelectOp selectOp) { // Look if there is an if in the same block, with the same condition. diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 459ca40cf9..605a896a8b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5,7 +5,7 @@ from typing import Optional import math import textwrap -import tempfile +import pathlib import numpy as np import pytest @@ -1776,47 +1776,34 @@ def kernel(X, Y, Z, N: tl.constexpr): @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", list(torch_dtypes)) +@pytest.mark.parametrize("constant_field", ["value", "mask"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_store_constant(dtype_str, num_ctas, device): +def test_store_constant(num_ctas, dtype_str, constant_field, device): check_type_supported(dtype_str, device) - """Tests that boolean True is stored as 1""" @triton.jit - def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - output = GENERATE_TEST_HERE + if CONSTANT_FIELD == "value": + value = 1 + output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) + mask = offsets < n_elements + elif CONSTANT_FIELD == "mask": + output = offsets < n_elements + mask = False tl.store(output_ptr + offsets, output, mask=mask) - triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str - kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'}) block_size = 128 ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) - kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) - - assert torch.all(output == ref) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_store_constant_default_dtype(num_ctas, device): - """Tests that boolean True is stored as 1""" - - @triton.jit - def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - value = 1 - output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) - tl.store(output_ptr + offsets, output, mask=mask) - block_size = 128 - ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device) - output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device) - kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field) - assert torch.all(output == ref) + if constant_field == "value": + print(output, ref) + assert torch.all(output == ref) + else: + assert torch.all(output == 0) def test_load_store_same_ptr(device): @@ -2606,7 +2593,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl. @pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) @pytest.mark.parametrize("src_layout", scan_layouts) @pytest.mark.parametrize("axis", [0, 1]) -def test_scan_layouts(M, N, src_layout, axis, device): +def test_scan_layouts(M, N, src_layout, axis, device, tmp_path: pathlib.Path): ir = f""" #blocked = {src_layout} @@ -2639,10 +2626,10 @@ def test_scan_layouts(M, N, src_layout, axis, device): }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_scan_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + rs = RandomState(17) x = rs.randint(-100, 100, (M, N)).astype('int32') @@ -2679,7 +2666,7 @@ def test_scan_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) @pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) @pytest.mark.parametrize("reduce_op", ["sum", "max"]) -def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device): +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device, tmp_path: pathlib.Path): if isinstance(src_layout, (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") @@ -2773,10 +2760,9 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> """ + epilogue - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_reduce_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) @@ -2806,7 +2792,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce @pytest.mark.parametrize("M", [32, 64, 128, 256]) @pytest.mark.parametrize("src_layout", layouts) -def test_store_op(M, src_layout, device): +def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): ir = f""" #src = {src_layout} @@ -2827,10 +2813,9 @@ def test_store_op(M, src_layout, device): }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - store_kernel = triton.compile(f.name) + temp_file = tmp_path / "test_store_op.ttgir" + temp_file.write_text(ir) + store_kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, 1)).astype('float32') @@ -2857,7 +2842,7 @@ def test_store_op(M, src_layout, device): @pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) @pytest.mark.parametrize("src_dim", [0, 1]) @pytest.mark.parametrize("dst_dim", [0, 1]) -def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path: pathlib.Path): ir = f""" #dst = {dst_layout} @@ -2877,10 +2862,9 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): }} }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convert1d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, )).astype('int32') @@ -2918,7 +2902,7 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("op", ["sum", "max"]) @pytest.mark.parametrize("first_axis", [0, 1]) -def test_chain_reduce(M, N, src_layout, op, device, first_axis): +def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathlib.Path): op_str = "" if op == "sum": @@ -2959,10 +2943,9 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): }} }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_chain_reduce.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, N)).astype('int32') @@ -5297,7 +5280,7 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("interm_layout", intermediate_layouts) @pytest.mark.parametrize("dst_layout", layouts) -def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path): if str(src_layout) == str(dst_layout): pytest.xfail("Do not convert same layout") if is_hip() or is_xpu(): @@ -5366,10 +5349,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x, device=device) - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convert2d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) assert torch.equal(z, x) @@ -5422,7 +5405,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): @pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) @pytest.mark.parametrize("dtype", ['float16']) @pytest.mark.parametrize("mma_pair", mma_pairs) -def test_convertmma2mma(M, N, mma_pair, dtype, device): +def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path): if is_hip() or is_xpu(): pytest.xfail("test_mma2mma is not supported in HIP/XPU") @@ -5479,10 +5462,10 @@ def do_test(src_layout, dst_layout): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x) - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convertmma2mma.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) assert torch.equal(z, x) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index a45cb3f888..a5f381dc9f 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,7 +1,7 @@ import importlib.util import itertools import shutil -import tempfile +import pathlib import pytest import torch @@ -129,17 +129,15 @@ def test_combine_fn_change(): seen_keys.add(key) -def write_and_load_module(code, num_extra_lines): - with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: - f.write(('# extra line\n' * num_extra_lines) + code) - f.flush() - spec = importlib.util.spec_from_file_location("module.name", f.name) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) +def write_and_load_module(temp_file: pathlib.Path, code, num_extra_lines): + temp_file.write_text(('# extra line\n' * num_extra_lines) + code) + spec = importlib.util.spec_from_file_location("module.name", str(temp_file)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) return module -def test_changed_line_numbers_invalidate_cache(): +def test_changed_line_numbers_invalidate_cache(tmp_path: pathlib.Path): from textwrap import dedent code = dedent(""" import triton @@ -147,10 +145,12 @@ def test_changed_line_numbers_invalidate_cache(): def test_kernel(i): i = i + 1 """) - orig_mod = write_and_load_module(code, 0) + temp_file0 = tmp_path / "test_changed_line_numbers_invalidate_cache0.py" + orig_mod = write_and_load_module(temp_file0, code, 0) orig_cache_key = orig_mod.test_kernel.cache_key - updated_mod = write_and_load_module(code, 1) + temp_file1 = tmp_path / "test_changed_line_numbers_invalidate_cache1.py" + updated_mod = write_and_load_module(temp_file1, code, 1) updated_cache_key = updated_mod.test_kernel.cache_key assert orig_cache_key != updated_cache_key diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 6e12280737..fc31959d27 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1253,7 +1253,7 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): val = cast(val, elt_ty, builder) # Build IR - if not mask: + if mask is None: return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) if not mask.type.scalar.is_bool(): raise ValueError("Mask must have boolean scalar type") @@ -1308,7 +1308,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, if val is not None: val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) val = cast(val, ptr.type.scalar.element_ty, builder) - if not mask: + if mask is None: mask_ir = builder.get_int1(True) mask_ty = tl.int1 if ptr.type.is_block(): diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 1464d489bc..49a8bb32c4 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -20,13 +20,13 @@ """ import argparse -import time import torch import triton import triton.language as tl import triton.tools.experimental_descriptor import triton.profiler as proton +from contextlib import contextmanager if torch.cuda.is_available(): from triton._C.libtriton import nvidia @@ -48,6 +48,8 @@ def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + if "tiles_per_update" in args: + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}, tiles_per_update={args['tiles_per_update']:02}]" if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: @@ -541,7 +543,24 @@ def torch_matmul(a, b): return c -def bench(K, dtype, tiles_per_update, reps=10): +@contextmanager +def proton_context(): + proton.activate(0) + try: + yield + finally: + proton.deactivate(0) + + +def bench_fn(reps, warmup_reps, fn, *args): + for _ in range(warmup_reps): + fn(*args) + with proton_context(): + for _ in range(reps): + fn(*args) + + +def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000): M = 8192 N = 8192 a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) @@ -549,33 +568,15 @@ def bench(K, dtype, tiles_per_update, reps=10): b = b.T.contiguous() - proton.activate(0) - if cublas is not None: - for _ in range(reps): - cublas_matmul(a, b) - time.sleep(0.01) + bench_fn(reps, warmup_reps, cublas_matmul, a, b) if dtype == torch.float16: - for _ in range(reps): - torch_matmul(a, b) - time.sleep(0.01) - for _ in range(reps): - matmul(a, b.T) - time.sleep(0.01) - for _ in range(reps): - matmul_persistent(a, b.T) - time.sleep(0.01) + bench_fn(reps, warmup_reps, torch_matmul, a, b) + bench_fn(reps, warmup_reps, matmul, a, b.T) + bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) if supports_tma(): - for _ in range(reps): - matmul_tma_persistent(a, b) - time.sleep(0.01) - with proton.scope( - f"matmul_kernel_device_tma_persistent [M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}]"): - for _ in range(reps): - matmul_device_tma_persistent(a, b, tiles_per_update) - time.sleep(0.01) - - proton.deactivate(0) + bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) + bench_fn(reps, warmup_reps, matmul_device_tma_persistent, a, b, tiles_per_update) def validate(M, N, K, dtype, tiles_per_update): diff --git a/test/TritonGPU/combine-select-if.mlir b/test/TritonGPU/combine-select-if.mlir index 62a9474dcb..f00b971235 100644 --- a/test/TritonGPU/combine-select-if.mlir +++ b/test/TritonGPU/combine-select-if.mlir @@ -1,46 +1,77 @@ // RUN: triton-opt %s -split-input-file -tritongpu-combine-tensor-select-and-if | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @select_if_combine - tt.func public @select_if_combine(%arg0: tensor<64xf32, #blocked>, %dst_ptr: tensor<64x!tt.ptr, #blocked>, %cnd: i1) attributes {noinline = false} { - // CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> - %cst = arith.constant dense<0.000000e+00> : tensor<64xf32, #blocked> - // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> - %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32, #blocked> - // CHECK-NOT: arith.select - %sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32, #blocked> - // CHECK: %[[IF_RES:.*]] = scf.if - scf.if %cnd { - tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr, #blocked> - // CHECK: scf.yield %[[CST0]] - } - // CHECK: else - // CHECK: scf.yield %[[CST1]] - // CHECK: tt.store %{{.*}}, %[[IF_RES]] - tt.store %dst_ptr, %sel : tensor<64x!tt.ptr, #blocked> - tt.return +tt.func public @select_if_combine(%arg0: tensor<64xf32>, %dst_ptr: tensor<64x!tt.ptr>, %cnd: i1) { + // CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> + %cst = arith.constant dense<0.000000e+00> : tensor<64xf32> + // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> + %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32> + // CHECK-NOT: arith.select + %sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32> + // CHECK: %[[R:.+]] = scf.if %{{.*}} + // CHECK: tt.store %{{.*}}, %{{.*}} + // CHECK: scf.yield %[[CST0]] + // CHECK: } else { + // CHECK: scf.yield %[[CST1]] + // CHECK: } + scf.if %cnd { + tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr> } + // CHECK: tt.store %{{.*}}, %[[R]] + tt.store %dst_ptr, %sel : tensor<64x!tt.ptr> + tt.return } // ----- - // CHECK-LABEL: @if_multiple_sel tt.func @if_multiple_sel(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32){ -// CHECK-NOT: select -// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) { -// CHECK: scf.yield {{.*}} : i32, i32, f32 -// CHECK: } else { -// CHECK: scf.yield {{.*}} : i32, i32, f32 -// CHECK: } -// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32 + // CHECK-NOT: arith.select %0 = arith.select %arg0, %arg1, %arg2 : i32 %1 = arith.select %arg0, %arg3, %arg4 : f32 + // CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) { + // CHECK: scf.yield {{.*}} : i32, i32, f32 + // CHECK: } else { + // CHECK: scf.yield {{.*}} : i32, i32, f32 + // CHECK: } %2 = scf.if %arg0 -> (i32) { %3 = arith.subi %arg1, %arg2 : i32 scf.yield %3 : i32 } else { scf.yield %arg1 : i32 } + // CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32 tt.return %0, %1, %2 : i32, f32, i32 } + +// ----- +// CHECK-LABEL: tt.func @users_in_if( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i1 +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i32 +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: i32 +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: f32 +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: f32 +tt.func @users_in_if(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32, i32) { + // CHECK: %[[CST:.*]] = arith.constant 8 : i32 + %c8_i32 = arith.constant 8 : i32 + // CHECK-NOT: arith.select + %0 = arith.select %arg0, %arg1, %arg2 : i32 + %1 = arith.select %arg0, %arg3, %arg4 : f32 + // CHECK: %[[R:.+]]:4 = scf.if %[[ARG0]] -> (i32, i32, i32, f32) { + // CHECK: %[[MULI:.*]] = arith.muli %[[ARG1]], %[[ARG2]] : i32 + // CHECK: %[[ADDI:.*]] = arith.addi %[[ARG1]], %[[CST]] : i32 + // CHECK: scf.yield %[[MULI]], %[[ADDI]], %[[ARG1]], %[[ARG3]] : i32, i32, i32, f32 + // CHECK: } else { + // CHECK: %[[ADDI:.*]] = arith.subi %[[ARG2]], %[[CST]] : i32 + // CHECK: scf.yield %[[ARG1]], %[[ADDI]], %[[ARG2]], %[[ARG4]] : i32, i32, i32, f32 + // CHECK: } + %2:2 = scf.if %arg0 -> (i32, i32) { + %3 = arith.muli %0, %arg2 : i32 + %4 = arith.addi %0, %c8_i32 : i32 + scf.yield %3, %4 : i32, i32 + } else { + %3 = arith.subi %0, %c8_i32 : i32 + scf.yield %arg1, %3 : i32, i32 + } + // CHECK: tt.return %[[R]]#2, %[[R]]#3, %[[R]]#0, %[[R]]#1 : i32, f32, i32, i32 + tt.return %0, %1, %2#0, %2#1 : i32, f32, i32, i32 +} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index a7395f86dc..6dbb0435e2 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -30,6 +30,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Dialect/Triton/IR/Traits.h" + // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" // clang-format on diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 7da8083cfb..c3a69a5f9a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_triton_library(TritonAMDGPUTransforms MfmaGroup.cpp DEPENDS + TritonAMDGPUIR TritonAMDGPUTransformsIncGen TritonGPUIR ) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index b03fb0989d..508f03227c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -121,19 +121,15 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( } if (dot.getOpIdx() == 1) { - // there are kWidth * 2 elems packed as bf16x2 int elemsInTile = dot.getKWidth(); - // n0 and n1 are unrolled in the legacy path - // Unrolling n1 makes some sense, but unrolling n0 makes absolutely no - // sense IMO + // n0 is unrolled in the legacy path, which makes no sense n0 *= 2; - n1 *= 2; for (auto b = 0; b < batch; ++b) - for (auto j = 0; j < n1 / elemsInTile; ++j) - for (auto i = 0; i < n0; ++i) - for (auto k = 0; k < elemsInTile; ++k) { - vals[{b, i, elemsInTile * j + k}] = elems[offset++]; - } + for (auto i = 0; i < n0; ++i) + for (auto j = 0; j < n1; ++j) { + vals[{b, i, 2 * j}] = elems[offset++]; + vals[{b, i, 2 * j + 1}] = elems[offset++]; + } return vals; } } diff --git a/third_party/proton/test/test_api.py b/third_party/proton/test/test_api.py index 713572c4fc..dd26ecbbfc 100644 --- a/third_party/proton/test/test_api.py +++ b/third_party/proton/test/test_api.py @@ -1,23 +1,24 @@ import json import triton.profiler as proton -import tempfile import pathlib -def test_profile(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id0 = proton.start(f.name.split(".")[0]) - proton.activate() - proton.deactivate() - proton.finalize() - assert session_id0 == 0 +def test_profile(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_profile0.hatchet" + session_id0 = proton.start(str(temp_file0.with_suffix(""))) + proton.activate() + proton.deactivate() + proton.finalize() + assert session_id0 == 0 + assert temp_file0.exists() - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id1 = proton.start(f.name.split(".")[0]) - proton.activate(session_id1) - proton.deactivate(session_id1) - proton.finalize(session_id1) - assert session_id1 == session_id0 + 1 + temp_file1 = tmp_path / "test_profile1.hatchet" + session_id1 = proton.start(str(temp_file1.with_suffix(""))) + proton.activate(session_id1) + proton.deactivate(session_id1) + proton.finalize(session_id1) + assert session_id1 == session_id0 + 1 + assert temp_file1.exists() session_id2 = proton.start("test") proton.activate(session_id2) @@ -28,19 +29,16 @@ def test_profile(): pathlib.Path("test.hatchet").unlink() -def test_profile_decorator(): - f = tempfile.NamedTemporaryFile(delete=True) - name = f.name.split(".")[0] +def test_profile_decorator(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_profile_decorator.hatchet" - @proton.profile(name=name) + @proton.profile(name=str(temp_file.with_suffix(""))) def foo0(a, b): return a + b foo0(1, 2) proton.finalize() - assert pathlib.Path(f.name).exists() - - f.close() + assert temp_file.exists() @proton.profile def foo1(a, b): @@ -48,126 +46,130 @@ def foo1(a, b): foo1(1, 2) proton.finalize() - assert pathlib.Path(proton.DEFAULT_PROFILE_NAME + ".hatchet").exists() + default_file = pathlib.Path(proton.DEFAULT_PROFILE_NAME + ".hatchet") + assert default_file.exists() + default_file.unlink() -def test_scope(): +def test_scope(tmp_path: pathlib.Path): # Scope can be annotated even when profiling is off with proton.scope("test"): pass - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("test"): - pass + temp_file = tmp_path / "test_scope.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test"): + pass - @proton.scope("test") - def foo(): - pass + @proton.scope("test") + def foo(): + pass - foo() + foo() - proton.enter_scope("test") - proton.exit_scope() - proton.finalize() - assert pathlib.Path(f.name).exists() + proton.enter_scope("test") + proton.exit_scope() + proton.finalize() + assert temp_file.exists() -def test_hook(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id0 = proton.start(f.name.split(".")[0], hook="triton") - proton.activate(session_id0) - proton.deactivate(session_id0) - proton.finalize(None) - assert pathlib.Path(f.name).exists() +def test_hook(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_hook.hatchet" + session_id0 = proton.start(str(temp_file.with_suffix("")), hook="triton") + proton.activate(session_id0) + proton.deactivate(session_id0) + proton.finalize(None) + assert temp_file.exists() -def test_scope_metrics(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = proton.start(f.name.split(".")[0]) - # Test different scope creation methods - with proton.scope("test0", {"a": 1.0}): - pass +def test_scope_metrics(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_metrics.hatchet" + session_id = proton.start(str(temp_file.with_suffix(""))) + # Test different scope creation methods + with proton.scope("test0", {"a": 1.0}): + pass - @proton.scope("test1", {"a": 1.0}) - def foo(): - pass + @proton.scope("test1", {"a": 1.0}) + def foo(): + pass - foo() + foo() - # After deactivation, the metrics should be ignored - proton.deactivate(session_id) - proton.enter_scope("test2", metrics={"a": 1.0}) - proton.exit_scope() + # After deactivation, the metrics should be ignored + proton.deactivate(session_id) + proton.enter_scope("test2", metrics={"a": 1.0}) + proton.exit_scope() - # Metrics should be recorded again after reactivation - proton.activate(session_id) - proton.enter_scope("test3", metrics={"a": 1.0}) - proton.exit_scope() + # Metrics should be recorded again after reactivation + proton.activate(session_id) + proton.enter_scope("test3", metrics={"a": 1.0}) + proton.exit_scope() - proton.enter_scope("test3", metrics={"a": 1.0}) - proton.exit_scope() + proton.enter_scope("test3", metrics={"a": 1.0}) + proton.exit_scope() - proton.finalize() - assert pathlib.Path(f.name).exists() + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 3 - for child in data[0]["children"]: - if child["frame"]["name"] == "test3": - assert child["metrics"]["a"] == 2.0 - - -def test_scope_properties(): - with open("test.hatchet", "w+") as f: - proton.start(f.name.split(".")[0]) - # Test different scope creation methods - # Different from metrics, properties could be str - with proton.scope("test0", properties={"a": "1"}): - pass + assert len(data[0]["children"]) == 3 + for child in data[0]["children"]: + if child["frame"]["name"] == "test3": + assert child["metrics"]["a"] == 2.0 + + +def test_scope_properties(tmp_path: pathlib.Path): + temp_file = tmp_path / "test.hatchet" + proton.start(str(temp_file.with_suffix(""))) + # Test different scope creation methods + # Different from metrics, properties could be str + with proton.scope("test0", properties={"a": "1"}): + pass - @proton.scope("test1", properties={"a": "1"}) - def foo(): - pass + @proton.scope("test1", properties={"a": "1"}) + def foo(): + pass - foo() + foo() - # Properties do not aggregate - proton.enter_scope("test2", properties={"a": 1.0}) - proton.exit_scope() + # Properties do not aggregate + proton.enter_scope("test2", properties={"a": 1.0}) + proton.exit_scope() - proton.enter_scope("test2", properties={"a": 1.0}) - proton.exit_scope() + proton.enter_scope("test2", properties={"a": 1.0}) + proton.exit_scope() - proton.finalize() - assert pathlib.Path(f.name).exists() + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: data = json.load(f) - for child in data[0]["children"]: - if child["frame"]["name"] == "test2": - assert child["metrics"]["a"] == 1.0 - elif child["frame"]["name"] == "test0": - assert child["metrics"]["a"] == "1" + for child in data[0]["children"]: + if child["frame"]["name"] == "test2": + assert child["metrics"]["a"] == 1.0 + elif child["frame"]["name"] == "test0": + assert child["metrics"]["a"] == "1" -def test_throw(): +def test_throw(tmp_path: pathlib.Path): # Catch an exception thrown by c++ session_id = 100 - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - activate_error = "" - try: - session_id = proton.start(f.name.split(".")[0]) - proton.activate(session_id + 1) - except Exception as e: - activate_error = str(e) - finally: - proton.finalize() - assert "Session has not been initialized: " + str(session_id + 1) in activate_error - - deactivate_error = "" - try: - session_id = proton.start(f.name.split(".")[0]) - proton.deactivate(session_id + 1) - except Exception as e: - deactivate_error = str(e) - finally: - proton.finalize() - assert "Session has not been initialized: " + str(session_id + 1) in deactivate_error + temp_file = tmp_path / "test_throw.hatchet" + activate_error = "" + try: + session_id = proton.start(str(temp_file.with_suffix(""))) + proton.activate(session_id + 1) + except Exception as e: + activate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in activate_error + + deactivate_error = "" + try: + session_id = proton.start(str(temp_file.with_suffix(""))) + proton.deactivate(session_id + 1) + except Exception as e: + deactivate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in deactivate_error diff --git a/third_party/proton/test/test_cmd.py b/third_party/proton/test/test_cmd.py index fa3331c024..e24ff9224d 100644 --- a/third_party/proton/test/test_cmd.py +++ b/third_party/proton/test/test_cmd.py @@ -1,7 +1,7 @@ import pytest import subprocess -import tempfile import json +import pathlib def test_help(): @@ -11,21 +11,22 @@ def test_help(): @pytest.mark.parametrize("mode", ["script", "python", "pytest"]) -def test_exec(mode): +def test_exec(mode, tmp_path: pathlib.Path): file_path = __file__ helper_file = file_path.replace("test_cmd.py", "helper.py") - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - name = f.name.split(".")[0] - if mode == "script": - ret = subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) - elif mode == "python": - ret = subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], - stdout=subprocess.DEVNULL) - elif mode == "pytest": - ret = subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], - stdout=subprocess.DEVNULL) - assert ret == 0 + temp_file = tmp_path / "test_exec.hatchet" + name = str(temp_file.with_suffix("")) + if mode == "script": + ret = subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) + elif mode == "python": + ret = subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], + stdout=subprocess.DEVNULL) + elif mode == "pytest": + ret = subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], + stdout=subprocess.DEVNULL) + assert ret == 0 + with temp_file.open() as f: data = json.load(f, ) - kernels = data[0]["children"] - assert len(kernels) == 2 - assert kernels[0]["frame"]["name"] == "test" or kernels[1]["frame"]["name"] == "test" + kernels = data[0]["children"] + assert len(kernels) == 2 + assert kernels[0]["frame"]["name"] == "test" or kernels[1]["frame"]["name"] == "test" diff --git a/third_party/proton/test/test_lib.py b/third_party/proton/test/test_lib.py index 0380268c04..f149d7965d 100644 --- a/third_party/proton/test/test_lib.py +++ b/third_party/proton/test/test_lib.py @@ -1,6 +1,6 @@ -import triton._C.libproton.proton as libproton -import tempfile import pathlib + +import triton._C.libproton.proton as libproton from triton.profiler.profile import _select_backend @@ -25,22 +25,22 @@ def test_op(): libproton.exit_op(id0, "zero") -def test_session(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = libproton.start(f.name.split(".")[0], "shadow", "tree", _select_backend()) - libproton.deactivate(session_id) - libproton.activate(session_id) - libproton.finalize(session_id, "hatchet") - libproton.finalize_all("hatchet") - assert pathlib.Path(f.name).exists() - - -def test_add_metrics(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - libproton.start(f.name.split(".")[0], "shadow", "tree", _select_backend()) - id1 = libproton.record_scope() - libproton.enter_scope(id1, "one") - libproton.add_metrics(id1, {"a": 1.0, "b": 2.0}) - libproton.exit_scope(id1, "one") - libproton.finalize_all("hatchet") - assert pathlib.Path(f.name).exists() +def test_session(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_session.hatchet" + session_id = libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend()) + libproton.deactivate(session_id) + libproton.activate(session_id) + libproton.finalize(session_id, "hatchet") + libproton.finalize_all("hatchet") + assert temp_file.exists() + + +def test_add_metrics(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_add_metrics.hatchet" + libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend()) + id1 = libproton.record_scope() + libproton.enter_scope(id1, "one") + libproton.add_metrics(id1, {"a": 1.0, "b": 2.0}) + libproton.exit_scope(id1, "one") + libproton.finalize_all("hatchet") + assert temp_file.exists() diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 13cb9bd99c..01bcaf3be0 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -1,10 +1,10 @@ import torch import triton import triton.profiler as proton -import tempfile import json import pytest from typing import NamedTuple +import pathlib import triton.language as tl @@ -14,30 +14,31 @@ def is_hip(): @pytest.mark.parametrize("context", ["shadow", "python"]) -def test_torch(context): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], context=context) - proton.enter_scope("test") - torch.ones((2, 2), device="cuda") - proton.exit_scope() - proton.finalize() +def test_torch(context, tmp_path: pathlib.Path): + temp_file = tmp_path / "test_torch.hatchet" + proton.start(str(temp_file.with_suffix("")), context=context) + proton.enter_scope("test") + torch.ones((2, 2), device="cuda") + proton.exit_scope() + proton.finalize() + with temp_file.open() as f: data = json.load(f) - if context == "shadow": - assert len(data[0]["children"]) == 1 - assert data[0]["children"][0]["frame"]["name"] == "test" - assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 - elif context == "python": - assert len(data[0]["children"]) == 1 - # The last frame is the torch kernel - prev_frame = data - curr_frame = data[0]["children"] - while len(curr_frame) > 0: - prev_frame = curr_frame - curr_frame = curr_frame[0]["children"] - assert "elementwise_kernel" in prev_frame[0]["frame"]["name"] - - -def test_triton(): + if context == "shadow": + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test" + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + elif context == "python": + assert len(data[0]["children"]) == 1 + # The last frame is the torch kernel + prev_frame = data + curr_frame = data[0]["children"] + while len(curr_frame) > 0: + prev_frame = curr_frame + curr_frame = curr_frame[0]["children"] + assert "elementwise_kernel" in prev_frame[0]["frame"]["name"] + + +def test_triton(tmp_path: pathlib.Path): @triton.jit def foo(x, y): @@ -45,23 +46,24 @@ def foo(x, y): x = torch.tensor([2], device="cuda") y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("test0"): - with proton.scope("test1"): - foo[(1, )](x, y) - with proton.scope("test2"): + temp_file = tmp_path / "test_triton.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test0"): + with proton.scope("test1"): foo[(1, )](x, y) - proton.finalize() + with proton.scope("test2"): + foo[(1, )](x, y) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 2 - assert data[0]["children"][0]["frame"]["name"] == "test0" - assert len(data[0]["children"][0]["children"]) == 1 - assert data[0]["children"][0]["children"][0]["frame"]["name"] == "test1" - assert data[0]["children"][1]["frame"]["name"] == "test2" + assert len(data[0]["children"]) == 2 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert len(data[0]["children"][0]["children"]) == 1 + assert data[0]["children"][0]["children"][0]["frame"]["name"] == "test1" + assert data[0]["children"][1]["frame"]["name"] == "test2" -def test_cudagraph(): +def test_cudagraph(tmp_path: pathlib.Path): stream = torch.cuda.Stream() torch.cuda.set_stream(stream) @@ -75,46 +77,47 @@ def fn(): c = a + b foo[(1, )](a, b, c) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], context="shadow") + temp_file = tmp_path / "test_cudagraph.hatchet" + proton.start(str(temp_file.with_suffix("")), context="shadow") - # warmup - # four kernels - fn() + # warmup + # four kernels + fn() - # no kernels - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - for _ in range(10): - fn() + # no kernels + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(10): + fn() - proton.enter_scope("test") - g.replay() - g.reset() - torch.cuda.synchronize() - proton.exit_scope() - proton.finalize() + proton.enter_scope("test") + g.replay() + g.reset() + torch.cuda.synchronize() + proton.exit_scope() + proton.finalize() + with temp_file.open() as f: data = json.load(f) - # CUDA/HIP graph may also invoke additional kernels to reset outputs - # {torch.ones, add, foo, test} - assert len(data[0]["children"]) >= 4 - # find the test frame - test_frame = None - for child in data[0]["children"]: - if child["frame"]["name"] == "test": - test_frame = child - break - assert test_frame is not None - # {torch.ones, add, foo} - if is_hip(): - assert len(test_frame["children"]) >= 2 - else: - assert len(test_frame["children"]) >= 3 - assert test_frame["children"][0]["metrics"]["time (ns)"] > 0 - - -def test_metrics(): + # CUDA/HIP graph may also invoke additional kernels to reset outputs + # {torch.ones, add, foo, test} + assert len(data[0]["children"]) >= 4 + # find the test frame + test_frame = None + for child in data[0]["children"]: + if child["frame"]["name"] == "test": + test_frame = child + break + assert test_frame is not None + # {torch.ones, add, foo} + if is_hip(): + assert len(test_frame["children"]) >= 2 + else: + assert len(test_frame["children"]) >= 3 + assert test_frame["children"][0]["metrics"]["time (ns)"] > 0 + + +def test_metrics(tmp_path: pathlib.Path): @triton.jit def foo(x, y): @@ -122,18 +125,19 @@ def foo(x, y): x = torch.tensor([2], device="cuda") y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("test0", {"foo": 1.0}): - foo[(1, )](x, y) - proton.finalize() + temp_file = tmp_path / "test_metrics.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test0", {"foo": 1.0}): + foo[(1, )](x, y) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 1 - assert data[0]["children"][0]["frame"]["name"] == "test0" - assert data[0]["children"][0]["metrics"]["foo"] == 1.0 + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert data[0]["children"][0]["metrics"]["foo"] == 1.0 -def test_metrics_ignore(): +def test_metrics_ignore(tmp_path: pathlib.Path): @triton.jit def foo(x, y): @@ -141,36 +145,38 @@ def foo(x, y): x = torch.tensor([2], device="cuda") y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = proton.start(f.name.split(".")[0]) - proton.deactivate(session_id) - with proton.scope("test0", {"foo": 1.0}): - foo[(1, )](x, y) - proton.activate(session_id) - proton.finalize() + temp_file = tmp_path / "test_metrics_ignore.hatchet" + session_id = proton.start(str(temp_file.with_suffix(""))) + proton.deactivate(session_id) + with proton.scope("test0", {"foo": 1.0}): + foo[(1, )](x, y) + proton.activate(session_id) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 0 - - -def test_scope_backward(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("ones1"): - a = torch.ones((100, 100), device="cuda", requires_grad=True) - with proton.scope("plus"): - a2 = a * a * a - with proton.scope("ones2"): - loss = torch.ones_like(a2) - - # Backward triggers two kernels in a single scope - with proton.scope("backward"): - a2.backward(loss) - proton.finalize() + assert len(data[0]["children"]) == 0 + + +def test_scope_backward(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_backward.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("ones1"): + a = torch.ones((100, 100), device="cuda", requires_grad=True) + with proton.scope("plus"): + a2 = a * a * a + with proton.scope("ones2"): + loss = torch.ones_like(a2) + + # Backward triggers two kernels in a single scope + with proton.scope("backward"): + a2.backward(loss) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 4 + assert len(data[0]["children"]) == 4 -def test_hook(): +def test_hook(tmp_path: pathlib.Path): def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): # get arg's element size @@ -187,20 +193,21 @@ def foo(x, size: tl.constexpr, y): x = torch.tensor([2], device="cuda", dtype=torch.float32) y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], hook="triton") - with proton.scope("test0"): - foo[(1, )](x, 1, y, num_warps=4) - proton.finalize() + temp_file = tmp_path / "test_hook.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton") + with proton.scope("test0"): + foo[(1, )](x, 1, y, num_warps=4) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 1 - assert data[0]["children"][0]["frame"]["name"] == "test0" - assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems" - assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0 - assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems" + assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0 + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 -def test_pcsampling(): +def test_pcsampling(tmp_path: pathlib.Path): if is_hip(): pytest.skip("HIP backend does not support pc sampling") @@ -214,37 +221,39 @@ def foo(x, y, size: tl.constexpr): for _ in range(1000): tl.store(y + offs, tl.load(x + offs)) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], hook="triton", backend="cupti_pcsampling") - with proton.scope("init"): - x = torch.ones((1024, ), device="cuda", dtype=torch.float32) - y = torch.zeros_like(x) - with proton.scope("test"): - foo[(1, )](x, y, x.size()[0], num_warps=4) - proton.finalize() + temp_file = tmp_path / "test_pcsampling.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton", backend="cupti_pcsampling") + with proton.scope("init"): + x = torch.ones((1024, ), device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + with proton.scope("test"): + foo[(1, )](x, y, x.size()[0], num_warps=4) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - init_frame = data[0]["children"][0] - test_frame = data[0]["children"][1] - # With line mapping - assert "foo" in test_frame["children"][0]["frame"]["name"] - assert test_frame["children"][0]["children"][0]["metrics"]["num_samples"] > 0 - assert "@" in test_frame["children"][0]["children"][0]["frame"]["name"] - # Without line mapping - assert "elementwise" in init_frame["children"][0]["frame"]["name"] - assert init_frame["children"][0]["metrics"]["num_samples"] > 0 - - -def test_deactivate(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = proton.start(f.name.split(".")[0], hook="triton") - proton.deactivate(session_id) - torch.randn((10, 10), device="cuda") - proton.activate(session_id) - torch.zeros((10, 10), device="cuda") - proton.deactivate(session_id) - proton.finalize() + init_frame = data[0]["children"][0] + test_frame = data[0]["children"][1] + # With line mapping + assert "foo" in test_frame["children"][0]["frame"]["name"] + assert test_frame["children"][0]["children"][0]["metrics"]["num_samples"] > 0 + assert "@" in test_frame["children"][0]["children"][0]["frame"]["name"] + # Without line mapping + assert "elementwise" in init_frame["children"][0]["frame"]["name"] + assert init_frame["children"][0]["metrics"]["num_samples"] > 0 + + +def test_deactivate(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_deactivate.hatchet" + session_id = proton.start(str(temp_file.with_suffix("")), hook="triton") + proton.deactivate(session_id) + torch.randn((10, 10), device="cuda") + proton.activate(session_id) + torch.zeros((10, 10), device="cuda") + proton.deactivate(session_id) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - # Root shouldn't have device id - assert "device_id" not in data[0]["metrics"] - assert len(data[0]["children"]) == 1 - assert "device_id" in data[0]["children"][0]["metrics"] + # Root shouldn't have device id + assert "device_id" not in data[0]["metrics"] + assert len(data[0]["children"]) == 1 + assert "device_id" in data[0]["children"][0]["metrics"] diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index d4c15bbad0..d662537ed7 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -555,14 +555,14 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {2, 0}, {4, 0}, {32, 0}, + {64, 0}, {0, 8}, {0, 16}, - {0, 32}, - {64, 0}}}, + {0, 32}}}, {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, { S("warp"), - {}, + {{0, 0}, {0, 0}}, }, {S("block"), {}}, }, @@ -582,13 +582,46 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, { S("warp"), - {}, + {{0, 0}, {0, 0}}, }, {S("block"), {}}, }, {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) { + EXPECT_EQ( + toLinearLayout({32, 64}, dotMMAv2(0, 8, {2, 2})), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({64, 16}, dotMMAv2(1, 8, {2, 2})), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {{0, 8}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(0, 8, {2, 2})), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {32, 0}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({128, 32}, dotMMAv2(1, 8, {2, 2})), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 16}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {{0, 8}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false);