From 77e0c536a264bdeb27fa7652fee606cd88d126bf Mon Sep 17 00:00:00 2001 From: Maxime France-Pillois Date: Mon, 8 Jul 2024 21:07:38 +0100 Subject: [PATCH] [RAISE-BP] Handle `scf.for` (#1546) Add the support to raise `tt.load`, `tt.store` and `tt.addptr` when this operations take place in a `scf.for` loop. It includes: - rewriting the loop and the yieldOp - adding a visitor for triton::MakeTensorPtrOp - early cleanup of unused AddptrOps - new tests --------- Signed-off-by: Maxime France-Pillois --- test/Triton/raise-block-pointer.mlir | 265 ++++++++++++ .../TritonRaiseBlockPointer.cpp | 393 +++++++++++++++++- 2 files changed, 654 insertions(+), 4 deletions(-) diff --git a/test/Triton/raise-block-pointer.mlir b/test/Triton/raise-block-pointer.mlir index 42f9177f96..efc356d2f0 100644 --- a/test/Triton/raise-block-pointer.mlir +++ b/test/Triton/raise-block-pointer.mlir @@ -236,3 +236,268 @@ tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr) -> tensor<128x2x128x %3 = tt.load %2 : tensor<128x2x128x!tt.ptr> tt.return %3 : tensor<128x2x128xf32> } + +// CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32 +// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64 +// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK: [[CST_0_i32:%.+]] = arith.constant 0 : i32 +// CHECK: [[CST_5_i64:%.+]] = arith.constant 5 : i64 +// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array} : > +// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr> +// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[PARAM_3_]]) -> (tensor<4x256xbf16>, i32) { +// CHECK: [[VAR_7_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[VAR_arg7_]], [[CST_0_i32]]] {order = array} : > +// CHECK: [[VAR_8_:%.+]] = tt.load [[VAR_7_]] : !tt.ptr> +// CHECK: [[VAR_9_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_8_]] : tensor<4x256xbf16> +// CHECK: [[VAR_10_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_3_i32]] : i32 +// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, i32 +// CHECK: } +// CHECK: [[VAR_6_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array} : > +// CHECK: tt.store [[VAR_6_]], [[VAR_4_]]#0 : !tt.ptr> +// CHECK: tt.return +// CHECK: } +module { + tt.func @test_addptr_for_accumulation( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr, + %arg3 : i32, + %arg4 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> + %offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32> + // offset = [%arg3,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %c5 = arith.constant 5 : i32 + %splat6 = tt.splat %c5 : i32 -> tensor<4x256xi32> + // scalar = 5 + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> // Why we never called the conversion function for the inputs here? + // offset = [0,0], size = [4,256], stride = [0,5] + %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> // Why we never called the conversion function for the inputs here? + // offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> // Why is the input unknown + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %19 = tt.load %9 : tensor<4x256x!tt.ptr> // this will be replaced with a memref.copy + %11 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> + %12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %19, %ptr_iter = %12) -> (tensor<4x256xbf16>, tensor<4x256x!tt.ptr>) { + %20 = tt.load %ptr_iter : tensor<4x256x!tt.ptr> + %sum = arith.addf %sum_iter, %20 : tensor<4x256xbf16> + // pointer updates + %17 = tt.splat %i_c3 : i32 -> tensor<4x256xi32> + // offset: [3, 0], size = [4, 256], stride [0, 0] + %ptr = tt.addptr %ptr_iter, %17 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg1, offset = [%arg3+%i, 0], size = [4, 256], stride = [1, 5] + scf.yield %sum, %ptr : tensor<4x256xbf16>, tensor<4x256x!tt.ptr> + } + %15 = tt.splat %arg2 : !tt.ptr -> tensor<4x256x!tt.ptr> + %16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + tt.store %16, %sum_out : tensor<4x256x!tt.ptr> + tt.return + } +} + +// CHECK: tt.func @test_addptr_for_more_init_args([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) { +// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32 +// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32 +// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64 +// CHECK: [[VAR_0_:%.+]]:5 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg3_:%.+]] = [[CST_1_]], [[VAR_arg4_:%.+]] = [[CST_2_]], [[VAR_arg5_:%.+]] = [[CST_3_]], [[VAR_arg6_:%.+]] = [[CST_1024_i32]], [[VAR_arg7_:%.+]] = [[CST_1024_i32]]) -> (index, index, index, i32, i32) { +// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg7_]]] {order = array} : > +// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg6_]]] {order = array} : > +// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr> +// CHECK: tt.store [[VAR_1_]], [[VAR_3_]] : !tt.ptr> +// CHECK: [[VAR_4_:%.+]] = arith.addi [[VAR_arg6_]], [[CST_3_i32]] : i32 +// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_arg3_]], [[CST_3_]] : index +// CHECK: [[VAR_6_:%.+]] = arith.addi [[VAR_arg4_]], [[CST_3_]] : index +// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_3_]] : index +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_5_]], [[VAR_6_]] : index +// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[VAR_7_]] : index +// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[VAR_9_]] : index to i32 +// CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_10_]] : i32 +// CHECK: scf.yield [[VAR_5_]], [[VAR_6_]], [[VAR_7_]], [[VAR_4_]], [[VAR_11_]] : index, index, index, i32, i32 +// CHECK: } +// CHECK: tt.return +// CHECK: } +module { + tt.func @test_addptr_for_more_init_args( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c12 = arith.constant 12 : index + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr> + %4 = tt.addptr %3, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + %_arg2, %_ptr_ld, %_arg3, %_ptr_st, %_arg4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%arg2 = %c1, %ptr_ld = %2, %arg3 = %c2, %ptr_st = %4, %arg4 = %c3) -> (index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index) { + %5 = tt.load %ptr_ld {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> + tt.store %ptr_st, %5 : tensor<256x!tt.ptr> + %cast3 = arith.index_cast %c3 : index to i32 + %6 = tt.splat %cast3 : i32 -> tensor<256xi32> + %ptr_ld_iter = tt.addptr %ptr_ld, %6 : tensor<256x!tt.ptr>, tensor<256xi32> + %arg2_iter = arith.addi %arg2, %c3 : index + %arg3_iter = arith.addi %arg3, %c3 : index + %arg4_iter = arith.addi %arg4, %c3 : index + %7 = arith.addi %arg2_iter, %arg3_iter : index + %8 = arith.addi %7, %arg4_iter : index + %cast8 = arith.index_cast %8 : index to i32 + %9 = tt.splat %cast8 : i32 -> tensor<256xi32> + %ptr_st_iter = tt.addptr %ptr_st, %9 : tensor<256x!tt.ptr>, tensor<256xi32> + scf.yield %arg2_iter, %ptr_ld_iter, %arg3_iter, %ptr_st_iter, %arg4_iter : index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index + } + tt.return + } +} + + +// CHECK: tt.func @test_addptr_for_used_after_update([[PARAM_0_:%.+]]: !tt.ptr) { +// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32 +// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32 +// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64 +// CHECK: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[CST_1024_i32]]) -> (i32) { +// CHECK: [[VAR_1_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_i32]] : i32 +// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_1_]]] {order = array} : > +// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr> +// CHECK: tt.store [[VAR_2_]], [[VAR_3_]] : !tt.ptr> +// CHECK: scf.yield [[VAR_1_]] : i32 +// CHECK: } +// CHECK: tt.return +// CHECK: } +module { + tt.func @test_addptr_for_used_after_update( + %arg0 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { + %4 = tt.splat %i_c3 : i32 -> tensor<256xi32> + %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + %3 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> + tt.store %ptr_iter, %3 : tensor<256x!tt.ptr> + scf.yield %ptr_iter : tensor<256x!tt.ptr> + } + tt.return + } +} + + + +// CHECK: tt.func @test_addptr_for_used_before_update([[PARAM_0_:%.+]]: !tt.ptr) { +// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32 +// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32 +// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64 +// CHECK: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[CST_1024_i32]]) -> (i32) { + +// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg2_]]] {order = array} : > +// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr> +// CHECK: tt.store [[VAR_2_]], [[VAR_3_]] : !tt.ptr> +// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_i32]] : i32 +// CHECK: scf.yield [[VAR_3_]] : i32 +// CHECK: } +// CHECK: tt.return +// CHECK: } +module { + tt.func @test_addptr_for_used_before_update( + %arg0 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + %_ptr2 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { + %3 = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> + tt.store %ptr, %3 : tensor<256x!tt.ptr> + %4 = tt.splat %i_c3 : i32 -> tensor<256xi32> + %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + scf.yield %ptr_iter : tensor<256x!tt.ptr> + } + tt.return + } +} + +// `triton::ExpandDims` ops on tensor of pointers are currently not supported in for loops. +// Consequently, the pass should fail cleanly. +// CHECK: tt.func @test_fail_addptr_for_expand_ptr([[PARAM_0_:%.+]]: !tt.ptr) { +// CHECK-NOT: tt.make_tensor_ptr +module { + tt.func @test_fail_addptr_for_expand_ptr( + %arg0 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { + %6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> + %8 = tt.broadcast %7 : tensor<256x1xi32> -> tensor<256x256xi32> + %9 = tt.make_range {end = 512 : i32, start = 256 : i32} : tensor<256xi32> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %11 = tt.broadcast %10 : tensor<1x256xi32> -> tensor<256x256xi32> + %12 = arith.addi %8, %11 : tensor<256x256xi32> + %13 = tt.expand_dims %ptr {axis = 1 : i32} : tensor<256x!tt.ptr> -> tensor<256x1x!tt.ptr> + %14 = tt.broadcast %13 : tensor<256x1x!tt.ptr> -> tensor<256x256x!tt.ptr> + %15 = tt.addptr %14, %12 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> + %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x256x!tt.ptr> + tt.store %15, %16 : tensor<256x256x!tt.ptr> + %17 = tt.splat %i_c3 : i32 -> tensor<256xi32> + %ptr_iter = tt.addptr %ptr, %17 : tensor<256x!tt.ptr>, tensor<256xi32> + scf.yield %ptr_iter : tensor<256x!tt.ptr> + } + tt.return + } +} diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index 3268f054ba..f91e775193 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Matchers.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include #define DEBUG_TYPE "triton-raise-block-pointer" @@ -201,21 +202,377 @@ struct TritonRaiseBlockPointer : triton::intel::impl::TritonRaiseBlockPointerBase< TritonRaiseBlockPointer> { using Base::Base; + using IndexMapSet = std::map>; + SmallVector cleanUp; void runOnOperation() final { - getOperation()->walk([this](Operation *op) { - TypeSwitch(op) + auto moduleOp = getOperation(); + + if (failed(rewriteOp(moduleOp))) { + moduleOp->emitWarning("TritonRaiseToBlockPointer failed"); + } + + for (auto op : cleanUp) { + if (op->getUsers().empty()) + op->erase(); + } + } + + LogicalResult rewriteOp(Operation *rootOp) { + LLVM_DEBUG({ + llvm::dbgs() << "rewriting rootOp\n"; + rootOp->dump(); + }); + + rootOp->walk([&](Operation *op) { + if (op == rootOp) { + return WalkResult::advance(); + } + return TypeSwitch(op) .Case([this](triton::AddPtrOp addptr) { if (failed(rewriteAddPtrOp(addptr))) addptr->emitRemark( "TritonRaiseToBlockPointer: Failed to rewrite"); + return WalkResult::advance(); + }) + .Case([&](auto maketptr) { + if (failed(remapMakeTensorPtrOp(maketptr))) { + maketptr->emitRemark("TritonRaiseToBlockPointer: Failed to " + "rewrite MakeTensorPtrOp"); + } + return WalkResult::advance(); }) .Case([this](auto loadstore) { - if (failed(rewriteLoadStoreOp(loadstore))) + if (failed(rewriteLoadStoreOp(loadstore))) { loadstore->emitRemark( "TritonRaiseToBlockPointer: Failed to rewrite"); - }); + return WalkResult::advance(); + } + return WalkResult::skip(); + }) + .Case([&](auto forOp) { + if (failed(rewriteForOp(forOp))) { + forOp->emitRemark( + "TritonRaiseToBlockPointer: Failed to rewrite ForOp"); + return WalkResult::interrupt(); + } + return WalkResult::skip(); + }) + .Default([&](auto) { return WalkResult::advance(); }); }); + + return success(); + } + + LogicalResult rewriteForOp(scf::ForOp op) { + SmallVector newInitArgs; + + SmallVector, 5> initArgIndexState; + SmallVector, 5> knownPtrsTmp; + + llvm::SmallDenseMap initArgIndexMap; + + OpBuilder builder(op); + + // Create a new list of init args + for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { + auto mappedV = ptrMap.lookupOrNull(arg); + PtrState state; + if (mappedV) { + if (auto makeTensorPtrOp = + mappedV.getDefiningOp()) { + + if (llvm::any_of(op.getRegionIterArgs()[i].getUsers(), + [](Operation *user) { + return isa(user); + })) { + op->emitRemark("TritonRaiseToBlockPointer: ExpandDims Ops in loops " + "are currently not supported"); + return failure(); + } + + if (succeeded(visitOperandMakeTensorPtr( + makeTensorPtrOp, state, op.getLoc(), builder, true))) { + newInitArgs.push_back(mappedV); + // Record the PtrState for later processing + initArgIndexState.push_back(std::make_pair(i, state)); + continue; + } + } else if (auto addptrOp = mappedV.getDefiningOp()) { + // We always use tt.addptr for scalar pointers. If the defininig op is + // tt.addptr and we have a non-scalar pointer, something must have + // gone wrong with the pass. + assert(!isa(addptrOp.getResult().getType()) && + "Result type of AddPtrOp must be a tensor!"); + if (succeeded( + visitOperandAddptr(addptrOp, state, op.getLoc(), builder))) { + newInitArgs.push_back(mappedV); + // Record the PtrState for later processing + initArgIndexState.push_back(std::make_pair(i, state)); + continue; + } + } + } + // If any of the analysis failed, or init arg is not pointer related or + // prior rewrite has failed. Pass as is + newInitArgs.push_back(arg); + } + + // For each of the PtrState recorded in the last step, insert new + // instructions to describe offset and stride for each dimension and append + // them to init args + for (auto &[i, state] : initArgIndexState) { + // For each dimension, if the corresponding offset and stride is an + // integer attribute, create a constant value and append them at the + // end of init arg list. + for (auto [j, s] : llvm::enumerate(state.offsets)) { + newInitArgs.push_back(s); + } + + for (auto [j, s] : llvm::enumerate(state.strides)) { + newInitArgs.push_back(s); + } + + if (state.getRank() == 0) { + assert(state.scalar && + "The state must have a scalar if its rank is equal to zero"); + // for scalar pointers, the scalar contains the offset and is the only + // relevant state that could be updated by the loop. + newInitArgs.push_back(state.scalar); + } + + // Note that we want the knownPtrs to be indexed by block arg, but we + // only have index for now. Also, the state we record is the init + // arg, but want to use the newly created block arg. These block args + // are not created yet. We will translate this mapping later. + knownPtrsTmp.push_back(std::make_pair(i, state)); + levelToBlockArgIndex[level].insert(i); + } + + // Create a new scf::ForOp that uses updated init args and same loop body + auto newOp = builder.create( + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), + newInitArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + IRMapping cloneMap; + cloneMap.map(op.getInductionVar(), iv); + cloneMap.map(op.getInitArgs(), newInitArgs); + cloneMap.map(op.getRegionIterArgs(), args); + + for (auto &bodyOp : op.getRegion().getOps()) { + b.clone(bodyOp, cloneMap); + } + }); + + // Convert the book-keeping data structure to use the correct key and value. + // Key is converted from init arg index to newly created block arg, and + // Value's PtrState fields are converted from init arg to newly created + // block arg + int cnt = op.getRegionIterArgs().size(); + for (auto &[i, state] : knownPtrsTmp) { + for (auto it = state.offsets.begin(); it != state.offsets.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + for (auto it = state.strides.begin(); it != state.strides.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + if (state.getRank() == 0) { + assert(state.scalar && + "The state must have a scalar if its rank is equal to zero"); + state.scalar = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + // Record the PtrState for this pointer + auto key = newOp.getRegionIterArgs()[i]; + knownPtrs[key] = state; + initArgIndexMap[i] = state; + + // For tensors of pointers, create a tt.make_block_ptr at the beginning of + // the loop body that correspond to this region iter arg. In case it is + // used by tt.load/tt.store in the loop body before pointer updates, this + // will make sure rewriteLoadOp/rewriteStoreOp can use the analysis + // result. E.g., given the following input (%tensor_of_ptr is a block + // arg): + // scf.for (%tensor_of_ptr) { + // %data = tt.load %tensor_of_ptr + // // more operations to update %tensor_of_ptr + // } + // We may produce the following output: + // scf.for (%base_ptr, %stride, %offset) { + // %tensor_of_ptr = tt.make_block_ptr(%base_ptr, %stride, %offset) + // %data = tt.load %tensor_of_ptr + // // more operations to update %offset + // } + // If %tensor_of_ptr is not used (i.e., %tensor_of_ptr is updated before + // used in the original IR), it will simply be removed by + // canonicalization. + + // For scalar pointers, there is no need to create a tts.addptr at the + // beginning of the loop body. We don't lower tt.load and tt.store on + // scalars in this pass; pointer arithmetics can also just use the + // original pointer. + if (state.getRank() != 0) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&newOp.getRegion().front()); + auto maketptrOp = state.createTTMakeTensorPtrOp(builder, op.getLoc()); + ptrMap.map(key, maketptrOp.getResult()); + } + } + + for (auto &bodyOp : newOp.getRegion().getOps()) { + if (auto forOp = dyn_cast(bodyOp)) { + forOp->emitRemark( + "TritonRaiseToBlockPointer: nested loops currently not supported"); + return failure(); + } + } + // Update the loop body. + if (failed(rewriteOp(newOp))) { + newOp->erase(); + op->emitRemark("TritonRaiseToBlockPointer: update loop body failed when " + "rewriting for op"); + return failure(); + } + if (op.getNumRegionIterArgs()) { + auto yieldOp = cast(newOp.getBody()->getTerminator()); + if (failed(rewriteYieldOp(yieldOp, initArgIndexMap))) { + newOp->erase(); + return failure(); + }; + } + + levelToBlockArgIndex.erase(level); + + // Replace only the results that correspond to the original scf.for + auto resultsToReplaceWith = ResultRange( + newOp.result_begin(), newOp.result_begin() + op.getNumResults()); + + LLVM_DEBUG({ + llvm::dbgs() << "new for\n"; + newOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + + llvm::dbgs() << "old for\n"; + op->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + op->replaceAllUsesWith(resultsToReplaceWith); + op->erase(); + + return success(); + } + + LogicalResult + rewriteYieldOp(scf::YieldOp op, + llvm::SmallDenseMap &knownPtrsFor) { + if (levelToBlockArgIndex.find(level) == levelToBlockArgIndex.end()) { + // no need to rewrite this op + return success(); + } + + OpBuilder builder(op); + + // For each of the init arg that we added additional Values in for loop, we + // need to add corresponding Values as yield operands. The loop below + // gathers PtrState for those values. + SmallVector initArgState; + for (auto [i, v] : llvm::enumerate(op->getOperands())) { + // If this operand is not rewritten by forOp, skip + auto &thisSet = levelToBlockArgIndex.find(level)->second; + if (thisSet.find(i) == thisSet.end()) + continue; + + auto mappedV = ptrMap.lookupOrNull(v); + if (!mappedV) { + op->emitRemark("Prior rewrite failure lead to yield rewrite failure"); + return failure(); + } + + PtrState state; + LogicalResult ret = failure(); + if (auto makeTPtrOp = mappedV.getDefiningOp()) { + ret = visitOperandMakeTensorPtr(makeTPtrOp, state, op.getLoc(), builder, + true); + } else if (auto addptrOp = mappedV.getDefiningOp()) { + ret = visitOperandAddptr(addptrOp, state, op.getLoc(), builder); + } + if (ret.failed()) { + op->emitRemark("Failed to rewrite yield op"); + return failure(); + } + initArgState.push_back(state); + + // Verify that shape is not updated during the for loop + auto forState = knownPtrsFor[i]; + for (auto i = 0; i < forState.getRank(); ++i) { + if (forState.shape[i] != state.shape[i]) { + // Special case, see comments in addState in dealing with shape/modulo + if (i == 0 && forState.getRank() == 2) { + if (forState.shape[1] == state.shape[0] && + forState.shape[0] == state.shape[1]) { + break; + } + } + op->emitRemark( + "TritonRaiseToBlockPointer: operand's shape/modulo state changed " + "within loop body"); + return failure(); + } + } + } + + SmallVector operands; + for (auto opnd : op->getOperands()) { + auto mappedV = ptrMap.lookupOrNull(opnd); + operands.push_back(mappedV ? mappedV : opnd); + } + + // For each of the PtrState recorded in the last step, extract value + // that correspond to offset and stride for each dimension and append + // them to yield operands. + for (auto state : initArgState) { + for (auto s : state.offsets) { + operands.push_back(s); + } + + for (auto s : state.strides) { + operands.push_back(s); + } + + if (state.getRank() == 0) { + operands.push_back(state.scalar); + } + } + + auto newOp = builder.create(op->getLoc(), operands); + + LLVM_DEBUG({ + llvm::dbgs() << "new yield:"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + op->erase(); + return success(); + } + + LogicalResult remapMakeTensorPtrOp(triton::MakeTensorPtrOp op) { + OpBuilder builder(op); + + PtrState state; + if (failed(visitOperandMakeTensorPtr(op, state, op.getLoc(), builder))) { + return failure(); + } + + knownPtrs[op.getResult()] = state; + return success(); } LogicalResult rewriteAddPtrOp(triton::AddPtrOp op) { @@ -237,6 +594,32 @@ struct TritonRaiseBlockPointer ptrMap.map(result, mapped); + // AddPtrOps that have been rewritten and no longer used in the code must be + // removed in the pass to avoid type matching issue. + cleanUp.push_back(op); + + return success(); + } + + LogicalResult visitOperandMakeTensorPtr(triton::MakeTensorPtrOp makeTPtrOp, + PtrState &state, const Location loc, + OpBuilder &builder, + bool addedByPass = false) { + assert(state.isEmpty() && "state is a return argument"); + state.source = makeTPtrOp.getBase(); + + auto resType = cast(makeTPtrOp.getResult().getType()); + auto pointeeType = cast(resType.getPointeeType()); + auto shape = pointeeType.getShape(); + + for (int64_t i = 0; i < pointeeType.getRank(); i++) { + state.sizes.push_back(shape[i]); + } + state.strides = makeTPtrOp.getStrides(); + state.offsets = makeTPtrOp.getOffsets(); + state.shape = makeTPtrOp.getShape(); + state.order = SmallVector(makeTPtrOp.getOrder()); + return success(); } @@ -375,6 +758,8 @@ struct TritonRaiseBlockPointer llvm::SmallDenseMap knownPtrs; IRMapping ptrMap; + IndexMapSet levelToBlockArgIndex; + int level = 0; }; template <>