From d4b1681801e7ba044ce7b802accdf10607f620ce Mon Sep 17 00:00:00 2001 From: Jiaxing Ding <61589029+Paran0idy@users.noreply.github.com> Date: Mon, 27 May 2024 22:58:27 +0800 Subject: [PATCH 1/7] [DOCS] improve Triton Linear layout doc (#4005) - Correct some indices errors. --- include/triton/Tools/LinearLayout.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 99c54bcdee..fb26802415 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -187,9 +187,9 @@ namespace mlir::triton { // // where // -// - a is a vector [a0...aM], and ai is a scalar in some field 𝔽 (for +// - a is a vector [a1...aM], and ai is a scalar in some field 𝔽 (for // example, ai might be a real number), and -// - each Bj is a vector [b0j, b1j, ..., bNj] of N scalars in 𝔽. +// - each Bj is a vector [b1j, b1j, ..., bNj] of N scalars in 𝔽. // // We can also write this as a matrix-vector product Ba, where // @@ -201,8 +201,8 @@ namespace mlir::triton { // B = | B1, B2, ..., BM| // | ↓ ↓ ↓ | // -// |b11, b12, ..., b0M| -// |b21, b22, ..., b1M| +// |b11, b12, ..., b1M| +// |b21, b22, ..., b2M| // = | ↓ ↓ ↓ | // |bN1, bN2, ..., bNM|. // From 513f38c47a901981dbf265f5da74b41803ef0d64 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Tue, 28 May 2024 06:08:57 -0700 Subject: [PATCH 2/7] [FRONTEND] Fix wrong livein set in loop codegen (#4018) When processing loops we were incorrectly setting all the local def as livein. Since part of the code assumes that only livein variable can be loop carried dependency that was causing a mismatch in the logic. Change it to enforce that local def not livein cannot be loop carried. --- python/test/unit/language/test_core.py | 34 ++++++++++++++++++++++++ python/triton/compiler/code_generator.py | 3 +++ 2 files changed, 37 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4d990a3d0c..f6604c3de3 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5357,3 +5357,37 @@ def kernel(X): except AssertionError: print("Failing ptx:\n", k.asm["ptx"]) raise + + +@pytest.mark.interpreter +def test_temp_var_in_loop(device): + + @triton.jit + def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): + acc = tl.full((BLOCK, ), 0, dtype=tl.int32) + for i in range(N): + if i == 0: + temp = tl.full((BLOCK, ), 2, dtype=tl.int32) + acc = temp + else: + acc += tl.full((BLOCK, ), 1, dtype=tl.int32) + # re-use the temp variable and make sure to check that it isn't creating incorrect IR. + temp = tl.full((BLOCK, ), 1, dtype=tl.int32) + acc += temp + z = Z + tl.arange(0, BLOCK) + tl.store(z, acc) + + N = 10 + BLOCK = 32 + out = torch.empty((BLOCK, ), dtype=torch.int32, device=device) + temp_in_loop[(1, )](out, N, BLOCK) + acc = torch.full((BLOCK, ), 0, dtype=torch.int32, device=device) + for i in range(N): + if i == 0: + temp = torch.full((BLOCK, ), 2, dtype=torch.int32, device=device) + acc = temp + else: + acc += torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + temp = torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + acc += temp + assert (acc == out).all() diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index ef5570eea3..6903052ca2 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -991,6 +991,9 @@ def visit_For(self, node): self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} for i, name in enumerate(names): self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) self.visit_compound_statement(node.body) From b847042d0a2956ae02e9689f8d038c9db6988573 Mon Sep 17 00:00:00 2001 From: Finlay Date: Tue, 28 May 2024 17:00:59 +0100 Subject: [PATCH 3/7] Remove redundant options from passes (#4015) The TritonGPUPipeline pass has unused pass options and the TritonGPUAccelerateMatmul pass option could instead be read from the module attributes, where the data already exists. The goal is to reduce redundancy. --------- Signed-off-by: Finlay Marno --- .../Dialect/TritonGPU/Transforms/Passes.td | 17 +------------- .../Dialect/TritonGPU/Transforms/Utility.h | 3 +++ .../TritonGPU/Transforms/AccelerateMatmul.cpp | 2 ++ lib/Dialect/TritonGPU/Transforms/Utility.cpp | 21 ++++++++++++++++++ python/src/passes.cc | 6 ++--- test/TritonGPU/accelerate-matmul.mlir | 22 +++++++------------ test/TritonGPU/loop-pipeline-hopper.mlir | 2 +- .../pipeline-hopper-remove-wait.mlir | 2 +- third_party/nvidia/backend/compiler.py | 4 ++-- 9 files changed, 41 insertions(+), 38 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index d98f918fde..fdceb2cfe4 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -19,16 +19,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { let options = [ Option<"numStages", "num-stages", "int32_t", /*default*/"3", - "number of pipeline stages">, - Option<"numWarps", "num-warps", - "int32_t", /*default*/"4", - "number of warps per block">, - Option<"numCTAs", "num-ctas", - "int32_t", /*default*/"1", - "number of CTAs per CGA">, - Option<"computeCapability", "compute-capability", - "int32_t", /*default*/"80", - "device compute capability"> + "number of pipeline stages"> ]; } @@ -68,12 +59,6 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", "mlir::triton::TritonDialect"]; - - let options = [ - Option<"computeCapability", "compute-capability", - "int32_t", /*default*/"80", - "device compute capability"> - ]; } def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> { diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index aa51bc5c4b..114c181425 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -169,6 +169,9 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, // operand and single result. bool isPureUnaryInlineAsm(Operation *op); +// read the compute capability from the module attributes +int getNVIDIAComputeCapability(Operation *module); + } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 821c8ab9c5..df84c4e628 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -387,6 +387,8 @@ class TritonGPUAccelerateMatmulPass MLIRContext *context = &getContext(); ModuleOp m = getOperation(); + auto computeCapability = getNVIDIAComputeCapability(m); + mlir::RewritePatternSet patterns(context); patterns.add(context, computeCapability); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index cc1818ca7b..1d6152417e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -8,6 +8,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -818,6 +819,26 @@ bool isPureUnaryInlineAsm(Operation *op) { inlineAsmOp.getPure(); } +int getNVIDIAComputeCapability(Operation *module) { + assert(module->hasAttr(triton::AttrTargetName) && + "Expected a target attribute on the module operation"); + + StringAttr targetAttr = + cast(module->getAttr(triton::AttrTargetName)); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("cuda:") && + "expected target attribute to be prefixed with \"cuda:\""); + + StringRef capabilityStr = ref.drop_front(5); // drop the "cuda:" + int computeCapability; + bool parseError = capabilityStr.getAsInteger(10, computeCapability); + assert(!parseError && + "invalid compute capability string in target attribute"); + + return computeCapability; +} + namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and diff --git a/python/src/passes.cc b/python/src/passes.cc index ae48846104..513e811d28 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -49,11 +49,9 @@ void init_triton_passes_ttgpuir(py::module &&m) { ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce); ADD_PASS_WRAPPER_0("add_optimize_thread_locality", createTritonGPUOptimizeThreadLocality); - ADD_PASS_OPTION_WRAPPER_4("add_pipeline", createTritonGPUPipeline, int, int, - int, int); + ADD_PASS_OPTION_WRAPPER_1("add_pipeline", createTritonGPUPipeline, int); ADD_PASS_WRAPPER_0("add_prefetch", createTritonGPUPrefetch); - ADD_PASS_OPTION_WRAPPER_1("add_accelerate_matmul", - createTritonGPUAccelerateMatmul, int); + ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul); ADD_PASS_WRAPPER_0("add_reorder_instructions", createTritonGPUReorderInstructions); ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC); diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 6536fb3f9d..8c4e85aa08 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -1,6 +1,4 @@ -// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s -// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=89 | FileCheck %s --check-prefix=CHECK-89 -// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=80 | FileCheck %s --check-prefix=CHECK-80 +// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul | FileCheck %s // CHECK: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> // CHECK: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> @@ -49,24 +47,24 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK-80: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> +// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-80-LABEL: chained_dot + // CHECK-LABEL: chained_dot tt.func public @chained_dot( %arg0: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> - // CHECK-80: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> + // CHECK: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> %d = tt.dot %arg0, %arg1, %cst_0 : tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> %c = triton_gpu.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> - // CHECK-80: tt.dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]> + // CHECK: tt.dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]> %r = tt.dot %c, %arg2, %cst_1 : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> tt.return %r : tensor<64x128xf32, #blocked1> @@ -75,18 +73,18 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK-89: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}> +// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-89-LABEL: fp8_dot + // CHECK-LABEL: fp8_dot tt.func public @fp8_dot( %arg0: tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x64xf32, #blocked> { %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> - // CHECK-89: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]> + // CHECK: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]> %d = tt.dot %arg0, %arg1, %cst_0 : tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> tt.return %d : tensor<64x64xf32, #blocked> @@ -97,8 +95,6 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : // CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> // CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> -// CHECK-80-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> -// CHECK-80-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -112,7 +108,6 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %1 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> %2 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #blocked1> // CHECK: tt.dot {{.*}} -> tensor<16x16xf32, #[[MMA]]> - // CHECK-80: tt.dot {{.*}} -> tensor<16x16xf32, #[[MMA]]> %3 = tt.dot %0, %1, %2, inputPrecision = tf32 : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<16x16xf32, #blocked1> %4 = triton_gpu.convert_layout %3 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16x16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16x16xf32, #blocked2> @@ -122,7 +117,6 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %9 = triton_gpu.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> %10 = triton_gpu.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked3> // CHECK: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]> - // CHECK-80: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]> %11 = tt.dot %8, %9, %10, inputPrecision = tf32 : tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<2x16x16xf32, #blocked3> %12 = triton_gpu.convert_layout %11 : tensor<2x16x16xf32, #blocked3> -> tensor<2x16x16xf32, #blocked> tt.print ": " {hex = false} : %12 : tensor<2x16x16xf32, #blocked> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index b3dc9d8834..5c9cac004c 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck --dump-input-context=50 %s +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index a3e002321f..1e3d4d9670 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -canonicalize -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -canonicalize -tritongpu-pipeline -canonicalize | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index a9f389ad61..10c5f3e83b 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -168,13 +168,13 @@ def make_ttgir(mod, metadata, opt, capability): nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) - passes.ttgpuir.add_accelerate_matmul(pm, capability) + passes.ttgpuir.add_accelerate_matmul(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.common.add_cse(pm) if capability // 10 >= 8: passes.ttgpuir.add_combine_tensor_select_and_if(pm) - passes.ttgpuir.add_pipeline(pm, opt.num_stages, opt.num_warps, opt.num_ctas, capability) + passes.ttgpuir.add_pipeline(pm, opt.num_stages) if capability // 10 <= 8: passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) From d3fb1dc1fa6caa52d8ee578ea72b98fd0dc02693 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Tue, 28 May 2024 18:38:46 +0200 Subject: [PATCH 4/7] [AMD] Move MFMA shortcut check to not compute scratch buffer shape if it is not needed (#3790) This PR: - moves shortcut check earlier, to not compute scratch buffer shape if it is not needed - raise priority of AMD specific over common conversions to eliminate uncertainty which pattern to apply. - add regression test for MFMA to Dot Op shortcut --- lib/Analysis/Allocation.cpp | 9 ++-- lib/Analysis/Utility.cpp | 6 ++- test/Conversion/amd/mfma-shortcut.mlir | 27 ++++++++++ .../ConvertLayoutOpToLLVM.cpp | 2 - .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 50 +++++++++++-------- 5 files changed, 65 insertions(+), 29 deletions(-) create mode 100644 test/Conversion/amd/mfma-shortcut.mlir diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 1e6e38749f..a129cb1947 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -68,6 +68,9 @@ SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { return convertType(getShapePerCTA(srcTy)); } + if (isMfmaToDotShortcut(srcTy, dstTy)) + return {}; + // MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem if (auto srcMmaLayout = mlir::dyn_cast(srcLayout)) { if (mlir::isa(dstLayout)) { @@ -111,11 +114,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - if (mlir::isa(srcLayout) && - mlir::dyn_cast(srcLayout).getIsTransposed() && - mlir::isa(dstLayout)) - if (isMfmaToDotShortcut(srcTy, dstTy)) - return {}; + assert(!isMfmaToDotShortcut(srcTy, dstTy)); auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); unsigned srcContigPerThread = diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 1851d9b6fe..689e83b5ac 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -581,8 +581,10 @@ bool supportMMA(Value value, int version) { bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { auto srcLayout = srcTy.getEncoding(); auto dstLayout = dstTy.getEncoding(); - auto mfmaLayout = cast(srcLayout); - auto dotOperandLayout = cast(dstLayout); + auto mfmaLayout = dyn_cast(srcLayout); + auto dotOperandLayout = dyn_cast(dstLayout); + if (mfmaLayout == nullptr || dotOperandLayout == nullptr) + return false; // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is // improved. In addition, we can enable this shortcut for regular MFMA // layout when opIdx == 1. diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir new file mode 100644 index 0000000000..83c9e535d8 --- /dev/null +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -0,0 +1,27 @@ +// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: shortcut_mfma16 + tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> + tt.return + } +} + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: no_shortcut_mfma16 + tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { + // CHECK: store + // CHECK: load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index edebbbe12c..953b01dab0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -168,7 +168,5 @@ void populateConvertLayoutOpToLLVMPatterns( ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - mlir::triton::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, - patterns, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index c61dd5b815..6625b8a120 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -145,48 +145,58 @@ struct ConvertTritonAMDGPUToLLVM OpBuilder::InsertPoint indexInsertPoint; RewritePatternSet patterns(context); - int benefit = patternBenefitPrioritizeOverLLVMConversions; - auto populatePatterns1 = [&](auto populateFunc) { + int commonBenefit = patternBenefitPrioritizeOverLLVMConversions; + // Make benefit for AMD specific patterns higher so they apply before common + // patterns + int AMDBenefit = commonBenefit + 1; + auto populatePatterns1 = [&](auto populateFunc, int benefit) { populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, allocation, benefit); }; - auto populatePatterns5 = [&](auto populateFunc) { + auto populatePatterns5 = [&](auto populateFunc, int benefit) { populateFunc(typeConverter, patterns, benefit); }; - auto populatePatterns6 = [&](auto populateFunc) { + auto populatePatterns6 = [&](auto populateFunc, int benefit) { populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, allocation, targetInfo, benefit); }; - auto populatePatterns7 = [&](auto populateFunc) { + auto populatePatterns7 = [&](auto populateFunc, int benefit) { populateFunc(typeConverter, patterns, targetInfo, benefit); }; AMD::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, patterns, numWarps, - axisInfoAnalysis, benefit); + axisInfoAnalysis, AMDBenefit); + mlir::triton::populateConvertLayoutOpToLLVMPatterns( + typeConverter, targetInfo, patterns, commonBenefit); AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, benefit); - populatePatterns6(AMD::populateElementwiseOpToLLVMPatterns); + axisInfoAnalysis, AMDBenefit); + populatePatterns6(AMD::populateElementwiseOpToLLVMPatterns, AMDBenefit); AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, - numWarps, axisInfoAnalysis, benefit); - populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns); - populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns); - populatePatterns5(mlir::triton::populateViewOpToLLVMPatterns); - populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns); + numWarps, axisInfoAnalysis, + AMDBenefit); + populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns, + commonBenefit); + populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns, + commonBenefit); + populatePatterns5(mlir::triton::populateViewOpToLLVMPatterns, + commonBenefit); + populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns, + commonBenefit); mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, - patterns, benefit); + patterns, commonBenefit); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, - patterns, benefit); + patterns, commonBenefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); + targetInfo, commonBenefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, - benefit); + commonBenefit); mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); - AMD::populateSPMDOpToLLVMPattern(typeConverter, patterns, benefit); + targetInfo, commonBenefit); + AMD::populateSPMDOpToLLVMPattern(typeConverter, patterns, AMDBenefit); // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns // to help convert scalar expression to LLVM. @@ -200,7 +210,7 @@ struct ConvertTritonAMDGPUToLLVM mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); + targetInfo, commonBenefit); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { return signalPassFailure(); } From 47f7d45c7b0a882da57ae46a88fc4836e2160a78 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Tue, 28 May 2024 18:05:15 +0100 Subject: [PATCH 5/7] [AMD] Replace wave with warp where possible. (#3978) This is a follow up PR of #3832 `wave` has been replaced with `warp` for the consistency between GPUs. Unfortunately there are still remaining use of `wave` in the code as below, although I've tried to minimize it. ## Referencing AMD features (HIP API or AMDGPU) third_party/amd/backend/include/hip/: third_party/amd/backend/include/roctracer/: third_party/amd/backend/include/has/*: - Cannot completely replace waves because the definition comes from outside e.g., __AMDGCN_WAVEFRONT_SIZE, hsa_wavefront_info_t - Mixing up `warp` and `wave` together in the same place could be even worse. ## Using amdgpu compiler option third_party/amd/backend/compiler.py: python/tutorials/03-matrix-multiplication.py: python/tutorials/06-fused-attention.py: - `waves_per_eu` which is supposed to mapped to a CLANG attribute `amdgpu-waves-per-eu` - It is AMD only option and makes better sense to keep --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 18 ++--- lib/Dialect/TritonGPU/IR/Dialect.cpp | 6 +- .../SharedToDotOperandMFMA.cpp | 74 +++++++++---------- .../TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp | 12 +-- 4 files changed, 55 insertions(+), 55 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 05f3378dc2..1099e9bf69 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -804,7 +804,7 @@ It is characterized by the following parameters: - 1.0: gfx908, i.e. MI100 - 2.0: gfx90a: i.e. MI200, MI210, MI250 - 3.0: gfx940, gfx941, gfx942: MI300 -- `warpsPerCTA` indicates the wave layout in the workgroup. +- `warpsPerCTA` indicates the warp layout in the block. - `MDim` and `NDim` indicate the dimension of the output of the mfma instruction. - `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout without going to LDS. This is used in the case of chained dot (E.g. Flash-Attention kernel). @@ -813,7 +813,7 @@ Example 1: Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32. The data will be distributed between threads as follows: - wave 0 wave 1 + warp 0 warp 1 -----------------/\-------------- -----------------/\-------------- [ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] [ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] @@ -852,7 +852,7 @@ Example 2: Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and MDim=NDim=16. The data will be distributed between threads as follows: - wave 0 wave 1 + warp 0 warp 1 -----------------/\------------- ------------------/\--------------- [ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] [ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] @@ -877,13 +877,13 @@ The data will be distributed between threads as follows(note that each element i Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and MDim=NDim=4. The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): -M N -> wave 0 wave 2 +M N -> warp 0 warp 2 | --------------------------/\-------------------------- ------------------------------/\------------------------------ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] - wave 1 wave 3 + warp 1 warp 3 --------------------------/\-------------------------- ------------------------------/\------------------------------ [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] @@ -933,15 +933,15 @@ def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encod let description = [{ An encoding for tensors that have been produced by WMMA instructions, available on RDNA 3. -A `warpsPerCTA` parameter characterizes data distribution between waves. +A `warpsPerCTA` parameter characterizes data distribution between warps. An important limitation of WMMA for layout is a shape for tiles proccessed -by a single wave. It is [16, 16]. +by a single warp. It is [16, 16]. This encoding assumes specific access to matrix elements by threads. Example: Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3]. - wave 0 [16, 16] wave 1 [16, 16] wave 2 [16, 16] + warp 0 [16, 16] warp 1 [16, 16] warp 2 [16, 16] -----------/\---------- -----------/\---------- -----------/\---------- [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] @@ -951,7 +951,7 @@ Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3]. [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] - wave 3 [16, 16] wave 4 [16, 16] wave 5 [16, 16] + warp 3 [16, 16] warp 4 [16, 16] warp 5 [16, 16] -----------/\---------- -----------/\---------- -----------/\---------- [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 74ae61c06d..86c9f8241c 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1565,10 +1565,10 @@ AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const { unsigned nDim = getNDim(); assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) || (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); - constexpr int waveSize = 64; // MFMA is used on wave64 architectures only + constexpr int warpSize = 64; // MFMA is always based on the 64-wide warps. int kGroups = -1; if (mDim == nDim) - kGroups = waveSize / mDim; + kGroups = warpSize / mDim; if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) kGroups = 1; int64_t kDim = kWidth * kGroups; @@ -1605,7 +1605,7 @@ AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { - constexpr int waveSize = 64; + constexpr int warpSize = 64; auto rep = getMFMARepForOperands(shape, kWidth, opIdx); return rep[0] * rep[1] * rep[2] * kWidth; } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 5e1067884d..59682558d9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -36,29 +36,29 @@ namespace SharedToDotOperandMFMA { * @brief This function maps particular load of mfma dot operand to element * indexes(row, col) * - * Whole tensor is broken into "blocks" of waves along "non-K" axis. - * One block could be processed by multiple waves. - * One wave works on a piece of tensor size elemsPerInstr[0] x K. + * Whole tensor is broken into "blocks" of warps along "non-K" axis. + * One block could be processed by multiple warps. + * One warp works on a piece of tensor size elemsPerInstr[0] x K. * Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x * elemsPerInstr[1]. * * Total offset of element is a sum of following values: - * 1. Offset of wave-block in tensor - * 2. Offset of wave inside one wave-block - * 3. Offset of tile in one wave + * 1. Offset of warp-block in tensor + * 2. Offset of warp inside one warp-block + * 3. Offset of tile in one warp * 4. Offset of one lane data in a tile * 5. Offset of particular element of tensor processed by one lane * * This function computes these offsets for axies independently * Note that this function returns the offsets of elements in the first - * wave-block. The offsets of elements in later wave-blocks can be computed + * warp-block. The offsets of elements in later warp-blocks can be computed * by adding a constant stride to the xor-ed offsets of elements in the - * first wave-block. + * first warp-block. * * @param rewriter * @param loc * @param elemsPerInstr operand tile shape consumed by one MFMA instruction - * @param waveId id component of 2d wave grid along non-K axis + * @param warpId id component of 2d warp grid along non-K axis * @param laneId lane id in warp [0..63] * @param numOfElems number of elements accessed by thread per repetition * @param reps number of instructions repetition to fully cover dot operand @@ -71,7 +71,7 @@ namespace SharedToDotOperandMFMA { */ llvm::SmallVector> computeTensorElemMappingInBlock( ConversionPatternRewriter &rewriter, Location loc, - const ArrayRef &elemsPerInstr, Value waveId, Value laneId, + const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int numOfElems, ArrayRef reps, ArrayRef smemOffsets, int loadVecSize, unsigned iNonKDim, unsigned iKDim) { auto numM = reps[1]; @@ -82,7 +82,7 @@ llvm::SmallVector> computeTensorElemMappingInBlock( Value _0 = i32_val(0); Value _32 = i32_val(32); Value nonKDim = i32_val(iNonKDim); - Value waveVOffset = mul(waveId, i32_val(elemsPerInstr[0])); + Value warpVOffset = mul(warpId, i32_val(elemsPerInstr[0])); auto rank = smemOffsets.size(); @@ -95,12 +95,12 @@ llvm::SmallVector> computeTensorElemMappingInBlock( if (iNonKDim == 32) laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0); else { - // In this configuration wave contains 16 copies of same data + // In this configuration warp contains 16 copies of same data if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) { laneHOffset = i32_val(0); } else { assert(iKDim * iNonKDim / numOfElems == 64 && - "seems no all threads in wave contain unique elements"); + "seems no all threads in warp contain unique elements"); laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems)); } } @@ -110,7 +110,7 @@ llvm::SmallVector> computeTensorElemMappingInBlock( Value elemHOffset = i32_val(loadId * loadVecSize); Value sliceVOffset = - add(add(add(tileVOffset, laneVOffset), elemVOffset), waveVOffset); + add(add(add(tileVOffset, laneVOffset), elemVOffset), warpVOffset); Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset); Value row = add(sliceVOffset, smemOffsets[rank - 2]); @@ -131,7 +131,7 @@ bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) { // @param loc // @param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA // instruction -// @param waveId wave id for the "non K" axis +// @param warpId warp id for the "non K" axis // @param laneId lane id in warp [0..63] // @param warpsPerBlock number of warps per horizontal axis // @param numOfElems number of elements accessed by threads per repetition @@ -139,7 +139,7 @@ bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) { // @param cSwizzleOffset llvm::SmallVector fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, - const ArrayRef &elemsPerInstr, Value waveId, + const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int warpsPerBlock, int numOfElems, ArrayRef reps, Value cSwizzleOffset) { auto numK = reps[1]; @@ -150,7 +150,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, auto iNonKDim = elemsPerInstr[1]; int lineSize = warpsPerBlock * iNonKDim * numN; Value _nonKDim = i32_val(iNonKDim); - Value waveOffset = mul(waveId, i32_val(iNonKDim)); + Value warpOffset = mul(warpId, i32_val(iNonKDim)); Value colOffset = urem(laneId, _nonKDim); for (int block = 0; block < numN; ++block) { @@ -158,15 +158,15 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, for (int tile = 0; tile < numK; ++tile) { Value tileOffset = i32_val(tile * iKDim * lineSize); for (int elem = 0; elem < numOfElems; ++elem) { - // halfOffset is an offset related to wrapping of wave in the tile. + // halfOffset is an offset related to wrapping of warp in the tile. // for example, mfma 32 case (mapping of tensor elements to lane ids in - // wave): + // warp): // // 0 1 2 3 ... 31 // 0 1 2 3 ... 31 // 0 1 2 3 ... 31 // 0 1 2 3 ... 31 - // 32 33 34 35 ... 63 <- at this point wave is wrapping + // 32 33 34 35 ... 63 <- at this point warp is wrapping // 32 33 34 35 ... 63 // 32 33 34 35 ... 63 // 32 33 34 35 ... 63 @@ -179,7 +179,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, Value rowOffset = add(i32_val(elem * lineSize), halfOffset); Value elemOffset = add(rowOffset, colOffset); Value offset = - add(add(add(waveOffset, blockOffset), tileOffset), elemOffset); + add(add(add(warpOffset, blockOffset), tileOffset), elemOffset); offsets[numK * numOfElems * block + numOfElems * tile + elem] = offset; } } @@ -240,30 +240,30 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, numRepK = numReps[kDimIdx + 1]; } - unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout); - assert(iWaveSize == 64); - Value waveSize = i32_val(iWaveSize); - Value linearWaveId = udiv(thread, waveSize); - Value lane = urem(thread, waveSize); + unsigned iWarpSize = triton::gpu::getWarpSize(mfmaLayout); + assert(iWarpSize == 64); + Value warpSize = i32_val(iWarpSize); + Value linearWarpId = udiv(thread, warpSize); + Value lane = urem(thread, warpSize); - Value spatialWaveId = AMD::getWarpIdInBlock( - rewriter, loc, linearWaveId, warpsPerCTA, mfmaInstrNonK, + Value spatialWarpId = AMD::getWarpIdInBlock( + rewriter, loc, linearWarpId, warpsPerCTA, mfmaInstrNonK, shape[nonKDimIdx], nonKDimIdx, triton::gpu::getOrder(mfmaLayout)); - // number of duplicates of elements in wave + // number of duplicates of elements in warp // In case of 64x4 x 4x4 multiplication, 4x4 B operand is duplicated 16 times int numSubBlocks = 1; if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4) numSubBlocks = 16; // numOfElemsPerThreadPerMfmaInstr - int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize; + int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWarpSize; assert(numOfElems >= 1); unsigned int maxNumWarps = shape[nonKDimIdx] / mfmaInstrNonK; int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); int warpsPerBatch = rank == 3 ? std::min(shape[0], warpsPerCTA[0]) : 1; - Value waveIdInBatch = urem(linearWaveId, i32_val(warpsPerBatch)); + Value warpIdInBatch = urem(linearWarpId, i32_val(warpsPerBatch)); elemTy = typeConverter->convertType(elemTy); SmallVector loadedValues; @@ -282,7 +282,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, SmallVector elemsPerInstr{mfmaInstrK, mfmaInstrNonK}; SmallVector reps{numReps[0], numReps[2], numReps[1]}; offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr, - spatialWaveId, lane, warpsPerBlockNonK, + spatialWarpId, lane, warpsPerBlockNonK, numOfElems, reps, cSwizzleOffset); } else { llvm_unreachable( @@ -294,7 +294,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, "col major operand B should be handled in the normal path"); } else { offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr, - spatialWaveId, lane, warpsPerBlockNonK, + spatialWarpId, lane, warpsPerBlockNonK, numOfElems, numReps, cSwizzleOffset); } } @@ -309,13 +309,13 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, if (opIdx == 0) { offsets = AMD::computeOffsetsAType( rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, - spatialWaveId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, + spatialWarpId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, sharedLayout, mDim, mfmaInstrK); } else { assert(opIdx == 1); offsets = AMD::computeOffsetsBType( rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, - spatialWaveId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, + spatialWarpId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, sharedLayout, nDim, mfmaInstrK); } smemBase = AMD::computeBasePtr(rewriter, loc, smemObj); @@ -331,10 +331,10 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), - add(waveIdInBatch, i32_val(b * warpsPerBatch))); + add(warpIdInBatch, i32_val(b * warpsPerBatch))); for (int nonK = 0; nonK < numRepNonK; ++nonK) { int blockNonKOffset = nonK * mfmaInstrNonK * warpsPerBlockNonK; - Value waveBlockOffAdjust = i32_val(blockNonKOffset * shape[order[0]]); + Value warpBlockOffAdjust = i32_val(blockNonKOffset * shape[order[0]]); for (int k = 0; k < numRepK; ++k) { auto vecTy = vec_ty(resElemTy, numOfElems); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index f8288024d6..46f91902cb 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -98,10 +98,10 @@ struct DotOpMFMAConversionHelper { "numSubBlocks in not pow 2!"); if (numSubBlocks == 1) return acc; - constexpr int waveSize = 64; - int subBlockSize = waveSize / numSubBlocks; + constexpr int warpSize = 64; + int subBlockSize = warpSize / numSubBlocks; Value laneId = getThreadId(); - laneId = and_(laneId, i32_val(waveSize - 1)); + laneId = and_(laneId, i32_val(warpSize - 1)); auto vecTy = dyn_cast(acc.getType()); auto elemType = vecTy.getElementType(); assert(elemType.getIntOrFloatBitWidth() == 32); @@ -111,7 +111,7 @@ struct DotOpMFMAConversionHelper { accScalar[i] = extract_element(elemType, acc, i32_val(i)); if (reduceSubBlocks) { - while (subBlockSize < waveSize) { + while (subBlockSize < warpSize) { for (int i = 0; i < numScalars; ++i) { Value other_acc = shuffleXor(loc, rewriter, accScalar[i], subBlockSize); @@ -151,9 +151,9 @@ struct DotOpMFMAConversionHelper { /// @brief Zeroes out redundant values in all sub-blocks except first one /// - /// Every wave in mfma 4x4 layout holds only 4 unique values(scalar or + /// Every warp in mfma 4x4 layout holds only 4 unique values(scalar or /// vectors) in blocks of 4 consecutive threads, There are 16 copies of these - /// 4 values across all threads of the wave. Need to zero out 15 copies to use + /// 4 values across all threads of the warp. Need to zero out 15 copies to use /// accumulator between dot operations. /// @param numSubBlocks /// @param acc From 706174da3e91f4cff2be960a9462042176023846 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 28 May 2024 13:22:25 -0400 Subject: [PATCH 6/7] [CI] Add macos build test (#3994) --- .github/workflows/integration-tests.yml | 117 ++++++++++++++++++++- .github/workflows/integration-tests.yml.in | 57 +++++++++- CMakeLists.txt | 4 +- cmake/json-version.txt | 1 + python/setup.py | 2 + third_party/proton/CMakeLists.txt | 4 +- 6 files changed, 180 insertions(+), 5 deletions(-) create mode 100644 cmake/json-version.txt diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 6b69c563ff..3911f2a396 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -32,6 +32,7 @@ jobs: outputs: matrix-CUDA: ${{ steps.set-matrix.outputs.matrix-CUDA }} matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }} + matrix-MACOS: ${{ steps.set-matrix.outputs.matrix-MACOS }} steps: - name: Decide pre-submit integration test enablement # Always enable integration tests for pre-submit pull requests. @@ -106,9 +107,11 @@ jobs: if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then echo '::set-output name=matrix-CUDA::[["self-hosted", "A100"], ["self-hosted", "H100"]]' echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]' + echo '::set-output name=matrix-MACOS::[["macos-latest"]]' else echo '::set-output name=matrix-CUDA::["ubuntu-latest"]' echo '::set-output name=matrix-HIP::["ubuntu-latest"]' + echo '::set-output name=matrix-MACOS::[["macos-latest"]]' fi pre-commit: name: pre-commit (code formatting) @@ -165,6 +168,7 @@ jobs: echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT echo "pybind11=$(cat cmake/pybind11-version.txt)" >> $GITHUB_OUTPUT echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT + echo "json=$(cat cmake/json-version.txt)" >> $GITHUB_OUTPUT echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT shell: bash - name: Cache build dependencies @@ -176,7 +180,8 @@ jobs: ~/.triton/llvm ~/.triton/nvidia ~/.triton/pybind11 - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-pybind11-${{ steps.cache-key.outputs.pybind11 }} + ~/.triton/json + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-pybind11-${{ steps.cache-key.outputs.pybind11 }}-json-${{ steps.cache-key.outputs.json }} - # Cache ~/.triton/cache because the vast majority of unit test time is # spent compiling. Triton won't (well, should not) use these cached files # if something internal to Triton changes, because Triton's internal @@ -301,6 +306,7 @@ jobs: echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT echo "pybind11=$(cat cmake/pybind11-version.txt)" >> $GITHUB_OUTPUT echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT + echo "json=$(cat cmake/json-version.txt)" >> $GITHUB_OUTPUT echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT shell: bash - name: Cache build dependencies @@ -312,7 +318,8 @@ jobs: ~/.triton/llvm ~/.triton/nvidia ~/.triton/pybind11 - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-pybind11-${{ steps.cache-key.outputs.pybind11 }} + ~/.triton/json + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-pybind11-${{ steps.cache-key.outputs.pybind11 }}-json-${{ steps.cache-key.outputs.json }} - # Cache ~/.triton/cache because the vast majority of unit test time is # spent compiling. Triton won't (well, should not) use these cached files # if something internal to Triton changes, because Triton's internal @@ -398,6 +405,112 @@ jobs: ls -alh ~/.triton du -sh ~/.triton/** + mkdir -p ~/.cache/ccache + ls -alh ~/.cache/ccache + du -sh ~/.cache/ccache + Build-Tests: + needs: Runner-Preparation + if: needs.Runner-Preparation.outputs.matrix-MACOS != '' + runs-on: ${{ matrix.runner }} + timeout-minutes: 30 + strategy: + matrix: + runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-MACOS)}} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: "true" + - name: Install brew dependencies + run: | + brew update + brew install ccache llvm + - name: Compute cache keys + id: cache-key + run: | + echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT + echo "pybind11=$(cat cmake/pybind11-version.txt)" >> $GITHUB_OUTPUT + echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT + echo "json=$(cat cmake/json-version.txt)" >> $GITHUB_OUTPUT + echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT + shell: bash + - name: Cache build dependencies + uses: actions/cache@v4 + with: + # Note that we cannot use environment variables here given there is + # no shell to interpret them in the paths. + path: | + ~/.triton/llvm + ~/.triton/nvidia + ~/.triton/pybind11 + ~/.triton/json + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-pybind11-${{ steps.cache-key.outputs.pybind11 }}-json-${{ steps.cache-key.outputs.json }} + - # Cache ~/.triton/cache because the vast majority of unit test time is + # spent compiling. Triton won't (well, should not) use these cached files + # if something internal to Triton changes, because Triton's internal + # source code is part of the cache key. + # + # Similarly, cache ~/.cache/ccache to speed up compilation. + # + # On branch `main` we always start from an empty cache, i.e. we skip the + # "restore" step. This is to prevent the caches from accumulating stale + # files over time. + name: Restore cache of ccache and Triton compilation artifacts + if: github.event_name != 'push' + uses: actions/cache/restore@v4 + with: + path: | + ~/.triton/cache + ~/.cache/ccache + # Restore the most recent cache entry. + restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + # We expect this cache key never to hit and for us to fall back + # unconditionally to the restore-key, so it doesn't actually matter + # what we put here (so long as it doesn't hit an existing key). + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directory + run: | + mkdir -p ~/.triton + ls -alh ~/.triton + - name: Update PATH + run: | + echo "$HOME/.local/bin" >> $GITHUB_PATH + echo "/opt/homebrew/opt/llvm/bin" >> $GITHUB_PATH + - name: Install pip dependencies + run: | + python3 -m venv ~/.venv + source ~/.venv/bin/activate + python3 -m pip install --upgrade pip + python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit + - name: Install Triton + env: + TRITON_BUILD_WITH_CCACHE: "true" + TRITON_BUILD_WITH_O1: "true" + # macos-latest has 3 vcpus and 7GB DRAM, to save memory we limit the number of jobs to 3 + # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories + MAX_JOBS: 3 + run: | + source ~/.venv/bin/activate + echo "PATH is '$PATH'" + cd python + python3 -m pip install --no-build-isolation . + - # If we're on branch `main`, save the ccache Triton compilation artifacts + # to the cache so they can be used by other (non-main) CI runs. + # + # (It wouldn't be a problem to save the cache on every run, because github + # evicts cache entries LRU, but maybe this saves a bit of time in CI.) + name: Save ccache and Triton compilation artifacts to cache + if: github.ref == 'refs/heads/main' + uses: actions/cache/save@v4 + with: + path: ~/.triton/cache ~/.cache/ccache + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + ls -alh ~/.triton + du -sh ~/.triton/** + mkdir -p ~/.cache/ccache ls -alh ~/.cache/ccache du -sh ~/.cache/ccache diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 01ad066441..3b0a4639d1 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -35,6 +35,7 @@ jobs: outputs: matrix-CUDA: ${{ steps.set-matrix.outputs.matrix-CUDA }} matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }} + matrix-MACOS: ${{ steps.set-matrix.outputs.matrix-MACOS }} steps: - name: Decide pre-submit integration test enablement # Always enable integration tests for pre-submit pull requests. @@ -114,9 +115,11 @@ jobs: if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then echo '::set-output name=matrix-CUDA::[["self-hosted", "A100"], ["self-hosted", "H100"]]' echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]' + echo '::set-output name=matrix-MACOS::[["macos-latest"]]' else echo '::set-output name=matrix-CUDA::["ubuntu-latest"]' echo '::set-output name=matrix-HIP::["ubuntu-latest"]' + echo '::set-output name=matrix-MACOS::[["macos-latest"]]' fi pre-commit: @@ -162,6 +165,7 @@ jobs: run: | git diff + Integration-Tests: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-CUDA != '' @@ -186,6 +190,7 @@ jobs: echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT echo "pybind11=$(cat cmake/pybind11-version.txt)" >> $GITHUB_OUTPUT echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT + echo "json=$(cat cmake/json-version.txt)" >> $GITHUB_OUTPUT echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT shell: bash @@ -199,7 +204,8 @@ jobs: ~/.triton/llvm ~/.triton/nvidia ~/.triton/pybind11 - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-pybind11-${{ steps.cache-key.outputs.pybind11 }} + ~/.triton/json + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-pybind11-${{ steps.cache-key.outputs.pybind11 }}-json-${{ steps.cache-key.outputs.json }} # Cache ~/.triton/cache because the vast majority of unit test time is # spent compiling. Triton won't (well, should not) use these cached files @@ -384,3 +390,52 @@ jobs: - *run-cpp-unittests-step - *save-build-artifacts-step - *inspect-cache-directories-step + + Build-Tests: + needs: Runner-Preparation + if: needs.Runner-Preparation.outputs.matrix-MACOS != '' + runs-on: ${{ matrix.runner }} + timeout-minutes: 30 + strategy: + matrix: + runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-MACOS)}} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: "true" + - name: Install brew dependencies + run: | + brew update + brew install ccache llvm + + - *compute-cache-keys-step + - *cache-build-dependencies-step + - *restore-build-artifacts-step + - *inspect-cache-directory-step + + - name: Update PATH + run: | + echo "$HOME/.local/bin" >> $GITHUB_PATH + echo "/opt/homebrew/opt/llvm/bin" >> $GITHUB_PATH + - name: Install pip dependencies + run: | + python3 -m venv ~/.venv + source ~/.venv/bin/activate + python3 -m pip install --upgrade pip + python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit + - name: Install Triton + env: + TRITON_BUILD_WITH_CCACHE: "true" + TRITON_BUILD_WITH_O1: "true" + # macos-latest has 3 vcpus and 7GB DRAM, to save memory we limit the number of jobs to 3 + # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories + MAX_JOBS: 3 + run: | + source ~/.venv/bin/activate + echo "PATH is '$PATH'" + cd python + python3 -m pip install --no-build-isolation . + + - *save-build-artifacts-step + - *inspect-cache-directories-step diff --git a/CMakeLists.txt b/CMakeLists.txt index 8f53a26026..8766e2302e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,8 @@ set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") # Customized release build type with assertions: TritonRelBuildWithAsserts set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") +set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1") +set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1") # Default build type if(NOT CMAKE_BUILD_TYPE) @@ -264,7 +266,7 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) # Check if the platform is MacOS if(APPLE) - set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto") + set(PYTHON_LDFLAGS "-undefined dynamic_lookup") endif() target_link_libraries(triton PRIVATE ${PYTHON_LDFLAGS}) diff --git a/cmake/json-version.txt b/cmake/json-version.txt new file mode 100644 index 0000000000..c294f65bf3 --- /dev/null +++ b/cmake/json-version.txt @@ -0,0 +1 @@ +v3.11.3 diff --git a/python/setup.py b/python/setup.py index 09b5769ed8..39d58e1338 100644 --- a/python/setup.py +++ b/python/setup.py @@ -99,6 +99,8 @@ def get_build_type(): return "RelWithDebInfo" elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"): return "TritonRelBuildWithAsserts" + elif check_env_flag("TRITON_BUILD_WITH_O1"): + return "TritonBuildWithO1" else: # TODO: change to release when stable enough return "TritonRelBuildWithAsserts" diff --git a/third_party/proton/CMakeLists.txt b/third_party/proton/CMakeLists.txt index 1518f2f206..ab25b41fd6 100644 --- a/third_party/proton/CMakeLists.txt +++ b/third_party/proton/CMakeLists.txt @@ -35,7 +35,9 @@ endif() # Check if the platform is MacOS if(APPLE) - set(PROTON_PYTHON_LDFLAGS "-undefined dynamic_lookup -flto") + set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") + # Other platforms build with -flto, but we found that this adds significant overhead to our macos CI without providing a major benefit. + set(PROTON_PYTHON_LDFLAGS "-undefined dynamic_lookup") endif() include_directories(${CUPTI_INCLUDE_DIR}) From 100e2aaca903ed99564242f933198a6c221d3b50 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Tue, 28 May 2024 22:32:18 +0200 Subject: [PATCH 7/7] [AMD][WMMA] Support dot3d (#3674) This PR enables support of 3d dot for RDNA GPUs. --- .../Conversion/TritonGPUToLLVM/Utility.h | 85 ++++++++++---- lib/Analysis/Utility.cpp | 8 +- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 75 +++++++++---- python/test/unit/language/test_core.py | 5 + .../amd/tritongpu_wmma_dot_to_llvm.mlir | 30 +++++ .../SharedToDotOperandMFMA.cpp | 2 + .../SharedToDotOperandWMMA.cpp | 90 +++++++++------ .../TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp | 106 ++++++++++-------- 9 files changed, 272 insertions(+), 131 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 186e667756..2c973e0021 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -908,14 +908,21 @@ emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout, inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout, SmallVector> &offsets, - unsigned ctaOffsetX, unsigned ctaOffsetY) { + unsigned ctaBatchOffset, unsigned ctaOffsetX, + unsigned ctaOffsetY) { const unsigned elemsPerThreadPerGroup = 8; auto warpSize = getWarpSize(wmmaLayout); assert(warpSize == 32); auto shapePerCta = getShapePerCTATile(wmmaLayout); + auto rank = shapePerCta.size(); + assert(rank == 2 || rank == 3); + SmallVector elemOffset(rank, 0); + if (rank == 3) + elemOffset[0] = ctaBatchOffset; for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { - offsets.push_back( - {ctaOffsetX * shapePerCta[0] + 2 * elem, ctaOffsetY * shapePerCta[1]}); + elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem; + elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; + offsets.push_back(elemOffset); } } @@ -925,9 +932,11 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, RankedTensorType type) { auto shape = type.getShape(); auto _warpsPerCTA = wmmaLayout.getWarpsPerCTA(); - assert(_warpsPerCTA.size() == 2); - SmallVector warpsPerCTA = {i32_val(_warpsPerCTA[0]), - i32_val(_warpsPerCTA[1])}; + auto rank = _warpsPerCTA.size(); + assert(rank == 2 || rank == 3); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); Value threadId = getThreadId(rewriter, loc); @@ -940,20 +949,34 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, triton::gpu::getWarpOrder(wmmaLayout)); - if (shape[0] >= mnkDim[0]) { - assert(shape[0] % mnkDim[0] == 0); - multiDimWarpId[0] = - urem(multiDimWarpId[0], i32_val(ceil(shape[0], mnkDim[0]))); + if (shape[rank - 2] >= mnkDim[0]) { + assert(shape[rank - 2] % mnkDim[0] == 0); + multiDimWarpId[rank - 2] = + urem(multiDimWarpId[rank - 2], + i32_val(ceil(shape[rank - 2], mnkDim[0]))); } - if (shape[1] >= mnkDim[1]) { - assert(shape[1] % mnkDim[1] == 0); - multiDimWarpId[1] = - urem(multiDimWarpId[1], i32_val(ceil(shape[1], mnkDim[1]))); + if (shape[rank - 1] >= mnkDim[1]) { + assert(shape[rank - 1] % mnkDim[1] == 0); + multiDimWarpId[rank - 1] = + urem(multiDimWarpId[rank - 1], + i32_val(ceil(shape[rank - 1], mnkDim[1]))); } - Value offWarp0 = mul(multiDimWarpId[0], i32_val(mnkDim[0])); - Value offWarp1 = mul(multiDimWarpId[1], i32_val(mnkDim[1])); - return {add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0), - add(laneId, offWarp1)}; + Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mnkDim[0])); + Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(mnkDim[1])); + + SmallVector multiDimBase(rank); + + multiDimBase[rank - 2] = + add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + multiDimBase[rank - 1] = add(laneId, offWarp1); + + // TODO: It is assumed when rank = 3, warpsPerCTA is set to + // {numWarps, 1, 1}. We need to generalize the offset computation. + if (rank == 3) { + assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); + multiDimBase[0] = urem(warpId, i32_val(shape[0])); + } + return multiDimBase; } inline SmallVector> @@ -964,17 +987,31 @@ emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout, auto shapePerCTA = getShapePerCTA(wmmaLayout, tensorShape); auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); - SmallVector numWarpsPerDim(2); + auto rank = tensorShape.size(); + assert(rank == 2 || rank == 3); + + SmallVector numWarpsPerDim(rank, 1); auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); - for (unsigned d = 0; d < 2; ++d) { + SmallVector shapePerWarp(rank, 1); + shapePerWarp[rank - 2] = mnkDim[0]; + shapePerWarp[rank - 1] = mnkDim[1]; + for (unsigned d = 0; d < rank; ++d) { unsigned inPerCTA = std::min(tensorShape[d], shapePerCTA[d]); unsigned inPerWarp = ceil(inPerCTA, warpsPerCTA[d]); - numWarpsPerDim[d] = ceil(inPerWarp, mnkDim[d]); + numWarpsPerDim[d] = ceil(inPerWarp, shapePerWarp[d]); } - for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) { - for (unsigned j = 0; j < numWarpsPerDim[1]; ++j) { - emitWmmaOffsetForCTA(wmmaLayout, offsets, i, j); + unsigned repBatch = rank == 3 ? numWarpsPerDim[0] : 1; + unsigned repM = numWarpsPerDim[rank - 2]; + unsigned repN = numWarpsPerDim[rank - 1]; + auto warpsPerBatch = + rank == 3 ? std::min(tensorShape[0], warpsPerCTA[0]) : 1; + + for (unsigned b = 0; b < repBatch; ++b) { + for (unsigned i = 0; i < repM; ++i) { + for (unsigned j = 0; j < repN; ++j) { + emitWmmaOffsetForCTA(wmmaLayout, offsets, b * warpsPerBatch, i, j); + } } } return offsets; diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 689e83b5ac..32cc43c9d5 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -469,6 +469,7 @@ bool supportMFMA(triton::DotOp op) { auto bShape = bTy.getShape(); auto rank = aShape.size(); + assert(bShape.size() == rank); auto M = aShape[rank - 2]; auto N = bShape[rank - 1]; auto K = aShape[rank - 1]; @@ -521,8 +522,11 @@ bool supportWMMA(triton::DotOp op) { auto aShape = aTy.getShape(); auto bShape = bTy.getShape(); - assert(aShape[1] == bShape[0]); - if (!supportWMMAGranularity(aShape[0], bShape[1], aShape[1])) + auto rank = aShape.size(); + assert(bShape.size() == rank); + assert(aShape[rank - 1] == bShape[rank - 2]); + if (!supportWMMAGranularity(aShape[rank - 2], bShape[rank - 1], + aShape[rank - 1])) return false; return true; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index ca7367d15b..a80158a463 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -590,7 +590,7 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0], multiDimCTAInRepId[1]); } else if (auto wmmaLayout = dyn_cast(layout)) { - emitWmmaOffsetForCTA(wmmaLayout, offsets, multiDimCTAInRepId[0], + emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0], multiDimCTAInRepId[1]); } multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0])); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 86c9f8241c..2d5f3e9755 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -804,16 +804,22 @@ SmallVector AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { size_t rank = shape.size(); - assert(rank == 2 && "Unexpected rank of wmma layout"); + assert((rank == 2 || rank == 3) && "Unexpected rank of wmma layout"); SmallVector elemsPerThread(rank); auto mnkDim = getMNKDimPerWMMAInstr(); auto elemsPerThreadPerTile = getSizePerThread(); auto warpsPerCTA = getWarpsPerCTA(); - return {ceil(shape[0], mnkDim[0] * warpsPerCTA[0]) * - elemsPerThreadPerTile[0], - ceil(shape[1], mnkDim[1] * warpsPerCTA[1]) * - elemsPerThreadPerTile[1]}; + + if (rank == 3) + elemsPerThread[0] = ceil(shape[0], getWarpsPerCTA()[0]); + elemsPerThread[rank - 2] = + ceil(shape[rank - 2], mnkDim[0] * warpsPerCTA[rank - 2]) * + elemsPerThreadPerTile[rank - 2]; + elemsPerThread[rank - 1] = + ceil(shape[rank - 1], mnkDim[1] * warpsPerCTA[rank - 1]) * + elemsPerThreadPerTile[rank - 1]; + return elemsPerThread; } unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, @@ -1605,9 +1611,8 @@ AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { - constexpr int warpSize = 64; auto rep = getMFMARepForOperands(shape, kWidth, opIdx); - return rep[0] * rep[1] * rep[2] * kWidth; + return product(rep) * kWidth; } SmallVector @@ -1646,8 +1651,14 @@ AMDMfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, SmallVector AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); + auto mnkDim = getMNKDimPerWMMAInstr(); - return {mnkDim[0] * getWarpsPerCTA()[0], mnkDim[1] * getWarpsPerCTA()[1]}; + shapePerCTATile[rank - 2] *= mnkDim[0]; + shapePerCTATile[rank - 1] *= mnkDim[1]; + return shapePerCTATile; } SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); @@ -1668,28 +1679,43 @@ SmallVector AMDWmmaEncodingAttr::getThreadOrder() const { return ::getOrder(*this); } SmallVector AMDWmmaEncodingAttr::getThreadsPerWarp() const { - return {getMNKDimPerWMMAInstr()[0] / getSizePerThread()[0], - getMNKDimPerWMMAInstr()[1] / getSizePerThread()[1]}; + auto rank = getWarpsPerCTA().size(); + SmallVector threads(rank, 1); + auto mnkInstr = getMNKDimPerWMMAInstr(); + threads[rank - 2] = mnkInstr[0] / getSizePerThread()[rank - 2]; + threads[rank - 1] = mnkInstr[1] / getSizePerThread()[rank - 1]; + return threads; } SmallVector AMDWmmaEncodingAttr::getSizePerThread() const { - return {8, 1}; + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); + sizePerThread[rank - 2] = 8; + sizePerThread[rank - 1] = 1; + return sizePerThread; } SmallVector AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); if (opIdx == 0) { - return {1, 16}; + sizePerThread[rank - 2] = 1; + sizePerThread[rank - 1] = 16; } else if (opIdx == 1) { - return {16, 1}; + sizePerThread[rank - 2] = 16; + sizePerThread[rank - 1] = 1; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } + return sizePerThread; } SmallVector AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const { auto parentShapePerCTA = getShapePerCTATile(shape); + auto rank = shape.size(); + assert(rank = 2); if (opIdx == 0) { return {parentShapePerCTA[0], static_cast(shape[1])}; } else if (opIdx == 1) { @@ -1702,7 +1728,7 @@ AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx); - return rep[0] * rep[1] * kWidth; + return product(rep) * kWidth; } SmallVector @@ -1715,16 +1741,25 @@ AMDWmmaEncodingAttr::getWMMARepForOperands(ArrayRef operandShape, Type elemType, int kWidth, int opIdx) const { auto operandTileShape = getWMMAElemsPerInstrForOperands(); + assert(operandTileShape.size() == 2); auto warpsPerCTA = getWarpsPerCTA(); + auto rank = operandShape.size(); + assert(rank == 2 || rank == 3); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; if (opIdx == 0) - return {std::max(1, operandShape[0] / - (operandTileShape[0] * warpsPerCTA[0])), - std::max(1, operandShape[1] / operandTileShape[1])}; + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; else { assert(opIdx == 1); - return {std::max(1, operandShape[0] / operandTileShape[0]), - std::max(1, operandShape[1] / - (operandTileShape[1] * warpsPerCTA[1]))}; + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; } } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f6604c3de3..5972c93d7f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3262,6 +3262,11 @@ def test_dot3d(B, num_warps, M, N, K, in_dtype_str, out_dtype_str, device): if is_hip(): # hip does not support tf32 precision, so use ieee for all tests input_precision = "ieee" + if "gfx11" in triton.runtime.driver.active.get_current_target().arch: + if in_dtype_str == "float32": + pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") + if out_dtype_str == "float16": + pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") else: input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" diff --git a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir index b42edaea3d..16925a54a5 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -57,3 +57,33 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> +#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 1, 4]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_dot_operand3d + tt.func @wmma_dot_operand3d(%arg0: !tt.memdesc<4x16x32xf16, #shared>) { + // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> + %0 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK-COUNT-32: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> + %1 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: wmma_dot3d + tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma>) { + // CHECK-COUNT-32: llvm.extractvalue %arg0 + // CHECK-COUNT-32: llvm.insertelement + // CHECK-COUNT-32: llvm.extractvalue %arg1 + // CHECK-COUNT-32: llvm.insertelement + // CHECK-COUNT-8: llvm.extractvalue %arg2 + // CHECK-COUNT-8: llvm.insertelement + // CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<2x16x16xf16, #mma> + // CHECK-COUNT-8: llvm.extractelement + // CHECK-COUNT-8: llvm.insertvalue + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 59682558d9..7c05ffede2 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -212,6 +212,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto sharedLayout = cast(aTensorTy.getEncoding()); auto order = sharedLayout.getOrder(); + assert((rank == 2 || order[2] == 0) && + "expect batch to be the slowest dimension"); auto elemTy = aTensorTy.getElementType(); auto kWidth = encoding.getKWidth(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index 3e2ec71db3..950e2926a1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -69,7 +69,9 @@ llvm::SmallVector> computeTensorElemMappingInBlock( const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int numOfElems, ArrayRef reps, ArrayRef smemOffsets, int loadVecSize, unsigned iNonKDim, [[maybe_unused]] unsigned iKDim) { - auto numK = reps[1]; + assert(reps.size() == 3); + assert(elemsPerInstr.size() == 2); + auto numK = reps[2]; const int loadsPerThread = numOfElems / loadVecSize; llvm::SmallVector> mapping(numK * loadsPerThread); @@ -77,6 +79,8 @@ llvm::SmallVector> computeTensorElemMappingInBlock( Value nonKDim = i32_val(iNonKDim); Value warpVOffset = mul(warpId, i32_val(elemsPerInstr[0])); + auto rank = smemOffsets.size(); + for (int tile = 0; tile < numK; ++tile) { Value tileVOffset = _0; Value tileHOffset = i32_val(tile * elemsPerInstr[1]); @@ -92,8 +96,8 @@ llvm::SmallVector> computeTensorElemMappingInBlock( add(add(add(tileVOffset, laneVOffset), elemVOffset), warpVOffset); Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset); - Value row = add(sliceVOffset, smemOffsets[0]); - Value col = add(sliceHOffset, smemOffsets[1]); + Value row = add(sliceVOffset, smemOffsets[rank - 2]); + Value col = add(sliceHOffset, smemOffsets[rank - 1]); mapping[loadsPerThread * tile + loadId] = {row, col}; } @@ -107,61 +111,68 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread) { assert((opIdx == 0 || opIdx == 1) && "unexpected operand idx"); - int kDimIdx = opIdx == 0 ? 1 : 0; - int nonKDimIdx = opIdx == 0 ? 0 : 1; + auto rank = smemObj.getStrides().size(); + int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2; + int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1; auto wmmaLayout = cast(encoding.getParent()); - auto nonKDim = wmmaLayout.getMNKDimPerWMMAInstr()[nonKDimIdx]; - assert(nonKDim == 16); + assert(wmmaLayout.getMNKDimPerWMMAInstr()[nonKDimIdx] == 16); auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); auto aTensorTy = cast(tensor.getType()); ArrayRef shape = aTensorTy.getShape(); auto sharedLayout = cast(aTensorTy.getEncoding()); auto order = sharedLayout.getOrder(); + assert((rank == 2 || order[2] == 0) && + "expect batch to be the slowest dimension"); auto elemTy = aTensorTy.getElementType(); int kWidth = encoding.getKWidth(); auto elemsPerInstr = wmmaLayout.getWMMAElemsPerInstrForOperands(); - auto wmmaInstrK = elemsPerInstr[kDimIdx]; + auto wmmaInstrK = elemsPerInstr[opIdx == 0 ? 1 : 0]; + auto wmmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1]; + assert(wmmaInstrNonK == 16); auto numReps = wmmaLayout.getWMMARepForOperands(shape, elemTy, kWidth, opIdx); - auto numRepNonK = numReps[nonKDimIdx]; - auto numRepK = numReps[kDimIdx]; - - unsigned iWarpSize = triton::gpu::getWarpSize(wmmaLayout); - unsigned iNumLanes = iWarpSize / 2; - assert(iWarpSize == 32); - Value warpSize = i32_val(iWarpSize); + auto numRepNonK = numReps[opIdx == 0 ? 1 : 2]; + auto numRepK = numReps[opIdx == 0 ? 2 : 1]; + auto repB = numReps[0]; + + unsigned iWaveSize = triton::gpu::getWarpSize(wmmaLayout); + unsigned iNumLanes = iWaveSize / 2; + assert(iWaveSize == 32); + Value waveSize = i32_val(iWaveSize); Value numLanes = i32_val(iNumLanes); - Value linearWarpId = udiv(thread, warpSize); + Value linearWaveId = udiv(thread, waveSize); Value lane = urem(thread, numLanes); // share elem between two threads - unsigned numElemsPerThreadPerRep = - wmmaLayout.getMNKDimPerWMMAInstr()[kDimIdx]; + unsigned numElemsPerThreadPerRep = wmmaInstrK; - Value warp = udiv(thread, warpSize); - unsigned int maxNumWarps = shape[nonKDimIdx] / elemsPerInstr[nonKDimIdx]; + Value warp = udiv(thread, waveSize); + unsigned int maxNumWarps = shape[nonKDimIdx] / wmmaInstrNonK; int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); + int warpsPerBatch = + rank == 3 ? std::min(shape[0], warpsPerCTA[0]) : 1; + Value waveIdInBatch = urem(linearWaveId, i32_val(warpsPerBatch)); elemTy = typeConverter->convertType(elemTy); SmallVector loadedValues; SmallVector offsets; Value smemBase; Value spatialWarpId = AMD::getWarpIdInBlock( - rewriter, loc, linearWarpId, warpsPerCTA, elemsPerInstr[0], + rewriter, loc, linearWaveId, warpsPerCTA, elemsPerInstr[0], shape[nonKDimIdx], nonKDimIdx, triton::gpu::getOrder(wmmaLayout)); if (opIdx == 0) { offsets = AMD::computeOffsetsAType( rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, spatialWarpId, lane, warpsPerBlockNonK, numElemsPerThreadPerRep, - numReps, smemObj, sharedLayout, nonKDim, wmmaInstrK); + numReps, smemObj, sharedLayout, wmmaInstrNonK, wmmaInstrK); } else { assert(opIdx == 1); offsets = AMD::computeOffsetsBType( rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, spatialWarpId, lane, warpsPerBlockNonK, numElemsPerThreadPerRep, - numReps, smemObj, sharedLayout, nonKDim, wmmaInstrK); + numReps, smemObj, sharedLayout, wmmaInstrNonK, wmmaInstrK); } smemBase = AMD::computeBasePtr(rewriter, loc, smemObj); @@ -171,19 +182,26 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int loadsPerThread = offsets.size() / (numRepNonK * numRepK); int elemsPerLoad = numElemsPerThreadPerRep / loadsPerThread; assert(numElemsPerThreadPerRep % loadsPerThread == 0); - for (int nonK = 0; nonK < numRepNonK; ++nonK) { - for (int k = 0; k < numRepK; ++k) { - auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); - Value valVec = undef(vecTy); - for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); - Value loadOffset = offsets[nonK * loadsPerThread * numRepK + - k * loadsPerThread + loadId]; - Value loadAddress = gep(smemPtrTy, elemTy, smemBase, loadOffset); - Value loadedValue = load(loadVecTy, loadAddress); - for (int elemId = 0; elemId < elemsPerLoad; ++elemId) { - Value elemVal = extract_element(elemTy, loadedValue, i32_val(elemId)); - loadedValues.push_back(elemVal); + for (int b = 0; b < repB; ++b) { + int operandSize = shape[rank - 1] * shape[rank - 2]; + Value batchOffset = mul(i32_val(operandSize), + add(waveIdInBatch, i32_val(b * warpsPerBatch))); + for (int nonK = 0; nonK < numRepNonK; ++nonK) { + for (int k = 0; k < numRepK; ++k) { + auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); + Value valVec = undef(vecTy); + for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { + auto loadVecTy = vec_ty(elemTy, elemsPerLoad); + Value loadOffset = offsets[nonK * loadsPerThread * numRepK + + k * loadsPerThread + loadId]; + loadOffset = add(loadOffset, batchOffset); + Value loadAddress = gep(smemPtrTy, elemTy, smemBase, loadOffset); + Value loadedValue = load(loadVecTy, loadAddress); + for (int elemId = 0; elemId < elemsPerLoad; ++elemId) { + Value elemVal = + extract_element(elemTy, loadedValue, i32_val(elemId)); + loadedValues.push_back(elemVal); + } } } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index d8b537be68..3066fe4ce0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -45,32 +45,38 @@ enum class WMMAInstrType : uint8_t { NOT_APPLICABLE, }; -using ValueTable = std::map, Value>; +using ValueTable = std::map, Value>; -ValueTable getValuesFromDotOperandLayoutStruct( - ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, - Value value, int n0, int n1, int kWidth, Type type, Location loc) { +ValueTable +getValuesFromDotOperandLayoutStruct(ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + Value value, int batch, int n0, int n1, + int kWidth, Type type, Location loc) { auto elems = unpackLLElements(loc, value, rewriter); ValueTable vals; - for (int i = 0; i < n0; i++) { - for (int j = 0; j < n1; j++) { - Type elemTy = typeConverter->convertType(type); - Type ty = vec_ty(elemTy, kWidth); - Value rawElems = undef(ty); - for (int k = 0; k < kWidth; ++k) { - rawElems = insert_element(ty, rawElems, - elems[kWidth * (n1 * i + j) + k], i32_val(k)); - } + for (int b = 0; b < batch; b++) { + for (int i = 0; i < n0; i++) { + for (int j = 0; j < n1; j++) { + Type elemTy = typeConverter->convertType(type); + Type ty = vec_ty(elemTy, kWidth); + Value rawElems = undef(ty); + for (int k = 0; k < kWidth; ++k) { + rawElems = insert_element( + ty, rawElems, + elems[n0 * n1 * kWidth * b + kWidth * (n1 * i + j) + k], + i32_val(k)); + } - Value convertedElems; - if (type.isBF16() || type.isF16()) { - convertedElems = rawElems; - } else { - convertedElems = bitcast( - rawElems, vec_ty(i32_ty, kWidth * type.getIntOrFloatBitWidth() / - i32_ty.getIntOrFloatBitWidth())); + Value convertedElems; + if (type.isBF16() || type.isF16()) { + convertedElems = rawElems; + } else { + convertedElems = bitcast( + rawElems, vec_ty(i32_ty, kWidth * type.getIntOrFloatBitWidth() / + i32_ty.getIntOrFloatBitWidth())); + } + vals[{b, i, j}] = convertedElems; } - vals[{i, j}] = convertedElems; } } return vals; @@ -172,52 +178,56 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, auto repB = wmmaLayout.getWMMARepForOperands(bTensorTy.getShape(), elemTy, kWidth, 1); - assert(repA[1] == repB[0]); + assert(repA[2] == repB[1]); Value loadedA = adaptor.getA(); Value loadedB = adaptor.getB(); Value loadedC = adaptor.getC(); - auto numRepM = repA[0]; - auto numRepN = repB[1]; - auto numRepK = repA[1]; + auto numRepM = repA[1]; + auto numRepN = repB[2]; + auto numRepK = repA[2]; + auto numRepB = repA[0]; ValueTable ha = getValuesFromDotOperandLayoutStruct( - rewriter, typeConverter, loadedA, numRepM, numRepK, kWidth, + rewriter, typeConverter, loadedA, numRepB, numRepM, numRepK, kWidth, aTensorTy.getElementType(), loc); ValueTable hb = getValuesFromDotOperandLayoutStruct( - rewriter, typeConverter, loadedB, numRepN, numRepK, kWidth, + rewriter, typeConverter, loadedB, numRepB, numRepN, numRepK, kWidth, aTensorTy.getElementType(), loc); auto dstElemTy = dTensorTy.getElementType(); auto fc = unpackLLElements(loc, loadedC, rewriter); unsigned warpSize = triton::gpu::getWarpSize(wmmaLayout); - // TODO get rid of magic numbers - unsigned vgprElemWidth = 32; + constexpr unsigned vgprElemBitWidth = 32; unsigned paddedOutputElemSize = - vgprElemWidth / dstElemTy.getIntOrFloatBitWidth(); + vgprElemBitWidth / dstElemTy.getIntOrFloatBitWidth(); // compute number of output elements that each thread holds for one WMMA // instruction. auto elemsPerVec = mnkDim[0] * mnkDim[1] * paddedOutputElemSize / warpSize; auto dElemsToStorePerThread = mnkDim[0] * mnkDim[1] / warpSize; auto vecTy = vec_ty(dstElemTy, elemsPerVec); - for (int m = 0; m < numRepM; ++m) { - for (int n = 0; n < numRepN; ++n) { - Value acc = undef(vecTy); - for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { - acc = insert_element(vecTy, acc, - fc[m * numRepN * dElemsToStorePerThread + - n * dElemsToStorePerThread + v], - i32_val(v * paddedOutputElemSize)); - } - for (size_t k = 0; k < numRepK; k++) { - acc = generateWMMAOp(rewriter, loc, wmmaInstrType, ha[{m, k}], - hb[{n, k}], acc, aTensorTy.getElementType(), - bTensorTy.getElementType()); - } - for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { - fc[m * numRepN * dElemsToStorePerThread + n * dElemsToStorePerThread + - v] = - extract_element(dstElemTy, acc, i32_val(v * paddedOutputElemSize)); + for (int b = 0; b < numRepB; ++b) { + for (int m = 0; m < numRepM; ++m) { + for (int n = 0; n < numRepN; ++n) { + auto batchOffIdx = b * numRepM * numRepN * dElemsToStorePerThread; + auto mRepOffId = m * numRepN * dElemsToStorePerThread; + auto nRepOffId = n * dElemsToStorePerThread; + auto fcThreadOffIdx = batchOffIdx + mRepOffId + nRepOffId; + + Value acc = undef(vecTy); + for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { + acc = insert_element(vecTy, acc, fc[fcThreadOffIdx + v], + i32_val(v * paddedOutputElemSize)); + } + for (size_t k = 0; k < numRepK; k++) { + acc = generateWMMAOp(rewriter, loc, wmmaInstrType, ha[{b, m, k}], + hb[{b, n, k}], acc, aTensorTy.getElementType(), + bTensorTy.getElementType()); + } + for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { + fc[fcThreadOffIdx + v] = extract_element( + dstElemTy, acc, i32_val(v * paddedOutputElemSize)); + } } } }