Skip to content

Commit

Permalink
Merge commit '73df068b8e24d68f7afe776e798db12a75ba9271'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Nov 6, 2024
2 parents 555d666 + 73df068 commit d96a80e
Show file tree
Hide file tree
Showing 17 changed files with 676 additions and 515 deletions.
11 changes: 11 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ unsigned getNumWarpsPerCTA(Attribute layout);

unsigned getNumCTAs(Attribute layout);

// Return the order that represents that the batch is in row-major or
// column-major order for a batch of matrices of shape [*, m, n] with
// len(shape) == rank.
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);

// Return the order that represents that the dot operand is in kMajor
// (contiguous in the inner dimension) or it's contiguous on the outer
// dimension.
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

// Return true if a view between the two types cannot be implemented as a no-op.
Expand Down
74 changes: 34 additions & 40 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,19 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
return resOrder;
}

SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
// Return the order that represents that the batch is in row-major or
// column-major order for a batch of matrices of shape [*, m, n] with
// len(shape) == rank.
assert(rank >= 2);
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
if (!rowMajor) {
std::swap(order[0], order[1]);
}
return order;
}

SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor) {
// kMajor: if true, the matrix is fastest-running on k,
Expand All @@ -247,15 +260,8 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
// batch (if rank == 3) is always the slowest running dimension
assert(rank == 2 || rank == 3);
assert(opIdx == 0 || opIdx == 1);
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
// If opIdx is 1 and kMajor is true, the order is [0, 1]
// (resp. [1, 2, 0] if rank == 3)
// Same if opIdx is 0 and kMajor is false
if (bool(opIdx) == kMajor) {
std::swap(order[0], order[1]);
}
return order;
auto rowMajor = bool(opIdx) != kMajor;
return getMatrixOrder(rank, rowMajor);
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
Expand All @@ -265,20 +271,21 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
}
}
auto order = getOrder(layout);
// FIXME: This mmaLayout if should just return
// getOrderForDotOperand(0, order.size(), kMajor=false)
// as mma has the same order as DotOperand(opIdx=0)
// FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
// M-major This is awkward. Since we can choose any warpOrder in Ampere, we
// should probably choose M-major and change `LinearLayoutConversion.cpp` and
// `MMAv2.cpp` to match.
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (mmaLayout.isHopper()) {
// Hopper MMA instructions force a warp order of [0, 1]. See docs:
// Hopper MMA instructions force warps to be column-major
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
auto it = std::find(order.begin(), order.end(), 0);
order.erase(it);
order.insert(order.begin(), 0);
return getMatrixOrder(order.size(), /*rowMajor*/ false);
}
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
/*kMajor*/ false);
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error(
"DotOperandEncoding::getWarpOrder not implemented");
}
return order;
}
Expand All @@ -288,11 +295,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
return llvm::to_vector(blockedLayout.getOrder());
}
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
// Order doesn't really matter. We just have to be consistent when unpacking
// the elements in the MMAv2/V3 lowerings. We choose row-major
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
auto rank = distributedLayout.getWarpsPerCTA().size();
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
return order;
return getMatrixOrder(rank, /*rowMajor*/ true);
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto rank = dotLayout.getWarpsPerCTA().size();
Expand Down Expand Up @@ -434,7 +441,7 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout))
warpsPerCTA = wmmaLayout.getWarpsPerCTA();
else if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout))
return getNumWarpsPerCTA(dotLayout.getParent());
warpsPerCTA = dotLayout.getWarpsPerCTA();
else if (auto sharedLayout = dyn_cast<SharedEncodingAttr>(layout))
llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr");
else
Expand Down Expand Up @@ -2176,25 +2183,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
assert(isAmpere() && "mmaLayout version = 1 is not implemented yet");
auto parentShapePerCTATile = getShapePerCTATile(shape);
auto rank = parentShapePerCTATile.size();
auto shapePerCTATile = getShapePerCTATile(shape);
auto rank = shapePerCTATile.size();
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
// 4 threads * 2 subtiles
unsigned kWidthTile = kWidth * 2 * 4;
if (opIdx == 0) {
if (rank == 2)
return {parentShapePerCTATile[rank - 2], kWidthTile};
else
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2],
kWidthTile};
} else if (opIdx == 1) {
if (rank == 2)
return {kWidthTile, parentShapePerCTATile[rank - 1]};
else
return {parentShapePerCTATile[0], kWidthTile,
parentShapePerCTATile[rank - 1]};
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
shapePerCTATile[kDim] = kWidth * 2 * 4;
return shapePerCTATile;
}
SmallVector<unsigned>
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
Expand Down
104 changes: 78 additions & 26 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
return ret;
}

// TODO Have order be a mandatory argument of standardOutDimNames.
SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
const SmallVector<unsigned> &order) {
assert(names.size() == order.size());
SmallVector<StringAttr> ret;
for (unsigned i : order) {
ret.push_back(names[i]);
}
return ret;
}

void assertIsRegisterLayout(const LinearLayout &layout) {
assert(layout.getNumInDims() > 0);
MLIRContext *ctx = layout.getInDimNames().begin()->getContext();
Expand Down Expand Up @@ -282,15 +293,19 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma));
// By using `reverse(dimNames)` below, we set the order to be row-major
assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));

LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));

ctaLayout *= identityND(
S("warp"), mma.getWarpsPerCTA(),
llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank))), dimNames);
ArrayRef(orderedDimNames).take_front(2));
assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
// FIXME(Lezcano). identityND should not have an `order` param as it's
// redundant with the order of the out dims.
ctaLayout *=
identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames);

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}
Expand Down Expand Up @@ -323,10 +338,14 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")),
S("register"), S("dim1"));

