Skip to content

Commit

Permalink
fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Dewei-Wang-sh committed Jul 12, 2024
1 parent f463208 commit b7f500b
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.d
// -----

// COM: Test Attention Related Ops
// CHECK-LABEL: @attn_fwd
#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 1 : i32} {
// CHECK-LABEL: @attn_fwd
tt.func public @attn_fwd(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: f32, %arg4: !tt.ptr<f32>, %arg5: !tt.ptr<f32>) {
%c16_i32 = arith.constant 16 : i32
%c128_i32 = arith.constant 128 : i32
Expand Down Expand Up @@ -290,7 +291,7 @@ tt.func public @attn_fwd(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.pt
%15 = tt.addptr %arg2, %9 : !tt.ptr<f16>, i64
%16 = tt.make_tensor_ptr %15, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%17 = tt.addptr %arg1, %9 : !tt.ptr<f16>, i64
%18 = tt.make_tensor_ptr %17, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
%18 = tt.make_tensor_ptr %17, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>>>
%19 = tt.addptr %arg5, %9 : !tt.ptr<f32>, i64
%20 = tt.make_tensor_ptr %19, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf32, #warp>>
%21 = arith.mulf %arg3, %cst : f32
Expand All @@ -300,11 +301,11 @@ tt.func public @attn_fwd(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.pt
%23 = tt.splat %21 : f32 -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%24 = tt.splat %21 : f32 -> tensor<16x64xf32, #warp>
// CHECK: 30 = scf.for
%25:5 = scf.for %arg6 = %c0_i32 to %c1024_i32 step %c64_i32 iter_args(%arg7 = %cst_2, %arg8 = %cst_0, %arg9 = %cst_1, %arg10 = %16, %arg11 = %18) -> (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>) : i32 {
%25:5 = scf.for %arg6 = %c0_i32 to %c1024_i32 step %c64_i32 iter_args(%arg7 = %cst_2, %arg8 = %cst_0, %arg9 = %cst_1, %arg10 = %16, %arg11 = %18) -> (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>>>) : i32 {
// CHECK-COUNT-16: tt.load {{.*}} {DotIdx = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
// CHECK-COUNT-32: tt.dot {{.*}} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32>
%29 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
%30 = tt.dot %22, %29, %cst_0, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>> -> tensor<16x64xf32, #warp>
%29 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>>>
%30 = tt.dot %22, %29, %cst_0, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>> -> tensor<16x64xf32, #warp>

// CHECK-COUNT-2: arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK: [[MAX:%.*]] = arith.maxnumf {{.*}} : tensor<8x16xf32>
Expand Down Expand Up @@ -348,12 +349,13 @@ tt.func public @attn_fwd(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.pt
// CHECK: scf.yield
%50 = tt.dot %49, %47, %46, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<16x64xf32, #warp>
%51 = tt.advance %arg10, [%c64_i32, %c0_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%52 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
scf.yield %43, %50, %33, %51, %52 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
%52 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>>>
scf.yield %43, %50, %33, %51, %52 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>>>
} {triton_gpu.workload = 4 : i32, tt.divisibility_arg1 = dense<64> : tensor<1xi32>}
%26 = tt.expand_dims %25#0 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%27 = tt.broadcast %26 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%28 = arith.divf %25#1, %27 : tensor<16x64xf32, #warp>
tt.store %20, %28 : !tt.ptr<tensor<16x64xf32, #warp>>
tt.return
}
}

0 comments on commit b7f500b

Please sign in to comment.