Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: refactor psi init & wfinit class #5533

Merged
merged 4 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 10 additions & 24 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,16 @@ void ESolver_KS_PW<T, Device>::before_scf(const int istep)
// does is only to initialize for once...
if (((PARAM.inp.init_wfc == "random") && (istep == 0)) || (PARAM.inp.init_wfc != "random"))
{
this->p_wf_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running);
this->p_wf_init->initialize_psi(this->psi,
this->kspw_psi,
this->p_hamilt,
GlobalV::ofs_running,
this->already_initpsi);

if (this->already_initpsi == false)
{
this->already_initpsi = true;
}
}
}

Expand Down Expand Up @@ -359,27 +368,6 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(const int istep, const int
}
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;

//---------------------------------------------------------------------------------------------------------------
//---------------------------------for psi init guess!!!!--------------------------------------------------------
//---------------------------------------------------------------------------------------------------------------
if (!PARAM.inp.psi_initializer && PARAM.inp.basis_type == "pw" && this->init_psi == false)
{
for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
{
//! Update Hamiltonian from other kpoint to the given one
this->p_hamilt->updateHk(ik);

//! Fix the wavefunction to initialize at given kpoint
this->kspw_psi->fix_k(ik);

/// for psi init guess!!!!
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *(this->kspw_psi), this->pw_wfc, &this->wf, this->p_hamilt);
}
}
//---------------------------------------------------------------------------------------------------------------
//---------------------------------END: for psi init guess!!!!--------------------------------------------------------
//---------------------------------------------------------------------------------------------------------------

hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc,
PARAM.inp.calculation,
PARAM.inp.basis_type,
Expand All @@ -400,8 +388,6 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(const int istep, const int
GlobalV::NPROC_IN_POOL,
skip_charge);

this->init_psi = true;

Symmetry_rho srho;
for (int is = 0; is < PARAM.inp.nspin; is++)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>

psi::Psi<std::complex<double>, Device>* __kspw_psi = nullptr;

bool init_psi = false;
bool already_initpsi = false;

using castmem_2d_d2h_op
= base_device::memory::cast_memory_op<std::complex<double>, T, base_device::DEVICE_CPU, Device>;
Expand Down
27 changes: 0 additions & 27 deletions source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,32 +190,6 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR = ethr;
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;

//---------------------------------------------------------------------------------------------------------------
//---------------------------------for psi init guess!!!!--------------------------------------------------------
//---------------------------------------------------------------------------------------------------------------
if (!PARAM.inp.psi_initializer && PARAM.inp.basis_type == "pw" && this->init_psi == false)
{
for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
{
//! Update Hamiltonian from other kpoint to the given one
this->p_hamilt->updateHk(ik);

if (this->kspw_psi->get_nbands() > 0 && GlobalV::MY_STOGROUP == 0)
{
//! Fix the wavefunction to initialize at given kpoint
this->kspw_psi->fix_k(ik);

/// for psi init guess!!!!
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *(this->kspw_psi), this->pw_wfc, &this->wf, this->p_hamilt);
}

}
}
//---------------------------------------------------------------------------------------------------------------
//---------------------------------END: for psi init guess!!!!--------------------------------------------------------
//---------------------------------------------------------------------------------------------------------------


// hsolver only exists in this function
hsolver::HSolverPW_SDFT<T, Device> hsolver_pw_sdft_obj(&this->kv,
this->pw_wfc,
Expand All @@ -242,7 +216,6 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
istep,
iter,
skip_charge);
this->init_psi = true;

// set_diagethr need it
this->esolver_KS_ne = hsolver_pw_sdft_obj.stoiter.KS_ne;
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/pw_init_after_vc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ void ESolver_KS_PW<T, Device>::init_after_vc(const Input_para& inp, UnitCell& uc
this->pw_wfc->collect_local_pw(inp.erf_ecut,
inp.erf_height,
inp.erf_sigma);
this->init_psi = false;
this->already_initpsi = false;

delete this->pelec;
this->pelec
Expand Down
42 changes: 29 additions & 13 deletions source/module_hamilt_pw/hamilt_pwdft/wfinit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ template <typename T, typename Device>
void WFInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
psi::Psi<T, Device>* kspw_psi,
hamilt::Hamilt<T, Device>* p_hamilt,
std::ofstream& ofs_running)
std::ofstream& ofs_running,
const bool is_already_initpsi)
{
ModuleBase::timer::tick("WFInit", "initialize_psi");

Expand Down Expand Up @@ -254,20 +255,35 @@ void WFInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
}
else
{
// if (PARAM.inp.basis_type == "pw")
// {
// for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
// {
// //! Update Hamiltonian from other kpoint to the given one
// p_hamilt->updateHk(ik);
//! note: is_already_initpsi will be false in init_after_vc when vc changes.
if (PARAM.inp.basis_type == "pw" && is_already_initpsi == false)
{
for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
{
//! Update Hamiltonian from other kpoint to the given one
p_hamilt->updateHk(ik);

// //! Fix the wavefunction to initialize at given kpoint
// kspw_psi->fix_k(ik);
if (PARAM.inp.esolver_type == "sdft")
{
if (kspw_psi->get_nbands() > 0 && GlobalV::MY_STOGROUP == 0)
{
//! Fix the wavefunction to initialize at given kpoint
kspw_psi->fix_k(ik);

// /// for psi init guess!!!!
// hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *kspw_psi, this->pw_wfc, this->p_wf, p_hamilt);
// }
// }
/// for psi init guess!!!!
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *(kspw_psi), this->pw_wfc, this->p_wf, p_hamilt);
}
}
else
{
//! Fix the wavefunction to initialize at given kpoint
kspw_psi->fix_k(ik);

/// for psi init guess!!!!
hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *(kspw_psi), this->pw_wfc, this->p_wf, p_hamilt);
}
}
}
}

ModuleBase::timer::tick("WFInit", "initialize_psi");
Expand Down
6 changes: 4 additions & 2 deletions source/module_hamilt_pw/hamilt_pwdft/wfinit.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ class WFInit
*
* @param psi store the wavefunction
* @param p_hamilt Hamiltonian operator
* @param ofs_running output stream for running information
* @param ofs_running output stream for running information
* @param is_already_initpsi whether psi has been initialized
*/
void initialize_psi(Psi<std::complex<double>>* psi,
psi::Psi<T, Device>* kspw_psi,
hamilt::Hamilt<T, Device>* p_hamilt,
std::ofstream& ofs_running);
std::ofstream& ofs_running,
const bool is_already_initpsi);

/**
* @brief get the psi_initializer
Expand Down
Loading