Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable complex coefficients for SelfHealingOverlap estimator #5291

Merged
merged 11 commits into from
Jan 31, 2025
Next Next commit
include rotated slater-det in sh overlap estimator
jtkrogel committed Nov 22, 2024
commit 67c129a9be024f2b945b354b3f5049bec167a96b
95 changes: 73 additions & 22 deletions src/Estimators/SelfHealingOverlap.cpp
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
#include "SelfHealingOverlap.h"
#include "TrialWaveFunction.h"
#include "QMCWaveFunctions/Fermion/MultiSlaterDetTableMethod.h"
#include "QMCWaveFunctions/Fermion/SlaterDet.h"

#include <iostream>
#include <numeric>
@@ -19,20 +20,53 @@
namespace qmcplusplus
{
SelfHealingOverlap::SelfHealingOverlap(SelfHealingOverlapInput&& inp_, const TrialWaveFunction& wfn, DataLocality dl)
: OperatorEstBase(dl), input_(std::move(inp_))
: OperatorEstBase(dl), input_(std::move(inp_)), wf_type(no_wf)
{
//my_name_ = input_.get_name();

auto& inp = this->input_.input_section_;

use_param_deriv = inp.get<bool>("param_deriv");

auto msd_refvec = wfn.findMSD();
if (msd_refvec.size() != 1)
auto sd_refvec = wfn.findSD();

auto nsd = sd_refvec.size();
auto nmsd = msd_refvec.size();

size_t nparams;
if (nmsd==1 and nsd==0)
{// multi-slater-det wavefunction
wf_type = msd_wf;
const MultiSlaterDetTableMethod& msd = msd_refvec[0];
if (!use_param_deriv)
nparams = msd.getLinearExpansionCoefs().size();
else
nparams = msd.myVars->size();
if(nparams==0)
throw std::runtime_error("SelfHealingOverlap: multidet wavefunction has no parameters.");
}
else if (nmsd==0 and nsd==1)
{// slater-det wavefunction
wf_type = sd_rot_wf;
const SlaterDet& sd = sd_refvec[0];
nparams = sd.myVars.size();
if(nparams==0)
throw std::runtime_error("SelfHealingOverlap: slaterdet wavefunction has no parameters.\n Please check that <rotated_sposet/>'s appear in the input file.");
}
else
{
throw std::runtime_error(
"SelfHealingOverlap requires one and only one multi slater determinant component in the trial wavefunction.");
"SelfHealingOverlap requires a single slater or multi-slater determinant component in the trial wavefunction.");
}

const MultiSlaterDetTableMethod& msd = msd_refvec[0];
const size_t data_size = msd.getLinearExpansionCoefs().size();
#ifndef QMC_COMPLEX
const size_t data_size = nparams;
#else
const size_t data_size = 2*nparams;
#endif
data_.resize(data_size, 0.0);

}


@@ -78,39 +112,56 @@ void SelfHealingOverlap::accumulate(const RefVector<MCPWalker>& walkers,
RealType weight = walker.Weight;
auto& wcs = psi.getOrbitals();

// separate jastrow and fermi wavefunction components
// find jastrow wavefunction components
std::vector<WaveFunctionComponent*> wcs_jastrow;
std::vector<WaveFunctionComponent*> wcs_fermi;
for (auto& wc : wcs)
if (wc->isFermionic())
wcs_fermi.push_back(wc.get());
else
if (!wc->isFermionic())
wcs_jastrow.push_back(wc.get());

// fermionic must have only one component, and must be multideterminant
assert(wcs_fermi.size() == 1);
WaveFunctionComponent& wf = *wcs_fermi[0];
if (!wf.isMultiDet())
throw std::runtime_error("SelfHealingOverlap estimator requires use of multideterminant wavefunction");
auto msd_refvec = psi.findMSD();
MultiSlaterDetTableMethod& msd = msd_refvec[0];

// collect parameter derivatives: (dpsi/dc_i)/psi
msd.calcIndividualDetRatios(det_ratios);
if (wf_type==msd_wf)
{
auto msd_refvec = psi.findMSD();
MultiSlaterDetTableMethod& msd = msd_refvec[0];
// collect parameter derivatives: (dpsi/dc_i)/psi
if (!use_param_deriv)
msd.calcIndividualDetRatios(det_ratios);
else
{
const auto& vars = *msd.myVars;
msd.evaluateDerivativesWF(pset,vars,det_ratios);
}
}
else if(wf_type==sd_rot_wf)
{
auto sd_refvec = psi.findSD();
SlaterDet& sd = sd_refvec[0];
// collect parameter derivatives: (dpsi/dc_i)/psi
sd.evaluateDerivativesWF(pset,sd.myVars,det_ratios);
}
else
throw std::runtime_error("SelfHealingOverlap: impossible branch reached, contact the developers");

// collect jastrow prefactor
WaveFunctionComponent::LogValue Jval = 0.0;
for (auto& wc : wcs_jastrow)
Jval += wc->get_log_value();
auto Jprefactor = std::real(std::exp(-2. * Jval));
auto Jprefactor = std::exp(-2. * Jval);

// accumulate weight (required by all estimators, otherwise inf results)
walkers_weight_ += weight;

// accumulate data
assert(det_ratios.size() == data_.size());
for (int ic = 0; ic < det_ratios.size(); ++ic)
data_[ic] += weight * Jprefactor * std::real(det_ratios[ic]); // only real supported for now
{
#ifndef QMC_COMPLEX
data_[ic] += weight * Jprefactor * det_ratios[ic];
#else
auto value = weight * Jprefactor * std::conj(det_ratios[ic]);
data_[2*ic ] += std::real(value);
data_[2*ic+1] += std::imag(value);
#endif
}
}
}

