Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update exx code format #5701

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 86 additions & 87 deletions source/module_ri/Exx_LRI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,38 @@ void Exx_LRI<Tdata>::init(const MPI_Comm &mpi_comm_in, const K_Vectors &kv_in, c
ModuleBase::TITLE("Exx_LRI","init");
ModuleBase::timer::tick("Exx_LRI", "init");

// if(GlobalC::exx_info.info_global.separate_loop)
// {
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::No;
// Hexx_para.mixing_beta = 0;
// }
// else
// {
// if("plain"==GlobalC::CHR.mixing_mode)
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::Plain;
// else if("pulay"==GlobalC::CHR.mixing_mode)
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::Pulay;
// else
// throw std::invalid_argument("exx mixing error. exx_separate_loop==false, mixing_mode!=plain or pulay");
// Hexx_para.mixing_beta = GlobalC::CHR.mixing_beta;
// }

this->mpi_comm = mpi_comm_in;
this->p_kv = &kv_in;
this->orb_cutoff_ = orb.cutoffs();
// if(GlobalC::exx_info.info_global.separate_loop)
// {
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::No;
// Hexx_para.mixing_beta = 0;
// }
// else
// {
// if("plain"==GlobalC::CHR.mixing_mode)
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::Plain;
// else if("pulay"==GlobalC::CHR.mixing_mode)
// Hexx_para.mixing_mode = Exx_Abfs::Parallel::Communicate::Hexx::Mixing_Mode::Pulay;
// else
// throw std::invalid_argument("exx mixing error. exx_separate_loop==false, mixing_mode!=plain or pulay");
// Hexx_para.mixing_beta = GlobalC::CHR.mixing_beta;
// }

this->mpi_comm = mpi_comm_in;
this->p_kv = &kv_in;
this->orb_cutoff_ = orb.cutoffs();

this->lcaos = Exx_Abfs::Construct_Orbs::change_orbs( orb, this->info.kmesh_times );

// #ifdef __MPI
// Exx_Abfs::Util::bcast( this->info.files_abfs, 0, this->mpi_comm );
// #endif
// #ifdef __MPI
// Exx_Abfs::Util::bcast( this->info.files_abfs, 0, this->mpi_comm );
// #endif

const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>>
abfs_same_atom = Exx_Abfs::Construct_Orbs::abfs_same_atom( orb, this->lcaos, this->info.kmesh_times, this->info.pca_threshold );
if(this->info.files_abfs.empty()) {
this->abfs = abfs_same_atom;
} else {
this->abfs = Exx_Abfs::IO::construct_abfs( abfs_same_atom, orb, this->info.files_abfs, this->info.kmesh_times );
}
if(this->info.files_abfs.empty())
{ this->abfs = abfs_same_atom;}
else
{ this->abfs = Exx_Abfs::IO::construct_abfs( abfs_same_atom, orb, this->info.files_abfs, this->info.kmesh_times ); }
Exx_Abfs::Construct_Orbs::print_orbs_size(this->abfs, GlobalV::ofs_running);

auto get_ccp_parameter = [this]() -> std::map<std::string,double>
Expand All @@ -85,15 +84,14 @@ void Exx_LRI<Tdata>::init(const MPI_Comm &mpi_comm_in, const K_Vectors &kv_in, c
throw std::domain_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); break;
}
};
this->abfs_ccp = Conv_Coulomb_Pot_K::cal_orbs_ccp(this->abfs, this->info.ccp_type, get_ccp_parameter(), this->info.ccp_rmesh_times);
this->abfs_ccp = Conv_Coulomb_Pot_K::cal_orbs_ccp(this->abfs, this->info.ccp_type, get_ccp_parameter(), this->info.ccp_rmesh_times);


for( size_t T=0; T!=this->abfs.size(); ++T ) {
GlobalC::exx_info.info_ri.abfs_Lmax = std::max( GlobalC::exx_info.info_ri.abfs_Lmax, static_cast<int>(this->abfs[T].size())-1 );
}
for( size_t T=0; T!=this->abfs.size(); ++T )
{ GlobalC::exx_info.info_ri.abfs_Lmax = std::max( GlobalC::exx_info.info_ri.abfs_Lmax, static_cast<int>(this->abfs[T].size())-1 ); }

