Skip to content

Commit

Permalink
[RAISE-BP] Handle scf.for (#1546)
Browse files Browse the repository at this point in the history
Add the support to raise `tt.load`, `tt.store` and `tt.addptr` when this
operations take place in a `scf.for` loop. It includes:
 - rewriting the loop and the yieldOp
 - adding a visitor for triton::MakeTensorPtrOp
 - early cleanup of unused AddptrOps
 - new tests

---------

Signed-off-by: Maxime France-Pillois <[email protected]>
  • Loading branch information
mfrancepillois authored Jul 8, 2024
1 parent 2421a85 commit 77e0c53
Show file tree
Hide file tree
Showing 2 changed files with 654 additions and 4 deletions.
265 changes: 265 additions & 0 deletions test/Triton/raise-block-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,268 @@ tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128x
%3 = tt.load %2 : tensor<128x2x128x!tt.ptr<f32>>
tt.return %3 : tensor<128x2x128xf32>
}

// CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index
// CHECK: [[CST_12_:%.+]] = arith.constant 12 : index
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK: [[CST_5_i64:%.+]] = arith.constant 5 : i64
// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[PARAM_3_]]) -> (tensor<4x256xbf16>, i32) {
// CHECK: [[VAR_7_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[VAR_arg7_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
// CHECK: [[VAR_8_:%.+]] = tt.load [[VAR_7_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: [[VAR_9_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_8_]] : tensor<4x256xbf16>
// CHECK: [[VAR_10_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_3_i32]] : i32
// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, i32
// CHECK: }
// CHECK: [[VAR_6_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
// CHECK: tt.store [[VAR_6_]], [[VAR_4_]]#0 : !tt.ptr<tensor<4x256xbf16>>
// CHECK: tt.return
// CHECK: }
module {
tt.func @test_addptr_for_accumulation(
%arg0 : !tt.ptr<bf16>,
%arg1 : !tt.ptr<bf16>,
%arg2 : !tt.ptr<bf16>,
%arg3 : i32,
%arg4 : i32
)
{
%0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32>
// offset = 0, size = 4, stride = 1
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
// offset = [0,0], size = [4,1], stride = [1,0]
%2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32>
// offset = [0,0], size = [4,256], stride = [1,0]
%arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32>
%offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32>
// offset = [%arg3,0], size = [4,256], stride = [1,0]
%3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32>
// offset = 0, size = 256, stride = 1
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32>
// offset = [0,0], size = [1,256], stride = [0,1]
%5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32>
// offset = [0,0], size = [4,256], stride = [0,1]
%c5 = arith.constant 5 : i32
%splat6 = tt.splat %c5 : i32 -> tensor<4x256xi32>
// scalar = 5
%scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> // Why we never called the conversion function for the inputs here?
// offset = [0,0], size = [4,256], stride = [0,5]
%7 = arith.addi %offset3, %scale5: tensor<4x256xi32> // Why we never called the conversion function for the inputs here?
// offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
%8 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>> // Why is the input unknown
%9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
%19 = tt.load %9 : tensor<4x256x!tt.ptr<bf16>> // this will be replaced with a memref.copy
%11 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>>
%12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
%c3 = arith.constant 3 : index
%i_c3 = arith.constant 3 : i32
%sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %19, %ptr_iter = %12) -> (tensor<4x256xbf16>, tensor<4x256x!tt.ptr<bf16>>) {
%20 = tt.load %ptr_iter : tensor<4x256x!tt.ptr<bf16>>
%sum = arith.addf %sum_iter, %20 : tensor<4x256xbf16>
// pointer updates
%17 = tt.splat %i_c3 : i32 -> tensor<4x256xi32>
// offset: [3, 0], size = [4, 256], stride [0, 0]
%ptr = tt.addptr %ptr_iter, %17 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg1, offset = [%arg3+%i, 0], size = [4, 256], stride = [1, 5]
scf.yield %sum, %ptr : tensor<4x256xbf16>, tensor<4x256x!tt.ptr<bf16>>
}
%15 = tt.splat %arg2 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>>
%16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
tt.store %16, %sum_out : tensor<4x256x!tt.ptr<bf16>>
tt.return
}
}

