diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 3ddab364d7..4b9082acca 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -7,6 +7,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -77,28 +78,33 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } } - SmallVector ret(rank, 1); - SmallVector shapePerWarp(rank, 1); - shapePerWarp[rank - 1] = 8; - shapePerWarp[rank - 2] = 16; - // TODO (@daadaada): double-check. - // original logic in - // https://github.com/triton-lang/triton/blob/master/lib/codegen/analysis/layout.cc#L252 - // seems buggy for shape = [32, 16] ? - do { - if (ret[0] * ret[1] >= numWarps) - break; - if (shape[0] / shapePerWarp[0] / ret[0] >= - shape[1] / (shapePerWarp[1] * 2) / ret[1]) { - if (ret[0] < shape[0] / shapePerWarp[0]) { - ret[0] *= 2; - } else - ret[1] *= 2; + assert(rank == 2); + SmallVector shapePerWarp = {16, 8}; + SmallVector warps = {1, 1}; + // Compute repM and repN + SmallVector reps = {ceil(shape[0], shapePerWarp[0]), + ceil(shape[1], shapePerWarp[1])}; + // The formula for the number of registers given the reps is + // repM * 4 * repK + repN * 2 * repK + regsC + // where regsC = repM * repN * 4, which does not depend on the warp shape + // + // As such, to minimize the register pressure, we need to balance + // repM and repN. We then untie towards M, as the lhs tile has 4 elements, + // and the rhs tile has just 2. + while (product(warps) < numWarps) { + if (reps[0] >= reps[1]) { + warps[0] *= 2; + // Too many warps for this mma (repM == repN == 1). + // We allocate the remainin warps to the left (arbitrary choice) + if (reps[0] != 1) { + reps[0] /= 2; + } } else { - ret[1] *= 2; + warps[1] *= 2; + reps[1] /= 2; } - } while (true); - return ret; + } + return {(unsigned)warps[0], (unsigned)warps[1]}; } SmallVector diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 0277792330..334d5d635f 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -7,6 +7,7 @@ from triton.compiler import ASTSource target = triton.runtime.driver.active.get_current_target() +start_method = 'fork' if 'fork' in multiprocessing.get_all_start_methods() else 'spawn' def compile_fn(attrs): @@ -27,8 +28,8 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: config = AttrsDescriptor.from_hints({i: 16 for i in range(4)}) - multiprocessing.set_start_method('fork') - proc = multiprocessing.Process(target=compile_fn, args=(config, )) + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_fn, args=(config, )) proc.start() proc.join() assert proc.exitcode == 0 @@ -49,8 +50,8 @@ def kernel_dot(Z): def test_compile_in_forked_subproc(fresh_triton_cache) -> None: config = AttrsDescriptor.from_hints({0: 16}) - assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process(target=compile_fn_dot, args=(config, )) + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_fn_dot, args=(config, )) proc.start() proc.join() assert proc.exitcode == 0 @@ -92,8 +93,8 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None: # stage 2.p shutil.rmtree(fresh_triton_cache) - assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, )) + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_empty_kernel_with_gc, args=(config, )) # stage 3.c proc.start() diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 420a9d5c2c..accbae971d 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -73,7 +73,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}> +// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> @@ -93,7 +93,7 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> // CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> @@ -148,7 +148,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- // Verify that we use mmav2 when the k dim is too small for mmav3. -// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}> +// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: small_k_size diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 6094a91118..8f1fcc1f70 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -659,15 +659,15 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, int kWidth = encoding.getKWidth(); auto numRep = mmaLayout.getMMAv2OrV3RepForOperand( - shapePerCTA, bitwidth, kWidth, encoding.getOpIdx()); + shapePerCTA, mmaBitwidth, kWidth, encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - auto order = triton::gpu::getOrder(mmaLayout); + auto warpOrder = mmaLayout.getWarpOrder(); Value warp = udiv(thread, i32_val(32)); Value lane = urem(thread, i32_val(32)); SmallVector multiDimWarpId = - delinearize(rewriter, loc, warp, warpsPerCTA, order); + delinearize(rewriter, loc, warp, warpsPerCTA, warpOrder); Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0])); int warpsPerTile; Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16));