From 6e6855f84c1d5d219febe662327f240ab9485f76 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Wed, 20 Nov 2024 10:10:09 +0800 Subject: [PATCH 01/11] delete fft.cpp --- source/Makefile.Objects | 3 +- source/module_basis/module_pw/CMakeLists.txt | 1 - source/module_basis/module_pw/fft.cpp | 881 ------------------ source/module_basis/module_pw/fft.h | 153 +-- source/module_basis/module_pw/test/Makefile | 1 - .../module_pw/test_serial/CMakeLists.txt | 1 - .../test/charge_extra_test.cpp | 6 - .../test/elecstate_base_test.cpp | 6 - .../module_xc/test/xc3_mock.h | 2 - source/module_hsolver/test/hsolver_pw_sup.h | 3 - 10 files changed, 2 insertions(+), 1055 deletions(-) delete mode 100644 source/module_basis/module_pw/fft.cpp diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 3f76995e7a..62c95ac83b 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -407,8 +407,7 @@ OBJS_PSI_INITIALIZER=psi_initializer.o\ psi_initializer_nao.o\ psi_initializer_nao_random.o\ -OBJS_PW=fft.o\ - fft_bundle.o\ +OBJS_PW=fft_bundle.o\ fft_base.o\ fft_cpu.o\ pw_basis.o\ diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index 78f9824d8b..ee5154f2c7 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -15,7 +15,6 @@ if (USE_ROCM) endif() list(APPEND objects - fft.cpp pw_basis.cpp pw_basis_k.cpp pw_basis_sup.cpp diff --git a/source/module_basis/module_pw/fft.cpp b/source/module_basis/module_pw/fft.cpp deleted file mode 100644 index fa94bd6442..0000000000 --- a/source/module_basis/module_pw/fft.cpp +++ /dev/null @@ -1,881 +0,0 @@ -#include "fft.h" - -#include "module_base/memory.h" -#include "module_base/tool_quit.h" -#include "module_hamilt_pw/hamilt_pwdft/global.h" - -namespace ModulePW -{ - -FFT::FFT() -{ -} - -FFT::~FFT() -{ - this->clear(); -} -void FFT::clear() -{ - this->cleanFFT(); - if (z_auxg != nullptr) - { - fftw_free(z_auxg); - z_auxg = nullptr; - } - if (z_auxr != nullptr) - { - fftw_free(z_auxr); - z_auxr = nullptr; - } - d_rspace = nullptr; -#if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") - { - if (c_auxr_3d != nullptr) - { - delmem_cd_op()(gpu_ctx, c_auxr_3d); - c_auxr_3d = nullptr; - } - if (z_auxr_3d != nullptr) - { - delmem_zd_op()(gpu_ctx, z_auxr_3d); - z_auxr_3d = nullptr; - } - } -#endif // defined(__CUDA) || defined(__ROCM) -#if defined(__ENABLE_FLOAT_FFTW) - if (this->precision == "single") - { - this->cleanfFFT(); - if (c_auxg != nullptr) - { - fftw_free(c_auxg); - c_auxg = nullptr; - } - if (c_auxr != nullptr) - { - fftw_free(c_auxr); - c_auxr = nullptr; - } - s_rspace = nullptr; - } -#endif // defined(__ENABLE_FLOAT_FFTW) -} - -void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, int nproc_in, - bool gamma_only_in, bool xprime_in, bool mpifft_in) -{ - this->gamma_only = gamma_only_in; - this->xprime = xprime_in; - this->fftnx = this->nx = nx_in; - this->fftny = this->ny = ny_in; - if (this->gamma_only) - { - if (xprime) { - this->fftnx = int(nx / 2) + 1; - } else { - this->fftny = int(ny / 2) + 1; -} - } - this->nz = nz_in; - this->ns = ns_in; - this->lixy = lixy_in; - this->rixy = rixy_in; - this->nplane = nplane_in; - this->nproc = nproc_in; - this->mpifft = mpifft_in; - this->nxy = this->nx * this->ny; - this->fftnxy = this->fftnx * this->fftny; - // this->maxgrids = (this->nz * this->ns > this->nxy * nplane) ? this->nz * this->ns : this->nxy * nplane; - const int nrxx = this->nxy * this->nplane; - const int nsz = this->nz * this->ns; - int maxgrids = (nsz > nrxx) ? nsz : nrxx; - if (!this->mpifft) - { - // z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); - // z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); - // ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids); - // d_rspace = (double*)z_auxg; - // auxr_3d = static_cast *>( - // fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz))); -#if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") - { - resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz); - resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz); - } -#endif // defined(__CUDA) || defined(__ROCM) -// #if defined(__ENABLE_FLOAT_FFTW) -// if (this->precision == "single") -// { -// c_auxg = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); -// c_auxr = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); -// ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids); -// s_rspace = (float*)c_auxg; -// } -// #endif // defined(__ENABLE_FLOAT_FFTW) - } - else - { - } -} - -void FFT::setupFFT() -{ - unsigned int flag = FFTW_ESTIMATE; - switch (this->fft_mode) - { - case 0: - flag = FFTW_ESTIMATE; - break; - case 1: - flag = FFTW_MEASURE; - break; - case 2: - flag = FFTW_PATIENT; - break; - case 3: - flag = FFTW_EXHAUSTIVE; - break; - default: - break; - } - if (!this->mpifft) - { - this->initplan(flag); -#if defined(__ENABLE_FLOAT_FFTW) - if (this->precision == "single") - { - this->initplanf(flag); - } -#endif // defined(__ENABLE_FLOAT_FFTW) - } -#if defined(__FFTW3_MPI) && defined(__MPI) - else - { - // this->initplan_mpi(); - // if (this->precision == "single") { - // this->initplanf_mpi(); - // } - } -#endif - return; -} - -void FFT ::initplan(const unsigned int& flag) -{ - //--------------------------------------------------------- - // 1 D - Z - //--------------------------------------------------------- - - // fftw_plan_many_dft(int rank, const int *n, int howmany, - // fftw_complex *in, const int *inembed, int istride, int idist, - // fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned - //flags); - - this->planzfor = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, - (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); - - this->planzbac = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, - (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); - - //--------------------------------------------------------- - // 2 D - XY - //--------------------------------------------------------- - // 1D+1D is much faster than 2D FFT! - // in-place fft is better for c2c and out-of-place fft is better for c2r - int* embed = nullptr; - int npy = this->nplane * this->ny; - if (this->xprime) - { - this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, nplane, 1, - (fftw_complex*)z_auxr, embed, nplane, 1, FFTW_FORWARD, flag); - this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, nplane, 1, - (fftw_complex*)z_auxr, embed, nplane, 1, FFTW_BACKWARD, flag); - if (this->gamma_only) - { - this->planxr2c = fftw_plan_many_dft_r2c(1, &this->nx, npy, d_rspace, embed, npy, 1, (fftw_complex*)z_auxr, - embed, npy, 1, flag); - this->planxc2r = fftw_plan_many_dft_c2r(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, d_rspace, - embed, npy, 1, flag); - } - else - { - this->planxfor1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, - (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, - (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - } - } - else - { - this->planxfor1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftw_complex*)z_auxr, embed, npy, - 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftw_complex*)z_auxr, embed, npy, - 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - if (this->gamma_only) - { - this->planyr2c = fftw_plan_many_dft_r2c(1, &this->ny, this->nplane, d_rspace, embed, this->nplane, 1, - (fftw_complex*)z_auxr, embed, this->nplane, 1, flag); - this->planyc2r = fftw_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, - this->nplane, 1, d_rspace, embed, this->nplane, 1, flag); - } - else - { - - this->planxfor2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (ny - rixy), (fftw_complex*)z_auxr, embed, - npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (ny - rixy), (fftw_complex*)z_auxr, embed, - npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, - 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); - this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, - 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); - } - } - - //--------------------------------------------------------- - // 3 D - XYZ - //--------------------------------------------------------- - // in-place fft test - // this->plan3dforward = fftw_plan_dft_3d( - // this->nx, this->ny, this->nz, - // reinterpret_cast(auxr_3d), - // reinterpret_cast(auxr_3d), - // FFTW_FORWARD, flag); - // this->plan3dbackward = fftw_plan_dft_3d( - // this->nx, this->ny, this->nz, - // reinterpret_cast(auxr_3d), - // reinterpret_cast(auxr_3d), - // FFTW_BACKWARD, flag); - -#if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") - { -#if defined(__CUDA) - cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C); - cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z); -#elif defined(__ROCM) - hipfftPlan3d(&c_handle, this->nx, this->ny, this->nz, HIPFFT_C2C); - hipfftPlan3d(&z_handle, this->nx, this->ny, this->nz, HIPFFT_Z2Z); -#endif - } -#endif -} - -#if defined(__ENABLE_FLOAT_FFTW) -void FFT ::initplanf(const unsigned int& flag) -{ - //--------------------------------------------------------- - // 1 D - //--------------------------------------------------------- - - // fftw_plan_many_dft(int rank, const int *n, int howmany, - // fftw_complex *in, const int *inembed, int istride, int idist, - // fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned - //flags); - - this->planfzfor = fftwf_plan_many_dft(1, &this->nz, this->ns, (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, - (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); - - this->planfzbac = fftwf_plan_many_dft(1, &this->nz, this->ns, (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, - (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); - //--------------------------------------------------------- - // 2 D - //--------------------------------------------------------- - - int* embed = nullptr; - int npy = this->nplane * this->ny; - if (this->xprime) - { - this->planfyfor = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, nplane, 1, - (fftwf_complex*)c_auxr, embed, nplane, 1, FFTW_FORWARD, flag); - this->planfybac = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, nplane, 1, - (fftwf_complex*)c_auxr, embed, nplane, 1, FFTW_BACKWARD, flag); - if (this->gamma_only) - { - this->planfxr2c = fftwf_plan_many_dft_r2c(1, &this->nx, npy, s_rspace, embed, npy, 1, - (fftwf_complex*)c_auxr, embed, npy, 1, flag); - this->planfxc2r = fftwf_plan_many_dft_c2r(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, - s_rspace, embed, npy, 1, flag); - } - else - { - this->planfxfor1 = fftwf_plan_many_dft(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, - (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planfxbac1 = fftwf_plan_many_dft(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, - (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - } - } - else - { - this->planfxfor1 = fftwf_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planfxbac1 = fftwf_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - if (this->gamma_only) - { - this->planfyr2c = fftwf_plan_many_dft_r2c(1, &this->ny, this->nplane, s_rspace, embed, this->nplane, 1, - (fftwf_complex*)c_auxr, embed, this->nplane, 1, flag); - this->planfyc2r = fftwf_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, - this->nplane, 1, s_rspace, embed, this->nplane, 1, flag); - } - else - { - this->planfxfor2 - = fftwf_plan_many_dft(1, &this->nx, this->nplane * (this->ny - rixy), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planfxbac2 - = fftwf_plan_many_dft(1, &this->nx, this->nplane * (this->ny - rixy), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - this->planfyfor - = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, this->nplane, 1, - (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); - this->planfybac - = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, this->nplane, 1, - (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); - } - } -} -#endif // defined(__ENABLE_FLOAT_FFTW) -// void FFT :: initplan_mpi() -// { - -// } - -// void FFT :: initplanf_mpi() -// { - -// } - -void FFT::cleanFFT() -{ - if (planzfor) - { - fftw_destroy_plan(planzfor); - planzfor = nullptr; - } - if (planzbac) - { - fftw_destroy_plan(planzbac); - planzbac = nullptr; - } - if (planxfor1) - { - fftw_destroy_plan(planxfor1); - planxfor1 = nullptr; - } - if (planxbac1) - { - fftw_destroy_plan(planxbac1); - planxbac1 = nullptr; - } - if (planxfor2) - { - fftw_destroy_plan(planxfor2); - planxfor2 = nullptr; - } - if (planxbac2) - { - fftw_destroy_plan(planxbac2); - planxbac2 = nullptr; - } - if (planyfor) - { - fftw_destroy_plan(planyfor); - planyfor = nullptr; - } - if (planybac) - { - fftw_destroy_plan(planybac); - planybac = nullptr; - } - if (planxr2c) - { - fftw_destroy_plan(planxr2c); - planxr2c = nullptr; - } - if (planxc2r) - { - fftw_destroy_plan(planxc2r); - planxc2r = nullptr; - } - if (planyr2c) - { - fftw_destroy_plan(planyr2c); - planyr2c = nullptr; - } - if (planyc2r) - { - fftw_destroy_plan(planyc2r); - planyc2r = nullptr; - } - - // fftw_destroy_plan(this->plan3dforward); - // fftw_destroy_plan(this->plan3dbackward); -#if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") - { -#if defined(__CUDA) - if (c_handle) - { - cufftDestroy(c_handle); - c_handle = {}; - } - if (z_handle) - { - cufftDestroy(z_handle); - z_handle = {}; - } -#elif defined(__ROCM) - if (c_handle) - { - hipfftDestroy(c_handle); - c_handle = {}; - } - if (z_handle) - { - hipfftDestroy(z_handle); - z_handle = {}; - } -#endif - } -#endif -} - -#if defined(__ENABLE_FLOAT_FFTW) -void FFT::cleanfFFT() -{ - if (planfzfor) - { - fftwf_destroy_plan(planfzfor); - planfzfor = NULL; - } - if (planfzbac) - { - fftwf_destroy_plan(planfzbac); - planfzbac = NULL; - } - if (planfxfor1) - { - fftwf_destroy_plan(planfxfor1); - planfxfor1 = NULL; - } - if (planfxbac1) - { - fftwf_destroy_plan(planfxbac1); - planfxbac1 = NULL; - } - if (planfxfor2) - { - fftwf_destroy_plan(planfxfor2); - planfxfor2 = NULL; - } - if (planfxbac2) - { - fftwf_destroy_plan(planfxbac2); - planfxbac2 = NULL; - } - if (planfyfor) - { - fftwf_destroy_plan(planfyfor); - planfyfor = NULL; - } - if (planfybac) - { - fftwf_destroy_plan(planfybac); - planfybac = NULL; - } - if (planfxr2c) - { - fftwf_destroy_plan(planfxr2c); - planfxr2c = NULL; - } - if (planfxc2r) - { - fftwf_destroy_plan(planfxc2r); - planfxc2r = NULL; - } - if (planfyr2c) - { - fftwf_destroy_plan(planfyr2c); - planfyr2c = NULL; - } - if (planfyc2r) - { - fftwf_destroy_plan(planfyc2r); - planfyc2r = NULL; - } - return; -} -#endif // defined(__ENABLE_FLOAT_FFTW) - -template <> -void FFT::fftzfor(std::complex* in, std::complex* out) const -{ -#if defined(__ENABLE_FLOAT_FFTW) - fftwf_execute_dft(this->planfzfor, (fftwf_complex*)in, (fftwf_complex*)out); -#else - ModuleBase::WARNING_QUIT("fft", "Please compile ABACUS using the ENABLE_FLOAT_FFTW flag!"); -#endif // defined(__ENABLE_FLOAT_FFTW) -} - -template <> -void FFT::fftzfor(std::complex* in, std::complex* out) const -{ - fftw_execute_dft(this->planzfor, (fftw_complex*)in, (fftw_complex*)out); -} - -template <> -void FFT::fftzbac(std::complex* in, std::complex* out) const -{ -#if defined(__ENABLE_FLOAT_FFTW) - fftwf_execute_dft(this->planfzbac, (fftwf_complex*)in, (fftwf_complex*)out); -#else - ModuleBase::WARNING_QUIT("fft", "Please compile ABACUS using the ENABLE_FLOAT_FFTW flag!"); -#endif // defined(__ENABLE_FLOAT_FFTW) -} - -template <> -void FFT::fftzbac(std::complex* in, std::complex* out) const -{ - fftw_execute_dft(this->planzbac, (fftw_complex*)in, (fftw_complex*)out); -} - -template <> -void FFT::fftxyfor(std::complex* in, std::complex* out) const -{ -#if defined(__ENABLE_FLOAT_FFTW) - int npy = this->nplane * this->ny; - if (this->xprime) - { - fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)in, (fftwf_complex*)out); - - for (int i = 0; i < this->lixy + 1; ++i) - { - fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); - } - for (int i = rixy; i < this->nx; ++i) - { - fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); - } - } - else - { - for (int i = 0; i < this->nx; ++i) - { - fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); - } - - fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)in, (fftwf_complex*)out); - fftwf_execute_dft(this->planfxfor2, (fftwf_complex*)&in[rixy * nplane], (fftwf_complex*)&out[rixy * nplane]); - } -#else - ModuleBase::WARNING_QUIT("fft", "Please compile ABACUS using the ENABLE_FLOAT_FFTW flag!"); -#endif // defined(__ENABLE_FLOAT_FFTW) -} - -template <> -void FFT::fftxyfor(std::complex* in, std::complex* out) const -{ - int npy = this->nplane * this->ny; - if (this->xprime) - { - fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); - - for (int i = 0; i < this->lixy + 1; ++i) - { - fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); - } - for (int i = rixy; i < this->nx; ++i) - { - fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); - } - } - else - { - for (int i = 0; i < this->nx; ++i) - { - fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); - } - - fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); - fftw_execute_dft(this->planxfor2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); - } -} - -template <> -void FFT::fftxybac(std::complex* in, std::complex* out) const -{ -#if defined(__ENABLE_FLOAT_FFTW) - int npy = this->nplane * this->ny; - if (this->xprime) - { - for (int i = 0; i < this->lixy + 1; ++i) - { - fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); - } - for (int i = rixy; i < this->nx; ++i) - { - fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); - } - - fftwf_execute_dft(this->planfxbac1, (fftwf_complex*)in, (fftwf_complex*)out); - } - else - { - fftwf_execute_dft(this->planfxbac1, (fftwf_complex*)in, (fftwf_complex*)out); - fftwf_execute_dft(this->planfxbac2, (fftwf_complex*)&in[rixy * nplane], (fftwf_complex*)&out[rixy * nplane]); - - for (int i = 0; i < this->nx; ++i) - { - fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); - } - } -#else - ModuleBase::WARNING_QUIT("fft", "Please compile ABACUS using the ENABLE_FLOAT_FFTW flag!"); -#endif // defined(__ENABLE_FLOAT_FFTW) -} - -template <> -void FFT::fftxybac(std::complex* in, std::complex* out) const -{ - int npy = this->nplane * this->ny; - if (this->xprime) - { - for (int i = 0; i < this->lixy + 1; ++i) - { - fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); - } - for (int i = rixy; i < this->nx; ++i) - { - fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); - } - - fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); - } - else - { - fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); - fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); - - for (int i = 0; i < this->nx; ++i) - { - fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); - } - } -} - -template <> -void FFT::fftxyr2c(float* in, std::complex* out) const -{ -#if defined(__ENABLE_FLOAT_FFTW) - int npy = this->nplane * this->ny; - if (this->xprime) - { - fftwf_execute_dft_r2c(this->planfxr2c, in, (fftwf_complex*)out); - - for (int i = 0; i < this->lixy + 1; ++i) - { - fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&out[i * npy], (fftwf_complex*)&out[i * npy]); - } - } - else - { - for (int i = 0; i < this->nx; ++i) - { - fftwf_execute_dft_r2c(this->planfyr2c, &in[i * npy], (fftwf_complex*)&out[i * npy]); - } - - fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)out, (fftwf_complex*)out); - } -#else - ModuleBase::WARNING_QUIT("fft", "Please compile ABACUS using the ENABLE_FLOAT_FFTW flag!"); -#endif // defined(__ENABLE_FLOAT_FFTW) -} - -template <> -void FFT::fftxyr2c(double* in, std::complex* out) const -{ - int npy = this->nplane * this->ny; - if (this->xprime) - { - fftw_execute_dft_r2c(this->planxr2c, in, (fftw_complex*)out); - - for (int i = 0; i < this->lixy + 1; ++i) - { - fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]); - } - } - else - { - for (int i = 0; i < this->nx; ++i) - { - fftw_execute_dft_r2c(this->planyr2c, &in[i * npy], (fftw_complex*)&out[i * npy]); - } - - fftw_execute_dft(this->planxfor1, (fftw_complex*)out, (fftw_complex*)out); - } -} - -template <> -void FFT::fftxyc2r(std::complex* in, float* out) const -{ -#if defined(__ENABLE_FLOAT_FFTW) - int npy = this->nplane * this->ny; - if (this->xprime) - { - for (int i = 0; i < this->lixy + 1; ++i) - { - fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&in[i * npy]); - } - - fftwf_execute_dft_c2r(this->planfxc2r, (fftwf_complex*)in, out); - } - else - { - fftwf_execute_dft(this->planfxbac1, (fftwf_complex*)in, (fftwf_complex*)in); - - for (int i = 0; i < this->nx; ++i) - { - fftwf_execute_dft_c2r(this->planfyc2r, (fftwf_complex*)&in[i * npy], &out[i * npy]); - } - } -#else - ModuleBase::WARNING_QUIT("fft", "Please compile ABACUS using the ENABLE_FLOAT_FFTW flag!"); -#endif // defined(__ENABLE_FLOAT_FFTW) -} - -template <> -void FFT::fftxyc2r(std::complex* in, double* out) const -{ - int npy = this->nplane * this->ny; - if (this->xprime) - { - for (int i = 0; i < this->lixy + 1; ++i) - { - fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&in[i * npy]); - } - - fftw_execute_dft_c2r(this->planxc2r, (fftw_complex*)in, out); - } - else - { - fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)in); - - for (int i = 0; i < this->nx; ++i) - { - fftw_execute_dft_c2r(this->planyc2r, (fftw_complex*)&in[i * npy], &out[i * npy]); - } - } -} - -#if defined(__CUDA) || defined(__ROCM) -template <> -void FFT::fft3D_forward(const base_device::DEVICE_GPU* /*ctx*/, std::complex* in, std::complex* out) const -{ -#if defined(__CUDA) - CHECK_CUFFT(cufftExecC2C(this->c_handle, reinterpret_cast(in), reinterpret_cast(out), - CUFFT_FORWARD)); -#elif defined(__ROCM) - CHECK_CUFFT(hipfftExecC2C(this->c_handle, reinterpret_cast(in), - reinterpret_cast(out), HIPFFT_FORWARD)); -#endif -} -template <> -void FFT::fft3D_forward(const base_device::DEVICE_GPU* /*ctx*/, std::complex* in, - std::complex* out) const -{ -#if defined(__CUDA) - CHECK_CUFFT(cufftExecZ2Z(this->z_handle, reinterpret_cast(in), - reinterpret_cast(out), CUFFT_FORWARD)); -#elif defined(__ROCM) - CHECK_CUFFT(hipfftExecZ2Z(this->z_handle, reinterpret_cast(in), - reinterpret_cast(out), HIPFFT_FORWARD)); -#endif -} - -template <> -void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex* in, - std::complex* out) const -{ -#if defined(__CUDA) - CHECK_CUFFT(cufftExecC2C(this->c_handle, reinterpret_cast(in), reinterpret_cast(out), - CUFFT_INVERSE)); -#elif defined(__ROCM) - CHECK_CUFFT(hipfftExecC2C(this->c_handle, reinterpret_cast(in), - reinterpret_cast(out), HIPFFT_BACKWARD)); -#endif -} -template <> -void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex* in, - std::complex* out) const -{ -#if defined(__CUDA) - CHECK_CUFFT(cufftExecZ2Z(this->z_handle, reinterpret_cast(in), - reinterpret_cast(out), CUFFT_INVERSE)); -#elif defined(__ROCM) - CHECK_CUFFT(hipfftExecZ2Z(this->z_handle, reinterpret_cast(in), - reinterpret_cast(out), HIPFFT_BACKWARD)); -#endif -} -#endif - -template <> -float* FFT::get_rspace_data() const -{ - return this->s_rspace; -} -template <> -double* FFT::get_rspace_data() const -{ - return this->d_rspace; -} - -template <> -std::complex* FFT::get_auxr_data() const -{ - return this->c_auxr; -} -template <> -std::complex* FFT::get_auxr_data() const -{ - return this->z_auxr; -} - -template <> -std::complex* FFT::get_auxg_data() const -{ - return this->c_auxg; -} -template <> -std::complex* FFT::get_auxg_data() const -{ - return this->z_auxg; -} - -#if defined(__CUDA) || defined(__ROCM) -template <> -std::complex* FFT::get_auxr_3d_data() const -{ - return this->c_auxr_3d; -} -template <> -std::complex* FFT::get_auxr_3d_data() const -{ - return this->z_auxr_3d; -} -#endif - -void FFT::set_device(std::string device_) -{ - this->device = std::move(device_); -} - -void FFT::set_precision(std::string precision_) -{ - this->precision = std::move(precision_); -} - -} // namespace ModulePW diff --git a/source/module_basis/module_pw/fft.h b/source/module_basis/module_pw/fft.h index 3581d01d18..8a69863cfc 100644 --- a/source/module_basis/module_pw/fft.h +++ b/source/module_basis/module_pw/fft.h @@ -1,8 +1,3 @@ -#ifndef FFT_H -#define FFT_H - -#include -#include #include "fftw3.h" #if defined(__FFTW3_MPI) && defined(__MPI) @@ -20,154 +15,8 @@ #include #endif -//Temporary: we donot need psi. However some GPU ops are defined in psi, which should be moved into module_base or module_gpu -#include "module_psi/psi.h" -// #ifdef __ENABLE_FLOAT_FFTW -// #include "fftw3f.h" -// #if defined(__FFTW3_MPI) && defined(__MPI) -// #include "fftw3f-mpi.h" -// //#include "fftw3-mpi_mkl.h" -// #endif -// #endif - -namespace ModulePW -{ - -class FFT -{ -public: - - FFT(); - ~FFT(); - void clear(); //reset fft - - // init parameters of fft - void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, - int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false); - - //init fftw_plans - void setupFFT(); - - //destroy fftw_plans - void cleanFFT(); - -#if defined(__ENABLE_FLOAT_FFTW) - void cleanfFFT(); -#endif // defined(__ENABLE_FLOAT_FFTW) - - template - void fftzfor(std::complex* in, std::complex* out) const; - template - void fftzbac(std::complex* in, std::complex* out) const; - template - void fftxyfor(std::complex* in, std::complex* out) const; - template - void fftxybac(std::complex* in, std::complex* out) const; - template - void fftxyr2c(FPTYPE* in, std::complex* out) const; - template - void fftxyc2r(std::complex* in, FPTYPE* out) const; - - template - void fft3D_forward(const Device* ctx, std::complex* in, std::complex* out) const; - template - void fft3D_backward(const Device* ctx, std::complex* in, std::complex* out) const; - - public: - //init fftw_plans - void initplan(const unsigned int& flag = 0); - // We have not support mpi fftw yet. - // void initplan_mpi(); - //init fftwf_plans -#if defined(__ENABLE_FLOAT_FFTW) - void initplanf(const unsigned int& flag = 0); -#endif // defined(__ENABLE_FLOAT_FFTW) - // void initplanf_mpi(); - -private: - int fftnx=0, fftny=0; - int fftnxy=0; - int ny=0, nx=0, nz=0; - int nxy=0; -public : - bool xprime = true; // true: when do recip2real, x-fft will be done last and when doing real2recip, x-fft will be done first; false: y-fft - // For gamma_only, true: we use half x; false: we use half y - int lixy=0,rixy=0;// lixy: the left edge of the pw ball in the y direction; rixy: the right edge of the pw ball in the x or y direction - int ns=0; //number of sticks - int nplane=0; //number of x-y planes - int nproc=1; // number of proc. - - template - FPTYPE* get_rspace_data() const; - template - std::complex* get_auxr_data() const; - template - std::complex* get_auxg_data() const; - template - std::complex* get_auxr_3d_data() const; - - int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive - - private: - bool gamma_only = false; - bool mpifft = false; // if use mpi fft, only used when define __FFTW3_MPI -//add by A.s 202406 considering that no all people are familiar with fftw3,some comments should be added. - fftw_plan planzfor = NULL;//create a special pointer pointing to the fftw_plan class as a plan for performing FFT - fftw_plan planzbac = NULL; - fftw_plan planxfor1 = NULL; - fftw_plan planxbac1 = NULL; - fftw_plan planxfor2 = NULL; - fftw_plan planxbac2 = NULL; - fftw_plan planyfor = NULL; - fftw_plan planybac = NULL; - fftw_plan planxr2c = NULL; - fftw_plan planxc2r = NULL; - fftw_plan planyr2c = NULL; - fftw_plan planyc2r = NULL; -// fftw_plan plan3dforward; -// fftw_plan plan3dbackward; - -#if defined(__CUDA) - cufftHandle c_handle = {}; - cufftHandle z_handle = {}; -#elif defined(__ROCM) - hipfftHandle c_handle = {}; - hipfftHandle z_handle = {}; -#endif - -#if defined(__ENABLE_FLOAT_FFTW) - fftwf_plan planfzfor = NULL; - fftwf_plan planfzbac = NULL; - fftwf_plan planfxfor1= NULL; - fftwf_plan planfxbac1= NULL; - fftwf_plan planfxfor2= NULL; - fftwf_plan planfxbac2= NULL; - fftwf_plan planfyfor = NULL; - fftwf_plan planfybac = NULL; - fftwf_plan planfxr2c = NULL; - fftwf_plan planfxc2r = NULL; - fftwf_plan planfyr2c = NULL; - fftwf_plan planfyc2r = NULL; -#endif // defined(__ENABLE_FLOAT_FFTW) - mutable std::complex* c_auxr_3d = nullptr; // fft space - mutable std::complex* z_auxr_3d = nullptr; // fft space - - mutable std::complex*c_auxg = nullptr, *c_auxr = nullptr; // fft space, - mutable std::complex*z_auxg = nullptr, *z_auxr = nullptr; // fft space - - mutable float* s_rspace = nullptr; // real number space for r, [nplane * nx *ny] - mutable double* d_rspace = nullptr; // real number space for r, [nplane * nx *ny] - - std::string device = "cpu"; - std::string precision = "double"; - -public: - void set_device(std::string device_); - void set_precision(std::string precision_); +#include "module_psi/psi.h" -}; -} -#endif diff --git a/source/module_basis/module_pw/test/Makefile b/source/module_basis/module_pw/test/Makefile index 884f0f74c0..d69d92ffbd 100644 --- a/source/module_basis/module_pw/test/Makefile +++ b/source/module_basis/module_pw/test/Makefile @@ -120,7 +120,6 @@ pw_transform.o\ pw_distributeg.o\ pw_distributeg_method1.o\ pw_distributeg_method2.o\ -fft.o\ pw_basis_k.o\ pw_basis_sup.o\ pw_transform_k.o\ diff --git a/source/module_basis/module_pw/test_serial/CMakeLists.txt b/source/module_basis/module_pw/test_serial/CMakeLists.txt index 028d5b3a0e..1ba8235b48 100644 --- a/source/module_basis/module_pw/test_serial/CMakeLists.txt +++ b/source/module_basis/module_pw/test_serial/CMakeLists.txt @@ -9,7 +9,6 @@ remove_definitions(-D__DEEPKS) add_library( planewave_serial OBJECT - ../fft.cpp ../module_fft/fft_base.cpp ../module_fft/fft_bundle.cpp ../module_fft/fft_cpu.cpp diff --git a/source/module_elecstate/test/charge_extra_test.cpp b/source/module_elecstate/test/charge_extra_test.cpp index 63478bf724..15673e2a5e 100644 --- a/source/module_elecstate/test/charge_extra_test.cpp +++ b/source/module_elecstate/test/charge_extra_test.cpp @@ -64,12 +64,6 @@ PW_Basis::PW_Basis() PW_Basis::~PW_Basis() { } -FFT::FFT() -{ -} -FFT::~FFT() -{ -} FFT_Bundle::~FFT_Bundle(){}; void PW_Basis::initgrids(const double lat0_in, const ModuleBase::Matrix3 latvec_in, const double gridecut) { diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index 6115b58a9b..c0da5a82ea 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -50,12 +50,6 @@ ModulePW::PW_Basis::~PW_Basis() ModulePW::PW_Basis_Sup::~PW_Basis_Sup() { } -ModulePW::FFT::FFT() -{ -} -ModulePW::FFT::~FFT() -{ -} ModulePW::FFT_Bundle::~FFT_Bundle(){}; void ModulePW::PW_Basis::initgrids(double, ModuleBase::Matrix3, double) { diff --git a/source/module_hamilt_general/module_xc/test/xc3_mock.h b/source/module_hamilt_general/module_xc/test/xc3_mock.h index 8ecccd932d..d613c8215b 100644 --- a/source/module_hamilt_general/module_xc/test/xc3_mock.h +++ b/source/module_hamilt_general/module_xc/test/xc3_mock.h @@ -132,8 +132,6 @@ namespace ModulePW const double factor) const; #endif - FFT::FFT(){}; - FFT::~FFT(){}; void PW_Basis::initgrids(double, ModuleBase::Matrix3, double){}; void PW_Basis::distribute_r(){}; diff --git a/source/module_hsolver/test/hsolver_pw_sup.h b/source/module_hsolver/test/hsolver_pw_sup.h index 6b9d872e34..94d9af4686 100644 --- a/source/module_hsolver/test/hsolver_pw_sup.h +++ b/source/module_hsolver/test/hsolver_pw_sup.h @@ -45,9 +45,6 @@ double& PW_Basis_K::getgk2(const int ik, const int igl) const { return this->gk2[igl]; } -FFT::FFT() {} - -FFT::~FFT() {} } // namespace ModulePW From 541dde287f9ce527bb047b48cfe2fe2305ed8e14 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Wed, 20 Nov 2024 11:05:36 +0800 Subject: [PATCH 02/11] update the psi.h --- source/module_basis/module_pw/fft.h | 10 +- .../module_pw/module_fft/fft_base.h | 6 +- .../module_pw/module_fft/fft_bundle.h | 13 ++- .../module_pw/module_fft/fft_cuda.cpp | 2 +- .../module_pw/module_fft/fft_cuda.h | 8 +- .../module_pw/module_fft/fft_rcom.h | 6 +- .../module_fft/kernel/fft_cuda_func.h | 57 +++++++++++ .../module_fft/kernel/fft_rcom_func.h | 53 ++++++++++ source/module_basis/module_pw/pw_basis.h | 2 +- .../module_basis/module_pw/pw_transform.cpp | 1 - .../module_pw/test_serial/pw_basis_k_test.cpp | 1 - .../module_pw/test_serial/pw_basis_test.cpp | 1 - .../kernels/cuda/gemm_selector.cuh | 2 +- source/module_hamilt_pw/hamilt_pwdft/global.h | 99 +------------------ 14 files changed, 135 insertions(+), 126 deletions(-) create mode 100644 source/module_basis/module_pw/module_fft/kernel/fft_cuda_func.h create mode 100644 source/module_basis/module_pw/module_fft/kernel/fft_rcom_func.h diff --git a/source/module_basis/module_pw/fft.h b/source/module_basis/module_pw/fft.h index 8a69863cfc..05bf7792b9 100644 --- a/source/module_basis/module_pw/fft.h +++ b/source/module_basis/module_pw/fft.h @@ -1,12 +1,7 @@ -#include "fftw3.h" -#if defined(__FFTW3_MPI) && defined(__MPI) -#include -//#include "fftw3-mpi_mkl.h" -#endif #if defined(__CUDA) || defined(__UT_USE_CUDA) -#include "cufft.h" +// #include "cufft.h" #include "cuda_runtime.h" #endif @@ -17,6 +12,3 @@ #include "module_psi/psi.h" - - - diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index c1b105f1fd..394fd9a870 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -1,8 +1,8 @@ -#include -#include -#include "fftw3.h" #ifndef FFT_BASE_H #define FFT_BASE_H + +#include +#include namespace ModulePW { template diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 6da2419245..41e1dc8602 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -1,8 +1,15 @@ -#include "fft_base.h" -#include -// #include "module_psi/psi.h" #ifndef FFT_TEMP_H #define FFT_TEMP_H + +#include "fft_base.h" +#include +#include "fft_cpu.h" +#ifdef __CUDA +#include "fft_cuda.h" +#endif +#ifdef __ROCM +#include "fft_rocm.h" +#endif namespace ModulePW { class FFT_Bundle diff --git a/source/module_basis/module_pw/module_fft/fft_cuda.cpp b/source/module_basis/module_pw/module_fft/fft_cuda.cpp index f9fc5df74b..4ac859f9ea 100644 --- a/source/module_basis/module_pw/module_fft/fft_cuda.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cuda.cpp @@ -1,6 +1,6 @@ #include "fft_cuda.h" #include "module_base/module_device/memory_op.h" -#include "module_hamilt_pw/hamilt_pwdft/global.h" + namespace ModulePW { template diff --git a/source/module_basis/module_pw/module_fft/fft_cuda.h b/source/module_basis/module_pw/module_fft/fft_cuda.h index 90192d24dc..ddd25c9028 100644 --- a/source/module_basis/module_pw/module_fft/fft_cuda.h +++ b/source/module_basis/module_pw/module_fft/fft_cuda.h @@ -1,9 +1,9 @@ -#include "fft_base.h" -#include "cufft.h" -#include "cuda_runtime.h" - #ifndef FFT_CUDA_H #define FFT_CUDA_H + +#include "fft_base.h" +#include "kernel/fft_cuda_func.h" + namespace ModulePW { template diff --git a/source/module_basis/module_pw/module_fft/fft_rcom.h b/source/module_basis/module_pw/module_fft/fft_rcom.h index 64ad13329e..38cc4fce2c 100644 --- a/source/module_basis/module_pw/module_fft/fft_rcom.h +++ b/source/module_basis/module_pw/module_fft/fft_rcom.h @@ -1,8 +1,8 @@ -#include "fft_base.h" -#include -#include #ifndef FFT_ROCM_H #define FFT_ROCM_H + +#include "fft_base.h" +#include "kernel/fft_rcom_func.h" namespace ModulePW { template diff --git a/source/module_basis/module_pw/module_fft/kernel/fft_cuda_func.h b/source/module_basis/module_pw/module_fft/kernel/fft_cuda_func.h new file mode 100644 index 0000000000..8d19fcb5fb --- /dev/null +++ b/source/module_basis/module_pw/module_fft/kernel/fft_cuda_func.h @@ -0,0 +1,57 @@ +#ifndef FFT_CUDA_FUNC_H +#define FFT_CUDA_FUNC_H +#include "cufft.h" +#include "cuda_runtime.h" + +static const char* _cufftGetErrorString(cufftResult_t error) +{ + switch (error) + { + case CUFFT_SUCCESS: + return "CUFFT_SUCCESS"; + case CUFFT_INVALID_PLAN: + return "CUFFT_INVALID_PLAN"; + case CUFFT_ALLOC_FAILED: + return "CUFFT_ALLOC_FAILED"; + case CUFFT_INVALID_TYPE: + return "CUFFT_INVALID_TYPE"; + case CUFFT_INVALID_VALUE: + return "CUFFT_INVALID_VALUE"; + case CUFFT_INTERNAL_ERROR: + return "CUFFT_INTERNAL_ERROR"; + case CUFFT_EXEC_FAILED: + return "CUFFT_EXEC_FAILED"; + case CUFFT_SETUP_FAILED: + return "CUFFT_SETUP_FAILED"; + case CUFFT_INVALID_SIZE: + return "CUFFT_INVALID_SIZE"; + case CUFFT_UNALIGNED_DATA: + return "CUFFT_UNALIGNED_DATA"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: + return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: + return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: + return "CUFFT_PARSE_ERROR"; + case CUFFT_NO_WORKSPACE: + return "CUFFT_NO_WORKSPACE"; + case CUFFT_NOT_IMPLEMENTED: + return "CUFFT_NOT_IMPLEMENTED"; + case CUFFT_LICENSE_ERROR: + return "CUFFT_LICENSE_ERROR"; + case CUFFT_NOT_SUPPORTED: + return "CUFFT_NOT_SUPPORTED"; + } + return ""; +} + +#define CHECK_CUFFT(func) \ + { \ + cufftResult_t status = (func); \ + if (status != CUFFT_SUCCESS) \ + { \ + printf("In File %s : CUFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + _cufftGetErrorString(status), status); \ + } \ + } +#endif // FFT_CUDA_FUNC_H diff --git a/source/module_basis/module_pw/module_fft/kernel/fft_rcom_func.h b/source/module_basis/module_pw/module_fft/kernel/fft_rcom_func.h new file mode 100644 index 0000000000..06dd76d2bc --- /dev/null +++ b/source/module_basis/module_pw/module_fft/kernel/fft_rcom_func.h @@ -0,0 +1,53 @@ +#ifndef FFT_ROCM_FUNC_H +#define FFT_ROCM_FUNC_H +#include +#include +static const char* _hipfftGetErrorString(hipfftResult_t error) +{ + switch (error) + { + case HIPFFT_SUCCESS: + return "HIPFFT_SUCCESS"; + case HIPFFT_INVALID_PLAN: + return "HIPFFT_INVALID_PLAN"; + case HIPFFT_ALLOC_FAILED: + return "HIPFFT_ALLOC_FAILED"; + case HIPFFT_INVALID_TYPE: + return "HIPFFT_INVALID_TYPE"; + case HIPFFT_INVALID_VALUE: + return "HIPFFT_INVALID_VALUE"; + case HIPFFT_INTERNAL_ERROR: + return "HIPFFT_INTERNAL_ERROR"; + case HIPFFT_EXEC_FAILED: + return "HIPFFT_EXEC_FAILED"; + case HIPFFT_SETUP_FAILED: + return "HIPFFT_SETUP_FAILED"; + case HIPFFT_INVALID_SIZE: + return "HIPFFT_INVALID_SIZE"; + case HIPFFT_UNALIGNED_DATA: + return "HIPFFT_UNALIGNED_DATA"; + case HIPFFT_INCOMPLETE_PARAMETER_LIST: + return "HIPFFT_INCOMPLETE_PARAMETER_LIST"; + case HIPFFT_INVALID_DEVICE: + return "HIPFFT_INVALID_DEVICE"; + case HIPFFT_PARSE_ERROR: + return "HIPFFT_PARSE_ERROR"; + case HIPFFT_NO_WORKSPACE: + return "HIPFFT_NO_WORKSPACE"; + case HIPFFT_NOT_IMPLEMENTED: + return "HIPFFT_NOT_IMPLEMENTED"; + case HIPFFT_NOT_SUPPORTED: + return "HIPFFT_NOT_SUPPORTED"; + } + return ""; +} +#define CHECK_CUFFT(func) \ + { \ + hipfftResult_t status = (func); \ + if (status != HIPFFT_SUCCESS) \ + { \ + printf("In File %s : HIPFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + _hipfftGetErrorString(status), status); \ + } \ + } +#endif // FFT_ROCM_FUNC_H \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index 00aba50971..9bd48d270f 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -1,11 +1,11 @@ #ifndef PWBASIS_H #define PWBASIS_H +#include "module_base/module_device/memory_op.h" #include "module_base/matrix.h" #include "module_base/matrix3.h" #include "module_base/vector3.h" #include -#include "fft.h" #include "module_fft/fft_bundle.h" #include #ifdef __MPI diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp index d8534c7f0a..8e458b2561 100644 --- a/source/module_basis/module_pw/pw_transform.cpp +++ b/source/module_basis/module_pw/pw_transform.cpp @@ -1,4 +1,3 @@ -#include "fft.h" #include "module_fft/fft_bundle.h" #include #include "pw_basis.h" diff --git a/source/module_basis/module_pw/test_serial/pw_basis_k_test.cpp b/source/module_basis/module_pw/test_serial/pw_basis_k_test.cpp index e5fac0ef4c..75f352f474 100644 --- a/source/module_basis/module_pw/test_serial/pw_basis_k_test.cpp +++ b/source/module_basis/module_pw/test_serial/pw_basis_k_test.cpp @@ -27,7 +27,6 @@ #define private public #include "../pw_basis_k.h" #include "../pw_basis.h" -#include "../fft.h" #undef private #undef protected diff --git a/source/module_basis/module_pw/test_serial/pw_basis_test.cpp b/source/module_basis/module_pw/test_serial/pw_basis_test.cpp index 89a84c43b3..eeff14b8f9 100644 --- a/source/module_basis/module_pw/test_serial/pw_basis_test.cpp +++ b/source/module_basis/module_pw/test_serial/pw_basis_test.cpp @@ -38,7 +38,6 @@ #define protected public #define private public #include "../pw_basis.h" -#include "../fft.h" #undef private #undef protected diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh b/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh index 380a16c842..f52b7a8643 100644 --- a/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh @@ -2,7 +2,7 @@ #define GEMM_SELECTOR_H #include "module_cell/unitcell.h" - +#include "cuda_runtime.h" typedef std::function"; } -static const char* _cufftGetErrorString(cufftResult_t error) -{ - switch (error) - { - case CUFFT_SUCCESS: - return "CUFFT_SUCCESS"; - case CUFFT_INVALID_PLAN: - return "CUFFT_INVALID_PLAN"; - case CUFFT_ALLOC_FAILED: - return "CUFFT_ALLOC_FAILED"; - case CUFFT_INVALID_TYPE: - return "CUFFT_INVALID_TYPE"; - case CUFFT_INVALID_VALUE: - return "CUFFT_INVALID_VALUE"; - case CUFFT_INTERNAL_ERROR: - return "CUFFT_INTERNAL_ERROR"; - case CUFFT_EXEC_FAILED: - return "CUFFT_EXEC_FAILED"; - case CUFFT_SETUP_FAILED: - return "CUFFT_SETUP_FAILED"; - case CUFFT_INVALID_SIZE: - return "CUFFT_INVALID_SIZE"; - case CUFFT_UNALIGNED_DATA: - return "CUFFT_UNALIGNED_DATA"; - case CUFFT_INCOMPLETE_PARAMETER_LIST: - return "CUFFT_INCOMPLETE_PARAMETER_LIST"; - case CUFFT_INVALID_DEVICE: - return "CUFFT_INVALID_DEVICE"; - case CUFFT_PARSE_ERROR: - return "CUFFT_PARSE_ERROR"; - case CUFFT_NO_WORKSPACE: - return "CUFFT_NO_WORKSPACE"; - case CUFFT_NOT_IMPLEMENTED: - return "CUFFT_NOT_IMPLEMENTED"; - case CUFFT_LICENSE_ERROR: - return "CUFFT_LICENSE_ERROR"; - case CUFFT_NOT_SUPPORTED: - return "CUFFT_NOT_SUPPORTED"; - } - return ""; -} #define CHECK_CUDA(func) \ { \ @@ -111,15 +70,7 @@ static const char* _cufftGetErrorString(cufftResult_t error) } \ } -#define CHECK_CUFFT(func) \ - { \ - cufftResult_t status = (func); \ - if (status != CUFFT_SUCCESS) \ - { \ - printf("In File %s : CUFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - _cufftGetErrorString(status), status); \ - } \ - } + #endif // __CUDA #ifdef __ROCM @@ -167,45 +118,6 @@ static const char* _hipblasGetErrorString(hipblasStatus_t error) // return ""; // } -static const char* _hipfftGetErrorString(hipfftResult_t error) -{ - switch (error) - { - case HIPFFT_SUCCESS: - return "HIPFFT_SUCCESS"; - case HIPFFT_INVALID_PLAN: - return "HIPFFT_INVALID_PLAN"; - case HIPFFT_ALLOC_FAILED: - return "HIPFFT_ALLOC_FAILED"; - case HIPFFT_INVALID_TYPE: - return "HIPFFT_INVALID_TYPE"; - case HIPFFT_INVALID_VALUE: - return "HIPFFT_INVALID_VALUE"; - case HIPFFT_INTERNAL_ERROR: - return "HIPFFT_INTERNAL_ERROR"; - case HIPFFT_EXEC_FAILED: - return "HIPFFT_EXEC_FAILED"; - case HIPFFT_SETUP_FAILED: - return "HIPFFT_SETUP_FAILED"; - case HIPFFT_INVALID_SIZE: - return "HIPFFT_INVALID_SIZE"; - case HIPFFT_UNALIGNED_DATA: - return "HIPFFT_UNALIGNED_DATA"; - case HIPFFT_INCOMPLETE_PARAMETER_LIST: - return "HIPFFT_INCOMPLETE_PARAMETER_LIST"; - case HIPFFT_INVALID_DEVICE: - return "HIPFFT_INVALID_DEVICE"; - case HIPFFT_PARSE_ERROR: - return "HIPFFT_PARSE_ERROR"; - case HIPFFT_NO_WORKSPACE: - return "HIPFFT_NO_WORKSPACE"; - case HIPFFT_NOT_IMPLEMENTED: - return "HIPFFT_NOT_IMPLEMENTED"; - case HIPFFT_NOT_SUPPORTED: - return "HIPFFT_NOT_SUPPORTED"; - } - return ""; -} #define CHECK_CUDA(func) \ { \ @@ -237,15 +149,6 @@ static const char* _hipfftGetErrorString(hipfftResult_t error) // }\ // } -#define CHECK_CUFFT(func) \ - { \ - hipfftResult_t status = (func); \ - if (status != HIPFFT_SUCCESS) \ - { \ - printf("In File %s : HIPFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - _hipfftGetErrorString(status), status); \ - } \ - } #endif // __ROCM //========================================================== From f920328c19a04d37c29e3ab370777b4f4254e4b5 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Wed, 20 Nov 2024 11:26:11 +0800 Subject: [PATCH 03/11] update the header file --- source/module_basis/module_pw/module_fft/fft_base.h | 1 + .../module_basis/module_pw/module_fft/fft_bundle.cpp | 2 -- source/module_basis/module_pw/module_fft/fft_bundle.h | 11 ++++------- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index 394fd9a870..21ed802807 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -3,6 +3,7 @@ #include #include +#include "module_base/module_device/memory_op.h" namespace ModulePW { template diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 31d29c32f9..adf5a1b1b2 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -1,7 +1,5 @@ #include #include "fft_bundle.h" -#include "fft_cpu.h" -#include "module_base/module_device/device.h" #if defined(__CUDA) #include "fft_cuda.h" #endif diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 41e1dc8602..a89b2d3bdb 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -2,14 +2,11 @@ #define FFT_TEMP_H #include "fft_base.h" -#include #include "fft_cpu.h" -#ifdef __CUDA -#include "fft_cuda.h" -#endif -#ifdef __ROCM -#include "fft_rocm.h" -#endif +#include "module_base/module_device/device.h" +#include "module_base/module_device/memory_op.h" +#include + namespace ModulePW { class FFT_Bundle From 18592ce2031c3776d92c072be94725edf7ec49f1 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Thu, 21 Nov 2024 10:10:28 +0800 Subject: [PATCH 04/11] add clear func --- .../module_pw/module_fft/fft_cpu.h | 1 - .../module_pw/module_fft/fft_cpu_float.cpp | 28 +++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.h b/source/module_basis/module_pw/module_fft/fft_cpu.h index 27c7e862a2..8864335d7e 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.h +++ b/source/module_basis/module_pw/module_fft/fft_cpu.h @@ -33,7 +33,6 @@ class FFT_CPU : public FFT_BASE * @param gamma_only_in whether only gamma point is used. * @param xprime_in whether xprime is used. */ - __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in, diff --git a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp index c13d47f762..b3e8d7d572 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp @@ -267,11 +267,11 @@ void FFT_CPU::setupFFT() } template <> -void FFT_CPU::clearfft(fftw_plan& plan) +void FFT_CPU::clearfft(fftwf_plan& plan) { if (plan) { - fftw_destroy_plan(plan); + fftwf_destroy_plan(plan); plan = nullptr; } } @@ -279,18 +279,18 @@ void FFT_CPU::clearfft(fftw_plan& plan) template <> void FFT_CPU::cleanFFT() { - clearfft(planzfor); - clearfft(planzbac); - clearfft(planxfor1); - clearfft(planxbac1); - clearfft(planxfor2); - clearfft(planxbac2); - clearfft(planyfor); - clearfft(planybac); - clearfft(planxr2c); - clearfft(planxc2r); - clearfft(planyr2c); - clearfft(planyc2r); + clearfft(planfzfor); + clearfft(planfzbac); + clearfft(planfxfor1); + clearfft(planfxbac1); + clearfft(planfxfor2); + clearfft(planfxbac2); + clearfft(planfyfor); + clearfft(planfybac); + clearfft(planfxr2c); + clearfft(planfxc2r); + clearfft(planfyr2c); + clearfft(planfyc2r); } From 3a0c5251c2f20c0eea1d146d0137089521ff1451 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Thu, 21 Nov 2024 20:31:20 +0800 Subject: [PATCH 05/11] change the fft makefile --- source/module_basis/module_pw/CMakeLists.txt | 1 - source/module_basis/module_pw/module_fft/fft_base.cpp | 5 +---- source/module_basis/module_pw/module_fft/fft_base.h | 5 +++++ source/module_basis/module_pw/module_fft/fft_cpu.h | 1 + source/module_basis/module_pw/test/Makefile | 10 ++++++---- 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index ee5154f2c7..7cff08efd5 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -25,7 +25,6 @@ list(APPEND objects pw_init.cpp pw_transform.cpp pw_transform_k.cpp - module_fft/fft_base.cpp module_fft/fft_bundle.cpp module_fft/fft_cpu.cpp ${FFT_SRC} diff --git a/source/module_basis/module_pw/module_fft/fft_base.cpp b/source/module_basis/module_pw/module_fft/fft_base.cpp index 4c91d4d7b4..446b7559d0 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.cpp +++ b/source/module_basis/module_pw/module_fft/fft_base.cpp @@ -1,8 +1,5 @@ #include "fft_base.h" namespace ModulePW { -template FFT_BASE::FFT_BASE(); -template FFT_BASE::FFT_BASE(); -template FFT_BASE::~FFT_BASE(); -template FFT_BASE::~FFT_BASE(); + } \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index 21ed802807..df371c472a 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -165,5 +165,10 @@ class FFT_BASE int ny=0; int nz=0; }; + +template FFT_BASE::FFT_BASE(); +template FFT_BASE::FFT_BASE(); +template FFT_BASE::~FFT_BASE(); +template FFT_BASE::~FFT_BASE(); } #endif // FFT_BASE_H diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.h b/source/module_basis/module_pw/module_fft/fft_cpu.h index 8864335d7e..ebd74f22cb 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.h +++ b/source/module_basis/module_pw/module_fft/fft_cpu.h @@ -43,6 +43,7 @@ class FFT_CPU : public FFT_BASE int nproc_in, bool gamma_only_in, bool xprime_in = true) override; + __attribute__((weak)) void setupFFT() override; diff --git a/source/module_basis/module_pw/test/Makefile b/source/module_basis/module_pw/test/Makefile index d69d92ffbd..df91138107 100644 --- a/source/module_basis/module_pw/test/Makefile +++ b/source/module_basis/module_pw/test/Makefile @@ -2,7 +2,7 @@ # Please set # e.g. make CXX=mpiicpc or make CXX=icpc #====================================================================== -CXX = mpiicpx +CXX = mpiicpc # mpiicpc: compile intel parallel version # icpc: compile intel sequential version # mpicxx: compile gnu parallel version @@ -94,7 +94,7 @@ endif ##========================== ## GTEST ##========================== -GTESTOPTS = -I${GTEST_DIR}/include -L${GTEST_DIR}/lib -lgtest -lpthread +GTESTOPTS = -I${GTEST_DIR}/include -L${GTEST_DIR}/lib -lgtest -lpthread -w @@ -106,6 +106,7 @@ VPATH=../../../module_base\ ../../../module_base/module_container/ATen/core\ ../../../module_base/module_container/ATen\ ../../../module_parameter\ +../module_fft\ ../\ MATH_OBJS0=matrix.o\ @@ -127,9 +128,10 @@ memory.o\ memory_op.o\ depend_mock.o\ parameter.o\ -fft_base.o\ -fft_bundle.o\ fft_cpu.o\ +fft_cpu_float.o\ +fft_bundle.o\ + OTHER_OBJS0= From d3d50b0039e0fe191fb806c9312023bfdc41f9c0 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Thu, 21 Nov 2024 20:41:11 +0800 Subject: [PATCH 06/11] delete fft.h --- source/module_basis/module_pw/fft.h | 14 -------------- .../module_basis/module_pw/module_fft/fft_base.cpp | 5 ----- .../module_pw/test_serial/CMakeLists.txt | 1 - .../module_xc/test/CMakeLists.txt | 2 -- 4 files changed, 22 deletions(-) delete mode 100644 source/module_basis/module_pw/fft.h delete mode 100644 source/module_basis/module_pw/module_fft/fft_base.cpp diff --git a/source/module_basis/module_pw/fft.h b/source/module_basis/module_pw/fft.h deleted file mode 100644 index 05bf7792b9..0000000000 --- a/source/module_basis/module_pw/fft.h +++ /dev/null @@ -1,14 +0,0 @@ - - -#if defined(__CUDA) || defined(__UT_USE_CUDA) -// #include "cufft.h" -#include "cuda_runtime.h" -#endif - -#if defined(__ROCM) || defined(__UT_USE_ROCM) -#include -#include -#endif - - -#include "module_psi/psi.h" diff --git a/source/module_basis/module_pw/module_fft/fft_base.cpp b/source/module_basis/module_pw/module_fft/fft_base.cpp deleted file mode 100644 index 446b7559d0..0000000000 --- a/source/module_basis/module_pw/module_fft/fft_base.cpp +++ /dev/null @@ -1,5 +0,0 @@ -#include "fft_base.h" -namespace ModulePW -{ - -} \ No newline at end of file diff --git a/source/module_basis/module_pw/test_serial/CMakeLists.txt b/source/module_basis/module_pw/test_serial/CMakeLists.txt index 1ba8235b48..7e9356d022 100644 --- a/source/module_basis/module_pw/test_serial/CMakeLists.txt +++ b/source/module_basis/module_pw/test_serial/CMakeLists.txt @@ -9,7 +9,6 @@ remove_definitions(-D__DEEPKS) add_library( planewave_serial OBJECT - ../module_fft/fft_base.cpp ../module_fft/fft_bundle.cpp ../module_fft/fft_cpu.cpp ../pw_basis.cpp diff --git a/source/module_hamilt_general/module_xc/test/CMakeLists.txt b/source/module_hamilt_general/module_xc/test/CMakeLists.txt index b93e3a6ddb..0dda934ac6 100644 --- a/source/module_hamilt_general/module_xc/test/CMakeLists.txt +++ b/source/module_hamilt_general/module_xc/test/CMakeLists.txt @@ -43,7 +43,6 @@ AddTest( ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp ../../../module_base/blas_connector.cpp - ../../../module_basis/module_pw/module_fft/fft_base.cpp ../../../module_basis/module_pw/module_fft/fft_bundle.cpp ../../../module_basis/module_pw/module_fft/fft_cpu.cpp ${FFT_SRC} @@ -82,7 +81,6 @@ AddTest( ../../../module_base/timer.cpp ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp - ../../../module_basis/module_pw/module_fft/fft_base.cpp ../../../module_basis/module_pw/module_fft/fft_bundle.cpp ../../../module_basis/module_pw/module_fft/fft_cpu.cpp ${FFT_SRC} From b60df0d1d9cf07a4a887eed63bc99cd711e5fd47 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Thu, 21 Nov 2024 20:43:33 +0800 Subject: [PATCH 07/11] update the Makefile.Obj --- source/Makefile.Objects | 1 - 1 file changed, 1 deletion(-) diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 62c95ac83b..3a4752ff11 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -408,7 +408,6 @@ OBJS_PSI_INITIALIZER=psi_initializer.o\ psi_initializer_nao_random.o\ OBJS_PW=fft_bundle.o\ - fft_base.o\ fft_cpu.o\ pw_basis.o\ pw_basis_k.o\ From 128283a3ff14c757550ee7ba8d04c13641d360f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:35:43 +0000 Subject: [PATCH 08/11] [pre-commit.ci lite] apply automatic fixes --- source/module_hamilt_general/module_xc/test/xc3_mock.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_hamilt_general/module_xc/test/xc3_mock.h b/source/module_hamilt_general/module_xc/test/xc3_mock.h index d613c8215b..6f812f52a0 100644 --- a/source/module_hamilt_general/module_xc/test/xc3_mock.h +++ b/source/module_hamilt_general/module_xc/test/xc3_mock.h @@ -163,7 +163,7 @@ namespace ModuleBase namespace GlobalV { std::string BASIS_TYPE = ""; - bool CAL_STRESS = 0; + bool CAL_STRESS = false; int CAL_FORCE = 0; int NSPIN; int NPOL; From 7ca986b1fd1301ee8701314afac9d45ec065dd08 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Sat, 23 Nov 2024 16:52:49 +0800 Subject: [PATCH 09/11] revert check_func --- .../module_pw/module_fft/fft_base.h | 2 - .../module_pw/module_fft/fft_bundle.cpp | 7 +- .../module_pw/module_fft/fft_bundle.h | 8 +- .../module_pw/module_fft/fft_cuda.cpp | 6 ++ .../module_pw/module_fft/fft_cuda.h | 10 +- .../module_pw/module_fft/fft_rcom.h | 5 +- .../module_fft/kernel/fft_cuda_func.h | 57 ----------- .../module_fft/kernel/fft_rcom_func.h | 53 ---------- .../kernels/cuda/gemm_selector.cuh | 2 +- source/module_hamilt_pw/hamilt_pwdft/global.h | 99 ++++++++++++++++++- 10 files changed, 119 insertions(+), 130 deletions(-) delete mode 100644 source/module_basis/module_pw/module_fft/kernel/fft_cuda_func.h delete mode 100644 source/module_basis/module_pw/module_fft/kernel/fft_rcom_func.h diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index df371c472a..7711755bc2 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -1,9 +1,7 @@ #ifndef FFT_BASE_H #define FFT_BASE_H - #include #include -#include "module_base/module_device/memory_op.h" namespace ModulePW { template diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index adf5a1b1b2..fb274e4f09 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -1,11 +1,6 @@ #include #include "fft_bundle.h" -#if defined(__CUDA) -#include "fft_cuda.h" -#endif -#if defined(__ROCM) -#include "fft_rcom.h" -#endif + template std::unique_ptr make_unique(Args &&... args) diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index a89b2d3bdb..afc8bc6a77 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -1,11 +1,17 @@ #ifndef FFT_TEMP_H #define FFT_TEMP_H +#include #include "fft_base.h" #include "fft_cpu.h" #include "module_base/module_device/device.h" #include "module_base/module_device/memory_op.h" -#include +#if defined(__CUDA) +#include "fft_cuda.h" +#endif +#if defined(__ROCM) +#include "fft_rcom.h" +#endif namespace ModulePW { diff --git a/source/module_basis/module_pw/module_fft/fft_cuda.cpp b/source/module_basis/module_pw/module_fft/fft_cuda.cpp index 4ac859f9ea..db93fb07fb 100644 --- a/source/module_basis/module_pw/module_fft/fft_cuda.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cuda.cpp @@ -1,5 +1,6 @@ #include "fft_cuda.h" #include "module_base/module_device/memory_op.h" +#include "module_hamilt_pw/hamilt_pwdft/global.h" namespace ModulePW { @@ -105,4 +106,9 @@ template <> std::complex* FFT_CUDA::get_auxr_3d_data() const {return this->c_auxr_3d;} template <> std::complex* FFT_CUDA::get_auxr_3d_data() const {return this->z_auxr_3d;} + +template FFT_CUDA::FFT_CUDA(); +template FFT_CUDA::~FFT_CUDA(); +template FFT_CUDA::FFT_CUDA(); +template FFT_CUDA::~FFT_CUDA(); }// namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_cuda.h b/source/module_basis/module_pw/module_fft/fft_cuda.h index ddd25c9028..4b0e318cb2 100644 --- a/source/module_basis/module_pw/module_fft/fft_cuda.h +++ b/source/module_basis/module_pw/module_fft/fft_cuda.h @@ -1,9 +1,8 @@ #ifndef FFT_CUDA_H #define FFT_CUDA_H - #include "fft_base.h" -#include "kernel/fft_cuda_func.h" - +#include "cufft.h" +#include "cuda_runtime.h" namespace ModulePW { template @@ -62,9 +61,6 @@ class FFT_CUDA : public FFT_BASE std::complex* z_auxr_3d = nullptr; // fft space }; -template FFT_CUDA::FFT_CUDA(); -template FFT_CUDA::~FFT_CUDA(); -template FFT_CUDA::FFT_CUDA(); -template FFT_CUDA::~FFT_CUDA(); + } // namespace ModulePW #endif \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_rcom.h b/source/module_basis/module_pw/module_fft/fft_rcom.h index 38cc4fce2c..a0a3ea49f1 100644 --- a/source/module_basis/module_pw/module_fft/fft_rcom.h +++ b/source/module_basis/module_pw/module_fft/fft_rcom.h @@ -1,8 +1,9 @@ + #ifndef FFT_ROCM_H #define FFT_ROCM_H - #include "fft_base.h" -#include "kernel/fft_rcom_func.h" +#include +#include namespace ModulePW { template diff --git a/source/module_basis/module_pw/module_fft/kernel/fft_cuda_func.h b/source/module_basis/module_pw/module_fft/kernel/fft_cuda_func.h deleted file mode 100644 index 8d19fcb5fb..0000000000 --- a/source/module_basis/module_pw/module_fft/kernel/fft_cuda_func.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef FFT_CUDA_FUNC_H -#define FFT_CUDA_FUNC_H -#include "cufft.h" -#include "cuda_runtime.h" - -static const char* _cufftGetErrorString(cufftResult_t error) -{ - switch (error) - { - case CUFFT_SUCCESS: - return "CUFFT_SUCCESS"; - case CUFFT_INVALID_PLAN: - return "CUFFT_INVALID_PLAN"; - case CUFFT_ALLOC_FAILED: - return "CUFFT_ALLOC_FAILED"; - case CUFFT_INVALID_TYPE: - return "CUFFT_INVALID_TYPE"; - case CUFFT_INVALID_VALUE: - return "CUFFT_INVALID_VALUE"; - case CUFFT_INTERNAL_ERROR: - return "CUFFT_INTERNAL_ERROR"; - case CUFFT_EXEC_FAILED: - return "CUFFT_EXEC_FAILED"; - case CUFFT_SETUP_FAILED: - return "CUFFT_SETUP_FAILED"; - case CUFFT_INVALID_SIZE: - return "CUFFT_INVALID_SIZE"; - case CUFFT_UNALIGNED_DATA: - return "CUFFT_UNALIGNED_DATA"; - case CUFFT_INCOMPLETE_PARAMETER_LIST: - return "CUFFT_INCOMPLETE_PARAMETER_LIST"; - case CUFFT_INVALID_DEVICE: - return "CUFFT_INVALID_DEVICE"; - case CUFFT_PARSE_ERROR: - return "CUFFT_PARSE_ERROR"; - case CUFFT_NO_WORKSPACE: - return "CUFFT_NO_WORKSPACE"; - case CUFFT_NOT_IMPLEMENTED: - return "CUFFT_NOT_IMPLEMENTED"; - case CUFFT_LICENSE_ERROR: - return "CUFFT_LICENSE_ERROR"; - case CUFFT_NOT_SUPPORTED: - return "CUFFT_NOT_SUPPORTED"; - } - return ""; -} - -#define CHECK_CUFFT(func) \ - { \ - cufftResult_t status = (func); \ - if (status != CUFFT_SUCCESS) \ - { \ - printf("In File %s : CUFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - _cufftGetErrorString(status), status); \ - } \ - } -#endif // FFT_CUDA_FUNC_H diff --git a/source/module_basis/module_pw/module_fft/kernel/fft_rcom_func.h b/source/module_basis/module_pw/module_fft/kernel/fft_rcom_func.h deleted file mode 100644 index 06dd76d2bc..0000000000 --- a/source/module_basis/module_pw/module_fft/kernel/fft_rcom_func.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef FFT_ROCM_FUNC_H -#define FFT_ROCM_FUNC_H -#include -#include -static const char* _hipfftGetErrorString(hipfftResult_t error) -{ - switch (error) - { - case HIPFFT_SUCCESS: - return "HIPFFT_SUCCESS"; - case HIPFFT_INVALID_PLAN: - return "HIPFFT_INVALID_PLAN"; - case HIPFFT_ALLOC_FAILED: - return "HIPFFT_ALLOC_FAILED"; - case HIPFFT_INVALID_TYPE: - return "HIPFFT_INVALID_TYPE"; - case HIPFFT_INVALID_VALUE: - return "HIPFFT_INVALID_VALUE"; - case HIPFFT_INTERNAL_ERROR: - return "HIPFFT_INTERNAL_ERROR"; - case HIPFFT_EXEC_FAILED: - return "HIPFFT_EXEC_FAILED"; - case HIPFFT_SETUP_FAILED: - return "HIPFFT_SETUP_FAILED"; - case HIPFFT_INVALID_SIZE: - return "HIPFFT_INVALID_SIZE"; - case HIPFFT_UNALIGNED_DATA: - return "HIPFFT_UNALIGNED_DATA"; - case HIPFFT_INCOMPLETE_PARAMETER_LIST: - return "HIPFFT_INCOMPLETE_PARAMETER_LIST"; - case HIPFFT_INVALID_DEVICE: - return "HIPFFT_INVALID_DEVICE"; - case HIPFFT_PARSE_ERROR: - return "HIPFFT_PARSE_ERROR"; - case HIPFFT_NO_WORKSPACE: - return "HIPFFT_NO_WORKSPACE"; - case HIPFFT_NOT_IMPLEMENTED: - return "HIPFFT_NOT_IMPLEMENTED"; - case HIPFFT_NOT_SUPPORTED: - return "HIPFFT_NOT_SUPPORTED"; - } - return ""; -} -#define CHECK_CUFFT(func) \ - { \ - hipfftResult_t status = (func); \ - if (status != HIPFFT_SUCCESS) \ - { \ - printf("In File %s : HIPFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - _hipfftGetErrorString(status), status); \ - } \ - } -#endif // FFT_ROCM_FUNC_H \ No newline at end of file diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh b/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh index f52b7a8643..380a16c842 100644 --- a/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh @@ -2,7 +2,7 @@ #define GEMM_SELECTOR_H #include "module_cell/unitcell.h" -#include "cuda_runtime.h" + typedef std::function"; } +static const char* _cufftGetErrorString(cufftResult_t error) +{ + switch (error) + { + case CUFFT_SUCCESS: + return "CUFFT_SUCCESS"; + case CUFFT_INVALID_PLAN: + return "CUFFT_INVALID_PLAN"; + case CUFFT_ALLOC_FAILED: + return "CUFFT_ALLOC_FAILED"; + case CUFFT_INVALID_TYPE: + return "CUFFT_INVALID_TYPE"; + case CUFFT_INVALID_VALUE: + return "CUFFT_INVALID_VALUE"; + case CUFFT_INTERNAL_ERROR: + return "CUFFT_INTERNAL_ERROR"; + case CUFFT_EXEC_FAILED: + return "CUFFT_EXEC_FAILED"; + case CUFFT_SETUP_FAILED: + return "CUFFT_SETUP_FAILED"; + case CUFFT_INVALID_SIZE: + return "CUFFT_INVALID_SIZE"; + case CUFFT_UNALIGNED_DATA: + return "CUFFT_UNALIGNED_DATA"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: + return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: + return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: + return "CUFFT_PARSE_ERROR"; + case CUFFT_NO_WORKSPACE: + return "CUFFT_NO_WORKSPACE"; + case CUFFT_NOT_IMPLEMENTED: + return "CUFFT_NOT_IMPLEMENTED"; + case CUFFT_LICENSE_ERROR: + return "CUFFT_LICENSE_ERROR"; + case CUFFT_NOT_SUPPORTED: + return "CUFFT_NOT_SUPPORTED"; + } + return ""; +} #define CHECK_CUDA(func) \ { \ @@ -70,7 +111,15 @@ static const char* _cublasGetErrorString(cublasStatus_t error) } \ } - +#define CHECK_CUFFT(func) \ + { \ + cufftResult_t status = (func); \ + if (status != CUFFT_SUCCESS) \ + { \ + printf("In File %s : CUFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + _cufftGetErrorString(status), status); \ + } \ + } #endif // __CUDA #ifdef __ROCM @@ -118,6 +167,45 @@ static const char* _hipblasGetErrorString(hipblasStatus_t error) // return ""; // } +static const char* _hipfftGetErrorString(hipfftResult_t error) +{ + switch (error) + { + case HIPFFT_SUCCESS: + return "HIPFFT_SUCCESS"; + case HIPFFT_INVALID_PLAN: + return "HIPFFT_INVALID_PLAN"; + case HIPFFT_ALLOC_FAILED: + return "HIPFFT_ALLOC_FAILED"; + case HIPFFT_INVALID_TYPE: + return "HIPFFT_INVALID_TYPE"; + case HIPFFT_INVALID_VALUE: + return "HIPFFT_INVALID_VALUE"; + case HIPFFT_INTERNAL_ERROR: + return "HIPFFT_INTERNAL_ERROR"; + case HIPFFT_EXEC_FAILED: + return "HIPFFT_EXEC_FAILED"; + case HIPFFT_SETUP_FAILED: + return "HIPFFT_SETUP_FAILED"; + case HIPFFT_INVALID_SIZE: + return "HIPFFT_INVALID_SIZE"; + case HIPFFT_UNALIGNED_DATA: + return "HIPFFT_UNALIGNED_DATA"; + case HIPFFT_INCOMPLETE_PARAMETER_LIST: + return "HIPFFT_INCOMPLETE_PARAMETER_LIST"; + case HIPFFT_INVALID_DEVICE: + return "HIPFFT_INVALID_DEVICE"; + case HIPFFT_PARSE_ERROR: + return "HIPFFT_PARSE_ERROR"; + case HIPFFT_NO_WORKSPACE: + return "HIPFFT_NO_WORKSPACE"; + case HIPFFT_NOT_IMPLEMENTED: + return "HIPFFT_NOT_IMPLEMENTED"; + case HIPFFT_NOT_SUPPORTED: + return "HIPFFT_NOT_SUPPORTED"; + } + return ""; +} #define CHECK_CUDA(func) \ { \ @@ -149,6 +237,15 @@ static const char* _hipblasGetErrorString(hipblasStatus_t error) // }\ // } +#define CHECK_CUFFT(func) \ + { \ + hipfftResult_t status = (func); \ + if (status != HIPFFT_SUCCESS) \ + { \ + printf("In File %s : HIPFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + _hipfftGetErrorString(status), status); \ + } \ + } #endif // __ROCM //========================================================== From 224e3ae9b142d05065e36b573616c1a4bbd1aa1c Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Mon, 25 Nov 2024 17:33:27 +0800 Subject: [PATCH 10/11] update the header file --- .../module_basis/module_pw/module_fft/fft_bundle.cpp | 10 +++++++++- source/module_basis/module_pw/module_fft/fft_bundle.h | 9 --------- .../module_gint/kernels/cuda/gemm_selector.cuh | 2 +- source/module_hamilt_pw/hamilt_pwdft/global.h | 2 ++ 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index fb274e4f09..fce85049fb 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -1,6 +1,14 @@ #include #include "fft_bundle.h" - +#include "fft_cpu.h" +#include "module_base/module_device/device.h" +#include "module_base/module_device/memory_op.h" +#if defined(__CUDA) +#include "fft_cuda.h" +#endif +#if defined(__ROCM) +#include "fft_rcom.h" +#endif template std::unique_ptr make_unique(Args &&... args) diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index afc8bc6a77..6b62f983a9 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -3,15 +3,6 @@ #include #include "fft_base.h" -#include "fft_cpu.h" -#include "module_base/module_device/device.h" -#include "module_base/module_device/memory_op.h" -#if defined(__CUDA) -#include "fft_cuda.h" -#endif -#if defined(__ROCM) -#include "fft_rcom.h" -#endif namespace ModulePW { diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh b/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh index 380a16c842..f52b7a8643 100644 --- a/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh @@ -2,7 +2,7 @@ #define GEMM_SELECTOR_H #include "module_cell/unitcell.h" - +#include "cuda_runtime.h" typedef std::function Date: Mon, 25 Nov 2024 18:09:43 +0800 Subject: [PATCH 11/11] change fft_cpu.h --- source/module_basis/module_pw/module_fft/fft_base.h | 2 +- .../module_basis/module_pw/module_fft/fft_bundle.cpp | 2 +- source/module_basis/module_pw/module_fft/fft_bundle.h | 2 +- source/module_basis/module_pw/module_fft/fft_cpu.h | 10 +++------- source/module_basis/module_pw/module_fft/fft_cuda.h | 1 + source/module_basis/module_pw/module_fft/fft_rcom.h | 1 + 6 files changed, 8 insertions(+), 10 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index 7711755bc2..b64b6f4e00 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -1,7 +1,7 @@ #ifndef FFT_BASE_H #define FFT_BASE_H + #include -#include namespace ModulePW { template diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index fce85049fb..b3e06c010e 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -1,6 +1,6 @@ #include #include "fft_bundle.h" -#include "fft_cpu.h" + #include "module_base/module_device/device.h" #include "module_base/module_device/memory_op.h" #if defined(__CUDA) diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 6b62f983a9..71ce5192f3 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -3,7 +3,7 @@ #include #include "fft_base.h" - +#include "fft_cpu.h" namespace ModulePW { class FFT_Bundle diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.h b/source/module_basis/module_pw/module_fft/fft_cpu.h index ebd74f22cb..c0fe9992eb 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.h +++ b/source/module_basis/module_pw/module_fft/fft_cpu.h @@ -1,12 +1,8 @@ -#include "fft_base.h" -#include "fftw3.h" - -// #ifdef __ENABLE_FLOAT_FFTW - -// #endif -// #endif #ifndef FFT_CPU_H #define FFT_CPU_H + +#include "fft_base.h" +#include "fftw3.h" namespace ModulePW { template diff --git a/source/module_basis/module_pw/module_fft/fft_cuda.h b/source/module_basis/module_pw/module_fft/fft_cuda.h index 4b0e318cb2..4942ee33f2 100644 --- a/source/module_basis/module_pw/module_fft/fft_cuda.h +++ b/source/module_basis/module_pw/module_fft/fft_cuda.h @@ -1,5 +1,6 @@ #ifndef FFT_CUDA_H #define FFT_CUDA_H + #include "fft_base.h" #include "cufft.h" #include "cuda_runtime.h" diff --git a/source/module_basis/module_pw/module_fft/fft_rcom.h b/source/module_basis/module_pw/module_fft/fft_rcom.h index a0a3ea49f1..1c316b98cd 100644 --- a/source/module_basis/module_pw/module_fft/fft_rcom.h +++ b/source/module_basis/module_pw/module_fft/fft_rcom.h @@ -1,6 +1,7 @@ #ifndef FFT_ROCM_H #define FFT_ROCM_H + #include "fft_base.h" #include #include