Skip to content

Commit

Permalink
Refactor: remove GlobalC::ucell in esolver (#5569)
Browse files Browse the repository at this point in the history
* Refactor: remove GlobalC::ucell in esolver

* [pre-commit.ci lite] apply automatic fixes

* update next_direct

* Refactor: put ucell as the first parameter

* rename cell to ucell

* update unitests

* update opt_TN.hpp

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
YuLiu98 and pre-commit-ci-lite[bot] authored Nov 24, 2024
1 parent c9f7973 commit 4ac1e8a
Show file tree
Hide file tree
Showing 46 changed files with 617 additions and 584 deletions.
2 changes: 1 addition & 1 deletion source/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void Driver::atomic_world()
//--------------------------------------------------

// where the actual stuff is done
this->driver_run();
this->driver_run(GlobalC::ucell);

ModuleBase::timer::finish(GlobalV::ofs_running);
ModuleBase::Memory::print_all(GlobalV::ofs_running);
Expand Down
4 changes: 3 additions & 1 deletion source/driver.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef DRIVER_H
#define DRIVER_H

#include "module_cell/unitcell.h"

class Driver
{
public:
Expand Down Expand Up @@ -34,7 +36,7 @@ class Driver
void atomic_world();

// the actual calculations
void driver_run();
void driver_run(UnitCell& ucell);
};

#endif
25 changes: 12 additions & 13 deletions source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
* the configuration-changing subroutine takes force and stress and updates the
* configuration
*/
void Driver::driver_run() {
void Driver::driver_run(UnitCell& ucell)
{
ModuleBase::TITLE("Driver", "driver_line");
ModuleBase::timer::tick("Driver", "driver_line");

Expand All @@ -39,37 +40,35 @@ void Driver::driver_run() {
#endif

// the life of ucell should begin here, mohan 2024-05-12
// delete ucell as a GlobalC in near future
GlobalC::ucell.setup_cell(PARAM.globalv.global_in_stru, GlobalV::ofs_running);
Check_Atomic_Stru::check_atomic_stru(GlobalC::ucell,
PARAM.inp.min_dist_coef);
ucell.setup_cell(PARAM.globalv.global_in_stru, GlobalV::ofs_running);
Check_Atomic_Stru::check_atomic_stru(ucell, PARAM.inp.min_dist_coef);

//! 2: initialize the ESolver (depends on a set-up ucell after `setup_cell`)
ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(PARAM.inp, GlobalC::ucell);
ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(PARAM.inp, ucell);

//! 3: initialize Esolver and fill json-structure
p_esolver->before_all_runners(PARAM.inp, GlobalC::ucell);
p_esolver->before_all_runners(ucell, PARAM.inp);

// this Json part should be moved to before_all_runners, mohan 2024-05-12
#ifdef __RAPIDJSON
Json::gen_stru_wrapper(&GlobalC::ucell);
Json::gen_stru_wrapper(&ucell);
#endif

const std::string cal_type = PARAM.inp.calculation;

//! 4: different types of calculations
if (cal_type == "md")
{
Run_MD::md_line(GlobalC::ucell, p_esolver, PARAM);
Run_MD::md_line(ucell, p_esolver, PARAM);
}
else if (cal_type == "scf" || cal_type == "relax" || cal_type == "cell-relax" || cal_type == "nscf")
{
Relax_Driver rl_driver;
rl_driver.relax_driver(p_esolver);
rl_driver.relax_driver(p_esolver, ucell);
}
else if (cal_type == "get_S")
{
p_esolver->runner(0, GlobalC::ucell);
p_esolver->runner(ucell, 0);
}
else
{
Expand All @@ -79,11 +78,11 @@ void Driver::driver_run() {
//! test_neighbour(LCAO),
//! gen_bessel(PW), et al.
const int istep = 0;
p_esolver->others(istep);
p_esolver->others(ucell, istep);
}

//! 5: clean up esolver
p_esolver->after_all_runners();
p_esolver->after_all_runners(ucell);

ModuleESolver::clean_esolver(p_esolver);

Expand Down
18 changes: 15 additions & 3 deletions source/module_base/opt_TN.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#ifndef OPT_TN_H
#define OPT_TN_H

#include <limits>
#include "opt_CG.h"

#include "./opt_CG.h"
#include <limits>

namespace ModuleBase
{
Expand All @@ -25,7 +25,7 @@ class Opt_TN
{
this->mach_prec_ = std::numeric_limits<double>::epsilon(); // get machine precise
}
~Opt_TN(){};
~Opt_TN() {};

/**
* @brief Allocate the space for the arrays in cg_.
Expand Down Expand Up @@ -54,7 +54,9 @@ class Opt_TN
{
this->iter_ = 0;
if (nx_new != 0)
{
this->nx_ = nx_new;
}
this->cg_.refresh(nx_new);
}

Expand Down Expand Up @@ -167,17 +169,23 @@ void Opt_TN::next_direct(double* px,
epsilon = this->get_epsilon(px, cg_direct);
// epsilon = 1e-9;
for (int i = 0; i < this->nx_; ++i)
{
temp_x[i] = px[i] + epsilon * cg_direct[i];
}
(t->*p_calGradient)(temp_x, temp_gradient);
for (int i = 0; i < this->nx_; ++i)
{
temp_Hcgd[i] = (temp_gradient[i] - pgradient[i]) / epsilon;
}

// get CG step length and update rdirect
cg_alpha = cg_.step_length(temp_Hcgd, cg_direct, cg_ifPD);
if (cg_ifPD == -1) // Hessian is not positive definite, and cgiter = 1.
{
for (int i = 0; i < this->nx_; ++i)
{
rdirect[i] += cg_alpha * cg_direct[i];
}
flag = -1;
break;
}
Expand All @@ -188,14 +196,18 @@ void Opt_TN::next_direct(double* px,
}

for (int i = 0; i < this->nx_; ++i)
{
rdirect[i] += cg_alpha * cg_direct[i];
}

// store residuals used in truncated conditions
last_residual = curr_residual;
curr_residual = cg_.get_residual();
cg_iter = cg_.get_iter();
if (cg_iter == 1)
{
init_residual = curr_residual;
}

// check truncated conditions
// if (curr_residual < 1e-12)
Expand Down
3 changes: 3 additions & 0 deletions source/module_cell/unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,9 @@ void UnitCell::setup_cell(const std::string& fn, std::ofstream& log) {
this->atoms = new Atom[this->ntype]; // atom species.
this->set_atom_flag = true;

this->symm.epsilon = PARAM.inp.symmetry_prec;
this->symm.epsilon_input = PARAM.inp.symmetry_prec;

bool ok = true;
bool ok2 = true;

Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/esolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
{
p_esolver = new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
}
p_esolver->before_all_runners(inp, ucell);
p_esolver->runner(0, ucell); // scf-only
p_esolver->before_all_runners(ucell, inp);
p_esolver->runner(ucell, 0); // scf-only
// force and stress is not needed currently,
// they will be supported after the analytical gradient
// of LR-TDDFT is implemented.
Expand Down
12 changes: 6 additions & 6 deletions source/module_esolver/esolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,26 @@ class ESolver
}

//! initialize the energy solver by using input parameters and cell modules
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) = 0;
virtual void before_all_runners(UnitCell& ucell, const Input_para& inp) = 0;

//! run energy solver
virtual void runner(const int istep, UnitCell& cell) = 0;
virtual void runner(UnitCell& cell, const int istep) = 0;

//! perform post processing calculations
virtual void after_all_runners(){};
virtual void after_all_runners(UnitCell& ucell){};

//! deal with exx and other calculation than scf/md/relax/cell-relax:
//! such as nscf, get_wf and get_pchg
virtual void others(const int istep){};
virtual void others(UnitCell& ucell, const int istep) {};

//! calculate total energy of a given system
virtual double cal_energy() = 0;

//! calcualte forces for the atoms in the given cell
virtual void cal_force(ModuleBase::matrix& force) = 0;
virtual void cal_force(UnitCell& ucell, ModuleBase::matrix& force) = 0;

//! calcualte stress of given cell
virtual void cal_stress(ModuleBase::matrix& stress) = 0;
virtual void cal_stress(UnitCell& ucell, ModuleBase::matrix& stress) = 0;

bool conv_esolver = true; // whether esolver is converged

Expand Down
10 changes: 5 additions & 5 deletions source/module_esolver/esolver_dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
namespace ModuleESolver
{

void ESolver_DP::before_all_runners(const Input_para& inp, UnitCell& ucell)
void ESolver_DP::before_all_runners(UnitCell& ucell, const Input_para& inp)
{
ucell_ = &ucell;
dp_potential = 0;
Expand All @@ -57,7 +57,7 @@ void ESolver_DP::before_all_runners(const Input_para& inp, UnitCell& ucell)
#endif
}

void ESolver_DP::runner(const int istep, UnitCell& ucell)
void ESolver_DP::runner(UnitCell& ucell, const int istep)
{
ModuleBase::TITLE("ESolver_DP", "runner");
ModuleBase::timer::tick("ESolver_DP", "runner");
Expand Down Expand Up @@ -127,13 +127,13 @@ double ESolver_DP::cal_energy()
return dp_potential;
}

void ESolver_DP::cal_force(ModuleBase::matrix& force)
void ESolver_DP::cal_force(UnitCell& ucell, ModuleBase::matrix& force)
{
force = dp_force;
ModuleIO::print_force(GlobalV::ofs_running, *ucell_, "TOTAL-FORCE (eV/Angstrom)", force, false);
}

void ESolver_DP::cal_stress(ModuleBase::matrix& stress)
void ESolver_DP::cal_stress(UnitCell& ucell, ModuleBase::matrix& stress)
{
stress = dp_virial;

Expand All @@ -148,7 +148,7 @@ void ESolver_DP::cal_stress(ModuleBase::matrix& stress)
ModuleIO::print_stress("TOTAL-STRESS", stress, true, false);
}

void ESolver_DP::after_all_runners()
void ESolver_DP::after_all_runners(UnitCell& ucell)
{
GlobalV::ofs_running << "\n\n --------------------------------------------" << std::endl;
GlobalV::ofs_running << std::setprecision(16);
Expand Down
10 changes: 5 additions & 5 deletions source/module_esolver/esolver_dp.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ class ESolver_DP : public ESolver
* @param inp input parameters
* @param cell unitcell information
*/
void before_all_runners(const Input_para& inp, UnitCell& cell) override;
void before_all_runners(UnitCell& ucell, const Input_para& inp) override;

/**
* @brief Run the DP solver for a given ion/md step and unit cell
*
* @param istep the current ion/md step
* @param cell unitcell information
*/
void runner(const int istep, UnitCell& cell) override;
void runner(UnitCell& cell, const int istep) override;

/**
* @brief get the total energy without ion kinetic energy
Expand All @@ -59,21 +59,21 @@ class ESolver_DP : public ESolver
*
* @param force the computed atomic forces
*/
void cal_force(ModuleBase::matrix& force) override;
void cal_force(UnitCell& ucell, ModuleBase::matrix& force) override;

/**
* @brief get the computed lattice virials
*
* @param stress the computed lattice virials
*/
void cal_stress(ModuleBase::matrix& stress) override;
void cal_stress(UnitCell& ucell, ModuleBase::matrix& stress) override;

/**
* @brief Prints the final total energy of the DP model to the output file
*
* This function prints the final total energy of the DP model in eV to the output file along with some formatting.
*/
void after_all_runners() override;
void after_all_runners(UnitCell& ucell) override;

private:
/**
Expand Down
Loading

0 comments on commit 4ac1e8a

Please sign in to comment.