this->cv.set_orbitals(
orb,
orb,
this->lcaos, this->abfs, this->abfs_ccp,
this->info.kmesh_times, this->info.ccp_rmesh_times );

Expand All @@ -106,19 +104,17 @@ void Exx_LRI<Tdata>::cal_exx_ions(const bool write_cv)
ModuleBase::TITLE("Exx_LRI","cal_exx_ions");
ModuleBase::timer::tick("Exx_LRI", "cal_exx_ions");

// init_radial_table_ions( cal_atom_centres_core(atom_pairs_core_origin), atom_pairs_core_origin );
// init_radial_table_ions( cal_atom_centres_core(atom_pairs_core_origin), atom_pairs_core_origin );

// this->m_abfsabfs.init_radial_table(Rradial);
// this->m_abfslcaos_lcaos.init_radial_table(Rradial);
// this->m_abfsabfs.init_radial_table(Rradial);
// this->m_abfslcaos_lcaos.init_radial_table(Rradial);

std::vector<TA> atoms(GlobalC::ucell.nat);
for(int iat=0; iat<GlobalC::ucell.nat; ++iat) {
atoms[iat] = iat;
}
for(int iat=0; iat<GlobalC::ucell.nat; ++iat)
{ atoms[iat] = iat; }
std::map<TA,TatomR> atoms_pos;
for(int iat=0; iat<GlobalC::ucell.nat; ++iat) {
atoms_pos[iat] = RI_Util::Vector3_to_array3( GlobalC::ucell.atoms[ GlobalC::ucell.iat2it[iat] ].tau[ GlobalC::ucell.iat2ia[iat] ] );
}
for(int iat=0; iat<GlobalC::ucell.nat; ++iat)
{ atoms_pos[iat] = RI_Util::Vector3_to_array3( GlobalC::ucell.atoms[ GlobalC::ucell.iat2it[iat] ].tau[ GlobalC::ucell.iat2ia[iat] ] ); }
const std::array<TatomR,Ndim> latvec
= {RI_Util::Vector3_to_array3(GlobalC::ucell.a1),
RI_Util::Vector3_to_array3(GlobalC::ucell.a2),
Expand All @@ -137,7 +133,8 @@ void Exx_LRI<Tdata>::cal_exx_ions(const bool write_cv)
list_As_Vs.first, list_As_Vs.second[0],
{{"writable_Vws",true}});
this->cv.Vws = LRI_CV_Tools::get_CVws(Vs);
if (write_cv && GlobalV::MY_RANK == 0) { LRI_CV_Tools::write_Vs_abf(Vs, PARAM.globalv.global_out_dir + "Vs"); }
if (write_cv && GlobalV::MY_RANK == 0)
{ LRI_CV_Tools::write_Vs_abf(Vs, PARAM.globalv.global_out_dir + "Vs"); }
this->exx_lri.set_Vs(std::move(Vs), this->info.V_threshold);

if(PARAM.inp.cal_force || PARAM.inp.cal_stress)
Expand Down Expand Up @@ -166,7 +163,8 @@ void Exx_LRI<Tdata>::cal_exx_ions(const bool write_cv)
{"writable_Cws",true}, {"writable_dCws",true}, {"writable_Vws",false}, {"writable_dVws",false}});
std::map<TA,std::map<TAC,RI::Tensor<Tdata>>> &Cs = std::get<0>(Cs_dCs);
this->cv.Cws = LRI_CV_Tools::get_CVws(Cs);
if (write_cv && GlobalV::MY_RANK == 0) { LRI_CV_Tools::write_Cs_ao(Cs, PARAM.globalv.global_out_dir + "Cs"); }
if (write_cv && GlobalV::MY_RANK == 0)
{ LRI_CV_Tools::write_Cs_ao(Cs, PARAM.globalv.global_out_dir + "Cs"); }
this->exx_lri.set_Cs(std::move(Cs), this->info.C_threshold);

if(PARAM.inp.cal_force || PARAM.inp.cal_stress)
Expand All @@ -185,44 +183,48 @@ void Exx_LRI<Tdata>::cal_exx_ions(const bool write_cv)

