From a1c57256ad534cc8e045705b080c2f51706ef4ec Mon Sep 17 00:00:00 2001 From: victor-eds Date: Wed, 20 Nov 2024 10:26:44 +0000 Subject: [PATCH] [XPU][TritonGPUToLLVM] Avoid bank conflicts in sub-group transposes - Store the whole matrix using SIMD block stores for each row leaving a single garbage item at the end of the row so each row has `sub_group_size + 1` elements - Load each row with vector loads By introducing this garbage item at the end of each row, we ensure matrix loading avoid bank conflicts as the offset between the position loaded by work-item `i` and `i+j` is `N * (sub_group_size + 1)` (assuming `sub_group_size` banks). Signed-off-by: victor-eds --- .../intel/intel-allocate-shared-memory.mlir | 6 +- .../Conversion/intel/sub-group-transpose.mlir | 578 ++++++++++++------ third_party/intel/lib/Analysis/Allocation.cpp | 12 +- .../ConvertLayoutOpToLLVM.cpp | 86 +-- 4 files changed, 440 insertions(+), 242 deletions(-) diff --git a/test/Conversion/intel/intel-allocate-shared-memory.mlir b/test/Conversion/intel/intel-allocate-shared-memory.mlir index 0aa7990417..5fad77531e 100644 --- a/test/Conversion/intel/intel-allocate-shared-memory.mlir +++ b/test/Conversion/intel/intel-allocate-shared-memory.mlir @@ -24,7 +24,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // Check scracth memory configuration for different sub-group transpose-like layout conversions. // CHECK-LABEL: module attributes -// CHECK-SAME: triton_gpu.shared = 512 : i32 +// CHECK-SAME: triton_gpu.shared = 544 : i32 module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { tt.func @test_f16(%arg0: tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #blocked1> { %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1> @@ -40,7 +40,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // Check scracth memory configuration for different sub-group transpose-like layout conversions. // CHECK-LABEL: module attributes -// CHECK-SAME: triton_gpu.shared = 1024 : i32 +// CHECK-SAME: triton_gpu.shared = 1088 : i32 module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { tt.func @test_f32(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16x16xf32, #blocked1> { %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1> @@ -56,7 +56,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // Check scracth memory configuration for different sub-group transpose-like layout conversions. // CHECK-LABEL: module attributes -// CHECK-SAME: triton_gpu.shared = 32768 : i32 +// CHECK-SAME: triton_gpu.shared = 34816 : i32 module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { tt.func @test_f32(%arg0: tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked1> { %0 = triton_gpu.convert_layout %arg0 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1> diff --git a/test/Conversion/intel/sub-group-transpose.mlir b/test/Conversion/intel/sub-group-transpose.mlir index 9387c7dda9..b4a9b242a7 100644 --- a/test/Conversion/intel/sub-group-transpose.mlir +++ b/test/Conversion/intel/sub-group-transpose.mlir @@ -10,21 +10,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_f16(%arg0: tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #blocked1> { // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f16 to i16 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 - // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_us16PU3AS3tDv16_t(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<16xi16>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi16> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi16> // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to f16 %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1> tt.return %0 : tensor<16x16xf16, #blocked1> @@ -34,21 +38,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_bf16(%arg0: tensor<16x16xbf16, #blocked>) -> tensor<16x16xbf16, #blocked1> { // CHECK-COUNT-16: llvm.bitcast %{{.*}} : bf16 to i16 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 - // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_us16PU3AS3tDv16_t(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<16xi16>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi16> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi16> // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to bf16 %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xbf16, #blocked> -> tensor<16x16xbf16, #blocked1> tt.return %0 : tensor<16x16xbf16, #blocked1> @@ -58,24 +66,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_f32(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16x16xf32, #blocked1> { // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1> tt.return %0 : tensor<16x16xf32, #blocked1> @@ -84,21 +93,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_i8( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_i8(%arg0: tensor<16x16xi8, #blocked>) -> tensor<16x16xi8, #blocked1> { - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 - // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_uc16PU3AS3hDv16_h(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<16xi8>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi8> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi8> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi8, #blocked> -> tensor<16x16xi8, #blocked1> tt.return %0 : tensor<16x16xi8, #blocked1> } @@ -106,24 +119,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_i64( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_i64(%arg0: tensor<16x16xi64, #blocked>) -> tensor<16x16xi64, #blocked1> { - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi64> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi64> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi64, #blocked> -> tensor<16x16xi64, #blocked1> tt.return %0 : tensor<16x16xi64, #blocked1> } @@ -132,24 +146,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_ptr(%arg0: tensor<16x16x!tt.ptr, #blocked>) -> tensor<16x16x!tt.ptr, #blocked1> { // CHECK-COUNT-16: llvm.ptrtoint %{{.*}} : !llvm.ptr<1> to i64 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi64> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi64> // CHECK-COUNT-16: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16x!tt.ptr, #blocked> -> tensor<16x16x!tt.ptr, #blocked1> tt.return %0 : tensor<16x16x!tt.ptr, #blocked1> @@ -159,21 +174,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_i1(%arg0: tensor<16x16xi1, #blocked>) -> tensor<16x16xi1, #blocked1> { // CHECK-COUNT-16: llvm.zext %{{.*}} : i1 to i8 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 - // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_uc16PU3AS3hDv16_h(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<16xi8>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi8> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi8> // CHECK-COUNT-16: llvm.trunc %{{.*}} : i8 to i1 %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi1, #blocked> -> tensor<16x16xi1, #blocked1> tt.return %0 : tensor<16x16xi1, #blocked1> @@ -191,7 +210,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1> tt.return %0 : tensor<32x16xf32, #blocked1> } @@ -208,7 +247,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1> tt.return %0 : tensor<16x32xf32, #blocked1> } @@ -225,7 +284,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> tt.return %0 : tensor<64x64xf32, #blocked1> } @@ -242,7 +321,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x64x1xf32, #blocked>) -> tensor<64x64x1xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<64x64x1xf32, #blocked> -> tensor<64x64x1xf32, #blocked1> tt.return %0 : tensor<64x64x1xf32, #blocked1> } @@ -258,7 +357,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x64xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>>) -> tensor<64x64xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<64x64xf32, #blocked1> tt.return %0 : tensor<64x64xf32, #blocked1> } @@ -275,7 +394,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>) -> tensor<16x16x1xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<16x16x1xf32, #blocked1> tt.return %0 : tensor<16x16x1xf32, #blocked1> } @@ -292,7 +431,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x16x4xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>) -> tensor<64x16x4xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<64x16x4xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<64x16x4xf32, #blocked1> tt.return %0 : tensor<64x16x4xf32, #blocked1> } @@ -310,31 +469,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> { // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // COM: Offset is double as before as we have double the rows. + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(544 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 31 more stores: + // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1> tt.return %0 : tensor<32x16xf32, #blocked1> @@ -353,31 +507,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> { // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // COM: Offset is double as before as we have double the rows. + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(544 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 31 more stores: + // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1> tt.return %0 : tensor<16x32xf32, #blocked1> @@ -396,31 +545,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) tt.func @test(%arg0: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked1> { // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // COM: Offset is double as before as we have double the rows. + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(544 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 31 more stores: + // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<32x64xf32, #blocked> -> tensor<32x64xf32, #blocked1> tt.return %0 : tensor<32x64xf32, #blocked1> @@ -443,3 +587,45 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %0, %1 : tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1> } } + +// ----- + +// Test transposition with sub-group size 32. + +#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // COM: Offset changes with increased number of columns: + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(1056 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[33] + // COM: Check there are 31 more stores: + // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(1056 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]][1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: llvm.load %[[VAL_62]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_63:.*]] = llvm.getelementptr inbounds %[[VAL_62]][1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: %[[VAL_64:.*]] = llvm.getelementptr inbounds %[[VAL_63]][1024] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i32 + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + tt.return %0 : tensor<64x64xf32, #blocked1> + } +} diff --git a/third_party/intel/lib/Analysis/Allocation.cpp b/third_party/intel/lib/Analysis/Allocation.cpp index 70782aaa36..266615a4fa 100644 --- a/third_party/intel/lib/Analysis/Allocation.cpp +++ b/third_party/intel/lib/Analysis/Allocation.cpp @@ -5,6 +5,7 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "intel/include/Analysis/Utility.h" +#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" namespace mlir::triton::intel { namespace { @@ -22,7 +23,16 @@ unsigned allocationAnalysisScratchSizeFn(gpu::ConvertLayoutOp convertLayout) { isa(elemTy) ? kPtrBitWidth / 8 : std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; - return product(srcTy.getShape()) * bytesPerElement; + unsigned numElements = product(srcTy.getShape()); + Attribute encoding = srcTy.getEncoding(); + int subGroupSize = product(gpu::getThreadsPerWarp(encoding)); + assert(numElements % subGroupSize == 0 && + "Sub-group transposable tensors have a number of elements multiple " + "of the sub-group size"); + // Add an element at the end of the row that will not be accessed. This + // allows us to avoid bank conflicts. + unsigned numMatrixCells = (numElements / subGroupSize) * (subGroupSize + 1); + return numMatrixCells * bytesPerElement; } return invalidSize; } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index c20258b263..72d5f7e291 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -726,23 +726,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion rewriter.replaceOp(op, result); } - VectorType - getTypeForSubGroupTranspose(ArrayRef inVals, - ConversionPatternRewriter &rewriter) const { - auto elementTy = cast(inVals.front().getType()); - return elementTy.getWidth() <= 16 ? vec_ty(elementTy, 16) - : vec_ty(elementTy, 8); - } - - Value wrapInVector(Location loc, VectorType type, ArrayRef values, - ConversionPatternRewriter &rewriter) const { - assert(type.getShape()[0] == values.size() && "Size mismatch"); - Value res = rewriter.create(loc, type); - for (auto [index, val] : llvm::enumerate(values)) - res = insert_element(res, val, i32_val(index)); - return res; - } - SmallVector unwrapFromVectors(Location loc, ArrayRef vecs, ConversionPatternRewriter &rewriter) const { @@ -755,21 +738,30 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return res; } + static unsigned getVecLoadWidth(unsigned threadsPerWarp) { + assert(llvm::isPowerOf2_32(threadsPerWarp) && + "Expecting power of 2 sub-group size"); + constexpr unsigned maxVecWidth = 16; + return std::min(maxVecWidth, threadsPerWarp); + } + SmallVector performSubGroupTranspose(Location loc, ArrayRef inVals, ConversionPatternRewriter &rewriter) const { - VectorType opType = getTypeForSubGroupTranspose(inVals, rewriter); + Type elementType = inVals.front().getType(); auto mod = rewriter.getInsertionPoint()->getParentOfType(); - unsigned vecWidth = opType.getShape()[0]; Value smemBase = LLVM::intel::getSharedMemoryBase( loc, rewriter, targetInfo, &*rewriter.getInsertionPoint()); Type ptrType = smemBase.getType(); - int numElements = inVals.size(); + int numRows = inVals.size(); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - int offset = threadsPerWarp; + // Add an element that won't be accessed at the end of the row to avoid bank + // conflicts. + int rowLength = threadsPerWarp + 1; Type offsetType = getTypeConverter()->getIndexType(); + unsigned offsetBitWidth = offsetType.getIntOrFloatBitWidth(); Value subGroupId = getValueOrCreateCastToIndexLike( rewriter, loc, offsetType, rewriter.create( @@ -778,38 +770,48 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion rewriter, loc, offsetType, rewriter.create(loc, /*upper_bound=*/IntegerAttr{})); - Value wiStride = - rewriter.create(loc, offsetType, threadsPerWarp); - Value sgStride = rewriter.create( - loc, offsetType, threadsPerWarp * numElements); - Value subGroupOffset = mul(sgStride, subGroupId); - Type elementType = opType.getElementType(); + Value subGroupOffset = + mul(subGroupId, int_val(offsetBitWidth, rowLength * numRows)); Value subGroupBasePtr = gep(ptrType, elementType, smemBase, ValueRange{subGroupOffset}, /*inbounds=*/true); Value base = subGroupBasePtr; // Store in matrix, transposed - for (ArrayRef vals = inVals; !vals.empty(); - vals = vals.drop_front(vecWidth)) { - ArrayRef curr = vals.take_front(vecWidth); - Value vec = wrapInVector(loc, opType, curr, rewriter); - rewriter.create(loc, base, vec); - base = gep(base.getType(), opType, base, ArrayRef{offset}, + for (Value val : inVals) { + rewriter.create(loc, base, val); + base = gep(base.getType(), elementType, base, + ArrayRef{rowLength}, /*inbounds=*/true); } // Load from matrix, non-trasposed. - // As per sub-group block semantics, we have stored the elements in a matrix - // of `Nxsub_group_size` size, so we need to load back in blocks of - // `sub_group_size` (`N/sub_group_size` loads). - Value workItemOffset = mul(wiStride, subGroupLocalId); + + // Each work-item will load a row (but the last garbage element) and go to + // the next row it needs to handle. + int32_t workItemStride = rowLength * threadsPerWarp; + Value workItemOffset = + mul(subGroupLocalId, int_val(offsetBitWidth, workItemStride)); Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr, ValueRange{workItemOffset}, /*inbounds=*/true); + int32_t rowsPerThread = numRows / threadsPerWarp; + // We may not be able to load rows in a single operation if the sub-group + // size exceeds a given threshold (16): + unsigned vecLoadWidth = getVecLoadWidth(threadsPerWarp); SmallVector transposedVecs; - Type loadTy = vec_ty(opType.getElementType(), threadsPerWarp); - for (std::size_t i = 0, n = inVals.size(); i < n; i += threadsPerWarp) { - transposedVecs.push_back(load(loadTy, workItemBasePtr)); - workItemBasePtr = gep(ptrType, loadTy, workItemBasePtr, - ArrayRef{offset}, /*inbounds=*/true); + VectorType vecType = vec_ty(elementType, vecLoadWidth); + assert(threadsPerWarp % vecLoadWidth == 0 && + "Column must be loadable with N loads"); + for (unsigned i = 0; i < rowsPerThread; ++i) { + for (unsigned j = 0; j < threadsPerWarp; j += vecLoadWidth) { + transposedVecs.push_back(load(vecType, workItemBasePtr)); + workItemBasePtr = gep(workItemBasePtr.getType(), vecType, + workItemBasePtr, ArrayRef{1}, + /*inbounds=*/true); + } + workItemBasePtr = + gep(workItemBasePtr.getType(), elementType, workItemBasePtr, + // "Go back" to the first column and increment by the stride. + ArrayRef{workItemStride - threadsPerWarp}, + /*inbounds=*/true); } return unwrapFromVectors(loc, transposedVecs, rewriter); }