Skip to content

Commit

Permalink
update exx code format
Browse files Browse the repository at this point in the history
  • Loading branch information
PeizeLin committed Dec 7, 2024
1 parent 0f79c05 commit 59c3b48
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 92 deletions.
173 changes: 86 additions & 87 deletions source/module_ri/Exx_LRI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,38 @@ void Exx_LRI<Tdata>::init(const MPI_Comm &mpi_comm_in, const K_Vectors &kv_in, c
ModuleBase::TITLE("Exx_LRI","init");
ModuleBase::timer::tick("Exx_LRI", "init");

// if(GlobalC::exx_info.info_global.separate_loop)
// {
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::No;
// Hexx_para.mixing_beta = 0;
// }
// else
// {
// if("plain"==GlobalC::CHR.mixing_mode)
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::Plain;
// else if("pulay"==GlobalC::CHR.mixing_mode)
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::Pulay;
// else
// throw std::invalid_argument("exx mixing error. exx_separate_loop==false, mixing_mode!=plain or pulay");
// Hexx_para.mixing_beta = GlobalC::CHR.mixing_beta;
// }

this->mpi_comm = mpi_comm_in;
this->p_kv = &kv_in;
this->orb_cutoff_ = orb.cutoffs();
// if(GlobalC::exx_info.info_global.separate_loop)
// {
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::No;
// Hexx_para.mixing_beta = 0;
// }
// else
// {
// if("plain"==GlobalC::CHR.mixing_mode)
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::Plain;
// else if("pulay"==GlobalC::CHR.mixing_mode)
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::Pulay;
// else
// throw std::invalid_argument("exx mixing error. exx_separate_loop==false, mixing_mode!=plain or pulay");
// Hexx_para.mixing_beta = GlobalC::CHR.mixing_beta;
// }

this->mpi_comm = mpi_comm_in;
this->p_kv = &kv_in;
this->orb_cutoff_ = orb.cutoffs();

this->lcaos = Exx_Abfs::Construct_Orbs::change_orbs( orb, this->info.kmesh_times );

// #ifdef __MPI
// Exx_Abfs::Util::bcast( this->info.files_abfs, 0, this->mpi_comm );
// #endif
// #ifdef __MPI
// Exx_Abfs::Util::bcast( this->info.files_abfs, 0, this->mpi_comm );
// #endif

const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>>
abfs_same_atom = Exx_Abfs::Construct_Orbs::abfs_same_atom( orb, this->lcaos, this->info.kmesh_times, this->info.pca_threshold );
if(this->info.files_abfs.empty()) {
this->abfs = abfs_same_atom;
} else {
this->abfs = Exx_Abfs::IO::construct_abfs( abfs_same_atom, orb, this->info.files_abfs, this->info.kmesh_times );
}
if(this->info.files_abfs.empty())
{ this->abfs = abfs_same_atom;}
else
{ this->abfs = Exx_Abfs::IO::construct_abfs( abfs_same_atom, orb, this->info.files_abfs, this->info.kmesh_times ); }
Exx_Abfs::Construct_Orbs::print_orbs_size(this->abfs, GlobalV::ofs_running);

auto get_ccp_parameter = [this]() -> std::map<std::string,double>
Expand All @@ -85,15 +84,14 @@ void Exx_LRI<Tdata>::init(const MPI_Comm &mpi_comm_in, const K_Vectors &kv_in, c
throw std::domain_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); break;
}
};
this->abfs_ccp = Conv_Coulomb_Pot_K::cal_orbs_ccp(this->abfs, this->info.ccp_type, get_ccp_parameter(), this->info.ccp_rmesh_times);
this->abfs_ccp = Conv_Coulomb_Pot_K::cal_orbs_ccp(this->abfs, this->info.ccp_type, get_ccp_parameter(), this->info.ccp_rmesh_times);


for( size_t T=0; T!=this->abfs.size(); ++T ) {
GlobalC::exx_info.info_ri.abfs_Lmax = std::max( GlobalC::exx_info.info_ri.abfs_Lmax, static_cast<int>(this->abfs[T].size())-1 );
}
for( size_t T=0; T!=this->abfs.size(); ++T )
{ GlobalC::exx_info.info_ri.abfs_Lmax = std::max( GlobalC::exx_info.info_ri.abfs_Lmax, static_cast<int>(this->abfs[T].size())-1 ); }

