diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 2ce6d4ed1b..028e503c08 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -245,9 +245,9 @@ jobs: lit -v "${LIT_TEST_DIR}" - name: Run python tests on CUDA run: | - SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/_C" - if [ ! -d "${SHARED_LIB_DIR}" ]; then - echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1 + INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation" + if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then + echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi cd python/test/unit python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py @@ -257,7 +257,7 @@ jobs: TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py # Run hopper/test_flashattention.py separately to avoid out of gpu memory python3 -m pytest -s hopper/test_flashattention.py - TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \ + TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \ python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py - name: Run interpreter tests if: ${{ matrix.runner[0] == 'h100-runner-set' }} @@ -401,9 +401,9 @@ jobs: lit -v "${LIT_TEST_DIR}" - name: Run python tests on HIP run: | - SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/_C" - if [ ! -d "${SHARED_LIB_DIR}" ]; then - echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1 + INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation" + if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then + echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py cd python/test/unit @@ -412,7 +412,7 @@ jobs: --ignore=test_debug.py # TODO: uncomment # pytest --capture=tee-sys -rfs test_debug.py - TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \ + TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \ pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index ef19bd7626..a868a9984b 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -279,9 +279,9 @@ jobs: - name: Run python tests on CUDA run: | - SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/_C" - if [ ! -d "${SHARED_LIB_DIR}" ]; then - echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1 + INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation" + if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then + echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi cd python/test/unit python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py @@ -291,7 +291,7 @@ jobs: TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py # Run hopper/test_flashattention.py separately to avoid out of gpu memory python3 -m pytest -s hopper/test_flashattention.py - TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \ + TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \ python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py - name: Run interpreter tests @@ -397,9 +397,9 @@ jobs: - name: Run python tests on HIP run: | - SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/_C" - if [ ! -d "${SHARED_LIB_DIR}" ]; then - echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1 + INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation" + if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then + echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py cd python/test/unit @@ -408,7 +408,7 @@ jobs: --ignore=test_debug.py # TODO: uncomment # pytest --capture=tee-sys -rfs test_debug.py - TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \ + TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \ pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 2c907e9f0d..7e10049945 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -88,6 +88,8 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUStreamPipeline(); mlir::registerTritonAMDGPUStreamPipelineV2(); mlir::registerTritonAMDGPUCanonicalizePointers(); + mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); + mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); // TODO: register Triton & TritonGPU passes registry.insert + localStoreOpConversion = nullptr; +}; + void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, PatternBenefit benefit); -void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, - PatternBenefit benefit); +// The given callback is invoked at the end of a successful rewrite. The +// callback receives 1) the current source op, 2) the number of issued LLVM +// instructions and 3) their input types. Each MLIR backend can provide a +// callback and, thus, handle backend-specific behaviors. +void populateMemoryOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks = std::nullopt); void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 29b8865c03..0f6fe913df 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1366,11 +1366,11 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, Location loc, RewriterBase &rewriter, const TargetInfoBase &target); -void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, - Type elemLlvmTy, ArrayRef srcVals, - Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, - const TargetInfoBase &target); +void storeDistributedToShared( + MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, + ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, + Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + std::pair *const llvmOpCount = nullptr); inline Value getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 3ebccfc801..204dbfe0a7 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -15,12 +15,11 @@ using namespace mlir::triton::gpu; // blocked -> shared. // Swizzling in shared memory to avoid bank conflict. Normally used for // A/B operands of dots. -void lowerDistributedToShared(Location loc, Value src, Value dst, - Value adaptorSrc, - const SharedMemoryObject &smemObj, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - const TargetInfoBase &targetInfo) { +void lowerDistributedToShared( + Location loc, Value src, Value dst, Value adaptorSrc, + const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); @@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst, auto dstStrides = smemObj.getStrides(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, - loc, rewriter, targetInfo); + loc, rewriter, targetInfo, llvmOpCount); } struct LocalAllocOpConversion @@ -185,12 +184,15 @@ struct LocalStoreOpConversion public: using ConvertOpToLLVMPattern< triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + using BackendCallbackType = + decltype(BackendCallbacks::localStoreOpConversion); LocalStoreOpConversion(const LLVMTypeConverter &converter, const TargetInfoBase &targetInfo, + BackendCallbackType backendCallback, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(converter, benefit), - targetInfo(targetInfo) {} + targetInfo(targetInfo), backendCallback(backendCallback) {} LogicalResult matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, @@ -200,24 +202,36 @@ struct LocalStoreOpConversion getTypeConverter()->convertType(op.getDst().getType().getElementType()); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + + std::pair llvmOpCount; lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), adaptor.getSrc(), smemObj, getTypeConverter(), - rewriter, targetInfo); + rewriter, targetInfo, &llvmOpCount); + + if (backendCallback) + (backendCallback)(op, llvmOpCount.first, llvmOpCount.second); + rewriter.eraseOp(op); return success(); } private: const TargetInfoBase &targetInfo; + BackendCallbackType backendCallback; }; } // namespace void mlir::triton::populateMemoryOpToLLVMPattern( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks) { patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, targetInfo, benefit); + + auto backendCall = + backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr; + patterns.add(typeConverter, targetInfo, backendCall, + benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index e857dd36f6..67954e5dae 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -404,7 +404,8 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target) { + const TargetInfoBase &target, + std::pair *const llvmOpCount) { bool success = emitTransferBetweenRegistersAndShared( srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { @@ -418,7 +419,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, store(vec, vecAddr) .setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); + if (llvmOpCount) { + ++(llvmOpCount->first); + llvmOpCount->second = vecTy; + } }); + if (!success) llvm::report_fatal_error("Failed to emit transfer from register to shared"); } diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index bb22489eac..a700a30a0e 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -370,9 +370,11 @@ class RewriteTensorPointerPass } // update rewritedInfo + auto opResults = op.getResults(); unsigned oldResIdx = 0, newResIdx = 0; while (oldResIdx < results.size()) { if (!triton::isTensorPointerType(results[oldResIdx].getType())) { + opResults[oldResIdx].replaceAllUsesWith(newOp.getResult(newResIdx)); oldResIdx++; newResIdx++; } else { diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 06a15f93fd..6d7300a1e2 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1909,8 +1909,6 @@ def where(condition, x, y, _builder=None): def add(x, y, sanitize_overflow: constexpr = True, _builder=None): x = _unwrap_if_constexpr(x) y = _unwrap_if_constexpr(y) - x = semantic.to_tensor(x, _builder) - y = semantic.to_tensor(y, _builder) return semantic.add(x, y, sanitize_overflow, _builder) @@ -1918,8 +1916,6 @@ def add(x, y, sanitize_overflow: constexpr = True, _builder=None): def sub(x, y, sanitize_overflow: constexpr = True, _builder=None): x = _unwrap_if_constexpr(x) y = _unwrap_if_constexpr(y) - x = semantic.to_tensor(x, _builder) - y = semantic.to_tensor(y, _builder) return semantic.sub(x, y, sanitize_overflow, _builder) @@ -1927,8 +1923,6 @@ def sub(x, y, sanitize_overflow: constexpr = True, _builder=None): def mul(x, y, sanitize_overflow: constexpr = True, _builder=None): x = _unwrap_if_constexpr(x) y = _unwrap_if_constexpr(y) - x = semantic.to_tensor(x, _builder) - y = semantic.to_tensor(y, _builder) return semantic.mul(x, y, sanitize_overflow, _builder) diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index a372543482..3a62421b7f 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -298,12 +298,12 @@ run_instrumentation_tests() { return fi - SHARED_LIB_DIR=$(ls -1d $TRITON_PROJ/python/build/*lib*/triton/_C) || err "Could not find $TRITON_PROJ/python/build/*lib*/triton/_C, build Triton first" + INSTRUMENTATION_LIB_DIR=$(ls -1d $TRITON_PROJ/python/build/*lib*/triton/instrumentation) || err "Could not find $TRITON_PROJ/python/build/*lib*/triton/instrumentation, build Triton first" cd $TRITON_PROJ/python/test/unit TRITON_TEST_SUITE=instrumentation \ - TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \ + TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \ pytest -vvv --device xpu instrumentation/test_gpuhello.py } diff --git a/test/Triton/rewrite-tensor-pointer.mlir b/test/Triton/rewrite-tensor-pointer.mlir index 26625c3a0f..eb39dcac01 100644 --- a/test/Triton/rewrite-tensor-pointer.mlir +++ b/test/Triton/rewrite-tensor-pointer.mlir @@ -141,26 +141,30 @@ tt.func public @rewrite_for(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK: tt.num_stages = 3 // ----- -tt.func public @rewrite_if(%arg0: !tt.ptr, %arg1: i1) -> tensor<128x32xf16> { +tt.func public @rewrite_if(%arg0: !tt.ptr, %arg1: i1, %arg2: tensor<128x32xf32>) -> tensor<128x32xf16> { %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 %c1_i64 = arith.constant 1 : i64 %c32_i64 = arith.constant 32 : i64 %c128_i64 = arith.constant 128 : i64 %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> - %1 = scf.if %arg1 -> (!tt.ptr>) { + %1:2 = scf.if %arg1 -> (tensor<128x32xf16>, !tt.ptr>) { %2 = tt.advance %0, [%c32_i32, %c0_i32] : !tt.ptr> - scf.yield %2 : !tt.ptr> + %3 = arith.truncf %arg2 : tensor<128x32xf32> to tensor<128x32xf16> + scf.yield %3, %2 : tensor<128x32xf16>, !tt.ptr> } else { - scf.yield %0 : !tt.ptr> + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16> + scf.yield %cst, %0 : tensor<128x32xf16>, !tt.ptr> } - %4 = tt.load %1 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> - tt.return %4 : tensor<128x32xf16> + %4 = tt.load %1#1 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> + %5 = arith.addf %1#0, %4 : tensor<128x32xf16> + tt.return %5 : tensor<128x32xf16> } // CHECK-LABEL: tt.func public @rewrite_if( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i1 +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<128x32xf32> // CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[C32_I32:.*]] = arith.constant 32 : i32 // CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 @@ -168,15 +172,20 @@ tt.func public @rewrite_if(%arg0: !tt.ptr, %arg1: i1) -> tensor<128x32xf16> // CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 // CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 // CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 -// CHECK: %[[IF:.*]]:2 = scf.if %[[ARG1]] -> (i64, i64) { +// CHECK: %[[IF:.*]]:3 = scf.if %[[ARG1]] -> (tensor<128x32xf16>, i64, i64) { // CHECK: %[[EXTSI2:.*]] = arith.extsi %[[C32_I32]] : i32 to i64 // CHECK: %[[ADDI0:.*]] = arith.addi %[[EXTSI0]], %[[EXTSI2]] : i64 // CHECK: %[[EXTSI3:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 // CHECK: %[[ADDI1:.*]] = arith.addi %[[EXTSI1]], %[[EXTSI3]] : i64 -// CHECK: scf.yield %[[ADDI0]], %[[ADDI1]] : i64, i64 +// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[ARG2]] : tensor<128x32xf32> to tensor<128x32xf16> +// CHECK: scf.yield %[[TRUNCF]], %[[ADDI0]], %[[ADDI1]] : tensor<128x32xf16>, i64, i64 // CHECK: } else { -// CHECK: scf.yield %[[EXTSI0]], %[[EXTSI1]] : i64, i64 +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<128x32xf16> +// CHECK: scf.yield %[[CST]], %[[EXTSI0]], %[[EXTSI1]] : tensor<128x32xf16>, i64, i64 // CHECK: } +// CHECK: %{{.*}} = tt.splat %[[IF]]#1 : i64 -> tensor<128xi64> +// CHECK: %{{.*}} = tt.splat %[[IF]]#2 : i64 -> tensor<32xi64> +// CHECK: %{{.*}} = arith.addf %[[IF]]#0, %{{.*}} : tensor<128x32xf16> // ----- diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir new file mode 100644 index 0000000000..bca502f980 --- /dev/null +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -0,0 +1,148 @@ +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -allocate-shared-memory -convert-scf-to-cf -convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s -check-prefix=INSTR_INSERTION +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -allocate-shared-memory -convert-scf-to-cf -convert-triton-amdgpu-to-llvm=arch=gfx942 -triton-amdgpu-lower-insert-instruction-sched-hints=variant="iglp0" | FileCheck %s -check-prefix=LOWER_IGLP0 + +#shared0_ex0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#mma0_ex0 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}> + +#blocked0_ex1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1_ex1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2_ex1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared0_ex1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1_ex1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#mma0_ex1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}> +#dot0_ex1 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex1, kWidth = 8}> +#dot1_ex1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex1, kWidth = 8}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // LOWER_IGLP0-LABEL: test_instruction_hints_lowering + tt.func @test_instruction_hints_lowering( + %arg0: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex0, kWidth = 16}>>, + %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex0, kWidth = 16}>>, + %arg2: tensor<32x32xf16, #mma0_ex0>) { + + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 1 : i32 + + scf.for %arg11 = %c0_i32 to %c64_i32 step %c1_i32 iter_args() -> () : i32 { + // LOWER_IGLP0: llvm.add + // LOWER_IGLP0-NEXT: %[[OPT_LEVEL:.*]] = llvm.mlir.constant(0 : i32) : i32 + // LOWER_IGLP0-NEXT: llvm.call_intrinsic "llvm.amdgcn.iglp.opt"(%[[OPT_LEVEL]]) : (i32) -> () + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex0, kWidth = 16}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex0, kWidth = 16}>> -> tensor<32x32xf16, #mma0_ex0> + scf.yield + } + tt.return + } + + // INSTR_INSERTION-LABEL: @test_llvm_instruction_count + tt.func public @test_llvm_instruction_count( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + + %cst = arith.constant dense<64> : tensor<256x64xi32, #blocked0_ex1> + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked1_ex1> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + %c63_i32 = arith.constant 63 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + + %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> + %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> + %21 = tt.splat %c256_i32 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> + %22 = tt.splat %c256_i32 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> + %23 = arith.addi %21, %19 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> + %24 = arith.addi %22, %20 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> + + %26 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> + %27 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> + %28 = tt.splat %c128_i32 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> + %29 = tt.splat %c128_i32 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> + %30 = arith.addi %28, %26 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> + %31 = arith.addi %29, %27 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> + %32 = tt.expand_dims %23 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> -> tensor<256x1xi32, #blocked0_ex1> + %33 = tt.expand_dims %24 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> -> tensor<256x1xi32, #blocked2_ex1> + %34 = tt.splat %c64_i32 : i32 -> tensor<256x1xi32, #blocked0_ex1> + %35 = arith.muli %32, %34 : tensor<256x1xi32, #blocked0_ex1> + %36 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked0_ex1> + %37 = tt.addptr %36, %35 : tensor<256x1x!tt.ptr, #blocked0_ex1>, tensor<256x1xi32, #blocked0_ex1> + %38 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0_ex1}>> + %39 = tt.expand_dims %38 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0_ex1}>> -> tensor<1x64xi32, #blocked0_ex1> + %40 = tt.broadcast %37 : tensor<256x1x!tt.ptr, #blocked0_ex1> -> tensor<256x64x!tt.ptr, #blocked0_ex1> + %41 = tt.broadcast %39 : tensor<1x64xi32, #blocked0_ex1> -> tensor<256x64xi32, #blocked0_ex1> + %42 = tt.addptr %40, %41 : tensor<256x64x!tt.ptr, #blocked0_ex1>, tensor<256x64xi32, #blocked0_ex1> + + %43 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1_ex1}>> + %44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1_ex1}>> -> tensor<64x1xi32, #blocked1_ex1> + %45 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1_ex1> + %46 = tt.addptr %45, %44 : tensor<64x1x!tt.ptr, #blocked1_ex1>, tensor<64x1xi32, #blocked1_ex1> + %47 = tt.expand_dims %30 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> -> tensor<1x128xi32, #blocked1_ex1> + %48 = tt.expand_dims %31 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> -> tensor<1x128xi32, #blocked2_ex1> + %49 = tt.splat %c64_i32 : i32 -> tensor<1x128xi32, #blocked1_ex1> + %50 = arith.muli %47, %49 : tensor<1x128xi32, #blocked1_ex1> + %51 = tt.broadcast %46 : tensor<64x1x!tt.ptr, #blocked1_ex1> -> tensor<64x128x!tt.ptr, #blocked1_ex1> + %52 = tt.broadcast %50 : tensor<1x128xi32, #blocked1_ex1> -> tensor<64x128xi32, #blocked1_ex1> + %53 = tt.addptr %51, %52 : tensor<64x128x!tt.ptr, #blocked1_ex1>, tensor<64x128xi32, #blocked1_ex1> + + %56 = triton_gpu.local_alloc : () -> !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> + %57 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + + %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma0_ex1> + + %cc0_i1 = arith.constant 1 : i1 + %59 = tt.splat %cc0_i1 : i1 -> tensor<256x64xi1, #blocked0_ex1> + %60 = tt.load %42, %59 : tensor<256x64x!tt.ptr, #blocked0_ex1> + %61 = tt.splat %cc0_i1 : i1 -> tensor<64x128xi1, #blocked1_ex1> + %62 = tt.load %53, %61 : tensor<64x128x!tt.ptr, #blocked1_ex1> + + %63 = triton_gpu.memdesc_subview %56[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %60, %63 : tensor<256x64xf16, #blocked0_ex1> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> + %64 = triton_gpu.memdesc_subview %57[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %62, %64 : tensor<64x128xf16, #blocked1_ex1> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + + %66:5 = scf.for %arg11 = %c0_i32 to %c63_i32 step %c1_i32 iter_args( + %arg12 = %cst_1, + %arg13 = %42, + %arg14 = %53, + %arg16 = %63, + %arg17 = %64) -> ( + tensor<256x128xf32, #mma0_ex1>, + tensor<256x64x!tt.ptr, #blocked0_ex1>, + tensor<64x128x!tt.ptr, #blocked1_ex1>, + !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>, + !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable>) : i32 { + + %82 = triton_gpu.local_load %arg16 : !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dot0_ex1> + %83 = triton_gpu.local_load %arg17 : !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> tensor<64x128xf16, #dot1_ex1> + + // INSTR_INSERTION: amdgpu.instruction_sched_hint + // INSTR_INSERTION-SAME: numDsReadsA = #amdgpu.InstCounter<16, vector<8xf16>> + // INSTR_INSERTION-SAME: numDsReadsB = #amdgpu.InstCounter<8, vector<8xf16>> + // INSTR_INSERTION-SAME: numDsWritesA = #amdgpu.InstCounter<8, vector<8xf16>> + // INSTR_INSERTION-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<8xf16>> + // INSTR_INSERTION-SAME: numGlobalLoadsA = #amdgpu.InstCounter<8, vector<8xf16>> + // INSTR_INSERTION-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<8xf16>> + // INSTR_INSERTION-SAME: numMMAs = #amdgpu.InstCounter<64, tensor<32x32x8xf16>> + + %84 = tt.dot %82, %83, %arg12 : tensor<256x64xf16, #dot0_ex1> * tensor<64x128xf16, #dot1_ex1> -> tensor<256x128xf32, #mma0_ex1> + %85 = tt.addptr %arg13, %cst : tensor<256x64x!tt.ptr, #blocked0_ex1>, tensor<256x64xi32, #blocked0_ex1> + %86 = tt.addptr %arg14, %cst_0 : tensor<64x128x!tt.ptr, #blocked1_ex1>, tensor<64x128xi32, #blocked1_ex1> + %87 = tt.load %85 : tensor<256x64x!tt.ptr, #blocked0_ex1> + %88 = tt.load %86 : tensor<64x128x!tt.ptr, #blocked1_ex1> + %89 = triton_gpu.memdesc_subview %56[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %87, %89 : tensor<256x64xf16, #blocked0_ex1> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> + %90 = triton_gpu.memdesc_subview %57[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %88, %90 : tensor<64x128xf16, #blocked1_ex1> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + + scf.yield %84, %85, %86, %89, %90 : + tensor<256x128xf32, #mma0_ex1>, + tensor<256x64x!tt.ptr, #blocked0_ex1>, + tensor<64x128x!tt.ptr, #blocked1_ex1>, + !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>, + !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> + } + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 3e9b8a0840..686e5a24e8 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -1,28 +1,100 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s -// Check that we order load, local_alloc, local_store (optional) and local_load one after another. This is useful -// for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers +// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands +// in cases where local_alloc is in the loop but it's operand is not. +// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers // throughout the computation. -// CHECK-LABEL: order_load_alloc_local_load -// CHECK: %[[LOAD:.+]] = tt.load -// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[LOAD]] -// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @order_load_alloc_local_load(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { - %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> - tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> + +// CHECK-LABEL: hoist_q_out_of_the_loop +// CHECK: %[[TRUNCF:.+]] = arith.truncf +// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]] +// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] +// CHECK: scf.for +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { + %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> + %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %107 = arith.addi %arg26, %c128_i64 : i64 + scf.yield %107 : i64 + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} tt.return } } + + +// ----- +// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both +// local_alloc and it's src tensor defining op are in the loop. +// CHECK-LABEL: no_hoist_q_type_reordering +// CHECK: scf.for +// CHECK: %[[TRUNCF:.+]] = arith.truncf +// CHECK-NEXT: arith.constant +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i64 = arith.constant 0 : i64 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> + %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> + %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %107 = arith.addi %arg26, %c128_i64 : i64 + scf.yield %107 : i64 + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} + tt.return + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> + // CHECK-LABEL: order_load_alloc_local_load_local_store // CHECK: %[[LOAD:.+]] = tt.load // CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir new file mode 100644 index 0000000000..bea937da60 --- /dev/null +++ b/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -0,0 +1,165 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s + +// Check the logic of sched-2nd-load optimizations +// The following tile sizes should apply the optimization +// 256x256x128 +// 256x256x64 +// The following tile sizes should NOT apply the optimization +// 256x64x128 +// 256x256x32 +// scf.for loop with two dots should not apply the optimization + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// Should apply: tile size 256x256x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// Should apply: tile size 256x256x64 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x64 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// Should NOT apply: tile size 256x64x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x64x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x64xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> + tt.return + } +} + +// Should NOT apply: tile size 256x256x32 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x32 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// Should NOT apply: tile size 128x128x128 with two dots +// CHECK-LABEL: sink_2nd_load_128x128x128_two_dot +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_128x128x128_two_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> + %6 = tt.dot %1, %2, %3 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> + %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<128x128xf16, #blocked1> -> !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %6 : tensor<128x128xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> + tt.return + } +} diff --git a/test/lib/Instrumentation/CMakeLists.txt b/test/lib/Instrumentation/CMakeLists.txt index bd416eb5de..90311bb86f 100644 --- a/test/lib/Instrumentation/CMakeLists.txt +++ b/test/lib/Instrumentation/CMakeLists.txt @@ -1,8 +1,8 @@ set(GPU_INSTRUMENTATION_PASSES - GPUHello + GPUInstrumentationTestLib ) -set(GPUHello_SOURCES +set(GPUInstrumentationTestLib_SOURCES GPUHello.cpp ) @@ -20,6 +20,17 @@ foreach( plugin ${GPU_INSTRUMENTATION_PASSES} ) LLVMCore "$<$:-undefined dynamic_lookup>" ) + # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python + # build. It is empty if building directly from the root + # CMakeLists.txt file. Therefore if not building from Python just + # use the default CMake shared lib path otherwise this causes a hard + # build error + if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + set_target_properties(${plugin} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY + "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation") + endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + # This is set to -fvisibility=hidden in the top level CMake file # which causes the llvmGetPassPluginInfo symbol to be hidden and # an "entry point not found" error. Reset it just for this target diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index 31a43acd2f..c0aa08421b 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -32,4 +32,31 @@ class TritonAMDGPU_Attr traits = [], : AttrDef { } +def TritonAMDGPU_OpIdxAttr : TritonAMDGPU_Attr<"OpIdx"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "OpIdx"; + let summary = "An operand index attribute."; + let description = [{ + The attribute is a way to describe which input argument of the target + operation (e.g., `tt.dot`) the result of a given operation belongs to. + }]; + + let parameters = (ins "uint32_t":$value); + let assemblyFormat = "`<` $value `>`"; +} + +def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "InstCounter"; + let summary = "An instruction counter attribute."; + let description = [{ + The attribute holds the number of issued LLVM instructions of a specific kind as well as + the data type. + }]; + + let parameters = (ins "uint32_t":$value, "Type":$type); + let assemblyFormat = "`<` params `>`"; +} + + #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td index d5956cf7a3..c0c18b07e9 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td @@ -35,6 +35,9 @@ def TritonAMDGPU_Dialect : Dialect { }]; let dependentDialects = []; + + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; } #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 4721d14ecb..1454d629b6 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -49,7 +49,27 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { interleave for better instruction level parallelism. }]; - let assemblyFormat = [{attr-dict}]; + let arguments = (ins + TritonAMDGPU_InstCounter:$numDsReadsA, + TritonAMDGPU_InstCounter:$numDsReadsB, + TritonAMDGPU_InstCounter:$numDsWritesA, + TritonAMDGPU_InstCounter:$numDsWritesB, + TritonAMDGPU_InstCounter:$numGlobalLoadsA, + TritonAMDGPU_InstCounter:$numGlobalLoadsB, + TritonAMDGPU_InstCounter:$numMMAs + ); + + let builders = [ + OpBuilder<(ins), [{ + auto ctx = $_state.getContext(); + auto type = IntegerType::get(ctx, 32); + auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, type); + build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr, + emptyAttr, emptyAttr, emptyAttr); + }]> + ]; + + let assemblyFormat = [{ attr-dict }]; } #endif diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index 67ff40d5b9..d6a95b86cd 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -35,9 +35,9 @@ std::unique_ptr> createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); std::unique_ptr> createConvertBuiltinFuncToLLVMPass(); std::unique_ptr> -createInsertInstructionSchedHintsPass(); +createTritonAMDGPUInsertInstructionSchedHintsPass(); std::unique_ptr> -createLowerInstructionSchedHintsPass(std::string variant); +createTritonAMDGPULowerInstructionSchedHintsPass(std::string variant); #define GEN_PASS_REGISTRATION #include "TritonAMDGPUToLLVM/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index ccb2b1898f..7383f1dfae 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -55,18 +55,20 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul } -def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::ModuleOp"> { +def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; - let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()"; + let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()"; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; } -def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { +def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Lower instruction scheduling hints to LLVM intrinsics"; - let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")"; + let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(\"\")"; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ Option<"variant", "variant", "std::string", /*default*/"\"default\"", diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 5631d56b24..2ee8a3cf25 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -23,6 +23,9 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" + +#include "llvm/ADT/TypeSwitch.h" // clang-format off #include "Dialect/TritonAMDGPU/IR/Dialect.h" @@ -44,5 +47,8 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { >(); } +#define GET_ATTRDEF_CLASSES +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" + #define GET_OP_CLASSES #include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 7c05ffede2..c0e774788e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "SharedToDotOperandHelper.h" #include "Utility.h" @@ -330,6 +331,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int elemsPerLoad = numOfElems / loadsPerThread; assert(numOfElems % loadsPerThread == 0); + VectorType loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), @@ -340,7 +342,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, for (int k = 0; k < numRepK; ++k) { auto vecTy = vec_ty(resElemTy, numOfElems); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset; loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; @@ -357,6 +358,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } } + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + MLIRContext *ctx = mfmaLayout.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index 6cdcddad39..400bcd9f65 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "SharedToDotOperandHelper.h" #include "Utility.h" @@ -212,6 +213,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int loadsPerThread = offsets.size() / (numRepNonK * numRepK); int elemsPerLoad = numElemsPerThreadPerRep / loadsPerThread; assert(numElemsPerThreadPerRep % loadsPerThread == 0); + auto loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), @@ -221,7 +223,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); Value valVec = undef(vecTy); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; loadOffset = add(loadOffset, batchOffset); @@ -237,6 +238,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } } + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + MLIRContext *ctx = wmmaLayout.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index c190711f10..bdc25f0a85 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -21,9 +21,9 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "TritonAMDGPUTransforms/MfmaGroup.h" #include "Utility.h" - #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" using namespace mlir; @@ -263,6 +263,14 @@ struct DotOpMFMAConversionHelper { Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(fc.size(), dstElemTy)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + Type elemtTy = elemTyA; + const size_t mmaCount = + numRepB * numRepM * numRepN * numRepK * kWidth / kBase; + setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(), + maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(), + elemtTy); + rewriter.replaceOp(op, res); return success(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 9f575be082..9ed21fa00d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -22,6 +22,7 @@ */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "Utility.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -325,6 +326,10 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, Type structTy = LLVM::LLVMStructType::getLiteral( wmmaLayout.getContext(), SmallVector(fc.size(), dstElemTy)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + const size_t mmaCount = numRepB * numRepM * numRepN * numRepK; + setNumGeneratedMMAs(op, mmaCount, mnkDim[0], mnkDim[1], mnkDim[2], elemTy); + rewriter.replaceOp(op, res); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 6009156cfc..0a51e6d7c3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,4 +1,5 @@ #include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" #include "TargetInfo.h" #include "Utility.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -203,6 +204,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, auto cacheMod = op.getCache(); SmallVector loadedVals; + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { size_t in_off = 0; @@ -215,7 +217,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, assert(wordNElems * nWords * numVecs == numElems); Value pred = mask ? maskElems[vecStart] : int_val(1, 1); - auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); mlir::Attribute zeroAttr = rewriter.getZeroAttr(valueElemTy); @@ -249,6 +250,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, rewriter, llvmResultStructTy); + + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + rewriter.replaceOp(op, {resultStruct}); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 9bed879619..3c30ae7e54 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -1,5 +1,4 @@ #include "TritonAMDGPUToLLVM/Passes.h" - #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" @@ -7,13 +6,77 @@ #include "triton/Dialect/Triton/IR/Dialect.h" namespace mlir::triton { -#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS -#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPULOWERINSTRUCTIONSCHEDHINTS #include "TritonAMDGPUToLLVM/Passes.h.inc" } // namespace mlir::triton using namespace mlir; +namespace mlir::triton { +void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType) { + auto *ctx = op->getContext(); + auto mmaType = RankedTensorType::get({m, n, k}, elementType); + auto counterAttr = amdgpu::InstCounterAttr::get(ctx, mmaCount, mmaType); + + op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { + schedHint.setNumMMAsAttr(counterAttr); + }); +} + +void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount, + Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type); + + op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { + auto opIdxAttr = + cast(op->getAttr(amdgpu::OpIdxAttr::getMnemonic())); + assert(opIdxAttr.getValue() < 2); + if (opIdxAttr.getValue() == 0) + schedHint.setNumGlobalLoadsAAttr(counterAttr); + else + schedHint.setNumGlobalLoadsBAttr(counterAttr); + }); +} + +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount, + Type type) { + auto *ctx = op->getContext(); + auto counterAttr = amdgpu::InstCounterAttr::get(ctx, dsReadsCount, type); + + op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { + Value dst = op.getResult(); + auto dstTensorTy = cast(dst.getType()); + auto dotOperandLayout = + cast(dstTensorTy.getEncoding()); + const size_t opIdx = dotOperandLayout.getOpIdx(); + assert(opIdx < 2); + if (opIdx == 0) + schedHint.setNumDsReadsAAttr(counterAttr); + else + schedHint.setNumDsReadsBAttr(counterAttr); + }); +} + +void storeOpConversionCallback(triton::gpu::LocalStoreOp op, + size_t localStoreOpCount, Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type); + + op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { + auto opIdxAttr = + op->getAttrOfType(amdgpu::OpIdxAttr::getMnemonic()); + assert(opIdxAttr.getValue() < 2); + if (opIdxAttr.getValue() == 0) + schedHint.setNumDsWritesAAttr(counterAttr); + else + schedHint.setNumDsWritesBAttr(counterAttr); + }); +} +} // namespace mlir::triton + namespace { // The bitmask that encodes kinds of the instructions from AMD ISA. @@ -52,7 +115,7 @@ void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, } // Insert intrinsic that controls the types of instructions that may be -// allowed to cross the intrinsic during instruction scheduling +// allowed to cross the intrinsic during instruction scheduling. Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, int64_t maskValue) { MLIRContext *ctx = rewriter.getContext(); @@ -78,7 +141,7 @@ Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { } struct InstructionSchedHintsRewriter - : public OpRewritePattern { + : public OpRewritePattern { InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant) : OpRewritePattern(ctx) { @@ -89,13 +152,119 @@ struct InstructionSchedHintsRewriter .Case("default", SchedulingType::NONE) .Case("iglp0", SchedulingType::IGLP0) .Case("iglp1", SchedulingType::IGLP1) + .Case("ck_v3", SchedulingType::CK_V3) .Default(SchedulingType::UNKNOWN); } - enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN }; + enum class SchedulingType : uint32_t { + NONE = 0, + IGLP0, + IGLP1, + CK_V3, + UNKNOWN + }; + + // This is the implementation of the CK's V3 pipelining (see + // see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp). + // This scheduling requires 1x register and 1x LDS buffers combined with the + // local (LDS to registers) and global (HBN to registers) data prefetching. + // see: + // include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.h + void createCKV3Schedule(PatternRewriter &rewriter, Location loc, + amdgpu::InstructionSchedHint schedHint) const { + const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue(); + const uint32_t numDsReadInstB = schedHint.getNumDsReadsB().getValue(); + + const uint32_t numDsWriteInstA = schedHint.getNumDsWritesA().getValue(); + const uint32_t numDsWriteInstB = schedHint.getNumDsWritesB().getValue(); + + const uint32_t numBufferLoadInstA = + schedHint.getNumGlobalLoadsA().getValue(); + const uint32_t numBufferLoadInstB = + schedHint.getNumGlobalLoadsB().getValue(); + + const uint32_t numMfmaInst = schedHint.getNumMMAs().getValue(); + + auto mfmaType = cast(schedHint.getNumMMAs().getType()); + const uint32_t nPerXDL = mfmaType.getShape()[1]; + const uint32_t mfmaCycle = nPerXDL == 16 ? 16 : 32; + + auto dsReadsAType = cast(schedHint.getNumDsReadsA().getType()); + auto dsReadsBType = cast(schedHint.getNumDsReadsB().getType()); + + const uint32_t dsReadAIssueCycle = dsReadsAType.getShape()[0] == 16 ? 8 : 4; + const uint32_t dsReadBIssueCycle = dsReadsBType.getShape()[0] == 16 ? 8 : 4; + + const auto dsReadAMfmaRate = + (mfmaCycle - 4 + 2 * dsReadAIssueCycle - 1) / (2 * dsReadAIssueCycle); + const auto dsReadBMfmaRate = + (mfmaCycle - 4 + 2 * dsReadBIssueCycle - 1) / (2 * dsReadBIssueCycle); + + const auto numDsreadAMfma = + (numDsReadInstA + dsReadAMfmaRate - 1) / dsReadAMfmaRate; + const auto numDsreadBMfma = + (numDsReadInstB + dsReadBMfmaRate - 1) / dsReadBMfmaRate; + + // stage 1 + const auto numMfmaStage1 = numMfmaInst - (numDsreadAMfma + numDsreadBMfma); + const auto num_mfma_per_issue = + numMfmaStage1 / (numBufferLoadInstA + numBufferLoadInstB); + + const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA; + const auto numDswritePerIssueB = numDsWriteInstB / numBufferLoadInstB; + + for (size_t i = 0; i < numBufferLoadInstA; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueA; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_WRITE, 1, + 0); + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + } + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::VMEM_READ, 1, + 0); + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, + num_mfma_per_issue - numDswritePerIssueA, 0); + } + + for (size_t i = 0; i < numBufferLoadInstB; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueB; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_WRITE, 1, + 0); + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + } + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::VMEM_READ, 1, + 0); + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, + num_mfma_per_issue - numDswritePerIssueB, 0); + } + + // stage 2 + for (size_t i = 0; i < numDsreadAMfma; ++i) { + if ((numDsReadInstA - (i + 1) * dsReadAMfmaRate) >= dsReadAMfmaRate) { + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_READ, + dsReadAMfmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, InstructionKindMask::DS_READ, + numDsReadInstA - (numDsreadAMfma - 1) * dsReadAMfmaRate, 0); + } + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + } + + for (size_t i = 0; i < numDsreadBMfma; ++i) { + if ((numDsReadInstB - (i + 1) * dsReadBMfmaRate) >= dsReadBMfmaRate) { + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_READ, + dsReadBMfmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, InstructionKindMask::DS_READ, + numDsReadInstB - (numDsreadBMfma - 1) * dsReadBMfmaRate, 0); + } + createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0); + } + } LogicalResult - matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, + matchAndRewrite(amdgpu::InstructionSchedHint instructionSchedHint, PatternRewriter &rewriter) const override { if (this->schedulingType == SchedulingType::UNKNOWN) { @@ -110,7 +279,8 @@ struct InstructionSchedHintsRewriter // not supposed to be used together with IGLP OPT according to the AMDGPU // backend documentation. const bool limitSchedulingRange = - !(schedulingType == SchedulingType::IGLP0 || + !(schedulingType == SchedulingType::NONE || + schedulingType == SchedulingType::IGLP0 || schedulingType == SchedulingType::IGLP1); Location loc = instructionSchedHint->getLoc(); Block *block = instructionSchedHint->getBlock(); @@ -128,6 +298,10 @@ struct InstructionSchedHintsRewriter createIglpOpt(rewriter, loc, static_cast(schedulingType) - 1); break; } + case SchedulingType::CK_V3: { + createCKV3Schedule(rewriter, loc, instructionSchedHint); + break; + } case SchedulingType::NONE: [[fallthrough]]; default: { @@ -146,11 +320,11 @@ struct InstructionSchedHintsRewriter SchedulingType schedulingType; }; -struct LowerInstructionSchedHints - : public triton::impl::LowerInstructionSchedHintsBase< - LowerInstructionSchedHints> { +struct TritonAMDGPULowerInstructionSchedHints + : public triton::impl::TritonAMDGPULowerInstructionSchedHintsBase< + TritonAMDGPULowerInstructionSchedHints> { - explicit LowerInstructionSchedHints(std::string variant) { + explicit TritonAMDGPULowerInstructionSchedHints(std::string variant) { this->variant = variant; } @@ -160,7 +334,7 @@ struct LowerInstructionSchedHints ConversionTarget target(*ctx); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); RewritePatternSet patterns(ctx); patterns.add(ctx, this->variant); @@ -172,32 +346,200 @@ struct LowerInstructionSchedHints } }; -struct InsertInstructionSchedHints - : public triton::impl::InsertInstructionSchedHintsBase< - InsertInstructionSchedHints> { +struct TritonAMDGPUInsertInstructionSchedHints + : public triton::impl::TritonAMDGPUInsertInstructionSchedHintsBase< + TritonAMDGPUInsertInstructionSchedHints> { void runOnOperation() override { MLIRContext *ctx = &getContext(); ModuleOp mod = getOperation(); - mod->walk([ctx](triton::DotOp dot) { - if (dyn_cast(dot->getParentOp())) { + mod.walk([this, ctx](scf::ForOp forOp) { + triton::DotOp dot = nullptr; + size_t dotCounter = 0; + forOp->walk([&dot, &dotCounter](triton::DotOp op) { + dot = op; + ++dotCounter; + }); + // Note, instruction schedule barriers are inserted only in the case of + // a single `tt.dot` op in a `scf::ForOp` scope in the current + // implementation. + if (dotCounter == 1) { mlir::OpBuilder rewriter(ctx); rewriter.setInsertionPointAfter(dot); - rewriter.create(dot->getLoc()); + rewriter.create(dot->getLoc()); + annotateDotUsageOnLoadStore(forOp); } }); } + + template bool isOf(Operation *op) const { + return llvm::isa(op); + } + + template + llvm::SmallVector getUsersOfTypes(Value value) const { + llvm::SmallVector concreteUsers; + for (auto user : value.getUsers()) { + std::vector values = {(isOf(user), ...)}; + if (llvm::any_of(values, [](bool value) { return value; })) + concreteUsers.push_back(user); + } + return concreteUsers; + } + + template + llvm::SmallVector getUsersOfType(Value value) const { + auto users = getUsersOfTypes(value); + llvm::SmallVector concreteUsers; + for (auto user : getUsersOfTypes(value)) { + concreteUsers.push_back(cast(user)); + } + return concreteUsers; + } + + // Go through a single use chain of `convert_layout` and/or `fp_to_fp` Ops to + // get the final value after all conversions + Value rewindUnaryOps(Value value) const { + auto unaryOps = + getUsersOfTypes(value); + while (!unaryOps.empty()) { + assert(unaryOps.size() == 1); + value = unaryOps[0]->getResult(0); + unaryOps = + getUsersOfTypes( + value); + } + return value; + } + + // Given a `scf::ForOp`, the method finds and annotates all Ops which produce + // input values for the `tt.dot` operation. The algorithm handles software + // pipelining. Therefore, we start by tracking `tt.load` ops and unwind the + // data flow by looking up to the yielded values and iteration arguments of a + // given `scf::ForOp` till we find `ttg.local_store` Op. Once a + // `ttg.local_store` Op is found, we need a single yielded-arguments lookup to + // find the corresponding `ttg.local_load` Op from which we have a direct data + // flow path to the target `tt.dot` op. At this point, we can annotate all + // found Ops (i.e., `tt.load`, `ttg.local_store`) with the input argument + // index of the data to `tt.dot` Op. Here is an example of the resulting + // annotated TTGIR: + // + // %13:8 = scf.for %arg11 = %c0_i32 to %0 step %c1_i32 iter_args( + // %arg0 = %cst_1, %arg1 = %in_0, %arg2 = %in_1, %arg3 = %c0_i32, + // %arg4 = %in_2, %arg5 = %in_3, %arg6 = %in_4, %arg7 = %in_5) + // -> (...) : i32 { + // %1 = triton_gpu.local_load %arg4 : {OpIdx = 0} + // %2 = triton_gpu.local_load %arg5 : {OpIdx = 1} + // %3 = tt.dot %1, %2, %arg0 + // %4 = tt.addptr %arg1, %cst + // %5 = tt.addptr %arg2, %cst_0 + // %6 = tt.load %4 : {OpIdx = 0} + // %7 = tt.load %5 : {OpIdx = 1} + // %8 = arith.addi %arg3, %c1_i32 + // %9 = arith.cmpi slt, %8, %c2_i32 + // %10 = arith.select %9, %8, %c0_i32 + // %11 = triton_gpu.memdesc_subview %56[%10, %c0_i32, %c0_i32] + // triton_gpu.local_store %arg6, %11 : {OpIdx = 0} + // %12 = triton_gpu.memdesc_subview %57[%10, %c0_i32, %c0_i32] + // triton_gpu.local_store %arg7, %12 : {OpIdx = 1} + // scf.yield %3, %4, %5, %10, %11, %12, %6, %7 : (...) + // } + // + // Note, this is required for counting issued `llvm` instructions during + // lowering from TTGIR to LLVM dialects to perform advanced instruction + // scheduling. + void annotateDotUsageOnLoadStore(scf::ForOp forOp) const { + llvm::SmallVector loadOps; + forOp.walk( + [&loadOps](triton::LoadOp loadOp) { loadOps.push_back(loadOp); }); + + ValueRange yieldedValues = forOp.getYieldedValues(); + auto initArgs = forOp.getRegionIterArgs(); + + MLIRContext *ctx = forOp->getContext(); + mlir::OpBuilder rewriter(ctx); + + for (auto loadOp : loadOps) { + Value loadResult = loadOp.getResult(); + + // Unwind till the first carried loop iteration regarding `tt.load`. + Value loopCarriedLoadValue = loadResult; + bool foundFirstCarriedLoopIteration = false; + while (!foundFirstCarriedLoopIteration) { + auto it = llvm::find(yieldedValues, loopCarriedLoadValue); + if (it != yieldedValues.end()) { + size_t idx = std::distance(yieldedValues.begin(), it); + loopCarriedLoadValue = initArgs[idx]; + } else { + foundFirstCarriedLoopIteration = true; + } + } + + loopCarriedLoadValue = rewindUnaryOps(loopCarriedLoadValue); + assert(loopCarriedLoadValue.hasOneUse()); + + // Handle pipelining - i.e., `local_store`, `memdesc_subview`, + // `local_load` ops. + triton::gpu::LocalLoadOp localLoadOp = nullptr; + auto loadOpUser = *(loopCarriedLoadValue.user_begin()); + auto localStoreOp = llvm::dyn_cast(loadOpUser); + if (localStoreOp) { + auto subviewOp = localStoreOp.getDst() + .getDefiningOp(); + Value subviewResult = subviewOp.getResult(); + auto it = llvm::find(yieldedValues, subviewResult); + if (it != yieldedValues.end()) { + size_t idx = std::distance(yieldedValues.begin(), it); + Value loopCarriedSubviewValue = initArgs[idx]; + + auto subviewLoadOps = + getUsersOfType(loopCarriedSubviewValue); + assert(subviewLoadOps.size() == 1); + localLoadOp = *subviewLoadOps.begin(); + + loopCarriedLoadValue = localLoadOp.getResult(); + } else { + auto localLoadOps = + getUsersOfType(subviewResult); + assert(localLoadOps.size() == 1); + localLoadOp = *localLoadOps.begin(); + auto it = llvm::find(yieldedValues, localLoadOp.getResult()); + assert(it != yieldedValues.end()); + size_t idx = std::distance(yieldedValues.begin(), it); + loopCarriedLoadValue = initArgs[idx]; + } + loopCarriedLoadValue = rewindUnaryOps(loopCarriedLoadValue); + } + + // Find the corresponding `DotOp`. + auto dots = getUsersOfType(loopCarriedLoadValue); + assert(dots.size() == 1); + + // Find which `DotOp` argument the current `loadOp` belongs to. + auto dotOperands = dots.begin()->getOperands(); + auto it = llvm::find(dotOperands, loopCarriedLoadValue); + assert(it != dotOperands.end()); + size_t opIdx = std::distance(dotOperands.begin(), it); + + // Set `OpIdx` attributes. + auto opIdxAttr = amdgpu::OpIdxAttr::get(ctx, opIdx); + + loadOp->setAttr(amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + if (localStoreOp) + localStoreOp->setAttr(amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + } + } }; } // namespace namespace mlir::triton { std::unique_ptr> -createLowerInstructionSchedHintsPass(std::string variant) { - return std::make_unique(variant); +createTritonAMDGPULowerInstructionSchedHintsPass(std::string variant) { + return std::make_unique(variant); } std::unique_ptr> -createInsertInstructionSchedHintsPass() { - return std::make_unique(); +createTritonAMDGPUInsertInstructionSchedHintsPass() { + return std::make_unique(); } } // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h new file mode 100644 index 0000000000..6b81dd0ab2 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h @@ -0,0 +1,22 @@ +#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_SCHED_INSTRUCTIONS_H +#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_SCHED_INSTRUCTIONS_H + +#include "mlir/IR/Types.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// The following functions are used to collect and set side-channel information +// during to LLVM conversion/lowering to facilitate instruction scheduling +// controls. +namespace mlir::triton { +void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType); +void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount, + Type type); +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t numDsReadsCount, + Type type); +void storeOpConversionCallback(triton::gpu::LocalStoreOp op, size_t llvmOpCount, + Type type); +} // namespace mlir::triton + +#endif diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index aa71c92666..b7c35e5532 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,6 +1,7 @@ #include "TritonAMDGPUToLLVM/Passes.h" #include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" #include "TargetInfo.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -20,6 +21,7 @@ #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -72,8 +74,9 @@ struct ConvertTritonAMDGPUToLLVM } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override { @@ -193,8 +196,12 @@ struct ConvertTritonAMDGPUToLLVM commonBenefit); populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns, commonBenefit); - mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, - patterns, commonBenefit); + + mlir::triton::BackendCallbacks callbacks; + callbacks.localStoreOpConversion = storeOpConversionCallback; + + mlir::triton::populateMemoryOpToLLVMPattern( + typeConverter, targetInfo, patterns, commonBenefit, callbacks); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, commonBenefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 1f9be3a428..9277940aa5 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -61,6 +61,14 @@ findEarlyInsertionPoint(Block *block, Operation *move) { return ipnt; } +// Check if the operation opInsideLoop is inside any scf::ForOp and +// opOutsideLoop is not inside the same loop. +bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { + scf::ForOp parentForOp = opInsideLoop->getParentOfType(); + return parentForOp && !parentForOp->isAncestor(opOutsideLoop); +} + class TritonAMDGPUReorderInstructionsPass : public TritonAMDGPUReorderInstructionsBase< TritonAMDGPUReorderInstructionsPass> { @@ -101,19 +109,28 @@ class TritonAMDGPUReorderInstructionsPass kv.first->moveBefore(kv.second); opToMove.clear(); - // Move writing to LDS and reading from LDS right after the loading of a - // tensor from global memory. There are 2 possible patterns depending on - // whether writing to LDS is done using an optional local_alloc argument or - // a local_store instruction: + // Adjust the placement of LDS writes and reads to immediately follow the + // definition of their operands in case where LDS write is in the + // loop but it's operand is not. This is a heuristic for optimizing fused + // attention by hoisting Q tensor LDS read/write operations outside of the + // loop, as Q is a loop invariant and can be loaded once before entering the + // loop. + // There are two possible patterns for this adjustment depending on + // whether the write to LDS is performed using an optional `local_alloc` + // argument or a `local_store` instruction. // - // 1) %1 = load %ptr + // clang-format off + // + // 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading) // %2 = local_alloc %1 // %3 = local_load %2 // - // 2) %1 = load %ptr + // 2) %1 = some_op ... // %2 = local_alloc // %3 = local_store %1, %2 // %4 = local_load %2 + // + // clang-format on m.walk([&](ttg::LocalLoadOp localLoad) { auto localAlloc = localLoad.getSrc().getDefiningOp(); if (!localAlloc) @@ -123,10 +140,15 @@ class TritonAMDGPUReorderInstructionsPass if (localAlloc->getNumOperands() == 1) { if (!localAlloc->hasOneUse()) return; - auto loadOp = localAlloc->getOperand(0).getDefiningOp(); - if (!loadOp) + + auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) { return; - localAlloc->moveAfter(loadOp); + } + + localAlloc->moveAfter(srcTensorOp); localLoad->moveAfter(localAlloc); return; } @@ -145,10 +167,14 @@ class TritonAMDGPUReorderInstructionsPass if (!isa(localStore)) return; - auto loadOp = localStore->getOperand(0).getDefiningOp(); - if (!loadOp) + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { return; - localAlloc->moveAfter(loadOp); + } + + localAlloc->moveAfter(srcTensorOp); localStore->moveAfter(localAlloc); localLoad->moveAfter(localStore); }); @@ -221,6 +247,78 @@ class TritonAMDGPUReorderInstructionsPass dfgop->moveBefore(block, block->begin()); } } + + /** + * Sched-load optimization for matmul kernels with large tile sizes + * The basic idea of sched-load optimization is to sink the 2nd tt.load + * after local_load so that global_load instructions can be interleaved with + * mfma's. This can help hide the issue latency of global_load instructions + * and improve performance on MI300X. + * + * It's assumed that the IR before this optimization has the following + * structure: + * ```mlir + * scf.for .. + * { + * tileA = tt.load a_ptr + * tileB = tt.load b_ptr + * opA = local_load bufferA + * opB = local_load bufferB + * res = tt.dot opA, opB + * local_store tileA, bufferA + * local_store tileB, bufferB + * } + * ``` + * After this optimization, the IR is transformed to + * ```mlir + * scf.for .. + * { + * tileA = tt.load a_ptr + * opA = local_load bufferA + * opB = local_load bufferB + * tileB = tt.load b_ptr <-- 2nd tt.load is sinked here + * res = tt.dot opA, opB + * local_store tileA, bufferA + * local_store tileB, bufferB + * } + * ``` + * For now, we don't have a perfect hueristic about when should this + * optimization be applied. Therefore, we implement a simple hueristic that + * this is applied when the tile size of A and B are large enough, i.e. + * nonKDim >= 128 and kDim >= 64. And also this is only applied for typical + * matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We + * are experimenting how to better control instruction scheduling and enable + * such optimizations. + */ + m.walk([&](scf::ForOp forOp) -> void { + SetVector loadOps; + triton::DotOp dotOp; + int nDotOps = 0; + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) + loadOps.insert(loadOp); + if (auto curOp = dyn_cast(&op)) { + nDotOps++; + dotOp = curOp; + } + } + // Only apply the optimization when there are 2 load's and 1 dot in the + // loop + if (loadOps.size() != 2 || nDotOps != 1) + return; + // Only apply the optimization when tile size is large enough + // 1. nonKDim >= 128 + // 2. kDim >= 64 + auto ldAOp = dyn_cast(loadOps[0]); + auto tileAShape = cast(ldAOp.getType()).getShape(); + auto ldBOp = dyn_cast(loadOps[1]); + auto tileBShape = cast(ldBOp.getType()).getShape(); + if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && + tileBShape[1] >= 128)) + return; + // move ldBOp right before tt.dot + loadOps[1]->moveBefore(dotOp); + }); } }; diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 5b5cca5b05..c481878eb6 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -45,11 +45,11 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { pm.addPass(createConvertBuiltinFuncToLLVMPass()); }); m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm) { - pm.addPass(createInsertInstructionSchedHintsPass()); + pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass()); }); m.def("lower_instruction_sched_hints", [](mlir::PassManager &pm, std::string variant) { - pm.addPass(createLowerInstructionSchedHintsPass(variant)); + pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass(variant)); }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, const std::string &arch) {