diff --git a/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.h b/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.h index 45ea0a4564..c7d980abbd 100644 --- a/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.h @@ -30,7 +30,9 @@ std::unique_ptr createTritonIntelGPUDistributeToWarpsPass(); std::unique_ptr createTritonIntelGPURemoveLayoutConversionsPass(); -std::unique_ptr createTritonIntelGPURewriteTensorPointerPass(); +std::unique_ptr createTritonIntelGPURewriteTensorPointerPass( + triton::gpu::intel::DeviceArch arch = + triton::gpu::intel::DeviceArch::UNKNOWN); std::unique_ptr createPrefetchBlockPass(); diff --git a/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.td b/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.td index c9e27bb985..3139170d7f 100644 --- a/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -151,16 +151,28 @@ def TritonIntelGPURemoveLayoutConversions : Pass<"tritonintelgpu-remove-layout-c } def TritonIntelGPURewriteTensorPointer : Pass<"tritonintelgpu-rewrite-tensor-pointer", "mlir::ModuleOp"> { - let summary = "Rewrite load/stores with tensor pointers into legacy load/stores"; + let summary = "Rewrite load/store operations using tensor pointers that cannot be lowered to 2D Block Load/Store intrinsics"; let description = [{ - This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy - semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute - the pointer/mask/other for each load/store. + This pass determines whether a load/store operation can be lowered to 2D + Block Load/Store intrinsic. If it cannot, it replaces the load/store + operation with a legacy pointer and removes the Triton operations that + create and advance the block pointer (that is `tt.make_tensor_tr` and + `tt.advance`). }]; let constructor = "mlir::triton::gpu::intel::createTritonIntelGPURewriteTensorPointerPass()"; let dependentDialects = ["mlir::triton::TritonDialect"]; + + let options = [ + Option<"deviceArch", "device-architecture", + "mlir::triton::gpu::intel::DeviceArch", /*default*/" mlir::triton::gpu::intel::DeviceArch::PVC", + "device architecture", + "llvm::cl::values(" + "clEnumValN(mlir::triton::gpu::intel::DeviceArch::UNKNOWN, \"UNKNOWN\", \"Unknown arch\"), " + "clEnumValN(mlir::triton::gpu::intel::DeviceArch::ATS, \"ATS\", \"ATS arch\"), " + "clEnumValN(mlir::triton::gpu::intel::DeviceArch::PVC, \"PVC\", \"PVC arch\"))"> + ]; } def TritonIntelGPUPrefetchBlock : Pass<"tritonintelgpu-prefetch-block", "mlir::ModuleOp"> { diff --git a/test/TritonIntelGPU/rewrite-tensor-pointer.mlir b/test/TritonIntelGPU/rewrite-tensor-pointer.mlir index d3ac3e48d6..ca0b1e7a35 100644 --- a/test/TritonIntelGPU/rewrite-tensor-pointer.mlir +++ b/test/TritonIntelGPU/rewrite-tensor-pointer.mlir @@ -1,4 +1,142 @@ -// RUN: triton-opt %s -tritonintelgpu-rewrite-tensor-pointer | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritonintelgpu-rewrite-tensor-pointer=device-architecture=PVC | FileCheck %s + +// COM: Case 1: +// COM: Check that operations using block pointers satisfying the following conditions are not rewritten: +// COM: - the block pointer has the "dot" layout attribute (with dpas parent layout) +// COM: - the block pointers is advanced in row major order: strides[order[0]] == 1 +// COM: - the block pointer pitch is divisible by QW: strides[order[1]] % (64 / elemTypeBitWidth) == 0 +// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 16], order = [1, 0]}> +// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], A = [8, 16], B = [16, 16], C = [8, 16]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 16], order = [1, 0]}> +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], A = [8, 16], B = [16, 16], C = [8, 16]}> +module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + // CHECK: @matmul_kernel_with_block_pointers + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c255_i32 = arith.constant 255 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c255_i32 : i32 + %2 = arith.divsi %1, %c256_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c4_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c4_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c4_i32 : i32 + %10 = arith.remsi %0, %9 : i32 + %11 = arith.addi %7, %10 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %14 = arith.muli %11, %c256_i32 : i32 + %15 = arith.extsi %arg3 : i32 to i64 + %16 = arith.extsi %arg5 : i32 to i64 + %17 = arith.extsi %arg6 : i32 to i64 + // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> + %18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array} : >> + %19 = arith.muli %13, %c256_i32 : i32 + %20 = arith.extsi %arg4 : i32 to i64 + %21 = arith.extsi %arg7 : i32 to i64 + // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> + %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : >> + %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr>>, !tt.ptr>>) : i32 { + // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> + // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> + %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr>> + %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr>> + // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]]}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]]}>> -> tensor<256x256xf32, #[[DPAS]]> + // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> + // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> + %30 = tt.dot %28, %29, %arg10, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>> -> tensor<256x256xf32, #dpas> + %31 = tt.advance %arg11, [%c0_i32, %c32_i32] : >> + %32 = tt.advance %arg12, [%c32_i32, %c0_i32] : >> + scf.yield %30, %31, %32 : tensor<256x256xf32, #dpas>, !tt.ptr>>, !tt.ptr>> + } + %24 = arith.truncf %23#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas> + %25 = triton_gpu.convert_layout %24 : tensor<256x256xf16, #dpas> -> tensor<256x256xf16, #blocked> + %26 = arith.extsi %arg8 : i32 to i64 + // CHECK-NOT: tt.make_tensor_ptr + %27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array} : > + // CHECK: tt.store {{.*}}, {{.*}}, {{.*}} : tensor<256x256x!tt.ptr, #[[BLOCKED]]> + tt.store %27, %25 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + +// ----- + +// COM: Case 2: +// COM: Check that operations using block pointers without divisibility attribute are rewritten to use a legacy pointer. +// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], A = [8, 16], B = [16, 16], C = [8, 16]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 16], order = [1, 0]}> +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], A = [8, 16], B = [16, 16], C = [8, 16]}> +module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + tt.func public @matmul_kernel_with_block_pointers_indivisible(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}) { + // CHECK: @matmul_kernel_with_block_pointers_indivisible + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c255_i32 = arith.constant 255 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c255_i32 : i32 + %2 = arith.divsi %1, %c256_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c4_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c4_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c4_i32 : i32 + %10 = arith.remsi %0, %9 : i32 + %11 = arith.addi %7, %10 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %14 = arith.muli %11, %c256_i32 : i32 + %15 = arith.extsi %arg3 : i32 to i64 + %16 = arith.extsi %arg5 : i32 to i64 + %17 = arith.extsi %arg6 : i32 to i64 + // CHECK-NOT: tt.make_tensor_ptr + %18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array} : >> + %19 = arith.muli %13, %c256_i32 : i32 + %20 = arith.extsi %arg4 : i32 to i64 + %21 = arith.extsi %arg7 : i32 to i64 + // CHECK-NOT: tt.make_tensor_ptr + %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : >> + %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr>>, !tt.ptr>>) : i32 { + // CHECK: tt.load {{.*}}, {{.*}} : tensor<256x32x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]]}>> + // CHECK: tt.load {{.*}}, {{.*}} : tensor<32x256x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]]}>> + %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr>> + %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr>> + %30 = tt.dot %28, %29, %arg10, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>> -> tensor<256x256xf32, #dpas> + // CHECK-NOT: tt.advance + %31 = tt.advance %arg11, [%c0_i32, %c32_i32] : >> + // CHECK-NOT: tt.advance + %32 = tt.advance %arg12, [%c32_i32, %c0_i32] : >> + scf.yield %30, %31, %32 : tensor<256x256xf32, #dpas>, !tt.ptr>>, !tt.ptr>> + } + %24 = arith.truncf %23#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas> + %25 = triton_gpu.convert_layout %24 : tensor<256x256xf16, #dpas> -> tensor<256x256xf16, #blocked> + %26 = arith.extsi %arg8 : i32 to i64 + // CHECK-NOT: tt.make_tensor_ptr + %27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array} : > + // CHECK: tt.store {{.*}}, {{.*}}, {{.*}} : tensor<256x256x!tt.ptr, #[[BLOCKED]]> + tt.store %27, %25 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + +// ----- + +// COM: Case 3: +// COM: Check that operations using block pointers without a layout attribute are rewritten to use a legacy pointer. tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { %c31_i32 = arith.constant 31 : i32 %c127_i32 = arith.constant 127 : i32 diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp index dcbbabbccb..ebc39612d2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Pass/Pass.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonIntelGPU/IR/Dialect.h" #include "triton/Dialect/TritonIntelGPU/Transforms/Passes.h" @@ -17,6 +18,7 @@ using namespace mlir; namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; namespace ttgi = mlir::triton::gpu::intel; #define GEN_PASS_CLASSES @@ -24,39 +26,108 @@ namespace ttgi = mlir::triton::gpu::intel; namespace { -/// An additional struct to record the meta information of operations -/// with tensor pointers -struct RewritedInfo { -private: - Value base; - SmallVector shape; - SmallVector strides; - SmallVector offsets; - ArrayRef tensorShape; +/// Check if given value is divisible by the divisor. +bool isDivisible(Value value, unsigned divisor) { + // Case 1: Value is defined by a constant operation + if (auto constantOp = value.getDefiningOp()) { + auto integerAttr = dyn_cast(constantOp.getValue()); + return integerAttr && integerAttr.getValue().getZExtValue() % divisor == 0; + } - // A cache to avoid generating the same offset with range - DenseMap cachedOffsetWithRange; + // Case 2: Value is a block argument of the entry block + if (value.getParentBlock()->isEntryBlock() && isa(value)) { + BlockArgument blockArg = cast(value); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (auto funcOp = dyn_cast(parentOp)) { + auto divisibilityAttr = funcOp.getArgAttrOfType( + blockArg.getArgNumber(), "tt.divisibility"); + return divisibilityAttr && + divisibilityAttr.getValue().getZExtValue() % divisor == 0; + } + } -public: - RewritedInfo() = default; + // Case 3: Value is defined by a sign extension operation + if (auto extSIOp = value.getDefiningOp()) + return isDivisible(extSIOp->getOperand(0), divisor); + + return false; +} + +/// Check if the tensor pointer should be removed. The tensor pointer should be +/// removed if: +/// - the device architecture is not PVC +/// - the tensor pointer does not have DpasEncodingAttr +/// - the tensor pointer pitch is not divisible by Qword bitwidth +/// - the tensor pointer is not contiguous on memory +bool shouldRemove(tt::MakeTensorPtrOp &op, ttgi::DeviceArch deviceArch) { + // Non-PVC device should always remove the tensor pointer + if (deviceArch != ttgi::DeviceArch::PVC) + return true; + + auto ptrType = cast(op.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + + // Only keep the tensor pointer with the layout of DpasEncodingAttr + if (!tensorType.getEncoding()) + return true; + auto dotLayout = + dyn_cast(tensorType.getEncoding()); + if (!dotLayout) + return true; + auto dpasLayout = dyn_cast(dotLayout.getParent()); + if (!dpasLayout) + return true; + + TypedValue base = op.getBase(); + Operation::operand_range shape = op.getShape(); + Operation::operand_range strides = op.getStrides(); + Operation::operand_range offsets = op.getOffsets(); + ArrayRef order = op.getOrder(); + ArrayRef tensorShape = tensorType.getShape(); + + // TODO: support column-major tensor + // HW 2D block read instruction has restriction on pitch divisibility + if (strides.size() == 2) { + auto pitch = strides[order[1]]; + // PVC requires pitch to be a multiple of QWord(64 bits). + if (!isDivisible(pitch, 64 / tensorType.getElementTypeBitWidth())) + return true; + } + + // HW 2D block read instruction only supports contiguous accessing. + auto fastChangeStride = strides[order[0]]; + if (auto stride = + dyn_cast(fastChangeStride.getDefiningOp())) { + if (auto strideInt = dyn_cast(stride.getValue())) + return strideInt.getInt() != 1; + } - RewritedInfo(const RewritedInfo &other) = default; + return true; +} + +/// The `RewritedInfo` struct is used to store information about a rewritten +/// tensor pointer. It holds the base pointer, shape, strides, offsets, and +/// encoding of the tensor. This information is used later in the code to handle +/// the rewritten tensor pointer. +struct RewritedInfo { + RewritedInfo() = default; RewritedInfo(Value base, const SmallVector &shape, const SmallVector &strides, const SmallVector &offsets, - const ArrayRef &tensorShape) + const ArrayRef &tensorShape, Attribute layout) : base(base), shape(shape), strides(strides), offsets(offsets), - tensorShape(tensorShape) { + tensorShape(tensorShape), layout(layout) { assert(shape.size() == strides.size() && shape.size() == offsets.size() && - shape.size() == tensorShape.size()); + shape.size() == tensorShape.size() && + "Expecting tensor shape, offsets and strides have the same size"); } unsigned int length() const { return shape.size(); } - Value getOffset(unsigned i) { return offsets[i]; } + Value getOffset(unsigned i) const { return offsets[i]; } - SmallVector getOffsets() { return offsets; } + SmallVector getOffsets() const { return offsets; } void setOffset(unsigned i, Value newOffset) { offsets[i] = newOffset; @@ -68,16 +139,54 @@ struct RewritedInfo { cachedOffsetWithRange.clear(); } - Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + void setEncoding(Attribute newLayout) { layout = newLayout; } + + // Creates a tensor with the values [0, tensorShape[axis]) + offsets[axis] + // broadcasted to N dimensions along axis (i.e. so that + // result[.., i, ...] = offsets[axis] + i). + Value getExpandedOffsetWithRange(OpBuilder &builder, Location loc, unsigned i) { if (cachedOffsetWithRange.count(i)) - return cachedOffsetWithRange[i]; + return cachedOffsetWithRange.at(i); + + // Ultimately this will look like: + // + // % base = create_range ... : tensor + // %a0 = expand_dims %base : tensor + // %a1 = broadcast %a0 : tensor + // %b0 = expand_dims %a1 : tensor + // %b1 = broadcast %b1 : tensor + // ... + // + // The final result has layout this->layout. When we subtract a dim, + // that's equivalent to taking a sliced layout, so e.g. the layout of + // %a0/%a1 is a slice of %b0/%b1's layout. + size_t rank = tensorShape.size(); + MLIRContext *ctx = loc.getContext(); + + // This code is creating a vector of layout attributes for a tensor. If a + // layout is provided, it sets the layout of each axis based on the layout + // of the previous axis, starting from the last axis and moving towards the + // first. If the current axis is the one to remove, it skips it and moves to + // the previous axis. + SmallVector layouts(rank); + if (layout) { + layouts[rank - 1] = layout; + size_t axisToRemove = rank - 1; + for (int64_t k = rank - 2; k >= 0; --k) { + if (axisToRemove == i) + --axisToRemove; + layouts[k] = + ttg::SliceEncodingAttr::get(ctx, axisToRemove, layouts[k + 1]); + --axisToRemove; + } + } // Add range - auto indexI32RowType = - RankedTensorType::get({tensorShape[i]}, builder.getI32Type()); - auto indexRowType = - RankedTensorType::get({tensorShape[i]}, builder.getI64Type()); + auto indexI32RowType = RankedTensorType::get( + {tensorShape[i]}, builder.getI32Type(), layouts[0]); + auto indexRowType = RankedTensorType::get({tensorShape[i]}, + builder.getI64Type(), layouts[0]); Value splatOffset = builder.create(loc, indexRowType, offsets[i]); Value range = builder.create(loc, indexI32RowType, 0, @@ -98,11 +207,12 @@ struct RewritedInfo { Value generatePtr(OpBuilder &builder, const Location &loc) { assert(tensorShape.size() == offsets.size() && - tensorShape.size() == strides.size()); + tensorShape.size() == strides.size() && + "Expecting tensor shape, offsets and strides have the same size"); auto indexTensorType = - RankedTensorType::get(tensorShape, builder.getI64Type()); - auto ptrType = base.getType().cast(); - auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType); + RankedTensorType::get(tensorShape, builder.getI64Type(), layout); + auto ptrType = cast(base.getType()); + auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType, layout); // Generate offsets per dimension Value ptr = builder.create(loc, ptrTensorType, base); @@ -132,7 +242,7 @@ struct RewritedInfo { // Generate mask per dimension auto maskTensorType = - RankedTensorType::get(tensorShape, builder.getI1Type()); + RankedTensorType::get(tensorShape, builder.getI1Type(), layout); Value mask; for (auto i : boundaryCheck.value()) { auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); @@ -168,25 +278,27 @@ struct RewritedInfo { } Value generateOther(OpBuilder &builder, const Location &loc, - const std::optional &padding) { + const std::optional &padding) const { if (!padding.has_value()) return Value(); // Create element attribute - auto elementType = base.getType().cast().getPointeeType(); - auto otherTensorType = RankedTensorType::get(tensorShape, elementType); + auto elementType = cast(base.getType()).getPointeeType(); + auto otherTensorType = + RankedTensorType::get(tensorShape, elementType, layout); // Set zero padding value TypedAttr attr = elementType.isIntOrIndex() - ? builder.getIntegerAttr(elementType, 0).cast() - : builder.getFloatAttr(elementType, 0).cast(); + ? cast(builder.getIntegerAttr(elementType, 0)) + : cast(builder.getFloatAttr(elementType, 0)); // Float NaN padding case if (padding.value() == tt::PaddingOption::PAD_NAN) { - assert(!elementType.isIntOrIndex()); + assert(!elementType.isIntOrIndex() && + "Expect element type to be non-integer type"); auto apNaN = llvm::APFloat::getNaN( - attr.cast().getValue().getSemantics()); + cast(attr).getValue().getSemantics()); attr = builder.getFloatAttr(elementType, apNaN); } @@ -194,32 +306,48 @@ struct RewritedInfo { Value constant = builder.create(loc, attr); return builder.create(loc, otherTensorType, constant); } -}; +private: + Value base; + SmallVector shape; + SmallVector strides; + SmallVector offsets; + ArrayRef tensorShape; + Attribute layout; + + // A cache to avoid generating the same offset with range + DenseMap cachedOffsetWithRange; +}; } // namespace // TODO: this pass relies on assumptions of how block pointers are created and // on pattern matches that walks the SSA links to find the base/strides. This is // very fragile and to solve we should expose convert Ptr of tensor to a -// structure containins all values and not only offsets. +// structure contains all values and not only offsets. class TritonIntelGPURewriteTensorPointerPass : public TritonIntelGPURewriteTensorPointerBase< TritonIntelGPURewriteTensorPointerPass> { private: DenseMap rewritedInfo; + DenseSet valueToRemove; public: - static bool needRewrite(Operation *op) { + TritonIntelGPURewriteTensorPointerPass(ttgi::DeviceArch arch) { + deviceArch = arch; + } + + static bool needRewrite(Operation *op, const DenseSet &valueToRemove) { return std::any_of(op->getOperands().begin(), op->getOperands().end(), - [](Value operand) { - return tt::isTensorPointerType(operand.getType()); + [&valueToRemove](Value operand) { + return tt::isTensorPointerType(operand.getType()) && + valueToRemove.count(operand); }); } static SmallVector generateNewOperands(const SmallVector &oldOperands, unsigned index, const SmallVector &newValues) { - assert(index < oldOperands.size()); + assert(index < oldOperands.size() && "Index out of range"); SmallVector newOperands; for (int i = 0; i < index; ++i) newOperands.push_back(oldOperands[i]); @@ -232,9 +360,11 @@ class TritonIntelGPURewriteTensorPointerPass Operation *rewriteMakeTensorPtrOp(OpBuilder &builder, tt::MakeTensorPtrOp op, std::stack &eraser) { + if (!valueToRemove.count(op.getResult())) + return nullptr; // Save info for later use - auto ptrType = op.getType().cast(); - auto tensorType = ptrType.getPointeeType().cast(); + auto ptrType = cast(op.getType()); + auto tensorType = cast(ptrType.getPointeeType()); // Cast I32 offsets into I64 SmallVector i64Offsets; @@ -247,7 +377,7 @@ class TritonIntelGPURewriteTensorPointerPass // Save information rewritedInfo[op.getResult()] = RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets, - tensorType.getShape()); + tensorType.getShape(), tensorType.getEncoding()); // Erase the original operation eraser.push(op); @@ -256,12 +386,17 @@ class TritonIntelGPURewriteTensorPointerPass Operation *rewriteAdvanceOp(OpBuilder &builder, tt::AdvanceOp op, std::stack &eraser) { + if (!valueToRemove.count(op.getResult())) + return nullptr; + // Get info from previous results - assert(rewritedInfo.count(op.getPtr())); + assert(rewritedInfo.count(op.getPtr()) && + "Expecting AdvanceOp ptr in rewritedInfo"); auto info = rewritedInfo[op.getPtr()]; // Calculate new offsets - assert(info.length() == op.getOffsets().size()); + assert(info.length() == op.getOffsets().size() && + "Expecting AdvanceOp ptr shape and offsets have the same size"); SmallVector newOffsets; for (int i = 0; i < info.length(); ++i) { Value i64Offset = builder.create( @@ -282,15 +417,15 @@ class TritonIntelGPURewriteTensorPointerPass Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op, std::stack &eraser) { - assert(isa(op) || isa(op)); - - // We only have to rewrite load/stores with tensor pointers - auto ptr = op->getOperand(0); - if (!tt::isTensorPointerType(ptr.getType())) + assert(isa(op) || + isa(op) && "Expecting LoadOp or StoreOp"); + if (!valueToRemove.count(op->getOperand(0))) return nullptr; // Get info from previous results - assert(rewritedInfo.count(ptr)); + auto ptr = op->getOperand(0); + assert(rewritedInfo.count(ptr) && + "Expecting LoadOp/StoreOp ptr in rewritedInfo"); auto info = rewritedInfo[ptr]; // Load/store with tensor pointers implicitly will check the bound while @@ -299,11 +434,20 @@ class TritonIntelGPURewriteTensorPointerPass // `other` while building IR from Python AST std::optional> boundaryCheck; if (auto loadOp = dyn_cast(op)) { - assert(!loadOp.getMask() && !loadOp.getOther()); + assert(!loadOp.getMask() && !loadOp.getOther() && + "LoadOp with tensor pointer should not have mask and other"); boundaryCheck = loadOp.getBoundaryCheck(); - } else if (auto storeOp = dyn_cast(op)) { - assert(!storeOp.getMask()); + if (auto valueType = + dyn_cast(loadOp.getResult().getType())) + info.setEncoding(valueType.getEncoding()); + } else { + auto storeOp = cast(op); + assert(!storeOp.getMask() && + "StoreOp with tensor pointer should not have mask"); boundaryCheck = storeOp.getBoundaryCheck(); + if (auto valueType = + dyn_cast(storeOp.getValue().getType())) + info.setEncoding(valueType.getEncoding()); } // Generate new `ptr`, `mask` and `other` @@ -333,20 +477,24 @@ class TritonIntelGPURewriteTensorPointerPass Operation *rewriteIfOp(OpBuilder &builder, scf::IfOp op, std::stack &eraser) { auto thenYieldOp = op.thenYield(); - assert(op.getNumResults() == thenYieldOp.getNumOperands()); + assert(op.getNumResults() == thenYieldOp.getNumOperands() && + "Expecting IfOp results and its thenYieldOp operands have the same " + "number"); SmallVector results = thenYieldOp.getOperands(); // get new result types SmallVector newRetTypes; bool needRewrite = false; for (unsigned i = 0; i < results.size(); ++i) { - if (!tt::isTensorPointerType(results[i].getType())) { + if (!tt::isTensorPointerType(results[i].getType()) || + !valueToRemove.count(results[i])) { newRetTypes.push_back(results[i].getType()); continue; } needRewrite = true; auto makeTensorPtrOp = getMakeTensorPtrOp(results[i]); - assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + assert(rewritedInfo.count(makeTensorPtrOp.getResult()) && + "Expecting MakeTensorPtrOp of IfOp result in rewritedInfo"); auto info = rewritedInfo[makeTensorPtrOp.getResult()]; for (unsigned j = 0; j < info.length(); ++j) { newRetTypes.push_back(builder.getI64Type()); @@ -374,15 +522,22 @@ class TritonIntelGPURewriteTensorPointerPass rematerialize(op.elseBlock()); } + // supported nested ops + for (auto &[k, v] : mapping.getValueMap()) + if (valueToRemove.find(k) != valueToRemove.end()) + valueToRemove.insert(v); + // update rewritedInfo unsigned oldResIdx = 0, newResIdx = 0; while (oldResIdx < results.size()) { - if (!tt::isTensorPointerType(results[oldResIdx].getType())) { + if (!tt::isTensorPointerType(results[oldResIdx].getType()) || + !valueToRemove.count(results[oldResIdx])) { oldResIdx++; newResIdx++; } else { auto makeTensorPtrOp = getMakeTensorPtrOp(results[oldResIdx]); - assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + assert(rewritedInfo.count(makeTensorPtrOp.getResult()) && + "Expecting MakeTensorPtrOp of IfOp result in rewritedInfo"); auto info = rewritedInfo[makeTensorPtrOp.getResult()]; for (unsigned j = 0; j < info.length(); ++j) { info.setOffset(j, newOp->getResult(newResIdx++)); @@ -405,9 +560,12 @@ class TritonIntelGPURewriteTensorPointerPass ++i, ++oldI) { if (!tt::isTensorPointerType(newIterOperands[i].getType())) continue; + if (!valueToRemove.count(newIterOperands[i])) + continue; // Expand the tensor pointer into offsets - assert(rewritedInfo.count(newIterOperands[i])); + assert(rewritedInfo.count(newIterOperands[i]) && + "Expecting ForOp operands in rewritedInfo"); auto info = rewritedInfo[newIterOperands[i]]; newIterOperands = generateNewOperands(newIterOperands, i, info.getOffsets()); @@ -427,9 +585,11 @@ class TritonIntelGPURewriteTensorPointerPass for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; ++i, ++oldI) { auto oldRegionIterArg = op.getRegionIterArg(oldI); - if (tt::isTensorPointerType(oldRegionIterArg.getType())) { + if (tt::isTensorPointerType(oldRegionIterArg.getType()) && + valueToRemove.count(oldIterOperands[oldI])) { // Pass rewrited info inside - assert(rewritedInfo.count(oldIterOperands[oldI])); + assert(rewritedInfo.count(oldIterOperands[oldI]) && + "Expecting ForOp operands in rewritedInfo"); auto info = rewritedInfo[oldIterOperands[oldI]]; mapping.map(oldRegionIterArg, oldRegionIterArg); for (unsigned j = 0; j < info.length(); ++j) @@ -446,17 +606,28 @@ class TritonIntelGPURewriteTensorPointerPass builder.setInsertionPointToStart(newForOp.getBody()); for (auto &opInFor : *op.getBody()) { auto *newOp = builder.clone(opInFor, mapping); - for (unsigned i = 0; i < opInFor.getNumResults(); ++i) + for (unsigned i = 0; i < opInFor.getNumResults(); ++i) { + if (valueToRemove.count(opInFor.getResult(i))) + valueToRemove.insert(newOp->getResult(i)); mapping.map(op->getResult(i), newOp->getResult(i)); + } } + // supported nested scf.for ops + for (auto &[k, v] : mapping.getValueMap()) + if (valueToRemove.find(k) != valueToRemove.end()) + valueToRemove.insert(v); + // Replace later usages - assert(op.getNumResults() == op.getInitArgs().size()); + assert(op.getNumResults() == op.getInitArgs().size() && + "Expecting ForOp results and operands have the same number"); for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { auto oldResult = op.getResult(oldI); - if (tt::isTensorPointerType(oldResult.getType())) { + if (tt::isTensorPointerType(oldResult.getType()) && + valueToRemove.count(oldIterOperands[oldI])) { // Pack new offsets into rewrited info - assert(rewritedInfo.count(oldIterOperands[oldI])); + assert(rewritedInfo.count(oldIterOperands[oldI]) && + "Expecting ForOp operands in rewritedInfo"); auto info = rewritedInfo[oldIterOperands[oldI]]; for (unsigned j = 0; j < info.length(); ++j) info.setOffset(j, newForOp.getResult(i + j)); @@ -479,8 +650,11 @@ class TritonIntelGPURewriteTensorPointerPass for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) { if (!tt::isTensorPointerType(newOperands[i].getType())) continue; + if (!valueToRemove.count(newOperands[i])) + continue; - assert(rewritedInfo.count(newOperands[i])); + assert(rewritedInfo.count(newOperands[i]) && + "Expecting YieldOp operands in rewritedInfo"); auto info = rewritedInfo[newOperands[i]]; newOperands = generateNewOperands(newOperands, i, info.getOffsets()); i += info.length() - 1; @@ -498,7 +672,6 @@ class TritonIntelGPURewriteTensorPointerPass // Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers // Rewriting functions return the next operation to visit, if there is no // next one, simply return `nullptr` - std::pair rewrited; if (auto makeTensorPtrOp = dyn_cast(op)) { return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser); } else if (auto advanceOp = dyn_cast(op)) { @@ -510,7 +683,7 @@ class TritonIntelGPURewriteTensorPointerPass if (auto ifOp = dyn_cast(op)) { return rewriteIfOp(builder, ifOp, eraser); } - if (!needRewrite(op)) + if (!needRewrite(op, valueToRemove)) return op; if (auto forOp = dyn_cast(op)) { @@ -548,6 +721,35 @@ class TritonIntelGPURewriteTensorPointerPass } void runOnOperation() override { + ModuleOp mod = getOperation(); + + auto markTensorPointerForRemoval = [this](Value val) { + if (tt::isTensorPointerType(val.getType())) { + tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val); + if (shouldRemove(makeTensorPtrOp, deviceArch)) + valueToRemove.insert(val); + } + }; + + mod.walk([&](Operation *op) { + if (llvm::isa(op)) { + markTensorPointerForRemoval(op->getResult(0)); + } else if (llvm::isa(op)) { + markTensorPointerForRemoval(op->getOperand(0)); + } else if (llvm::isa(op)) { + // TODO: Block store should not be removed when 2d store is enabled + auto src = op->getOperand(0); + if (tt::isTensorPointerType(src.getType())) + valueToRemove.insert(src); + } else if (auto forOp = dyn_cast(op)) { + for (auto arg : forOp.getInitArgs()) + markTensorPointerForRemoval(arg); + } else if (auto yieldOp = dyn_cast(op)) { + for (auto operand : yieldOp.getOperands()) + markTensorPointerForRemoval(operand); + } + }); + // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because // MLIR does not support one-multiple value mapping. For example, if we use // `ConversionPatternRewriter`, we can not make a type converter, which @@ -565,6 +767,7 @@ class TritonIntelGPURewriteTensorPointerPass // The operation could not be erased during visit, because they may have // later usages, so we erase after visit rewritedInfo.clear(); + valueToRemove.clear(); while (!eraser.empty()) { auto op = eraser.top(); eraser.pop(); @@ -573,6 +776,7 @@ class TritonIntelGPURewriteTensorPointerPass } }; -std::unique_ptr ttgi::createTritonIntelGPURewriteTensorPointerPass() { - return std::make_unique(); +std::unique_ptr +ttgi::createTritonIntelGPURewriteTensorPointerPass(ttgi::DeviceArch arch) { + return std::make_unique(arch); } diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index fe8bb7bc26..610eff9634 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -48,6 +48,14 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { pm.addPass(mlir::triton::gpu::intel:: createTritonIntelGPURemoveLayoutConversionsPass()); }); + m.def( + "add_rewrite_tensor_pointer", + [](mlir::PassManager &pm, mlir::triton::gpu::intel::DeviceArch arch) { + pm.addPass(mlir::triton::gpu::intel:: + createTritonIntelGPURewriteTensorPointerPass(arch)); + }, + py::arg("pm"), + py::arg("arch") = mlir::triton::gpu::intel::DeviceArch::UNKNOWN); } void init_triton_intel_passes_ttnvgpuir(py::module &&m) {