Skip to content

Commit

Permalink
Merge commit 'cfc523d5842238390e036dd9671ba054e4ab6122'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Apr 29, 2024
2 parents e8b232a + cfc523d commit 4f4af05
Show file tree
Hide file tree
Showing 37 changed files with 626 additions and 225 deletions.
10 changes: 0 additions & 10 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::registerConvertTritonAMDGPUToLLVM();
mlir::triton::registerDecomposeUnsupportedAMDConversions();

// TODO: Uncomment when fixed undefined symbols and
// remove section below
// List of undefined symbols:
// createTritonAMDGPUCoalesce is not defined
// createTritonAMDGPUOptimizeDotOperands is not defined
// createTritonAMDGPUPipeline is not defined
// createTritonAMDGPUPrefetch is not defined

// mlir::registerTritonAMDGPUPasses();

mlir::registerTritonAMDGPUAccelerateMatmul();
mlir::registerTritonAMDGPUOptimizeEpilogue();
mlir::registerTritonAMDGPURemoveLayoutConversions();
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Conversion/TritonToTritonGPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
"int32_t", /*default*/"1",
"number of ctas in a cga">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"compute capability">
"int32_t", /*default*/"0",
"optional compute capability for cuda; 0 means missing">
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H

#include <memory>
#include <optional>

namespace mlir {

Expand All @@ -23,9 +24,9 @@ std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUWarpPass();

// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32,
int numCTAs = 1, int computeCapability = 80);
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass(
int numWarps, int threadsPerWarp = 32, int numCTAs = 1,
std::optional<int> computeCapability = std::nullopt);
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUWarpPass(unsigned numWarps);

Expand Down
9 changes: 4 additions & 5 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ def TritonGPU_Dialect : Dialect {
return 1;
return mod->getAttr("triton_gpu.num-ctas").cast<IntegerAttr>().getInt();
}
static int getComputeCapability(ModuleOp mod) {
if (!mod->hasAttr("triton_gpu.compute-capability"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.compute-capability attribute");
return mod->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability").getInt();
static std::optional<int> getComputeCapability(ModuleOp mod) {
auto attr = mod->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability");
if (attr) return attr.getInt();
return std::nullopt;
}
void registerTypes();

Expand Down
68 changes: 68 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,72 @@ def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods<InferTypeOpI
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
}

def TTNG_InitBarrierOp : TTNG_Op<"init_barrier", [MemoryEffects<[MemWrite<SharedMemory>]>]> {
let summary = "Initialize a barrier in the given shared memory allocation.";

let description = [{
Initializes a shared memory allocation with mbarrier information.
`alloc` is a descriptor to the shared memory allocation. `count` is the
number of arrives expected by the barrier.

This lowers to PTX mbarrier.init.shared::cta.b64.
}];

let hasVerifier = 1;
let arguments = (ins TT_MemDescType:$alloc,
I32Attr:$count);
let assemblyFormat = "$alloc `,` $count attr-dict `:` type($alloc)";
}

def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", []> {
let summary = "wait until the mbarrier phase completes.";

let description = [{
Blocks the program progress until the mbarrier object in `alloc` completes
its current phase.

This lowers a waitloop using PTX instruction
mbarrier.try_wait.parity.shared.b64.

The barrier behavior is described here:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-asynchronous-copy-completion-mechanisms
}];

let hasVerifier = 1;
let arguments = (ins TT_MemDescType:$alloc,
I32:$phase);
let assemblyFormat = "$alloc `,` $phase attr-dict `:` type($alloc)";
}


def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<SharedMemory>]>]> {
let summary = "copy data based on descriptor from global memory to local memory asynchronously";

let description = [{
This operation copies data from global memory to local memory
asynchronously. This is analogue to tt.load except the data are copied to
local memory pointed by the memory descriptor instread of a distributed
tensor. The data copied depends on the global memory descriptor pointed to
by `desc_ptr`.
}];

let hasVerifier = 1;
let arguments = (
ins TT_PtrType:$desc_ptr,
Variadic<I32>:$coord,
TT_MemDescType:$barrier,
TT_MemDescType:$result,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict,
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
);

let assemblyFormat = [{
$desc_ptr `[` $coord `]` $result `,` $barrier
oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict)
attr-dict `:` type($desc_ptr) `,` type($barrier) `->` type($result)
}];
}

