Skip to content

Commit

Permalink
Feature: Optimized memory management on DSP (#5361)
Browse files Browse the repository at this point in the history
* Initial commit

* Change memory_op construction

* I finally find this

* Fix template bug

* Fix memory header definition

* Optimize memory op usage

* Update diago_subspace

* No change

* Fix MPI Error

* Make the extra memory usage DSP-hardware-specialized. Add some annotations.

* Reorganize dsp codes

* Fix bug 1

* Fix bug 2

* Finish transporting codes

---------

Co-authored-by: Mohan Chen <[email protected]>
  • Loading branch information
Critsium-xy and mohanchen authored Nov 5, 2024
1 parent 535c7f9 commit ddf990f
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 7 deletions.
8 changes: 4 additions & 4 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mt_(&transb, &transa, &n, &m, &k,
sgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand All @@ -111,7 +111,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mt_(&transb, &transa, &n, &m, &k,
dgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand All @@ -129,7 +129,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mt_(&transb, &transa, &n, &m, &k,
cgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand All @@ -147,7 +147,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mt_(&transb, &transa, &n, &m, &k,
zgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand Down
65 changes: 65 additions & 0 deletions source/module_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#define DSP_CONNECTOR_H
#ifdef __DSP

#include "module_base/module_device/device.h"
#include "module_base/module_device/memory_op.h"
#include "module_hsolver/diag_comm_info.h"

// Base dsp functions
void dspInitHandle(int id);
void dspDestoryHandle(int id);
Expand Down Expand Up @@ -62,5 +66,66 @@ void cgemm_mth_(const char *transa, const char *transb,

//#define zgemm_ zgemm_mt

// The next is dsp utils. It may be moved to other files if this file get too huge

template <typename T>
void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv, MPI_Comm diag_comm){

using syncmem_complex_op = base_device::memory::synchronize_memory_op<T, base_device::DEVICE_CPU, base_device::DEVICE_CPU>;

auto* swap = new T[notconv * nbase_x];
auto* target = new T[notconv * nbase_x];
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, hcc + nbase * nbase_x, notconv * nbase_x);
if (base_device::get_current_precision(swap) == "single")
{
MPI_Reduce(swap,
target,
notconv * nbase_x,
MPI_COMPLEX,
MPI_SUM,
0,
diag_comm);
}
else
{
MPI_Reduce(swap,
target,
notconv * nbase_x,
MPI_DOUBLE_COMPLEX,
MPI_SUM,
0,
diag_comm);
}

syncmem_complex_op()(cpu_ctx, cpu_ctx, hcc + nbase * nbase_x, target, notconv * nbase_x);
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, scc + nbase * nbase_x, notconv * nbase_x);

if (base_device::get_current_precision(swap) == "single")
{
MPI_Reduce(swap,
target,
notconv * nbase_x,
MPI_COMPLEX,
MPI_SUM,
0,
diag_comm);
}
else
{
MPI_Reduce(swap,
target,
notconv * nbase_x,
MPI_DOUBLE_COMPLEX,
MPI_SUM,
0,
diag_comm);
}

syncmem_complex_op()(cpu_ctx, cpu_ctx, scc + nbase * nbase_x, target, notconv * nbase_x);
delete[] swap;
delete[] target;
}


#endif
#endif
52 changes: 52 additions & 0 deletions source/module_base/module_device/memory_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,5 +346,57 @@ template struct delete_memory_op<std::complex<float>, base_device::DEVICE_GPU>;
template struct delete_memory_op<std::complex<double>, base_device::DEVICE_GPU>;
#endif

#ifdef __DSP

template <typename FPTYPE>
struct resize_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE*& arr, const size_t size, const char* record_in)
{
if (arr != nullptr)
{
free_ht(arr);
}
arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK);
std::string record_string;
if (record_in != nullptr)
{
record_string = record_in;
}
else
{
record_string = "no_record";
}

if (record_string != "no_record")
{
ModuleBase::Memory::record(record_string, sizeof(FPTYPE) * size);
}
}
};

template <typename FPTYPE>
struct delete_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
{
free_ht(arr);
}
};


template struct resize_memory_op_mt<int, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<float, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<double, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;

template struct delete_memory_op_mt<int, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<float, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<double, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
#endif

} // namespace memory
} // namespace base_device
31 changes: 30 additions & 1 deletion source/module_base/module_device/memory_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,36 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_GPU>
};
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM

#ifdef __DSP