this->cv.set_orbitals(
orb,
orb,
this->lcaos, this->abfs, this->abfs_ccp,
this->info.kmesh_times, this->info.ccp_rmesh_times );

Expand All @@ -106,19 +104,17 @@ void Exx_LRI<Tdata>::cal_exx_ions(const bool write_cv)
ModuleBase::TITLE("Exx_LRI","cal_exx_ions");
ModuleBase::timer::tick("Exx_LRI", "cal_exx_ions");

// init_radial_table_ions( cal_atom_centres_core(atom_pairs_core_origin), atom_pairs_core_origin );
// init_radial_table_ions( cal_atom_centres_core(atom_pairs_core_origin), atom_pairs_core_origin );

// this->m_abfsabfs.init_radial_table(Rradial);
// this->m_abfslcaos_lcaos.init_radial_table(Rradial);
// this->m_abfsabfs.init_radial_table(Rradial);
// this->m_abfslcaos_lcaos.init_radial_table(Rradial);

std::vector<TA> atoms(GlobalC::ucell.nat);
for(int iat=0; iat<GlobalC::ucell.nat; ++iat) {
atoms[iat] = iat;
}
for(int iat=0; iat<GlobalC::ucell.nat; ++iat)
{ atoms[iat] = iat; }
std::map<TA,TatomR> atoms_pos;
for(int iat=0; iat<GlobalC::ucell.nat; ++iat) {
atoms_pos[iat] = RI_Util::Vector3_to_array3( GlobalC::ucell.atoms[ GlobalC::ucell.iat2it[iat] ].tau[ GlobalC::ucell.iat2ia[iat] ] );
}
for(int iat=0; iat<GlobalC::ucell.nat; ++iat)
{ atoms_pos[iat] = RI_Util::Vector3_to_array3( GlobalC::ucell.atoms[ GlobalC::ucell.iat2it[iat] ].tau[ GlobalC::ucell.iat2ia[iat] ] ); }
const std::array<TatomR,Ndim> latvec
= {RI_Util::Vector3_to_array3(GlobalC::ucell.a1),
RI_Util::Vector3_to_array3(GlobalC::ucell.a2),
Expand All @@ -137,7 +133,8 @@ void Exx_LRI<Tdata>::cal_exx_ions(const bool write_cv)
list_As_Vs.first, list_As_Vs.second[0],
{{"writable_Vws",true}});
this->cv.Vws = LRI_CV_Tools::get_CVws(Vs);
if (write_cv && GlobalV::MY_RANK == 0) { LRI_CV_Tools::write_Vs_abf(Vs, PARAM.globalv.global_out_dir + "Vs"); }
if (write_cv && GlobalV::MY_RANK == 0)
{ LRI_CV_Tools::write_Vs_abf(Vs, PARAM.globalv.global_out_dir + "Vs"); }
this->exx_lri.set_Vs(std::move(Vs), this->info.V_threshold);

if(PARAM.inp.cal_force || PARAM.inp.cal_stress)
Expand Down Expand Up @@ -166,7 +163,8 @@ void Exx_LRI<Tdata>::cal_exx_ions(const bool write_cv)
{"writable_Cws",true}, {"writable_dCws",true}, {"writable_Vws",false}, {"writable_dVws",false}});
std::map<TA,std::map<TAC,RI::Tensor<Tdata>>> &Cs = std::get<0>(Cs_dCs);
this->cv.Cws = LRI_CV_Tools::get_CVws(Cs);
if (write_cv && GlobalV::MY_RANK == 0) { LRI_CV_Tools::write_Cs_ao(Cs, PARAM.globalv.global_out_dir + "Cs"); }
if (write_cv && GlobalV::MY_RANK == 0)
{ LRI_CV_Tools::write_Cs_ao(Cs, PARAM.globalv.global_out_dir + "Cs"); }
this->exx_lri.set_Cs(std::move(Cs), this->info.C_threshold);

if(PARAM.inp.cal_force || PARAM.inp.cal_stress)
Expand All @@ -185,44 +183,48 @@ void Exx_LRI<Tdata>::cal_exx_ions(const bool write_cv)

