Skip to content

Commit

Permalink
Merge commit '48468af3b4bfd9913d325a7fee660ed2961ce953'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Dec 19, 2024
2 parents c83c0ed + 48468af commit 8d5a01b
Show file tree
Hide file tree
Showing 22 changed files with 582 additions and 537 deletions.
4 changes: 0 additions & 4 deletions include/triton/Conversion/TritonGPUToLLVM/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ namespace triton::gpu {
/// |module| op because the codegen doesn't handle `blocked -> dot_op` directly.
void decomposeBlockedToDotLayoutConversion(ModuleOp module);

/// Replaces `splat -> shared` with `splat -> blocked -> shared` in the given
/// |module| op.
void decomposeSplatOpToSharedLayoutConversion(ModuleOp module);

/// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the
/// given |module| op, but bypass the decomposition if |shortcutFn| returns
/// true.
Expand Down
18 changes: 12 additions & 6 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,19 @@ class DialectInferLayoutInterface

// Tries to compute the encoding for the result of a reshape operation that
// makes the reshape a "nop", i.e. the same GPU threads contain the same
// elements as before the reshape. Note that this is not always possible (in
// which case you'd need to choose a different layout for the input to the
// reshape).
// elements as before the reshape using legacy layouts. This is not always
// possible (in which case we fallback to using LinearLayouts)
// In the future we'll always use LinearLayouts
virtual LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

// Check if two layouts are structurally the same, even if their names are
// different
virtual LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const = 0;

virtual LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
Expand Down
26 changes: 0 additions & 26 deletions lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,6 @@ static void addAttrs(Operation *op, ArrayRef<mlir::NamedAttribute> attrs) {

namespace mlir::triton::gpu {

void decomposeSplatOpToSharedLayoutConversion(ModuleOp module) {
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
module.walk([&](triton::SplatOp splatOp) -> void {
auto dstType = cast<RankedTensorType>(splatOp.getType());
auto shared = dyn_cast_or_null<triton::gpu::SharedEncodingAttr>(
dstType.getEncoding());
if (shared) {
OpBuilder builder(splatOp);
SmallVector<unsigned, 4> sizePerThread(dstType.getRank(), 1);
auto newType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::BlockedEncodingAttr::get(
module.getContext(), dstType.getShape(), sizePerThread,
getOrder(shared), numWarps, threadsPerWarp, numCTAs));
auto newSplat = builder.create<triton::SplatOp>(splatOp.getLoc(), newType,
splatOp.getSrc());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
splatOp.getLoc(), dstType, newSplat.getResult());
splatOp.replaceAllUsesWith(newConvert.getResult());
splatOp.erase();
}
});
}

void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
ShortcutFn shortcutFn) {
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
Expand Down
30 changes: 14 additions & 16 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Tools/LinearLayout.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
Expand Down Expand Up @@ -701,24 +702,21 @@ LogicalResult ReshapeOp::verify() {
"encodings, or (b) neither does.");
}

if (srcEnc && !getAllowReorder()) {
Attribute inferredDstEnc;
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc,
dstTy.getShape(), inferredDstEnc,
getLoc())
.failed()) {
return emitError("This reshape is impossible without reordering, but "
"reordering is not allowed. Try choosing a different "
"encoding for the input tensor (or allow reordering).");
}
if (inferredDstEnc != dstEnc) {
return emitError("Expected result encoding ")
<< inferredDstEnc << " but was " << dstEnc;
}
if (!srcEnc || getAllowReorder()) {
return success();
}

return success();
// Check that we can infer the dst encoding from the src encoding
// and that the inferred dst encoding is the same as the given dst encoding
Attribute inferredDstEnc;
auto result =
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(),
inferredDstEnc, getLoc());
assert(succeeded(result));
return cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc,
getLoc());
}

//-- FpToFpOp --
Expand Down
119 changes: 80 additions & 39 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1653,11 +1653,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

SmallVector<unsigned>
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
// We can relax this assert by calling toLinearLayout rather than
// getLinearLayout
SmallVector<int32_t> shapeVec(shape.begin(), shape.end());
assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes()));
auto ll = getLinearLayout();
// When broadcasting the layout the shape changes, otherwise the shape is
// the same as the shape of the tensor
// We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
// the invariant that the shape of the LL is that of the tensor
// We choose the former for BC
auto ll = *toLinearLayout(shape);
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
}

