Skip to content

Commit

Permalink
Feature: add interface Gint::psir_func (#5380)
Browse files Browse the repository at this point in the history
  • Loading branch information
PeizeLin authored Nov 7, 2024
1 parent 01beb19 commit 3da2868
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 116 deletions.
63 changes: 29 additions & 34 deletions source/module_base/array_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,56 +15,51 @@ namespace ModuleBase
class Array_Pool
{
public:
Array_Pool();
Array_Pool(const int nr, const int nc);
Array_Pool() = default;
Array_Pool(const int nr_in, const int nc_in);
Array_Pool(Array_Pool<T>&& other);
Array_Pool& operator=(Array_Pool<T>&& other);
~Array_Pool();
Array_Pool(const Array_Pool<T>& other) = delete;
Array_Pool& operator=(const Array_Pool& other) = delete;

T** get_ptr_2D() const { return ptr_2D; }
T* get_ptr_1D() const { return ptr_1D; }
int get_nr() const { return nr; }
int get_nc() const { return nc; }
T* operator[](const int ir) const { return ptr_2D[ir]; }
T** get_ptr_2D() const { return this->ptr_2D; }
T* get_ptr_1D() const { return this->ptr_1D; }
int get_nr() const { return this->nr; }
int get_nc() const { return this->nc; }
T* operator[](const int ir) const { return this->ptr_2D[ir]; }
private:
T** ptr_2D;
T* ptr_1D;
int nr;
int nc;
T** ptr_2D = nullptr;
T* ptr_1D = nullptr;
int nr = 0;
int nc = 0;
};

template <typename T>
Array_Pool<T>::Array_Pool() : ptr_2D(nullptr), ptr_1D(nullptr), nr(0), nc(0)
Array_Pool<T>::Array_Pool(const int nr_in, const int nc_in) // Attention: uninitialized
: nr(nr_in),
nc(nc_in)
{
}

template <typename T>
Array_Pool<T>::Array_Pool(const int nr, const int nc) // Attention: uninitialized
{
this->nr = nr;
this->nc = nc;
ptr_1D = new T[nr * nc];
ptr_2D = new T*[nr];
this->ptr_1D = new T[nr * nc];
this->ptr_2D = new T*[nr];
for (int ir = 0; ir < nr; ++ir)
ptr_2D[ir] = &ptr_1D[ir * nc];
this->ptr_2D[ir] = &this->ptr_1D[ir * nc];
}

template <typename T>
Array_Pool<T>::~Array_Pool()
{
delete[] ptr_2D;
delete[] ptr_1D;
delete[] this->ptr_2D;
delete[] this->ptr_1D;
}

template <typename T>
Array_Pool<T>::Array_Pool(Array_Pool<T>&& other)
: ptr_2D(other.ptr_2D),
ptr_1D(other.ptr_1D),
nr(other.nr),
nc(other.nc)
{
ptr_2D = other.ptr_2D;
ptr_1D = other.ptr_1D;
nr = other.nr;
nc = other.nc;
other.ptr_2D = nullptr;
other.ptr_1D = nullptr;
other.nr = 0;
Expand All @@ -76,12 +71,12 @@ namespace ModuleBase
{
if (this != &other)
{
delete[] ptr_2D;
delete[] ptr_1D;
ptr_2D = other.ptr_2D;
ptr_1D = other.ptr_1D;
nr = other.nr;
nc = other.nc;
delete[] this->ptr_2D;
delete[] this->ptr_1D;
this->ptr_2D = other.ptr_2D;
this->ptr_1D = other.ptr_1D;
this->nr = other.nr;
this->nc = other.nc;
other.ptr_2D = nullptr;
other.ptr_1D = nullptr;
other.nr = 0;
Expand Down
3 changes: 2 additions & 1 deletion source/module_hamilt_lcao/module_gint/cal_psir_ylm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#include "module_base/ylm.h"
namespace Gint_Tools{
void cal_psir_ylm(
const Grid_Technique& gt, const int bxyz,
const Grid_Technique& gt,
const int bxyz,
const int na_grid, // number of atoms on this grid
const int grid_index, // 1d index of FFT index (i,j,k)
const double delta_r, // delta_r of the uniform FFT grid
Expand Down
51 changes: 35 additions & 16 deletions source/module_hamilt_lcao/module_gint/gint.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "module_cell/module_neighbor/sltk_grid_driver.h"
#include "module_hamilt_lcao/module_gint/grid_technique.h"
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"

#include <functional>

class Gint {
public:
~Gint();
Expand Down Expand Up @@ -64,6 +67,21 @@ class Gint {
const Grid_Technique* gridt = nullptr;
const UnitCell* ucell;

// psir_ylm_new = psir_func(psir_ylm)
// psir_func==nullptr means psir_ylm_new=psir_ylm
using T_psir_func = std::function<
const ModuleBase::Array_Pool<double>&(
const ModuleBase::Array_Pool<double> &psir_ylm,
const Grid_Technique &gt,
const int grid_index,
const int is,
const std::vector<int> &block_iw,
const std::vector<int> &block_size,
const std::vector<int> &block_index,
const ModuleBase::Array_Pool<bool> &cal_flag)>;
T_psir_func psir_func_1 = nullptr;
T_psir_func psir_func_2 = nullptr;

protected:
// variables related to FFT grid
int nbx;
Expand Down Expand Up @@ -152,17 +170,18 @@ class Gint {
hamilt::HContainer<double>* hR); // HContainer for storing the <phi_0 |
// V | phi_R> matrix element.

void cal_meshball_vlocal_k(int na_grid,
const int LD_pool,
int grid_index,
int* block_size,
int* block_index,
int* block_iw,
bool** cal_flag,
double** psir_ylm,
double** psir_vlbr3,
double* pvpR,
const UnitCell& ucell);
void cal_meshball_vlocal_k(
const int na_grid,
const int LD_pool,
const int grid_index,
const int*const block_size,
const int*const block_index,
const int*const block_iw,
const bool*const*const cal_flag,
const double*const*const psir_ylm,
const double*const*const psir_vlbr3,
double*const pvpR,
const UnitCell &ucell);

//------------------------------------------------------
// in gint_fvl.cpp
Expand Down Expand Up @@ -225,11 +244,11 @@ class Gint {
Gint_inout* inout);

void cal_meshball_rho(const int na_grid,
int* block_index,
int* vindex,
double** psir_ylm,
double** psir_DMR,
double* rho);
const int*const block_index,
const int*const vindex,
const double*const*const psir_ylm,
const double*const*const psir_DMR,
double*const rho);

void gint_kernel_tau(const int na_grid,
const int grid_index,
Expand Down
12 changes: 6 additions & 6 deletions source/module_hamilt_lcao/module_gint/gint_rho.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
#include "module_hamilt_pw/hamilt_pwdft/global.h"

void Gint::cal_meshball_rho(const int na_grid,
int* block_index,
int* vindex,
double** psir_ylm,
double** psir_DMR,
double* rho)
const int*const block_index,
const int*const vindex,
const double*const*const psir_ylm,
const double*const*const psir_DMR,
double*const rho)
{
const int inc = 1;
// sum over mu to get density on grid
for (int ib = 0; ib < this->bxyz; ++ib)
{
double r = ddot_(&block_index[na_grid], psir_ylm[ib], &inc, psir_DMR[ib], &inc);
const double r = ddot_(&block_index[na_grid], psir_ylm[ib], &inc, psir_DMR[ib], &inc);
const int grid = vindex[ib];
rho[grid] += r;
}
Expand Down
37 changes: 22 additions & 15 deletions source/module_hamilt_lcao/module_gint/gint_rho_cpu_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {
const int ncyz = this->ny * this->nplane;
const double delta_r = this->gridt->dr_uniform;

#pragma omp parallel
#pragma omp parallel
{
std::vector<int> block_iw(max_size, 0);
std::vector<int> block_index(max_size+1, 0);
std::vector<int> block_size(max_size, 0);
std::vector<int> vindex(bxyz, 0);
std::vector<int> vindex(this->bxyz, 0);
#pragma omp for
for (int grid_index = 0; grid_index < this->nbxx; grid_index++) {
for (int grid_index = 0; grid_index < this->nbxx; grid_index++)
{
const int na_grid = this->gridt->how_many_atoms[grid_index];
if (na_grid == 0) {
continue;
Expand All @@ -41,7 +42,7 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {
block_size.data(),
cal_flag.get_ptr_2D());

// evaluate psi on grids
// evaluate psi on grids
const int LD_pool = block_index[na_grid];
ModuleBase::Array_Pool<double> psir_ylm(this->bxyz, LD_pool);
Gint_Tools::cal_psir_ylm(*this->gridt,
Expand All @@ -56,6 +57,11 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {

for (int is = 0; is < inout->nspin_rho; ++is)
{
// psir_ylm_new = psir_func(psir_ylm)
// psir_func==nullptr means psir_ylm_new=psir_ylm
const ModuleBase::Array_Pool<double> &psir_ylm_1 = (!this->psir_func_1) ? psir_ylm : this->psir_func_1(psir_ylm, *this->gridt, grid_index, is, block_iw, block_size, block_index, cal_flag);
const ModuleBase::Array_Pool<double> &psir_ylm_2 = (!this->psir_func_2) ? psir_ylm : this->psir_func_2(psir_ylm, *this->gridt, grid_index, is, block_iw, block_size, block_index, cal_flag);

ModuleBase::Array_Pool<double> psir_DM(this->bxyz, LD_pool);
ModuleBase::GlobalFunc::ZEROS(psir_DM.get_ptr_1D(), this->bxyz * LD_pool);

Expand All @@ -68,13 +74,13 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {
block_index.data(),
block_size.data(),
cal_flag.get_ptr_2D(),
psir_ylm.get_ptr_2D(),
psir_ylm_1.get_ptr_2D(),
psir_DM.get_ptr_2D(),
this->DMRGint[is],
inout->if_symm);

// do sum_mu g_mu(r)psi_mu(r) to get electron density on grid
this->cal_meshball_rho(na_grid, block_index.data(), vindex.data(), psir_ylm.get_ptr_2D(), psir_DM.get_ptr_2D(), inout->rho[is]);
this->cal_meshball_rho(na_grid, block_index.data(), vindex.data(), psir_ylm_2.get_ptr_2D(), psir_DM.get_ptr_2D(), inout->rho[is]);
}
}
}
Expand All @@ -90,14 +96,15 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
const double delta_r = this->gridt->dr_uniform;


#pragma omp parallel
#pragma omp parallel
{
std::vector<int> block_iw(max_size, 0);
std::vector<int> block_index(max_size+1, 0);
std::vector<int> block_size(max_size, 0);
std::vector<int> vindex(bxyz, 0);
#pragma omp for
for (int grid_index = 0; grid_index < this->nbxx; grid_index++) {
for (int grid_index = 0; grid_index < this->nbxx; grid_index++)
{
const int na_grid = this->gridt->how_many_atoms[grid_index];
if (na_grid == 0) {
continue;
Expand All @@ -112,19 +119,19 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
vindex.data());
//prepare block information
ModuleBase::Array_Pool<bool> cal_flag(this->bxyz,max_size);
Gint_Tools::get_block_info(*this->gridt, this->bxyz, na_grid, grid_index,
Gint_Tools::get_block_info(*this->gridt, this->bxyz, na_grid, grid_index,
block_iw.data(), block_index.data(), block_size.data(), cal_flag.get_ptr_2D());

//evaluate psi and dpsi on grids
//evaluate psi and dpsi on grids
const int LD_pool = block_index[na_grid];
ModuleBase::Array_Pool<double> psir_ylm(this->bxyz, LD_pool);
ModuleBase::Array_Pool<double> dpsir_ylm_x(this->bxyz, LD_pool);
ModuleBase::Array_Pool<double> dpsir_ylm_y(this->bxyz, LD_pool);
ModuleBase::Array_Pool<double> dpsir_ylm_z(this->bxyz, LD_pool);

Gint_Tools::cal_dpsir_ylm(*this->gridt,
Gint_Tools::cal_dpsir_ylm(*this->gridt,
this->bxyz, na_grid, grid_index, delta_r,
block_index.data(), block_size.data(),
block_index.data(), block_size.data(),
cal_flag.get_ptr_2D(),
psir_ylm.get_ptr_2D(),
dpsir_ylm_x.get_ptr_2D(),
Expand All @@ -146,7 +153,7 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
LD_pool,
grid_index, na_grid,
block_index.data(), block_size.data(),
cal_flag.get_ptr_2D(),
cal_flag.get_ptr_2D(),
dpsir_ylm_x.get_ptr_2D(),
dpsix_DM.get_ptr_2D(),
this->DMRGint[is],
Expand All @@ -166,13 +173,13 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
LD_pool,
grid_index, na_grid,
block_index.data(), block_size.data(),
cal_flag.get_ptr_2D(),
cal_flag.get_ptr_2D(),
dpsir_ylm_z.get_ptr_2D(),
dpsiz_DM.get_ptr_2D(),
this->DMRGint[is],
true);

//do sum_i,mu g_i,mu(r) * d/dx_i psi_mu(r) to get kinetic energy density on grid
//do sum_i,mu g_i,mu(r) * d/dx_i psi_mu(r) to get kinetic energy density on grid
if(inout->job==Gint_Tools::job_type::tau)
{
this->cal_meshball_tau(
Expand Down
25 changes: 13 additions & 12 deletions source/module_hamilt_lcao/module_gint/gint_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,19 @@ ModuleBase::Array_Pool<double> get_psir_vlbr3(
const double* const* const psir_ylm); // psir_ylm[bxyz][LD_pool]

// sum_nu,R rho_mu,nu(R) psi_nu, for multi-k and gamma point
void mult_psi_DMR(const Grid_Technique& gt,
const int bxyz,
const int LD_pool,
const int& grid_index,
const int& na_grid,
const int* const block_index,
const int* const block_size,
bool** cal_flag,
double** psi,
double** psi_DMR,
const hamilt::HContainer<double>* DM,
const bool if_symm);
void mult_psi_DMR(
const Grid_Technique& gt,
const int bxyz,
const int LD_pool,
const int &grid_index,
const int &na_grid,
const int*const block_index,
const int*const block_size,
const bool*const*const cal_flag,
const double*const*const psi,
double*const*const psi_DMR,
const hamilt::HContainer<double>*const DM,
const bool if_symm);


// pair.first is the first index of the meshcell which is inside atoms ia1 and ia2.
Expand Down
Loading

0 comments on commit 3da2868

Please sign in to comment.