template<typename Tdata>
void Exx_LRI<Tdata>::cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
const Parallel_Orbitals& pv,
const ModuleSymmetry::Symmetry_rotation* p_symrot)
const Parallel_Orbitals& pv,
const ModuleSymmetry::Symmetry_rotation* p_symrot)
{
ModuleBase::TITLE("Exx_LRI","cal_exx_elec");
ModuleBase::timer::tick("Exx_LRI", "cal_exx_elec");

const std::vector<std::tuple<std::set<TA>, std::set<TA>>> judge = RI_2D_Comm::get_2D_judge(pv);

if(p_symrot)
{ this->exx_lri.set_symmetry(true, p_symrot->get_irreducible_sector()); }
else
{ this->exx_lri.set_symmetry(false, {}); }

this->Hexxs.resize(PARAM.inp.nspin);
this->Eexx = 0;
(p_symrot) ? this->exx_lri.set_symmetry(true, p_symrot->get_irreducible_sector()) : this->exx_lri.set_symmetry(false, {});
for(int is=0; is<PARAM.inp.nspin; ++is)
{
std::string suffix = ((PARAM.inp.cal_force || PARAM.inp.cal_stress) ? std::to_string(is) : "");

this->exx_lri.set_Ds(Ds[is], this->info.dm_threshold, suffix);
this->exx_lri.cal_Hs({ "","",suffix });

if (!p_symrot)
{
this->Hexxs[is] = RI::Communicate_Tensors_Map_Judge::comm_map2_first(
this->mpi_comm, std::move(this->exx_lri.Hs), std::get<0>(judge[is]), std::get<1>(judge[is]));
}
else
{
// reduce but not repeat
auto Hs_a2D = this->exx_lri.post_2D.set_tensors_map2(this->exx_lri.Hs);
// rotate locally without repeat
Hs_a2D = p_symrot->restore_HR(GlobalC::ucell.symm, GlobalC::ucell.atoms, GlobalC::ucell.st, 'H', Hs_a2D);
// cal energy using full Hs without repeat
this->exx_lri.energy = this->exx_lri.post_2D.cal_energy(
this->exx_lri.post_2D.saves["Ds_" + suffix],
this->exx_lri.post_2D.set_tensors_map2(Hs_a2D));
// get repeated full Hs for abacus
this->Hexxs[is] = RI::Communicate_Tensors_Map_Judge::comm_map2_first(
this->mpi_comm, std::move(Hs_a2D), std::get<0>(judge[is]), std::get<1>(judge[is]));
}
this->Eexx += std::real(this->exx_lri.energy);
const std::string suffix = ((PARAM.inp.cal_force || PARAM.inp.cal_stress) ? std::to_string(is) : "");

this->exx_lri.set_Ds(Ds[is], this->info.dm_threshold, suffix);
this->exx_lri.cal_Hs({ "","",suffix });

if (!p_symrot)
{
this->Hexxs[is] = RI::Communicate_Tensors_Map_Judge::comm_map2_first(
this->mpi_comm, std::move(this->exx_lri.Hs), std::get<0>(judge[is]), std::get<1>(judge[is]));
}
else
{
// reduce but not repeat
auto Hs_a2D = this->exx_lri.post_2D.set_tensors_map2(this->exx_lri.Hs);
// rotate locally without repeat
Hs_a2D = p_symrot->restore_HR(GlobalC::ucell.symm, GlobalC::ucell.atoms, GlobalC::ucell.st, 'H', Hs_a2D);
// cal energy using full Hs without repeat
this->exx_lri.energy = this->exx_lri.post_2D.cal_energy(
this->exx_lri.post_2D.saves["Ds_" + suffix],
this->exx_lri.post_2D.set_tensors_map2(Hs_a2D));
// get repeated full Hs for abacus
this->Hexxs[is] = RI::Communicate_Tensors_Map_Judge::comm_map2_first(
this->mpi_comm, std::move(Hs_a2D), std::get<0>(judge[is]), std::get<1>(judge[is]));
}
this->Eexx += std::real(this->exx_lri.energy);
post_process_Hexx(this->Hexxs[is]);
}
this->Eexx = post_process_Eexx(this->Eexx);
Expand All @@ -245,8 +247,8 @@ template<typename Tdata>
double Exx_LRI<Tdata>::post_process_Eexx(const double& Eexx_in) const
{
ModuleBase::TITLE("Exx_LRI","post_process_Eexx");
const double SPIN_multiple = std::map<int, double>{ {1,2}, {2,1}, {4,1} }.at(PARAM.inp.nspin); // why?
const double frac = -SPIN_multiple;
const double SPIN_multiple = std::map<int, double>{ {1,2}, {2,1}, {4,1} }.at(PARAM.inp.nspin); // why?
const double frac = -SPIN_multiple;
return frac * Eexx_in;
}

Expand Down Expand Up @@ -280,8 +282,7 @@ void Exx_LRI<Tdata>::cal_exx_force()
for(std::size_t idim=0; idim<Ndim; ++idim) {
for(const auto &force_item : this->exx_lri.force[idim]) {
this->force_exx(force_item.first, idim) += std::real(force_item.second);
}
}
} }
}

