diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp index b03c612dd7..61932d1066 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp @@ -49,9 +49,14 @@ struct AllocateSharedMemory IntegerAttr::get(IntegerType::get(ctx, 32), offset)); }); }); + int32_t initialSharedMemorySize = 0; + if (IntegerAttr sharedAttr = + mod->getAttrOfType("triton_gpu.shared")) + initialSharedMemorySize = sharedAttr.getInt(); mod->setAttr("triton_gpu.shared", IntegerAttr::get(IntegerType::get(ctx, 32), - allocation.getSharedMemorySize())); + initialSharedMemorySize + + allocation.getSharedMemorySize())); } }; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index 7696abd7cf..f23b098393 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -258,6 +258,8 @@ class LoadStorePrefetchOpConversion assert(ptrType.getAddressSpace() == TritonGEN::TritonGENMemorySpace::kWorkgroup && "expecting local space"); + auto elemType = + cast(ptrType.getPointeeType()).getElementType(); MLIRContext *ctx = rewriter.getContext(); Location loc = op.getLoc(); @@ -278,13 +280,12 @@ class LoadStorePrefetchOpConversion base = gep(ptrToSharedMemTy, i16_ty, base, index); if constexpr (std::is_same_v) { - VectorType v64f16Ty = VectorType::get(64, f16_ty); - rewriter.restoreInsertionPoint(insertPoint); TritonGEN::SIMDBlockReadOp simdRead = rewriter.create(loc, v64i16Ty, base); - rewriter.replaceOp(op, simdRead.getRes()); + VectorType v64Ty = VectorType::get(64, elemType); + rewriter.replaceOp(op, bitcast(simdRead.getRes(), v64Ty)); return success(); }