Skip to content

Commit

Permalink
Refactor: make get_S a new esolver (#5515)
Browse files Browse the repository at this point in the history
* Refactor: make get_S a new esolver

* add head files

* add esolver_gets
  • Loading branch information
YuLiu98 authored Nov 19, 2024
1 parent f68dc9f commit da1baf5
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 191 deletions.
2 changes: 1 addition & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
dpks_cal_e_delta_band.o\
set_matrix_grid.o\
lcao_before_scf.o\
lcao_gets.o\
esolver_gets.o\
lcao_others.o\
lcao_init_after_vc.o\

Expand Down
5 changes: 4 additions & 1 deletion source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,16 @@ void Driver::driver_run() {
Relax_Driver rl_driver;
rl_driver.relax_driver(p_esolver);
}
else if (cal_type == "get_S")
{
p_esolver->runner(0, GlobalC::ucell);
}
else
{
//! supported "other" functions:
//! get_pchg(LCAO),
//! test_memory(PW,LCAO),
//! test_neighbour(LCAO),
//! get_S(LCAO),
//! gen_bessel(PW), et al.
const int istep = 0;
p_esolver->others(istep);
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ if(ENABLE_LCAO)
dpks_cal_e_delta_band.cpp
set_matrix_grid.cpp
lcao_before_scf.cpp
lcao_gets.cpp
esolver_gets.cpp
lcao_others.cpp
lcao_init_after_vc.cpp
)
Expand Down
48 changes: 35 additions & 13 deletions source/module_esolver/esolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
#include "module_base/module_device/device.h"
#include "module_parameter/parameter.h"
#ifdef __LCAO
#include "esolver_ks_lcaopw.h"
#include "esolver_gets.h"
#include "esolver_ks_lcao.h"
#include "esolver_ks_lcao_tddft.h"
#include "esolver_ks_lcaopw.h"
#include "module_lr/esolver_lrtd_lcao.h"
extern "C"
{
Expand Down Expand Up @@ -188,18 +189,39 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
{
if (PARAM.globalv.gamma_only_local)
{
return new ESolver_KS_LCAO<double, double>();
}
else if (PARAM.inp.nspin < 4)
{
return new ESolver_KS_LCAO<std::complex<double>, double>();
}
else
{
return new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
}
}
else if (esolver_type == "ksdft_lcao_tddft")
if (PARAM.inp.calculation == "get_S")
{
return new ESolver_GetS<double, double>();
}
else
{
return new ESolver_KS_LCAO<double, double>();
}
}
else if (PARAM.inp.nspin < 4)
{
if (PARAM.inp.calculation == "get_S")
{
return new ESolver_GetS<std::complex<double>, double>();
}
else
{
return new ESolver_KS_LCAO<std::complex<double>, double>();
}
}
else
{
if (PARAM.inp.calculation == "get_S")
{
return new ESolver_GetS<std::complex<double>, std::complex<double>>();
}
else
{
return new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
}
}
}
else if (esolver_type == "ksdft_lcao_tddft")
{
return new ESolver_KS_LCAO_TDDFT();
}
Expand Down
175 changes: 175 additions & 0 deletions source/module_esolver/esolver_gets.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#include "esolver_gets.h"

#include "module_base/timer.h"
#include "module_cell/module_neighbor/sltk_atom_arrange.h"
#include "module_elecstate/elecstate_lcao.h"
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h"
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
#include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/operator_lcao.h"
#include "module_io/print_info.h"
#include "module_io/write_HS_R.h"

