From 04d655e3179bafbf2886ff6ec4902181fa241a11 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Mon, 4 Nov 2024 09:43:03 +0000 Subject: [PATCH 1/3] [BACKEND] Simplify and comment warp allocation logic in mmav2 (#5041) It's not entirely clear to me whether the previous logic was equivalent to this one, as it was rather obtuse. I think the new one is optimal but I'm happy to run benchmarks to make sure we don't regress. --- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 46 +++++++++++-------- test/TritonGPU/accelerate-matmul.mlir | 6 +-- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 3ddab364d7..4b9082acca 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -7,6 +7,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -77,28 +78,33 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } } - SmallVector ret(rank, 1); - SmallVector shapePerWarp(rank, 1); - shapePerWarp[rank - 1] = 8; - shapePerWarp[rank - 2] = 16; - // TODO (@daadaada): double-check. - // original logic in - // https://github.com/triton-lang/triton/blob/master/lib/codegen/analysis/layout.cc#L252 - // seems buggy for shape = [32, 16] ? - do { - if (ret[0] * ret[1] >= numWarps) - break; - if (shape[0] / shapePerWarp[0] / ret[0] >= - shape[1] / (shapePerWarp[1] * 2) / ret[1]) { - if (ret[0] < shape[0] / shapePerWarp[0]) { - ret[0] *= 2; - } else - ret[1] *= 2; + assert(rank == 2); + SmallVector shapePerWarp = {16, 8}; + SmallVector warps = {1, 1}; + // Compute repM and repN + SmallVector reps = {ceil(shape[0], shapePerWarp[0]), + ceil(shape[1], shapePerWarp[1])}; + // The formula for the number of registers given the reps is + // repM * 4 * repK + repN * 2 * repK + regsC + // where regsC = repM * repN * 4, which does not depend on the warp shape + // + // As such, to minimize the register pressure, we need to balance + // repM and repN. We then untie towards M, as the lhs tile has 4 elements, + // and the rhs tile has just 2. + while (product(warps) < numWarps) { + if (reps[0] >= reps[1]) { + warps[0] *= 2; + // Too many warps for this mma (repM == repN == 1). + // We allocate the remainin warps to the left (arbitrary choice) + if (reps[0] != 1) { + reps[0] /= 2; + } } else { - ret[1] *= 2; + warps[1] *= 2; + reps[1] /= 2; } - } while (true); - return ret; + } + return {(unsigned)warps[0], (unsigned)warps[1]}; } SmallVector diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 420a9d5c2c..accbae971d 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -73,7 +73,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}> +// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> @@ -93,7 +93,7 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> // CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> @@ -148,7 +148,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- // Verify that we use mmav2 when the k dim is too small for mmav3. -// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}> +// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: small_k_size From e82dfd999a238f4d1b370f8f9fcad44502ffb08e Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Mon, 4 Nov 2024 02:10:44 -0800 Subject: [PATCH 2/3] [BACKEND] Minor Bugfixes for SharedToDotOperand MMAv3 (#5030) Two bugfixes following https://github.com/triton-lang/triton/pull/5009. - When `BLOCK_M=64` and `num_warps > 4`, the order of warps for DotOpEncoded tensor should be M-major instead of N-major, since WGMMA expects the 4 warps in each warp group to be stacked along the M dimension. - Should use `mmaBitwidth` instead of `bitwidth` when calculating `numRep` in `SharedToDotOperandMMAv2OrV3`. This was missed in a bad rebase. @lezcano I encountered these bugs when attempting to locally test the [DotOp hoisting PR](https://github.com/triton-lang/triton/pull/5003) after rebasing (they normally would be caught by `test_core.py` but that path was not yet enabled in the last PR). With these fixes added, I was able to successfully validate against pytorch. --- .../ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 6094a91118..8f1fcc1f70 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -659,15 +659,15 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, int kWidth = encoding.getKWidth(); auto numRep = mmaLayout.getMMAv2OrV3RepForOperand( - shapePerCTA, bitwidth, kWidth, encoding.getOpIdx()); + shapePerCTA, mmaBitwidth, kWidth, encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - auto order = triton::gpu::getOrder(mmaLayout); + auto warpOrder = mmaLayout.getWarpOrder(); Value warp = udiv(thread, i32_val(32)); Value lane = urem(thread, i32_val(32)); SmallVector multiDimWarpId = - delinearize(rewriter, loc, warp, warpsPerCTA, order); + delinearize(rewriter, loc, warp, warpsPerCTA, warpOrder); Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0])); int warpsPerTile; Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16)); From 9f939760d2455bb0644698a5b6f3a13aa485abde Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 4 Nov 2024 19:26:02 +0100 Subject: [PATCH 3/3] Don't use `fork` method if it's not available on the platform (#5051) Signed-off-by: Anatoly Myachev --- python/test/unit/runtime/test_subproc.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 0277792330..334d5d635f 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -7,6 +7,7 @@ from triton.compiler import ASTSource target = triton.runtime.driver.active.get_current_target() +start_method = 'fork' if 'fork' in multiprocessing.get_all_start_methods() else 'spawn' def compile_fn(attrs): @@ -27,8 +28,8 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: config = AttrsDescriptor.from_hints({i: 16 for i in range(4)}) - multiprocessing.set_start_method('fork') - proc = multiprocessing.Process(target=compile_fn, args=(config, )) + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_fn, args=(config, )) proc.start() proc.join() assert proc.exitcode == 0 @@ -49,8 +50,8 @@ def kernel_dot(Z): def test_compile_in_forked_subproc(fresh_triton_cache) -> None: config = AttrsDescriptor.from_hints({0: 16}) - assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process(target=compile_fn_dot, args=(config, )) + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_fn_dot, args=(config, )) proc.start() proc.join() assert proc.exitcode == 0 @@ -92,8 +93,8 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None: # stage 2.p shutil.rmtree(fresh_triton_cache) - assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, )) + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_empty_kernel_with_gc, args=(config, )) # stage 3.c proc.start()