Skip to content

Commit

Permalink
Rewrite RewriteTensorPointer pass to support 2D block load (#958)
Browse files Browse the repository at this point in the history
This is the first PR separated from
#941
This PR focuses on rewriting the `RewriteTensorPointer` pass, so we can
allow `tt.load` with tensor pointer pattern in our compilation pipeline,
rather than being rewriten to legacy load.

---------

Signed-off-by: Tiotto, Ettore <[email protected]>
Co-authored-by: Whitney Tsang <[email protected]>
Co-authored-by: Tiotto, Ettore <[email protected]>
  • Loading branch information
3 people authored Apr 29, 2024
1 parent 4f4af05 commit c8ac581
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 80 deletions.
4 changes: 3 additions & 1 deletion include/triton/Dialect/TritonIntelGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ std::unique_ptr<Pass> createTritonIntelGPUDistributeToWarpsPass();

std::unique_ptr<Pass> createTritonIntelGPURemoveLayoutConversionsPass();

std::unique_ptr<Pass> createTritonIntelGPURewriteTensorPointerPass();
std::unique_ptr<Pass> createTritonIntelGPURewriteTensorPointerPass(
triton::gpu::intel::DeviceArch arch =
triton::gpu::intel::DeviceArch::UNKNOWN);

std::unique_ptr<Pass> createPrefetchBlockPass();

Expand Down
20 changes: 16 additions & 4 deletions include/triton/Dialect/TritonIntelGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,28 @@ def TritonIntelGPURemoveLayoutConversions : Pass<"tritonintelgpu-remove-layout-c
}

def TritonIntelGPURewriteTensorPointer : Pass<"tritonintelgpu-rewrite-tensor-pointer", "mlir::ModuleOp"> {
let summary = "Rewrite load/stores with tensor pointers into legacy load/stores";
let summary = "Rewrite load/store operations using tensor pointers that cannot be lowered to 2D Block Load/Store intrinsics";
let description = [{
This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy
semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute
the pointer/mask/other for each load/store.
This pass determines whether a load/store operation can be lowered to 2D
Block Load/Store intrinsic. If it cannot, it replaces the load/store
operation with a legacy pointer and removes the Triton operations that
create and advance the block pointer (that is `tt.make_tensor_tr` and
`tt.advance`).
}];

let constructor = "mlir::triton::gpu::intel::createTritonIntelGPURewriteTensorPointerPass()";

let dependentDialects = ["mlir::triton::TritonDialect"];

let options = [
Option<"deviceArch", "device-architecture",
"mlir::triton::gpu::intel::DeviceArch", /*default*/" mlir::triton::gpu::intel::DeviceArch::PVC",
"device architecture",
"llvm::cl::values("
"clEnumValN(mlir::triton::gpu::intel::DeviceArch::UNKNOWN, \"UNKNOWN\", \"Unknown arch\"), "
"clEnumValN(mlir::triton::gpu::intel::DeviceArch::ATS, \"ATS\", \"ATS arch\"), "
"clEnumValN(mlir::triton::gpu::intel::DeviceArch::PVC, \"PVC\", \"PVC arch\"))">
];
}

