Skip to content

Commit

Permalink
fix FlashAttention SLM path
Browse files Browse the repository at this point in the history
  • Loading branch information
quintinwang5 committed Sep 11, 2024
1 parent 5297206 commit eed3d95
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,14 @@ struct AllocateSharedMemory
IntegerAttr::get(IntegerType::get(ctx, 32), offset));
});
});
mod->setAttr("triton_gpu.shared",
IntegerAttr::get(IntegerType::get(ctx, 32),
allocation.getSharedMemorySize()));
int32_t originSharedSize = 0;
if (IntegerAttr sharedAttr =
mod->getAttrOfType<IntegerAttr>("triton_gpu.shared"))
originSharedSize = sharedAttr.getInt();
mod->setAttr(
"triton_gpu.shared",
IntegerAttr::get(IntegerType::get(ctx, 32),
originSharedSize + allocation.getSharedMemorySize()));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ class LoadStorePrefetchOpConversion
assert(ptrType.getAddressSpace() ==
TritonGEN::TritonGENMemorySpace::kWorkgroup &&
"expecting local space");
auto elemType =
cast<RankedTensorType>(ptrType.getPointeeType()).getElementType();

MLIRContext *ctx = rewriter.getContext();
Location loc = op.getLoc();
Expand All @@ -278,13 +280,13 @@ class LoadStorePrefetchOpConversion
base = gep(ptrToSharedMemTy, i16_ty, base, index);

if constexpr (std::is_same_v<OpType, LoadOp>) {
VectorType v64f16Ty = VectorType::get(64, f16_ty);
VectorType v64Ty = VectorType::get(64, elemType);

rewriter.restoreInsertionPoint(insertPoint);

TritonGEN::SIMDBlockReadOp simdRead =
rewriter.create<TritonGEN::SIMDBlockReadOp>(loc, v64i16Ty, base);
rewriter.replaceOp(op, simdRead.getRes());
rewriter.replaceOp(op, bitcast(simdRead.getRes(), v64Ty));

return success();
}
Expand Down

0 comments on commit eed3d95

Please sign in to comment.