15 changes: 15 additions & 0 deletions src/Estimators/SelfHealingOverlap.h
Original file line number Diff line number Diff line change
@@ -33,13 +33,28 @@ class SelfHealingOverlap : public OperatorEstBase
using ValueType = QMCTraits::ValueType;
using PosType = QMCTraits::PosType;


enum wf_types
{
msd_wf = 0,
sd_rot_wf,
no_wf
};

//data members set only during construction
const SelfHealingOverlapInput input_;

/** @ingroup SelfHealingOverlap mutable data members
*/
Vector<ValueType> det_ratios;

/// wavefunction type
wf_types wf_type;

/// use direct parameter derivative for MSD or not
bool use_param_deriv;


public:
/** Constructor for SelfHealingOverlapInput
*/
7 changes: 5 additions & 2 deletions src/Estimators/SelfHealingOverlapInput.h
Original file line number Diff line number Diff line change
@@ -32,9 +32,12 @@ class SelfHealingOverlapInput
SelfHealingOverlapInputSection()
{
section_name = "SelfHealingOverlap";
attributes = {"type", "name"};
attributes = {"type", "name", "param_deriv"};
strings = {"type", "name"};
default_values = {{"type", std::string("sh_overlap")},{"name", std::string("sh_overlap")}};
bools = {"param_deriv"};
default_values = {{"type", std::string("sh_overlap")},
{"name", std::string("sh_overlap")},
{"param_deriv",false}};
}
// clang-format: on
};
2 changes: 2 additions & 0 deletions src/QMCWaveFunctions/Fermion/MultiSlaterDetTableMethod.h
Original file line number Diff line number Diff line change
@@ -285,8 +285,10 @@ class MultiSlaterDetTableMethod : public WaveFunctionComponent, public Optimizab
std::shared_ptr<std::vector<ValueType>> C;
/// if true, the CI coefficients are optimized
bool CI_Optimizable;
public:
//optimizable variable is shared with the clones
std::shared_ptr<opt_variables_type> myVars;
private:

/// CSF data set. If nullptr, not using CSF
std::shared_ptr<CSFData> csf_data_;
10 changes: 10 additions & 0 deletions src/QMCWaveFunctions/TrialWaveFunction.cpp
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
#include "Concurrency/Info.hpp"
#include "type_traits/ConvertToReal.h"
#include "NaNguard.h"
#include "Fermion/SlaterDet.h"
#include "Fermion/MultiSlaterDetTableMethod.h"

namespace qmcplusplus
@@ -107,6 +108,15 @@ const SPOSet& TrialWaveFunction::getSPOSet(const std::string& name) const
return *spoit->second;
}

RefVector<SlaterDet> TrialWaveFunction::findSD() const
{
RefVector<SlaterDet> refs;
for (auto& component : Z)
if (auto* comp_ptr = dynamic_cast<SlaterDet*>(component.get()); comp_ptr)
Copy link
Contributor

@PDoakORNL PDoakORNL Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

findSD is seems too brief, this is now part of the public API and it would be nice if it was recognizable as what it is from that. Imagine I want to search the codebase for SlaterDet definitions

auto my_det = twf.findSD()

SD is not a character combination of much specificity.

This smells a bit, dynamic_casts aren't zero cost

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't worry about that. I'd like to addressed it separately.
findSD/findMSD are too specific. I'd like to change them as mentioned #5291 (comment)

refs.push_back(*comp_ptr);
return refs;
}

RefVector<MultiSlaterDetTableMethod> TrialWaveFunction::findMSD() const
{
RefVector<MultiSlaterDetTableMethod> refs;
4 changes: 4 additions & 0 deletions src/QMCWaveFunctions/TrialWaveFunction.h
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@

namespace qmcplusplus
{
class SlaterDet;
class MultiSlaterDetTableMethod;

/** @ingroup MBWfs
@@ -539,6 +540,9 @@ class TrialWaveFunction
/// spomap_ reference accessor
const SPOMap& getSPOMap() const { return *spomap_; }

/// find SD WFCs if exist
RefVector<SlaterDet> findSD() const;

/// find MSD WFCs if exist
RefVector<MultiSlaterDetTableMethod> findMSD() const;