def TritonIntelGPUPrefetchBlock : Pass<"tritonintelgpu-prefetch-block", "mlir::ModuleOp"> {
Expand Down
140 changes: 139 additions & 1 deletion test/TritonIntelGPU/rewrite-tensor-pointer.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,142 @@
// RUN: triton-opt %s -tritonintelgpu-rewrite-tensor-pointer | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritonintelgpu-rewrite-tensor-pointer=device-architecture=PVC | FileCheck %s

// COM: Case 1:
// COM: Check that operations using block pointers satisfying the following conditions are not rewritten:
// COM: - the block pointer has the "dot" layout attribute (with dpas parent layout)
// COM: - the block pointers is advanced in row major order: strides[order[0]] == 1
// COM: - the block pointer pitch is divisible by QW: strides[order[1]] % (64 / elemTypeBitWidth) == 0
// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 16], order = [1, 0]}>
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], A = [8, 16], B = [16, 16], C = [8, 16]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 16], order = [1, 0]}>
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], A = [8, 16], B = [16, 16], C = [8, 16]}>
module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
// CHECK: @matmul_kernel_with_block_pointers
%c4_i32 = arith.constant 4 : i32
%c256_i32 = arith.constant 256 : i32
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%c255_i32 = arith.constant 255 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas>
%0 = tt.get_program_id x : i32
%1 = arith.addi %arg3, %c255_i32 : i32
%2 = arith.divsi %1, %c256_i32 : i32
%3 = arith.addi %arg4, %c255_i32 : i32
%4 = arith.divsi %3, %c256_i32 : i32
%5 = arith.muli %4, %c4_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c4_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.minsi %8, %c4_i32 : i32
%10 = arith.remsi %0, %9 : i32
%11 = arith.addi %7, %10 : i32
%12 = arith.remsi %0, %5 : i32
%13 = arith.divsi %12, %9 : i32
%14 = arith.muli %11, %c256_i32 : i32
%15 = arith.extsi %arg3 : i32 to i64
%16 = arith.extsi %arg5 : i32 to i64
%17 = arith.extsi %arg6 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]]}>>>
%18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>>>
%19 = arith.muli %13, %c256_i32 : i32
%20 = arith.extsi %arg4 : i32 to i64
%21 = arith.extsi %arg7 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]]}>>>
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>>>) : i32 {
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]]}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]]}>>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>>>
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]]}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]]}>> -> tensor<256x256xf32, #[[DPAS]]>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]]}>>>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]]}>>>
%30 = tt.dot %28, %29, %arg10, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>> -> tensor<256x256xf32, #dpas>
%31 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>>>
%32 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>>>
scf.yield %30, %31, %32 : tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>>>
}
%24 = arith.truncf %23#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
%25 = triton_gpu.convert_layout %24 : tensor<256x256xf16, #dpas> -> tensor<256x256xf16, #blocked>
%26 = arith.extsi %arg8 : i32 to i64
// CHECK-NOT: tt.make_tensor_ptr
%27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #blocked>>
// CHECK: tt.store {{.*}}, {{.*}}, {{.*}} : tensor<256x256x!tt.ptr<f16>, #[[BLOCKED]]>
tt.store %27, %25 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #blocked>>
tt.return
}
}

// -----

// COM: Case 2:
// COM: Check that operations using block pointers without divisibility attribute are rewritten to use a legacy pointer.
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], A = [8, 16], B = [16, 16], C = [8, 16]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 16], order = [1, 0]}>
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], A = [8, 16], B = [16, 16], C = [8, 16]}>
module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func public @matmul_kernel_with_block_pointers_indivisible(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}) {
// CHECK: @matmul_kernel_with_block_pointers_indivisible
%c4_i32 = arith.constant 4 : i32
%c256_i32 = arith.constant 256 : i32
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%c255_i32 = arith.constant 255 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas>
%0 = tt.get_program_id x : i32
%1 = arith.addi %arg3, %c255_i32 : i32
%2 = arith.divsi %1, %c256_i32 : i32
%3 = arith.addi %arg4, %c255_i32 : i32
%4 = arith.divsi %3, %c256_i32 : i32
%5 = arith.muli %4, %c4_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c4_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.minsi %8, %c4_i32 : i32
%10 = arith.remsi %0, %9 : i32
%11 = arith.addi %7, %10 : i32
%12 = arith.remsi %0, %5 : i32
%13 = arith.divsi %12, %9 : i32
%14 = arith.muli %11, %c256_i32 : i32
%15 = arith.extsi %arg3 : i32 to i64
%16 = arith.extsi %arg5 : i32 to i64
%17 = arith.extsi %arg6 : i32 to i64
// CHECK-NOT: tt.make_tensor_ptr
%18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>>>
%19 = arith.muli %13, %c256_i32 : i32
%20 = arith.extsi %arg4 : i32 to i64
%21 = arith.extsi %arg7 : i32 to i64
// CHECK-NOT: tt.make_tensor_ptr
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>>>) : i32 {
// CHECK: tt.load {{.*}}, {{.*}} : tensor<256x32x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]]}>>
// CHECK: tt.load {{.*}}, {{.*}} : tensor<32x256x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]]}>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>>>
%30 = tt.dot %28, %29, %arg10, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>> -> tensor<256x256xf32, #dpas>
// CHECK-NOT: tt.advance
%31 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>>>
// CHECK-NOT: tt.advance
%32 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>>>
scf.yield %30, %31, %32 : tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas}>>>
}
%24 = arith.truncf %23#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
%25 = triton_gpu.convert_layout %24 : tensor<256x256xf16, #dpas> -> tensor<256x256xf16, #blocked>
%26 = arith.extsi %arg8 : i32 to i64
// CHECK-NOT: tt.make_tensor_ptr
%27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #blocked>>
// CHECK: tt.store {{.*}}, {{.*}}, {{.*}} : tensor<256x256x!tt.ptr<f16>, #[[BLOCKED]]>
tt.store %27, %25 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #blocked>>
tt.return
}
}

// -----

// COM: Case 3:
// COM: Check that operations using block pointers without a layout attribute are rewritten to use a legacy pointer.
tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
%c31_i32 = arith.constant 31 : i32
%c127_i32 = arith.constant 127 : i32
Expand Down
Loading

0 comments on commit c8ac581

Please sign in to comment.