Skip to content

Commit

Permalink
Merge commit '6f5baf6801b44e51b7ba8eedaa619e39c912bef6'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Dec 11, 2024
2 parents 2c6a850 + 6f5baf6 commit e302ae6
Show file tree
Hide file tree
Showing 32 changed files with 1,761 additions and 254 deletions.
9 changes: 4 additions & 5 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,6 @@ jobs:
cd python
ccache --zero-stats
pip install -v -e '.[tests]'
- name: Clean up after an unsuccessful build
if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }}
run: |
rm -rf ~/.triton
- name: CCache Stats
run: ccache --print-stats
- name: Run lit tests
Expand Down Expand Up @@ -477,8 +473,11 @@ jobs:
~/.ccache
key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }}
- name: Clean up caches
# Always cleanup the worker, even if builds or tests failed
if: always()
run: |
rm -rf ~/.triton/cache
rm -rf ~/.triton
rm -rf ~/.ccache
Build-Tests:
needs: Runner-Preparation
if: needs.Runner-Preparation.outputs.matrix-MACOS != ''
Expand Down
10 changes: 4 additions & 6 deletions .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,6 @@ jobs:
ccache --zero-stats
pip install -v -e '.[tests]'

- name: Clean up after an unsuccessful build
if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }}
run: |
rm -rf ~/.triton

- *print-ccache-stats
- *run-lit-tests-step

Expand Down Expand Up @@ -442,8 +437,11 @@ jobs:
- *save-build-artifacts-step

- name: Clean up caches
# Always cleanup the worker, even if builds or tests failed
if: always()
run: |
rm -rf ~/.triton/cache
rm -rf ~/.triton
rm -rf ~/.ccache

Build-Tests:
needs: Runner-Preparation
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).
- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks.
- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx.
- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx when `TRITON_KERNEL_DUMP` is set to 1.
- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx at the beginning of each compilation stage.
- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx files when `TRITON_KERNEL_OVERRIDE` is set to 1.
- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn.
- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1.
- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage.
- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx/amdgcn files when `TRITON_KERNEL_OVERRIDE` is set to 1.

**Kernel Override Steps**

Expand All @@ -224,7 +224,7 @@ export TRITON_KERNEL_DUMP=1
export TRITON_DUMP_DIR=<dump_dir>
export TRITON_KERNEL_OVERRIDE=1
export TRITON_OVERRIDE_DIR=<override_dir>
# Step 1: Run the kernel once to dump kernel's IRs and ptx in $TRITON_DUMP_DIR
# Step 1: Run the kernel once to dump kernel's IRs and ptx/amdgcn in $TRITON_DUMP_DIR
# Step 2: Copy $TRITON_DUMP_DIR/<kernel_hash> to $TRITON_OVERRIDE_DIR
# Step 3: Delete the stages that you do not want to override and modify the stage you do want to override
# Step 4: Run the kernel again to see the overridden result
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class GatherLoweringHelper {

// Get the shared memory scratch size required by this op.
unsigned getScratchSizeInBytes();
// Determine if the gather can be performed completely within a warp.
bool isWarpLocal();

private:
triton::GatherOp gatherOp;
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1123,8 +1123,8 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
return idx;
}

// Emit code to compute the (blockId, warpId, laneId) for the current thread.
std::tuple</*blockId=*/Value, /*warpId=*/Value, /*laneId=*/Value>
// Emit code to compute the (laneId, warpId, blockId) for the current thread.
std::tuple</*laneId=*/Value, /*warpId=*/Value, /*blockId=*/Value>
emitHardwareTuple(Location loc, RewriterBase &rewriter,
const TargetInfoBase &target, bool withCTAOffset,
unsigned threadsPerWarp);
Expand Down
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ LinearLayout ensureLayoutNotSmallerThan(
const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);

// Return a vector of the standard out dimension names for tensor layouts. These
// are "dim0", "dim1", etc.
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);
// Return an identity mapping from `inDimName` to the standard out dimensions,
// with the dimensions sized according to the shape. The bases are sorted
// according to `order`, with the most minor dimension first.
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_

#include "mlir/Dialect/SCF/IR/SCF.h"
#include <optional>
#include <utility>
#include <vector>

namespace mlir {
Expand Down Expand Up @@ -38,6 +40,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
// Return the minClusterId and maxClusterId for the given ForOp.
std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp);
std::pair<int, int> getStageCluster(Operation *op);
std::optional<std::pair<int, int>> maybeGetStageCluster(Operation *op);
void setStageCluster(Operation *op, int stage, int cluster);
} // namespace triton
} // namespace mlir
Expand Down
99 changes: 99 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir::triton::nvidia_gpu {

constexpr inline int TMA_SIZE_BYTES = 128;
constexpr inline int TMA_ALIGN = 128;

template <typename BuilderT>
mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
mlir::triton::MakeTensorDescOp op,
BuilderT &builder) {
using namespace mlir;
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();
auto mkI32Constant = [&](int32_t val) {
return builder.template create<arith::ConstantOp>(
loc, builder.getI32Type(), builder.getI32IntegerAttr(val));
};

auto elemType = op.getBase().getType().getPointeeType();
auto elemSize = elemType.getIntOrFloatBitWidth() / 8;

int32_t contig_dim_size = op.getTensorShape().back();
int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize;
if (contig_dim_size_in_bytes > 128) {
contig_dim_size = 128 / elemSize;
}
llvm::SmallVector<Value> boxDim;
boxDim.push_back(mkI32Constant(contig_dim_size));
for (int k = op.getTensorShape().size() - 2; k >= 0; --k) {
boxDim.push_back(mkI32Constant(op.getTensorShape()[k]));
}

int32_t swizzle_mode;
if (contig_dim_size_in_bytes >= 128) {
swizzle_mode = 3;
} else if (contig_dim_size_in_bytes == 64) {
swizzle_mode = 2;
} else if (contig_dim_size_in_bytes == 32) {
swizzle_mode = 1;
} else {
op->emitError()
<< "contiguous box dimension must be at least 32 bytes but got "
<< contig_dim_size_in_bytes;
return failure();
}

Value elemSizeVal = builder.template create<arith::ConstantOp>(
loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize));
Value globalStride = builder.template create<arith::MulIOp>(
loc, op.getStrides()[0], elemSizeVal);
// TODO: Workaround for ptxas bug, remove when we update ptxas
Value four = builder.template create<arith::ConstantOp>(
loc, builder.getI64Type(), builder.getI64IntegerAttr(4));
globalStride =
builder.template create<arith::ShRSIOp>(loc, globalStride, four);