namespace ModuleESolver
{

template <typename TK, typename TR>
ESolver_GetS<TK, TR>::ESolver_GetS()
{
this->classname = "ESolver_GetS";
this->basisname = "LCAO";
}

template <typename TK, typename TR>
ESolver_GetS<TK, TR>::~ESolver_GetS()
{
}

template <typename TK, typename TR>
void ESolver_GetS<TK, TR>::before_all_runners(const Input_para& inp, UnitCell& ucell)
{
ModuleBase::TITLE("ESolver_GetS", "before_all_runners");
ModuleBase::timer::tick("ESolver_GetS", "before_all_runners");

// 1.1) read pseudopotentials
ucell.read_pseudo(GlobalV::ofs_running);

// 1.2) symmetrize things
if (ModuleSymmetry::Symmetry::symm_flag == 1)
{
ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SYMMETRY");
}

// 1.3) Setup k-points according to symmetry.
this->kv.set(ucell.symm, inp.kpoint_file, inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");

ModuleIO::setup_parameters(ucell, this->kv);

// 2) init ElecState
// autoset nbands in ElecState, it should before basis_init (for Psi 2d division)
if (this->pelec == nullptr)
{
// TK stands for double and complex<double>?
this->pelec = new elecstate::ElecStateLCAO<TK>(&(this->chr), // use which parameter?
&(this->kv),
this->kv.get_nks(),
&(this->GG), // mohan add 2024-04-01
&(this->GK), // mohan add 2024-04-01
this->pw_rho,
this->pw_big);
}

// 3) init LCAO basis
// reading the localized orbitals/projectors
// construct the interpolation tables.
LCAO_domain::init_basis_lcao(this->pv,
inp.onsite_radius,
inp.lcao_ecut,
inp.lcao_dk,
inp.lcao_dr,
inp.lcao_rmax,
ucell,
two_center_bundle_,
orb_);

// 4) initialize the density matrix
// DensityMatrix is allocated here, DMK is also initialized here
// DMR is not initialized here, it will be constructed in each before_scf
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, &(this->pv), inp.nspin);

ModuleBase::timer::tick("ESolver_GetS", "before_all_runners");
}

template <>
void ESolver_GetS<double, double>::runner(const int istep, UnitCell& ucell)
{
ModuleBase::TITLE("ESolver_GetS", "runner");
ModuleBase::WARNING_QUIT("ESolver_GetS<double, double>::runner", "not implemented");
}

template <>
void ESolver_GetS<std::complex<double>, std::complex<double>>::runner(const int istep, UnitCell& ucell)
{
ModuleBase::TITLE("ESolver_GetS", "runner");
ModuleBase::timer::tick("ESolver_GetS", "runner");

// (1) Find adjacent atoms for each atom.
double search_radius = -1.0;
search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running,
PARAM.inp.out_level,
orb_.get_rcutmax_Phi(),
ucell.infoNL.get_rcutmax_Beta(),
PARAM.globalv.gamma_only_local);

atom_arrange::search(PARAM.inp.search_pbc,
GlobalV::ofs_running,
GlobalC::GridD,
ucell,
search_radius,
PARAM.inp.test_atom_input);

this->RA.for_2d(this->pv, PARAM.globalv.gamma_only_local, orb_.cutoffs());

if (this->p_hamilt == nullptr)
{
this->p_hamilt
= new hamilt::HamiltLCAO<std::complex<double>, std::complex<double>>(&this->pv,
this->kv,
*(two_center_bundle_.overlap_orb),
orb_.cutoffs());
dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, std::complex<double>>*>(this->p_hamilt->ops)
->contributeHR();
}

const std::string fn = PARAM.globalv.global_out_dir + "SR.csr";
std::cout << " The file is saved in " << fn << std::endl;
ModuleIO::output_SR(pv, GlobalC::GridD, this->p_hamilt, fn);

ModuleBase::timer::tick("ESolver_GetS", "runner");
}

template <>
void ESolver_GetS<std::complex<double>, double>::runner(const int istep, UnitCell& ucell)
{
ModuleBase::TITLE("ESolver_GetS", "runner");
ModuleBase::timer::tick("ESolver_GetS", "runner");

// (1) Find adjacent atoms for each atom.
double search_radius = -1.0;
search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running,
PARAM.inp.out_level,
orb_.get_rcutmax_Phi(),
ucell.infoNL.get_rcutmax_Beta(),
PARAM.globalv.gamma_only_local);

atom_arrange::search(PARAM.inp.search_pbc,
GlobalV::ofs_running,
GlobalC::GridD,
ucell,
search_radius,
PARAM.inp.test_atom_input);

this->RA.for_2d(this->pv, PARAM.globalv.gamma_only_local, orb_.cutoffs());

