Skip to content

Commit

Permalink
revert check_func
Browse files Browse the repository at this point in the history
  • Loading branch information
A-006 committed Nov 23, 2024
1 parent 128283a commit 7ca986b
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 130 deletions.
2 changes: 0 additions & 2 deletions source/module_basis/module_pw/module_fft/fft_base.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#ifndef FFT_BASE_H
#define FFT_BASE_H

#include <complex>
#include <string>
#include "module_base/module_device/memory_op.h"
namespace ModulePW
{
template <typename FPTYPE>
Expand Down
7 changes: 1 addition & 6 deletions source/module_basis/module_pw/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
#include <cassert>
#include "fft_bundle.h"
#if defined(__CUDA)
#include "fft_cuda.h"
#endif
#if defined(__ROCM)
#include "fft_rcom.h"
#endif


template<typename FFT_BASE, typename... Args>
std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
Expand Down
8 changes: 7 additions & 1 deletion source/module_basis/module_pw/module_fft/fft_bundle.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#ifndef FFT_TEMP_H
#define FFT_TEMP_H

#include <memory>
#include "fft_base.h"
#include "fft_cpu.h"
#include "module_base/module_device/device.h"
#include "module_base/module_device/memory_op.h"
#include <memory>
#if defined(__CUDA)
#include "fft_cuda.h"
#endif
#if defined(__ROCM)
#include "fft_rcom.h"
#endif

namespace ModulePW
{
Expand Down
6 changes: 6 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "fft_cuda.h"
#include "module_base/module_device/memory_op.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"

namespace ModulePW
{
Expand Down Expand Up @@ -105,4 +106,9 @@ template <> std::complex<float>*
FFT_CUDA<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
template <> std::complex<double>*
FFT_CUDA<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}

template FFT_CUDA<float>::FFT_CUDA();
template FFT_CUDA<float>::~FFT_CUDA();
template FFT_CUDA<double>::FFT_CUDA();
template FFT_CUDA<double>::~FFT_CUDA();
}// namespace ModulePW
10 changes: 3 additions & 7 deletions source/module_basis/module_pw/module_fft/fft_cuda.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#ifndef FFT_CUDA_H
#define FFT_CUDA_H

#include "fft_base.h"
#include "kernel/fft_cuda_func.h"

#include "cufft.h"
#include "cuda_runtime.h"
namespace ModulePW
{
template <typename FPTYPE>
Expand Down Expand Up @@ -62,9 +61,6 @@ class FFT_CUDA : public FFT_BASE<FPTYPE>
std::complex<double>* z_auxr_3d = nullptr; // fft space

};
template FFT_CUDA<float>::FFT_CUDA();
template FFT_CUDA<float>::~FFT_CUDA();
template FFT_CUDA<double>::FFT_CUDA();
template FFT_CUDA<double>::~FFT_CUDA();

} // namespace ModulePW
#endif
5 changes: 3 additions & 2 deletions source/module_basis/module_pw/module_fft/fft_rcom.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@

#ifndef FFT_ROCM_H
#define FFT_ROCM_H