int elemTypeEnum;
switch (elemSize) {
case 1: {
elemTypeEnum = 0;
break;
}
case 2: {
elemTypeEnum = 1;
break;
}
case 4: {
elemTypeEnum = 2;
break;
}
default: {
op->emitError()
<< "Tensor descriptor element type must have size 1, 2, or 4 but got "
<< elemSize;
return failure();
}
}

auto one = mkI32Constant(1);
builder.template create<triton::ExperimentalTensormapCreateOp>(
loc,
/*desc_ptr=*/tmaPtr,
/*global_address=*/op.getBase(),
/*box_dim=*/boxDim,
/*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]},
/*global_stride=*/ValueRange{globalStride},
/*element_strides=*/ValueRange{one, one},
/*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum),
/*interleave_layout*/ builder.getI32IntegerAttr(0),
/*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode),
/*fill_mode=*/builder.getI32IntegerAttr(0));
return success();
}

} // namespace mlir::triton::nvidia_gpu
84 changes: 82 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,93 @@ GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp)
: gatherOp(gatherOp) {}

unsigned GatherLoweringHelper::getScratchSizeInBytes() {
// For now, lower the gather op by writing the source tensor to shared memory.
// TODO(jeff): Leverage locality to avoid using scratch space when possible.
// If the gather is warp-local, no scratch space is needed.
if (isWarpLocal())
return 0;

// Otherwise, performing the gather will require scratch space to communicate
// the source tensor across threads. For now, assume the whole source tensor
// is written back to shared memory.
RankedTensorType srcType = gatherOp.getSrc().getType();
return product(srcType.getShape()) *
ceil<unsigned>(srcType.getElementTypeBitWidth(), 8);
}

bool GatherLoweringHelper::isWarpLocal() {
// The gather is warp-local if for each column along the gather axis in the
// source and index tensors, all the elements are owned by the same warp.
RankedTensorType srcType = gatherOp.getSrc().getType();
RankedTensorType idxType = gatherOp.getIndices().getType();
std::optional<LinearLayout> srcLayout =
toLinearLayout(srcType.getShape(), srcType.getEncoding());
std::optional<LinearLayout> idxLayout =
toLinearLayout(idxType.getShape(), idxType.getEncoding());

// FIXME: If an unsupported layout was encountered, assume the gather is not
// warp-local.
if (!srcLayout || !idxLayout)
return false;

Builder b(gatherOp.getContext());
StringAttr kBlock = b.getStringAttr("block");
StringAttr kWarp = b.getStringAttr("warp");
StringAttr kLane = b.getStringAttr("lane");
StringAttr kGatherDim =
b.getStringAttr("dim" + std::to_string(gatherOp.getAxis()));

// The tensor layouts must be distributed layouts, where the basis matrix is a
// subpermutation matrix (permutation matrix plus zeros for broadcasting).
// FIXME(jeff): Check this invariant somehow.
//
// We want to know if all elements of a column along the gather axis are
// mapped to the same set of warps, which means the gather can be performed
// entirely within the warp. We need to query
//
// srcLayout.invert().sublayoutIsZero({kGatherDim}, {kBlock, kWarp})
//
// But due to broadcasting, the matrix might not be invertible. But since the
// matrix is a permutation matrix (checked below), we can instead query
//
// srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim})
//
// Which implies that changing the warp will not change the gather dimension.
// And since there is no swizzling, this applies to all warps.
if (!srcLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim) ||
!idxLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim))
return false;

SmallVector<StringAttr> otherDims;
for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) {
if (dim != gatherOp.getAxis()) {
otherDims.push_back(b.getStringAttr("dim" + Twine(dim)));
}
}

// If the gather axis `dimN` is invariant to the warp, but the `(block, warp)`
// mapping to all other dimensions must be the same for both layouts. If so,
// then the warp that owns a particular index element also owns all the source
// elements it could index into.
if (srcLayout->sublayout({kBlock, kWarp}, otherDims) !=
idxLayout->sublayout({kBlock, kWarp}, otherDims))
return false;

// The two constraints above ensure that data-movement to perform the gather
// operation are contained within a warp. The subsequent constraints simplify
// codegen.

// Require that for any given gather column, the threads mapped to the column
// in the index and source tensors are the same. This means we don't need to
// xor shuffle across threads before emitting index shuffles; we push warp
// shuffling to layout conversions.
if (srcLayout->sublayout(kLane, otherDims) !=
idxLayout->sublayout(kLane, otherDims))
return false;

// Otherwise, the source layout has to be invertible. This primarily means
// the codegen path doesn't support broadcasted source layouts.
return srcLayout->isInvertible();
}

unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
if (shape.empty())
return 0;
Expand Down
Loading

0 comments on commit e302ae6

Please sign in to comment.