Skip to content

Commit

Permalink
Merge pull request #4902 from ye-luo/fix-Tmove
Browse files Browse the repository at this point in the history
Fix T-move in batched DMC driver.
  • Loading branch information
prckent authored Jan 17, 2024
2 parents 5bf03ee + 0f1e7fe commit fc64a42
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 32 deletions.
15 changes: 10 additions & 5 deletions src/QMCDrivers/DMC/DMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,17 @@ DMCBatched::DMCBatched(const ProjectData& project_data,
std::move(pop),
"DMCBatched::",
comm,
"DMCBatched",
std::bind(&DMCBatched::setNonLocalMoveHandler, this, _1)),
"DMCBatched"),
dmcdriver_input_(input),
dmc_timers_("DMCBatched::")
{}

DMCBatched::~DMCBatched() = default;

void DMCBatched::setNonLocalMoveHandler(QMCHamiltonian& golden_hamiltonian)
void DMCBatched::setNonLocalMoveHandler(QMCHamiltonian& hamiltonian)
{
golden_hamiltonian.setNonLocalMoves(dmcdriver_input_.get_non_local_move(), qmcdriver_input_.get_tau(),
dmcdriver_input_.get_alpha(), dmcdriver_input_.get_gamma());
hamiltonian.setNonLocalMoves(dmcdriver_input_.get_non_local_move(), qmcdriver_input_.get_tau(),
dmcdriver_input_.get_alpha(), dmcdriver_input_.get_gamma());
}

template<CoordsType CT>
Expand Down Expand Up @@ -474,6 +473,12 @@ bool DMCBatched::run()
for (int step = 0; step < qmcdriver_input_.get_max_steps(); ++step)
{
ScopedTimer local_timer(timers_.run_steps_timer);

// ensure all the live walkers carry the up-to-date T-move settings.
// Such info should be removed from each NLPP eventually and be kept in the driver.
for (UPtr<QMCHamiltonian>& ham : population_.get_hamiltonians())
setNonLocalMoveHandler(*ham);

dmc_state.step = step;
crowd_task(crowds_.size(), runDMCStep, dmc_state, timers_, dmc_timers_, std::ref(step_contexts_),
std::ref(crowds_));
Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/DMC/DMCBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class DMCBatched : public QMCDriverNew

QMCRunType getRunType() override { return QMCRunType::DMC_BATCH; }

void setNonLocalMoveHandler(QMCHamiltonian& golden_hamiltonian);
void setNonLocalMoveHandler(QMCHamiltonian& hamiltonian);

private:
const DMCDriverInput dmcdriver_input_;
Expand Down
20 changes: 4 additions & 16 deletions src/QMCDrivers/QMCDriverNew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ QMCDriverNew::QMCDriverNew(const ProjectData& project_data,
MCPopulation&& population,
const std::string timer_prefix,
Communicate* comm,
const std::string& QMC_driver_type,
SetNonLocalMoveHandler snlm_handler)
const std::string& QMC_driver_type)
: MPIObjectBase(comm),
qmcdriver_input_(std::move(input)),
QMCType(QMC_driver_type),
Expand All @@ -61,8 +60,7 @@ QMCDriverNew::QMCDriverNew(const ProjectData& project_data,
driver_scope_timer_(createGlobalTimer(QMC_driver_type, timer_level_coarse)),
driver_scope_profiler_(qmcdriver_input_.get_scoped_profiling()),
project_data_(project_data),
walker_configs_ref_(wc),
setNonLocalMoveHandler_(snlm_handler)
walker_configs_ref_(wc)
{
// This is done so that the application level input structures reflect the actual input to the code.
// While the actual simulation objects still take singular input structures at construction.
Expand Down Expand Up @@ -193,8 +191,8 @@ void QMCDriverNew::initializeQMC(const AdjustedWalkerCounts& awc)
*/
void QMCDriverNew::setStatus(const std::string& aname, const std::string& h5name, bool append)
{
app_log() << "\n========================================================="
<< "\n Start " << QMCType << "\n File Root " << get_root_name();
app_log() << "\n=========================================================" << "\n Start " << QMCType
<< "\n File Root " << get_root_name();
app_log() << "\n=========================================================" << std::endl;

if (h5name.size())
Expand Down Expand Up @@ -281,14 +279,6 @@ void QMCDriverNew::makeLocalWalkers(IndexType nwalkers, RealType reserve)
for (int i = 0; i < num_walkers_to_kill; ++i)
population_.killLastWalker();
}

// \todo: this could be what is breaking spawned walkers
for (UPtr<QMCHamiltonian>& ham : population_.get_hamiltonians())
setNonLocalMoveHandler_(*ham);

// For the dead ones too. Since this should be on construction but...
for (UPtr<QMCHamiltonian>& ham : population_.get_dead_hamiltonians())
setNonLocalMoveHandler_(*ham);
}

/** Creates Random Number generators for crowds and step contexts
Expand Down Expand Up @@ -395,8 +385,6 @@ std::ostream& operator<<(std::ostream& o_stream, const QMCDriverNew& qmcd)
return o_stream;
}

void QMCDriverNew::defaultSetNonLocalMoveHandler(QMCHamiltonian& ham) {}

QMCDriverNew::AdjustedWalkerCounts QMCDriverNew::adjustGlobalWalkerCount(Communicate& comm,
const IndexType current_configs,
const IndexType requested_total_walkers,
Expand Down
13 changes: 3 additions & 10 deletions src/QMCDrivers/QMCDriverNew.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
MCPopulation&& population,
const std::string timer_prefix,
Communicate* comm,
const std::string& QMC_driver_type,
SetNonLocalMoveHandler = &QMCDriverNew::defaultSetNonLocalMoveHandler);
const std::string& QMC_driver_type);

///Move Constructor
QMCDriverNew(QMCDriverNew&&) = default;
Expand Down Expand Up @@ -186,7 +185,7 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
*/
void setStatus(const std::string& aname, const std::string& h5name, bool append) override;

void add_H_and_Psi(QMCHamiltonian* h, TrialWaveFunction* psi) override{};
void add_H_and_Psi(QMCHamiltonian* h, TrialWaveFunction* psi) override {};

void createRngsStepContexts(int num_crowds);

Expand Down Expand Up @@ -231,9 +230,7 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
*/
void process(xmlNodePtr cur) override = 0;

static void initialLogEvaluation(int crowd_id,
UPtrVector<Crowd>& crowds,
UPtrVector<ContextForSteps>& step_context);
static void initialLogEvaluation(int crowd_id, UPtrVector<Crowd>& crowds, UPtrVector<ContextForSteps>& step_context);


/** should be set in input don't see a reason to set individually
Expand Down Expand Up @@ -466,10 +463,6 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
private:
friend std::ostream& operator<<(std::ostream& o_stream, const QMCDriverNew& qmcd);

SetNonLocalMoveHandler setNonLocalMoveHandler_;

static void defaultSetNonLocalMoveHandler(QMCHamiltonian& gold_ham);

friend class qmcplusplus::testing::VMCBatchedTest;
friend class qmcplusplus::testing::DMCBatchedTest;
friend class qmcplusplus::testing::QMCDriverNewTestWrapper;
Expand Down

0 comments on commit fc64a42

Please sign in to comment.