Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit f8b5301 #3069

Merged
merged 9 commits into from
Dec 26, 2024
8 changes: 6 additions & 2 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
- name: Build wheels
if: ${{ steps.check-version.outputs.new_commit == 'true' }}
run: |
# Make sure cibuildwheel is updated to latest, this will enable latest python builds
python3 -m pip install cibuildwheel --upgrade --user
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d%H%M%S' --format="%cd")
# Pass MAX_JOBS=4 because, at time of writing, the VM "only" has 32GB
# of RAM and OOMs while building if we give it the default number of
Expand All @@ -63,8 +65,10 @@ jobs:
# many_linux_2_28 image comes with GCC 12.2.1, but not clang.
# With this install, it gets clang 16.0.6.
export CIBW_BEFORE_ALL="dnf install clang lld -y";
export CIBW_SKIP="cp{35,36,37}-*"
export CIBW_BUILD="cp3*-manylinux_x86_64"
export CIBW_SKIP="cp{35,36,37,38}-*"
export CIBW_BUILD="cp3{9,10,11,12,13,13t}-manylinux_x86_64"
export CIBW_FREE_THREADED_SUPPORT=1

python3 -m cibuildwheel python --output-dir wheelhouse

- name: Install Azure CLI
Expand Down
1 change: 1 addition & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ Scan/Sort Ops
cumsum
histogram
sort
gather

