Skip to content

Commit

Permalink
abort if destroy fails
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 4, 2024
1 parent 8e70b0c commit e510c5e
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 26 deletions.
8 changes: 0 additions & 8 deletions fft/src/KokkosFFT_Cuda_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ auto create_plan(const ExecutionSpace& exec_space,
using out_value_type = typename OutViewType::non_const_value_type;

plan = std::make_unique<PlanType>();
// cufftResult cufft_rt = cufftCreate(&(*plan));
// KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftResult cufft_rt = cufftSetStream((*plan).plan(), stream);
Expand Down Expand Up @@ -78,8 +76,6 @@ auto create_plan(const ExecutionSpace& exec_space,
using out_value_type = typename OutViewType::non_const_value_type;

plan = std::make_unique<PlanType>();
// cufftResult cufft_rt = cufftCreate(&(*plan));
// KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftResult cufft_rt = cufftSetStream((*plan).plan(), stream);
Expand Down Expand Up @@ -121,8 +117,6 @@ auto create_plan(const ExecutionSpace& exec_space,
using out_value_type = typename OutViewType::non_const_value_type;

plan = std::make_unique<PlanType>();
// cufftResult cufft_rt = cufftCreate(&(*plan));
// KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftResult cufft_rt = cufftSetStream((*plan).plan(), stream);
Expand Down Expand Up @@ -186,8 +180,6 @@ auto create_plan(const ExecutionSpace& exec_space,
int istride = 1, ostride = 1;

plan = std::make_unique<PlanType>();
// cufftResult cufft_rt = cufftCreate(&(*plan));
// KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftResult cufft_rt = cufftSetStream((*plan).plan(), stream);
Expand Down
14 changes: 12 additions & 2 deletions fft/src/KokkosFFT_Cuda_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#define KOKKOSFFT_CUDA_TYPES_HPP

#include <cufft.h>
#include <Kokkos_Abort.hpp>
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_asserts.hpp"

#if defined(ENABLE_HOST_AND_DEVICE)
#include "KokkosFFT_FFTW_Types.hpp"
Expand All @@ -25,10 +27,18 @@ using FFTDirectionType = int;

/// \brief A class that wraps cufft for RAII
struct ScopedCufftPlanType {
private:
cufftHandle m_plan;

ScopedCufftPlanType() { cufftCreate(&m_plan); }
~ScopedCufftPlanType() { cufftDestroy(m_plan); }
public:
ScopedCufftPlanType() {
cufftResult cufft_rt = cufftCreate(&m_plan);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");
}
~ScopedCufftPlanType() {
cufftResult cufft_rt = cufftDestroy(m_plan);
if (cufft_rt != CUFFT_SUCCESS) Kokkos::abort("cufftDestroy failed");
}

cufftHandle &plan() { return m_plan; }
};
Expand Down
8 changes: 0 additions & 8 deletions fft/src/KokkosFFT_HIP_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ auto create_plan(const ExecutionSpace& exec_space,
using out_value_type = typename OutViewType::non_const_value_type;

plan = std::make_unique<PlanType>();
// hipfftResult hipfft_rt = hipfftCreate(&(*plan));
// KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftCreate failed");

hipStream_t stream = exec_space.hip_stream();
hipfftResult hipfft_rt = hipfftSetStream((*plan).plan(), stream);
Expand Down Expand Up @@ -78,8 +76,6 @@ auto create_plan(const ExecutionSpace& exec_space,
using out_value_type = typename OutViewType::non_const_value_type;

plan = std::make_unique<PlanType>();
// hipfftResult hipfft_rt = hipfftCreate(&(*plan));
// KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftCreate failed");

hipStream_t stream = exec_space.hip_stream();
hipfftResult hipfft_rt = hipfftSetStream((*plan).plan(), stream);
Expand Down Expand Up @@ -121,8 +117,6 @@ auto create_plan(const ExecutionSpace& exec_space,
using out_value_type = typename OutViewType::non_const_value_type;

plan = std::make_unique<PlanType>();
// hipfftResult hipfft_rt = hipfftCreate(&(*plan));
// KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftCreate failed");

hipStream_t stream = exec_space.hip_stream();
hipfftResult hipfft_rt = hipfftSetStream((*plan).plan(), stream);
Expand Down Expand Up @@ -186,8 +180,6 @@ auto create_plan(const ExecutionSpace& exec_space,
int istride = 1, ostride = 1;

plan = std::make_unique<PlanType>();
// hipfftResult hipfft_rt = hipfftCreate(&(*plan));
// KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftCreate failed");

hipStream_t stream = exec_space.hip_stream();
hipfftResult hipfft_rt = hipfftSetStream((*plan).plan(), stream);
Expand Down
14 changes: 12 additions & 2 deletions fft/src/KokkosFFT_HIP_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#define KOKKOSFFT_HIP_TYPES_HPP

#include <hipfft/hipfft.h>
#include <Kokkos_Abort.hpp>
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_asserts.hpp"

#if defined(ENABLE_HOST_AND_DEVICE)
#include "KokkosFFT_FFTW_Types.hpp"
Expand All @@ -25,10 +27,18 @@ using FFTDirectionType = int;

/// \brief A class that wraps hipfft for RAII
struct ScopedHIPfftPlanType {
private:
hipfftHandle m_plan;

ScopedHIPfftPlanType() { hipfftCreate(&m_plan); }
~ScopedHIPfftPlanType() { hipfftDestroy(m_plan); }
public:
ScopedHIPfftPlanType() {
hipfftResult hipfft_rt = hipfftCreate(&m_plan);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftCreate failed");
}
~ScopedHIPfftPlanType() {
hipfftResult hipfft_rt = hipfftDestroy(m_plan);
if (hipfft_rt != HIPFFT_SUCCESS) Kokkos::abort("hipfftDestroy failed");
}

hipfftHandle &plan() { return m_plan; }
};
Expand Down
6 changes: 3 additions & 3 deletions fft/src/KokkosFFT_ROCM_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ auto create_plan(const ExecutionSpace& exec_space,
);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_create failed");
plan->m_is_plan_created = true;
plan->set_is_plan_created();

// Prepare workbuffer and set execution information
status = rocfft_execution_info_create(&(plan->execution_info()));
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_create failed");
plan->m_is_info_created = true;
plan->set_is_info_created();

// set stream
// NOTE: The stream must be of type hipStream_t.
Expand All @@ -194,7 +194,7 @@ auto create_plan(const ExecutionSpace& exec_space,
if (workbuffersize > 0) {
plan->allocate_work_buffer(workbuffersize);
status = rocfft_execution_info_set_work_buffer(
plan->execution_info(), (void*)plan->m_buffer.data(), workbuffersize);
plan->execution_info(), (void*)plan->buffer_data(), workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_work_buffer failed");
}
Expand Down
22 changes: 19 additions & 3 deletions fft/src/KokkosFFT_ROCM_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

#include <complex>
#include <rocfft/rocfft.h>
#include <Kokkos_Abort.hpp>
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_asserts.hpp"
#if defined(ENABLE_HOST_AND_DEVICE)
#include "KokkosFFT_FFTW_Types.hpp"
#endif
Expand Down Expand Up @@ -37,30 +39,44 @@ using TransformType = FFTWTransformType;
/// \brief A class that wraps rocfft for RAII
template <typename ExecutionSpace, typename T>
struct ScopedRocfftPlanType {
private:
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T>;
rocfft_plan m_plan;
rocfft_execution_info m_execution_info;

using BufferViewType =
Kokkos::View<Kokkos::complex<floating_point_type> *, ExecutionSpace>;

bool m_is_info_created = false;
bool m_is_plan_created = false;
bool m_is_info_created = false;

//! Internal work buffer
BufferViewType m_buffer;

public:
ScopedRocfftPlanType() {}
~ScopedRocfftPlanType() {
if (m_is_info_created) rocfft_execution_info_destroy(m_execution_info);
if (m_is_plan_created) rocfft_plan_destroy(m_plan);
if (m_is_info_created) {
rocfft_status status = rocfft_execution_info_destroy(m_execution_info);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_execution_info_destroy failed");
}
if (m_is_plan_created) {
rocfft_status status = rocfft_plan_destroy(m_plan);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_plan_destroy failed");
}
}

void set_is_plan_created() { m_is_plan_created = true; }
void set_is_info_created() { m_is_info_created = true; }

void allocate_work_buffer(std::size_t workbuffersize) {
m_buffer = BufferViewType("work buffer", workbuffersize);
}
rocfft_plan &plan() { return m_plan; }
rocfft_execution_info &execution_info() { return m_execution_info; }
auto *buffer_data() { return m_buffer.data(); }
};

// Define fft transform types
Expand Down

0 comments on commit e510c5e

Please sign in to comment.