diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index b6bf5b3bea..12b660f5c2 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -26,7 +26,8 @@ The encoding is characterized by parameters: - `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type. - `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2. - `repCluster` indicates the cluster size of the repetitions of the DPAS tile. - - `sugGroupSize` Currently only sub group size 16 is supported. + - `threadsPerWarp_` AKA threadsPerWarp. It conflicts with the getThreadsPerWarp in DistributedLayout interface . + We use the name threadsPerWarp_ here. Currently only 16 is supported. The values of the matrix is distributed across the threads in the subgroup as row-major order. - If the column size of the matrix is equal to the number of threads in the subgroup, a single value name represents a single rows of the matrix. @@ -34,7 +35,7 @@ The values of the matrix is distributed across the threads in the subgroup as ro - If the column size of the matrix is larger than the number of the threads in the subgroup, a single row of the matrix requires multiple value name. Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16. -The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and sugGroupSize=16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and threadsPerWarp=16. The layout for A operand: K = 16 (K = systolic depth * opsPerChan) @@ -83,7 +84,7 @@ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16. -The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and sugGroupSize=16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and threadsPerWarp=16. The layout for A operand: K = 8 (K = systolic depth * opsPerChan) @@ -102,7 +103,7 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 8. The layouts for C and D operands are same as the one of opsPerChan=2. Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16. -The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and sugGroupSize=16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and threadsPerWarp=16. The layout for A operand: K = 32 (K = systolic depth * opsPerChan) @@ -121,15 +122,21 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 32. The layouts for C and D operands are same as the one of opsPerChan=2. The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks -along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing. +along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing. -Suppose we have a `tt.dot` operation of the block size [64, 128] += [64, 32] * [32, 128] of hf16/bf16. -The `warpsPerCTA` set to [2, 2]. The number of repetitions of the DPAS tile per warp is: A=8, B=8, C,D=16. -The DPAS repetitions are distributed as follows: +Suppose we have a `tt.dot` operation of the block size [64, 128] = [64, 32] * [32, 128] of f16/bf16. And its input tensor layout is defined as follows: +``` +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 2]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#dpas, kWidth=2}> - warp[:0] warp[:1] warp[:0] warp[:1] +%d = tt.dot %a, %b, %c : tensor<64x32xf16, #dot_operand_a> * tensor<32x128xf16, #dot_operand_b> -> tensor<64x128xf32, #dpas> +``` +The semantic of this `tt.dot` includes GEMM tiling configuration as: + + warp[:,0] warp[:,1] warp[:,0] warp[:,1] |----^----|----^----|----^----|----^----| - repCluster[1] + repCluster[1] <---------> ┌────┬────┬────┬────┬────┬────┬────┬────┐ │R0 │R1 │ │ │R4 │R5 │ │ │ @@ -142,25 +149,25 @@ The DPAS repetitions are distributed as follows: - ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐ | | │R0 │R2 │ │R0 │R1 │ │ │R4 │R5 │ │ │ | | │ │ │ │ │ │ │ │ │ │ │ │ - warp[0:] < repCluster[0] | ]────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + warp[0,:] < repCluster[0] | ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | | │R1 │R3 │ │R2 │R3 │ │ │R6 │R7 │ │ │ | | │ │ │ │ │ │ │ │ │ │ │ │ - v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | │ │ │ │ │ │ │ │ │ │ │ │ | │ │ │ │ │ │ │ │ │ │ │ │ - warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + warp[1,:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | │ │ │ │ │ │ │ │ │ │ │ │ | │ │ │ │ │ │ │ │ │ │ │ │ - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | │R4 │R6 │ │R8 │R9 │ │ │R12 │R13 │ │ │ | │ │ │ │ │ │ │ │ │ │ │ │ - warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + warp[0,:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | │R5 │R7 │ │R10 │R11 │ │ │R14 │R15 │ │ │ | │ │ │ │ │ │ │ │ │ │ │ │ - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | │ │ │ │ │ │ │ │ │ │ │ │ | │ │ │ │ │ │ │ │ │ │ │ │ - warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + warp[1,:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | │ │ │ │ │ │ │ │ │ │ │ │ | │ │ │ │ │ │ │ │ │ │ │ │ - └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘ @@ -175,7 +182,7 @@ The DPAS repetitions are distributed as follows: "unsigned":$opsPerChannel, ArrayRefParameter<"unsigned">:$warpsPerCTA__, ArrayRefParameter<"unsigned">:$repCluster, - "unsigned":$subGroupSize + "unsigned":$threadsPerWarp_ ); let extraClassDeclaration = extraDistributedDeclaration # [{ diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index baa0e3e347..92857c2c58 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -134,7 +134,7 @@ SmallVector DpasEncodingAttr::getShapeC() const { SmallVector DpasEncodingAttr::getSizePerThread() const { size_t rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); - unsigned threadsPerWarp = getSubGroupSize(); + unsigned threadsPerWarp = getThreadsPerWarp_(); auto shapeC = getDPASInstShapeC(); unsigned elemsNum = product(shapeC); unsigned elemsPerThread = elemsNum / threadsPerWarp; @@ -260,7 +260,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, mlir::Type eltTy, int kWidth, int opIdx) const { auto shapePerCTA = getShapePerCTA(*this, shape); auto rep = getDPASRepetitions(shapePerCTA, opIdx); - auto threadsPerWar = getSubGroupSize(); + auto threadsPerWar = getThreadsPerWarp_(); size_t rank = shape.size(); if (opIdx == 0) { auto shapeA = getShapeA(); @@ -296,7 +296,7 @@ SmallVector DpasEncodingAttr::getThreadsPerWarp() const { size_t rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); auto executionSize = getExecutionSize(); - auto subGroupSize = getSubGroupSize(); + auto subGroupSize = getThreadsPerWarp_(); if (subGroupSize < executionSize) { llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be " "smaller than the execution size"); @@ -340,7 +340,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const { assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout"); if (opIdx == 0) { SmallVector shapeA = getDPASInstShapeA(); - unsigned subGroupSize = getSubGroupSize(); + unsigned subGroupSize = getThreadsPerWarp_(); unsigned opsPerChannel = getOpsPerChannel(); // pack the value to i16 for scalar bit width <=16. @@ -359,7 +359,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const { if (opIdx == 1) { auto shapeB = getShapeB(); - auto subGroupSize = getSubGroupSize(); + auto subGroupSize = getThreadsPerWarp_(); auto executionSize = getExecutionSize(); if (subGroupSize < executionSize) { llvm::report_fatal_error("DpasEncodingAttr sub-group size could not " @@ -394,7 +394,7 @@ SmallVector DpasEncodingAttr::getContigPerThread() { assert(rank == 2 || rank == 3); SmallVector contigPerThread(rank, 1); - unsigned threadsPerWarp = getSubGroupSize(); + unsigned threadsPerWarp = getThreadsPerWarp_(); auto instShapeC = getDPASInstShapeC(); // The software vectorization vectorized the value as C array: int a[N] -> int // a[N][threadsPerWarp] @@ -506,7 +506,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const { << "systolicDepth = " << getSystolicDepth() << ", " << "executionSize = " << getExecutionSize() << ", " << "opsPerChan = " << getOpsPerChannel() << ", " - << "threadsPerWarp = " << getSubGroupSize() << ", " + << "threadsPerWarp = " << getThreadsPerWarp_() << ", " << "warpsPerCTA = [" << llvm::ArrayRef(warpsPerCTA) << "], " << "repCluster = [" << repCluster << "], " << "A = [" << rA << "], " diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index c20258b263..5991ddcd82 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -334,7 +334,7 @@ struct ConvertLayoutOpConversion size_t totalElems = elems.size(); auto numElemsPerOperand = product(dpasLayout.getDPASInstShapeC()) / - dpasLayout.getSubGroupSize(); + dpasLayout.getThreadsPerWarp_(); Type elemTy = this->getTypeConverter()->convertType(srcType.getElementType()); VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp index ecd8eb1140..7677d648b8 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp @@ -36,7 +36,7 @@ class DotOpDPASConversionHelper { Type i16Ty = type::i16Ty(ctx); Type s32Ty = IntegerType::get(ctx, 32, IntegerType::Signed); - unsigned threadsPerWarp = layout.getSubGroupSize(); + unsigned threadsPerWarp = layout.getThreadsPerWarp_(); unsigned opsPerChannel = layout.getOpsPerChannel(); SmallVector shapeC = layout.getDPASInstShapeC(); unsigned elemNumC = product(shapeC) / threadsPerWarp; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index 2160b8f17d..44a30a41a4 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -168,7 +168,7 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout, sizePerThreads[rank - 2] / repCluster[rank - 2], sizePerThreads[rank - 1] / repCluster[rank - 1]}; - unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1]; + unsigned rowsPerElem = dpasLayout.getThreadsPerWarp_() / instShapeC[1]; unsigned colsPerElem = 1; unsigned repNumber = product(repCluster); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 914e851b70..940586b698 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -176,10 +176,10 @@ struct DpasOperandPattern final : OpRewritePattern { // We want to transpose matrices of N*threads_per_warpxthreads_per_warp // shape. if ( // X axis condition - encoding.getExecutionSize() != encoding.getSubGroupSize() || + encoding.getExecutionSize() != encoding.getThreadsPerWarp_() || // Y axis conditions (encoding.getRepeatCount() * encoding.getRepCluster()[0]) % - encoding.getSubGroupSize() != + encoding.getThreadsPerWarp_() != 0) return failure();