// Expand the `warp` dimension according to warpsPerCTA.
//
// It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but
// this really does seem to be correct.
// The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
// Since the warpOrder needs to be M-major, we need to transpose the out
// dimensions AND transpose the order
// FIXME(Lezcano). identityND should not have an `order` param as it's
// redundant. The order is already given by the order of the
// out dims, and if it has an order, it shouldn't change the
// order of the out dims.
assert(getWarpOrder(mma) == SmallVector<unsigned>({0, 1}));
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1},
{S("dim0"), S("dim1")})
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
Expand Down Expand Up @@ -844,18 +863,24 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
// TODO,BE. Implement ampereMMA in terms of this one
// Note that, even though MMAv2 looks similar to this layout, they are just
// the same at a register and lane level. The warps treatment is different!
int rank = shape.size();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
int kWidth = dot.getKWidth();
bool isA = dot.getOpIdx() == 0;

assert(mma.isAmpere());
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
assert(mma.isAmpere());

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
// A and B have kMajor order
assert(getOrder(dot) ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

auto kMajorDims =
permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot));

// Implement A. For B transpose in the end
std::vector<std::vector<int32_t>> registers;
Expand All @@ -882,24 +907,51 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
}
registers.push_back({i, 0});

if (!isA) {
for (auto &r : registers) {
std::swap(r[0], r[1]);
LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}},
ArrayRef(kMajorDims).take_front(2));

// Let warpsPerCTAMma = {2, 2}, then
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
// assume warpOrder = {0, 1}
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
// the C is owned as per the following layout:
// C: 0 | 1
// - | -
// 2 | 3
// In order to be able to compute C, we need the following warp tiling of
// A and B:
// A: 0 1 | 0 1 B: 0 2 | 1 3
// - - | - - - - | - -
// 2 3 | 2 3 0 2 | 1 3
// In particular, for A and B we need to broadcast along K

assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true));
auto warpsPerCTAMma = mma.getWarpsPerCTA();
std::vector<std::vector<int32_t>> warps;
if (isA) {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
warps.push_back({0, 0});
}
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, i});
}
} else {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
warps.push_back({0, i});
}
for (auto &l : lanes) {
std::swap(l[0], l[1]);
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, 0});
}
}
if (rank == 3) {
for (auto &w : warps) {
w.push_back(0);
}
}

LinearLayout ctaLayout(
{{S("register"), registers}, {S("lane"), lanes}},
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));

auto order = dot.getCTAOrder();
assert(order[0] == rank - 1 && order[1] == rank - 2);
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims);

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}

std::optional<LinearLayout>
Expand All @@ -908,7 +960,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
return mfmaDotToLinearLayout(*this, shape);
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) {
if (mma.isAmpere()) {
return ampereDotToLinearLayout(shape, *this);
}
} else if (auto dpasLayout =
Expand Down
52 changes: 49 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "mlir/IR/Dominance.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand All @@ -14,8 +15,52 @@ namespace gpu {
#define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

// Return true if the select could be merged into the If without breaking SSA
// rules.
/// The user of select maybe inside either the ThenRegion or ElseRegion of
/// the scf.if. So, canonicalize user of select in scf.if first.
static void canonicalizeSelectUsersInSCFIf(ModuleOp input) {
llvm::MapVector<std::pair<Value, Value>, SmallVector<Operation *>>
usersNeedreplaced;
input.walk([&](arith::SelectOp selectOp) {
auto *parentBlock = selectOp->getBlock();
Value condition = selectOp.getOperand(0);
Value trueVal = selectOp.getOperand(1);
Value falseVal = selectOp.getOperand(2);
Value resVal = selectOp.getResult();
for (auto *condUser : condition.getUsers()) {
if (!llvm::isa<scf::IfOp>(condUser))
continue;
scf::IfOp ifOp = llvm::cast<scf::IfOp>(condUser);
for (auto *resUser : resVal.getUsers()) {
if (ifOp->isProperAncestor(resUser)) {
if (ifOp.getThenRegion().findAncestorOpInRegion(*resUser) !=
nullptr) {
// The user is inside the ThenRegion of the scf.if.
usersNeedreplaced[std::make_pair(resVal, trueVal)].push_back(
resUser);
} else {
// The user is inside the ElseRegion of the scf.if.
usersNeedreplaced[std::make_pair(resVal, falseVal)].push_back(
resUser);
}
}
}
}
});

// Replace the operand of user.
for (auto [replacedSrcAndDst, users] :
llvm::make_early_inc_range(usersNeedreplaced)) {
Value srcVal = replacedSrcAndDst.first;
Value dstVal = replacedSrcAndDst.second;
for (Operation *user : llvm::make_early_inc_range(users)) {
srcVal.replaceUsesWithIf(
dstVal, [&](OpOperand &use) { return use.getOwner() == user; });
}
}
}

/// Return true if the select could be merged into the If without breaking SSA
/// rules.
static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp,
DominanceInfo &dom) {
// If needs to be dominated by the select.
Expand All @@ -38,10 +83,11 @@ class CombineTensorSelectAndIfPass
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
DominanceInfo dom(m);
canonicalizeSelectUsersInSCFIf(m);

// Go over the arith.select ops, look if there is an if
// with the same condition.
DominanceInfo dom(m);
llvm::MapVector<scf::IfOp, SmallVector<arith::SelectOp>> selectToIf;
m.walk([&](arith::SelectOp selectOp) {
// Look if there is an if in the same block, with the same condition.
Expand Down
Loading

0 comments on commit d96a80e

Please sign in to comment.