const double SPIN_multiple = std::map<int,double>{{1,2}, {2,1}, {4,1}}.at(PARAM.inp.nspin); // why?
Expand All @@ -304,8 +305,7 @@ void Exx_LRI<Tdata>::cal_exx_stress()
for(std::size_t idim0=0; idim0<Ndim; ++idim0) {
for(std::size_t idim1=0; idim1<Ndim; ++idim1) {
this->stress_exx(idim0,idim1) += std::real(this->exx_lri.stress(idim0,idim1));
}
}
} }
}

const double SPIN_multiple = std::map<int,double>{{1,2}, {2,1}, {4,1}}.at(PARAM.inp.nspin); // why?
Expand All @@ -318,16 +318,15 @@ void Exx_LRI<Tdata>::cal_exx_stress()
template<typename Tdata>
std::vector<std::vector<int>> Exx_LRI<Tdata>::get_abfs_nchis() const
{
std::vector<std::vector<int>> abfs_nchis;
for (const auto& abfs_T : this->abfs)
{
std::vector<int> abfs_nchi_T;
for (const auto& abfs_L : abfs_T) {
abfs_nchi_T.push_back(abfs_L.size());
}
abfs_nchis.push_back(abfs_nchi_T);
}
return abfs_nchis;
std::vector<std::vector<int>> abfs_nchis;
for (const auto& abfs_T : this->abfs)
{
std::vector<int> abfs_nchi_T;
for (const auto& abfs_L : abfs_T)
{ abfs_nchi_T.push_back(abfs_L.size()); }
abfs_nchis.push_back(abfs_nchi_T);
}
return abfs_nchis;
}

#endif
10 changes: 5 additions & 5 deletions source/module_ri/LRI_CV.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ LRI_CV<Tdata>::~LRI_CV()

template<typename Tdata>
void LRI_CV<Tdata>::set_orbitals(
const LCAO_Orbitals& orb,
const LCAO_Orbitals& orb,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &lcaos_in,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &abfs_in,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &abfs_ccp_in,
Expand All @@ -47,7 +47,7 @@ void LRI_CV<Tdata>::set_orbitals(
ModuleBase::TITLE("LRI_CV", "set_orbitals");
ModuleBase::timer::tick("LRI_CV", "set_orbitals");

this->orb_cutoff_ = orb.cutoffs();
this->orb_cutoff_ = orb.cutoffs();
this->lcaos = lcaos_in;
this->abfs = abfs_in;
this->abfs_ccp = abfs_ccp_in;
Expand Down Expand Up @@ -109,11 +109,11 @@ auto LRI_CV<Tdata>::cal_datas(
if( R_delta.norm()*GlobalC::ucell.lat0 < Rcut )
{
const Tresult Data = func_DPcal_data(it0, it1, R_delta, flags);
// if(Data.norm(std::numeric_limits<double>::max()) > threshold)
// {
// if(Data.norm(std::numeric_limits<double>::max()) > threshold)
// {
#pragma omp critical(LRI_CV_cal_datas)
Datas[list_A0[i0]][list_A1[i1]] = Data;
// }
// }
}
}
}
Expand Down

0 comments on commit 59c3b48

Please sign in to comment.