#include "fft_base.h"
#include "kernel/fft_rcom_func.h"
#include <hipfft/hipfft.h>
#include <hip/hip_runtime.h>
namespace ModulePW
{
template <typename FPTYPE>
Expand Down
57 changes: 0 additions & 57 deletions source/module_basis/module_pw/module_fft/kernel/fft_cuda_func.h

This file was deleted.

53 changes: 0 additions & 53 deletions source/module_basis/module_pw/module_fft/kernel/fft_rcom_func.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define GEMM_SELECTOR_H

#include "module_cell/unitcell.h"
#include "cuda_runtime.h"

typedef std::function<void(int,
int,
int*,
Expand Down
99 changes: 98 additions & 1 deletion source/module_hamilt_pw/hamilt_pwdft/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,47 @@ static const char* _cublasGetErrorString(cublasStatus_t error)
return "<unknown>";
}

static const char* _cufftGetErrorString(cufftResult_t error)
{
switch (error)
{
case CUFFT_SUCCESS:
return "CUFFT_SUCCESS";
case CUFFT_INVALID_PLAN:
return "CUFFT_INVALID_PLAN";
case CUFFT_ALLOC_FAILED:
return "CUFFT_ALLOC_FAILED";
case CUFFT_INVALID_TYPE:
return "CUFFT_INVALID_TYPE";
case CUFFT_INVALID_VALUE:
return "CUFFT_INVALID_VALUE";
case CUFFT_INTERNAL_ERROR:
return "CUFFT_INTERNAL_ERROR";
case CUFFT_EXEC_FAILED:
return "CUFFT_EXEC_FAILED";
case CUFFT_SETUP_FAILED:
return "CUFFT_SETUP_FAILED";
case CUFFT_INVALID_SIZE:
return "CUFFT_INVALID_SIZE";
case CUFFT_UNALIGNED_DATA:
return "CUFFT_UNALIGNED_DATA";
case CUFFT_INCOMPLETE_PARAMETER_LIST:
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
case CUFFT_INVALID_DEVICE:
return "CUFFT_INVALID_DEVICE";
case CUFFT_PARSE_ERROR:
return "CUFFT_PARSE_ERROR";
case CUFFT_NO_WORKSPACE:
return "CUFFT_NO_WORKSPACE";
case CUFFT_NOT_IMPLEMENTED:
return "CUFFT_NOT_IMPLEMENTED";
case CUFFT_LICENSE_ERROR:
return "CUFFT_LICENSE_ERROR";
case CUFFT_NOT_SUPPORTED:
return "CUFFT_NOT_SUPPORTED";
}
return "<unknown>";
}

#define CHECK_CUDA(func) \
{ \
Expand Down Expand Up @@ -70,7 +111,15 @@ static const char* _cublasGetErrorString(cublasStatus_t error)
} \
}


#define CHECK_CUFFT(func) \
{ \
cufftResult_t status = (func); \
if (status != CUFFT_SUCCESS) \
{ \
printf("In File %s : CUFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \
_cufftGetErrorString(status), status); \
} \
}
#endif // __CUDA

#ifdef __ROCM
Expand Down Expand Up @@ -118,6 +167,45 @@ static const char* _hipblasGetErrorString(hipblasStatus_t error)
// return "<unknown>";
// }

static const char* _hipfftGetErrorString(hipfftResult_t error)
{
switch (error)
{
case HIPFFT_SUCCESS:
return "HIPFFT_SUCCESS";
case HIPFFT_INVALID_PLAN:
return "HIPFFT_INVALID_PLAN";
case HIPFFT_ALLOC_FAILED:
return "HIPFFT_ALLOC_FAILED";
case HIPFFT_INVALID_TYPE:
return "HIPFFT_INVALID_TYPE";
case HIPFFT_INVALID_VALUE:
return "HIPFFT_INVALID_VALUE";
case HIPFFT_INTERNAL_ERROR:
return "HIPFFT_INTERNAL_ERROR";
case HIPFFT_EXEC_FAILED:
return "HIPFFT_EXEC_FAILED";
case HIPFFT_SETUP_FAILED:
return "HIPFFT_SETUP_FAILED";
case HIPFFT_INVALID_SIZE:
return "HIPFFT_INVALID_SIZE";
case HIPFFT_UNALIGNED_DATA:
return "HIPFFT_UNALIGNED_DATA";
case HIPFFT_INCOMPLETE_PARAMETER_LIST:
return "HIPFFT_INCOMPLETE_PARAMETER_LIST";
case HIPFFT_INVALID_DEVICE:
return "HIPFFT_INVALID_DEVICE";
case HIPFFT_PARSE_ERROR:
return "HIPFFT_PARSE_ERROR";
case HIPFFT_NO_WORKSPACE:
return "HIPFFT_NO_WORKSPACE";
case HIPFFT_NOT_IMPLEMENTED:
return "HIPFFT_NOT_IMPLEMENTED";
case HIPFFT_NOT_SUPPORTED:
return "HIPFFT_NOT_SUPPORTED";
}
return "<unknown>";
}

#define CHECK_CUDA(func) \
{ \
Expand Down Expand Up @@ -149,6 +237,15 @@ static const char* _hipblasGetErrorString(hipblasStatus_t error)
// }\
// }

#define CHECK_CUFFT(func) \
{ \
hipfftResult_t status = (func); \
if (status != HIPFFT_SUCCESS) \
{ \
printf("In File %s : HIPFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \
_hipfftGetErrorString(status), status); \
} \
}
#endif // __ROCM

//==========================================================
Expand Down

0 comments on commit 7ca986b

Please sign in to comment.