Skip to content

Commit

Permalink
Merge commit '38a11b859fff79ea214256d3f1cfe43d54e36c2c'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Nov 7, 2024
2 parents d7095ce + 38a11b8 commit 7573b11
Show file tree
Hide file tree
Showing 18 changed files with 82 additions and 75 deletions.
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
b74e588e1f460eb48ceb1a30cf8ac870b7537dcc
fa57c7a6a5f594a9e3ae2dbe3542cf89a20cdd73
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ReduceOpHelper {
// The shape of the shared memory space needed for the reduction.
SmallVector<unsigned> getScratchRepShape();

SmallVector<unsigned> getOrderWithAxisAtBeginning();
SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();

unsigned getScratchSizeInBytes();

Expand Down
3 changes: 3 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
auto dstShapePerCTATile =
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());

assert(srcTy.getRank() == dstTy.getRank() &&
"src and dst must have the same rank");

unsigned rank = dstTy.getRank();
SmallVector<unsigned> repShape(rank);
for (unsigned d = 0; d < rank; ++d) {
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {

// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
auto order = triton::gpu::getThreadOrder(layout);
unsigned align = getPtrAlignment(ptr);

auto uniqueContigPerThread =
Expand All @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
if (!axisInfo)
return 1;
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
auto order = triton::gpu::getThreadOrder(layout);
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
Expand All @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding());
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down
10 changes: 5 additions & 5 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ int getParentAxis(Attribute layout, int axis) {
return axis;
}

SmallVector<unsigned> getParentOrder(Attribute layout) {
SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
return getParentThreadOrder(sliceEncoding.getParent());
}
return getThreadOrder(layout);
}
Expand All @@ -46,12 +46,12 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
// TODO(jlebar): Move this class into namespace triton.
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentOrder(getSrcLayout())[0];
getParentThreadOrder(getSrcLayout())[0];
}

SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
SmallVector<unsigned> ReduceOpHelper::getThreadOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ struct ReduceOpConversion
getMultiDimWarpId(helper, warpId, loc, rewriter);
Value warpIdAxis = multiDimWarpId[axis];

auto smemOrder = helper.getOrderWithAxisAtBeginning();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = it.second;
Expand Down Expand Up @@ -370,7 +370,7 @@ struct ReduceOpConversion
Location loc = op.getLoc();
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Expand Down
7 changes: 3 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
int numContiguousValues = 1;
auto encoding = cast<BlockedEncodingAttr>(
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
int splitDim = encoding.getOrder().size() - 1;
for (int i = 0; i < encoding.getOrder().size(); i++) {
if (encoding.getOrder()[i] == splitDim)
int splitDim = encoding.getThreadOrder().size() - 1;
for (int i = 0; i < encoding.getThreadOrder().size(); i++) {
if (encoding.getThreadOrder()[i] == splitDim)
break;
numContiguousValues *= encoding.getSizePerThread()[i];
}
Expand Down Expand Up @@ -336,7 +336,6 @@ struct BroadcastOpConversion
unsigned rank = srcTy.getRank();
auto typeConverter = getTypeConverter();
assert(rank == resultTy.getRank());
auto order = triton::gpu::getOrder(srcLayout);
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> srcVals = unpackLLElements(loc, src, rewriter);
Expand Down
20 changes: 14 additions & 6 deletions lib/Dialect/Triton/Transforms/LoopUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@

namespace mlir::triton {

static const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";

namespace {

class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {

int getUnrollFactorOrDefault(scf::ForOp forOp) {
// Use the attribute attached to the loop if it exists otherwise set the
// factor to 1 to suppress the unrolling.
if (auto factor = forOp->getAttrOfType<IntegerAttr>(
mlir::triton::loopUnrollFactorAttrName))
if (auto factor =
forOp->getAttrOfType<IntegerAttr>(loopUnrollFactorAttrName))
return factor.getInt();
return 1;
}

const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";
const char *pipelineStagesAttrName = "tt.num_stages";

public:
LoopUnrollPass() = default;
LoopUnrollPass(const LoopUnrollPass &) {}
Expand All @@ -49,11 +50,18 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
loops.push_back(forOp);
});

auto ctx = getOperation()->getContext();
for (auto loop : loops) {
auto unrollFactor = getUnrollFactorOrDefault(loop);
loop->removeAttr(mlir::triton::loopUnrollFactorAttrName);
loop->removeAttr(loopUnrollFactorAttrName);
LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop);
(void)loopUnrollByFactor(loop, unrollFactor);
auto resultLoops = loopUnrollByFactor(loop, unrollFactor);
// Do not pipeline the epilog loop.
if (succeeded(resultLoops) && resultLoops->epilogueLoopOp) {
(*resultLoops->epilogueLoopOp)
->setAttr(pipelineStagesAttrName,
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1));
}
}
}
};
Expand Down
71 changes: 33 additions & 38 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ bool isExpensiveView(Type srcType, Type dstType) {
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
}

/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
/* Utility function used by get.*Order methods of SliceEncodingAttr.
* Erase dim and decrease all values larger than dim by 1.
* Example: order = [0, 2, 4, 3, 1], dim = 2
* resOrder = [0, 3, 2, 1]
Expand Down Expand Up @@ -265,29 +265,11 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
return getWarpOrder(dotLayout.getParent());
}
}
auto order = getOrder(layout);
// FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
// M-major This is awkward. Since we can choose any warpOrder in Ampere, we
// should probably choose M-major and change `LinearLayoutConversion.cpp` and
// `MMAv2.cpp` to match.
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (mmaLayout.isHopper()) {
// Hopper MMA instructions force warps to be column-major
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
return getMatrixOrder(order.size(), /*rowMajor*/ false);
}
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error(
"DotOperandEncoding::getWarpOrder not implemented");
}
return order;
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getWarpOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
}

