Skip to content

Commit

Permalink
swap the sizeof() be the first multiplier to avoid overflow of int (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kirk0830 authored Nov 22, 2024
1 parent 8c3def4 commit 20efb88
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
27 changes: 19 additions & 8 deletions source/module_hamilt_pw/hamilt_pwdft/wavefunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,18 @@ psi::Psi<std::complex<double>>* wavefunc::allocate(const int nkstot, const int n
if (PARAM.inp.basis_type == "lcao_in_pw")
{
wanf2[0].create(PARAM.globalv.nlocal, npwx * PARAM.globalv.npol);
const size_t memory_cost = PARAM.globalv.nlocal * (PARAM.globalv.npol * npwx) * sizeof(std::complex<double>);
std::cout << " Memory for wanf2 (MB): " << double(memory_cost) / 1024.0 / 1024.0 << std::endl;

// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
const size_t memory_cost = sizeof(std::complex<double>) * PARAM.globalv.nlocal * (PARAM.globalv.npol * npwx);

std::cout << " Memory for wanf2 (MB): " << static_cast<double>(memory_cost) / 1024.0 / 1024.0 << std::endl;
ModuleBase::Memory::record("WF::wanf2", memory_cost);
}
const size_t memory_cost = PARAM.inp.nbands * (PARAM.globalv.npol * npwx) * sizeof(std::complex<double>);
std::cout << " MEMORY FOR PSI (MB) : " << double(memory_cost) / 1024.0 / 1024.0 << std::endl;

// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
const size_t memory_cost = sizeof(std::complex<double>) * PARAM.inp.nbands * (PARAM.globalv.npol * npwx);

std::cout << " MEMORY FOR PSI (MB) : " << static_cast<double>(memory_cost) / 1024.0 / 1024.0 << std::endl;
ModuleBase::Memory::record("Psi_PW", memory_cost);
}
else if (PARAM.inp.basis_type != "pw")
Expand All @@ -82,17 +88,22 @@ psi::Psi<std::complex<double>>* wavefunc::allocate(const int nkstot, const int n
this->wanf2[ik].create(PARAM.globalv.nlocal, npwx * PARAM.globalv.npol);
}

const size_t memory_cost = nks2 * PARAM.globalv.nlocal * (npwx * PARAM.globalv.npol) * sizeof(std::complex<double>);
std::cout << " Memory for wanf2 (MB): " << double(memory_cost) / 1024.0 / 1024.0 << std::endl;
// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
const size_t memory_cost = sizeof(std::complex<double>) * nks2 * PARAM.globalv.nlocal * (npwx * PARAM.globalv.npol);

std::cout << " Memory for wanf2 (MB): " << static_cast<double>(memory_cost) / 1024.0 / 1024.0 << std::endl;
ModuleBase::Memory::record("WF::wanf2", memory_cost);
}
}
else
{
// initial psi rather than evc
psi_out = new psi::Psi<std::complex<double>>(nks2, PARAM.inp.nbands, npwx * PARAM.globalv.npol, ngk);
const size_t memory_cost = nks2 * PARAM.inp.nbands * (PARAM.globalv.npol * npwx) * sizeof(std::complex<double>);
std::cout << " MEMORY FOR PSI (MB) : " << double(memory_cost) / 1024.0 / 1024.0 << std::endl;

// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
const size_t memory_cost = sizeof(std::complex<double>) * nks2 * PARAM.inp.nbands * (PARAM.globalv.npol * npwx);

std::cout << " MEMORY FOR PSI (MB) : " << static_cast<double>(memory_cost) / 1024.0 / 1024.0 << std::endl;
ModuleBase::Memory::record("Psi_PW", memory_cost);
}
return psi_out;
Expand Down
4 changes: 2 additions & 2 deletions source/module_psi/psi_initializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ psi::Psi<std::complex<double>>* psi_initializer<T, Device>::allocate(const bool
PARAM.inp.nbands, // because no matter what, the wavefunction finally needed has PARAM.inp.nbands bands
nbasis_actual,
this->pw_wfc_->npwk);
double memory_cost_psi = nks_psi * PARAM.inp.nbands * this->pw_wfc_->npwk_max * PARAM.globalv.npol*
sizeof(std::complex<double>);
double memory_cost_psi = sizeof(std::complex<double>) * nks_psi * PARAM.inp.nbands
* this->pw_wfc_->npwk_max * PARAM.globalv.npol;
#ifdef __MPI
// get the correct memory cost for psi by all-reduce sum
Parallel_Reduce::reduce_all(memory_cost_psi);
Expand Down

0 comments on commit 20efb88

Please sign in to comment.