diff --git a/CMakeLists.txt b/CMakeLists.txt index 7adeac323a91f..9f95efeeffd8b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -742,28 +742,13 @@ if(MSVC) append_cxx_flag_if_supported("/utf-8" CMAKE_CXX_FLAGS) endif() -# Note for ROCM platform: -# 1. USE_ROCM is always ON until include(cmake/Dependencies.cmake) -# 2. USE_CUDA will become OFF during re-configuration -# Truth Table: -# CUDA 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default -# CUDA 2nd pass: USE_CUDA=True;USE_ROCM=False, FLASH evaluates to ON by default -# ROCM 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default -# ROCM 2nd pass: USE_CUDA=False;USE_ROCM=True, FLASH evaluates to ON by default -# CPU 1st pass: USE_CUDA=False(Cmd Option);USE_ROCM=True, FLASH evaluates to OFF by default -# CPU 2nd pass: USE_CUDA=False(Cmd Option);USE_ROCM=False, FLASH evaluates to OFF by default -# Thus we cannot tell ROCM 2nd pass and CPU 1st pass -# -# The only solution is to include(cmake/Dependencies.cmake), and defer the -# aotriton build decision later. - -include(cmake/Dependencies.cmake) - +# CAVEAT: do NOT check USE_ROCM here, because USE_ROCM is always True until +# include(cmake/Dependencies.cmake) cmake_dependent_option( USE_FLASH_ATTENTION "Whether to build the flash_attention kernel for scaled dot product attention.\ Will be disabled if not supported by the platform" ON - "USE_CUDA OR USE_ROCM;NOT MSVC" OFF) + "USE_CUDA AND NOT MSVC" OFF) # We are currenlty not using alibi attention for Flash # So we disable this feature by default @@ -779,6 +764,8 @@ cmake_dependent_option( Will be disabled if not supported by the platform" ON "USE_CUDA" OFF) +include(cmake/Dependencies.cmake) + if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo") diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 96b839820efd7..e2ea560b6afc6 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -21,10 +21,6 @@ #include #include -#if USE_ROCM -#include -#endif - /** * Note [SDPA Runtime Dispatch] * SDPA relies on a runtime dispatch mechanism to select the appropriate @@ -186,18 +182,32 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug // Check that the gpu is capable of running flash attention using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; + auto dprops = at::cuda::getCurrentDeviceProperties(); #if USE_ROCM - auto stream = at::cuda::getCurrentCUDAStream().stream(); - if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - if (debug) { - TORCH_WARN( - "Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); - } - return false; + constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-"; + static const char *over_arch = [] { + auto rc = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE"); + if (rc) { + TORCH_WARN("SDPA functions only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. " + "Later changes to this environment variable with os.environ " + "(or other methods) will not affect SDPA function's behavior."); + } + return rc; + }(); + const char* real_arch = dprops->gcnArchName; + const char* arch = over_arch ? over_arch : real_arch; + if (mi200 != arch) { + if (debug) { + TORCH_WARN( + "Flash attention only supports gpu architecture gfx90a, for now. Attempting to run on a ", + arch, + ".", + over_arch ? " This is overrided by PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE. Real architecture is " : "", + over_arch ? real_arch : ""); + } + return false; } #else - auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { TORCH_WARN( diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index 9a43404f5d337..24eebee7a75ab 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -59,140 +59,41 @@ #include #include -// AOTriton headers -#include -#include -#include -#include +// OORT headers +#include +#include +#include +#include namespace pytorch_flash { namespace { -void check_gpu_arch(hipStream_t stream) { - auto ret = aotriton::v2::flash::check_gpu(stream); - if (hipSuccess != ret) { - TORCH_CHECK(false, - "FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)") - } -} - -aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype) -{ -#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname - CAST_TYPE(kByte, kUInt8); - CAST_TYPE(kUInt16, kUInt16); - CAST_TYPE(kUInt32, kUInt32); - CAST_TYPE(kUInt64, kUInt64); - CAST_TYPE(kChar, kInt8); - CAST_TYPE(kShort, kInt16); - CAST_TYPE(kInt, kInt32); - CAST_TYPE(kLong, kInt64); - CAST_TYPE(kHalf, kFloat16); - CAST_TYPE(kFloat, kFloat32); - CAST_TYPE(kBFloat16, kBFloat16); - return aotriton::DType::kUnknown; -#undef CAST_TYPE -} - -template -struct IntArrayRefCaster { - // std::array cast(IntArrayRef); -}; +c10::once_flag fa_gcn_arch_override_flag; +const char* fa_override_arch = nullptr; -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ static_cast(ref.at(0)) }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)) - }}; +void init_fa_override_arch() { + fa_override_arch = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE"); + if (fa_override_arch) { + TORCH_WARN("ROCM flash attention backend only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. " + "Later changes to this environment variable with os.environ " + "(or other methods) will not affect this backend's behavior."); } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)), - static_cast(ref.at(2)) - }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)), - static_cast(ref.at(2)), - static_cast(ref.at(3)) - }}; - } -}; - - -template -aotriton::TensorView mk_aotensor(const at::Tensor& q, c10::string_view tensor_name) -{ - const auto strides = q.strides(); - int real_rank = strides.size(); - if (real_rank != Rank) { // Lazy convertion of tensor_name - TORCH_CHECK(false, - std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) - + " but is " + std::to_string(real_rank)); - } - return aotriton::TensorView(reinterpret_cast(q.data_ptr()), - IntArrayRefCaster::cast(q.sizes()), - IntArrayRefCaster::cast(strides), - cast_dtype(q.dtype())); } -template // For Output Tensor -class TensorStorageSanitizer { -public: - TensorStorageSanitizer(const at::Tensor& ref, - at::Tensor& to_sanitize) - : ref_(ref), to_sanitize_(to_sanitize) - { - need_sanitize = ref_.strides() != to_sanitize_.strides(); - if (!need_sanitize) - return; - - temp_ = at::empty_like(ref_); - if (COPY_FROM_INPUT) { - temp_.copy_(to_sanitize_); - } - } +void check_gpu_arch() { + auto dprops = at::cuda::getCurrentDeviceProperties(); - ~TensorStorageSanitizer() - { - if (need_sanitize && COPY_BACK) - to_sanitize_.copy_(temp_); - } - - at::Tensor& sanitized_tensor() - { - if (need_sanitize) - return temp_; - return to_sanitize_; + constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-"; + c10::call_once(fa_gcn_arch_override_flag, init_fa_override_arch); + if (fa_override_arch) { + TORCH_CHECK(mi200 == fa_override_arch, + "FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName) + " override as " + fa_override_arch); + } else { + TORCH_CHECK(mi200 == dprops->gcnArchName, + "FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName)); } -private: - const at::Tensor& ref_; - at::Tensor& to_sanitize_; - at::Tensor temp_; - bool need_sanitize = false; -}; +} } @@ -213,8 +114,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head int window_size_right, const bool return_softmax, c10::optional gen_) { - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); + check_gpu_arch(); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, @@ -306,51 +206,102 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head seed_t = at::empty({}, at::dtype(at::kLong)); offset_t = at::empty({}, at::dtype(at::kLong)); } - } - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*seed_t.data_ptr(), *offset_t.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(seed_t.data_ptr(), offset_t.data_ptr(), 0); - } } - // Transpose tensors to meet AOTriton's Flash API - at::Tensor q_t = q_padded.permute({0,2,1,3}); - at::Tensor k_t = k_padded.permute({0,2,1,3}); - at::Tensor v_t = v_padded.permute({0,2,1,3}); - at::Tensor output_t = out.permute({0,2,1,3}); + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + //reorder tensors and make contiguous + at::Tensor q_t = q_padded.permute({0,2,1,3}).contiguous(); + at::Tensor k_t = k_padded.permute({0,2,1,3}).contiguous(); + at::Tensor v_t = v_padded.permute({0,2,1,3}).contiguous(); + at::Tensor output_t = out.permute({0,2,1,3}).contiguous(); - at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse + at::Tensor M = at::empty({batch_size, num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse - at::Tensor softmax_fa_t; - if (return_softmax) { - softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, - at::dtype(q.dtype()).device(q.device())); - } else { - softmax_fa_t = at::empty({ 0, 0, 0, 0 }, at::dtype(q.dtype()).device(q.device())); - } + constexpr int BLOCK_M = 16; + constexpr int BLOCK_N = 16; + dim3 grid; + grid.x = (q_t.sizes()[2] + BLOCK_M - 1) / BLOCK_M; + grid.y = q_t.sizes()[0] * q_t.sizes()[1]; + grid.z = 1; + dim3 block { 64 * 4, 1, 1 }; // compiled triton kernel intrinsic + + at::Tensor softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, + at::dtype(q.dtype()).device(q.device())); hipError_t err; // TODO: Error handling - using aotriton::v2::flash::attn_fwd; - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - softmax_scale, - mk_aotensor<2>(M, "M"), - mk_aotensor(output_t, "Out"), - p_dropout, - philox_args.seed_.val, - philox_args.offset_.val, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - stream); +#define CALL_FWD(FP, STAGE, BLOCK_M, BLOCK_DMODEL, BLOCK_N, pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX) \ + do { \ + oort::attn_fwd fwd_opt; \ + err = fwd_opt(grid, block, \ + (FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \ + softmax_scale, (float*)M.data_ptr(), (FP*)output_t.data_ptr(), \ + q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \ + k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \ + v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \ + output_t.stride(0), output_t.stride(1), output_t.stride(2), output_t.stride(3), \ + q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \ + *(uint64_t*)(seed_t.data_ptr()), *(uint32_t*)(offset_t.data_ptr()), \ + (FP*)(softmax_fa_t.data_ptr()), \ + stream); \ + } while(0) + + // TODO: Ugly but works + constexpr int kFwdUseCausal = 3; + constexpr int kFwdNoCausal = 1; + int d_head = q_t.sizes()[3]; + constexpr int BM = BLOCK_M; + constexpr int BN = BLOCK_N; + if (q_dtype == at::kHalf) { + if (is_causal) { + if (d_head == 16) + CALL_FWD(__fp16,kFwdUseCausal,BM,16,BN,true,true,true); + else if (d_head == 32) + CALL_FWD(__fp16,kFwdUseCausal,BM,32,BN,true,true,true); + else if (d_head == 64) + CALL_FWD(__fp16,kFwdUseCausal,BM,64,BN,true,true,true); + else if (d_head == 128) + CALL_FWD(__fp16,kFwdUseCausal,BM,128,BN,true,true,true); + } else { + if (d_head == 16) + CALL_FWD(__fp16,kFwdNoCausal,BM,16,BN,true,true,true); + else if (d_head == 32) + CALL_FWD(__fp16,kFwdNoCausal,BM,32,BN,true,true,true); + else if (d_head == 64) + CALL_FWD(__fp16,kFwdNoCausal,BM,64,BN,true,true,true); + else if (d_head == 128) + CALL_FWD(__fp16,kFwdNoCausal,BM,128,BN,true,true,true); + } + } else if (q_dtype == at::kBFloat16) { + if (is_causal) { + if (d_head == 16) + CALL_FWD(__bf16,kFwdUseCausal,BM,16,BN,true,true,true); + else if (d_head == 32) + CALL_FWD(__bf16,kFwdUseCausal,BM,32,BN,true,true,true); + else if (d_head == 64) + CALL_FWD(__bf16,kFwdUseCausal,BM,64,BN,true,true,true); + else if (d_head == 128) + CALL_FWD(__bf16,kFwdUseCausal,BM,128,BN,true,true,true); + } else { + if (d_head == 16) + CALL_FWD(__bf16,kFwdNoCausal,BM,16,BN,true,true,true); + else if (d_head == 32) + CALL_FWD(__bf16,kFwdNoCausal,BM,32,BN,true,true,true); + else if (d_head == 64) + CALL_FWD(__bf16,kFwdNoCausal,BM,64,BN,true,true,true); + else if (d_head == 128) + CALL_FWD(__bf16,kFwdNoCausal,BM,128,BN,true,true,true); + } + } + + //undo reorder tensors + q_padded = q_t.permute({0,2,1,3}).contiguous(); + k_padded = k_t.permute({0,2,1,3}).contiguous(); + v_padded = v_t.permute({0,2,1,3}).contiguous(); + out = output_t.permute({0,2,1,3}).contiguous(); return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t}; +#undef CALL_FWD } std::tuple @@ -403,10 +354,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); + check_gpu_arch(); bool is_dropout = p_dropout > 0.0; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, @@ -489,12 +440,23 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si // const at::Tensor& dout_padded = dout; + // bool loop = seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; auto opts = q.options(); auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + at::Tensor dk_accum, dv_accum; + if (loop) { + dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + // dk_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // dv_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + } at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA @@ -506,52 +468,149 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si } at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { + if (is_dropout) { if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); } else { // dropout + capture - philox_args = at::PhiloxCudaState(philox_seed.data_ptr(), philox_offset.data_ptr(), 0); + philox_args = at::PhiloxCudaState( + philox_seed.data_ptr(), philox_offset.data_ptr(), 0); } } - at::Tensor q_t = q.permute({0,2,1,3}); - at::Tensor k_t = k.permute({0,2,1,3}); - at::Tensor v_t = v.permute({0,2,1,3}); - at::Tensor out_t = out.permute({0,2,1,3}); - at::Tensor dq_t = dq.permute({0,2,1,3}); - at::Tensor dk_t = dk.permute({0,2,1,3}); - at::Tensor dv_t = dv.permute({0,2,1,3}); - at::Tensor dout_t = dout.permute({0,2,1,3}); + //JCG TODO WE GO IN HERE TODO backwards + //reorder tensors and make contiguous + at::Tensor q_t = q.permute({0,2,1,3}).contiguous(); + at::Tensor k_t = k.permute({0,2,1,3}).contiguous(); + at::Tensor v_t = v.permute({0,2,1,3}).contiguous(); + at::Tensor out_t = out.permute({0,2,1,3}).contiguous(); - at::Tensor softmax_lse_cont = softmax_lse.contiguous(); + //reorder tensors and make contiguous + at::Tensor dq_t = dq.permute({0,2,1,3}).contiguous(); + at::Tensor dk_t = dk.permute({0,2,1,3}).contiguous(); + at::Tensor dv_t = dv.permute({0,2,1,3}).contiguous(); + at::Tensor dout_t = dout.permute({0,2,1,3}).contiguous(); + + dim3 block { 64 * 4, 1, 1 }; + + at::Tensor new_do = at::empty_like(dout_t).contiguous(); at::Tensor delta = at::empty_like(softmax_lse).contiguous(); int d_head = head_size_og; hipError_t err; // TODO: Error handling - { - TensorStorageSanitizer dq_s(q_t, dq_t); - TensorStorageSanitizer dk_s(k_t, dk_t); - TensorStorageSanitizer dv_s(v_t, dv_t); - using aotriton::v2::flash::attn_bwd; - err = attn_bwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_s.sanitized_tensor(), "dq"), - mk_aotensor(dk_s.sanitized_tensor(), "dk"), - mk_aotensor(dv_s.sanitized_tensor(), "dv"), - mk_aotensor<2>(softmax_lse_cont, "L"), - mk_aotensor<2>(delta, "delta"), - p_dropout, - philox_args.seed_.val, - philox_args.offset_.val, - is_causal, - stream); +#define CALL_BWD_PP(FP, PP_BLOCK, PP_DMODEL) \ + do { \ + dim3 pp_grid; \ + pp_grid.x = batch_size * num_heads * ((dout_t.size(2) + PP_BLOCK - 1) / PP_BLOCK); \ + pp_grid.y = 1; \ + pp_grid.z = 1; \ + oort::bwd_preprocess pre_opt; \ + err = pre_opt(pp_grid, block, \ + (FP*)(out_t.data_ptr()), \ + (FP*)(dout_t.data_ptr()), \ + (FP*)(new_do.data_ptr()), \ + (float*)(delta.data_ptr()), \ + stream); \ + } while (0) + +#define CALL_BWD_PP_DMODEL(FP, PP_BLOCK) \ + do { \ + if (d_head == 16) \ + CALL_BWD_PP(FP, PP_BLOCK, 16); \ + else if (d_head == 32) \ + CALL_BWD_PP(FP, PP_BLOCK, 32); \ + else if (d_head == 64) \ + CALL_BWD_PP(FP, PP_BLOCK, 64); \ + else if (d_head == 128) \ + CALL_BWD_PP(FP, PP_BLOCK, 128); \ + } while (0) + + if(q_dtype == at::kHalf) { + if (seqlen_q >= 64) + CALL_BWD_PP_DMODEL(__fp16, 16); + else + CALL_BWD_PP_DMODEL(__fp16, 16); + } else if (q_dtype == at::kBFloat16) { + if (seqlen_q >= 64) + CALL_BWD_PP_DMODEL(__bf16, 16); + else + CALL_BWD_PP_DMODEL(__bf16, 16); } +#undef CALL_BWD_PP + +#define CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT) \ + do { \ + dim3 grid; \ + grid.x = (seqlen_k + BLOCK_M - 1) / BLOCK_M; \ + grid.y = batch_size * num_heads; \ + grid.z = 1; \ + oort::bwd_kernel_dk_dv dk_dv_opt; \ + err = dk_dv_opt(grid, block, \ + (FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \ + softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \ + (FP*)dk_t.data_ptr(),(FP*)dv_t.data_ptr(), \ + (float*)(softmax_lse.data_ptr()), \ + (float*)(delta.data_ptr()), \ + q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \ + k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \ + v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \ + q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \ + (uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \ + grid.x = (seqlen_q + BLOCK_M - 1) / BLOCK_M; \ + oort::bwd_kernel_dq dq_opt; \ + err = dq_opt(grid, block, \ + (FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \ + softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \ + (FP*)dq_t.data_ptr(), \ + (float*)(softmax_lse.data_ptr()), \ + (float*)(delta.data_ptr()), \ + q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \ + k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \ + v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \ + q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \ + (uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \ + } while(0) + +#define CALL_BWD_DROPOUT(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL) \ + do { \ + if (p_dropout > 0.0) { \ + CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, true); \ + } else { \ + CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, false); \ + } \ + } while (0) + +#define CALL_BWD_DROPOUT_DMODEL(FP, BLOCK_M, BLOCK_N, CAUSAL) \ + do { \ + if (d_head == 16) \ + CALL_BWD_DROPOUT(FP, BLOCK_M, 16, BLOCK_N, CAUSAL); \ + else if (d_head == 32) \ + CALL_BWD_DROPOUT(FP, BLOCK_M, 32, BLOCK_N, CAUSAL); \ + else if (d_head == 64) \ + CALL_BWD_DROPOUT(FP, BLOCK_M, 64, BLOCK_N, CAUSAL); \ + else if (d_head == 128) \ + CALL_BWD_DROPOUT(FP, BLOCK_M, 128, BLOCK_N, CAUSAL); \ + } while (0) + + if (q_dtype == at::kHalf) { + if (is_causal) { + CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, true); + } else { + CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, false); + } + } else if (q_dtype == at::kBFloat16) { + if (is_causal) { + CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, true); + } else { + CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, false); + } + } + + //undo reorder tensors for returns + dq = dq_t.permute({0,2,1,3}).contiguous(); + dk = dk_t.permute({0,2,1,3}).contiguous(); + dv = dv_t.permute({0,2,1,3}).contiguous(); // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index e64243d45ed29..93ec759b4972f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -985,8 +985,7 @@ if(USE_ROCM) list(APPEND Caffe2_HIP_SRCS ${GENERATED_CXX_TORCH_CUDA}) hip_add_library(torch_hip ${Caffe2_HIP_SRCS}) if(USE_FLASH_ATTENTION) - target_link_libraries(torch_hip PRIVATE __caffe2_aotriton) - add_dependencies(torch_hip aotriton_external) + target_link_libraries(torch_hip PRIVATE __caffe2_oort) endif() set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_hip) # see cmake/public/utils.cmake diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a96075245aedf..892bad591887c 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1335,7 +1335,9 @@ if(USE_ROCM) message(STATUS "Disabling Kernel Assert for ROCm") endif() - include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) + if(USE_FLASH_ATTENTION) + include(${CMAKE_CURRENT_LIST_DIR}/External/oort.cmake) + endif() if(USE_CUDA) caffe2_update_option(USE_MEM_EFF_ATTENTION OFF) endif() diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake deleted file mode 100644 index ca9725451049c..0000000000000 --- a/cmake/External/aotriton.cmake +++ /dev/null @@ -1,28 +0,0 @@ -if(NOT __AOTRITON_INCLUDED) - set(__AOTRITON_INCLUDED TRUE) - - set(__AOTRITON_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/src") - set(__AOTRITON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/build") - set(__AOTRITON_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch") - ExternalProject_Add(aotriton_external - GIT_REPOSITORY https://github.com/ROCm/aotriton.git - GIT_TAG 9044fe5eb16130e49a0a1f781ea15037353ad542 - SOURCE_DIR ${__AOTRITON_SOURCE_DIR} - BINARY_DIR ${__AOTRITON_BUILD_DIR} - PREFIX ${__AOTRITON_INSTALL_DIR} - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} - -DAOTRITON_COMPRESS_KERNEL=OFF - -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} - -DAOTRITON_NO_PYTHON=ON - -DAOTRITON_NO_SHARED=ON - # CONFIGURE_COMMAND "" - # BUILD_COMMAND ${MAKE_COMMAND} - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a" - # INSTALL_COMMAND ${MAKE_COMMAND} install - ) - set(AOTRITON_FOUND TRUE) - add_library(__caffe2_aotriton INTERFACE) - add_dependencies(__caffe2_aotriton aotriton_external) - target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a) - target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) -endif() # __AOTRITON_INCLUDED diff --git a/cmake/External/oort.cmake b/cmake/External/oort.cmake new file mode 100644 index 0000000000000..29c9a1005a7fb --- /dev/null +++ b/cmake/External/oort.cmake @@ -0,0 +1,25 @@ +if(NOT __OORT_INCLUDED) + set(__OORT_INCLUDED TRUE) + + set(__OORT_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/src") + set(__OORT_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/build") + set(__OORT_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch") + ExternalProject_Add(oort_external + GIT_REPOSITORY https://github.com/ROCmSoftwarePlatform/triton.git + GIT_TAG 29e1252c1ac8e6a54deb883701e553e5b201a1ba + SOURCE_DIR ${__OORT_SOURCE_DIR} + SOURCE_SUBDIR mathaot + BINARY_DIR ${__OORT_BUILD_DIR} + PREFIX ${__OORT_INSTALL_DIR} + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__OORT_INSTALL_DIR} + # CONFIGURE_COMMAND "" + # BUILD_COMMAND ${MAKE_COMMAND} + BUILD_BYPRODUCTS "${__OORT_INSTALL_DIR}/lib/liboort.a" + # INSTALL_COMMAND ${MAKE_COMMAND} install + ) + set(OORT_FOUND TRUE) + add_library(__caffe2_oort INTERFACE) + add_dependencies(__caffe2_oort oort_external) + target_link_libraries(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/lib/liboort.a) + target_include_directories(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/include) +endif() # __OORT_INCLUDED diff --git a/test/test_transformers.py b/test/test_transformers.py index c60c32c15302c..8b5f4cdffb75e 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2396,7 +2396,7 @@ def test_sdp_mem_efficient_grad_against_math(self, device, contiguous_inputs: bo # Cast up and compare self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5) - @skipIfRocm # TODO: Packed QKV + @skipIfRocm # Small matrices @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system") @parametrize("contiguous_inputs", [True, False]) @parametrize("is_causal", [True, False]) @@ -2798,6 +2798,16 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, scale: str): + if TEST_WITH_ROCM: + def is_power_of_2(n): + return n & (n - 1) == 0 + if not is_power_of_2(seq_len_q) or not is_power_of_2(seq_len_k) or not is_power_of_2(head_dim): + self.skipTest("Flash attention on ROCM only supports power of two seq_len_q seq_len_k headdim, for now.") + if head_dim < 16 or seq_len_q < 16 or seq_len_k < 16: + self.skipTest("Flash attention on ROCM only supports power of two seq_len_q, seq_len_k, headdim >= 16, for now.") + if head_dim > 128: + self.skipTest("Flash attention on ROCM only supports power of two headdim <= 128, for now.") + if isSM8XDevice and head_dim in range(193, 256 + 1): self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") if is_causal and seq_len_q != seq_len_k: @@ -3418,7 +3428,6 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: Lis self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None) - @skipIfRocm # CausalVariant @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) @parametrize( "shape", diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 2a9055597f732..13abf02f1e603 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -31,19 +31,18 @@ SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)) SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)) -def evaluate_gfx_arch_exact(matching_arch): +def evaluate_gfx90a_exact(): if not torch.cuda.is_available(): return False gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) - return arch == matching_arch + return arch == 'gfx90a:sramecc+:xnack-' -GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-')) -GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')) +GFX90A_Exact = LazyVal(lambda: evaluate_gfx90a_exact()) def evaluate_platform_supports_flash_attention(): if TEST_WITH_ROCM: - return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') + return evaluate_gfx90a_exact() if TEST_CUDA: return not IS_WINDOWS and SM80OrLater return False