if (this->p_hamilt == nullptr)
{
this->p_hamilt = new hamilt::HamiltLCAO<std::complex<double>, double>(&this->pv,
this->kv,
*(two_center_bundle_.overlap_orb),
orb_.cutoffs());
dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, double>*>(this->p_hamilt->ops)->contributeHR();
}

const std::string fn = PARAM.globalv.global_out_dir + "SR.csr";
std::cout << " The file is saved in " << fn << std::endl;
ModuleIO::output_SR(pv, GlobalC::GridD, this->p_hamilt, fn);

ModuleBase::timer::tick("ESolver_GetS", "runner");
}

template class ESolver_GetS<double, double>;
template class ESolver_GetS<std::complex<double>, double>;
template class ESolver_GetS<std::complex<double>, std::complex<double>>;

} // namespace ModuleESolver
55 changes: 55 additions & 0 deletions source/module_esolver/esolver_gets.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef ESOLVER_GETS_H
#define ESOLVER_GETS_H

#include "module_basis/module_nao/two_center_bundle.h"
#include "module_cell/unitcell.h"
#include "module_esolver/esolver_ks.h"
#include "module_hamilt_lcao/module_gint/gint_gamma.h"
#include "module_hamilt_lcao/module_gint/gint_k.h"

#include <memory>

namespace ModuleESolver
{
template <typename TK, typename TR>
class ESolver_GetS : public ESolver_KS<TK>
{
public:
ESolver_GetS();
~ESolver_GetS();

void before_all_runners(const Input_para& inp, UnitCell& ucell) override;

void after_all_runners() {};

void runner(const int istep, UnitCell& ucell) override;

//! calculate total energy of a given system
double cal_energy() {};

//! calcualte forces for the atoms in the given cell
void cal_force(ModuleBase::matrix& force) {};

//! calcualte stress of given cell
void cal_stress(ModuleBase::matrix& stress) {};

protected:
// we will get rid of this class soon, don't use it, mohan 2024-03-28
Record_adj RA;

// 2d block - cyclic distribution info
Parallel_Orbitals pv;

// used for k-dependent grid integration.
Gint_k GK;

// used for gamma only algorithms.
Gint_Gamma GG;

TwoCenterBundle two_center_bundle_;

// // temporary introduced during removing GlobalC::ORB
LCAO_Orbitals orb_;
};
} // namespace ModuleESolver
#endif
32 changes: 1 addition & 31 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,30 +122,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
ModuleBase::TITLE("ESolver_KS_LCAO", "before_all_runners");
ModuleBase::timer::tick("ESolver_KS_LCAO", "before_all_runners");

// 1) calculate overlap matrix S
if (PARAM.inp.calculation == "get_S")
{
// 1.1) read pseudopotentials
ucell.read_pseudo(GlobalV::ofs_running);

// 1.2) symmetrize things
if (ModuleSymmetry::Symmetry::symm_flag == 1)
{
ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SYMMETRY");
}

// 1.3) Setup k-points according to symmetry.
this->kv.set(ucell.symm, PARAM.inp.kpoint_file, PARAM.inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");

ModuleIO::setup_parameters(ucell, this->kv);
}
else
{
// 1) else, call before_all_runners() in ESolver_KS
ESolver_KS<TK>::before_all_runners(inp, ucell);
} // end ifnot get_S
ESolver_KS<TK>::before_all_runners(inp, ucell);

// 2) init ElecState
// autoset nbands in ElecState, it should before basis_init (for Psi 2d division)
Expand Down Expand Up @@ -179,13 +156,6 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
// DMR is not initialized here, it will be constructed in each before_scf
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, &(this->pv), PARAM.inp.nspin);

// this function should be removed outside of the function in near future
if (PARAM.inp.calculation == "get_S")
{
ModuleBase::timer::tick("ESolver_KS_LCAO", "init");
return;
}

// 5) initialize Hamilt in LCAO
// * allocate H and S matrices according to computational resources
// * set the 'trace' between local H/S and global H/S
Expand Down
2 changes: 0 additions & 2 deletions source/module_esolver/esolver_ks_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class ESolver_KS_LCAO : public ESolver_KS<TK> {

void after_all_runners() override;

void get_S();

protected:
virtual void before_scf(const int istep) override;

Expand Down
Loading

0 comments on commit da1baf5

Please sign in to comment.