#endif
20 changes: 11 additions & 9 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,11 +745,11 @@ class ConvertTritonToTritonGPU
ConvertTritonToTritonGPU() = default;
// constructor with some parameters set explicitly.
ConvertTritonToTritonGPU(int numWarps, int threadsPerWarp, int numCTAs,
int computeCapability) {
std::optional<int> computeCapability) {
this->numWarps = numWarps;
this->threadsPerWarp = threadsPerWarp;
this->numCTAs = numCTAs;
this->computeCapability = computeCapability;
this->computeCapability = computeCapability.value_or(0);
}

void runOnOperation() override {
Expand Down Expand Up @@ -783,9 +783,12 @@ class ConvertTritonToTritonGPU
mod->setAttr(AttrNumCTAsName,
IntegerAttr::get(i32_ty, llvm::APInt(32, numCTAs.getValue())));

mod->setAttr(AttrComputeCapabilityName,
IntegerAttr::get(
i32_ty, llvm::APInt(32, computeCapability.getValue())));
if (std::optional<int> cc = computeCapability.getValue()) {
if (cc.value() != 0) {
mod->setAttr(AttrComputeCapabilityName,
IntegerAttr::get(i32_ty, llvm::APInt(32, cc.value())));
}
}

if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
Expand All @@ -800,10 +803,9 @@ class ConvertTritonToTritonGPU
} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps,
int threadsPerWarp,
int numCTAs,
int computeCapability) {
mlir::triton::createConvertTritonToTritonGPUPass(
int numWarps, int threadsPerWarp, int numCTAs,
std::optional<int> computeCapability) {
return std::make_unique<::ConvertTritonToTritonGPU>(
numWarps, threadsPerWarp, numCTAs, computeCapability);
}
Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,10 @@ class TritonGPUOptimizeDotOperandsPass

mlir::RewritePatternSet patterns(context);
patterns.add<SwizzleShmemConvert>(context);
if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
// TODO: compute capability is CUDA specific; change it to backend agnostic.
std::optional<int> cc =
triton::gpu::TritonGPUDialect::getComputeCapability(m);
if (cc.value_or(0) >= 80)
patterns.add<HoistLayoutConversion>(context);
patterns.add<FuseTransHopper>(context);
patterns.add<MMAV3UseRegOperand>(context);
Expand Down
29 changes: 22 additions & 7 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
ConvertLayoutOp convertOp,
IRMapping &mapping) {
SetVector<Operation *> opsToRewrite;
// Keep track of yield operands that need to be duplicated.
DenseMap<Operation *, SmallVector<int>> yieldOperandsMap;
for (Value v : slice) {
auto layoutIt = layout.find(v);
assert(layoutIt != layout.end());
Expand All @@ -843,13 +845,22 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
if (v.getDefiningOp()) {
opsToRewrite.insert(v.getDefiningOp());
if (auto ifOp = v.getDefiningOp<scf::IfOp>()) {
unsigned operandIdx = v.cast<OpResult>().getResultNumber();
opsToRewrite.insert(ifOp.thenYield().getOperation());
yieldOperandsMap[ifOp.thenYield()].push_back(operandIdx);
opsToRewrite.insert(ifOp.elseYield().getOperation());
yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx);
}
} else {
opsToRewrite.insert(v.cast<BlockArgument>().getOwner()->getParentOp());
// We also need to rewrite the yield op.
opsToRewrite.insert(v.cast<BlockArgument>().getOwner()->getTerminator());
BlockArgument blockArg = v.cast<BlockArgument>();
Operation *parentOp = blockArg.getOwner()->getParentOp();
if (auto loopOp = cast<LoopLikeOpInterface>(parentOp)) {
opsToRewrite.insert(loopOp.getOperation());
OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg);
auto yieldOp = blockArg.getOwner()->getTerminator();
yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber());
opsToRewrite.insert(yieldOp);
}
}
}
opsToRewrite = multiRootTopologicalSort(opsToRewrite);
Expand Down Expand Up @@ -893,6 +904,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
Value oldArg = loopBody.getArgument(m.first + numIndVars);
addRematValue(newForOp.getResult(m.first), layout[oldArg],
newForOp.getResult(m.second));
addRematValue(oldArg, layout[oldArg],
loopBody.getArgument(m.second + numIndVars));
}
continue;
}
Expand Down Expand Up @@ -929,10 +942,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
builder.setInsertionPoint(op);
if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
auto yieldOperands = llvm::to_vector(yieldOp.getOperands());
for (Value operand : yieldOp.getOperands()) {
if (slice.count(operand) == 0)
continue;
yieldOperands.push_back(mapping.lookup(operand));
SmallVector<int> operandsToRewrite = yieldOperandsMap[op];
// Sort so that operands are added in the same order as the new scf
// results/arguments.
std::sort(operandsToRewrite.begin(), operandsToRewrite.end());
for (int operandIdx : operandsToRewrite) {
yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx)));
}
builder.create<scf::YieldOp>(op->getLoc(), yieldOperands);
op->erase();
Expand Down
31 changes: 31 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,37 @@ LogicalResult DotWaitOp::inferReturnTypes(
return mlir::success();
}

