Skip to content

Commit

Permalink
Reland [TUTORIAL] persistent softmax kernel (#1495)
Browse files Browse the repository at this point in the history
Reland 01c3e98,
a263360,
a5b32a8 and
8ffdec1.

These commits introduce tuning for NVIDIA GPUs. Modify for better tuning
for XPU devices:

- Launch a number of programs to maximize occupancy in a single wave if
that's higher than the number of rows and the minimum number of rows
each program will process is 2
- Launch `n_rows` programs otherwise
- Tune `num_warps` depending on `BLOCK_SIZE` aiming for 4 elements per
work-item.
- Drop `num_stages` argument as we don't use that for now

Code calculating occupancy based on
https://oneapi-src.github.io/oneAPI-samples/Tools/GPU-Occupancy-Calculator/

Closes #1099

---------

Signed-off-by: Victor Perez <[email protected]>
  • Loading branch information
victor-eds authored Jul 3, 2024
1 parent db107db commit 7358f79
Showing 1 changed file with 86 additions and 41 deletions.
127 changes: 86 additions & 41 deletions python/tutorials/02-fused-softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import triton
import triton.language as tl
from triton.runtime import driver


def naive_softmax(x):
Expand Down Expand Up @@ -62,7 +63,7 @@ def naive_softmax(x):
# Compute Kernel
# --------------
#
# Our softmax kernel works as follows: each program loads a row of the input matrix X,
# Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs,
# normalizes it and writes back the result to the output Y.
#
# Note that one important limitation of Triton is that each block must have a
Expand All @@ -71,59 +72,103 @@ def naive_softmax(x):


@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols,
BLOCK_SIZE: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)


# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

device = torch.xpu.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
SIZE_SMEM = properties["max_shared_mem"]
WARPS_PER_EU = 8 # TODO: Get from properties
EU_PER_SM = 8 # TODO: Get from properties
MAX_NUM_WG = 64 # TODO: Get from properties
WARP_SIZE = properties["sub_group_sizes"][-1]
WG_SIZE = properties["max_work_group_size"]
max_num_warps = WG_SIZE // WARP_SIZE
target = triton.runtime.driver.active.get_current_target()
warps_per_sm = WARPS_PER_EU * EU_PER_SM
max_num_resident_warps = NUM_SM * warps_per_sm
kernels = {}
# Possible SLM allocation sizes in kB
tg_slm_sizes = [i * 2**i for i in [0, 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128]] # TODO: Get from properties


def softmax(x):

def occupancy(num_warps, size_smem):

def allocated_slm_size(size_smem):
for size in tg_slm_sizes:
if size_smem <= size:
return size
raise RuntimeError("Exceeded max SLM allocation size")

num_wg_threads = warps_per_sm // num_warps
num_wg_slm = MAX_NUM_WG if size_smem == 0 else SIZE_SMEM // allocated_slm_size(size_smem)
num_wg = min(num_wg_threads, num_wg_slm, MAX_NUM_WG)
return NUM_SM * num_wg

n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.

# Simple heuristic depending on `BLOCK_SIZE`. We aim for 4 elements per thread as the block size may be almost twice
# as larger as the row size. This way, we reduce the number of threads performing no work.
# As the maximum number of warps is limited by hardware, we need to make sure we do not surpass that limit.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
num_warps = min(max_num_warps, max(1, BLOCK_SIZE // (WARP_SIZE * 4)))

# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
softmax_kernel[(n_rows, )](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)

# pre-compile kernel to get register usage and compute thread occupancy.
kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
if kernel is None:
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, num_warps=num_warps,
threads_per_warp=WARP_SIZE, BLOCK_SIZE=BLOCK_SIZE, grid=(1, ))
kernel._init_handles()
size_smem = kernel.metadata.shared
num_programs = occupancy(num_warps, size_smem)
kernels[BLOCK_SIZE] = (kernel, num_programs)

# We will *not* launch a persistent kernel if the number of rows is lower (not needed) or that would imply each
# program would need to process more than 2 rows. Persistent kernels save thread dispatch overhead, but cannot
# hide stalling. Overdispatching will help hiding this thanks to work-group level preemption. That's why, as a
# heuristic, if each work-group would need to process at least more than 2 rows, we do not schedule a persistent
# kernel.
if n_rows < num_programs or n_rows // num_programs > 2:
num_programs = n_rows

# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols)
return y


Expand Down

0 comments on commit 7358f79

Please sign in to comment.