// CHECK: tt.func @test_addptr_for_more_init_args([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index
// CHECK: [[CST_12_:%.+]] = arith.constant 12 : index
// CHECK: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32
// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64
// CHECK: [[VAR_0_:%.+]]:5 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg3_:%.+]] = [[CST_1_]], [[VAR_arg4_:%.+]] = [[CST_2_]], [[VAR_arg5_:%.+]] = [[CST_3_]], [[VAR_arg6_:%.+]] = [[CST_1024_i32]], [[VAR_arg7_:%.+]] = [[CST_1024_i32]]) -> (index, index, index, i32, i32) {
// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg7_]]] {order = array<i32>} : <tensor<256xbf16>>
// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg6_]]] {order = array<i32>} : <tensor<256xbf16>>
// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
// CHECK: tt.store [[VAR_1_]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
// CHECK: [[VAR_4_:%.+]] = arith.addi [[VAR_arg6_]], [[CST_3_i32]] : i32
// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_arg3_]], [[CST_3_]] : index
// CHECK: [[VAR_6_:%.+]] = arith.addi [[VAR_arg4_]], [[CST_3_]] : index
// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_3_]] : index
// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_5_]], [[VAR_6_]] : index
// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[VAR_7_]] : index
// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[VAR_9_]] : index to i32
// CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_10_]] : i32
// CHECK: scf.yield [[VAR_5_]], [[VAR_6_]], [[VAR_7_]], [[VAR_4_]], [[VAR_11_]] : index, index, index, i32, i32
// CHECK: }
// CHECK: tt.return
// CHECK: }
module {
tt.func @test_addptr_for_more_init_args(
%arg0 : !tt.ptr<bf16>,
%arg1 : !tt.ptr<bf16>
)
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c12 = arith.constant 12 : index
%0 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>>
%1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32>
%2 = tt.addptr %0, %1 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
%3 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>>
%4 = tt.addptr %3, %1 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
%_arg2, %_ptr_ld, %_arg3, %_ptr_st, %_arg4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%arg2 = %c1, %ptr_ld = %2, %arg3 = %c2, %ptr_st = %4, %arg4 = %c3) -> (index, tensor<256x!tt.ptr<bf16>>, index, tensor<256x!tt.ptr<bf16>>, index) {
%5 = tt.load %ptr_ld {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr<bf16>>
tt.store %ptr_st, %5 : tensor<256x!tt.ptr<bf16>>
%cast3 = arith.index_cast %c3 : index to i32
%6 = tt.splat %cast3 : i32 -> tensor<256xi32>
%ptr_ld_iter = tt.addptr %ptr_ld, %6 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
%arg2_iter = arith.addi %arg2, %c3 : index
%arg3_iter = arith.addi %arg3, %c3 : index
%arg4_iter = arith.addi %arg4, %c3 : index
%7 = arith.addi %arg2_iter, %arg3_iter : index
%8 = arith.addi %7, %arg4_iter : index
%cast8 = arith.index_cast %8 : index to i32
%9 = tt.splat %cast8 : i32 -> tensor<256xi32>
%ptr_st_iter = tt.addptr %ptr_st, %9 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
scf.yield %arg2_iter, %ptr_ld_iter, %arg3_iter, %ptr_st_iter, %arg4_iter : index, tensor<256x!tt.ptr<bf16>>, index, tensor<256x!tt.ptr<bf16>>, index
}
tt.return
}
}


