From eed3d95e81bd39a5bcf32d1b4fe47a76d53a5f70 Mon Sep 17 00:00:00 2001 From: "Wang, Quintin" Date: Wed, 11 Sep 2024 02:38:52 +0000 Subject: [PATCH] fix FlashAttention SLM path --- .../lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp | 11 ++++++++--- .../lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 6 ++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp index b03c612dd7..1c48ca8da5 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp @@ -49,9 +49,14 @@ struct AllocateSharedMemory IntegerAttr::get(IntegerType::get(ctx, 32), offset)); }); }); - mod->setAttr("triton_gpu.shared", - IntegerAttr::get(IntegerType::get(ctx, 32), - allocation.getSharedMemorySize())); + int32_t originSharedSize = 0; + if (IntegerAttr sharedAttr = + mod->getAttrOfType("triton_gpu.shared")) + originSharedSize = sharedAttr.getInt(); + mod->setAttr( + "triton_gpu.shared", + IntegerAttr::get(IntegerType::get(ctx, 32), + originSharedSize + allocation.getSharedMemorySize())); } }; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index 7696abd7cf..c909a3549f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -258,6 +258,8 @@ class LoadStorePrefetchOpConversion assert(ptrType.getAddressSpace() == TritonGEN::TritonGENMemorySpace::kWorkgroup && "expecting local space"); + auto elemType = + cast(ptrType.getPointeeType()).getElementType(); MLIRContext *ctx = rewriter.getContext(); Location loc = op.getLoc(); @@ -278,13 +280,13 @@ class LoadStorePrefetchOpConversion base = gep(ptrToSharedMemTy, i16_ty, base, index); if constexpr (std::is_same_v) { - VectorType v64f16Ty = VectorType::get(64, f16_ty); + VectorType v64Ty = VectorType::get(64, elemType); rewriter.restoreInsertionPoint(insertPoint); TritonGEN::SIMDBlockReadOp simdRead = rewriter.create(loc, v64i16Ty, base); - rewriter.replaceOp(op, simdRead.getRes()); + rewriter.replaceOp(op, bitcast(simdRead.getRes(), v64Ty)); return success(); }