Skip to content

Commit

Permalink
Merge commit '9f939760d2455bb0644698a5b6f3a13aa485abde'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Nov 7, 2024
2 parents d96a80e + 9f93976 commit 6da6e84
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 32 deletions.
46 changes: 26 additions & 20 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
Expand Down Expand Up @@ -77,28 +78,33 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
}
}

SmallVector<unsigned> ret(rank, 1);
SmallVector<int64_t> shapePerWarp(rank, 1);
shapePerWarp[rank - 1] = 8;
shapePerWarp[rank - 2] = 16;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/triton-lang/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
assert(rank == 2);
SmallVector<int64_t> shapePerWarp = {16, 8};
SmallVector<int64_t> warps = {1, 1};
// Compute repM and repN
SmallVector<int64_t> reps = {ceil(shape[0], shapePerWarp[0]),
ceil(shape[1], shapePerWarp[1])};
// The formula for the number of registers given the reps is
// repM * 4 * repK + repN * 2 * repK + regsC
// where regsC = repM * repN * 4, which does not depend on the warp shape
//
// As such, to minimize the register pressure, we need to balance
// repM and repN. We then untie towards M, as the lhs tile has 4 elements,
// and the rhs tile has just 2.
while (product(warps) < numWarps) {
if (reps[0] >= reps[1]) {
warps[0] *= 2;
// Too many warps for this mma (repM == repN == 1).
// We allocate the remainin warps to the left (arbitrary choice)
if (reps[0] != 1) {
reps[0] /= 2;
}
} else {
ret[1] *= 2;
warps[1] *= 2;
reps[1] /= 2;
}
} while (true);
return ret;
}
return {(unsigned)warps[0], (unsigned)warps[1]};
}

SmallVector<unsigned, 2>
Expand Down
13 changes: 7 additions & 6 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from triton.compiler import ASTSource

target = triton.runtime.driver.active.get_current_target()
start_method = 'fork' if 'fork' in multiprocessing.get_all_start_methods() else 'spawn'


def compile_fn(attrs):
Expand All @@ -27,8 +28,8 @@ def kernel_sub(a, b, o, N: tl.constexpr):

def test_compile_in_subproc() -> None:
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(target=compile_fn, args=(config, ))
mp_ctx = multiprocessing.get_context(start_method)
proc = mp_ctx.Process(target=compile_fn, args=(config, ))
proc.start()
proc.join()
assert proc.exitcode == 0
Expand All @@ -49,8 +50,8 @@ def kernel_dot(Z):

def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
config = AttrsDescriptor.from_hints({0: 16})
assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, ))
mp_ctx = multiprocessing.get_context(start_method)
proc = mp_ctx.Process(target=compile_fn_dot, args=(config, ))
proc.start()
proc.join()
assert proc.exitcode == 0
Expand Down Expand Up @@ -92,8 +93,8 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:

# stage 2.p
shutil.rmtree(fresh_triton_cache)
assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, ))
mp_ctx = multiprocessing.get_context(start_method)
proc = mp_ctx.Process(target=compile_empty_kernel_with_gc, args=(config, ))

# stage 3.c
proc.start()
Expand Down
6 changes: 3 additions & 3 deletions test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :

// -----

// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}>
// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
Expand All @@ -93,7 +93,7 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 :

// -----

// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
// CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}>

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}>
Expand Down Expand Up @@ -148,7 +148,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
// -----

// Verify that we use mmav2 when the k dim is too small for mmav3.
// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}>
// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: small_k_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,15 +659,15 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc,

int kWidth = encoding.getKWidth();
auto numRep = mmaLayout.getMMAv2OrV3RepForOperand(
shapePerCTA, bitwidth, kWidth, encoding.getOpIdx());
shapePerCTA, mmaBitwidth, kWidth, encoding.getOpIdx());

auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto order = triton::gpu::getOrder(mmaLayout);
auto warpOrder = mmaLayout.getWarpOrder();
Value warp = udiv(thread, i32_val(32));
Value lane = urem(thread, i32_val(32));

SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warp, warpsPerCTA, order);
delinearize(rewriter, loc, warp, warpsPerCTA, warpOrder);
Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0]));
int warpsPerTile;
Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16));
Expand Down

0 comments on commit 6da6e84

Please sign in to comment.