Skip to content

Commit

Permalink
Refactor: remove template for get_S (#5593)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuLiu98 authored Nov 25, 2024
1 parent 7e9d081 commit df94f88
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 73 deletions.
6 changes: 3 additions & 3 deletions source/module_esolver/esolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
{
if (PARAM.inp.calculation == "get_S")
{
return new ESolver_GetS<double, double>();
ModuleBase::WARNING_QUIT("ESolver", "get_S is not implemented for gamma_only");
}
else
{
Expand All @@ -202,7 +202,7 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
{
if (PARAM.inp.calculation == "get_S")
{
return new ESolver_GetS<std::complex<double>, double>();
return new ESolver_GetS();
}
else
{
Expand All @@ -213,7 +213,7 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
{
if (PARAM.inp.calculation == "get_S")
{
return new ESolver_GetS<std::complex<double>, std::complex<double>>();
ModuleBase::WARNING_QUIT("ESolver", "get_S is not implemented for npsin=4");
}
else
{
Expand Down
81 changes: 13 additions & 68 deletions source/module_esolver/esolver_gets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
namespace ModuleESolver
{

template <typename TK, typename TR>
ESolver_GetS<TK, TR>::ESolver_GetS()
ESolver_GetS::ESolver_GetS()
{
this->classname = "ESolver_GetS";
this->basisname = "LCAO";
}

template <typename TK, typename TR>
ESolver_GetS<TK, TR>::~ESolver_GetS()
ESolver_GetS::~ESolver_GetS()
{
}

template <typename TK, typename TR>
void ESolver_GetS<TK, TR>::before_all_runners(UnitCell& ucell, const Input_para& inp)
void ESolver_GetS::before_all_runners(UnitCell& ucell, const Input_para& inp)
{
ModuleBase::TITLE("ESolver_GetS", "before_all_runners");
ModuleBase::timer::tick("ESolver_GetS", "before_all_runners");
Expand All @@ -51,13 +48,13 @@ void ESolver_GetS<TK, TR>::before_all_runners(UnitCell& ucell, const Input_para&
if (this->pelec == nullptr)
{
// TK stands for double and complex<double>?
this->pelec = new elecstate::ElecStateLCAO<TK>(&(this->chr), // use which parameter?
&(this->kv),
this->kv.get_nks(),
&(this->GG), // mohan add 2024-04-01
&(this->GK), // mohan add 2024-04-01
this->pw_rho,
this->pw_big);
this->pelec = new elecstate::ElecStateLCAO<std::complex<double>>(&(this->chr), // use which parameter?
&(this->kv),
this->kv.get_nks(),
&(this->GG), // mohan add 2024-04-01
&(this->GK), // mohan add 2024-04-01
this->pw_rho,
this->pw_big);
}

// 3) init LCAO basis
Expand All @@ -76,61 +73,13 @@ void ESolver_GetS<TK, TR>::before_all_runners(UnitCell& ucell, const Input_para&
// 4) initialize the density matrix
// DensityMatrix is allocated here, DMK is also initialized here
// DMR is not initialized here, it will be constructed in each before_scf
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, &(this->pv), inp.nspin);
dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)
->init_DM(&this->kv, &(this->pv), inp.nspin);

ModuleBase::timer::tick("ESolver_GetS", "before_all_runners");
}

template <>
void ESolver_GetS<double, double>::runner(UnitCell& ucell, const int istep)
{
ModuleBase::TITLE("ESolver_GetS", "runner");
ModuleBase::WARNING_QUIT("ESolver_GetS<double, double>::runner", "not implemented");
}

template <>
void ESolver_GetS<std::complex<double>, std::complex<double>>::runner(UnitCell& ucell, const int istep)
{
ModuleBase::TITLE("ESolver_GetS", "runner");
ModuleBase::timer::tick("ESolver_GetS", "runner");

// (1) Find adjacent atoms for each atom.
double search_radius = -1.0;
search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running,
PARAM.inp.out_level,
orb_.get_rcutmax_Phi(),
ucell.infoNL.get_rcutmax_Beta(),
PARAM.globalv.gamma_only_local);

atom_arrange::search(PARAM.inp.search_pbc,
GlobalV::ofs_running,
GlobalC::GridD,
ucell,
search_radius,
PARAM.inp.test_atom_input);

this->RA.for_2d(this->pv, PARAM.globalv.gamma_only_local, orb_.cutoffs());

if (this->p_hamilt == nullptr)
{
this->p_hamilt
= new hamilt::HamiltLCAO<std::complex<double>, std::complex<double>>(&this->pv,
this->kv,
*(two_center_bundle_.overlap_orb),
orb_.cutoffs());
dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, std::complex<double>>*>(this->p_hamilt->ops)
->contributeHR();
}

const std::string fn = PARAM.globalv.global_out_dir + "SR.csr";
std::cout << " The file is saved in " << fn << std::endl;
ModuleIO::output_SR(pv, GlobalC::GridD, this->p_hamilt, fn);

ModuleBase::timer::tick("ESolver_GetS", "runner");
}

template <>
void ESolver_GetS<std::complex<double>, double>::runner(UnitCell& ucell, const int istep)
void ESolver_GetS::runner(UnitCell& ucell, const int istep)
{
ModuleBase::TITLE("ESolver_GetS", "runner");
ModuleBase::timer::tick("ESolver_GetS", "runner");
Expand Down Expand Up @@ -168,8 +117,4 @@ void ESolver_GetS<std::complex<double>, double>::runner(UnitCell& ucell, const i
ModuleBase::timer::tick("ESolver_GetS", "runner");
}

template class ESolver_GetS<double, double>;
template class ESolver_GetS<std::complex<double>, double>;
template class ESolver_GetS<std::complex<double>, std::complex<double>>;

} // namespace ModuleESolver
4 changes: 2 additions & 2 deletions source/module_esolver/esolver_gets.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

namespace ModuleESolver
{
template <typename TK, typename TR>
class ESolver_GetS : public ESolver_KS<TK>

class ESolver_GetS : public ESolver_KS<std::complex<double>>
{
public:
ESolver_GetS();
Expand Down

0 comments on commit df94f88

Please sign in to comment.