Expand Down Expand Up @@ -2681,8 +2682,8 @@ struct TritonGPUInferLayoutInterface
// contains elements [a,b,c,d] before the reshape, it contains those same
// elements after the reshape, they're just "renamed".
//
// A dst encoding that satisfies this property does not exist for all inputs.
// Here are some positive and negative examples.
// Using legacy layouts, a dst encoding that satisfies this property may not
// exist. Here are some positive and negative examples.
//
// - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
// dim 1 is the fastest-changing in the dst, but the src has the opposite
Expand All @@ -2696,17 +2697,19 @@ struct TritonGPUInferLayoutInterface
// - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
// contain the same elements as before.
//
// With linear layouts, we can always find a dst encoding that satisfies
// this property. See inferReshapeOpEncoding.
//
// Users of this function require that it is symmetrical: if
// (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
// srcEnc.
LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
LogicalResult inferReshapeOpLegacyEncoding(ArrayRef<int64_t> srcShape,
Attribute srcEnc,
ArrayRef<int64_t> dstShape,
Attribute &dstEnc) const {
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
if (!src) {
return emitOptionalError(
loc, "Non-reordering reshape only supports BlockedEncoding");
return failure();
}

// Nop reshape; we can always infer an encoding.
Expand Down Expand Up @@ -2739,9 +2742,7 @@ struct TritonGPUInferLayoutInterface
// to handle CTASplitNum.
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
return emitOptionalError(
loc, "Non-reordering reshape does not currently support multi-CTA "
"layouts other than the default layout.");
return failure();
}

// Cowardly refuse to handle encodings where shape[dim] is not divisible by
Expand All @@ -2751,12 +2752,7 @@ struct TritonGPUInferLayoutInterface
for (int dim = 0; dim < srcShape.size(); dim++) {
if (srcShape[dim] >= subblock[dim] &&
srcShape[dim] % subblock[dim] != 0) {
return emitOptionalError(loc,
"Can't do a non-reordering reshape because "
"the size of dimension ",
dim, " (", srcShape[dim], ")",
" is not divisible by ", name, "[", dim, "]",
" = ", subblock[dim]);
return failure();
}
}
return success();
Expand All @@ -2781,11 +2777,7 @@ struct TritonGPUInferLayoutInterface
// physical order, with `a` being the most major.
for (const auto &[srcDims, dstDims] : decomp) {
if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) {
return emitOptionalError(loc,
"Cannot do a non-reordering reshape given "
"this src encoding order. Dimensions [",
join(srcDims),
"] must be physically consecutive.");
return failure();
}
}

Expand Down Expand Up @@ -2832,11 +2824,7 @@ struct TritonGPUInferLayoutInterface
// Check that more-minor dims all have 1 in shapeRemaining.
for (int j = i + 1; j < srcDims.size(); j++) {
if (shapeRemaining[j] != 1) {
return emitOptionalError(
loc,
"Invalid src encoding for non-reordering reshape. Must use "
"up sizePerThread / threadsPerWarp / warpsPerCTA for "
"more-minor dimensions before more major-dims can use them.");
return failure();
}
}

Expand All @@ -2851,13 +2839,7 @@ struct TritonGPUInferLayoutInterface
// only if we're the most-major dimension of the chunk and in all
// future chunks, only this most-major dim has a non-1 size.
if (shapeRemaining[i] == 0 && i != 0) {
return emitOptionalError(
loc,
"Invalid src encoding for non-reordering reshape. Block "
"size in dimension ",
dim,
" is larger than the shape that dimension, but this is only "
"allowed for the most-major dimension of a reshape chunk");
return failure();
}
}
return success();
Expand Down Expand Up @@ -2947,6 +2929,65 @@ struct TritonGPUInferLayoutInterface
return success();
}

LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const override {
if (expected == got) {
return success();
}
// Check whether the encodings are structurally the same.
auto expectedLL = triton::gpu::toLinearLayout(shape, expected);
auto gotLL = triton::gpu::toLinearLayout(shape, got);
if (expectedLL != gotLL) {
return emitError(loc, "Expected result encoding ")
<< expected << " but was " << got;
}
return success();
}

LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
auto result =
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
if (succeeded(result)) {
return result;
}

// If the legacy encoding failed use LinearLayouts.
// Once LinearLayouts are more widely used, we can remove
// inferReshapeOpLegacyEncoding and simply use LLs.
auto *ctx = getContext();
auto src = triton::gpu::toLinearLayout(srcShape, srcEnc);
if (!src) {
return emitOptionalError(loc,
"src encoding does not support linear layout");
}

if (product(srcShape) != product(dstShape)) {
return emitOptionalError(loc, "numel of dst shape does not match "
"numel of src shape");
}

auto newRank = dstShape.size();
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto [dim, size] :
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
newOutDims.emplace_back(dim, size);
}
auto srcOutDims = llvm::to_vector(src->getOutDimNames());
// reshapeOp assumes minor-to-major, so we need to transpose the out dims
// before the reshape
std::reverse(srcOutDims.begin(), srcOutDims.end());
std::reverse(newOutDims.begin(), newOutDims.end());
auto dst = src->transposeOuts(srcOutDims)
.reshapeOuts(newOutDims)
.transposeOuts(standardOutDimNames(ctx, newRank));
dstEnc = LinearEncodingAttr::get(ctx, dst);
return success();
}

LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
std::optional<Location> loc) const override {
Expand Down
Loading

0 comments on commit 8d5a01b

Please sign in to comment.