From 1458f1de665bd10e980dda0fae1de1fb932a8d2d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Feb 2024 18:42:21 +0000 Subject: [PATCH] Revert "Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)" This reverts commit 4b7a521856ca5fb0fc28edd18591f77fff5a6ba1. Reverted https://github.com/pytorch/pytorch/pull/118935 on behalf of https://github.com/atalman due to Significantly increases build time. Optimization is needed ([comment](https://github.com/pytorch/pytorch/pull/118935#issuecomment-1971723284)) --- .../native/transformers/cuda/attention.cu | 5 +- .../transformers/cuda/attention_backward.cu | 26 +- .../transformers/cuda/flash_attn/alibi.h | 74 -- .../transformers/cuda/flash_attn/block_info.h | 4 +- .../transformers/cuda/flash_attn/dropout.h | 96 -- .../transformers/cuda/flash_attn/flash.h | 30 +- .../cuda/flash_attn/flash_api.cpp | 400 +++----- .../transformers/cuda/flash_attn/flash_api.h | 18 +- .../cuda/flash_attn/flash_bwd_kernel.h | 877 ++++++++++++++++-- .../flash_attn/flash_bwd_launch_template.h | 314 ++++--- .../flash_attn/flash_bwd_preprocess_kernel.h | 377 -------- .../cuda/flash_attn/flash_fwd_kernel.h | 417 +++++---- .../flash_attn/flash_fwd_launch_template.h | 132 ++- .../cuda/flash_attn/kernel_traits.h | 107 ++- .../cuda/flash_attn/kernel_traits_sm90.h | 161 ++++ .../kernels/flash_bwd_hdim128_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim128_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim160_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim160_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim192_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim192_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim224_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim224_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim256_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim256_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim32_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim32_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim64_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim64_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim96_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim96_fp16_sm80.cu | 4 +- .../flash_attn/kernels/generate_kernels.py | 4 +- .../transformers/cuda/flash_attn/mask.h | 213 ----- .../transformers/cuda/flash_attn/philox.cuh | 120 ++- .../transformers/cuda/flash_attn/rotary.h | 152 --- .../transformers/cuda/flash_attn/softmax.h | 259 ++++-- .../cuda/flash_attn/static_switch.h | 43 +- .../transformers/cuda/flash_attn/utils.h | 220 ++++- .../native/transformers/cuda/sdp_utils.cpp | 18 +- .../transformers/hip/flash_attn/flash_api.hip | 18 +- test/test_transformers.py | 258 ++---- 41 files changed, 2301 insertions(+), 2106 deletions(-) delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h create mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits_sm90.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/mask.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 55de97ad223a78..900defaa763660 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -50,6 +50,7 @@ #include #include #include +#include #include #endif @@ -64,6 +65,7 @@ #include #include #include +#include #include #include @@ -850,7 +852,6 @@ _flash_attention_forward( // of the tensor. This is useful for kv cache scenarios but for now // we will not support in this PR. c10::optional seqused_k = c10::nullopt; - c10::optional alibi_slopes = c10::nullopt; // We are going to have two paths: // 1. The standard MHA path for dense tensors @@ -879,7 +880,6 @@ _flash_attention_forward( cumulative_sequence_length_q.value(), cumulative_sequence_length_k.value(), seqused_k, /*seqused_k*/ - alibi_slopes, /*alibi_slopes*/ max_seqlen_batch_q, max_seqlen_batch_k, dropout_p, @@ -905,7 +905,6 @@ _flash_attention_forward( key, value, out, - alibi_slopes, dropout_p, softmax_scale, is_causal, diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index c829f45a6f4a3f..cf8d543f12122c 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -1,4 +1,3 @@ -#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -42,8 +41,9 @@ #include #include #endif +namespace at { -namespace at::native { +namespace native { std::tuple _flash_attention_backward( const Tensor& grad_out, @@ -74,21 +74,6 @@ std::tuple _flash_attention_backward( // The kernel computes irregardless we will drop for this functions return Tensor grad_softmax; - // Currently unused args: - c10::optional alibi_slopes{c10::nullopt}; - - bool determinisitic{false}; - auto& ctx = at::globalContext(); - if (ctx.deterministicAlgorithms()) { - if (ctx.deterministicAlgorithmsWarnOnly()) { - TORCH_WARN_ONCE( - "Flash Attention defaults to a non-deterministic algorithm. ", - "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); - } else { - determinisitic = true; - } - } - // We check the whether the cumulative_sequence_length_q is defined // in order to determine whether we are using varlen or dense forward if (cumulative_sequence_length_q.defined()) { @@ -105,7 +90,6 @@ std::tuple _flash_attention_backward( dv, cumulative_sequence_length_q, cumulative_sequence_length_k, - alibi_slopes, max_seqlen_batch_q, max_seqlen_batch_k, dropout_p, @@ -114,7 +98,6 @@ std::tuple _flash_attention_backward( is_causal, -1, /*window_size_left*/ -1, /*window_size_right*/ - determinisitic, philox_seed, philox_offset); return std::make_tuple(dQuery, dKey, dValue); @@ -130,13 +113,11 @@ std::tuple _flash_attention_backward( dq, dk, dv, - alibi_slopes, dropout_p, softmax_scale, is_causal, -1, /*window_size_left*/ -1, /*window_size_right*/ - determinisitic, philox_seed, philox_offset); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); @@ -649,4 +630,5 @@ std::tuple _scaled_dot_product_e grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias); } -} // namespace at::native +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h b/aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h deleted file mode 100644 index 311231432c7cfe..00000000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h +++ /dev/null @@ -1,74 +0,0 @@ -#include - -#include - -#include -#include - -#include - -namespace pytorch_flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Alibi { - - const float alibi_slope; - const int max_seqlen_k, max_seqlen_q; - - __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) - : alibi_slope(alibi_slope) - , max_seqlen_k(max_seqlen_k) - , max_seqlen_q(max_seqlen_q) { - }; - - - template - __forceinline__ __device__ void apply_alibi(Tensor &tensor, - const int col_idx_offset_, - const int row_idx_offset, - const int warp_row_stride) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; - } - } - } - } else { // Bias depends on both row_idx and col_idx - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); - } - } - } - } - } - } - -}; - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h b/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h index bbaf6978002177..3e05d7e7195e8c 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h @@ -24,12 +24,12 @@ struct BlockInfo { } template - __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; } template - __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; } diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h b/aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h deleted file mode 100644 index 8dc4b0b22bcc9d..00000000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h +++ /dev/null @@ -1,96 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include - -namespace pytorch_flash { - -using namespace cute; - -struct Dropout { - - const unsigned long long seed, offset; - const uint8_t p_dropout_in_uint8_t; - - __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, - const uint8_t p_dropout_in_uint8_t, - const int bid, const int hid, const int tid, const int nheads) - : seed(seed) - , offset(offset + (bid * nheads + hid) * 32 + tid % 32) - , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { - } - - template - __forceinline__ __device__ void apply_dropout(Tensor &tensor_, - int block_row_start, int block_col_start, int block_row_stride) { - // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) - Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_dropout(tensor_.layout())); - using T = typename Engine::value_type; - auto encode_dropout = [](bool keep, T val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); - }; - static_assert(decltype(size<2>(tensor))::value % 2 == 0); - const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); - const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); - // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } - #pragma unroll - for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { - uint2 rowcol = make_uint2(block_row_start, block_col_start); - #pragma unroll - for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { - // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} - uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast(rowcol), offset); - // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} - uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); - // Special implementation for 16-bit types: we duplicate the threshold to the - // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction - // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, - // and the high 16 bits will be either 0xffff or 0x0000, depending on whether - // the random value is less than the threshold. - // We then do a bit-wise AND between the mask and the original value (in 32-bit). - // We're exploiting the fact that floating point comparison is equivalent to integer - // comparison, since we're comparing unsigned integers whose top 8-bits are zero. - if (!encode_dropout_in_sign_bit - && (std::is_same::value || std::is_same::value)) { - uint16_t rnd_16[16]; - #pragma unroll - for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } - uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); - #pragma unroll - for (int j = 0; j < 2; j++) { - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - #pragma unroll - for (int i = 0; i < 4; i++) { - uint32_t mask; - asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); - tensor_uint32(i) &= mask; - } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } else { - #pragma unroll - for (int j = 0; j < 2; j++) { - #pragma unroll - for (int i = 0; i < 8; i++) { - tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); - } - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } - // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); - // // } - } - } - } - -}; - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h index 9ce14cf6489ef2..23fa6584b9b564 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h @@ -5,15 +5,13 @@ #pragma once #include +#include + +#include + +namespace pytorch_flash{ -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif -#include // For at::cuda::philox::unpack -namespace pytorch_flash { constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; @@ -21,7 +19,7 @@ constexpr int D_DIM = 2; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { - using index_t = int64_t; + using index_t = uint32_t; // The QKV matrices. void *__restrict__ q_ptr; void *__restrict__ k_ptr; @@ -98,12 +96,7 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ rotary_sin_ptr; // The indices to index into the KV cache. - int * __restrict__ cache_batch_idx; - - // Paged KV cache - int * __restrict__ block_table; - index_t block_table_batch_stride; - int page_block_size; + int *__restrict__ cache_batch_idx; // The dropout probability (probability of keeping an activation). float p_dropout; @@ -133,9 +126,6 @@ struct Flash_fwd_params : public Qkv_params { bool is_rotary_interleaved; int num_splits; // For split-KV version - - void * __restrict__ alibi_slopes_ptr; - index_t alibi_slopes_batch_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -175,9 +165,6 @@ struct Flash_bwd_params : public Flash_fwd_params { // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; - - bool deterministic; - index_t dq_accum_split_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -185,6 +172,7 @@ struct Flash_bwd_params : public Flash_fwd_params { template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); + } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index 8f6f7a9f357dc9..07c9f7e547facd 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -1,5 +1,29 @@ /****************************************************************************** - * Copyright (c) 2024, Tri Dao. + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * ******************************************************************************/ #include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS @@ -26,7 +50,6 @@ #include #include #include -#include #endif @@ -70,11 +93,11 @@ void set_params_fprop(Flash_fwd_params ¶ms, float p_dropout, float softmax_scale, int window_size_left, - int window_size_right, - bool seqlenq_ngroups_swapped=false) { + int window_size_right) { - // Reset the parameters + // Reset the parameters should be equivalent params = {}; + // memset(¶ms, 0, sizeof(params)); params.is_bf16 = q.dtype() == at::kBFloat16; @@ -98,10 +121,6 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); params.o_batch_stride = out.stride(0); - if (seqlenq_ngroups_swapped) { - params.q_batch_stride *= seqlen_q; - params.o_batch_stride *= seqlen_q; - } } params.cu_seqlens_q = static_cast(cu_seqlens_q_d); @@ -140,9 +159,6 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; TORCH_CHECK(p_dropout < 1.f); - #ifdef FLASHATTENTION_DISABLE_DROPOUT - TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); - #endif // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. @@ -153,16 +169,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.window_size_left = window_size_left; params.window_size_right = window_size_right; - #ifdef FLASHATTENTION_DISABLE_LOCAL - TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), - "This flash attention build does not support local attention."); - #endif - params.is_seqlens_k_cumulative = true; - - #ifdef FLASHATTENTION_DISABLE_UNEVEN_K - TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); - #endif } void set_params_dgrad(Flash_bwd_params ¶ms, @@ -195,8 +202,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, float p_dropout, float softmax_scale, int window_size_left, - int window_size_right, - bool deterministic) { + int window_size_right) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, @@ -238,13 +244,11 @@ void set_params_dgrad(Flash_bwd_params ¶ms, // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; - - params.deterministic = deterministic; } void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { FP16_SWITCH(!params.is_bf16, [&] { - HEADDIM_SWITCH(params.d, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 run_mha_fwd_(params, stream); } else { @@ -296,62 +300,16 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n return 1; } -void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, - const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, - const int head_size_rounded, const float p_dropout, - const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) { - - // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); - const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; - // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. - // In any case we don't expect seqlen_q to be larger than 64 for inference. - const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; - params.num_splits = num_splits; - if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout - if (num_splits < 1) { - params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); - } - if (params.num_splits > 1) { - at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); - params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); - params.oaccum_ptr = out_accum.data_ptr(); - } - TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); - } -} - -void set_params_alibi(Flash_fwd_params ¶ms, c10::optional &alibi_slopes_, int batch_size, int num_heads){ -#ifdef FLASHATTENTION_DISABLE_ALIBI - TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi."); - params.alibi_slopes_ptr = nullptr; -#else - if (alibi_slopes_.has_value()) { - auto alibi_slopes = alibi_slopes_.value(); - TORCH_CHECK(alibi_slopes.dtype() == at::kFloat, "ALiBi slopes must have dtype fp32"); - CHECK_DEVICE(alibi_slopes); - TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); - TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({num_heads}) || alibi_slopes.sizes() == at::IntArrayRef({batch_size, num_heads})); - params.alibi_slopes_ptr = alibi_slopes.data_ptr(); - params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; - } else { - params.alibi_slopes_ptr = nullptr; - } -#endif -} - // return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; std::tuple mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, - int window_size_left, + const int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_) { @@ -392,16 +350,12 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (window_size_left >= seqlen_k) { window_size_left = -1; } - if (window_size_right >= seqlen_k) { window_size_right = -1; } - - // causal=true is the same as causal=false in this case - if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0; at::Tensor temp_q = q; if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; @@ -415,9 +369,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); at::Tensor q_padded, k_padded, v_padded; - q_padded = temp_q; - k_padded = k; - v_padded = v; + q_padded = temp_q; + k_padded = k; + v_padded = v; at::Tensor out; if (out_.has_value()) { @@ -469,17 +423,30 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head window_size_left, window_size_right); - - set_params_splitkv(params, batch_size, num_heads, - head_size, seqlen_k, seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (seqlen_q + 64 - 1) / 64; + params.num_splits = 1; + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); + if (params.num_splits > 1) { + at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + } // We want to checkpoint and save the RNG state for backward if dropout // We get the default generator and return the seed and offset which will // be used in the backward function + auto gen = at::get_generator_or_default(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); at::Tensor seed_t, offset_t; if (p_dropout > 0.0) { - auto gen = at::get_generator_or_default(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); // number of times random will be generated per thread, to offset philox counter in thc random // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. @@ -509,8 +476,6 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head } - set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream); @@ -536,18 +501,18 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - c10::optional &alibi_slopes_, // num_heads or b x num_heads - int max_seqlen_q, + const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, - bool is_causal, - int window_size_left, + const bool is_causal, + const int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_) { + if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; @@ -579,39 +544,17 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const auto sizes = q.sizes(); + const int total_q = sizes[0]; const int batch_size = cu_seqlens_q.numel() - 1; - int num_heads = sizes[1]; + const int num_heads = sizes[1]; const int head_size_og = sizes[2]; const int total_k = k.size(0); const int num_heads_k = k.size(1); - - if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case - if (is_causal) { window_size_right = 0; } - - void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); - - // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case - // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); - at::Tensor temp_q = q; - if (seqlenq_ngroups_swapped) { - const int ngroups = num_heads / num_heads_k; - temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); - max_seqlen_q = ngroups; - num_heads = num_heads_k; - cu_seqlens_q_d = nullptr; - } - - const int total_q = q.sizes()[0]; - TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!") - if (window_size_left >= max_seqlen_k) { window_size_left = -1; } - if (window_size_right >= max_seqlen_k) { window_size_right = -1; } - CHECK_SHAPE(q, total_q, num_heads, head_size_og); CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); @@ -626,7 +569,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q } at::Tensor q_padded, k_padded, v_padded; - q_padded = temp_q; + q_padded = q; k_padded = k; v_padded = v; @@ -676,7 +619,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q num_heads, num_heads_k, head_size, head_size_rounded, q_padded, k_padded, v_padded, out, - cu_seqlens_q_d, + cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, return_softmax ? p.data_ptr() : nullptr, @@ -684,16 +627,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q p_dropout, softmax_scale, window_size_left, - window_size_right, - seqlenq_ngroups_swapped); - if (seqlenq_ngroups_swapped) { - // Only apply split-k for decoding - set_params_splitkv(params, batch_size, num_heads, - head_size, max_seqlen_k, max_seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); - } + window_size_right); - // We want to checkpoint and save the RNG state for backward if dropout + // We want to checkpoint and save the RNG state for backward if dropout // We get the default generator and return the seed and offset which will // be used in the backward function auto gen = at::get_generator_or_default(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); @@ -728,33 +664,31 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q } - set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - - if (max_seqlen_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream); - } else { - // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. - out.zero_(); - softmax_lse.fill_(std::numeric_limits::infinity()); - } - - if (seqlenq_ngroups_swapped) { - std::array size_before = {batch_size, max_seqlen_q, num_heads_k, head_size_og}; - std::array size_after = {batch_size, num_heads_k * max_seqlen_q, head_size_og}; - out = out.reshape(size_before).transpose(1, 2).reshape(size_after); - q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); - softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1}); - } + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; } -void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { FP16_SWITCH(!params.is_bf16, [&] { - HEADDIM_SWITCH(params.d, [&] { - run_mha_bwd_(params, stream); - }); + if (params.d <= 32) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 64) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 96) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 128) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 160) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 192) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 224) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 256) { + run_mha_bwd_(params, stream, configure); + } }); } @@ -768,19 +702,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - int window_size_left, + const int window_size_left, int window_size_right, - const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { - #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash attention build does not support backward."); - #endif if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -827,8 +756,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - if (head_size > 192 && (head_size <= 224 || is_dropout)) { - TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); + if (head_size > 192) { + TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800"); } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -839,9 +768,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); - if (window_size_left >= seqlen_k) { window_size_left = -1; } - if (window_size_right >= seqlen_k) { window_size_right = -1; } - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); @@ -877,6 +803,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si dv = at::empty_like(v); } + // const at::Tensor& dout_padded = dout; + // bool loop = seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity bool loop = true; @@ -890,14 +818,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si at::Tensor dq_accum; at::Tensor dk_accum, dv_accum; if (loop) { - if (!deterministic) { - dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); - } else { - const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads); - dq_accum = at::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); - } - // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + // dk_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // dv_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; @@ -931,11 +854,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si p_dropout, softmax_scale, window_size_left, - window_size_right, - deterministic); - params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); + window_size_right); auto launch = &run_mha_bwd; + // launch(params, stream, /*configure=*/true); at::PhiloxCudaState philox_args; if (is_dropout) { @@ -950,14 +872,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si } params.philox_args = philox_args; - set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - if (seqlen_q > 0) { - launch(params, stream); + launch(params, stream, /*configure=*/false); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. - dk_expanded.zero_(); - dv_expanded.zero_(); + dk.zero_(); + dv.zero_(); softmax_d.zero_(); } @@ -981,24 +901,17 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 - c10::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop const float softmax_scale, const bool zero_tensors, const bool is_causal, - int window_size_left, + const int window_size_left, int window_size_right, - const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { - - #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash attention build does not support backward."); - #endif - if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -1012,7 +925,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + TORCH_CHECK(q_dtype == at::kHalf|| q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); @@ -1049,8 +962,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - if (head_size > 192 && (head_size <= 224 || is_dropout)) { - TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); + if (head_size > 192) { + TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800"); } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1061,9 +974,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); - if (window_size_left >= max_seqlen_k) { window_size_left = -1; } - if (window_size_right >= max_seqlen_k) { window_size_right = -1; } - CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); @@ -1098,9 +1008,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, total_k, num_heads_k, head_size); } else { - dv = at::empty_like(v); + dv = at::empty_like(k); } + // const at::Tensor& dout_padded = dout; + // bool loop = max_seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity bool loop = true; @@ -1121,12 +1033,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally // allowed to do. So we won't have to do any bound checking, and performance should stay the same. - if (!deterministic) { - dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); - } else { - const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads); - dq_accum = at::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); - } + dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; @@ -1165,11 +1072,10 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size p_dropout, softmax_scale, window_size_left, - window_size_right, - deterministic); - params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); + window_size_right); auto launch = &run_mha_bwd; + // launch(params, stream, /*configure=*/true); at::PhiloxCudaState philox_args; if (is_dropout) { @@ -1184,16 +1090,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } params.philox_args = philox_args; - set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - - if (max_seqlen_q > 0) { - launch(params, stream); - } else { - // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. - dk_expanded.zero_(); - dv_expanded.zero_(); - softmax_d.zero_(); - } + launch(params, stream, /*configure=*/false); // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { @@ -1206,20 +1103,18 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size std::tuple mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &seqlens_k_, // batch_size c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) c10::optional &cache_batch_idx_, // indices to index into the KV cache - c10::optional &block_table_, // batch_size x max_num_blocks_per_seq - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, - int window_size_left, + const int window_size_left, int window_size_right, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits @@ -1248,41 +1143,25 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - at::Tensor block_table; - const bool paged_KV = block_table_.has_value(); - if (paged_KV) { - TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); - block_table = block_table_.value(); - CHECK_DEVICE(block_table); - TORCH_CHECK(block_table.dtype() == at::kInt, "block_table must have dtype torch.int32"); - TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); - } - const auto sizes = q.sizes(); const int batch_size = sizes[0]; int seqlen_q = sizes[1]; int num_heads = sizes[2]; const int head_size_og = sizes[3]; - - const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); - const int num_blocks = !paged_KV ? 0 : kcache.size(0); - const int page_block_size = !paged_KV ? 1 : kcache.size(1); - TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); - const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; + const int seqlen_k = kcache.size(1); const int num_heads_k = kcache.size(2); - const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; - TORCH_CHECK(batch_size > 0, "batch size must be postive"); + const int batch_size_c = kcache.size(0); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - // causal=true is the same as causal=false in this case - if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0; if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); @@ -1290,18 +1169,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he num_heads = num_heads_k; } - if (window_size_left >= seqlen_k) { window_size_left = -1; } - if (window_size_right >= seqlen_k) { window_size_right = -1; } - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); - if (!paged_KV) { - CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); - } else { - CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og); - CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og); - CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); - } + CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); at::Tensor q_padded, kcache_padded, vcache_padded; if (head_size_og % 8 != 0) { @@ -1440,24 +1310,27 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(cache_batch_idx.scalar_type() == at::kInt, "cache_batch_idx must have dtype int32"); params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); } - - set_params_splitkv(params, batch_size, num_heads, - head_size, seqlen_k, seqlen_q, - head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts); - - if (paged_KV) { - params.block_table = block_table.data_ptr(); - params.block_table_batch_stride = block_table.stride(0); + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (seqlen_q + 64 - 1) / 64; + params.num_splits = num_splits; + if (num_splits < 1) { + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); + } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + if (params.num_splits > 1) { + at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); } - params.page_block_size = page_block_size; - - - set_params_alibi(params, alibi_slopes_, batch_size, num_heads); auto stream = at::cuda::getCurrentCUDAStream().stream(); - // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, - // or paged KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV); + // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx + run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value()); if (head_size_og % 8 != 0) { // out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); @@ -1479,7 +1352,6 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } return {out, softmax_lse}; } - } // namespace pytorch_fmha #endif diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h index 2745b28dca29b8..fd15d929e300be 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h @@ -12,11 +12,10 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, - int window_size_left, + const int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); @@ -29,14 +28,13 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - c10::optional &alibi_slopes_, // num_heads or b x num_heads - int max_seqlen_q, + const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, - bool is_causal, - int window_size_left, + const bool is_causal, + const int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); @@ -52,13 +50,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - int window_size_left, + const int window_size_left, int window_size_right, - const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset); @@ -74,16 +70,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 - c10::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop const float softmax_scale, const bool zero_tensors, const bool is_causal, - int window_size_left, + const int window_size_left, int window_size_right, - const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h index db817a0657ffcb..9f2dc5ac388d1f 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h @@ -1,23 +1,24 @@ /*************************************************************************************************** - * Copyright (c) 2024, Tri Dao. + * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include + #include +#include #include #include #include +#include #include #include #include #include -#include -#include -#include +#include namespace pytorch_flash { @@ -65,8 +66,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; // Divide by 2 because right now we always use 2 for the ValLayout - constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; - constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; + constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; auto t = make_tile(make_layout(Int{}), Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) Stride<_1, Int, _8> >{}); // (1, 64, 8) or (1, 32, 8) @@ -76,7 +76,359 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template +inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, + Tensor &dP_sum, const int gdP_col_stride, const float scale) { + static_assert(Layout0::rank == 3, "Only support 3D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); + // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) + // The last coordinate is the "page". + Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), + make_layout(get<0>(do_.layout()), + get<2>(do_.layout())))); + Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); + Tensor do_fp32 = pytorch_flash::convert_type(do_reshaped); + Tensor o_fp32 = pytorch_flash::convert_type(o_reshaped); + #pragma unroll + for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { + float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); + #pragma unroll + for (int ni = 1; ni < size<1>(do_reshaped); ni++) { + dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); + } + pytorch_flash::SumOp sum_op; + dP_sum_cur = pytorch_flash::Allreduce::run(dP_sum_cur, sum_op) * scale; + if (threadIdx.x % THREADS_PER_ROW == 0) { + dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void compute_dot_do_o(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; + + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; + auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); + // TODO: careful, we're zeroing out dQaccum with type float4, but when + // we do atomicAdds, we use type float. The layouts are different. Check this. + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); + + Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); + + // Allocate predicate tensors for k + Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); + // Set predicates for k bounds + #pragma unroll + for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + pytorch_flash::copy( + gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + pytorch_flash::copy( + gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final + // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, + // so that (dP - dP_sum) is on the same scale. + dot_do_o(tdOrdO, tdOrO, dP_sum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + if (Clear_dQaccum) { + // We're actually not zero'ing out all of dQaccum, but only the part that we're going to + // do atomicAdds on. + Tensor zero = make_fragment_like(tdQgdQaccum); + clear(zero); + cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void clear_dKVaccum(const Params ¶ms) { + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; + + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); + Tensor zero = make_fragment_like(tdKgdKaccum); + clear(zero); + cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); + cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dQ from dQaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void convert_dQ(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + + Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdQ{}); + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; + auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); + + Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); + cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { + acc_dq(i) = tdQrdQaccum(i) * params.scale_softmax_rp_dropout; + } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = pytorch_flash::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + __syncthreads(); + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); + + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); + #pragma unroll + for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + pytorch_flash::copy( + gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_q. +template +inline __device__ void convert_dKV(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + + n_block * kBlockN) * params.d_rounded; + + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + + Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdKV{}); + Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); + + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); + CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); + + Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); + Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); + cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); + cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { + acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; + } + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { + acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; + } + // Convert acc_dk from fp32 to fp16 + Tensor rdK = pytorch_flash::convert_type(acc_dk); + Tensor rdV = pytorch_flash::convert_type(acc_dv); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + __syncthreads(); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); + cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); + + Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + pytorch_flash::copy( + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + pytorch_flash::copy( + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -92,8 +444,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; - constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; + // constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; const BlockInfo binfo(params, bidb); @@ -117,9 +469,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded - // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. - + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); + + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM; const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded @@ -368,7 +718,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tdKsQt.data() = tdKsQt.data() + size(sQ); } - if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); } + if (!Is_first && !Seq_parallel) { __syncthreads(); } if (Kernel_traits::Is_V_in_regs) { // Clear the smem tiles to account for predicated off loads @@ -406,12 +756,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccScS_row(mi)); - lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; + lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; } - // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, - // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply - // with V (which would be zero), we're fine. However, with ALiBi, we might modify these - // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. // Tensor tKrK = make_fragment_like(tKsK); // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK); @@ -445,16 +791,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in auto seeds = at::cuda::philox::unpack(params.philox_args); unsigned long long seed = std::get<0>(seeds); - unsigned long long offset = std::get<1>(seeds); - pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t, - bidb, bidh, tidx, params.h); + unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; clear(acc_dv); clear(acc_dk); - const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - pytorch_flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); - for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) clear(acc_s); @@ -478,12 +819,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); // if (cute::thread(32, 0)) { print(scores); } - - if (Has_alibi) { - alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16); - } - // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond // actual_seqlen_k, because acc_s would be some finite value for those indices. // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, @@ -520,27 +855,28 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } } - // if (cute::thread(32, 0)) { print(scores); } // Compute the exponential value. pytorch_flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); - if constexpr (Is_dropout) { + if (Is_dropout) { int warp_id = tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 static_assert(MMA_N_SdP % 2 == 0); int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); - dropout.template apply_dropout( - acc_s, block_row_idx, block_col_idx, AtomLayoutMS + Tensor scores_dropped = make_tensor(scores.data(), pytorch_flash::convert_layout_rowcol_Aregs(scores.layout())); + pytorch_flash::apply_dropout( + scores_dropped, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, AtomLayoutMS ); } // Convert scores from fp32 to fp16/bf16 Tensor rP = !Is_dropout - ? pytorch_flash::convert_type(acc_s) - : pytorch_flash::convert_type_relu(acc_s); - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2) - // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + ? pytorch_flash::convert_type(scores) + : pytorch_flash::convert_type_relu(scores); + // Reshape rP from (nrow=(2, MMA_N), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_N, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8. + Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); // if (cute::thread0()) { print(tPaP); } @@ -553,7 +889,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA clear(acc_dp); - // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), pytorch_flash::convert_layout_acc_rowcol(acc_dp.layout())); + // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), flash::convert_layout_acc_rowcol(acc_dp.layout())); // #pragma unroll // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) { // #pragma unroll @@ -617,9 +953,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Layout p_l = tPrP.layout(); // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); - // pytorch_flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); + // flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); - // pytorch_flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); + // flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); pytorch_flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } @@ -784,7 +1120,430 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template +inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + // constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value; + constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal) { + n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + // We'll advance gdKaccum and gdVaccum before the first write. + const index_t row_offset_dkv_accum = ((bidb * params.h_k + (bidh / params.h_h_k_ratio)) * params.seqlen_k_rounded + + n_block_max * kBlockN) * params.d_rounded; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + // We assume that params.d == kHeadDim for now + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQdO{}); + Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + Tensor sdO = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutQdO{}); + Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), + typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); + // Double buffer for sK + Tensor sV = make_tensor(sK.data() + 2 * size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); + Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); + Tensor sdS = make_tensor(sV.data() + size(sV), typename Kernel_traits::SmemLayoutPdS{}); + Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); + Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{}); + Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); + Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast(sdS.data().get())), + Shape>{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; + auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); + + typename Kernel_traits::TiledMmaSdP tiled_mma_sdp; + auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx); + Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K) + Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K) + Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K) + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx); + Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N) + Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N) + Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N) + Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N) + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx); + Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N) + Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N) + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_M_SdP, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); + auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); + Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); + + auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); + auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_KV.partition_S(sK); + Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); + + // Partition sP and sdS to match the accumulator partitioning + // This has to be tiled_mma_sdp, not tiled_mma_dkv + auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp); + auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx); + Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); + auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx); + Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); + Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); + + auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); + auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx); + Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); + Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); + + auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq); + auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx); + Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); + + auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq); + auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx); + Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); + + // + // PREDICATES + // + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + + // TODO: Might need to exit early and write 0 to gdQ. + + pytorch_flash::copy( + gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + pytorch_flash::copy( + gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + + Tensor tQrQ = make_fragment_like(tQgQ); + pytorch_flash::copy( + gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + + int n_block = n_block_max - 1; + if (n_block % 2 == 1) { + tKsK.data() = tKsK.data() + size(sK); + tSsK.data() = tSsK.data() + size(sK); + tdQsKt.data() = tdQsKt.data() + size(sK); + } + + pytorch_flash::copy( + gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + pytorch_flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + + Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0>(taccScS))::value == 4); + // Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices. + Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + Tensor lse = make_tensor(Shape>{}); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccScS_row(mi)); + lse(mi) = row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; + } + + cute::cp_async_fence(); + + Tensor dP_sum = make_fragment_like(lse); + cute::copy(tdOrdO, tdOsdO); + dot_do_o( + tdOrdO, tdOrO, sdPsum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout + ); + __syncthreads(); + #pragma unroll + for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<0>(taccScS_row(mi))); } + + auto seeds = at::cuda::philox::unpack(params.philox_args); + unsigned long long seed = std::get<0>(seeds); + unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + + clear(acc_dq); + + for (; n_block >= 0; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M_SdP, MMA_N) + clear(acc_s); + pytorch_flash::cp_async_wait<0>(); + __syncthreads(); + + pytorch_flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); + + // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); + // We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would + // be some finite value for those indices. In the end when we multiply with K to get dQ, + // the corresponding values of K would be 0, so the result would still be correct. + if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) { + pytorch_flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + AtomLayoutMS * 16); + } + // Compute the exponential value. + pytorch_flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + if (Is_dropout) { + int warp_id = tidx / 32; + int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; + // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 + static_assert(MMA_N_SdP % 2 == 0); + int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); + Tensor scores_dropped = make_tensor(scores.data(), pytorch_flash::convert_layout_rowcol_Aregs(scores.layout())); + pytorch_flash::apply_dropout( + scores_dropped, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, AtomLayoutMS + ); + } + // Convert scores from fp32 to fp16/bf16 + Tensor rP = !Is_dropout + ? pytorch_flash::convert_type(scores) + : pytorch_flash::convert_type_relu(scores); + // Reshape rP from (nrow=(2, MMA_N), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_N, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8. + Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); + Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); + + Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA + CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA + + clear(acc_dp); + pytorch_flash::gemm(acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); + + // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + Tensor dS = make_tensor(acc_dp.data(), scores.layout()); + auto pointwise_mult = [](float p, float dp, float d) { + return p * (!Is_dropout || p >= 0 ? dp - d : d); + }; + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + } + } + + Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); + // Convert dS from fp32 to fp16 + Tensor tdSrdS = pytorch_flash::convert_type(dS_reshaped); + Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); + __syncthreads(); + + if (n_block > 0) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tKsK.data() = tKsK.data() + sK_offset; + tSsK.data() = tSsK.data() + sK_offset; + // Advance gK, gV + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + pytorch_flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + clear(acc_dv); + pytorch_flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(acc_dv); } + tdVgdVaccum.data() = tdVgdVaccum.data() + (-int(kBlockN * params.d_rounded)); + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { atomicAdd(&tdVgdVaccum(i), acc_dv(i)); } + + __syncthreads(); + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + clear(acc_dk); + pytorch_flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + tdKgdKaccum.data() = tdKgdKaccum.data() + (-int(kBlockN * params.d_rounded)); + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { atomicAdd(&tdKgdKaccum(i), acc_dk(i)); } + + pytorch_flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, + smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt); + // Double buffer for sK + tdQsKt.data() = tdQsKt.data() + (n_block % 2 == 0 ? size(sK) : -size(sK)); + + } + + // Epilogue + + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = pytorch_flash::convert_type(acc_dq); + + Tensor sdQ = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutdQ{}); + + // Partition sdV and sdK to match the accumulator partitioning + auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + __syncthreads(); + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; + auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + + __syncthreads(); + + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); + + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + pytorch_flash::copy( + gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template inline __device__ void compute_dq_dk_dv(const Params ¶ms) { // The block index for the batch. @@ -798,32 +1557,44 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; if (n_block_max == 1) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } else { // Iterating backward from n_block_max - 1 to 0 might save 1 register - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); for (int n_block = n_block_max - 2; n_block > 0; n_block--) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { + const int n_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; - // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. - for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); - } + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params ¶ms) { + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + compute_dq_dk_dv_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h index 8644ccd88a69ce..5c65bbd5ced150 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h @@ -1,6 +1,4 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ +// Copyright (c) 2022, Tri Dao. #pragma once @@ -8,81 +6,58 @@ #include #include -#include #include namespace pytorch_flash { -// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#define ARCH_SUPPORTS_FLASH -#define KERNEL_PARAM_MODIFIER __grid_constant__ -#else -#define KERNEL_PARAM_MODIFIER -#endif - -// Define a macro for unsupported architecture handling to centralize the error message -#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); +template +__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) { + pytorch_flash::compute_dot_do_o(params); +} -// Use a macro to clean up kernel definitions -#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \ -template \ -__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) +template +__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) { + pytorch_flash::clear_dKVaccum(params); +} -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) { - #if defined(ARCH_SUPPORTS_FLASH) - pytorch_flash::compute_dq_dk_dv(params); - #else - FLASH_UNSUPPORTED_ARCH - #endif +template +__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { + pytorch_flash::compute_dq_dk_dv(params); } -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) { - #if defined(ARCH_SUPPORTS_FLASH) +template +__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - pytorch_flash::compute_dq_dk_dv_seqk_parallel(params); + pytorch_flash::compute_dq_dk_dv_seqk_parallel(params); #else - FLASH_UNSUPPORTED_ARCH + printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!"); #endif } -template -__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { - pytorch_flash::compute_dot_do_o(params); +template +__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) { + pytorch_flash::compute_dq_dk_dv_seqq_parallel(params); } template -__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) { - pytorch_flash::clear_dKVaccum(params); +__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) { + pytorch_flash::convert_dQ(params); } template -__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { - pytorch_flash::convert_dQ(params, nsplits); -} - -template -__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { +__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) { pytorch_flash::convert_dKV(params); } template -void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid_m(num_m_block, params.b, params.h); const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; - int gridDimx = num_n_block; - if (params.deterministic) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h); - } - dim3 grid_n(gridDimx, params.b, params.h); + dim3 grid_n(num_n_block, params.b, params.h); - if (!params.deterministic) { - flash_bwd_dot_do_o_kernel<<>>(params); - } else { - flash_bwd_dot_do_o_kernel<<>>(params); - } + flash_bwd_dot_do_o_kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not @@ -91,23 +66,21 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) const bool is_even_K = params.d == Kernel_traits::kHeadDim; constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); - BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { - ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); @@ -118,19 +91,58 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) C10_CUDA_CHECK(cudaFuncSetAttribute( kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); } - kernel_dq<<>>(params, !params.deterministic ? 1 : gridDimx); + kernel_dq<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + dim3 grid_n(num_n_block, params.b, params.h_k); + flash_bwd_clear_dkvaccum_kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid_m(num_m_block, params.b, params.h); + // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check + // for cu_seqlens_k as well. + const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock; + // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + + auto kernel_dkv = &flash_bwd_convert_dkv_kernel; + if (Kernel_traits::kSmemKVSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize)); + } + kernel_dkv<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template -void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { -#ifndef FLASHATTENTION_DISABLE_BACKWARD - run_flash_bwd_seqk_parallel(params, stream); -#endif +void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + if (configure) return; + run_flash_bwd_seqk_parallel(params, stream, configure); } template -void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr static int Headdim = 32; int device; cudaGetDevice(&device); @@ -140,21 +152,21 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers - run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream, configure); } else { - run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream, configure); } } else { // 96 KB - run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream, configure); } }); } template -void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr static int Headdim = 64; int device; cudaGetDevice(&device); @@ -165,41 +177,42 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // Changing AtomLayoutMdQ from 2 to 4 takes the same time - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); // This is slightly faster. We want to split M more so we need fewer registers to store LSE. if (max_smem_per_block >= 144 * 1024) { - run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream, configure); // This has a lot of register spilling - // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream, configure); } else { // if (params.h == params.h_k) { - // run_flash_bwd, Is_dropout>(params, stream); - run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); // } else { + // run_flash_bwd_seqq_parallel, Is_dropout>(params, stream, configure); // } } }); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream, configure); } template -void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr static int Headdim = 96; int device; cudaGetDevice(&device); @@ -210,22 +223,26 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if (max_smem_per_block >= 116 * 1024) { - if constexpr(!Is_dropout) { // 92KB - run_flash_bwd, Is_dropout>(params, stream); - } else { // 116 KB - // This is faster for dropout since we don't have many registers to spare - run_flash_bwd, Is_dropout>(params, stream); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // if (params.h == params.h_k) { + if (max_smem_per_block >= 116 * 1024) { + if constexpr(!Is_dropout) { // 92KB + run_flash_bwd, Is_dropout>(params, stream, configure); + } else { // 116 KB + // This is faster for dropout since we don't have many registers to spare + run_flash_bwd, Is_dropout>(params, stream, configure); + } + } else { + run_flash_bwd, Is_dropout>(params, stream, configure); } - } else { - run_flash_bwd, Is_dropout>(params, stream); - } + // } else { + // run_flash_bwd_seqq_parallel>(params, stream, configure); + // } }); } template -void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr static int Headdim = 128; int device; cudaGetDevice(&device); @@ -236,30 +253,35 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // run_flash_bwd>(params, stream); - // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). - // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. - // run_flash_bwd>(params, stream); - if (max_smem_per_block >= 144 * 1024) { - run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); - // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - } else { - // run_flash_bwd, Is_dropout>(params, stream); - run_flash_bwd, Is_dropout>(params, stream); - } - // run_flash_bwd>(params, stream); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // if (params.h == params.h_k) { + // run_flash_bwd>(params, stream, configure); + // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). + // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. + // run_flash_bwd>(params, stream, configure); + if (max_smem_per_block >= 144 * 1024) { + run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); + // run_flash_bwd_seqq_parallel, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + } else { + // run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream, configure); + } + // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream, configure); + // } else { + // run_flash_bwd_seqq_parallel>(params, stream, configure); + // } }); } template -void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr static int Headdim = 160; int device; cudaGetDevice(&device); @@ -269,17 +291,17 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { - run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream, configure); } else { - run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream, configure); } }); } template -void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr static int Headdim = 192; int device; cudaGetDevice(&device); @@ -289,25 +311,25 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 136 * 1024) { - run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream, configure); } else { - run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream, configure); } }); } template -void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr static int Headdim = 224; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - run_flash_bwd, Is_dropout>(params, stream); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_bwd, Is_dropout>(params, stream, configure); }); } template -void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { constexpr static int Headdim = 256; int device; cudaGetDevice(&device); @@ -317,18 +339,14 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 176 * 1024) { // H100 - run_flash_bwd, Is_dropout>(params, stream); - } else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem - run_flash_bwd, Is_dropout>(params, stream); - } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering. - if constexpr (!Is_dropout) { - run_flash_bwd, false>(params, stream); - } + run_flash_bwd, Is_dropout>(params, stream, configure); + } else { // A100, we don't do double buffering to save smem + run_flash_bwd, Is_dropout>(params, stream, configure); } }); } -}; // namespace pytorch_flash +}; // namespace pytorch_fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h deleted file mode 100644 index 7811984b7e61e0..00000000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h +++ /dev/null @@ -1,377 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include - -#include -#include -#include - -namespace pytorch_flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, - Tensor &dP_sum, const int gdP_col_stride, const float scale) { - static_assert(Layout0::rank == 3, "Only support 3D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); - // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) - // The last coordinate is the "page". - Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), - make_layout(get<0>(do_.layout()), - get<2>(do_.layout())))); - Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); - Tensor do_fp32 = pytorch_flash::convert_type(do_reshaped); - Tensor o_fp32 = pytorch_flash::convert_type(o_reshaped); - #pragma unroll - for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { - float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); - #pragma unroll - for (int ni = 1; ni < size<1>(do_reshaped); ni++) { - dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); - } - pytorch_flash::SumOp sum_op; - dP_sum_cur = pytorch_flash::Allreduce::run(dP_sum_cur, sum_op) * scale; - if (threadIdx.x % THREADS_PER_ROW == 0) { - dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. -// This is used in the case where we want to parallelize the backward across seqlen_k. -template -inline __device__ void compute_dot_do_o(const Params ¶ms) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) - + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; - const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; - - Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), - Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), - Shape>{}, Stride<_1>{}); - - typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; - auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); - // TODO: careful, we're zeroing out dQaccum with type float4, but when - // we do atomicAdds, we use type float. The layouts are different. Check this. - typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; - auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); - - Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); - Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); - Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); - - Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); - - // Allocate predicate tensors for k - Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); - // Set predicates for k bounds - #pragma unroll - for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} - - Tensor tdOrdO = make_fragment_like(tdOgdO); - Tensor tdOrO = make_fragment_like(tdOgO); - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM - ); - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM - ); - // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final - // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, - // so that (dP - dP_sum) is on the same scale. - dot_do_o(tdOrdO, tdOrO, dP_sum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); - if (Clear_dQaccum) { - // We're actually not zero'ing out all of dQaccum, but only the part that we're going to - // do atomicAdds on. - Tensor zero = make_fragment_like(tdQgdQaccum); - clear(zero); - cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void clear_dKVaccum(const Params ¶ms) { - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - const int n_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (n_block * kBlockN >= binfo.actual_seqlen_k) return; - - const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; - - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, Stride, _1>{}); - - typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; - auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); - Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); - Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); - Tensor zero = make_fragment_like(tdKgdKaccum); - clear(zero); - cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); - cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert dQ from dQaccum (in float) to fp16/bf16. -// This is used in the case where we want to parallelize the backward across seqlen_k. -template -inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) - + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; - const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; - - Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), - Shape, Int>{}, - make_stride(params.dq_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - - Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutdQ{}); - - typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; - auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; - auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); - - typename Kernel_traits::TiledMmadQ tiled_mma_dq; - auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); - auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); - Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); - Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); - - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K - CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); - - Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); - clear(acc_dq); - for (int s = 0; s < nsplits; ++s) { - cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); } - tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride; - } - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } - // Convert acc_dq from fp32 to fp16 - Tensor rdQ = pytorch_flash::convert_type(acc_dq); - Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); - __syncthreads(); - Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); - cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); - - Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); - Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); - #pragma unroll - for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. -// This is used in the case where we want to parallelize the backward across seqlen_q. -template -inline __device__ void convert_dKV(const Params ¶ms) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - const int n_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (n_block * kBlockN >= binfo.actual_seqlen_k) return; - - const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) - + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; - const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) - + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; - const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded - + n_block * kBlockN) * params.d_rounded; - - Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), - Shape, Int>{}, - make_stride(params.dk_row_stride, _1{})); - Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), - Shape, Int>{}, - make_stride(params.dv_row_stride, _1{})); - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - - Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutdKV{}); - Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) - - typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; - auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; - auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); - - typename Kernel_traits::TiledMmadKV tiled_mma_dkv; - auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); - auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); - Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) - Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); - Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); - Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); - Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); - - Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); - CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); - - Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); - Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); - cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); - cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); - #pragma unroll - for (int i = 0; i < size(acc_dk); ++i) { - acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; - } - #pragma unroll - for (int i = 0; i < size(acc_dv); ++i) { - acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; - } - // Convert acc_dk from fp32 to fp16 - Tensor rdK = pytorch_flash::convert_type(acc_dk); - Tensor rdV = pytorch_flash::convert_type(acc_dv); - Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) - Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); - cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); - __syncthreads(); - Tensor tdKrdK = make_tensor(shape(tdKgdK)); - Tensor tdVrdV = make_tensor(shape(tdVgdV)); - cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); - cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); - - Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); - #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - pytorch_flash::copy( - gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); -} - -} // namespace flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h index 0386a07cc64fd6..844ba52a211a47 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h @@ -1,23 +1,23 @@ /****************************************************************************** - * Copyright (c) 2024, Tri Dao. + * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once +#include #include +#include #include #include #include - +#include #include #include #include #include -#include -#include -#include +#include namespace pytorch_flash { @@ -25,7 +25,57 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template +inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, + Tensor2 &acc_o, float softmax_scale_log2) { + if (Is_first) { + pytorch_flash::template reduce_max(scores, scores_max); + pytorch_flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + pytorch_flash::reduce_sum(scores, scores_sum); + } else { + Tensor scores_max_prev = make_fragment_like(scores_max); + cute::copy(scores_max, scores_max_prev); + pytorch_flash::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + pytorch_flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + Tensor scores_sum_cur = make_fragment_like(scores_sum); + pytorch_flash::reduce_sum(scores, scores_sum_cur); + #pragma unroll + for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void write_softmax_to_gmem( + Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_tiled_copy_P +) { + // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) + Layout l = tOrP.layout(); + Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); + CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); + CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); + #pragma unroll + for (int mi = 0; mi < size<1>(tPrP); ++mi) { + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -43,19 +93,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; - auto seed_offset = at::cuda::philox::unpack(params.philox_args); - pytorch_flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, - bidb, bidh, tidx, params.h); - - // Save seed and offset for backward. If we don't have this here, the 0-th thread block might - // exit early and no one saves the rng state. - if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { - if (params.philox_args.captured_) { - *params.seed = std::get<0>(seed_offset); - *params.extragraph_offset = std::get<1>(seed_offset); - } - } - const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; @@ -71,6 +108,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + // Save seed and offset for backward. If we don't have this here, the 0-th thread block might + // exit early and no one saves the rng state. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + auto seeds = at::cuda::philox::unpack(params.philox_args); + if (params.philox_args.captured_) { + *params.seed = std::get<0>(seeds); + *params.extragraph_offset = std::get<1>(seeds); + } + } const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; @@ -145,6 +191,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -152,6 +200,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tPgP = gmem_thr_copy_P.partition_D(gP); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); @@ -159,8 +208,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - Tensor tSgS = thr_mma.partition_C(gP); - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // @@ -181,6 +228,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + // // PREDICATES // @@ -223,11 +274,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Prologue + Tensor tQrQ = make_fragment_like(tQgQ); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs pytorch_flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + // // Copy rmem to smem + // // copy(tQrQ, tQsQ); + // pytorch_flash::cp_async_wait<0>(); + // __syncthreads(); // // if (cute::thread(1, 0)) { print(tQsQ); } // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); // // if (cute::thread0()) { print(sQNoSwizzle); } @@ -257,12 +313,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } - clear(acc_o); + auto seeds = at::cuda::philox::unpack(params.philox_args); + if (params.philox_args.captured_) { + *params.seed = std::get<0>(seeds); + *params.extragraph_offset = std::get<1>(seeds); + } - pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax; + unsigned long long seed = std::get<0>(seeds); + unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; - const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - pytorch_flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + clear(acc_o); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -300,9 +360,37 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // if (cute::thread0()) { print(acc_s); } - mask.template apply_mask( - acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print_tensor(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + } else { + // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) + // static_assert(decltype(size<0>(taccScS))::value == 4); + // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices. + // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + // Tensor idx_rowcol = make_tensor(taccScS.data(), pytorch_flash::convert_layout_acc_rowcol(taccScS.layout())); + // pytorch_flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM); + // Idk why it's get<1> and not get<0> of the stride. + // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } + // I can't get the stride from idx_row + pytorch_flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16 + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16 + ); + // if (cute::thread0()) { print_tensor(scores); } + } pytorch_flash::cp_async_wait<0>(); __syncthreads(); @@ -317,31 +405,33 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) - : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - - // Convert acc_s from fp32 to fp16/bf16 - Tensor rP = pytorch_flash::convert_type(acc_s); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = pytorch_flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor rP_drop = make_fragment_like(rP); - cute::copy(rP, rP_drop); - dropout.template apply_dropout( - rP_drop, block_row_idx, block_col_idx, kNWarps + Tensor tOrP_copy = make_fragment_like(tOrP); + cute::copy(tOrP, tOrP_copy); + pytorch_flash::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps ); - cute::copy(rP_drop, tSgS); - tSgS.data() = tSgS.data() + (-kBlockN); + pytorch_flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); + tPgP.data() = tPgP.data() + (-kBlockN); } if (Is_dropout) { - dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + pytorch_flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); } - - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } - pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + + pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration @@ -378,37 +468,58 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi cute::cp_async_fence(); } - mask.template apply_mask( - acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - - softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + pytorch_flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + ); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - Tensor rP = pytorch_flash::convert_type(acc_s); + Tensor rP = pytorch_flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor rP_drop = make_fragment_like(rP); - cute::copy(rP, rP_drop); - dropout.template apply_dropout( - rP_drop, block_row_idx, block_col_idx, kNWarps + Tensor tOrP_copy = make_fragment_like(tOrP); + cute::copy(tOrP, tOrP_copy); + pytorch_flash::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps ); - cute::copy(rP_drop, tSgS); - tSgS.data() = tSgS.data() + (-kBlockN); + pytorch_flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); + tPgP.data() = tPgP.data() + (-kBlockN); } if (Is_dropout) { - dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + pytorch_flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); } - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); - pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor lse = make_fragment_like(scores_sum); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + + // if (cute::thread0()) { print(acc_o_rowcol); } // Convert acc_o from fp32 to fp16/bf16 Tensor rO = pytorch_flash::convert_type(acc_o); @@ -474,7 +585,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; @@ -562,17 +673,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; - const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; - const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; - const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; - const index_t row_offset_k = block_table == nullptr - ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) - + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride - : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = block_table == nullptr - ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) - + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride - : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -626,6 +730,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // // PREDICATES // @@ -705,12 +814,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); - auto tKgK_data = tKgK.data(); - auto tVgV_data = tVgV.data(); for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { pytorch_flash::copy_w_min_idx( tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); if (params.rotary_dim == 0) { pytorch_flash::copy_w_min_idx( @@ -736,30 +844,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons } } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - if (block_table == nullptr) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - } else { - if (n_block > n_block_copy_min) { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; - const int offset_diff = block_table_offset_next - block_table_offset_cur; - tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; - tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; - } - } } // Need this before we can read in K again, so that we'll see the updated K values. __syncthreads(); - tKgK.data() = tKgK_data; - tVgV.data() = tVgV_data; + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } } // Read Q from gmem to smem, optionally apply rotary embedding. + Tensor tQrQ = make_fragment_like(tQgQ); if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs pytorch_flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, @@ -810,11 +907,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons clear(acc_o); - pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax; - - const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - pytorch_flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); - // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -835,15 +927,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Advance gV if (masking_step > 0) { - if (block_table == nullptr) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; - } + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); pytorch_flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads @@ -859,9 +943,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); // if (cute::thread0()) { print(acc_s); } - mask.template apply_mask( - acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + } else { + pytorch_flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + ); + } pytorch_flash::cp_async_wait<0>(); __syncthreads(); @@ -870,15 +966,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (n_block > n_block_min) { // Advance gK - if (block_table == nullptr) { - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; - } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. @@ -887,17 +975,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) - : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } - // Convert acc_s from fp32 to fp16/bf16 - Tensor rP = pytorch_flash::convert_type(acc_s); - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + // Convert scores from fp32 to fp16/bf16 + Tensor rP = pytorch_flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); - pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { @@ -913,15 +1002,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons pytorch_flash::cp_async_wait<0>(); __syncthreads(); // Advance gV - if (block_table == nullptr) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; - } + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); pytorch_flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); @@ -934,38 +1015,50 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons __syncthreads(); if (n_block > n_block_min) { // Advance gK - if (block_table == nullptr) { - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; - } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); } - mask.template apply_mask( - acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + pytorch_flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + ); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - Tensor rP = pytorch_flash::convert_type(acc_s); - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor rP = pytorch_flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); - pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); + // if (cute::thread0()) { print(acc_o_rowcol); } + Tensor lse = make_fragment_like(scores_sum); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } // if (cute::thread0()) { print(lse); } + // if (cute::thread0()) { print(acc_o_rowcol); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning @@ -1042,7 +1135,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1058,12 +1151,12 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - pytorch_flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + pytorch_flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1072,7 +1165,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - pytorch_flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + pytorch_flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1237,4 +1330,6 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } } -} // namespace pytorch_flash +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h index fcc99686eb835c..d76eaa4450e4b8 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h @@ -12,40 +12,27 @@ namespace pytorch_flash { -// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#define ARCH_SUPPORTS_FLASH -#define KERNEL_PARAM_MODIFIER __grid_constant__ -#else -#define KERNEL_PARAM_MODIFIER -#endif - -// Define a macro for unsupported architecture handling to centralize the error message -#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); - -// Use a macro to clean up kernel definitions -#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ -template \ -__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) { - #if defined(ARCH_SUPPORTS_FLASH) - static_assert(!(Is_causal && Is_local)); // Enforce constraints - pytorch_flash::compute_attn(params); +template +__global__ void flash_fwd_kernel(Flash_fwd_params params) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + pytorch_flash::compute_attn(params); #else - FLASH_UNSUPPORTED_ARCH + printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!"); #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) { - #if defined(ARCH_SUPPORTS_FLASH) - pytorch_flash::compute_attn_splitkv(params); +template +__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + pytorch_flash::compute_attn_splitkv(params); #else - FLASH_UNSUPPORTED_ARCH + printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!"); #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { +template +__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { static_assert(Log_max_splits >= 1); pytorch_flash::combine_attn_seqk_parallel(params); } @@ -65,30 +52,27 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool return_softmax = params.p_ptr != nullptr; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If return_softmax, set IsEvenMNConst to false to reduce number of templates - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); @@ -106,24 +90,22 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); @@ -136,7 +118,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If headdim is divisible by 64, then we set kBlockM = 8, etc. constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { @@ -170,7 +152,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); @@ -180,7 +162,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower @@ -204,7 +186,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { @@ -230,7 +212,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -267,7 +249,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 160; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -295,7 +277,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -323,7 +305,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -354,7 +336,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. @@ -371,4 +353,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -}; // namespace pytorch_flash +}; // namespace pytorch_fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h index ef1c3b91c94b03..875701e6cf2be9 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2024, Tri Dao. + * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once @@ -26,7 +26,7 @@ struct Flash_kernel_traits { #endif using ElementAccum = float; - using index_t = int64_t; + using index_t = uint32_t; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using MMA_Atom_Arch = std::conditional_t< @@ -91,10 +91,20 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape, Int>{})); - // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 - using SmemLayoutVtransposed = decltype( - composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutAtomO = decltype( composition(Swizzle{}, @@ -106,8 +116,10 @@ struct Flash_fwd_kernel_traits : public Base { using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; - static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); - static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); @@ -137,6 +149,15 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, @@ -223,18 +244,26 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); - using SmemLayoutKtransposed = decltype( - composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomKtransposed = decltype( + composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); + using SmemLayoutKtransposed = decltype(tile_to_shape( + SmemLayoutAtomKtransposed{}, + make_shape(Int{}, Int{}))); + // Maybe the KtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // static constexpr int kPBlockN = kBlockN; - // Temporarily disabling this for hdim 256 on sm86 and sm89 - // static_assert(kBlockN >= 64); - static_assert(kBlockN >= 32); + static_assert(kBlockN >= 64); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. - static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; + static constexpr int kPBlockN = 64; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; @@ -245,15 +274,30 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposed = decltype( - composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); - + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomPdStransposed = decltype( + composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); + using SmemLayoutPdStransposed = decltype(tile_to_shape( + SmemLayoutAtomPdStransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemCopyAtomPdS = Copy_Atom; - using SmemLayoutQdOtransposed = decltype( - composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomQdOtransposed = decltype( + composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); + using SmemLayoutQdOtransposed = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, @@ -273,12 +317,16 @@ struct Flash_bwd_kernel_traits : public Base { make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; - // Double buffer for sQ - static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); - static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); - static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); - static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); - static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); + static constexpr int kSmemPCount = size(SmemLayoutPdS{}); + static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); + static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); + static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); + static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) @@ -287,6 +335,9 @@ struct Flash_bwd_kernel_traits : public Base { + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + + kSmemdSSize + kSmemPSize; + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits_sm90.h b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits_sm90.h new file mode 100644 index 00000000000000..01ea212b452c47 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits_sm90.h @@ -0,0 +1,161 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include + +namespace pytorch_flash{ + +using namespace cute; + +template +struct Flash_kernel_traits_sm90 { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; +} // namespace pytorch_flash +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu index 63a80c4d2062fc..247b359b052199 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim128(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu index 720f54343a4693..54ba9b1d016578 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim128(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu index 04aa184a6f78c1..351df04f7bd8b3 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim160(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu index 979082162997ad..057023e3be16ac 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim160(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu index 76ac4426f0390e..f772b3c75a4d52 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim192(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu index d0a05f597219c4..91deb5f3e88e5a 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim192(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu index 14ce1a9a450fc6..bf11ee849e1bc3 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim224(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim224(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu index 259c84cf8cdaaf..59a062829d468b 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim224(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim224(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu index 1767b60f7908bb..48150fabcd61f5 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim256(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu index 6381904f7b5b72..f24074782bf7da 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim256(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu index bd47a37e7f6e36..8724f83e900719 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim32(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu index ae046260c3706f..aca37f6dfa07e7 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim32(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu index 42314aac9d2a2d..ce1c12768d75bf 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim64(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu index 616c784f7524ca..5f901a7b3243f0 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim64(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu index 6eccc4f455ad04..a0dc45eea3c887 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim96(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu index 54e455b81a36db..083828ee67f9b3 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim96(params, stream, configure); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py index ca1fe27f94903e..ee97a6a73cc050 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py @@ -27,8 +27,8 @@ KERNEL_IMPL_TEMPLATE_BWD = """ template<> -void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ - run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream); +void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {{ + run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure); }} """ diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h b/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h deleted file mode 100644 index 9cee154fbbd50e..00000000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h +++ /dev/null @@ -1,213 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -namespace pytorch_flash { - -using namespace cute; - -template -__forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, - const int col_idx_offset_ = 0) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if (col_idx >= max_seqlen_k) { - // Without the "make_coord" we get wrong results - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) = -INFINITY; - } - } - } - } -} - -template -__forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset, - const int max_seqlen_q, const int warp_row_stride, - const int window_size_left, const int window_size_right) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); - const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } - } - } - // if (cute::thread0()) { - // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); - // print(tensor(make_coord(i, mi), _)); - // // print(tensor(_, j + nj * size<1, 0>(tensor))); - // } - } - } -} - -template -__forceinline__ __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset, - const int max_seqlen_q, const int warp_row_stride) { - // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 - apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, - max_seqlen_q, warp_row_stride, -1, 0); -} - -template -__forceinline__ __device__ void apply_mask_causal_w_idx( - Tensor &tensor, Tensor const &idx_rowcol, - const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) -{ - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 2, "Only support 2D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); - CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); - #pragma unroll - for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { - if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { - tensor(mi, ni) = -INFINITY; - } - } - // if (cute::thread0()) { - // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); - // print(tensor(_, make_coord(j, ni))); - // // print(tensor(_, j + ni * size<1, 0>(tensor))); - // } - } -} - -template -struct Mask { - - const int max_seqlen_k, max_seqlen_q; - const int window_size_left, window_size_right; - const float alibi_slope; - - __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, - const int window_size_left, const int window_size_right, - const float alibi_slope=0.f) - : max_seqlen_k(max_seqlen_k) - , max_seqlen_q(max_seqlen_q) - , window_size_left(window_size_left) - , window_size_right(window_size_right) - , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { - }; - - // Causal_mask: whether this particular iteration needs causal masking - template - __forceinline__ __device__ void apply_mask(Tensor &tensor_, - const int col_idx_offset_, - const int row_idx_offset, - const int warp_row_stride) { - static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); - static_assert(Layout::rank == 3, "Only support 3D Tensor"); - static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); - static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; - // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } - if constexpr (Need_masking) { - // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_rowcol(tensor_.layout())); - // Do we need both row and column indices, or just column incides? - static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if constexpr (Col_idx_only) { - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - // No causal, no local - if constexpr (Has_alibi) { - tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; - } - if constexpr (!Is_even_MN) { - if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } - } - } - } - } - } else { - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); - const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if constexpr (Has_alibi) { - if constexpr (Is_causal) { - tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; - } else { - tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); - - } - } - if constexpr (Causal_mask) { - if (col_idx >= col_idx_limit_right) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } - } - if constexpr (Is_local) { - if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } - } - if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { - // Causal and Local already handles MN masking - if (col_idx >= max_seqlen_k) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } - } - } - } - } - } - } - } - }; - -}; - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh b/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh index bed362bdd0c8ea..472d6b211f052c 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh @@ -11,7 +11,7 @@ struct ull2 { unsigned long long y; }; -__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { +inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { uint2 *res; unsigned long long tmp; asm ("mul.wide.u32 %0, %1, %2;\n\t" @@ -21,7 +21,7 @@ __forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned return *res; } -__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { +inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { constexpr unsigned long kPhiloxSA = 0xD2511F53; constexpr unsigned long kPhiloxSB = 0xCD9E8D57; uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); @@ -30,7 +30,7 @@ __forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint return ret; } -__forceinline__ __device__ uint4 philox(unsigned long long seed, +inline __device__ uint4 philox(unsigned long long seed, unsigned long long subsequence, unsigned long long offset) { constexpr unsigned long kPhilox10A = 0x9E3779B9; @@ -51,3 +51,117 @@ __forceinline__ __device__ uint4 philox(unsigned long long seed, } } // namespace flash + +namespace { + +class Philox { +public: + __device__ inline Philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) + : STATE(0) + , seed_(seed) + , offset_(offset) + , key(reinterpret_cast(seed)) { + //key.x = (unsigned int)seed; + //key.y = (unsigned int)(seed >> 32); + //counter = make_uint4(0, 0, 0, 0); + //counter.z = (unsigned int)(subsequence); + //counter.w = (unsigned int)(subsequence >> 32); + //STATE = 0; + //incr_n(offset / 4); + + // key = reinterpret_cast(seed); + ull2 * tmp = reinterpret_cast(&counter); + tmp->x = offset / 4; + tmp->y = subsequence; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); + // } + } + __device__ inline uint4 operator()() { + // // if (STATE == 0) { + // uint4 counter_ = counter; + // uint2 key_ = key; + // // 7-round philox + // #pragma unroll + // for (int i = 0; i < 6; i++) { + // counter_ = pytorch_flash::philox_single_round(counter_, key_); + // key_.x += (kPhilox10A); + // key_.y += (kPhilox10B); + // } + // // output = philox_single_round(counter_, key_); + // uint4 output = pytorch_flash::philox_single_round(counter_, key_); + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); + // // } + // incr(); + // // } + // // return a float4 directly + // // unsigned long ret; + // // switch(STATE) { + // // case 0: ret = output.x; break; + // // case 1: ret = output.y; break; + // // case 2: ret = output.z; break; + // // case 3: ret = output.w; break; + // //} + // // STATE = (STATE + 1) % 4; + // return output; + return pytorch_flash::philox(seed_, offset_, offset_); + } + +private: + unsigned long long offset_, seed_; + struct ull2 { + uint64_t x; + uint64_t y; + }; + uint4 counter; + // uint4 output; + const uint2 key; + unsigned int STATE; + __device__ inline void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + + __device__ uint4 incr128 (uint4 ctr) + { + uint4 res; + asm ("add.cc.u32 %0, %4, %8;\n\t" + "addc.cc.u32 %1, %5, %9;\n\t" + "addc.cc.u32 %2, %6, %10;\n\t" + "addc.u32 %3, %7, %11;\n\t" + : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) + : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), + "n"(1), "n"(0), "n"(0), "n"(0)); + return res; + } + + __device__ inline void incr() { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + counter = incr128(counter); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + // static const unsigned long kPhiloxSA = 0xD2511F53; + // static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h b/aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h deleted file mode 100644 index 12dc1746c80878..00000000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h +++ /dev/null @@ -1,152 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace pytorch_flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K - static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - cute::copy(Cos(_, m, k), rCos(_, m, k)); - cute::copy(Sin(_, m, k), rSin(_, m, k)); - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS) / 2; ++i) { - float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); - float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); - S_fp32(2 * i) = real; - S_fp32(2 * i + 1) = imag; - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - Tensor rS_other = make_fragment_like(rS(_, 0, 0)); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; - Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); - cute::copy(gS_other, rS_other); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } - Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); - Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); - cute::copy(gCos, rCos(_, m, k)); - cute::copy(gSin, rSin(_, m, k)); - // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor S_other_fp32 = convert_type(rS_other); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS); ++i) { - S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); } - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h b/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h index dec2065ec400a6..239a8114b68b7b 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h @@ -1,15 +1,34 @@ /****************************************************************************** - * Copyright (c) 2024, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * ******************************************************************************/ #pragma once #include - -#include - -#include - +#include #include #include @@ -20,7 +39,7 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template -__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { +__device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); @@ -35,7 +54,7 @@ __device__ __forceinline__ void thread_reduce_(Tensor const &t } template -__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { +__device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++){ @@ -44,26 +63,26 @@ __device__ __forceinline__ void quad_allreduce_(Tensor &dst, T } template -__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { +__device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template -__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ +__device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ MaxOp max_op; reduce_(tensor, max, max_op); } -template -__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ SumOp sum_op; - thread_reduce_(tensor, sum, sum_op); + reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. template -__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { +inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); @@ -85,7 +104,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor &tenso // Apply the exp to all the elements. template -__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { +inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); @@ -115,67 +134,171 @@ __forceinline__ __device__ void max_scale_exp2_sum(Tensor &ten } } -//////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { + // Without the "make_coord" we get wrong results + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} -template -struct Softmax { - - using TensorT = decltype(make_tensor(Shape>{})); - TensorT row_max, row_sum; - - __forceinline__ __device__ Softmax() {}; - - template - __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); - static_assert(decltype(size<0>(scores))::value == kNRows); - if (Is_first) { - pytorch_flash::template reduce_max(scores, row_max); - pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - pytorch_flash::reduce_sum(scores, row_sum); - } else { - Tensor scores_max_prev = make_fragment_like(row_max); - cute::copy(row_max, scores_max_prev); - pytorch_flash::template reduce_max(scores, row_max); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); - static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +template +inline __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + // const int row_idx_offset = row_idx_offset_ + lane_id / 4; + const int row_idx_offset = row_idx_offset_; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float scores_max_cur = !Check_inf - ? row_max(mi) - : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); - float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - row_sum(mi) *= scores_scale; + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } } - pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - // We don't do the reduce across threads here since we don't need to use the row_sum. - // We do that reduce at the end when we need to normalize the softmax. - pytorch_flash::reduce_sum(scores, row_sum); + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } } - }; + } +} - template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { - SumOp sum_op; - quad_allreduce_(row_sum, row_sum, sum_op); - TensorT lse = make_fragment_like(row_sum); - Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); - static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +template +inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +inline __device__ void apply_mask_causal_w_idx( + Tensor &tensor, Tensor const &idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) +{ + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); #pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = row_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); - float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } } - return lse; + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, + unsigned long long seed, unsigned long long offset, + int block_row_start, int block_col_start, + int block_row_stride) { + // tensor has shape (8, MMA_M, MMA_N / 2) + using T = typename Engine::value_type; + auto encode_dropout = [](bool keep, T val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); }; -}; + static_assert(decltype(size<2>(tensor))::value % 2 == 0); + const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); + const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); + // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } + #pragma unroll + for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { + uint2 rowcol = make_uint2(block_row_start, block_col_start); + #pragma unroll + for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { + // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} + uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast(rowcol), offset); + // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} + uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); + // Special implementation for 16-bit types: we duplicate the threshold to the + // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction + // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, + // and the high 16 bits will be either 0xffff or 0x0000, depending on whether + // the random value is less than the threshold. + // We then do a bit-wise AND between the mask and the original value (in 32-bit). + // We're exploiting the fact that floating point comparison is equivalent to integer + // comparison, since we're comparing unsigned integers whose top 8-bits are zero. + if (!encode_dropout_in_sign_bit + && (std::is_same::value || std::is_same::value)) { + uint16_t rnd_16[16]; + #pragma unroll + for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } + uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); + #pragma unroll + for (int j = 0; j < 2; j++) { + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + #pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t mask; + asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); + tensor_uint32(i) &= mask; + } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } else { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); + } + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); + // // } + } + } +} } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h b/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h index ca12fa171bf989..4aa8474028868d 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h @@ -14,7 +14,6 @@ /// some_function(...); /// }); /// ``` - #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ @@ -26,46 +25,6 @@ } \ }() -#ifdef FLASHATTENTION_DISABLE_DROPOUT - #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define DROPOUT_SWITCH BOOL_SWITCH -#endif - -#ifdef FLASHATTENTION_DISABLE_ALIBI - #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define ALIBI_SWITCH BOOL_SWITCH -#endif - -#ifdef FLASHATTENTION_DISABLE_UNEVEN_K - #define EVENK_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - }() -#else - #define EVENK_SWITCH BOOL_SWITCH -#endif - -#ifdef FLASHATTENTION_DISABLE_LOCAL - #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define LOCAL_SWITCH BOOL_SWITCH -#endif - #define FP16_SWITCH(COND, ...) \ [&] { \ if (COND) { \ @@ -77,7 +36,7 @@ } \ }() -#define HEADDIM_SWITCH(HEADDIM, ...) \ +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ [&] { \ if (HEADDIM <= 32) { \ constexpr static int kHeadDim = 32; \ diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h b/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h index 2c8add318366ae..fc791b0b2107ea 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h @@ -22,17 +22,16 @@ #include #include + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace pytorch_flash { -//////////////////////////////////////////////////////////////////////////////////////////////////// - template -__forceinline__ __device__ uint32_t relu2(const uint32_t x); +inline __device__ uint32_t relu2(const uint32_t x); template<> -__forceinline__ __device__ uint32_t relu2(const uint32_t x) { +inline __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -50,7 +49,7 @@ __forceinline__ __device__ uint32_t relu2(const uint32_t x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template<> -__forceinline__ __device__ uint32_t relu2(const uint32_t x) { +inline __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); @@ -63,10 +62,10 @@ __forceinline__ __device__ uint32_t relu2(const uint32_t x) #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template -__forceinline__ __device__ uint32_t convert_relu2(const float2 x); +inline __device__ uint32_t convert_relu2(const float2 x); template<> -__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { +inline __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); @@ -75,7 +74,7 @@ __forceinline__ __device__ uint32_t convert_relu2(const float2 } template<> -__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { +inline __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); @@ -89,20 +88,20 @@ __forceinline__ __device__ uint32_t convert_relu2(const flo template struct MaxOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster -__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +__device__ inline T operator()(T const & x, T const & y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -111,7 +110,7 @@ template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template - static __device__ __forceinline__ T run(T x, Operator &op) { + static __device__ inline T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); @@ -123,7 +122,7 @@ struct Allreduce { template<> struct Allreduce<2> { template -static __device__ __forceinline__ T run(T x, Operator &op) { +static __device__ inline T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } @@ -135,7 +134,7 @@ template -__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, +inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { @@ -162,9 +161,9 @@ __forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, template -__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, - ThrCopy smem_thr_copy_B) { +inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K @@ -184,48 +183,42 @@ __forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tC // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) template -__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { +inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) -// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. template -__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { using X = Underscore; - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); static_assert(mma_shape_K == 8 || mma_shape_K == 16); - if constexpr (mma_shape_K == 8) { - return acc_layout; - } else { - auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) - return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) -template -__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { - using X = Underscore; - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) - return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template -__forceinline__ __device__ auto convert_type(Tensor const &tensor) { +inline __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; @@ -237,7 +230,7 @@ __forceinline__ __device__ auto convert_type(Tensor const &tenso //////////////////////////////////////////////////////////////////////////////////////////////////// template -__forceinline__ __device__ void relu_(Tensor &tensor) { +inline __device__ void relu_(Tensor &tensor) { constexpr int numel = decltype(size(tensor))::value; static_assert(numel % 2 == 0); using value_t = typename Engine::value_type; @@ -253,7 +246,7 @@ __forceinline__ __device__ void relu_(Tensor &tensor) { // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction template -__forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { +inline __device__ auto convert_type_relu(Tensor const &tensor) { using From_type = typename Engine::value_type; static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v); @@ -295,7 +288,7 @@ void cp_async_wait() { template -__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, +inline __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); @@ -364,7 +357,7 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor -__forceinline__ __device__ void copy_w_min_idx(Tensor const &S, +inline __device__ void copy_w_min_idx(Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0, const int min_MN=0) { @@ -391,4 +384,137 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor const &S //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void copy_rotary_interleaved(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_contiguous(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index e2ea560b6afc6d..421bc83ebed432 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -242,7 +242,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) return true; } -bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89( +bool check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90( sdp_params const& params, bool debug) { // Flash Attention will raise an error in the backward pass if the head_dim @@ -252,19 +252,11 @@ bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89( auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm86_or_sm89 = check_sm_version(dprops); bool is_head_dim_gt192 = params.query.sym_size(-1) > 192; - bool is_head_dim_lte224 = params.query.sym_size(-1) <= 224; - bool is_dropout = params.dropout > 0.0; - // head_dim size in (192, 224] is not supported on sm86 and sm89 - bool cond1 = is_head_dim_gt192 && is_head_dim_lte224; - // head_dim size > 224 and is_dropout is not supported on sm86 and sm89 - bool cond2 = params.query.sym_size(-1) > 224 && is_dropout; - if (input_requires_grad(params) && is_sm86_or_sm89 && (cond1 || cond2)) { + if (input_requires_grad(params) && is_sm86_or_sm89 && is_head_dim_gt192) { if (debug) { TORCH_WARN( - "Flash attention currently doesn't support training with head_dim ∈ (192, 224] or " - "(head_dim ∈ (224, 256] and dropout > 0.0) on gpu architectures in the range[sm86, sm89].", - "Attempting to run with dropout set to: ", params.dropout, - "and head_dim: ", + "Flash attention currently doesn't support training with head_dim greater than 192 on gpu architectures in the range[sm86, sm89].", + "Attempting to run with head_dim: ", params.query.sym_size(-1), " on a sm ", dprops->major, ".", dprops->minor, " gpu."); } @@ -475,7 +467,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { check_for_attn_mask, check_head_dim_size_flash, check_flash_attention_hardware_support, - check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89, + check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90, check_flash_causal_non_square_seqlens, check_dtypes_low_precision); for (auto& constraint : general_constraints) { diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index 24eebee7a75ab5..61999bc706c693 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -106,11 +106,10 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, - int window_size_left, + const int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_) { @@ -312,14 +311,13 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - c10::optional &alibi_slopes_, // num_heads or b x num_heads - int max_seqlen_q, + const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, - bool is_causal, - int window_size_left, + const bool is_causal, + const int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_) { @@ -345,13 +343,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - int window_size_left, + const int window_size_left, int window_size_right, - const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { check_gpu_arch(); @@ -634,16 +630,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 - c10::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop const float softmax_scale, const bool zero_tensors, const bool is_causal, - int window_size_left, + const int window_size_left, int window_size_right, - const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm"); diff --git a/test/test_transformers.py b/test/test_transformers.py index e752ba1fa41131..af14b06c21048d 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1328,17 +1328,13 @@ class TestSDPAFailureModes(NNTestCase): _do_cuda_non_default_stream = True @onlyCUDA - @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, - "Does not support fused SDPA or not SM86+ hardware", - ) + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, + "Does not support fused SDPA or not SM86+ hardware") @parametrize("head_dim", [193, 204, 256]) - @parametrize("dropout_p", [0.0, 0.2]) - def test_flash_backward_failure_sm86plus(self, device, head_dim: int, dropout_p: float): + def test_flash_backward_failure_sm86plus(self, device, head_dim: int): dtype = torch.float16 make_tensor = partial(torch.rand, device=device, dtype=dtype) - # See check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89 in - # pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h + # See check_requires_grad_and_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h size = (2, 2, 4, head_dim) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) @@ -1355,15 +1351,8 @@ def test_flash_backward_failure_sm86plus(self, device, head_dim: int, dropout_p: q = make_tensor(size, requires_grad=True) k = make_tensor(size, requires_grad=True) v = make_tensor(size, requires_grad=True) - if 192 < head_dim <= 224 or (head_dim > 224 and dropout_p != 0.0): - self.assertRaises( - RuntimeError, - lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, None, dropout_p, False - ), - ) - else: - flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, dropout_p, False) + self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, None, 0.0, False)) @onlyCUDA def test_dispatch_fails_no_backend(self, device): @@ -1600,6 +1589,7 @@ def test_nested_fails_on_padding_head_dim(self, device): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) + @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device, "Current platform does not support fused SDPA or is an SM80+ device.") @@ -1680,35 +1670,37 @@ def test_flash_attention_fail_with_non_square_causal_attention(self, device): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, is_causal=True)) -def _get_block_size_n(device, head_dim, is_dropout, is_causal): +def _get_block_size(device, head_dim, is_causal): # This should match the block sizes in the CUDA kernel + # Mask is only interesting when we are setting dropout + is_dropout = True assert head_dim <= 256 major, minor = torch.cuda.get_device_capability(device) is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) is_sm80 = major == 8 and minor == 0 is_sm90 = major == 9 and minor == 0 if head_dim <= 32: - return 128 + return 128, 128 if head_dim <= 64: - return 128 if not is_dropout else 64 + return (128, 128) if not is_dropout else (128, 64) elif head_dim <= 96: - return 64 + return (64, 64) if (is_sm8x and is_causal) else (128, 64) elif head_dim <= 128: if is_sm8x: - return 64 if (not is_dropout and is_causal) else 32 + return (64, 64) if (not is_dropout and is_causal) else (128, 32) else: - return 64 if not is_dropout else 32 + return 128, (64 if not is_dropout else 32) elif head_dim <= 160: if is_sm8x: - return 64 + return (128, 64) if not is_causal else (64, 64) else: - return 32 + return 128, 32 elif head_dim <= 192: - return 64 + return (128, 64) if not is_dropout else (64, 64) elif head_dim <= 224: - return 64 + return (128, 64) if (is_sm80 or is_sm90) else (64, 64) elif head_dim <= 256: - return 64 + return (128, 64) if is_sm80 else (64, 64) def pad_last_dim(input_tensor, alignment_size, slice: bool = False): @@ -1971,114 +1963,7 @@ class TestSDPACudaOnly(NNTestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True - # TODO USED FOR TESTING THE SCORES, e.g. testing ALIBI we don't need this now - def normalize_flash_attn_S( - self, - attn_unnorm, - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - attn_bias=None, - is_dropout=False, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - scale=None, - ): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k, v: (batch_size, seqlen_k, nheads, head_dim) - key_padding_mask: (batch_size, seqlen_q) - attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) - Output: - softmax_lse: (batch_size, nheads, seqlen_q) - softmax_max: (batch_size, nheads, seqlen_q) - """ - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - if causal: - window_size = (window_size[0], 0) - q, k, v = q.float(), k.float(), v.float() - _, seqlen_q, _, head_dim = q.shape - seqlen_k = k.shape[1] - b = q.shape[0] - from torch.nn.attention.bias import _calculate_scale - scale = _calculate_scale(head_dim, scale) - scores = torch.matmul(q.transpose(1, 2) * scale, k.permute(0, 2, 3, 1)) - if key_padding_mask is not None: - scores.masked_fill_(~key_padding_mask.view(b, 1, 1, -1), float("-inf")) - if window_size[0] >= 0 or window_size[1] >= 0: - local_mask = self.construct_local_mask( - seqlen_q, - seqlen_k, - window_size, - query_padding_mask, - key_padding_mask, - q.device, - ) - scores.masked_fill_(local_mask, float("-inf")) - if attn_bias is not None: - scores = scores + attn_bias.to(dtype=scores.dtype) - block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) - scores_block = scores.split(block_size_n, dim=-1) - lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) - lse = torch.logsumexp(lse_block, dim=-1) - # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf - # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. - lse[lse == float("-inf")] = float("inf") - scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) - cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) - attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) - attn_norm = torch.cat( - [ - a * (torch.exp(m - lse)).unsqueeze(-1) - for a, m in zip(attn_unnorm_block, cummax_block) - ], - dim=-1, - ) - if query_padding_mask is not None: - attn_norm.masked_fill_(~query_padding_mask.view(b, 1, -1, 1), 0.0) - # attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) - return attn_norm.to(dtype=attn_unnorm.dtype) - - def construct_local_mask(self, seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device): - # row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") - row_idx = torch.arange(seqlen_q, device=device, dtype=torch.long).view(-1, 1) - col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - sk = ( - seqlen_k - if key_padding_mask is None - else key_padding_mask.sum(-1).view(-1, 1, 1, 1) - # else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") - ) - sq = ( - seqlen_q - if query_padding_mask is None - else query_padding_mask.sum(-1).view(-1, 1, 1, 1) - # else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") - ) - if window_size[0] < 0: - return col_idx > row_idx + sk - sq + window_size[1] - else: - sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk - return torch.logical_or( - col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), - col_idx < row_idx + sk - sq - window_size[0], - ) - - def convert_flash_attn_S_to_softmax( - self, - S, - seqlen_q, - seqlen_k, - query_padding_mask, - key_padding_mask, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - ): + def convert_flash_attn_S_to_softmax(self, S, query_padding_mask, key_padding_mask, head_dim, causal=False): """FlashAttention stores the S matrix in a different way. Arguments: S: (batch_size, nheads, seqlen_q, seqlen_k) @@ -2087,45 +1972,53 @@ def convert_flash_attn_S_to_softmax( """ if TEST_WITH_ROCM: return S - b = S.shape[0] - if causal: - window_size = (window_size[0], 0) - seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] - S_converted = S - if window_size[0] >= 0 or window_size[1] >= 0: - local_mask = self.construct_local_mask( - seqlen_q, - seqlen_k, - window_size, - query_padding_mask, - key_padding_mask, - S.device, - ) - local_mask = F.pad( - local_mask, - (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), - value=True, - ) - S_converted = S_converted.masked_fill(local_mask, 0.0) + b, h, seqlen_q, seqlen_k = S.shape + warps_n = 4 + blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal) + nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m + nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n + mmas_n = (blocksize_n + 16 - 1) // 16 + + # Reshape S using PyTorch native functions + S_flat = S.view(b, h, nblocks_m, blocksize_m, nblocks_n, blocksize_n) + S_flat = S_flat.permute(0, 1, 2, 4, 3, 5) + S_flat = S_flat.reshape(b, h, nblocks_m, nblocks_n, (blocksize_m * blocksize_n)) + S_converted = S_flat.view(b, h, nblocks_m, nblocks_n, mmas_n, -1, warps_n, 8, 4, 2, 2, 2) + S_converted = S_converted.permute(0, 1, 2, 5, 6, 10, 7, 3, 4, 9, 8, 11) + S_converted = S_converted.reshape(b, h, (nblocks_m * S_converted.size(3) * + warps_n * 2 * 8), (nblocks_n * mmas_n * 2 * 4 * 2)) + if causal: + causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1) + S_converted.masked_fill_(causal_mask, 0.0) # Need to zero out things not in attention_mask in case S was initialized with random values # and some of those values aren't overwritten. - seqlen_q_og = ( - query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded - ) + seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q if query_padding_mask is not None: - query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) - # S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) - S_converted = S_converted.masked_fill(~query_padding_mask.view(b, 1, -1, 1), 0.0) + if seqlen_q_og < seqlen_q: + query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og)) + else: + query_padding_mask = query_padding_mask[:, :seqlen_q] + q_mask_fill = ~query_padding_mask.view(query_padding_mask.shape[0], 1, query_padding_mask.shape[1], 1) + S_converted = S_converted.masked_fill(q_mask_fill, 0.0) seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k if key_padding_mask is not None: - key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) - S_converted = S_converted.masked_fill(~key_padding_mask.view(b, 1, 1, -1), 0.0) - # S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) - S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) - S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) - return S_converted[:, :, :seqlen_q, :seqlen_k] + if seqlen_k_og < seqlen_k: + key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og)) + else: + key_padding_mask = key_padding_mask[:, :seqlen_k] + k_mask_fill = ~key_padding_mask.view(key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1]) + S_converted = S_converted.masked_fill(k_mask_fill, 0.0) + if seqlen_q_og < seqlen_q: + S_converted = S_converted[:, :, :seqlen_q_og, :] + else: + S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q)) + if seqlen_k_og < seqlen_k: + S_converted = S_converted[:, :, :, :seqlen_k_og] + else: + S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k)) + return S_converted @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) @@ -2477,29 +2370,28 @@ def test_sdp_choice_with_determinism(self, device, warn_only): query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) with use_deterministic_algorithims(True, warn_only=warn_only): + # Note that this should swith to a testing version with we remove old context manager with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value - @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") - @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA") @parametrize("warn_only", [True, False]) - def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fused_kernel): + def test_mem_eff_backwards_throws_determinism_warning(self, device, warn_only): batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64 shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) - make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float16, packed=False, requires_grad=True) + make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False, requires_grad=True) query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) - kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else "Flash Attention" warning_context = ( self.assertWarnsRegex( UserWarning, - f"{kernel_name} defaults to a non-deterministic algorithm.", + "Memory Efficient attention defaults to a non-deterministic algorithm.", ) if warn_only else contextlib.nullcontext() ) with use_deterministic_algorithims(True, warn_only=warn_only): - with sdpa_kernel(backends=[fused_kernel]): + with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): with warning_context: torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward() @@ -2818,6 +2710,8 @@ def is_power_of_2(n): is_dropout = dropout_p > 0.0 if not is_dropout: + # Problem: We pad sizes in the composite region of the top level SDPA. But we need the + # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) with sdpa_kernel(backends=[SDPBackend.MATH]): @@ -2828,8 +2722,6 @@ def is_power_of_2(n): out_lp_ref = F.scaled_dot_product_attention( query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale) else: - # Problem: We pad sizes in the composite region of the top level SDPA. But we need the - # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout q_padded, q_og_size = pad_last_dim(query, 8) k_padded, k_og_size = pad_last_dim(key, 8) v_padded, v_og_size = pad_last_dim(value, 8) @@ -2848,14 +2740,9 @@ def is_power_of_2(n): batch_size, seq_len_k, device=device, dtype=torch.bool) softmax_mask = self.convert_flash_attn_S_to_softmax( - dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask, + dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal)[:, :, :seq_len_q, :seq_len_k] dropout_mask = softmax_mask >= 0 - # attn_unnorm = softmax_mask.abs() - # attn = self.normalize_flash_attn_S(attn_unnorm, q_padded, - # k_padded, v_padded, query_padding_mask, - # key_padding_mask, None, True, is_causal, scale=scale) - # High Precision Math Reference out_ref = torch.ops.aten._scaled_dot_product_attention_math( query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] @@ -2936,8 +2823,7 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d batch_size, seq_len_k, device=device, dtype=torch.bool) softmax_mask = self.convert_flash_attn_S_to_softmax( - dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask, - causal=is_causal)[:, :, :seq_len_q, :seq_len_k] + dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal) dropout_mask = softmax_mask >= 0 return dropout_mask @@ -3292,7 +3178,7 @@ def rand_nt(sequence_list, num_heads, head_dim): key_padding_mask = key_padding_mask.to("cuda") softmax_mask = self.convert_flash_attn_S_to_softmax( - dbug_mask, max_seq_len_q, max_seq_len_kv, query_padding_mask, key_padding_mask, causal=is_causal) + dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal) dropout_mask = softmax_mask >= 0 nt_stack = [] for tensor_component in range(batch_size):