template <typename FPTYPE, typename Device>
struct resize_memory_op_mt
{
/// @brief Allocate memory for a given pointer. Note this op will free the pointer first.
///
/// Input Parameters
/// \param dev : the type of computing device
/// \param size : array size
/// \param record_string : label for memory record
///
/// Output Parameters
/// \param arr : allocated array
void operator()(const Device* dev, FPTYPE*& arr, const size_t size, const char* record_in = nullptr);
};

template <typename FPTYPE, typename Device>
struct delete_memory_op_mt
{
/// @brief free memory for multi-device
///
/// Input Parameters
/// \param dev : the type of computing device
/// \param arr : the input array
void operator()(const Device* dev, FPTYPE* arr);
};

#endif // __DSP

} // end of namespace memory
} // end of namespace base_device

Expand Down Expand Up @@ -233,5 +263,4 @@ using castmem_z2c_d2h_op = base_device::memory::

static base_device::DEVICE_CPU* cpu_ctx = {};
static base_device::DEVICE_GPU* gpu_ctx = {};

#endif // MODULE_DEVICE_MEMORY_H_
1 change: 1 addition & 0 deletions source/module_base/module_device/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace base_device

struct DEVICE_CPU;
struct DEVICE_GPU;
struct DEVICE_DSP;

enum AbacusDevice_t
{
Expand Down
9 changes: 8 additions & 1 deletion source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "module_base/timer.h"
#include "module_hsolver/kernels/dngvd_op.h"
#include "module_hsolver/kernels/math_kernel_op.h"
#include "module_base/kernels/dsp/dsp_connector.h"

#include <vector>

Expand Down Expand Up @@ -182,7 +183,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax);

#ifdef __DSP
gemm_op_mt<T, Device>()
gemm_op_mt<T, Device>() // In order to not coding another whole template, using this method to minimize the code change.
#else
gemm_op<T, Device>()
#endif
Expand Down Expand Up @@ -444,7 +445,12 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
#ifdef __MPI
if (this->diag_comm.nproc > 1)
{
#ifdef __DSP
// Only on dsp hardware need an extra space to reduce data
dsp_dav_subspace_reduce(hcc, scc, nbase, this->nbase_x, this->notconv, this->diag_comm.comm);
#else
auto* swap = new T[notconv * this->nbase_x];

syncmem_complex_op()(this->ctx, this->ctx, swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x);

if (std::is_same<T, double>::value)
Expand Down Expand Up @@ -499,6 +505,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
}
}
delete[] swap;
#endif
}
#endif

Expand Down
10 changes: 10 additions & 0 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,22 @@ class Diago_DavSubspace

bool test_exit_cond(const int& ntry, const int& notconv, const bool& scf);

#ifdef __DSP
using resmem_complex_op = base_device::memory::resize_memory_op_mt<T, Device>;
using delmem_complex_op = base_device::memory::delete_memory_op_mt<T, Device>;
#else
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
#endif
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;

#ifdef __DSP
using resmem_real_op = base_device::memory::resize_memory_op_mt<Real, Device>;
using delmem_real_op = base_device::memory::delete_memory_op_mt<Real, Device>;
#else
using resmem_real_op = base_device::memory::resize_memory_op<Real, Device>;
using delmem_real_op = base_device::memory::delete_memory_op<Real, Device>;
#endif
using setmem_real_op = base_device::memory::set_memory_op<Real, Device>;

using resmem_real_h_op = base_device::memory::resize_memory_op<Real, base_device::DEVICE_CPU>;
Expand Down
8 changes: 7 additions & 1 deletion source/module_psi/psi.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,16 @@ class Psi

bool allocate_inside = true; ///<whether allocate psi inside Psi class

using set_memory_op = base_device::memory::set_memory_op<T, Device>;
#ifdef __DSP
using delete_memory_op = base_device::memory::delete_memory_op_mt<T, Device>;
using resize_memory_op = base_device::memory::resize_memory_op_mt<T, Device>;
#else
using delete_memory_op = base_device::memory::delete_memory_op<T, Device>;
using resize_memory_op = base_device::memory::resize_memory_op<T, Device>;
#endif
using set_memory_op = base_device::memory::set_memory_op<T, Device>;
using synchronize_memory_op = base_device::memory::synchronize_memory_op<T, Device, Device>;

};

} // end of namespace psi
Expand Down

0 comments on commit ddf990f

Please sign in to comment.