diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index b25ebf5aa5..137fecc3ca 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -26,6 +26,7 @@ import triton import triton.language as tl +from triton.runtime import driver def naive_softmax(x): @@ -62,7 +63,7 @@ def naive_softmax(x): # Compute Kernel # -------------- # -# Our softmax kernel works as follows: each program loads a row of the input matrix X, +# Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, # normalizes it and writes back the result to the output Y. # # Note that one important limitation of Triton is that each block must have a @@ -71,59 +72,103 @@ def naive_softmax(x): @triton.jit -def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): - # The rows of the softmax are independent, so we parallelize across those - row_idx = tl.program_id(0) - # The stride represents how much we need to increase the pointer to advance 1 row - row_start_ptr = input_ptr + row_idx * input_row_stride - # The block size is the next power of two greater than n_cols, so we can fit each - # row in a single block - col_offsets = tl.arange(0, BLOCK_SIZE) - input_ptrs = row_start_ptr + col_offsets - # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols - row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) - # Subtract maximum for numerical stability - row_minus_max = row - tl.max(row, axis=0) - # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - # Write back output to DRAM - output_row_start_ptr = output_ptr + row_idx * output_row_stride - output_ptrs = output_row_start_ptr + col_offsets - tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, + BLOCK_SIZE: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) # %% # We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. +device = torch.xpu.current_device() +properties = driver.active.utils.get_device_properties(device) +NUM_SM = properties["multiprocessor_count"] +SIZE_SMEM = properties["max_shared_mem"] +WARPS_PER_EU = 8 # TODO: Get from properties +EU_PER_SM = 8 # TODO: Get from properties +MAX_NUM_WG = 64 # TODO: Get from properties +WARP_SIZE = properties["sub_group_sizes"][-1] +WG_SIZE = properties["max_work_group_size"] +max_num_warps = WG_SIZE // WARP_SIZE +target = triton.runtime.driver.active.get_current_target() +warps_per_sm = WARPS_PER_EU * EU_PER_SM +max_num_resident_warps = NUM_SM * warps_per_sm +kernels = {} +# Possible SLM allocation sizes in kB +tg_slm_sizes = [i * 2**i for i in [0, 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128]] # TODO: Get from properties + def softmax(x): + + def occupancy(num_warps, size_smem): + + def allocated_slm_size(size_smem): + for size in tg_slm_sizes: + if size_smem <= size: + return size + raise RuntimeError("Exceeded max SLM allocation size") + + num_wg_threads = warps_per_sm // num_warps + num_wg_slm = MAX_NUM_WG if size_smem == 0 else SIZE_SMEM // allocated_slm_size(size_smem) + num_wg = min(num_wg_threads, num_wg_slm, MAX_NUM_WG) + return NUM_SM * num_wg + n_rows, n_cols = x.shape - # The block size is the smallest power of two greater than the number of columns in `x` + # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` BLOCK_SIZE = triton.next_power_of_2(n_cols) - # Another trick we can use is to ask the compiler to use more threads per row by - # increasing the number of warps (`num_warps`) over which each row is distributed. + + # Simple heuristic depending on `BLOCK_SIZE`. We aim for 4 elements per thread as the block size may be almost twice + # as larger as the row size. This way, we reduce the number of threads performing no work. + # As the maximum number of warps is limited by hardware, we need to make sure we do not surpass that limit. # You will see in the next tutorial how to auto-tune this value in a more natural # way so you don't have to come up with manual heuristics yourself. - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 + num_warps = min(max_num_warps, max(1, BLOCK_SIZE // (WARP_SIZE * 4))) + # Allocate output y = torch.empty_like(x) - # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o - # f the input matrix - softmax_kernel[(n_rows, )]( - y, - x, - x.stride(0), - y.stride(0), - n_cols, - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) + + # pre-compile kernel to get register usage and compute thread occupancy. + kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) + if kernel is None: + kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, num_warps=num_warps, + threads_per_warp=WARP_SIZE, BLOCK_SIZE=BLOCK_SIZE, grid=(1, )) + kernel._init_handles() + size_smem = kernel.metadata.shared + num_programs = occupancy(num_warps, size_smem) + kernels[BLOCK_SIZE] = (kernel, num_programs) + + # We will *not* launch a persistent kernel if the number of rows is lower (not needed) or that would imply each + # program would need to process more than 2 rows. Persistent kernels save thread dispatch overhead, but cannot + # hide stalling. Overdispatching will help hiding this thanks to work-group level preemption. That's why, as a + # heuristic, if each work-group would need to process at least more than 2 rows, we do not schedule a persistent + # kernel. + if n_rows < num_programs or n_rows // num_programs > 2: + num_programs = n_rows + + # Create a number of persistent programs. + kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols) return y