diff --git a/test/Conversion/intel/intel-allocate-shared-memory.mlir b/test/Conversion/intel/intel-allocate-shared-memory.mlir index 0aa7990417..5fad77531e 100644 --- a/test/Conversion/intel/intel-allocate-shared-memory.mlir +++ b/test/Conversion/intel/intel-allocate-shared-memory.mlir @@ -24,7 +24,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // Check scracth memory configuration for different sub-group transpose-like layout conversions. // CHECK-LABEL: module attributes -// CHECK-SAME: triton_gpu.shared = 512 : i32 +// CHECK-SAME: triton_gpu.shared = 544 : i32 module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { tt.func @test_f16(%arg0: tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #blocked1> { %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1> @@ -40,7 +40,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // Check scracth memory configuration for different sub-group transpose-like layout conversions. // CHECK-LABEL: module attributes -// CHECK-SAME: triton_gpu.shared = 1024 : i32 +// CHECK-SAME: triton_gpu.shared = 1088 : i32 module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { tt.func @test_f32(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16x16xf32, #blocked1> { %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1> @@ -56,7 +56,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // Check scracth memory configuration for different sub-group transpose-like layout conversions. // CHECK-LABEL: module attributes -// CHECK-SAME: triton_gpu.shared = 32768 : i32 +// CHECK-SAME: triton_gpu.shared = 34816 : i32 module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { tt.func @test_f32(%arg0: tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked1> { %0 = triton_gpu.convert_layout %arg0 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1> diff --git a/test/Conversion/intel/sub-group-transpose.mlir b/test/Conversion/intel/sub-group-transpose.mlir index 9387c7dda9..b4a9b242a7 100644 --- a/test/Conversion/intel/sub-group-transpose.mlir +++ b/test/Conversion/intel/sub-group-transpose.mlir @@ -10,21 +10,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_f16(%arg0: tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #blocked1> { // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f16 to i16 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 - // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_us16PU3AS3tDv16_t(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<16xi16>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi16> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi16> // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to f16 %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1> tt.return %0 : tensor<16x16xf16, #blocked1> @@ -34,21 +38,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_bf16(%arg0: tensor<16x16xbf16, #blocked>) -> tensor<16x16xbf16, #blocked1> { // CHECK-COUNT-16: llvm.bitcast %{{.*}} : bf16 to i16 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 - // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_us16PU3AS3tDv16_t(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<16xi16>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi16> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi16> // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to bf16 %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xbf16, #blocked> -> tensor<16x16xbf16, #blocked1> tt.return %0 : tensor<16x16xbf16, #blocked1> @@ -58,24 +66,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_f32(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16x16xf32, #blocked1> { // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1> tt.return %0 : tensor<16x16xf32, #blocked1> @@ -84,21 +93,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_i8( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_i8(%arg0: tensor<16x16xi8, #blocked>) -> tensor<16x16xi8, #blocked1> { - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 - // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_uc16PU3AS3hDv16_h(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<16xi8>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi8> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi8> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi8, #blocked> -> tensor<16x16xi8, #blocked1> tt.return %0 : tensor<16x16xi8, #blocked1> } @@ -106,24 +119,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_i64( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_i64(%arg0: tensor<16x16xi64, #blocked>) -> tensor<16x16xi64, #blocked1> { - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi64> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi64> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi64, #blocked> -> tensor<16x16xi64, #blocked1> tt.return %0 : tensor<16x16xi64, #blocked1> } @@ -132,24 +146,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_ptr(%arg0: tensor<16x16x!tt.ptr, #blocked>) -> tensor<16x16x!tt.ptr, #blocked1> { // CHECK-COUNT-16: llvm.ptrtoint %{{.*}} : !llvm.ptr<1> to i64 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi64> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi64> // CHECK-COUNT-16: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16x!tt.ptr, #blocked> -> tensor<16x16x!tt.ptr, #blocked1> tt.return %0 : tensor<16x16x!tt.ptr, #blocked1> @@ -159,21 +174,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_i1(%arg0: tensor<16x16xi1, #blocked>) -> tensor<16x16xi1, #blocked1> { // CHECK-COUNT-16: llvm.zext %{{.*}} : i1 to i8 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 - // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_uc16PU3AS3hDv16_h(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<16xi8>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi8> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi8> // CHECK-COUNT-16: llvm.trunc %{{.*}} : i8 to i1 %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi1, #blocked> -> tensor<16x16xi1, #blocked1> tt.return %0 : tensor<16x16xi1, #blocked1> @@ -191,7 +210,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1> tt.return %0 : tensor<32x16xf32, #blocked1> } @@ -208,7 +247,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1> tt.return %0 : tensor<16x32xf32, #blocked1> } @@ -225,7 +284,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> tt.return %0 : tensor<64x64xf32, #blocked1> } @@ -242,7 +321,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x64x1xf32, #blocked>) -> tensor<64x64x1xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<64x64x1xf32, #blocked> -> tensor<64x64x1xf32, #blocked1> tt.return %0 : tensor<64x64x1xf32, #blocked1> } @@ -258,7 +357,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x64xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>>) -> tensor<64x64xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<64x64xf32, #blocked1> tt.return %0 : tensor<64x64xf32, #blocked1> } @@ -275,7 +394,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>) -> tensor<16x16x1xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<16x16x1xf32, #blocked1> tt.return %0 : tensor<16x16x1xf32, #blocked1> } @@ -292,7 +431,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x16x4xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>) -> tensor<64x16x4xf32, #blocked1> { - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 15 more stores: + // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<64x16x4xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<64x16x4xf32, #blocked1> tt.return %0 : tensor<64x16x4xf32, #blocked1> } @@ -310,31 +469,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> { // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // COM: Offset is double as before as we have double the rows. + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(544 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 31 more stores: + // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1> tt.return %0 : tensor<32x16xf32, #blocked1> @@ -353,31 +507,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> { // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // COM: Offset is double as before as we have double the rows. + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(544 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 31 more stores: + // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1> tt.return %0 : tensor<16x32xf32, #blocked1> @@ -396,31 +545,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) tt.func @test(%arg0: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked1> { // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // COM: Offset is double as before as we have double the rows. + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(544 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 31 more stores: + // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 %0 = triton_gpu.convert_layout %arg0 : tensor<32x64xf32, #blocked> -> tensor<32x64xf32, #blocked1> tt.return %0 : tensor<32x64xf32, #blocked1> @@ -443,3 +587,45 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %0, %1 : tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1> } } + +// ----- + +// Test transposition with sub-group size 32. + +#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // COM: Offset changes with increased number of columns: + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(1056 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[33] + // COM: Check there are 31 more stores: + // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(1056 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]][1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: llvm.load %[[VAL_62]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_63:.*]] = llvm.getelementptr inbounds %[[VAL_62]][1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: %[[VAL_64:.*]] = llvm.getelementptr inbounds %[[VAL_63]][1024] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i32 + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + tt.return %0 : tensor<64x64xf32, #blocked1> + } +} diff --git a/third_party/intel/lib/Analysis/Allocation.cpp b/third_party/intel/lib/Analysis/Allocation.cpp index 70782aaa36..266615a4fa 100644 --- a/third_party/intel/lib/Analysis/Allocation.cpp +++ b/third_party/intel/lib/Analysis/Allocation.cpp @@ -5,6 +5,7 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "intel/include/Analysis/Utility.h" +#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" namespace mlir::triton::intel { namespace { @@ -22,7 +23,16 @@ unsigned allocationAnalysisScratchSizeFn(gpu::ConvertLayoutOp convertLayout) { isa(elemTy) ? kPtrBitWidth / 8 : std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; - return product(srcTy.getShape()) * bytesPerElement; + unsigned numElements = product(srcTy.getShape()); + Attribute encoding = srcTy.getEncoding(); + int subGroupSize = product(gpu::getThreadsPerWarp(encoding)); + assert(numElements % subGroupSize == 0 && + "Sub-group transposable tensors have a number of elements multiple " + "of the sub-group size"); + // Add an element at the end of the row that will not be accessed. This + // allows us to avoid bank conflicts. + unsigned numMatrixCells = (numElements / subGroupSize) * (subGroupSize + 1); + return numMatrixCells * bytesPerElement; } return invalidSize; } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index c20258b263..72d5f7e291 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -726,23 +726,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion rewriter.replaceOp(op, result); } - VectorType - getTypeForSubGroupTranspose(ArrayRef inVals, - ConversionPatternRewriter &rewriter) const { - auto elementTy = cast(inVals.front().getType()); - return elementTy.getWidth() <= 16 ? vec_ty(elementTy, 16) - : vec_ty(elementTy, 8); - } - - Value wrapInVector(Location loc, VectorType type, ArrayRef values, - ConversionPatternRewriter &rewriter) const { - assert(type.getShape()[0] == values.size() && "Size mismatch"); - Value res = rewriter.create(loc, type); - for (auto [index, val] : llvm::enumerate(values)) - res = insert_element(res, val, i32_val(index)); - return res; - } - SmallVector unwrapFromVectors(Location loc, ArrayRef vecs, ConversionPatternRewriter &rewriter) const { @@ -755,21 +738,30 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return res; } + static unsigned getVecLoadWidth(unsigned threadsPerWarp) { + assert(llvm::isPowerOf2_32(threadsPerWarp) && + "Expecting power of 2 sub-group size"); + constexpr unsigned maxVecWidth = 16; + return std::min(maxVecWidth, threadsPerWarp); + } + SmallVector performSubGroupTranspose(Location loc, ArrayRef inVals, ConversionPatternRewriter &rewriter) const { - VectorType opType = getTypeForSubGroupTranspose(inVals, rewriter); + Type elementType = inVals.front().getType(); auto mod = rewriter.getInsertionPoint()->getParentOfType(); - unsigned vecWidth = opType.getShape()[0]; Value smemBase = LLVM::intel::getSharedMemoryBase( loc, rewriter, targetInfo, &*rewriter.getInsertionPoint()); Type ptrType = smemBase.getType(); - int numElements = inVals.size(); + int numRows = inVals.size(); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - int offset = threadsPerWarp; + // Add an element that won't be accessed at the end of the row to avoid bank + // conflicts. + int rowLength = threadsPerWarp + 1; Type offsetType = getTypeConverter()->getIndexType(); + unsigned offsetBitWidth = offsetType.getIntOrFloatBitWidth(); Value subGroupId = getValueOrCreateCastToIndexLike( rewriter, loc, offsetType, rewriter.create( @@ -778,38 +770,48 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion rewriter, loc, offsetType, rewriter.create(loc, /*upper_bound=*/IntegerAttr{})); - Value wiStride = - rewriter.create(loc, offsetType, threadsPerWarp); - Value sgStride = rewriter.create( - loc, offsetType, threadsPerWarp * numElements); - Value subGroupOffset = mul(sgStride, subGroupId); - Type elementType = opType.getElementType(); + Value subGroupOffset = + mul(subGroupId, int_val(offsetBitWidth, rowLength * numRows)); Value subGroupBasePtr = gep(ptrType, elementType, smemBase, ValueRange{subGroupOffset}, /*inbounds=*/true); Value base = subGroupBasePtr; // Store in matrix, transposed - for (ArrayRef vals = inVals; !vals.empty(); - vals = vals.drop_front(vecWidth)) { - ArrayRef curr = vals.take_front(vecWidth); - Value vec = wrapInVector(loc, opType, curr, rewriter); - rewriter.create(loc, base, vec); - base = gep(base.getType(), opType, base, ArrayRef{offset}, + for (Value val : inVals) { + rewriter.create(loc, base, val); + base = gep(base.getType(), elementType, base, + ArrayRef{rowLength}, /*inbounds=*/true); } // Load from matrix, non-trasposed. - // As per sub-group block semantics, we have stored the elements in a matrix - // of `Nxsub_group_size` size, so we need to load back in blocks of - // `sub_group_size` (`N/sub_group_size` loads). - Value workItemOffset = mul(wiStride, subGroupLocalId); + + // Each work-item will load a row (but the last garbage element) and go to + // the next row it needs to handle. + int32_t workItemStride = rowLength * threadsPerWarp; + Value workItemOffset = + mul(subGroupLocalId, int_val(offsetBitWidth, workItemStride)); Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr, ValueRange{workItemOffset}, /*inbounds=*/true); + int32_t rowsPerThread = numRows / threadsPerWarp; + // We may not be able to load rows in a single operation if the sub-group + // size exceeds a given threshold (16): + unsigned vecLoadWidth = getVecLoadWidth(threadsPerWarp); SmallVector transposedVecs; - Type loadTy = vec_ty(opType.getElementType(), threadsPerWarp); - for (std::size_t i = 0, n = inVals.size(); i < n; i += threadsPerWarp) { - transposedVecs.push_back(load(loadTy, workItemBasePtr)); - workItemBasePtr = gep(ptrType, loadTy, workItemBasePtr, - ArrayRef{offset}, /*inbounds=*/true); + VectorType vecType = vec_ty(elementType, vecLoadWidth); + assert(threadsPerWarp % vecLoadWidth == 0 && + "Column must be loadable with N loads"); + for (unsigned i = 0; i < rowsPerThread; ++i) { + for (unsigned j = 0; j < threadsPerWarp; j += vecLoadWidth) { + transposedVecs.push_back(load(vecType, workItemBasePtr)); + workItemBasePtr = gep(workItemBasePtr.getType(), vecType, + workItemBasePtr, ArrayRef{1}, + /*inbounds=*/true); + } + workItemBasePtr = + gep(workItemBasePtr.getType(), elementType, workItemBasePtr, + // "Go back" to the first column and increment by the stride. + ArrayRef{workItemStride - threadsPerWarp}, + /*inbounds=*/true); } return unwrapFromVectors(loc, transposedVecs, rewriter); }