template<typename Tdata>
void Exx_LRI<Tdata>::cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
const Parallel_Orbitals& pv,
const ModuleSymmetry::Symmetry_rotation* p_symrot)
const Parallel_Orbitals& pv,
const ModuleSymmetry::Symmetry_rotation* p_symrot)
{
ModuleBase::TITLE("Exx_LRI","cal_exx_elec");
ModuleBase::timer::tick("Exx_LRI", "cal_exx_elec");

const std::vector<std::tuple<std::set<TA>, std::set<TA>>> judge = RI_2D_Comm::get_2D_judge(pv);

if(p_symrot)
{ this->exx_lri.set_symmetry(true, p_symrot->get_irreducible_sector()); }
else
{ this->exx_lri.set_symmetry(false, {}); }

this->Hexxs.resize(PARAM.inp.nspin);
this->Eexx = 0;
(p_symrot) ? this->exx_lri.set_symmetry(true, p_symrot->get_irreducible_sector()) : this->exx_lri.set_symmetry(false, {});
for(int is=0; is<PARAM.inp.nspin; ++is)
{
std::string suffix = ((PARAM.inp.cal_force || PARAM.inp.cal_stress) ? std::to_string(is) : "");

this->exx_lri.set_Ds(Ds[is], this->info.dm_threshold, suffix);
this->exx_lri.cal_Hs({ "","",suffix });

if (!p_symrot)
{
this->Hexxs[is] = RI::Communicate_Tensors_Map_Judge::comm_map2_first(
this->mpi_comm, std::move(this->exx_lri.Hs), std::get<0>(judge[is]), std::get<1>(judge[is]));
}
else
{
// reduce but not repeat
auto Hs_a2D = this->exx_lri.post_2D.set_tensors_map2(this->exx_lri.Hs);
// rotate locally without repeat
Hs_a2D = p_symrot->restore_HR(GlobalC::ucell.symm, GlobalC::ucell.atoms, GlobalC::ucell.st, 'H', Hs_a2D);
// cal energy using full Hs without repeat
this->exx_lri.energy = this->exx_lri.post_2D.cal_energy(
this->exx_lri.post_2D.saves["Ds_" + suffix],
this->exx_lri.post_2D.set_tensors_map2(Hs_a2D));
// get repeated full Hs for abacus
this->Hexxs[is] = RI::Communicate_Tensors_Map_Judge::comm_map2_first(
this->mpi_comm, std::move(Hs_a2D), std::get<0>(judge[is]), std::get<1>(judge[is]));
}
this->Eexx += std::real(this->exx_lri.energy);
const std::string suffix = ((PARAM.inp.cal_force || PARAM.inp.cal_stress) ? std::to_string(is) : "");

this->exx_lri.set_Ds(Ds[is], this->info.dm_threshold, suffix);
this->exx_lri.cal_Hs({ "","",suffix });

if (!p_symrot)
{
this->Hexxs[is] = RI::Communicate_Tensors_Map_Judge::comm_map2_first(
this->mpi_comm, std::move(this->exx_lri.Hs), std::get<0>(judge[is]), std::get<1>(judge[is]));
}
else
{
// reduce but not repeat
auto Hs_a2D = this->exx_lri.post_2D.set_tensors_map2(this->exx_lri.Hs);
// rotate locally without repeat
Hs_a2D = p_symrot->restore_HR(GlobalC::ucell.symm, GlobalC::ucell.atoms, GlobalC::ucell.st, 'H', Hs_a2D);
// cal energy using full Hs without repeat
this->exx_lri.energy = this->exx_lri.post_2D.cal_energy(
this->exx_lri.post_2D.saves["Ds_" + suffix],
this->exx_lri.post_2D.set_tensors_map2(Hs_a2D));
// get repeated full Hs for abacus
this->Hexxs[is] = RI::Communicate_Tensors_Map_Judge::comm_map2_first(
this->mpi_comm, std::move(Hs_a2D), std::get<0>(judge[is]), std::get<1>(judge[is]));
}
this->Eexx += std::real(this->exx_lri.energy);
post_process_Hexx(this->Hexxs[is]);
}
this->Eexx = post_process_Eexx(this->Eexx);
Expand All @@ -245,8 +247,8 @@ template<typename Tdata>
double Exx_LRI<Tdata>::post_process_Eexx(const double& Eexx_in) const
{
ModuleBase::TITLE("Exx_LRI","post_process_Eexx");
const double SPIN_multiple = std::map<int, double>{ {1,2}, {2,1}, {4,1} }.at(PARAM.inp.nspin); // why?
const double frac = -SPIN_multiple;
const double SPIN_multiple = std::map<int, double>{ {1,2}, {2,1}, {4,1} }.at(PARAM.inp.nspin); // why?
const double frac = -SPIN_multiple;
return frac * Eexx_in;
}

Expand Down Expand Up @@ -280,8 +282,7 @@ void Exx_LRI<Tdata>::cal_exx_force()
for(std::size_t idim=0; idim<Ndim; ++idim) {
for(const auto &force_item : this->exx_lri.force[idim]) {
this->force_exx(force_item.first, idim) += std::real(force_item.second);
}
}
} }
}