Atomic Ops
----------
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace triton::gpu {
/// |module| op because the codegen doesn't handle `blocked -> dot_op` directly.
void decomposeBlockedToDotLayoutConversion(ModuleOp module);

/// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the
/// Replaces `mfma -> dot_op` with `mfma -> blocked -> dot_op` in the
/// given |module| op, but bypass the decomposition if |shortcutFn| returns
/// true.
using ShortcutFn = std::function<bool(RankedTensorType, RankedTensorType)>;
Expand Down
8 changes: 8 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
(ins "int":$opIdx,
"int":$kWidth)>,

InterfaceMethod<"Return the number of threads per warp for dot operands.",
"SmallVector<unsigned>",
"getThreadsPerWarpForOperand",
(ins "int":$opIdx)>,

InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
"SmallVector<unsigned>",
"getRepOrderForOperand",
Expand Down Expand Up @@ -915,6 +920,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;

SmallVector<unsigned> getContigPerThread() {
auto rank = getWarpsPerCTA().size();
Expand Down Expand Up @@ -1023,6 +1029,7 @@ Row | warp 0 warp 2
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
static SmallVector<unsigned> getMNKDimPerInstr();

SmallVector<unsigned> getContigPerThread() {
Expand Down Expand Up @@ -1141,6 +1148,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
int bitwidth, int kWidth,
int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;

bool supportReduction() const {
if (isAmpere() || isHopper()) {
Expand Down
33 changes: 26 additions & 7 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2112,6 +2112,13 @@ AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

SmallVector<unsigned>
AMDMfmaEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const {
llvm::report_fatal_error(
"getThreadsPerWarpForOperand not implemented for AMDMfmaEncodingAttr");
return {};
}

SmallVector<int64_t>
AMDMfmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
int kWidth, int opIdx) const {
Expand Down Expand Up @@ -2173,6 +2180,13 @@ AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

SmallVector<unsigned>
AMDWmmaEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const {
llvm::report_fatal_error("getThreadsPerWarpForOperand not implemented for "
"AMDWmmaEncodingAttr");
return {};
}

SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -2350,6 +2364,15 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

SmallVector<unsigned>
NvidiaMmaEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const {
auto threadsPerWarp = getThreadsPerWarp();
auto rank = threadsPerWarp.size();
if (opIdx == 1)
std::swap(threadsPerWarp[rank - 2], threadsPerWarp[rank - 1]);
return threadsPerWarp;
}

SmallVector<int64_t>
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
int kWidth, int opIdx) const {
Expand Down Expand Up @@ -2418,16 +2441,12 @@ SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
}

SmallVector<unsigned> DotOperandEncodingAttr::getThreadsPerWarp() const {
auto parent = getParent();
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
auto threadsPerWarp = mma.getThreadsPerWarp();
auto rank = threadsPerWarp.size();
if (getOpIdx() == 1)
std::swap(threadsPerWarp[rank - 2], threadsPerWarp[rank - 1]);
return threadsPerWarp;
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
return mma.getThreadsPerWarpForOperand(getOpIdx());
}
llvm::report_fatal_error(
"getThreadsPerWarp not implemented for DotOperandEncodingAttr");
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getSizePerThread() const {
auto parentLayout = getParent();
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,11 +737,11 @@ def get_git_commit_hash(length=8):
"Intended Audience :: Developers",
"Topic :: Software Development :: Build Tools",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
],
test_suite="tests",
extras_require={
Expand Down
1 change: 1 addition & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6455,6 +6455,7 @@ def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0:
tl.store(out_ptr + out_offs, out)


@pytest.mark.interpreter
@pytest.mark.parametrize("src_shape, indices_shape, axis", [
([4, 4], [8, 4], 0),
([128, 64], [256, 64], 0),
Expand Down
3 changes: 0 additions & 3 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,9 +1722,6 @@ def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor
return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, [num_bins]))


##


def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if max(1, len(x.shape)) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
Expand Down
15 changes: 8 additions & 7 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
...
:note: When all the configurations are evaluated, the kernel will run multiple times.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
Expand Down Expand Up @@ -382,18 +382,19 @@ def run(self, *args, **kwargs):
def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
This is useful for cases where auto-tuning is prohibitively expensive, or just not applicable.

.. highlight:: python
.. code-block:: python

@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
# smallest power-of-two >= x_size
@triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])})
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
...
:param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
:type values: dict[str, Callable[[list[Any]], Any]]
:type values: dict[str, Callable[[dict[str, Any]], Any]]
"""

def decorator(fn):
Expand Down
5 changes: 4 additions & 1 deletion python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,9 @@ def create_make_range(self, start, stop):
def create_histogram(self, data, bins):
return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32)

def create_gather(self, src, indices, axis):
return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)

# pointer arithmetic

def create_addptr(self, ptr, offset):
Expand Down Expand Up @@ -998,7 +1001,7 @@ def _set_attr(input, values, name):


def _patch_lang(fn):
langs = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]]
langs = [value for _, value in fn.__globals__.items() if inspect.ismodule(value) and value in [tl, tl.core]]
assert len(langs) >= 1, "triton.language must be visible from within jit'd function"
for lang in langs:
_patch_builtin(lang, interpreter_builder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,15 @@ struct DecomposeUnsupportedAMDConversions
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);

auto isShortcut =
mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory));

triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut);

// Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op`
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getSrc().getType();
auto dstType = cvtOp.getType();
auto shortcutFn = [](RankedTensorType srcTy, RankedTensorType dstTy) {
auto srcWmma =
dyn_cast<triton::gpu::AMDWmmaEncodingAttr>(srcType.getEncoding());
dyn_cast<triton::gpu::AMDWmmaEncodingAttr>(srcTy.getEncoding());
auto dstDotOp =
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
if (srcWmma && dstDotOp) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::BlockedEncodingAttr::get(
mod.getContext(), srcType.getShape(), getSizePerThread(srcWmma),
getOrder(srcWmma), numWarps, threadsPerWarp, numCTAs));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstTy.getEncoding());
return !cvtNeedsSharedMemory(srcTy, dstTy) && !(srcWmma && dstDotOp);
};

triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, shortcutFn);
// Try to reduce LDS usage of cvt(mfma->blocked) op by changing the shape of
// WarpsPerCta attribute in mfma layout. The implicit LDS usage of
// cvt(mfma->blocked) op depends on the number of warps per CTA that mfma
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ along the row (resp. col) dimension.
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const;
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, OpIdx opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(OpIdx opIdx) const;
SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, OpIdx opIdx) const;

// Forwarder functions for casting unsigned to OpIdx.
Expand Down
7 changes: 7 additions & 0 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ DpasEncodingAttr::getRepOrderForOperand(OpIdx opIdx) const {
return getOrderForDotOperand(unsigned(opIdx), rank, /*kMajor*/ true);
}

SmallVector<unsigned>
DpasEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const {
llvm::report_fatal_error(
"getThreadsPerWarpForOperand not implemented for DpasEncodingAttr");
return {};
}

SmallVector<unsigned>
DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
size_t rank = shape.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,7 @@ struct DecomposeUnsupportedConversions
: public mlir::triton::impl::DecomposeUnsupportedNVIDIAConversionsBase<
DecomposeUnsupportedConversions> {
void runOnOperation() override {
// FIXME [Dot LL]
// Remove the decomposeTensorCoreToDotLayoutConversion class entirely after
// we have enabled the new layout conversion for all the cases.
auto nvidiaShortCutFn = [&](RankedTensorType srcTy,
RankedTensorType dstTy) { return true; };
ModuleOp mod = getOperation();
triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod,
nvidiaShortCutFn);
triton::gpu::decomposeBlockedToDotLayoutConversion(mod);

mlir::RewritePatternSet patterns(&getContext());
Expand Down