static LogicalResult verifyBarrierType(Operation *op, MemDescType barrierType) {
if (!barrierType.getElementType().isInteger(64) ||
barrierType.getShape() != ArrayRef<int64_t>({1}))
return op->emitOpError(
"barrier allocation must be a descriptor of 1xi64 type");
return success();
}

// -- InitBarrierOp --
LogicalResult InitBarrierOp::verify() {
if (failed(verifyBarrierType(*this, getAlloc().getType())))
return failure();
return success();
}

// -- WaitBarrierOp --
LogicalResult WaitBarrierOp::verify() {
if (failed(verifyBarrierType(*this, getAlloc().getType())))
return failure();
return success();
}

// -- AsyncTMACopyGlobalToLocalOp --
LogicalResult AsyncTMACopyGlobalToLocalOp::verify() {
if (failed(verifyBarrierType(*this, getBarrier().getType())))
return failure();
if (getCoord().size() < 1 || getCoord().size() > 5)
return emitOpError("TMA copies must have between 1 and 5 coordinates");
return success();
}

} // namespace nvidia_gpu
} // namespace triton
} // namespace mlir
4 changes: 3 additions & 1 deletion python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Target/LLVMIR/Passes.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

Expand Down Expand Up @@ -39,7 +40,8 @@ void init_triton_passes_ttir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer",
createRewriteTensorPointerPass);
ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir",
createConvertTritonToTritonGPUPass, int, int, int, int);
createConvertTritonToTritonGPUPass, int, int, int,
std::optional<int>);
}

void init_triton_passes_ttgpuir(py::module &&m) {
Expand Down
47 changes: 47 additions & 0 deletions python/test/unit/hopper/test_experimental_tma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import pytest
import torch
import tempfile

import triton


def test_descriptor_load_ttgir():
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9:
pytest.skip("Test requires Hopper target.")
return
device = "cuda"
SIZE = 128

ir = f"""
#blocked = #triton_gpu.blocked<{{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}}>
#shared = #triton_gpu.shared<{{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}}>
module attributes {{"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i8> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
%c0_i32 = arith.constant 0 : i32
%0 = tt.make_range {{end = {SIZE} : i32, start = 0 : i32}} : tensor<{SIZE}xi32, #blocked>
%1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, mutable>
%2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, mutable>
triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, mutable>
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2 : <i8>, <1xi64, #shared, mutable> -> <{SIZE}xf32, #shared, mutable>
triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, mutable>
%3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, mutable> -> tensor<{SIZE}xf32, #blocked>
%4 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<{SIZE}x!tt.ptr<f32>, #blocked>
%5 = tt.addptr %4, %0 : tensor<{SIZE}x!tt.ptr<f32>, #blocked>, tensor<{SIZE}xi32, #blocked>
tt.store %5, %3 : tensor<{SIZE}x!tt.ptr<f32>, #blocked>
tt.return
}}
}}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)

x = torch.randn(SIZE, dtype=torch.float32, device=device)
desc = np.empty(SIZE, dtype=np.int8)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(x.data_ptr(), SIZE, 4, desc)
desc = torch.tensor(desc, device=device)
z_tri = torch.empty_like(x)
kernel[(1, 1, 1)](z_tri, desc)
assert torch.equal(x, z_tri)
2 changes: 1 addition & 1 deletion test/Conversion/amd/decompose-unsupported-conversions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) {
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[WMMA]]> -> tensor<16x16xf16, #[[BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]]>
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/amd/fp_to_fp.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm | FileCheck %s
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s

// CHECK-LABEL: f16_to_f32
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @f16_to_f32(%arg0: tensor<8x8xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) {
// CHECK-COUNT-8: llvm.inline_asm asm_dialect {{.*}}v_cvt_f32_f16 {{.*}}: (f16) -> f32
%0 = tt.fp_to_fp %arg0 : tensor<8x8xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<8x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/amd/load_store.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm | FileCheck %s
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s

#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm | FileCheck %s
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s

#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: wmma_dot_operand
tt.func @wmma_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) {
// 2 CTA * 4 rep * load_per_thread_per_instr
Expand Down
Loading

0 comments on commit 4f4af05

Please sign in to comment.