From 9cb16756d3b75ef8ed4c87900e1e4780e9db653d Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 29 Nov 2024 16:55:58 -0500 Subject: [PATCH 01/11] Add flash_attention_benchmark --- .gitignore | 4 ++++ benchmarks/flash_attention_bench/run.py | 23 +++++++++++++++++++ .../operators/flash_attention/operator.py | 5 +++- 3 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 benchmarks/flash_attention_bench/run.py diff --git a/.gitignore b/.gitignore index ecfa253b..64e77890 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,8 @@ __pycache__/ .ipynb_checkpoints/ .idea *.egg-info/ +<<<<<<< HEAD torch_compile_debug/ +======= +build/ +>>>>>>> 1a642f1 (Add flash_attention_benchmark) diff --git a/benchmarks/flash_attention_bench/run.py b/benchmarks/flash_attention_bench/run.py new file mode 100644 index 00000000..092b55c0 --- /dev/null +++ b/benchmarks/flash_attention_bench/run.py @@ -0,0 +1,23 @@ +""" +Run flash_attention benchmark with a single input shape: +BATCH: 4 + +SEQ_LEN: 16384 + +Print tflops metrics +""" + +import tritonbench +from tritonbench.utils.parser import get_parser + + +def run(): + args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "bf16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal"] + flash_attn_op = tritonbench.load_opbench_by_name("flash_attention") + parser = get_parser() + args, extra_args = parser.parse_known_args(args) + flash_attn_bench = flash_attn_op(args, extra_args) + flash_attn_bench.run() + +if __name__ == "__main__": + run() diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 5d3bcc46..4b9ebe51 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -149,6 +149,7 @@ def parse_op_args(args: List[str]): action="store_true", help="enable causal (always true on backward)", ) + parser.add_argument("--additional-inputs", action="store_true", help="enable additional inputs") return parser.parse_args(args) @@ -172,6 +173,7 @@ def __init__( # Because Triton-Flash-V2 does not support backward with non-causal if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD: self.causal = True + self.additional_inputs = args.additional_inputs self.sm_scale = 1.3 @register_benchmark() @@ -489,7 +491,8 @@ def get_ctx_vals(): yield (BATCH, H, N_CTX, D_HEAD) ctx_vals = get_ctx_vals() - shapes = self.__additional_example_input(ctx_vals) + if self.additional_inputs: + shapes = self.__additional_example_input(ctx_vals) requires_grad = True for shape in shapes: BATCH, H, N_CTX, D_HEAD = shape From 82a471687375dd0d8e651631aa3bd7d9a016455f Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 29 Nov 2024 15:16:38 -0800 Subject: [PATCH 02/11] Add flash attention bench --- benchmarks/flash_attention_bench/run.py | 3 ++- tritonbench/operators/flash_attention/operator.py | 13 ++++++++++--- tunableop_results0.csv | 0 3 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 tunableop_results0.csv diff --git a/benchmarks/flash_attention_bench/run.py b/benchmarks/flash_attention_bench/run.py index 092b55c0..b377a9f3 100644 --- a/benchmarks/flash_attention_bench/run.py +++ b/benchmarks/flash_attention_bench/run.py @@ -12,12 +12,13 @@ def run(): - args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "bf16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal"] + args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "bf16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal", "--metrics", "tflops"] flash_attn_op = tritonbench.load_opbench_by_name("flash_attention") parser = get_parser() args, extra_args = parser.parse_known_args(args) flash_attn_bench = flash_attn_op(args, extra_args) flash_attn_bench.run() + print(flash_attn_bench.output) if __name__ == "__main__": run() diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 4b9ebe51..aa43b9c3 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -141,8 +141,8 @@ def parse_op_args(args: List[str]): parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=4, help="Batch size") - parser.add_argument("--seq-len", type=int, default=11, help="Batch size") - parser.add_argument("--n-heads", type=int, default=48, help="Number of heads") + parser.add_argument("--seq-len", type=int, default=16384, help="Sequence length") + parser.add_argument("--n-heads", type=int, default=None, help="Number of heads") parser.add_argument("--d-head", type=int, default=64, help="specify head dimension") parser.add_argument( "--causal", @@ -479,12 +479,17 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable: return fn def get_input_iter(self) -> Generator: + import math D_HEAD = self.D_HEAD BATCH = self.BATCH H = self.H + seq_len_log2 = int(math.log2(self.SEQ_LEN)) def get_ctx_vals(): - for i in range(self.SEQ_LEN, 15): + if self.H: + yield (BATCH, self.H, self.SEQ_LEN, self.D_HEAD) + return + for i in range(seq_len_log2, 15): N_CTX = 2**i # BATCH = 16384 // N_CTX # H = 2048 // D_HEAD @@ -493,6 +498,8 @@ def get_ctx_vals(): ctx_vals = get_ctx_vals() if self.additional_inputs: shapes = self.__additional_example_input(ctx_vals) + else: + shapes = ctx_vals requires_grad = True for shape in shapes: BATCH, H, N_CTX, D_HEAD = shape diff --git a/tunableop_results0.csv b/tunableop_results0.csv new file mode 100644 index 00000000..e69de29b From dfbf006eb1facb8b64855502442c685d4cdb9e15 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 29 Nov 2024 15:17:08 -0800 Subject: [PATCH 03/11] Enable triton flash_attention --- tunableop_results0.csv | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tunableop_results0.csv diff --git a/tunableop_results0.csv b/tunableop_results0.csv deleted file mode 100644 index e69de29b..00000000 From cd6390950321c82760a670c13b08df51f06502b3 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 29 Nov 2024 18:30:56 -0500 Subject: [PATCH 04/11] Add fused attention vanilla --- .../kernels/triton_fused_attention_vanilla.py | 505 ++++++++++++++++++ .../operators/flash_attention/operator.py | 10 +- 2 files changed, 512 insertions(+), 3 deletions(-) create mode 100644 tritonbench/kernels/triton_fused_attention_vanilla.py diff --git a/tritonbench/kernels/triton_fused_attention_vanilla.py b/tritonbench/kernels/triton_fused_attention_vanilla.py new file mode 100644 index 00000000..e67cf7d1 --- /dev/null +++ b/tritonbench/kernels/triton_fused_attention_vanilla.py @@ -0,0 +1,505 @@ +import torch +import triton +import triton.language as tl + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, # + K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr, fp8_v: tl.constexpr): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + qk = tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl.load(V_block_ptr) + if fp8_v: + p = p.to(tl.float8e5) + else: + p = p.to(tl.float16) + acc = tl.dot(p, v, acc) + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +# We don't run auto-tuning every time to keep the tutorial fast. Keeping +# the code below and commenting out the equivalent parameters is convenient for +# re-tuning. +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ + for BM in [64, 128]\ + for BN in [32, 64]\ + for s in ([1] if is_hip() else [3, 4, 7])\ + for w in [4, 8]\ +] + + +def keep(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: + return False + return True + + +@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) +@triton.jit +def _attn_fwd(Q, K, V, sm_scale, M, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr # + ): + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + # block pointers + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=v_order, + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(HEAD_DIM, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_bwd_preprocess(O, DO, # + Delta, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # + ): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_hz = tl.program_id(1) + off_n = tl.arange(0, HEAD_DIM) + # load + o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) + do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hz * N_CTX + off_m, delta) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + # Filled in by the wrapper. + start_n, start_m, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, HEAD_DIM) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq(dq, q, K, V, # + do, m, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, # + DO, # + DQ, DK, DV, # + M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, HEAD_DIM) + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=True # + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv( # + dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=False # + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # + MASK=True # + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * BLOCK_N2, num_steps, # + MASK=False # + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale): + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + # Tuning for AMD target + if is_hip(): + waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 + extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} + + grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + _attn_fwd[grid]( + q, k, v, sm_scale, M, o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + **extra_kern_args) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o, do, # + delta, # + BATCH, N_HEAD, N_CTX, # + BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # + ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # + M, delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + N_HEAD, N_CTX, # + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES # + ) + + return dq, dk, dv, None, None + + +attention = _attention.apply diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index aa43b9c3..7290247e 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -57,6 +57,10 @@ from tritonbench.kernels.triton_fused_attention import ( attention_opt as triton_tutorial_FA2_opt, ) +from tritonbench.kernels.triton_fused_attention_vanilla import ( + attention as triton_tutorial_FA2, +) + # [Optional] flash_attn v2 try: @@ -94,7 +98,7 @@ from .test_fmha_utils import permute_qkv HAS_XFORMERS = True -except (ImportError, IOError, AttributeError): +except (ImportError, IOError, AttributeError, TypeError): HAS_XFORMERS = False # [Optional] colfax cutlass backend @@ -253,8 +257,8 @@ def triton_tutorial_flash_v2( v: torch.Tensor, ) -> Callable: # base: do not enable TMA/WarpSpec/CompPipe - return lambda: triton_tutorial_FA2_opt( - q, k, v, self.causal, self.sm_scale, "base" + return lambda: triton_tutorial_FA2( + q, k, v, self.causal, self.sm_scale ) @register_benchmark(enabled=HAS_CUDA_124) From d98e0c2b7ed36d566ae9c722ef29ac9ea2d71085 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 29 Nov 2024 18:34:10 -0500 Subject: [PATCH 05/11] Change how we detect device --- tritonbench/kernels/triton_fused_attention_vanilla.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tritonbench/kernels/triton_fused_attention_vanilla.py b/tritonbench/kernels/triton_fused_attention_vanilla.py index e67cf7d1..1f75db13 100644 --- a/tritonbench/kernels/triton_fused_attention_vanilla.py +++ b/tritonbench/kernels/triton_fused_attention_vanilla.py @@ -4,7 +4,9 @@ def is_hip(): - return triton.runtime.driver.active.get_current_target().backend == "hip" + # use pytorch to detect current device + # return triton.runtime.driver.active.get_current_target().backend == "hip" + return bool(torch.version.hip) @triton.jit From 59f8a7e22a5812dd24fae879074b0e8d23f29477 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 29 Nov 2024 18:40:03 -0500 Subject: [PATCH 06/11] Add fused attention that will work on v2.1.0 --- .../kernels/triton_fused_attention_vanilla.py | 680 ++++++------------ 1 file changed, 237 insertions(+), 443 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention_vanilla.py b/tritonbench/kernels/triton_fused_attention_vanilla.py index 1f75db13..fd15c174 100644 --- a/tritonbench/kernels/triton_fused_attention_vanilla.py +++ b/tritonbench/kernels/triton_fused_attention_vanilla.py @@ -1,425 +1,233 @@ -import torch -import triton -import triton.language as tl - - -def is_hip(): - # use pytorch to detect current device - # return triton.runtime.driver.active.get_current_target().backend == "hip" - return bool(torch.version.hip) +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) -@triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, # - K_block_ptr, V_block_ptr, # - start_m, qk_scale, # - BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # - N_CTX: tl.constexpr, fp8_v: tl.constexpr): - # range of values handled by this stage - if STAGE == 1: - lo, hi = 0, start_m * BLOCK_M - elif STAGE == 2: - lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M - lo = tl.multiple_of(lo, BLOCK_M) - # causal = False - else: - lo, hi = 0, N_CTX - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - qk = tl.dot(q, k) - if STAGE == 2: - mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - else: - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij - # -- update output accumulator -- - acc = acc * alpha[:, None] - # update acc - v = tl.load(V_block_ptr) - if fp8_v: - p = p.to(tl.float8e5) - else: - p = p.to(tl.float16) - acc = tl.dot(p, v, acc) - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - return acc, l_i, m_i +Extra Credits: +- Original flash attention paper (https://arxiv.org/abs/2205.14135) +- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) +- Adam P. Goucher for simplified vector math +""" +import torch -# We don't run auto-tuning every time to keep the tutorial fast. Keeping -# the code below and commenting out the equivalent parameters is convenient for -# re-tuning. -configs = [ - triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ - for BM in [64, 128]\ - for BN in [32, 64]\ - for s in ([1] if is_hip() else [3, 4, 7])\ - for w in [4, 8]\ -] +import triton +import triton.language as tl -def keep(conf): - BLOCK_M = conf.kwargs["BLOCK_M"] - BLOCK_N = conf.kwargs["BLOCK_N"] - if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: - return False - return True +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) -@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) @triton.jit -def _attn_fwd(Q, K, V, sm_scale, M, Out, # - stride_qz, stride_qh, stride_qm, stride_qk, # - stride_kz, stride_kh, stride_kn, stride_kk, # - stride_vz, stride_vh, stride_vk, stride_vn, # - stride_oz, stride_oh, stride_om, stride_on, # - Z, H, N_CTX, # - HEAD_DIM: tl.constexpr, # - BLOCK_M: tl.constexpr, # - BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr # - ): - tl.static_assert(BLOCK_N <= HEAD_DIM) +def _fwd_kernel( + Q, K, V, sm_scale, + L, + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): start_m = tl.program_id(0) off_hz = tl.program_id(1) - off_z = off_hz // H - off_h = off_hz % H - qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh - - # block pointers + qvk_offset = off_hz * stride_qh Q_block_ptr = tl.make_block_ptr( base=Q + qvk_offset, - shape=(N_CTX, HEAD_DIM), + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), - order=(1, 0), - ) - v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) - V_block_ptr = tl.make_block_ptr( - base=V + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, HEAD_DIM), - order=v_order, + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) ) K_block_ptr = tl.make_block_ptr( base=K + qvk_offset, - shape=(HEAD_DIM, N_CTX), + shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), - block_shape=(HEAD_DIM, BLOCK_N), - order=(0, 1), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) ) - O_block_ptr = tl.make_block_ptr( - base=Out + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), - order=(1, 0), + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - # load scales - qk_scale = sm_scale - qk_scale *= 1.44269504 # 1/log(2) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 # load q: it will stay in SRAM throughout q = tl.load(Q_block_ptr) - # stage 1: off-band - # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE - # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE - if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # - start_m, qk_scale, # - BLOCK_M, HEAD_DIM, BLOCK_N, # - 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # - ) - # stage 2: on-band - if STAGE & 2: - # barrier makes it easier for compielr to schedule the - # two loops independently - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # - start_m, qk_scale, # - BLOCK_M, HEAD_DIM, BLOCK_N, # - 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # - ) - # epilogue - m_i += tl.math.log2(l_i) + q = (q * qk_scale).to(tl.float16) + # loop over k, v and update accumulator + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.float16), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back l and m acc = acc / l_i[:, None] - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + l_ptrs = L + off_hz * N_CTX + offs_m + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(tl.float16)) @triton.jit -def _attn_bwd_preprocess(O, DO, # - Delta, # - Z, H, N_CTX, # - BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # - ): +def _bwd_preprocess( + Out, DO, + Delta, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, +): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) - off_hz = tl.program_id(1) - off_n = tl.arange(0, HEAD_DIM) + off_n = tl.arange(0, D_HEAD) # load - o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) - do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # compute delta = tl.sum(o * do, axis=1) # write-back - tl.store(Delta + off_hz * N_CTX + off_m, delta) - - -# The main inner-loop logic for computing dK and dV. -@triton.jit -def _attn_bwd_dkdv(dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - # shared by Q/K/V/DO. - stride_tok, stride_d, # - H, N_CTX, BLOCK_M1: tl.constexpr, # - BLOCK_N1: tl.constexpr, # - HEAD_DIM: tl.constexpr, # - # Filled in by the wrapper. - start_n, start_m, num_steps, # - MASK: tl.constexpr): - offs_m = start_m + tl.arange(0, BLOCK_M1) - offs_n = start_n + tl.arange(0, BLOCK_N1) - offs_k = tl.arange(0, HEAD_DIM) - qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d - do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d - # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) - curr_m = start_m - step_m = BLOCK_M1 - for blk_idx in range(num_steps): - qT = tl.load(qT_ptrs) - # Load m before computing qk to reduce pipeline stall. - offs_m = curr_m + tl.arange(0, BLOCK_M1) - m = tl.load(M + offs_m) - qkT = tl.dot(k, qT) - pT = tl.math.exp2(qkT - m[None, :]) - # Autoregressive masking. - if MASK: - mask = (offs_m[None, :] >= offs_n[:, None]) - pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs) - # Compute dV. - ppT = pT - ppT = ppT.to(tl.float16) - dv += tl.dot(ppT, do) - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m) - # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)).to(tl.float32) - dsT = pT * (dpT - Di[None, :]) - dsT = dsT.to(tl.float16) - dk += tl.dot(dsT, tl.trans(qT)) - # Increment pointers. - curr_m += step_m - qT_ptrs += step_m * stride_tok - do_ptrs += step_m * stride_tok - return dk, dv + tl.store(Delta + off_m, delta) -# the main inner-loop logic for computing dQ @triton.jit -def _attn_bwd_dq(dq, q, K, V, # - do, m, D, - # shared by Q/K/V/DO. - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, - # Filled in by the wrapper. - start_m, start_n, num_steps, # - MASK: tl.constexpr): - offs_m = start_m + tl.arange(0, BLOCK_M2) - offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) - kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d - vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m) - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - curr_n = start_n - step_n = BLOCK_N2 - for blk_idx in range(num_steps): - kT = tl.load(kT_ptrs) - vT = tl.load(vT_ptrs) - qk = tl.dot(q, kT) - p = tl.math.exp2(qk - m) - # Autoregressive masking. - if MASK: - offs_n = curr_n + tl.arange(0, BLOCK_N2) - mask = (offs_m[:, None] >= offs_n[None, :]) - p = tl.where(mask, p, 0.0) - # Compute dP and dS. - dp = tl.dot(do, vT).to(tl.float32) - ds = p * (dp - Di[:, None]) - ds = ds.to(tl.float16) - # Compute dQ. - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - dq += tl.dot(ds, tl.trans(kT)) - # Increment pointers. - curr_n += step_n - kT_ptrs += step_n * stride_tok - vT_ptrs += step_n * stride_tok - return dq - - -@triton.jit -def _attn_bwd(Q, K, V, sm_scale, # - DO, # - DQ, DK, DV, # - M, D, - # shared by Q/K/V/DO. - stride_z, stride_h, stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M1: tl.constexpr, # - BLOCK_N1: tl.constexpr, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - BLK_SLICE_FACTOR: tl.constexpr, # - HEAD_DIM: tl.constexpr): - LN2: tl.constexpr = 0.6931471824645996 # = ln(2) - - bhid = tl.program_id(2) - off_chz = (bhid * N_CTX).to(tl.int64) - adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) - pid = tl.program_id(0) - +def _bwd_kernel( + Q, K, V, sm_scale, Out, DO, + DQ, DK, DV, + L, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + num_block, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + CAUSAL: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + qk_scale = sm_scale * 1.44269504 # offset pointers for batch/head - Q += adj - K += adj - V += adj - DO += adj - DQ += adj - DK += adj - DV += adj - M += off_chz - D += off_chz - - # load scales - offs_k = tl.arange(0, HEAD_DIM) - - start_n = pid * BLOCK_N1 - start_m = start_n - - MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR - offs_n = start_n + tl.arange(0, BLOCK_N1) - - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - - num_steps = BLOCK_N1 // MASK_BLOCK_M1 - - dk, dv = _attn_bwd_dkdv(dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - stride_tok, stride_d, # - H, N_CTX, # - MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # - start_n, start_m, num_steps, # - MASK=True # - ) - - start_m += num_steps * MASK_BLOCK_M1 - num_steps = (N_CTX - start_m) // BLOCK_M1 - - # Compute dK and dV for non-masked blocks. - dk, dv = _attn_bwd_dkdv( # - dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M1, BLOCK_N1, HEAD_DIM, # - start_n, start_m, num_steps, # - MASK=False # - ) - - dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dv_ptrs, dv) - - # Write back dK. - dk *= sm_scale - dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dk_ptrs, dk) - - # THIS BLOCK DOES DQ: - start_m = pid * BLOCK_M2 - end_n = start_m + BLOCK_M2 - - MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR - offs_m = start_m + tl.arange(0, BLOCK_M2) - - q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - - m = tl.load(M + offs_m) - m = m[:, None] - - # Compute dQ for masked (diagonal) blocks. - # NOTE: This code scans each row of QK^T backward (from right to left, - # but inside each call to _attn_bwd_dq, from left to right), but that's - # not due to anything important. I just wanted to reuse the loop - # structure for dK & dV above as much as possible. - num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, # - do, m, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # - start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # - MASK=True # - ) - end_n -= num_steps * MASK_BLOCK_N2 - # stage 2 - num_steps = end_n // BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, # - do, m, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2, BLOCK_N2, HEAD_DIM, # - start_m, end_n - num_steps * BLOCK_N2, num_steps, # - MASK=False # - ) - # Write back dQ. - dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d - dq *= LN2 - tl.store(dq_ptrs, dq) + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + if CAUSAL: + lo = start_n * BLOCK_M + else: + lo = 0 + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + if CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) + else: + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk *= qk_scale + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) + # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(Q.dtype.element_ty), k) + tl.store(dq_ptrs, dq) + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + +empty = torch.empty(128, device="cuda") class _attention(torch.autograd.Function): @@ -427,80 +235,66 @@ class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints - HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] - # when v is in float8_e5m2 it is transposed. - HEAD_DIM_V = v.shape[-1] - assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V - assert HEAD_DIM_K in {16, 32, 64, 128, 256} + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) - stage = 3 if causal else 1 - extra_kern_args = {} - # Tuning for AMD target - if is_hip(): - waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 - extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} - - grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) - M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - _attn_fwd[grid]( - q, k, v, sm_scale, M, o, # - q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - k.stride(0), k.stride(1), k.stride(2), k.stride(3), # - v.stride(0), v.stride(1), v.stride(2), v.stride(3), # - o.stride(0), o.stride(1), o.stride(2), o.stride(3), # - q.shape[0], q.shape[1], # - N_CTX=q.shape[2], # - HEAD_DIM=HEAD_DIM_K, # - STAGE=stage, # - **extra_kern_args) - - ctx.save_for_backward(q, k, v, o, M) + BLOCK_M = 128 + BLOCK_N = 64 + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( + q, k, v, sm_scale, + L, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + IS_CAUSAL=causal, + num_warps=num_warps, + num_stages=4) + + ctx.save_for_backward(q, k, v, o, L) ctx.grid = grid ctx.sm_scale = sm_scale - ctx.HEAD_DIM = HEAD_DIM_K + ctx.BLOCK_DMODEL = Lk ctx.causal = causal return o @staticmethod def backward(ctx, do): - q, k, v, o, M = ctx.saved_tensors - assert do.is_contiguous() - assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() - dq = torch.empty_like(q) + BLOCK = 128 + q, k, v, o, L = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty_like(k) dv = torch.empty_like(v) - BATCH, N_HEAD, N_CTX = q.shape[:3] - PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 5 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) - arg_k = k - arg_k = arg_k * (ctx.sm_scale * RCP_LN2) - PRE_BLOCK = 128 - assert N_CTX % PRE_BLOCK == 0 - pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) - delta = torch.empty_like(M) - _attn_bwd_preprocess[pre_grid]( - o, do, # - delta, # - BATCH, N_HEAD, N_CTX, # - BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # + delta = torch.empty_like(L) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, + delta, + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) - _attn_bwd[grid]( - q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # - M, delta, # - q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - N_HEAD, N_CTX, # - BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # - BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # - HEAD_DIM=ctx.HEAD_DIM, # - num_warps=NUM_WARPS, # - num_stages=NUM_STAGES # + _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do, + dq, dk, dv, + L, delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + CAUSAL=ctx.causal, + num_stages=1, ) - return dq, dk, dv, None, None From d0442271525c0dcd5595dc0236dd51c6a55a662a Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 29 Nov 2024 18:41:54 -0500 Subject: [PATCH 07/11] Add precision --- benchmarks/flash_attention_bench/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/flash_attention_bench/run.py b/benchmarks/flash_attention_bench/run.py index b377a9f3..bf4f4203 100644 --- a/benchmarks/flash_attention_bench/run.py +++ b/benchmarks/flash_attention_bench/run.py @@ -12,7 +12,7 @@ def run(): - args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "bf16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal", "--metrics", "tflops"] + args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "bf16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal", "--metrics", "tflops", "--precision", "fp16"] flash_attn_op = tritonbench.load_opbench_by_name("flash_attention") parser = get_parser() args, extra_args = parser.parse_known_args(args) From a562c2fa2d905ffa03951ee23360dbfaa08b70dc Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 4 Dec 2024 09:51:00 -0500 Subject: [PATCH 08/11] Add flash attention and gemm benchmarks --- benchmarks/flash_attention_bench/run.py | 2 +- benchmarks/gemm_bench/run.py | 25 ++ .../kernels/triton_fused_attention_vanilla.py | 301 ------------------ .../operators/flash_attention/operator.py | 7 +- 4 files changed, 28 insertions(+), 307 deletions(-) create mode 100644 benchmarks/gemm_bench/run.py delete mode 100644 tritonbench/kernels/triton_fused_attention_vanilla.py diff --git a/benchmarks/flash_attention_bench/run.py b/benchmarks/flash_attention_bench/run.py index bf4f4203..1d17f40a 100644 --- a/benchmarks/flash_attention_bench/run.py +++ b/benchmarks/flash_attention_bench/run.py @@ -12,7 +12,7 @@ def run(): - args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "bf16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal", "--metrics", "tflops", "--precision", "fp16"] + args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "fp16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal", "--metrics", "tflops"] flash_attn_op = tritonbench.load_opbench_by_name("flash_attention") parser = get_parser() args, extra_args = parser.parse_known_args(args) diff --git a/benchmarks/gemm_bench/run.py b/benchmarks/gemm_bench/run.py new file mode 100644 index 00000000..5c1b2589 --- /dev/null +++ b/benchmarks/gemm_bench/run.py @@ -0,0 +1,25 @@ +""" +Run gemm benchmark with a single input shape: + +M = 4096 +N = 4096 +K = 4096 + +Print tflops metrics +""" + +import tritonbench +from tritonbench.utils.parser import get_parser + + +def run(): + args = ["--m", "4096", "--n", "4096", "--k", "4096", "--precision", "fp16", "--only", "triton_tutorial_matmul", "--metrics", "tflops"] + gemm_op = tritonbench.load_opbench_by_name("gemm") + parser = get_parser() + args, extra_args = parser.parse_known_args(args) + gemm_bench = gemm_op(args, extra_args) + gemm_bench.run() + print(gemm_bench.output) + +if __name__ == "__main__": + run() diff --git a/tritonbench/kernels/triton_fused_attention_vanilla.py b/tritonbench/kernels/triton_fused_attention_vanilla.py deleted file mode 100644 index fd15c174..00000000 --- a/tritonbench/kernels/triton_fused_attention_vanilla.py +++ /dev/null @@ -1,301 +0,0 @@ -""" -Fused Attention -=============== - -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) - -Extra Credits: -- Original flash attention paper (https://arxiv.org/abs/2205.14135) -- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) -- Adam P. Goucher for simplified vector math - -""" -import torch - -import triton -import triton.language as tl - - -@triton.jit -def max_fn(x, y): - return tl.math.max(x, y) - - -@triton.jit -def _fwd_kernel( - Q, K, V, sm_scale, - L, - Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, N_CTX, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qvk_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - K_block_ptr = tl.make_block_ptr( - base=K + qvk_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0) - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout - q = tl.load(Q_block_ptr) - q = (q * qk_scale).to(tl.float16) - # loop over k, v and update accumulator - lo = 0 - hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX - for start_n in range(lo, hi, BLOCK_N): - # -- load k, v -- - k = tl.load(K_block_ptr) - v = tl.load(V_block_ptr) - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if IS_CAUSAL: - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - qk += tl.dot(q, k) - # -- compute scaling constant --- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] - acc += tl.dot(p.to(tl.float16), v) - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - # update pointers - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - # write back l and m - acc = acc / l_i[:, None] - l_ptrs = L + off_hz * N_CTX + offs_m - tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - tl.store(O_block_ptr, acc.to(tl.float16)) - - -@triton.jit -def _bwd_preprocess( - Out, DO, - Delta, - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, -): - off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) - off_n = tl.arange(0, D_HEAD) - # load - o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - # compute - delta = tl.sum(o * do, axis=1) - # write-back - tl.store(Delta + off_m, delta) - - -@triton.jit -def _bwd_kernel( - Q, K, V, sm_scale, Out, DO, - DQ, DK, DV, - L, - D, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - Z, H, N_CTX, - num_block, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - CAUSAL: tl.constexpr, -): - off_hz = tl.program_id(0) - off_z = off_hz // H - off_h = off_hz % H - qk_scale = sm_scale * 1.44269504 - # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_qz + off_h * stride_qh - V += off_z * stride_qz + off_h * stride_qh - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DK += off_z * stride_qz + off_h * stride_qh - DV += off_z * stride_qz + off_h * stride_qh - for start_n in range(0, num_block): - if CAUSAL: - lo = start_n * BLOCK_M - else: - lo = 0 - # initialize row/col offsets - offs_qm = lo + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) - offs_m = tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - # pointer to row-wise quantities in value-like data - D_ptrs = D + off_hz * N_CTX - l_ptrs = L + off_hz * N_CTX - # initialize dv amd dk - dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # k and v stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - # loop over rows - for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): - offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - # recompute p = softmax(qk, dim=-1).T - if CAUSAL: - qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) - else: - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, tl.trans(k)) - qk *= qk_scale - l_i = tl.load(l_ptrs + offs_m_curr) - p = tl.math.exp2(qk - l_i[:, None]) - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, tl.trans(v)) - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) - # compute dq - dq = tl.load(dq_ptrs) - dq += tl.dot(ds.to(Q.dtype.element_ty), k) - tl.store(dq_ptrs, dq) - # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - - -empty = torch.empty(128, device="cuda") - - -class _attention(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, causal, sm_scale): - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - o = torch.empty_like(q) - BLOCK_M = 128 - BLOCK_N = 64 - grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) - L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, k, v, sm_scale, - L, - o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, - IS_CAUSAL=causal, - num_warps=num_warps, - num_stages=4) - - ctx.save_for_backward(q, k, v, o, L) - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = Lk - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - BLOCK = 128 - q, k, v, o, L = ctx.saved_tensors - do = do.contiguous() - dq = torch.zeros_like(q, dtype=torch.float32) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - delta = torch.empty_like(L) - _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( - o, do, - delta, - BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, - ) - _bwd_kernel[(ctx.grid[1],)]( - q, k, v, ctx.sm_scale, - o, do, - dq, dk, dv, - L, delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - q.shape[0], q.shape[1], q.shape[2], - ctx.grid[0], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - CAUSAL=ctx.causal, - num_stages=1, - ) - return dq, dk, dv, None, None - - -attention = _attention.apply diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 7290247e..10f05ebe 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -57,9 +57,6 @@ from tritonbench.kernels.triton_fused_attention import ( attention_opt as triton_tutorial_FA2_opt, ) -from tritonbench.kernels.triton_fused_attention_vanilla import ( - attention as triton_tutorial_FA2, -) # [Optional] flash_attn v2 @@ -257,8 +254,8 @@ def triton_tutorial_flash_v2( v: torch.Tensor, ) -> Callable: # base: do not enable TMA/WarpSpec/CompPipe - return lambda: triton_tutorial_FA2( - q, k, v, self.causal, self.sm_scale + return lambda: triton_tutorial_FA2_opt( + q, k, v, self.causal, self.sm_scale, "base" ) @register_benchmark(enabled=HAS_CUDA_124) From c3e679206e031080070c723e2d85851f0a3685d7 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 4 Dec 2024 11:16:30 -0500 Subject: [PATCH 09/11] Add format --- .gitignore | 3 --- benchmarks/flash_attention_bench/run.py | 20 ++++++++++++++++++- benchmarks/gemm_bench/run.py | 16 ++++++++++++++- .../operators/flash_attention/operator.py | 5 ++++- 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 64e77890..fc0098f4 100644 --- a/.gitignore +++ b/.gitignore @@ -14,8 +14,5 @@ __pycache__/ .ipynb_checkpoints/ .idea *.egg-info/ -<<<<<<< HEAD torch_compile_debug/ -======= build/ ->>>>>>> 1a642f1 (Add flash_attention_benchmark) diff --git a/benchmarks/flash_attention_bench/run.py b/benchmarks/flash_attention_bench/run.py index 1d17f40a..ac1df18d 100644 --- a/benchmarks/flash_attention_bench/run.py +++ b/benchmarks/flash_attention_bench/run.py @@ -12,7 +12,24 @@ def run(): - args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "fp16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal", "--metrics", "tflops"] + args = [ + "--batch", + "4", + "--seq-len", + "16384", + "--n-heads", + "32", + "--d-head", + "64", + "--precision", + "fp16", + "--bwd", + "--only", + "triton_tutorial_flash_v2", + "--causal", + "--metrics", + "tflops", + ] flash_attn_op = tritonbench.load_opbench_by_name("flash_attention") parser = get_parser() args, extra_args = parser.parse_known_args(args) @@ -20,5 +37,6 @@ def run(): flash_attn_bench.run() print(flash_attn_bench.output) + if __name__ == "__main__": run() diff --git a/benchmarks/gemm_bench/run.py b/benchmarks/gemm_bench/run.py index 5c1b2589..7dcc1dad 100644 --- a/benchmarks/gemm_bench/run.py +++ b/benchmarks/gemm_bench/run.py @@ -13,7 +13,20 @@ def run(): - args = ["--m", "4096", "--n", "4096", "--k", "4096", "--precision", "fp16", "--only", "triton_tutorial_matmul", "--metrics", "tflops"] + args = [ + "--m", + "4096", + "--n", + "4096", + "--k", + "4096", + "--precision", + "fp16", + "--only", + "triton_tutorial_matmul", + "--metrics", + "tflops", + ] gemm_op = tritonbench.load_opbench_by_name("gemm") parser = get_parser() args, extra_args = parser.parse_known_args(args) @@ -21,5 +34,6 @@ def run(): gemm_bench.run() print(gemm_bench.output) + if __name__ == "__main__": run() diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 10f05ebe..dcc2daca 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -150,7 +150,9 @@ def parse_op_args(args: List[str]): action="store_true", help="enable causal (always true on backward)", ) - parser.add_argument("--additional-inputs", action="store_true", help="enable additional inputs") + parser.add_argument( + "--additional-inputs", action="store_true", help="enable additional inputs" + ) return parser.parse_args(args) @@ -481,6 +483,7 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable: def get_input_iter(self) -> Generator: import math + D_HEAD = self.D_HEAD BATCH = self.BATCH H = self.H From c26a51fc226cfbb051a3a230dc2b28e9f5c6be62 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 4 Dec 2024 18:30:59 -0500 Subject: [PATCH 10/11] Add option to do native sdpa --- tritonbench/operators/flash_attention/operator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index dcc2daca..ade5d515 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -34,6 +34,7 @@ import argparse import math import os +from contextlib import nullcontext from itertools import chain import torch @@ -150,6 +151,7 @@ def parse_op_args(args: List[str]): action="store_true", help="enable causal (always true on backward)", ) + parser.add_argument("--native-sdpa", action="store_true", help="Use SDPA native choice.") parser.add_argument( "--additional-inputs", action="store_true", help="enable additional inputs" ) @@ -172,6 +174,7 @@ def __init__( self.D_HEAD = args.d_head self.N_CTX = None self.causal = args.causal + self.native_sdpa = args.native_sdpa # We always turn on causal for backward # Because Triton-Flash-V2 does not support backward with non-causal if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD: @@ -206,7 +209,9 @@ def sdpa( v: torch.Tensor, ) -> Callable: def sdpa_flash_attention(q, k, v): - with sdpa_kernel([SDPBackend.FLASH_ATTENTION]): + cxt = nullcontext if self.native_sdpa else \ + sdpa_kernel([SDPBackend.FLASH_ATTENTION]) + with cxt: return sdpa( q, k, From ba41fb4037b730348545c93a513512f3eba168bb Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 4 Dec 2024 18:44:10 -0500 Subject: [PATCH 11/11] Add gemm operator --- .../operators/flash_attention/operator.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index ade5d515..5cc4d223 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -143,15 +143,17 @@ def parse_op_args(args: List[str]): parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=4, help="Batch size") - parser.add_argument("--seq-len", type=int, default=16384, help="Sequence length") - parser.add_argument("--n-heads", type=int, default=None, help="Number of heads") + parser.add_argument("--seq-len", type=int, default=None, help="Sequence length") + parser.add_argument("--n-heads", type=int, default=48, help="Number of heads") parser.add_argument("--d-head", type=int, default=64, help="specify head dimension") parser.add_argument( "--causal", action="store_true", help="enable causal (always true on backward)", ) - parser.add_argument("--native-sdpa", action="store_true", help="Use SDPA native choice.") + parser.add_argument( + "--native-sdpa", action="store_true", help="Use SDPA native choice." + ) parser.add_argument( "--additional-inputs", action="store_true", help="enable additional inputs" ) @@ -209,8 +211,11 @@ def sdpa( v: torch.Tensor, ) -> Callable: def sdpa_flash_attention(q, k, v): - cxt = nullcontext if self.native_sdpa else \ - sdpa_kernel([SDPBackend.FLASH_ATTENTION]) + cxt = ( + nullcontext + if self.native_sdpa + else sdpa_kernel([SDPBackend.FLASH_ATTENTION]) + ) with cxt: return sdpa( q, @@ -487,18 +492,16 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable: return fn def get_input_iter(self) -> Generator: - import math - D_HEAD = self.D_HEAD BATCH = self.BATCH H = self.H - seq_len_log2 = int(math.log2(self.SEQ_LEN)) + SEQ_LEN_LOG2 = 7 def get_ctx_vals(): - if self.H: + if self.SEQ_LEN: yield (BATCH, self.H, self.SEQ_LEN, self.D_HEAD) return - for i in range(seq_len_log2, 15): + for i in range(SEQ_LEN_LOG2, 15): N_CTX = 2**i # BATCH = 16384 // N_CTX # H = 2048 // D_HEAD