Skip to content

Commit

Permalink
Revert "Update flash_attention kernel from 2.3.6 to 2.5.5 (pytorch#11…
Browse files Browse the repository at this point in the history
…8935)"

This reverts commit 4b7a521.

Reverted pytorch#118935 on behalf of https://github.com/atalman due to Significantly increases build time. Optimization is needed ([comment](pytorch#118935 (comment)))
  • Loading branch information
pytorchmergebot committed Feb 29, 2024
1 parent 96eff4e commit 1458f1d
Show file tree
Hide file tree
Showing 41 changed files with 2,301 additions and 2,106 deletions.
5 changes: 2 additions & 3 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include <ATen/ops/scalar_tensor.h>
#include <ATen/ops/scaled_dot_product_attention.h>
#include <ATen/ops/split_native.h>
#include <ATen/ops/narrow_native.h>
#include <ATen/ops/zeros.h>
#endif

Expand All @@ -64,6 +65,7 @@
#include <ATen/native/transformers/attention.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <ATen/native/transformers/cuda/sdp_utils.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>

Expand Down Expand Up @@ -850,7 +852,6 @@ _flash_attention_forward(
// of the tensor. This is useful for kv cache scenarios but for now
// we will not support in this PR.
c10::optional<Tensor> seqused_k = c10::nullopt;
c10::optional<Tensor> alibi_slopes = c10::nullopt;

// We are going to have two paths:
// 1. The standard MHA path for dense tensors
Expand Down Expand Up @@ -879,7 +880,6 @@ _flash_attention_forward(
cumulative_sequence_length_q.value(),
cumulative_sequence_length_k.value(),
seqused_k, /*seqused_k*/
alibi_slopes, /*alibi_slopes*/
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout_p,
Expand All @@ -905,7 +905,6 @@ _flash_attention_forward(
key,
value,
out,
alibi_slopes,
dropout_p,
softmax_scale,
is_causal,
Expand Down
26 changes: 4 additions & 22 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <string_view>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <cstdint>
#include <type_traits>
Expand Down Expand Up @@ -42,8 +41,9 @@
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
#endif
namespace at {

namespace at::native {
namespace native {

std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
const Tensor& grad_out,
Expand Down Expand Up @@ -74,21 +74,6 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
// The kernel computes irregardless we will drop for this functions return
Tensor grad_softmax;

// Currently unused args:
c10::optional<at::Tensor> alibi_slopes{c10::nullopt};

bool determinisitic{false};
auto& ctx = at::globalContext();
if (ctx.deterministicAlgorithms()) {
if (ctx.deterministicAlgorithmsWarnOnly()) {
TORCH_WARN_ONCE(
"Flash Attention defaults to a non-deterministic algorithm. ",
"To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False).");
} else {
determinisitic = true;
}
}

// We check the whether the cumulative_sequence_length_q is defined
// in order to determine whether we are using varlen or dense forward
if (cumulative_sequence_length_q.defined()) {
Expand All @@ -105,7 +90,6 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
dv,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
alibi_slopes,
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout_p,
Expand All @@ -114,7 +98,6 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
is_causal,
-1, /*window_size_left*/
-1, /*window_size_right*/
determinisitic,
philox_seed,
philox_offset);
return std::make_tuple(dQuery, dKey, dValue);
Expand All @@ -130,13 +113,11 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
dq,
dk,
dv,
alibi_slopes,
dropout_p,
softmax_scale,
is_causal,
-1, /*window_size_left*/
-1, /*window_size_right*/
determinisitic,
philox_seed,
philox_offset);
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
Expand Down Expand Up @@ -649,4 +630,5 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
}

} // namespace at::native
} // namespace native
} // namespace at
74 changes: 0 additions & 74 deletions aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ struct BlockInfo {
}

template <typename index_t>
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}

template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}

Expand Down
96 changes: 0 additions & 96 deletions aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h

This file was deleted.

30 changes: 9 additions & 21 deletions aten/src/ATen/native/transformers/cuda/flash_attn/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,21 @@
#pragma once

#include <cuda.h>
#include <vector>

#include <ATen/cuda/PhiloxUtils.cuh>

namespace pytorch_flash{

#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif

#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
namespace pytorch_flash {
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
using index_t = int64_t;
using index_t = uint32_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
Expand Down Expand Up @@ -98,12 +96,7 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ rotary_sin_ptr;

// The indices to index into the KV cache.
int * __restrict__ cache_batch_idx;

// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int *__restrict__ cache_batch_idx;

// The dropout probability (probability of keeping an activation).
float p_dropout;
Expand Down Expand Up @@ -133,9 +126,6 @@ struct Flash_fwd_params : public Qkv_params {
bool is_rotary_interleaved;

int num_splits; // For split-KV version

void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -175,16 +165,14 @@ struct Flash_bwd_params : public Flash_fwd_params {

// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;

bool deterministic;
index_t dq_accum_split_stride;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);


} // namespace pytorch_flash
Loading

0 comments on commit 1458f1d

Please sign in to comment.