const double SPIN_multiple = std::map<int,double>{{1,2}, {2,1}, {4,1}}.at(PARAM.inp.nspin); // why?
Expand All @@ -304,8 +305,7 @@ void Exx_LRI<Tdata>::cal_exx_stress()
for(std::size_t idim0=0; idim0<Ndim; ++idim0) {
for(std::size_t idim1=0; idim1<Ndim; ++idim1) {
this->stress_exx(idim0,idim1) += std::real(this->exx_lri.stress(idim0,idim1));
}
}
} }
}

const double SPIN_multiple = std::map<int,double>{{1,2}, {2,1}, {4,1}}.at(PARAM.inp.nspin); // why?
Expand All @@ -318,16 +318,15 @@ void Exx_LRI<Tdata>::cal_exx_stress()
template<typename Tdata>
std::vector<std::vector<int>> Exx_LRI<Tdata>::get_abfs_nchis() const
{
std::vector<std::vector<int>> abfs_nchis;
for (const auto& abfs_T : this->abfs)
{
std::vector<int> abfs_nchi_T;
for (const auto& abfs_L : abfs_T) {
abfs_nchi_T.push_back(abfs_L.size());
}
abfs_nchis.push_back(abfs_nchi_T);
}
return abfs_nchis;
std::vector<std::vector<int>> abfs_nchis;
for (const auto& abfs_T : this->abfs)
{
std::vector<int> abfs_nchi_T;
for (const auto& abfs_L : abfs_T)
{ abfs_nchi_T.push_back(abfs_L.size()); }
abfs_nchis.push_back(abfs_nchi_T);
}
return abfs_nchis;
}

#endif
10 changes: 5 additions & 5 deletions source/module_ri/LRI_CV.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ LRI_CV<Tdata>::~LRI_CV()

template<typename Tdata>
void LRI_CV<Tdata>::set_orbitals(
const LCAO_Orbitals& orb,
const LCAO_Orbitals& orb,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &lcaos_in,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &abfs_in,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &abfs_ccp_in,
Expand All @@ -47,7 +47,7 @@ void LRI_CV<Tdata>::set_orbitals(
ModuleBase::TITLE("LRI_CV", "set_orbitals");
ModuleBase::timer::tick("LRI_CV", "set_orbitals");

this->orb_cutoff_ = orb.cutoffs();
this->orb_cutoff_ = orb.cutoffs();
this->lcaos = lcaos_in;
this->abfs = abfs_in;
this->abfs_ccp = abfs_ccp_in;
Expand Down Expand Up @@ -109,11 +109,11 @@ auto LRI_CV<Tdata>::cal_datas(
if( R_delta.norm()*GlobalC::ucell.lat0 < Rcut )
{
const Tresult Data = func_DPcal_data(it0, it1, R_delta, flags);
// if(Data.norm(std::numeric_limits<double>::max()) > threshold)
// {
// if(Data.norm(std::numeric_limits<double>::max()) > threshold)
// {
#pragma omp critical(LRI_CV_cal_datas)
Datas[list_A0[i0]][list_A1[i1]] = Data;
// }
// }
}
}
}
Expand Down
Loading