Skip to content

Commit

Permalink
Merge commit 'f4c48a9233957903e30474bae6443bf3d3a79bf7'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Sep 18, 2024
2 parents f54d008 + f4c48a9 commit 31c1333
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 34 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ docs/sg_execution_times.rst

# Vim
*.swp

# macOS
.DS_Store
7 changes: 7 additions & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,13 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
int tileRows = 8;
int tileCols = 8 * tileWidthBytes / elemBitWidth;

if (shape[colDim] < tileCols || shape[rowDim] < tileRows) {
llvm::errs() << "Illegal shared layout; expected shape to be at least ["
<< tileRows << ", " << tileCols << "], shape: ["
<< shape[rowDim] << ", " << shape[colDim] << "]\n";
llvm::report_fatal_error("Illegal shared layout");
}

int vec = 8 * 16 / elemBitWidth;
if (vec != shared.getVec()) {
llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec
Expand Down
102 changes: 68 additions & 34 deletions python/tutorials/09-persistent-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ def matmul_tma_persistent(a, b):


@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
def matmul_kernel_device_tma_persistent(workspace_ptr, #
tiles_per_update: tl.constexpr, #
a_ptr, b_ptr, c_ptr, #
ready_flag, #
M, N, K, #
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
Expand All @@ -377,31 +377,32 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n

if start_pid == 0:
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K],
element_ty=a_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K],
element_ty=b_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N],
element_ty=c_ptr.dtype.element_ty)
tl.atomic_xchg(ready_flag, 1, sem="release")
else:
flag = tl.full([], 0, tl.int32)
while flag != 1:
flag = tl.atomic_add(ready_flag, 0, sem="acquire")
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
TMA_SIZE: tl.constexpr = 128
workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
c_desc_ptr = workspace_base + 2 * TMA_SIZE

tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K],
element_ty=a_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K],
element_ty=b_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N],
element_ty=c_ptr.dtype.element_ty)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1

tile_id = start_pid - NUM_SMS
ki = -1
ni = -1

pid_m = 0
pid_n = 0
Expand All @@ -415,6 +416,27 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
ni += 1

# Simulate a grouped gemm
if ni == tiles_per_update:
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
load_size=[BLOCK_SIZE_M,
BLOCK_SIZE_K], global_size=[M, K],
element_ty=a_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
load_size=[BLOCK_SIZE_N,
BLOCK_SIZE_K], global_size=[N, K],
element_ty=b_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
load_size=[BLOCK_SIZE_M,
BLOCK_SIZE_N], global_size=[M, N],
element_ty=c_ptr.dtype.element_ty)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
ni = 0

tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
Expand All @@ -435,10 +457,11 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
c = accumulator.to(dtype)

tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn])

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)


def matmul_device_tma_persistent(a, b):
def matmul_device_tma_persistent(a, b, tiles_per_update):
# Autotuner does not work with TMA. Use manual config.
configs = {
torch.float8_e4m3fn: {
Expand All @@ -459,15 +482,15 @@ def matmul_device_tma_persistent(a, b):
dtype = a.dtype

c = torch.zeros((M, N), device=a.device, dtype=dtype)
a_desc, b_desc, c_desc = [torch.empty(128, dtype=torch.uint8, device="cuda") for _ in range(3)]
ready_flag = torch.zeros((), dtype=torch.int32, device="cuda")
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
tma_size = 128
workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda")

grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
matmul_kernel_device_tma_persistent[grid](
a_desc, b_desc, c_desc, #
workspace, #
tiles_per_update, #
a, b, c, #
ready_flag, #
M, N, K, #
BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], #
BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], #
Expand Down Expand Up @@ -507,7 +530,7 @@ def torch_matmul(a, b):
return c


def bench(K, dtype, reps=10):
def bench(K, dtype, tiles_per_update, reps=10):
M = 8192
N = 8192
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
Expand Down Expand Up @@ -535,14 +558,18 @@ def bench(K, dtype, reps=10):
for _ in range(reps):
matmul_tma_persistent(a, b)
time.sleep(0.01)
for _ in range(reps):
matmul_device_tma_persistent(a, b)
time.sleep(0.01)
flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops"
with proton.scope(
f"matmul_kernel_device_tma_persistent M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}",
{"bytes": a.element_size() * (M * K + N * K), flops_str: 2. * M * N * K}):
for _ in range(reps):
matmul_device_tma_persistent(a, b, tiles_per_update)
time.sleep(0.01)

proton.deactivate(0)


def validate(M, N, K, dtype):
def validate(M, N, K, dtype, tiles_per_update):
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
b = b.T.contiguous()
Expand All @@ -552,7 +579,7 @@ def validate(M, N, K, dtype):
naive_result = matmul(a, b.T)
persistent_result = matmul_persistent(a, b.T)
tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None
device_tma_persistent_result = matmul_device_tma_persistent(a, b) if supports_tma() else None
device_tma_persistent_result = matmul_device_tma_persistent(a, b, tiles_per_update) if supports_tma() else None

if torch_result is not None:
naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16),
Expand Down Expand Up @@ -586,6 +613,13 @@ def validate(M, N, K, dtype):
parser.add_argument("-K", type=int, required=False, default=512)
parser.add_argument("--K_range", type=int, nargs=2)
parser.add_argument("--K_step", type=int, default=512)
parser.add_argument(
"--tiles_per_update",
type=int,
default=1,
help=
"Number of output tiles calculated for each update of the tma descriptor in matmul_device_tma_persistent_kernel",
)
parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16")
args = parser.parse_args()

Expand All @@ -601,10 +635,10 @@ def validate(M, N, K, dtype):

torch.manual_seed(0)

validate(32, 32, 32, dtype)
validate(8192, 8192, 512, dtype)
validate(32, 32, 32, dtype, args.tiles_per_update)
validate(8192, 8192, 512, dtype, args.tiles_per_update)

proton.start("matmul", hook="triton")
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
bench(K, dtype)
bench(K, dtype, args.tiles_per_update)
proton.finalize()
20 changes: 20 additions & 0 deletions test/TritonGPU/amd/amd-reorder-instructions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -830,3 +830,23 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
tt.return
}
}

// -----

// CHECK-LABEL: anchor_barrier
// CHECK: gpu.barrier
// CHECK: tt.load %arg0 : tensor<32x32x!tt.ptr<f16>, #blocked>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @anchor_barrier(%arg0: tensor<32x32x!tt.ptr<f16>, #blocked>) attributes {noinline = false} {
%0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable>
gpu.barrier
%2 = tt.load %arg0 : tensor<32x32x!tt.ptr<f16>, #blocked>
%1 = triton_gpu.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !tt.memdesc<4x128x64xf16, #shared, mutable>
triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable>
triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ findEarlyInsertionPoint(Block *block, Operation *move) {
// Atomics used for global synchronization.
if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(wop))
ipnt = bi;
// Break at barrier
if (isa<gpu::BarrierOp>(wop))
ipnt = bi;
// Break at loops.
if (isa<scf::ForOp, scf::WhileOp>(wop))
ipnt = bi;
Expand Down

0 comments on commit 31c1333

Please sign in to comment.