SmallVector<unsigned> getOrder(Attribute layout) {
Expand All @@ -296,7 +278,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
}
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
// Order doesn't really matter. We just have to be consistent when unpacking
// the elements in the MMAv2/V3 lowerings. We choose row-major
// the output elements in the LLVM lowerings. We choose row-major
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
auto rank = distributedLayout.getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
Expand Down Expand Up @@ -331,15 +313,15 @@ SmallVector<unsigned> getOrder(Attribute layout) {

llvm::report_fatal_error("Unimplemented usage of getOrder");
return {};
};
}

SmallVector<unsigned> getThreadOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getThreadOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
};
}

CTALayoutAttr getCTALayout(Attribute layout) {
if (auto distributedLayout =
Expand Down Expand Up @@ -782,7 +764,8 @@ SmallVector<unsigned> SliceEncodingAttr::getWarpsPerCTA() const {
return warpsPerCTA;
}
SmallVector<unsigned> SliceEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
auto parentWarpOrder = ::getWarpOrder(getParent());
return eraseOrder(parentWarpOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
auto parent = getParent();
Expand All @@ -794,7 +777,8 @@ SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
return threadsPerWarp;
}
SmallVector<unsigned> SliceEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto parentThreadOrder = ::getThreadOrder(getParent());
return eraseOrder(parentThreadOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
auto sizePerThread = ::getSizePerThread(getParent());
Expand Down Expand Up @@ -1065,7 +1049,14 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
return warps;
}
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
// FIXME(Lezcano): Preexisting. Do we want to have this path at all?
if (mlir::isa<AMDMfmaEncodingAttr>(getParent())) {
return ::getWarpOrder(getParent());
}
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented");
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
Expand Down Expand Up @@ -1637,7 +1628,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
return ::getOrder(*this);
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
auto order = ::getOrder(*this);
Expand Down Expand Up @@ -1806,7 +1797,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
return ::getOrder(*this);
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
Expand Down Expand Up @@ -1930,7 +1921,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
auto rank = getWarpsPerCTA().size();
// Hopper (wgmma) uses column-major as this is embeded in the instruction
// For Ampere we can choose either row-major or column-major.
// We choose row-major as the legacy path did so
return getMatrixOrder(rank, /*rowMajor*/ !isHopper());
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
auto rank = getWarpsPerCTA().size();
Expand All @@ -1954,10 +1949,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
"getThreadsPerWarp not implemented for unknown Mma version ");
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
auto rank = ::getOrder(*this).size();
auto rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
if (isAmpere()) {
res[rank - 2] = 2;
Expand Down Expand Up @@ -2198,11 +2194,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
if (opIdx == 0) {
sizePerThread[rank - 2] = 2;
sizePerThread[rank - 1] = 2 * kWidth;
} else if (opIdx == 1) {
} else {
assert(opIdx == 1);
sizePerThread[rank - 2] = 2 * kWidth;
sizePerThread[rank - 1] = 1;
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
return sizePerThread;
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(k == 8 || k == 16 || k == 32);

// TODO Make the getOrder of Hopper explicit here via an assert
MLIRContext *ctx = mma.getContext();
LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ class TritonGPUReduceDataDuplicationPass
if (largeKWidth) {
return;
}
auto srcOrder = triton::gpu::getOrder(srcEncoding);
auto rank = srcOrder.size();
auto srcThreadOrder = triton::gpu::getThreadOrder(srcEncoding);
auto rank = srcThreadOrder.size();
SmallVector<unsigned> sharedOrder;
if (rank == 3) {
// add all elements except the element that is zero
for (unsigned i = 0; i < rank; ++i)
if (srcOrder[i] != 0)
sharedOrder.emplace_back(srcOrder[i]);
if (srcThreadOrder[i] != 0)
sharedOrder.emplace_back(srcThreadOrder[i]);
sharedOrder.emplace_back(0);
} else {
sharedOrder = srcOrder;
sharedOrder = srcThreadOrder;
}
auto sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
Expand Down
1 change: 0 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,7 +1801,6 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.
kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field)

if constant_field == "value":
print(output, ref)
assert torch.all(output == ref)
else:
assert torch.all(output == 0)
Expand Down
1 change: 1 addition & 0 deletions test/Triton/loop-unroll.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ tt.func @add_kernel_unroll(%arg0: tensor<256x!tt.ptr<f32>>, %arg1: i32) {
// CHECK: scf.for
// CHECK: tt.load
// CHECK-NOT: tt.load
// CHECK: tt.num_stages = 1 : i32
%2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>) : i32 {
%3 = tt.load %arg5 : tensor<256x!tt.ptr<f32>>
%4 = arith.addf %arg4, %3 : tensor<256xf32>
Expand Down
2 changes: 1 addition & 1 deletion test/lib/Instrumentation/GPUHello.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ bool GpuHello::runOnModule(Module &module) {

PassPluginLibraryInfo getPassPluginInfo() {
const auto callback = [](PassBuilder &pb) {
pb.registerOptimizerLastEPCallback([&](ModulePassManager &mpm, auto) {
pb.registerOptimizerLastEPCallback([&](ModulePassManager &mpm, auto, auto) {
mpm.addPass(GpuHello());
return true;
});
Expand Down
Loading

0 comments on commit 7573b11

Please sign in to comment.