From b2e937833e6c9dfa621023fdcfe6e54cb5086849 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 11 Sep 2024 04:59:37 +0000 Subject: [PATCH 1/7] Decompose large simdblockread to smaller simdblockreads Signed-off-by: Whitney Tsang --- .../lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index e3da9dc1a8..25c2ee0db7 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -282,8 +282,17 @@ class LoadStorePrefetchOpConversion if constexpr (std::is_same_v) { rewriter.restoreInsertionPoint(insertPoint); - TritonGEN::SIMDBlockReadOp simdRead = - rewriter.create(loc, v64i16Ty, base); + VectorType v8i16Ty = VectorType::get(8, i16_ty); + SmallVector values; + Value offset = i32_val(128); + for (int i = 0; i < 8; ++i) { + auto simdRead = + rewriter.create(loc, v8i16Ty, base); + values.push_back(simdRead.getRes()); + base = gep(ptrToSharedMemTy, i16_ty, base, offset); + } + auto simdRead = rewriter.create(loc, v64i16Ty, values); + VectorType v64Ty = VectorType::get(64, elemType); rewriter.replaceOp(op, bitcast(simdRead.getRes(), v64Ty)); From 362e4ba8efacd434e984e6af4fab31e058cf80ec Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 11 Sep 2024 16:19:19 +0000 Subject: [PATCH 2/7] address review comments Signed-off-by: Whitney Tsang --- .../intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp | 4 ++-- .../lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 696a5b4430..89bea2b8e1 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -1263,9 +1263,9 @@ static std::string getSIMDBlockManglingName(OpType op, VectorType vecTy) { funcName = "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" + std::to_string(ptrTy.getAddressSpace()) + - intel::getTypeMangling(vecTy.getElementType(), true /*isUnsigned*/); + intel::getTypeMangling(vecTy.getElementType(), /*isUnsigned=*/true); if constexpr (isWrite) - funcName += intel::getTypeMangling(vecTy, true /*isUnsigned*/); + funcName += intel::getTypeMangling(vecTy, /*isUnsigned=*/true); return funcName; } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index 25c2ee0db7..eb71e500bc 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -282,14 +282,17 @@ class LoadStorePrefetchOpConversion if constexpr (std::is_same_v) { rewriter.restoreInsertionPoint(insertPoint); - VectorType v8i16Ty = VectorType::get(8, i16_ty); + const unsigned maxOCLVectorSize = 8; + VectorType v8i16Ty = VectorType::get(maxOCLVectorSize, i16_ty); + auto mod = op->template getParentOfType(); + Value offset = + i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); SmallVector values; - Value offset = i32_val(128); - for (int i = 0; i < 8; ++i) { + for (int i = 0; i < 64 / maxOCLVectorSize; ++i) { auto simdRead = rewriter.create(loc, v8i16Ty, base); values.push_back(simdRead.getRes()); - base = gep(ptrToSharedMemTy, i16_ty, base, offset); + base = gep(ptrToSharedMemTy, v8i16Ty, base, offset); } auto simdRead = rewriter.create(loc, v64i16Ty, values); From cf3c5c322178d46cbdd3e59c44833478b391fe09 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 11 Sep 2024 21:32:55 +0000 Subject: [PATCH 3/7] Update lit test Signed-off-by: Whitney Tsang --- ...tritongpu_to_llvm_intel_advanced_path.mlir | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir index 30eca7568c..bbd8d6c8a7 100644 --- a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir +++ b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir @@ -231,7 +231,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : #dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> #dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}> module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { - // CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockRead(!llvm.ptr<3>) -> vector<64xi16> + // CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t(!llvm.ptr<3>) -> vector<8xi16> // CHECK-LABEL: @slm_load tt.func public @slm_load(%arg0: !tt.ptr) { %c0_i32 = arith.constant 0 : i32 @@ -239,7 +239,31 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war %c1_i64 = arith.constant 1 : i64 %c64_i64 = arith.constant 64 : i64 %ptr = tt.make_tensor_ptr %arg0, [%c0_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 3> - // CHECK: {{.*}} = llvm.call spir_funccc @llvm.genx.GenISA.simdBlockRead({{.*}}) {{.*}} : (!llvm.ptr<3>) -> vector<64xi16> + // CHECK-COUNT-2: [[SUBGROUP_SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[READ1:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> + // CHECK: [[READ2:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> + // CHECK: [[READ3:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> + // CHECK: [[READ4:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> + // CHECK: [[READ5:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> + // CHECK: [[READ6:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> + // CHECK: [[READ7:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> + // CHECK: [[READ8:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> + // CHECK: [[GLUE1:%.*]] = llvm.shufflevector [[READ1]], [[READ2]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16> + // CHECK: [[GLUE2:%.*]] = llvm.shufflevector [[READ3]], [[READ4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16> + // CHECK: [[GLUE3:%.*]] = llvm.shufflevector [[READ5]], [[READ6]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16> + // CHECK: [[GLUE4:%.*]] = llvm.shufflevector [[READ7]], [[READ8]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16> + // CHECK: [[GLUE5:%.*]] = llvm.shufflevector [[GLUE1]], [[GLUE2]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xi16> + // CHECK: [[GLUE6:%.*]] = llvm.shufflevector [[GLUE3]], [[GLUE4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xi16> + // CHECK: [[READ:%.*]] = llvm.shufflevector [[GLUE5]], [[GLUE6]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xi16> + // CHECK: llvm.bitcast [[READ]] : vector<64xi16> to vector<64xf16> %ld = tt.load %ptr {DotIdx = 0 : i32} : !tt.ptr, 3> tt.return } From ea9edbd1ba5801e3f104080c7281b9c414582ba3 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 12 Sep 2024 02:19:09 +0000 Subject: [PATCH 4/7] Restrict TritonGEN::SIMDBlockReadOp Signed-off-by: Whitney Tsang --- test/TritonGEN/tritongen-invalid.mlir | 8 ++++++++ test/TritonGEN/tritongen-to-llvm.mlir | 11 ----------- test/TritonGEN/tritongen.mlir | 4 ++-- .../Dialect/TritonGEN/IR/TritonGENOps.td | 2 ++ .../lib/Dialect/TritonGEN/IR/TritonGENOps.cpp | 17 +++++++++++++++++ .../lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp | 6 +----- 6 files changed, 30 insertions(+), 18 deletions(-) diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index bebd25643e..e6d4455a66 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -428,3 +428,11 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } + +// ----- + +llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) { + // expected-error @+1 {{'triton_gen.simdblockread' op unsupported vector type}} + %ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16> + llvm.return +} diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index 469a764de2..89878a6169 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -278,17 +278,6 @@ llvm.func @triton_gen.dpas.bf16_accum(%c: vector<8xbf16>, %a : vector<8xi16>, %b // ----- -// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockRead(!llvm.ptr<3>) -> vector<64xi16> attributes {passthrough = ["nounwind", "willreturn", ["memory", "1"]]} - -llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) { - // CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr<3>) { - // CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockRead(%arg0) {{.*}} : (!llvm.ptr<3>) -> vector<64xi16> - %ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16> - llvm.return -} - -// ----- - // CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(!llvm.ptr<3>) -> vector<2xi16> attributes {passthrough = ["nounwind", "willreturn", ["memory", "1"]]} llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) { diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index 30f0ed05bd..a603992c7c 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -156,8 +156,8 @@ llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base llvm.func @triton_gen.simdblockread(%ptr : !llvm.ptr) { // CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr) { - // CHECK-NEXT: triton_gen.simdblockread %arg0 : (!llvm.ptr) -> vector<64xi16> - triton_gen.simdblockread %ptr : (!llvm.ptr) -> vector<64xi16> + // CHECK-NEXT: triton_gen.simdblockread %arg0 : (!llvm.ptr) -> vector<2xi16> + triton_gen.simdblockread %ptr : (!llvm.ptr) -> vector<2xi16> llvm.return } diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index 66cde56ea1..16d2fd8f41 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -413,6 +413,8 @@ def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">, let assemblyFormat = [{ operands ` ` attr-dict `:` functional-type(operands, results) }]; + + let hasVerifier = 1; } def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">, diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index fb94262c60..351599efa6 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -435,3 +435,20 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() { return success(); } + +//===----------------------------------------------------------------------===// +// gen.simdblockread +//===----------------------------------------------------------------------===// + +LogicalResult TritonGEN::SIMDBlockReadOp::verify() { + VectorType vecTy = getRes().getType(); + unsigned numElems = vecTy.getNumElements(); + IntegerType elemTy = cast(vecTy.getElementType()); + + // FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it. + if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 && + (elemTy.getWidth() != 8 || numElems != 16)) + return emitOpError("unsupported vector type"); + + return success(); +} diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 89bea2b8e1..5ad4095bbc 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -1280,11 +1280,7 @@ struct TritonSIMDBlockReadLowering LLVM::LLVMPointerType ptrTy = op.getPtr().getType(); VectorType vecTy = op.getRes().getType(); - // TODO: Remove GenISA lowering after PoC productization is completed. - std::string funcName = "llvm.genx.GenISA.simdBlockRead"; - if (isTySIMDOCLBuiltinAvailable(vecTy)) - funcName = getSIMDBlockManglingName(op, vecTy); - + std::string funcName = getSIMDBlockManglingName(op, vecTy); intel::AttributeList attrs = createFunctionAttributes( {{llvm::Attribute::NoUnwind, std::nullopt}, {llvm::Attribute::WillReturn, std::nullopt}, From 083679197c8e6b19fbd05bb8b2e579c7754128f6 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 12 Sep 2024 16:05:18 +0000 Subject: [PATCH 5/7] address review comments Signed-off-by: Whitney Tsang --- .../intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index eb71e500bc..ac131db8f0 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -282,13 +282,13 @@ class LoadStorePrefetchOpConversion if constexpr (std::is_same_v) { rewriter.restoreInsertionPoint(insertPoint); - const unsigned maxOCLVectorSize = 8; - VectorType v8i16Ty = VectorType::get(maxOCLVectorSize, i16_ty); + constexpr unsigned maxBlockLoadi16Width = 8; + VectorType v8i16Ty = VectorType::get(maxBlockLoadi16Width, i16_ty); auto mod = op->template getParentOfType(); Value offset = i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); SmallVector values; - for (int i = 0; i < 64 / maxOCLVectorSize; ++i) { + for (int i = 0; i < 64 / maxBlockLoadi16Width; ++i) { auto simdRead = rewriter.create(loc, v8i16Ty, base); values.push_back(simdRead.getRes()); From 0cf05a96124a6fb432d717afde5b110fd6d966fa Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 12 Sep 2024 22:56:55 +0000 Subject: [PATCH 6/7] minor changes Signed-off-by: Whitney Tsang --- .../tritongpu_to_llvm_intel_advanced_path.mlir | 15 +++++++-------- .../lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 9 +++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir index bbd8d6c8a7..99b2f5c235 100644 --- a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir +++ b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir @@ -242,20 +242,19 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // CHECK-COUNT-2: [[SUBGROUP_SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[READ1:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16> // CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> - // CHECK: [[READ2:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[READ2:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16> // CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> - // CHECK: [[READ3:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[READ3:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1]]) {{.*}} -> vector<8xi16> // CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> - // CHECK: [[READ4:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[READ4:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16> // CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> - // CHECK: [[READ5:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[READ5:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1]]) {{.*}} -> vector<8xi16> // CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> - // CHECK: [[READ6:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[READ6:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16> // CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> - // CHECK: [[READ7:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16> + // CHECK: [[READ7:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1]]) {{.*}} -> vector<8xi16> // CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> - // CHECK: [[READ8:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16> - // CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16> + // CHECK: [[READ8:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16> // CHECK: [[GLUE1:%.*]] = llvm.shufflevector [[READ1]], [[READ2]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16> // CHECK: [[GLUE2:%.*]] = llvm.shufflevector [[READ3]], [[READ4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16> // CHECK: [[GLUE3:%.*]] = llvm.shufflevector [[READ5]], [[READ6]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16> diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index ac131db8f0..bf14bc4d5c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -283,16 +283,17 @@ class LoadStorePrefetchOpConversion rewriter.restoreInsertionPoint(insertPoint); constexpr unsigned maxBlockLoadi16Width = 8; - VectorType v8i16Ty = VectorType::get(maxBlockLoadi16Width, i16_ty); + VectorType decomposedVecTy = + VectorType::get(maxBlockLoadi16Width, i16_ty); auto mod = op->template getParentOfType(); Value offset = i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); SmallVector values; for (int i = 0; i < 64 / maxBlockLoadi16Width; ++i) { - auto simdRead = - rewriter.create(loc, v8i16Ty, base); + auto simdRead = rewriter.create( + loc, decomposedVecTy, base); values.push_back(simdRead.getRes()); - base = gep(ptrToSharedMemTy, v8i16Ty, base, offset); + base = gep(ptrToSharedMemTy, decomposedVecTy, base, offset); } auto simdRead = rewriter.create(loc, v64i16Ty, values); From 2282a9bb39e3052c6a3f803e0c6070b05fb35a76 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 13 Sep 2024 15:00:32 +0000 Subject: [PATCH 7/7] fix build Signed-off-by: Whitney Tsang --- third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index ca1c327863..a650c0c811 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -301,7 +301,8 @@ class LoadStorePrefetchOpConversion values.push_back(simdRead.getRes()); base = gep(ptrToSharedMemTy, decomposedVecTy, base, offset); } - auto simdRead = rewriter.create(loc, v64i16Ty, values); + auto simdRead = + rewriter.create(loc, v64i16Ty, values); VectorType v64Ty = VectorType::get(64, elemType); rewriter.replaceOp(op, bitcast(simdRead.getRes(), v64Ty));