// CHECK: tt.func @test_addptr_for_used_after_update([[PARAM_0_:%.+]]: !tt.ptr<bf16>) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK: [[CST_12_:%.+]] = arith.constant 12 : index
// CHECK: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32
// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64
// CHECK: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[CST_1024_i32]]) -> (i32) {
// CHECK: [[VAR_1_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_i32]] : i32
// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_1_]]] {order = array<i32>} : <tensor<256xbf16>>
// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
// CHECK: tt.store [[VAR_2_]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
// CHECK: scf.yield [[VAR_1_]] : i32
// CHECK: }
// CHECK: tt.return
// CHECK: }
module {
tt.func @test_addptr_for_used_after_update(
%arg0 : !tt.ptr<bf16>
)
{
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
%c3 = arith.constant 3 : index
%i_c3 = arith.constant 3 : i32
%0 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>>
%1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32>
%2 = tt.addptr %0, %1 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
%_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr<bf16>>) {
%4 = tt.splat %i_c3 : i32 -> tensor<256xi32>
%ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
%3 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr<bf16>>
tt.store %ptr_iter, %3 : tensor<256x!tt.ptr<bf16>>
scf.yield %ptr_iter : tensor<256x!tt.ptr<bf16>>
}
tt.return
}
}



// CHECK: tt.func @test_addptr_for_used_before_update([[PARAM_0_:%.+]]: !tt.ptr<bf16>) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK: [[CST_12_:%.+]] = arith.constant 12 : index
// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index
// CHECK: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32
// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64
// CHECK: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[CST_1024_i32]]) -> (i32) {

// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_arg2_]]] {order = array<i32>} : <tensor<256xbf16>>
// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr<tensor<256xbf16>>
// CHECK: tt.store [[VAR_2_]], [[VAR_3_]] : !tt.ptr<tensor<256xbf16>>
// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_i32]] : i32
// CHECK: scf.yield [[VAR_3_]] : i32
// CHECK: }
// CHECK: tt.return
// CHECK: }
module {
tt.func @test_addptr_for_used_before_update(
%arg0 : !tt.ptr<bf16>
)
{
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
%c3 = arith.constant 3 : index
%i_c3 = arith.constant 3 : i32
%0 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>>
%1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32>
%2 = tt.addptr %0, %1 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
%_ptr2 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr<bf16>>) {
%3 = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr<bf16>>
tt.store %ptr, %3 : tensor<256x!tt.ptr<bf16>>
%4 = tt.splat %i_c3 : i32 -> tensor<256xi32>
%ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
scf.yield %ptr_iter : tensor<256x!tt.ptr<bf16>>
}
tt.return
}
}

// `triton::ExpandDims` ops on tensor of pointers are currently not supported in for loops.
// Consequently, the pass should fail cleanly.
// CHECK: tt.func @test_fail_addptr_for_expand_ptr([[PARAM_0_:%.+]]: !tt.ptr<bf16>) {
// CHECK-NOT: tt.make_tensor_ptr
module {
tt.func @test_fail_addptr_for_expand_ptr(
%arg0 : !tt.ptr<bf16>
)
{
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
%c3 = arith.constant 3 : index
%i_c3 = arith.constant 3 : i32
%0 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>>
%1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32>
%2 = tt.addptr %0, %1 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
%_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr<bf16>>) {
%6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
%7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32>
%8 = tt.broadcast %7 : tensor<256x1xi32> -> tensor<256x256xi32>
%9 = tt.make_range {end = 512 : i32, start = 256 : i32} : tensor<256xi32>
%10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32>
%11 = tt.broadcast %10 : tensor<1x256xi32> -> tensor<256x256xi32>
%12 = arith.addi %8, %11 : tensor<256x256xi32>
%13 = tt.expand_dims %ptr {axis = 1 : i32} : tensor<256x!tt.ptr<bf16>> -> tensor<256x1x!tt.ptr<bf16>>
%14 = tt.broadcast %13 : tensor<256x1x!tt.ptr<bf16>> -> tensor<256x256x!tt.ptr<bf16>>
%15 = tt.addptr %14, %12 : tensor<256x256x!tt.ptr<bf16>>, tensor<256x256xi32>
%16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x256x!tt.ptr<bf16>>
tt.store %15, %16 : tensor<256x256x!tt.ptr<bf16>>
%17 = tt.splat %i_c3 : i32 -> tensor<256xi32>
%ptr_iter = tt.addptr %ptr, %17 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
scf.yield %ptr_iter : tensor<256x!tt.ptr<bf16>>
}
tt.return
}
}
Loading

0 comments on commit 77e0c53

Please sign in to comment.