From ddf990f8a5a95125c8dfdb9da711fb21850d8cae Mon Sep 17 00:00:00 2001 From: Critsium Date: Mon, 4 Nov 2024 21:29:01 -0500 Subject: [PATCH] Feature: Optimized memory management on DSP (#5361) * Initial commit * Change memory_op construction * I finally find this * Fix template bug * Fix memory header definition * Optimize memory op usage * Update diago_subspace * No change * Fix MPI Error * Make the extra memory usage DSP-hardware-specialized. Add some annotations. * Reorganize dsp codes * Fix bug 1 * Fix bug 2 * Finish transporting codes --------- Co-authored-by: Mohan Chen --- source/module_base/blas_connector.cpp | 8 +-- .../module_base/kernels/dsp/dsp_connector.h | 65 +++++++++++++++++++ .../module_base/module_device/memory_op.cpp | 52 +++++++++++++++ source/module_base/module_device/memory_op.h | 31 ++++++++- source/module_base/module_device/types.h | 1 + source/module_hsolver/diago_dav_subspace.cpp | 9 ++- source/module_hsolver/diago_dav_subspace.h | 10 +++ source/module_psi/psi.h | 8 ++- 8 files changed, 177 insertions(+), 7 deletions(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 1de321ca99..30b3b93d40 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -93,7 +93,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - sgemm_mt_(&transb, &transa, &n, &m, &k, + sgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -111,7 +111,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - dgemm_mt_(&transb, &transa, &n, &m, &k, + dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -129,7 +129,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - cgemm_mt_(&transb, &transa, &n, &m, &k, + cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -147,7 +147,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - zgemm_mt_(&transb, &transa, &n, &m, &k, + zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } diff --git a/source/module_base/kernels/dsp/dsp_connector.h b/source/module_base/kernels/dsp/dsp_connector.h index a928a0b095..b51c67663e 100644 --- a/source/module_base/kernels/dsp/dsp_connector.h +++ b/source/module_base/kernels/dsp/dsp_connector.h @@ -2,6 +2,10 @@ #define DSP_CONNECTOR_H #ifdef __DSP +#include "module_base/module_device/device.h" +#include "module_base/module_device/memory_op.h" +#include "module_hsolver/diag_comm_info.h" + // Base dsp functions void dspInitHandle(int id); void dspDestoryHandle(int id); @@ -62,5 +66,66 @@ void cgemm_mth_(const char *transa, const char *transb, //#define zgemm_ zgemm_mt +// The next is dsp utils. It may be moved to other files if this file get too huge + +template +void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv, MPI_Comm diag_comm){ + + using syncmem_complex_op = base_device::memory::synchronize_memory_op; + + auto* swap = new T[notconv * nbase_x]; + auto* target = new T[notconv * nbase_x]; + syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, hcc + nbase * nbase_x, notconv * nbase_x); + if (base_device::get_current_precision(swap) == "single") + { + MPI_Reduce(swap, + target, + notconv * nbase_x, + MPI_COMPLEX, + MPI_SUM, + 0, + diag_comm); + } + else + { + MPI_Reduce(swap, + target, + notconv * nbase_x, + MPI_DOUBLE_COMPLEX, + MPI_SUM, + 0, + diag_comm); + } + + syncmem_complex_op()(cpu_ctx, cpu_ctx, hcc + nbase * nbase_x, target, notconv * nbase_x); + syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, scc + nbase * nbase_x, notconv * nbase_x); + + if (base_device::get_current_precision(swap) == "single") + { + MPI_Reduce(swap, + target, + notconv * nbase_x, + MPI_COMPLEX, + MPI_SUM, + 0, + diag_comm); + } + else + { + MPI_Reduce(swap, + target, + notconv * nbase_x, + MPI_DOUBLE_COMPLEX, + MPI_SUM, + 0, + diag_comm); + } + + syncmem_complex_op()(cpu_ctx, cpu_ctx, scc + nbase * nbase_x, target, notconv * nbase_x); + delete[] swap; + delete[] target; +} + + #endif #endif \ No newline at end of file diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index 00c4a36ad7..f989924d30 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -346,5 +346,57 @@ template struct delete_memory_op, base_device::DEVICE_GPU>; template struct delete_memory_op, base_device::DEVICE_GPU>; #endif +#ifdef __DSP + +template +struct resize_memory_op_mt +{ + void operator()(const base_device::DEVICE_CPU* dev, FPTYPE*& arr, const size_t size, const char* record_in) + { + if (arr != nullptr) + { + free_ht(arr); + } + arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK); + std::string record_string; + if (record_in != nullptr) + { + record_string = record_in; + } + else + { + record_string = "no_record"; + } + + if (record_string != "no_record") + { + ModuleBase::Memory::record(record_string, sizeof(FPTYPE) * size); + } + } +}; + +template +struct delete_memory_op_mt +{ + void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr) + { + free_ht(arr); + } +}; + + +template struct resize_memory_op_mt; +template struct resize_memory_op_mt; +template struct resize_memory_op_mt; +template struct resize_memory_op_mt, base_device::DEVICE_CPU>; +template struct resize_memory_op_mt, base_device::DEVICE_CPU>; + +template struct delete_memory_op_mt; +template struct delete_memory_op_mt; +template struct delete_memory_op_mt; +template struct delete_memory_op_mt, base_device::DEVICE_CPU>; +template struct delete_memory_op_mt, base_device::DEVICE_CPU>; +#endif + } // namespace memory } // namespace base_device \ No newline at end of file diff --git a/source/module_base/module_device/memory_op.h b/source/module_base/module_device/memory_op.h index d1b4995937..49ca788d0a 100644 --- a/source/module_base/module_device/memory_op.h +++ b/source/module_base/module_device/memory_op.h @@ -146,6 +146,36 @@ struct delete_memory_op }; #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM +#ifdef __DSP + +template +struct resize_memory_op_mt +{ + /// @brief Allocate memory for a given pointer. Note this op will free the pointer first. + /// + /// Input Parameters + /// \param dev : the type of computing device + /// \param size : array size + /// \param record_string : label for memory record + /// + /// Output Parameters + /// \param arr : allocated array + void operator()(const Device* dev, FPTYPE*& arr, const size_t size, const char* record_in = nullptr); +}; + +template +struct delete_memory_op_mt +{ + /// @brief free memory for multi-device + /// + /// Input Parameters + /// \param dev : the type of computing device + /// \param arr : the input array + void operator()(const Device* dev, FPTYPE* arr); +}; + +#endif // __DSP + } // end of namespace memory } // end of namespace base_device @@ -233,5 +263,4 @@ using castmem_z2c_d2h_op = base_device::memory:: static base_device::DEVICE_CPU* cpu_ctx = {}; static base_device::DEVICE_GPU* gpu_ctx = {}; - #endif // MODULE_DEVICE_MEMORY_H_ \ No newline at end of file diff --git a/source/module_base/module_device/types.h b/source/module_base/module_device/types.h index 153b6ab8ca..81413006f4 100644 --- a/source/module_base/module_device/types.h +++ b/source/module_base/module_device/types.h @@ -6,6 +6,7 @@ namespace base_device struct DEVICE_CPU; struct DEVICE_GPU; +struct DEVICE_DSP; enum AbacusDevice_t { diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index c1724eb136..d11d7093f1 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -6,6 +6,7 @@ #include "module_base/timer.h" #include "module_hsolver/kernels/dngvd_op.h" #include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/dsp/dsp_connector.h" #include @@ -182,7 +183,7 @@ int Diago_DavSubspace::diag_once(const HPsiFunc& hpsi_func, setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax); #ifdef __DSP - gemm_op_mt() + gemm_op_mt() // In order to not coding another whole template, using this method to minimize the code change. #else gemm_op() #endif @@ -444,7 +445,12 @@ void Diago_DavSubspace::cal_elem(const int& dim, #ifdef __MPI if (this->diag_comm.nproc > 1) { +#ifdef __DSP + // Only on dsp hardware need an extra space to reduce data + dsp_dav_subspace_reduce(hcc, scc, nbase, this->nbase_x, this->notconv, this->diag_comm.comm); +#else auto* swap = new T[notconv * this->nbase_x]; + syncmem_complex_op()(this->ctx, this->ctx, swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x); if (std::is_same::value) @@ -499,6 +505,7 @@ void Diago_DavSubspace::cal_elem(const int& dim, } } delete[] swap; +#endif } #endif diff --git a/source/module_hsolver/diago_dav_subspace.h b/source/module_hsolver/diago_dav_subspace.h index dc1e6c79a6..d93c252602 100644 --- a/source/module_hsolver/diago_dav_subspace.h +++ b/source/module_hsolver/diago_dav_subspace.h @@ -139,12 +139,22 @@ class Diago_DavSubspace bool test_exit_cond(const int& ntry, const int& notconv, const bool& scf); +#ifdef __DSP + using resmem_complex_op = base_device::memory::resize_memory_op_mt; + using delmem_complex_op = base_device::memory::delete_memory_op_mt; +#else using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; +#endif using setmem_complex_op = base_device::memory::set_memory_op; +#ifdef __DSP + using resmem_real_op = base_device::memory::resize_memory_op_mt; + using delmem_real_op = base_device::memory::delete_memory_op_mt; +#else using resmem_real_op = base_device::memory::resize_memory_op; using delmem_real_op = base_device::memory::delete_memory_op; +#endif using setmem_real_op = base_device::memory::set_memory_op; using resmem_real_h_op = base_device::memory::resize_memory_op; diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 39956321fd..283c641204 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -143,10 +143,16 @@ class Psi bool allocate_inside = true; ///; +#ifdef __DSP + using delete_memory_op = base_device::memory::delete_memory_op_mt; + using resize_memory_op = base_device::memory::resize_memory_op_mt; +#else using delete_memory_op = base_device::memory::delete_memory_op; using resize_memory_op = base_device::memory::resize_memory_op; +#endif + using set_memory_op = base_device::memory::set_memory_op; using synchronize_memory_op = base_device::memory::synchronize_memory_op; + }; } // end of namespace psi