Skip to content

Commit

Permalink
Merge pull request #5241 from ye-luo/cleanup-myVars
Browse files Browse the repository at this point in the history
Relocate myVars in SPOSet/WFC base classes to OptimizableObject
  • Loading branch information
prckent authored Nov 26, 2024
2 parents 802b0e9 + a4940e1 commit 45b8457
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 58 deletions.
5 changes: 0 additions & 5 deletions src/QMCWaveFunctions/Fermion/SlaterDet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,9 @@ void SlaterDet::extractOptimizableObjectRefs(UniqueOptObjRefs& opt_obj_refs)

void SlaterDet::checkOutVariables(const opt_variables_type& active)
{
myVars.clear();
if (isOptimizable())
for (int i = 0; i < Dets.size(); i++)
{
Dets[i]->checkOutVariables(active);
myVars.insertFrom(Dets[i]->myVars);
}
myVars.getIndex(active);
}

PsiValue SlaterDet::ratioGrad(ParticleSet& P, int iat, GradType& grad_iat)
Expand Down
18 changes: 0 additions & 18 deletions src/QMCWaveFunctions/Fermion/SlaterDet.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,31 +265,13 @@ class SlaterDet : public WaveFunctionComponent
Vector<ValueType>& dlogpsi,
Vector<ValueType>& dhpsioverpsi) override
{
// First zero out values, since each determinant only adds on
// its contribution (i.e. +=) , rather than setting the value
// (i.e. =)
for (int k = 0; k < myVars.size(); ++k)
{
int kk = myVars.where(k);
if (kk >= 0)
dlogpsi[kk] = dhpsioverpsi[kk] = 0.0;
}
// Now add on contribution from each determinant to the derivatives
for (int i = 0; i < Dets.size(); i++)
Dets[i]->evaluateDerivatives(P, active, dlogpsi, dhpsioverpsi);
}

void evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& active, Vector<ValueType>& dlogpsi) override
{
// First zero out values, since each determinant only adds on
// its contribution (i.e. +=) , rather than setting the value
// (i.e. =)
for (int k = 0; k < myVars.size(); ++k)
{
int kk = myVars.where(k);
if (kk >= 0)
dlogpsi[kk] = 0.0;
}
// Now add on contribution from each determinant to the derivatives
for (int i = 0; i < Dets.size(); i++)
Dets[i]->evaluateDerivativesWF(P, active, dlogpsi);
Expand Down
2 changes: 2 additions & 0 deletions src/QMCWaveFunctions/Jastrow/JeeIOrbitalSoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class JeeIOrbitalSoA : public WaveFunctionComponent
std::map<std::string, std::unique_ptr<FT>> J3Unique;
//YYYY
std::map<FT*, int> J3UniqueIndex;
///optimizable variables extracted from functors
opt_variables_type myVars;

/// the cutoff for e-I pairs
std::vector<valT> Ion_cutoff;
Expand Down
3 changes: 3 additions & 0 deletions src/QMCWaveFunctions/Jastrow/TwoBodyJastrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class TwoBodyJastrow : public WaveFunctionComponent
std::map<std::string, std::unique_ptr<FT>> J2Unique;
///Container for \f$F[ig*NumGroups+jg]\f$. treat every pointer as a reference.
std::vector<FT*> F;
///optimizable variables extracted from functors
opt_variables_type myVars;

/// e-e table ID
const int my_table_ID_;
// helper for compute J2 Chiesa KE correction
Expand Down
8 changes: 3 additions & 5 deletions src/QMCWaveFunctions/OptimizableFunctorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,11 @@ namespace qmcplusplus
struct OptimizableFunctorBase : public OptimizableObject
{
///typedef for real values
using real_type = optimize::VariableSet::real_type;
///typedef for variableset: this is going to be replaced
using opt_variables_type = optimize::VariableSet;
using real_type = opt_variables_type::real_type;
///expose OptimizableObject::myVars for direct access by a few consumers. Should clean up the consumers.
using OptimizableObject::myVars;
///maximum cutoff
real_type cutoff_radius = 0.0;
///set of variables to be optimized
opt_variables_type myVars;
///default constructor
inline OptimizableFunctorBase(const std::string& name = "") : OptimizableObject(name) {}
///virtual destrutor
Expand Down
15 changes: 9 additions & 6 deletions src/QMCWaveFunctions/OptimizableObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ using opt_variables_type = optimize::VariableSet;

class OptimizableObject
{
public:
OptimizableObject(const std::string& name) : name_(name) {}

const std::string& getName() const { return name_; }
bool isOptimized() const { return is_optimized_; }

private:
/** Name of the optimizable object
*/
Expand All @@ -39,7 +33,16 @@ class OptimizableObject
*/
bool is_optimized_ = false;

protected:
///optimizable variables in use
opt_variables_type myVars;

public:
OptimizableObject(const std::string& name) : name_(name) {}

const std::string& getName() const { return name_; }
bool isOptimized() const { return is_optimized_; }

/** check in variational parameters to the global list of parameters used by the optimizer.
* @param active a super set of optimizable variables
*
Expand Down
10 changes: 6 additions & 4 deletions src/QMCWaveFunctions/RotatedSPOs.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ namespace qmcplusplus
class RotatedSPOs;
namespace testing
{
std::vector<QMCTraits::ValueType>& getMyVarsFull(RotatedSPOs& rot);
std::vector<std::vector<QMCTraits::ValueType>>& getHistoryParams(RotatedSPOs& rot);
const opt_variables_type& getMyVars(RotatedSPOs& rot);
const std::vector<QMCTraits::ValueType>& getMyVarsFull(RotatedSPOs& rot);
const std::vector<std::vector<QMCTraits::ValueType>>& getHistoryParams(RotatedSPOs& rot);
} // namespace testing

class RotatedSPOs : public SPOSet, public OptimizableObject
Expand Down Expand Up @@ -479,8 +480,9 @@ class RotatedSPOs : public SPOSet, public OptimizableObject
/// Use global rotation or history list
bool use_global_rot_ = true;

friend std::vector<ValueType>& testing::getMyVarsFull(RotatedSPOs& rot);
friend std::vector<std::vector<ValueType>>& testing::getHistoryParams(RotatedSPOs& rot);
friend const opt_variables_type& testing::getMyVars(RotatedSPOs& rot);
friend const std::vector<ValueType>& testing::getMyVarsFull(RotatedSPOs& rot);
friend const std::vector<std::vector<ValueType>>& testing::getHistoryParams(RotatedSPOs& rot);
};


Expand Down
11 changes: 0 additions & 11 deletions src/QMCWaveFunctions/SPOSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@ namespace qmcplusplus
{
class ResourceCollection;

class SPOSet;
namespace testing
{
opt_variables_type& getMyVars(SPOSet& spo);
}


/** base class for Single-particle orbital sets
*
* SPOSet stands for S(ingle)P(article)O(rbital)Set which contains
Expand Down Expand Up @@ -581,10 +574,6 @@ class SPOSet : public QMCTraits
const std::string my_name_;
///number of Single-particle orbitals
IndexType OrbitalSetSize;
/// Optimizable variables
opt_variables_type myVars;

friend opt_variables_type& testing::getMyVars(SPOSet& spo);
};

using SPOSetPtr = SPOSet*;
Expand Down
2 changes: 0 additions & 2 deletions src/QMCWaveFunctions/WaveFunctionComponent.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ class WaveFunctionComponent : public QMCTraits

/** current update mode */
int UpdateMode;
///list of variables this WaveFunctionComponent handles
opt_variables_type myVars;
///Bytes in WFBuffer
size_t Bytes_in_WFBuffer;

Expand Down
14 changes: 7 additions & 7 deletions src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,9 +730,9 @@ TEST_CASE("RotatedSPOs construct delta matrix", "[wavefunction]")

namespace testing
{
opt_variables_type& getMyVars(SPOSet& rot) { return rot.myVars; }
std::vector<QMCTraits::ValueType>& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull_; }
std::vector<std::vector<QMCTraits::ValueType>>& getHistoryParams(RotatedSPOs& rot) { return rot.history_params_; }
const opt_variables_type& getMyVars(RotatedSPOs& rot) { return rot.myVars; }
const std::vector<QMCTraits::ValueType>& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull_; }
const std::vector<std::vector<QMCTraits::ValueType>>& getHistoryParams(RotatedSPOs& rot) { return rot.history_params_; }
} // namespace testing

// Test using global rotation
Expand Down Expand Up @@ -775,14 +775,14 @@ TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]")
vs2.readFromHDF("rot_vp.h5", hin);
rot2.readVariationalParameters(hin);

opt_variables_type& var = testing::getMyVars(rot2);
auto& var = testing::getMyVars(rot2);
for (size_t i = 0; i < vs.size(); i++)
CHECK(var[i] == Approx(vs[i]));

//add extra parameters for full set
vs_values.push_back(0.0);
vs_values.push_back(0.0);
std::vector<SPOSet::ValueType>& full_var = testing::getMyVarsFull(rot2);
auto& full_var = testing::getMyVarsFull(rot2);
for (size_t i = 0; i < full_var.size(); i++)
CHECK(full_var[i] == ValueApprox(vs_values[i]));
}
Expand Down Expand Up @@ -827,11 +827,11 @@ TEST_CASE("RotatedSPOs read and write parameters history", "[wavefunction]")
vs2.readFromHDF("rot_vp_hist.h5", hin);
rot2.readVariationalParameters(hin);

opt_variables_type& var = testing::getMyVars(rot2);
auto& var = testing::getMyVars(rot2);
for (size_t i = 0; i < var.size(); i++)
CHECK(var[i] == Approx(vs[i]));

auto hist = testing::getHistoryParams(rot2);
const auto hist = testing::getHistoryParams(rot2);
REQUIRE(hist.size() == 1);
REQUIRE(hist[0].size() == 4);
}
Expand Down

0 comments on commit 45b8457

Please sign in to comment.