Skip to content

Commit

Permalink
update the rcom to rocm
Browse files Browse the repository at this point in the history
  • Loading branch information
A-006 committed Nov 25, 2024
1 parent 84df4d2 commit 4e12e56
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ if (USE_CUDA)
endif()
if (USE_ROCM)
list (APPEND FFT_SRC
module_fft/fft_rcom.cpp
module_fft/fft_rocm.cpp
)
endif()

Expand Down
6 changes: 3 additions & 3 deletions source/module_basis/module_pw/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "fft_cuda.h"
#endif
#if defined(__ROCM)
#include "fft_rcom.h"
#include "fft_rocm.h"
#endif

template<typename FFT_BASE, typename... Args>
Expand Down Expand Up @@ -89,9 +89,9 @@ void FFT_Bundle::initfft(int nx_in,
if (device=="gpu")
{
#if defined(__ROCM)
fft_float = make_unique<FFT_RCOM<float>>();
fft_float = make_unique<FFT_ROCM<float>>();
fft_float->initfft(nx_in,ny_in,nz_in);
fft_double = make_unique<FFT_RCOM<double>>();
fft_double = make_unique<FFT_ROCM<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
#elif defined(__CUDA)
fft_float = make_unique<FFT_CUDA<float>>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "fft_rcom.h"
#include "fft_rocm.h"
#include "module_base/module_device/memory_op.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
namespace ModulePW
{
template <typename FPTYPE>
void FFT_RCOM<FPTYPE>::initfft(int nx_in,
void FFT_ROCM<FPTYPE>::initfft(int nx_in,
int ny_in,
int nz_in)
{
Expand All @@ -13,20 +13,20 @@ void FFT_RCOM<FPTYPE>::initfft(int nx_in,
this->nz = nz_in;
}
template <>
void FFT_RCOM<float>::setupFFT()
void FFT_ROCM<float>::setupFFT()
{
hipfftPlan3d(&c_handle, this->nx, this->ny, this->nz, HIPFFT_C2C);
resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz);

}
template <>
void FFT_RCOM<double>::setupFFT()
void FFT_ROCM<double>::setupFFT()
{
hipfftPlan3d(&z_handle, this->nx, this->ny, this->nz, HIPFFT_Z2Z);
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
}
template <>
void FFT_RCOM<float>::cleanFFT()
void FFT_ROCM<float>::cleanFFT()
{
if (c_handle)
{
Expand All @@ -35,7 +35,7 @@ void FFT_RCOM<float>::cleanFFT()
}
}
template <>
void FFT_RCOM<double>::cleanFFT()
void FFT_ROCM<double>::cleanFFT()
{
if (z_handle)
{
Expand All @@ -44,7 +44,7 @@ void FFT_RCOM<double>::cleanFFT()
}
}
template <>
void FFT_RCOM<float>::clear()
void FFT_ROCM<float>::clear()
{
this->cleanFFT();
if (c_auxr_3d != nullptr)
Expand All @@ -54,7 +54,7 @@ void FFT_RCOM<float>::clear()
}
}
template <>
void FFT_RCOM<double>::clear()
void FFT_ROCM<double>::clear()
{
this->cleanFFT();
if (z_auxr_3d != nullptr)
Expand All @@ -64,7 +64,7 @@ void FFT_RCOM<double>::clear()
}
}
template <>
void FFT_RCOM<float>::fft3D_forward(std::complex<float>* in,
void FFT_ROCM<float>::fft3D_forward(std::complex<float>* in,
std::complex<float>* out) const
{
CHECK_CUFFT(hipfftExecC2C(this->c_handle,
Expand All @@ -73,7 +73,7 @@ void FFT_RCOM<float>::fft3D_forward(std::complex<float>* in,
HIPFFT_FORWARD));
}
template <>
void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in,
void FFT_ROCM<double>::fft3D_forward(std::complex<double>* in,
std::complex<double>* out) const
{
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle,
Expand All @@ -82,7 +82,7 @@ void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in,
HIPFFT_FORWARD));
}
template <>
void FFT_RCOM<float>::fft3D_backward(std::complex<float>* in,
void FFT_ROCM<float>::fft3D_backward(std::complex<float>* in,
std::complex<float>* out) const
{
CHECK_CUFFT(hipfftExecC2C(this->c_handle,
Expand All @@ -91,7 +91,7 @@ void FFT_RCOM<float>::fft3D_backward(std::complex<float>* in,
HIPFFT_BACKWARD));
}
template <>
void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in,
void FFT_ROCM<double>::fft3D_backward(std::complex<double>* in,
std::complex<double>* out) const
{
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle,
Expand All @@ -100,7 +100,11 @@ void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in,
HIPFFT_BACKWARD));
}
template <> std::complex<float>*
FFT_RCOM<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
FFT_ROCM<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
template <> std::complex<double>*
FFT_RCOM<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
FFT_ROCM<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
template FFT_ROCM<float>::FFT_ROCM();
template FFT_ROCM<float>::~FFT_ROCM();
template FFT_ROCM<double>::FFT_ROCM();
template FFT_ROCM<double>::~FFT_ROCM();
}// namespace ModulePW
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
namespace ModulePW
{
template <typename FPTYPE>
class FFT_RCOM : public FFT_BASE<FPTYPE>
class FFT_ROCM : public FFT_BASE<FPTYPE>
{
public:
FFT_RCOM(){};
~FFT_RCOM(){};
FFT_ROCM(){};
~FFT_ROCM(){};

void setupFFT() override;

Expand Down Expand Up @@ -57,9 +57,5 @@ class FFT_RCOM : public FFT_BASE<FPTYPE>
mutable std::complex<double>* z_auxr_3d = nullptr; // fft space

};
template FFT_RCOM<float>::FFT_RCOM();
template FFT_RCOM<float>::~FFT_RCOM();
template FFT_RCOM<double>::FFT_RCOM();
template FFT_RCOM<double>::~FFT_RCOM();
}// namespace ModulePW
#endif

0 comments on commit 4e12e56

Please sign in to comment.