Skip to content

Commit

Permalink
Revert "Add Flash Attention support on ROCM (pytorch#121561)"
Browse files Browse the repository at this point in the history
This reverts commit a37e22d.

Reverted pytorch#121561 on behalf of https://github.com/huydhn due to Sorry for reverting your change but this needs more work to be able to land in fbcode because https://github.com/ROCm/aotriton is not available there atm.  We are working to reland this change before 2.3 release ([comment](pytorch#121561 (comment)))
  • Loading branch information
pytorchmergebot committed Mar 19, 2024
1 parent 88ebdbc commit 764eae9
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 264 deletions.
23 changes: 5 additions & 18 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -742,28 +742,13 @@ if(MSVC)
append_cxx_flag_if_supported("/utf-8" CMAKE_CXX_FLAGS)
endif()

# Note for ROCM platform:
# 1. USE_ROCM is always ON until include(cmake/Dependencies.cmake)
# 2. USE_CUDA will become OFF during re-configuration
# Truth Table:
# CUDA 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default
# CUDA 2nd pass: USE_CUDA=True;USE_ROCM=False, FLASH evaluates to ON by default
# ROCM 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default
# ROCM 2nd pass: USE_CUDA=False;USE_ROCM=True, FLASH evaluates to ON by default
# CPU 1st pass: USE_CUDA=False(Cmd Option);USE_ROCM=True, FLASH evaluates to OFF by default
# CPU 2nd pass: USE_CUDA=False(Cmd Option);USE_ROCM=False, FLASH evaluates to OFF by default
# Thus we cannot tell ROCM 2nd pass and CPU 1st pass
#
# The only solution is to include(cmake/Dependencies.cmake), and defer the
# aotriton build decision later.

include(cmake/Dependencies.cmake)

# CAVEAT: do NOT check USE_ROCM here, because USE_ROCM is always True until
# include(cmake/Dependencies.cmake)
cmake_dependent_option(
USE_FLASH_ATTENTION
"Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform" ON
"USE_CUDA OR USE_ROCM;NOT MSVC" OFF)
"USE_CUDA AND NOT MSVC" OFF)

# We are currenlty not using alibi attention for Flash
# So we disable this feature by default
Expand All @@ -779,6 +764,8 @@ cmake_dependent_option(
Will be disabled if not supported by the platform" ON
"USE_CUDA" OFF)

include(cmake/Dependencies.cmake)

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
Expand Down
36 changes: 23 additions & 13 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
#include <cmath>
#include <functional>

#if USE_ROCM
#include <aotriton/flash.h>
#endif

/**
* Note [SDPA Runtime Dispatch]
* SDPA relies on a runtime dispatch mechanism to select the appropriate
Expand Down Expand Up @@ -186,18 +182,32 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
// Check that the gpu is capable of running flash attention
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (debug) {
TORCH_WARN(
"Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
}
return false;
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
static const char *over_arch = [] {
auto rc = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
if (rc) {
TORCH_WARN("SDPA functions only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
"Later changes to this environment variable with os.environ "
"(or other methods) will not affect SDPA function's behavior.");
}
return rc;
}();
const char* real_arch = dprops->gcnArchName;
const char* arch = over_arch ? over_arch : real_arch;
if (mi200 != arch) {
if (debug) {
TORCH_WARN(
"Flash attention only supports gpu architecture gfx90a, for now. Attempting to run on a ",
arch,
".",
over_arch ? " This is overrided by PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE. Real architecture is " : "",
over_arch ? real_arch : "");
}
return false;
}
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm90>(dprops)) {
if (debug) {
TORCH_WARN(
Expand Down
Loading

0 comments on commit 764eae9

Please sign in to comment.