From dcff74dbeb77d0763e7d84e7581f64157d9b736b Mon Sep 17 00:00:00 2001 From: Yu Liu <77716030+YuLiu98@users.noreply.github.com> Date: Wed, 13 Nov 2024 20:50:41 +0800 Subject: [PATCH] Refactor: move io_npz to ModuleIO (#5475) --- source/Makefile.Objects | 2 +- source/module_esolver/CMakeLists.txt | 1 - source/module_esolver/esolver_ks_lcao.cpp | 18 +- source/module_esolver/esolver_ks_lcao.h | 4 - source/module_esolver/lcao_before_scf.cpp | 5 +- source/module_io/CMakeLists.txt | 1 + .../{module_esolver => module_io}/io_npz.cpp | 212 ++++++++---------- source/module_io/io_npz.h | 23 ++ 8 files changed, 135 insertions(+), 131 deletions(-) rename source/{module_esolver => module_io}/io_npz.cpp (64%) create mode 100644 source/module_io/io_npz.h diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 63e47a75aa..ba4370a9f9 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -253,7 +253,6 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\ esolver_ks_lcao_tddft.o\ dpks_cal_e_delta_band.o\ dftu_cal_occup_m.o\ - io_npz.o\ set_matrix_grid.o\ lcao_before_scf.o\ lcao_gets.o\ @@ -557,6 +556,7 @@ OBJS_IO_LCAO=cal_r_overlap_R.o\ output_mulliken.o\ output_sk.o\ output_dmk.o\ + io_npz.o\ OBJS_LCAO=evolve_elec.o\ evolve_psi.o\ diff --git a/source/module_esolver/CMakeLists.txt b/source/module_esolver/CMakeLists.txt index 491bb95268..61968c7734 100644 --- a/source/module_esolver/CMakeLists.txt +++ b/source/module_esolver/CMakeLists.txt @@ -19,7 +19,6 @@ if(ENABLE_LCAO) esolver_ks_lcao.cpp esolver_ks_lcao_tddft.cpp dpks_cal_e_delta_band.cpp - io_npz.cpp set_matrix_grid.cpp dftu_cal_occup_m.cpp lcao_before_scf.cpp diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index 78c424bfd0..069ec0e8c7 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -4,9 +4,12 @@ #include "module_base/global_variable.h" #include "module_base/tool_title.h" #include "module_elecstate/module_dm/cal_dm_psi.h" +#include "module_hamilt_lcao/module_deltaspin/spin_constrain.h" #include "module_io/berryphase.h" #include "module_io/cube_io.h" #include "module_io/dos_nao.h" +#include "module_io/io_dmk.h" +#include "module_io/io_npz.h" #include "module_io/nscf_band.h" #include "module_io/output_dmk.h" #include "module_io/output_log.h" @@ -16,11 +19,13 @@ #include "module_io/to_wannier90_lcao.h" #include "module_io/to_wannier90_lcao_in_pw.h" #include "module_io/write_HS.h" +#include "module_io/write_dmr.h" #include "module_io/write_eband_terms.hpp" #include "module_io/write_elecstat_pot.h" #include "module_io/write_istate_info.h" #include "module_io/write_proj_band_lcao.h" #include "module_io/write_vxc.hpp" +#include "module_io/write_wfc_nao.h" #include "module_parameter/parameter.h" //--------------temporary---------------------------- @@ -55,11 +60,6 @@ // #include "module_elecstate/cal_dm.h" //--------------------------------------------------- -#include "module_hamilt_lcao/module_deltaspin/spin_constrain.h" -#include "module_io/io_dmk.h" -#include "module_io/write_dmr.h" -#include "module_io/write_wfc_nao.h" - namespace ModuleESolver { @@ -1105,7 +1105,7 @@ void ESolver_KS_LCAO::after_scf(const int istep) hamilt::HamiltLCAO, double>* p_ham_lcao = dynamic_cast, double>*>(this->p_hamilt); std::string zipname = "output_HR0.npz"; - this->output_mat_npz(zipname, *(p_ham_lcao->getHR())); + ModuleIO::output_mat_npz(GlobalC::ucell, zipname, *(p_ham_lcao->getHR())); if (PARAM.inp.nspin == 2) { @@ -1113,7 +1113,7 @@ void ESolver_KS_LCAO::after_scf(const int istep) hamilt::HamiltLCAO, double>* p_ham_lcao = dynamic_cast, double>*>(this->p_hamilt); zipname = "output_HR1.npz"; - this->output_mat_npz(zipname, *(p_ham_lcao->getHR())); + ModuleIO::output_mat_npz(GlobalC::ucell, zipname, *(p_ham_lcao->getHR())); } } @@ -1123,12 +1123,12 @@ void ESolver_KS_LCAO::after_scf(const int istep) const elecstate::DensityMatrix* dm = dynamic_cast*>(this->pelec)->get_DM(); std::string zipname = "output_DM0.npz"; - this->output_mat_npz(zipname, *(dm->get_DMR_pointer(1))); + ModuleIO::output_mat_npz(GlobalC::ucell, zipname, *(dm->get_DMR_pointer(1))); if (PARAM.inp.nspin == 2) { zipname = "output_DM1.npz"; - this->output_mat_npz(zipname, *(dm->get_DMR_pointer(2))); + ModuleIO::output_mat_npz(GlobalC::ucell, zipname, *(dm->get_DMR_pointer(2))); } } diff --git a/source/module_esolver/esolver_ks_lcao.h b/source/module_esolver/esolver_ks_lcao.h index 75eced339a..761d660fe5 100644 --- a/source/module_esolver/esolver_ks_lcao.h +++ b/source/module_esolver/esolver_ks_lcao.h @@ -95,10 +95,6 @@ class ESolver_KS_LCAO : public ESolver_KS { /// density matrix of H, S, T, r ModuleIO::Output_Mat_Sparse create_Output_Mat_Sparse(int istep); - void read_mat_npz(std::string& zipname, hamilt::HContainer& hR); - void output_mat_npz(std::string& zipname, - const hamilt::HContainer& hR); - /// @brief check if skip the corresponding output in md calculation bool md_skip_out(std::string calculation, int istep, int interval); diff --git a/source/module_esolver/lcao_before_scf.cpp b/source/module_esolver/lcao_before_scf.cpp index 72de1c87de..0b12ca5265 100644 --- a/source/module_esolver/lcao_before_scf.cpp +++ b/source/module_esolver/lcao_before_scf.cpp @@ -10,6 +10,7 @@ #include "module_io/berryphase.h" #include "module_io/get_pchg_lcao.h" #include "module_io/get_wf_lcao.h" +#include "module_io/io_npz.h" #include "module_io/to_wannier90_lcao.h" #include "module_io/to_wannier90_lcao_in_pw.h" #include "module_io/write_HS_R.h" @@ -279,11 +280,11 @@ void ESolver_KS_LCAO::before_scf(const int istep) std::string zipname = "output_DM0.npz"; elecstate::DensityMatrix* dm = dynamic_cast*>(this->pelec)->get_DM(); - this->read_mat_npz(zipname, *(dm->get_DMR_pointer(1))); + ModuleIO::read_mat_npz(&(this->pv), GlobalC::ucell, zipname, *(dm->get_DMR_pointer(1))); if (PARAM.inp.nspin == 2) { zipname = "output_DM1.npz"; - this->read_mat_npz(zipname, *(dm->get_DMR_pointer(2))); + ModuleIO::read_mat_npz(&(this->pv), GlobalC::ucell, zipname, *(dm->get_DMR_pointer(2))); } this->pelec->calculate_weights(); diff --git a/source/module_io/CMakeLists.txt b/source/module_io/CMakeLists.txt index 44e94bb057..e9f7d0b982 100644 --- a/source/module_io/CMakeLists.txt +++ b/source/module_io/CMakeLists.txt @@ -66,6 +66,7 @@ if(ENABLE_LCAO) output_sk.cpp output_dmk.cpp output_mulliken.cpp + io_npz.cpp ) list(APPEND objects_advanced unk_overlap_lcao.cpp diff --git a/source/module_esolver/io_npz.cpp b/source/module_io/io_npz.cpp similarity index 64% rename from source/module_esolver/io_npz.cpp rename to source/module_io/io_npz.cpp index ea354438ec..0732445554 100644 --- a/source/module_esolver/io_npz.cpp +++ b/source/module_io/io_npz.cpp @@ -1,19 +1,9 @@ //Deals with io of dm(r)/h(r) in npz format +#include "io_npz.h" + +#include "module_base/element_name.h" #include "module_parameter/parameter.h" -#include "module_esolver/esolver_ks_lcao.h" - -#include "module_base/parallel_reduce.h" -#include "module_cell/module_neighbor/sltk_atom_arrange.h" -#include "module_cell/module_neighbor/sltk_grid_driver.h" -#include "module_hamilt_general/module_xc/xc_functional.h" -#include "module_hamilt_lcao/module_dftu/dftu.h" -#include "module_hamilt_pw/hamilt_pwdft/global.h" -#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h" -#ifdef __DEEPKS -#include "module_hamilt_lcao/module_deepks/LCAO_deepks.h" //caoyu add 2021-07-26 -#endif -#include "module_base/timer.h" #ifdef __MPI #include @@ -24,17 +14,15 @@ #include "cnpy.h" #endif -#include "module_base/element_name.h" - -namespace ModuleESolver +namespace ModuleIO { -template -void ESolver_KS_LCAO::read_mat_npz(std::string& zipname, hamilt::HContainer& hR) +void read_mat_npz(const Parallel_Orbitals* paraV, + const UnitCell& ucell, + std::string& zipname, + hamilt::HContainer& hR) { - ModuleBase::TITLE("LCAO_Hamilt","read_mat_npz"); - - const Parallel_Orbitals* paraV = &(this->pv); + ModuleBase::TITLE("ModuleIO", "read_mat_npz"); #ifdef __USECNPY @@ -56,38 +44,38 @@ void ESolver_KS_LCAO::read_mat_npz(std::string& zipname, hamilt::HContai //check consistency // 1. lattice vectors double* lattice_vector = my_npz["lattice_vectors"].data(); - assert(std::abs(lattice_vector[0] - GlobalC::ucell.lat0 * GlobalC::ucell.a1.x) < 1e-6); - assert(std::abs(lattice_vector[1] - GlobalC::ucell.lat0 * GlobalC::ucell.a1.y) < 1e-6); - assert(std::abs(lattice_vector[2] - GlobalC::ucell.lat0 * GlobalC::ucell.a1.z) < 1e-6); - assert(std::abs(lattice_vector[3] - GlobalC::ucell.lat0 * GlobalC::ucell.a2.x) < 1e-6); - assert(std::abs(lattice_vector[4] - GlobalC::ucell.lat0 * GlobalC::ucell.a2.y) < 1e-6); - assert(std::abs(lattice_vector[5] - GlobalC::ucell.lat0 * GlobalC::ucell.a2.z) < 1e-6); - assert(std::abs(lattice_vector[6] - GlobalC::ucell.lat0 * GlobalC::ucell.a3.x) < 1e-6); - assert(std::abs(lattice_vector[7] - GlobalC::ucell.lat0 * GlobalC::ucell.a3.y) < 1e-6); - assert(std::abs(lattice_vector[8] - GlobalC::ucell.lat0 * GlobalC::ucell.a3.z) < 1e-6); + assert(std::abs(lattice_vector[0] - ucell.lat0 * ucell.a1.x) < 1e-6); + assert(std::abs(lattice_vector[1] - ucell.lat0 * ucell.a1.y) < 1e-6); + assert(std::abs(lattice_vector[2] - ucell.lat0 * ucell.a1.z) < 1e-6); + assert(std::abs(lattice_vector[3] - ucell.lat0 * ucell.a2.x) < 1e-6); + assert(std::abs(lattice_vector[4] - ucell.lat0 * ucell.a2.y) < 1e-6); + assert(std::abs(lattice_vector[5] - ucell.lat0 * ucell.a2.z) < 1e-6); + assert(std::abs(lattice_vector[6] - ucell.lat0 * ucell.a3.x) < 1e-6); + assert(std::abs(lattice_vector[7] - ucell.lat0 * ucell.a3.y) < 1e-6); + assert(std::abs(lattice_vector[8] - ucell.lat0 * ucell.a3.z) < 1e-6); // 2. atoms double* atom_info = my_npz["atom_info"].data(); - for(int iat = 0; iat < GlobalC::ucell.nat; ++iat) + for (int iat = 0; iat < ucell.nat; ++iat) { - const int it = GlobalC::ucell.iat2it[iat]; - const int ia = GlobalC::ucell.iat2ia[iat]; + const int it = ucell.iat2it[iat]; + const int ia = ucell.iat2ia[iat]; //get atomic number (copied from write_vdata_palgrid.cpp) std::string element = ""; - element = GlobalC::ucell.atoms[it].label; - std::string::iterator temp = element.begin(); - while (temp != element.end()) - { - if ((*temp >= '1') && (*temp <= '9')) - { - temp = element.erase(temp); - } - else + element = ucell.atoms[it].label; + std::string::iterator temp = element.begin(); + while (temp != element.end()) + { + if ((*temp >= '1') && (*temp <= '9')) + { + temp = element.erase(temp); + } + else { temp++; } - } + } int z = 0; for(int j=0; j!=ModuleBase::element_name.size(); j++) { @@ -100,24 +88,24 @@ void ESolver_KS_LCAO::read_mat_npz(std::string& zipname, hamilt::HContai assert(atom_info[iat*5] == it); assert(atom_info[iat*5+1] == z); - //I will not be checking the coordinates for now in case the direct coordinates provided in the - //npz file do not fall in the range [0,1); if a protocol is to be set in the future such that - //this could be guaranteed, then the following lines could be uncommented - //assert(std::abs(atom_info[iat*5+2] - GlobalC::ucell.atoms[it].taud[ia].x) < 1e-6); - //assert(std::abs(atom_info[iat*5+3] - GlobalC::ucell.atoms[it].taud[ia].y) < 1e-6); - //assert(std::abs(atom_info[iat*5+4] - GlobalC::ucell.atoms[it].taud[ia].z) < 1e-6); + // I will not be checking the coordinates for now in case the direct coordinates provided in the + // npz file do not fall in the range [0,1); if a protocol is to be set in the future such that + // this could be guaranteed, then the following lines could be uncommented + // assert(std::abs(atom_info[iat*5+2] - ucell.atoms[it].taud[ia].x) < 1e-6); + // assert(std::abs(atom_info[iat*5+3] - ucell.atoms[it].taud[ia].y) < 1e-6); + // assert(std::abs(atom_info[iat*5+4] - ucell.atoms[it].taud[ia].z) < 1e-6); } // 3. orbitals - for(int it = 0; it < GlobalC::ucell.ntype; ++it) + for (int it = 0; it < ucell.ntype; ++it) { std::string filename="orbital_info_"+std::to_string(it); double* orbital_info = my_npz[filename].data(); - for(int iw = 0; iw < GlobalC::ucell.atoms[it].nw; ++iw) + for (int iw = 0; iw < ucell.atoms[it].nw; ++iw) { - assert(orbital_info[iw*3] == GlobalC::ucell.atoms[it].iw2n[iw]); - assert(orbital_info[iw*3+1] == GlobalC::ucell.atoms[it].iw2l[iw]); - const int im = GlobalC::ucell.atoms[it].iw2m[iw]; + assert(orbital_info[iw * 3] == ucell.atoms[it].iw2n[iw]); + assert(orbital_info[iw * 3 + 1] == ucell.atoms[it].iw2l[iw]); + const int im = ucell.atoms[it].iw2m[iw]; const int m = (im % 2 == 0) ? -im/2 : (im+1)/2; assert(orbital_info[iw*3+2] == m); } @@ -148,12 +136,12 @@ void ESolver_KS_LCAO::read_mat_npz(std::string& zipname, hamilt::HContai int Ry = std::stoi(tokens[4]); int Rz = std::stoi(tokens[5]); - int it1 = GlobalC::ucell.iat2it[iat1]; - int it2 = GlobalC::ucell.iat2it[iat2]; + int it1 = ucell.iat2it[iat1]; + int it2 = ucell.iat2it[iat2]; + + assert(arr.shape[0] == ucell.atoms[it1].nw); + assert(arr.shape[1] == ucell.atoms[it2].nw); - assert(arr.shape[0] == GlobalC::ucell.atoms[it1].nw); - assert(arr.shape[1] == GlobalC::ucell.atoms[it2].nw); - //hamilt::AtomPair tmp(iat1,iat2,Rx,Ry,Rz,&serialV); //HR_serial->insert_pair(tmp); hamilt::AtomPair tmp(iat1,iat2,Rx,Ry,Rz,paraV); @@ -193,11 +181,11 @@ void ESolver_KS_LCAO::read_mat_npz(std::string& zipname, hamilt::HContai int Ry = std::stoi(tokens[4]); int Rz = std::stoi(tokens[5]); - int it1 = GlobalC::ucell.iat2it[iat1]; - int it2 = GlobalC::ucell.iat2it[iat2]; + int it1 = ucell.iat2it[iat1]; + int it2 = ucell.iat2it[iat2]; - assert(arr.shape[0] == GlobalC::ucell.atoms[it1].nw); - assert(arr.shape[1] == GlobalC::ucell.atoms[it2].nw); + assert(arr.shape[0] == ucell.atoms[it1].nw); + assert(arr.shape[1] == ucell.atoms[it2].nw); double* submat_read = arr.data(); @@ -257,12 +245,12 @@ void ESolver_KS_LCAO::read_mat_npz(std::string& zipname, hamilt::HContai int Ry = std::stoi(tokens[4]); int Rz = std::stoi(tokens[5]); - int it1 = GlobalC::ucell.iat2it[iat1]; - int it2 = GlobalC::ucell.iat2it[iat2]; + int it1 = ucell.iat2it[iat1]; + int it2 = ucell.iat2it[iat2]; + + assert(arr.shape[0] == ucell.atoms[it1].nw); + assert(arr.shape[1] == ucell.atoms[it2].nw); - assert(arr.shape[0] == GlobalC::ucell.atoms[it1].nw); - assert(arr.shape[1] == GlobalC::ucell.atoms[it2].nw); - hamilt::AtomPair tmp(iat1,iat2,Rx,Ry,Rz,paraV); hR->insert_pair(tmp); // use symmetry : H_{mu,nu,R} = H_{nu,mu,-R} @@ -296,11 +284,11 @@ void ESolver_KS_LCAO::read_mat_npz(std::string& zipname, hamilt::HContai int Ry = std::stoi(tokens[4]); int Rz = std::stoi(tokens[5]); - int it1 = GlobalC::ucell.iat2it[iat1]; - int it2 = GlobalC::ucell.iat2it[iat2]; + int it1 = ucell.iat2it[iat1]; + int it2 = ucell.iat2it[iat2]; - assert(arr.shape[0] == GlobalC::ucell.atoms[it1].nw); - assert(arr.shape[1] == GlobalC::ucell.atoms[it2].nw); + assert(arr.shape[0] == ucell.atoms[it1].nw); + assert(arr.shape[1] == ucell.atoms[it2].nw); double* submat_read = arr.data(); @@ -333,10 +321,9 @@ void ESolver_KS_LCAO::read_mat_npz(std::string& zipname, hamilt::HContai #endif } -template -void ESolver_KS_LCAO::output_mat_npz(std::string& zipname, const hamilt::HContainer& hR) +void output_mat_npz(const UnitCell& ucell, std::string& zipname, const hamilt::HContainer& hR) { - ModuleBase::TITLE("LCAO_Hamilt","output_mat_npz"); + ModuleBase::TITLE("ModuleIO", "output_mat_npz"); #ifdef __USECNPY std::string filename = ""; @@ -348,41 +335,41 @@ void ESolver_KS_LCAO::output_mat_npz(std::string& zipname, const hamilt: filename = "lattice_vectors"; std::vector lattice_vectors; lattice_vectors.resize(9); - lattice_vectors[0] = GlobalC::ucell.lat0 * GlobalC::ucell.a1.x; - lattice_vectors[1] = GlobalC::ucell.lat0 * GlobalC::ucell.a1.y; - lattice_vectors[2] = GlobalC::ucell.lat0 * GlobalC::ucell.a1.z; - lattice_vectors[3] = GlobalC::ucell.lat0 * GlobalC::ucell.a2.x; - lattice_vectors[4] = GlobalC::ucell.lat0 * GlobalC::ucell.a2.y; - lattice_vectors[5] = GlobalC::ucell.lat0 * GlobalC::ucell.a2.z; - lattice_vectors[6] = GlobalC::ucell.lat0 * GlobalC::ucell.a3.x; - lattice_vectors[7] = GlobalC::ucell.lat0 * GlobalC::ucell.a3.y; - lattice_vectors[8] = GlobalC::ucell.lat0 * GlobalC::ucell.a3.z; + lattice_vectors[0] = ucell.lat0 * ucell.a1.x; + lattice_vectors[1] = ucell.lat0 * ucell.a1.y; + lattice_vectors[2] = ucell.lat0 * ucell.a1.z; + lattice_vectors[3] = ucell.lat0 * ucell.a2.x; + lattice_vectors[4] = ucell.lat0 * ucell.a2.y; + lattice_vectors[5] = ucell.lat0 * ucell.a2.z; + lattice_vectors[6] = ucell.lat0 * ucell.a3.x; + lattice_vectors[7] = ucell.lat0 * ucell.a3.y; + lattice_vectors[8] = ucell.lat0 * ucell.a3.z; cnpy::npz_save(zipname,filename,lattice_vectors); // second block: atom info filename = "atom_info"; - double* atom_info = new double[GlobalC::ucell.nat*5]; - for(int iat = 0; iat < GlobalC::ucell.nat; ++iat) + double* atom_info = new double[ucell.nat * 5]; + for (int iat = 0; iat < ucell.nat; ++iat) { - const int it = GlobalC::ucell.iat2it[iat]; - const int ia = GlobalC::ucell.iat2ia[iat]; + const int it = ucell.iat2it[iat]; + const int ia = ucell.iat2ia[iat]; //get atomic number (copied from write_vdata_palgrid.cpp) std::string element = ""; - element = GlobalC::ucell.atoms[it].label; - std::string::iterator temp = element.begin(); - while (temp != element.end()) - { - if ((*temp >= '1') && (*temp <= '9')) - { - temp = element.erase(temp); - } - else + element = ucell.atoms[it].label; + std::string::iterator temp = element.begin(); + while (temp != element.end()) + { + if ((*temp >= '1') && (*temp <= '9')) + { + temp = element.erase(temp); + } + else { temp++; } - } + } int z = 0; for(int j=0; j!=ModuleBase::element_name.size(); j++) { @@ -395,29 +382,29 @@ void ESolver_KS_LCAO::output_mat_npz(std::string& zipname, const hamilt: atom_info[iat*5] = it; atom_info[iat*5+1] = z; - atom_info[iat*5+2] = GlobalC::ucell.atoms[it].taud[ia].x; - atom_info[iat*5+3] = GlobalC::ucell.atoms[it].taud[ia].y; - atom_info[iat*5+4] = GlobalC::ucell.atoms[it].taud[ia].z; + atom_info[iat * 5 + 2] = ucell.atoms[it].taud[ia].x; + atom_info[iat * 5 + 3] = ucell.atoms[it].taud[ia].y; + atom_info[iat * 5 + 4] = ucell.atoms[it].taud[ia].z; } - std::vector shape={(size_t)GlobalC::ucell.nat,5}; + std::vector shape = {(size_t)ucell.nat, 5}; cnpy::npz_save(zipname,filename,atom_info,shape,"a"); delete[] atom_info; //third block: orbital info - for(int it = 0; it < GlobalC::ucell.ntype; ++it) + for (int it = 0; it < ucell.ntype; ++it) { filename="orbital_info_"+std::to_string(it); - double* orbital_info = new double[GlobalC::ucell.atoms[it].nw*3]; - for(int iw = 0; iw < GlobalC::ucell.atoms[it].nw; ++iw) + double* orbital_info = new double[ucell.atoms[it].nw * 3]; + for (int iw = 0; iw < ucell.atoms[it].nw; ++iw) { - orbital_info[iw*3] = GlobalC::ucell.atoms[it].iw2n[iw]; - orbital_info[iw*3+1] = GlobalC::ucell.atoms[it].iw2l[iw]; - const int im = GlobalC::ucell.atoms[it].iw2m[iw]; + orbital_info[iw * 3] = ucell.atoms[it].iw2n[iw]; + orbital_info[iw * 3 + 1] = ucell.atoms[it].iw2l[iw]; + const int im = ucell.atoms[it].iw2m[iw]; const int m = (im % 2 == 0) ? -im/2 : (im+1)/2; orbital_info[iw*3+2] = m; } - shape={(size_t)GlobalC::ucell.atoms[it].nw,3}; + shape = {(size_t)ucell.atoms[it].nw, 3}; cnpy::npz_save(zipname,filename,orbital_info,shape,"a"); } @@ -428,7 +415,7 @@ void ESolver_KS_LCAO::output_mat_npz(std::string& zipname, const hamilt: hamilt::HContainer* HR_serial; Parallel_Orbitals serialV; serialV.set_serial(PARAM.globalv.nlocal, PARAM.globalv.nlocal); - serialV.set_atomic_trace(GlobalC::ucell.get_iat2iwt(), GlobalC::ucell.nat, PARAM.globalv.nlocal); + serialV.set_atomic_trace(ucell.get_iat2iwt(), ucell.nat, PARAM.globalv.nlocal); if(GlobalV::MY_RANK == 0) { HR_serial = new hamilt::HContainer(&serialV); @@ -484,7 +471,4 @@ void ESolver_KS_LCAO::output_mat_npz(std::string& zipname, const hamilt: #endif } -template class ESolver_KS_LCAO; -template class ESolver_KS_LCAO, double>; -template class ESolver_KS_LCAO, std::complex>; -} +} // namespace ModuleIO diff --git a/source/module_io/io_npz.h b/source/module_io/io_npz.h new file mode 100644 index 0000000000..f06caa9331 --- /dev/null +++ b/source/module_io/io_npz.h @@ -0,0 +1,23 @@ +#ifndef NPZ_IO_H +#define NPZ_IO_H + +#include "module_basis/module_ao/parallel_orbitals.h" +#include "module_cell/unitcell.h" +#include "module_hamilt_lcao/module_hcontainer/hcontainer.h" + +#include +#include + +namespace ModuleIO +{ + +void read_mat_npz(const Parallel_Orbitals* paraV, + const UnitCell& ucell, + std::string& zipname, + hamilt::HContainer& hR); + +void output_mat_npz(const UnitCell& ucell, std::string& zipname, const hamilt::HContainer& hR); + +} // namespace ModuleIO + +#endif // NPZ_IO_H