Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor:Remove GlobalC::ucell in module_lr,module_psi #5691

Merged
merged 8 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
this->kv.ngk.data(),
this->pw_wfc->npwk_max,
&this->sf,
&this->ppcell);
&this->ppcell,
ucell);

this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single"
? new psi::Psi<T, Device>(this->psi[0])
Expand Down Expand Up @@ -257,7 +258,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)

this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma);

this->p_wf_init->make_table(this->kv.get_nks(), &this->sf, &this->ppcell);
this->p_wf_init->make_table(this->kv.get_nks(), &this->sf, &this->ppcell,ucell);
}
if (ucell.ionic_position_updated)
{
Expand Down Expand Up @@ -373,6 +374,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
this->kspw_psi,
this->p_hamilt,
this->ppcell,
ucell,
GlobalV::ofs_running,
this->already_initpsi);

Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_pw/hamilt_pwdft/structure_factor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void Structure_Factor::setup_structure_factor(const UnitCell* Ucell, const Modul
ModuleBase::TITLE("PW_Basis","setup_structure_factor");
ModuleBase::timer::tick("PW_Basis","setup_struc_factor");
const std::complex<double> ci_tpi = ModuleBase::NEG_IMAG_UNIT * ModuleBase::TWO_PI;

this->ucell = Ucell;
this->strucFac.create(Ucell->ntype, rho_basis->npw);
ModuleBase::Memory::record("SF::strucFac", sizeof(std::complex<double>) * Ucell->ntype*rho_basis->npw);

Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_pw/hamilt_pwdft/structure_factor.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Structure_Factor
ModuleBase::Vector3<double> q);

private:
const UnitCell* ucell;
std::complex<float> * c_eigts1 = nullptr, * c_eigts2 = nullptr, * c_eigts3 = nullptr;
std::complex<double> * z_eigts1 = nullptr, * z_eigts2 = nullptr, * z_eigts3 = nullptr;
const ModulePW::PW_Basis* rho_basis = nullptr;
Expand Down
47 changes: 25 additions & 22 deletions source/module_hamilt_pw/hamilt_pwdft/structure_factor_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ std::complex<double>* Structure_Factor::get_sk(const int ik,
const ModulePW::PW_Basis_K* wfc_basis) const
{
ModuleBase::timer::tick("Structure_Factor", "get_sk");
const double arg = (wfc_basis->kvec_c[ik] * GlobalC::ucell.atoms[it].tau[ia]) * ModuleBase::TWO_PI;
const double arg = (wfc_basis->kvec_c[ik] * ucell->atoms[it].tau[ia]) * ModuleBase::TWO_PI;
const std::complex<double> kphase = std::complex<double>(cos(arg), -sin(arg));
const int npw = wfc_basis->npwk[ik];
std::complex<double> *sk = new std::complex<double>[npw];
Expand All @@ -26,19 +26,22 @@ std::complex<double>* Structure_Factor::get_sk(const int ik,
const int ixy = wfc_basis->is2fftixy[is];
int ix = ixy / wfc_basis->fftny;
int iy = ixy % wfc_basis->fftny;
if (ix >= int(nx / 2) + 1) {
if (ix >= int(nx / 2) + 1)
{
ix -= nx;
}
if (iy >= int(ny / 2) + 1) {
}
if (iy >= int(ny / 2) + 1)
{
iy -= ny;
}
if (iz >= int(nz / 2) + 1) {
}
if (iz >= int(nz / 2) + 1)
{
iz -= nz;
}
}
ix += this->rho_basis->nx;
iy += this->rho_basis->ny;
iz += this->rho_basis->nz;
const int iat = GlobalC::ucell.itia2iat(it, ia);
const int iat = ucell->itia2iat(it, ia);
sk[igl] = kphase * this->eigts1(iat, ix) * this->eigts2(iat, iy) * this->eigts3(iat, iz);
}
ModuleBase::timer::tick("Structure_Factor", "get_sk");
Expand Down Expand Up @@ -66,33 +69,33 @@ void Structure_Factor::get_sk(Device* ctx,

int iat = 0, _npw = wfc_basis->npwk[ik], eigts1_nc = this->eigts1.nc, eigts2_nc = this->eigts2.nc,
eigts3_nc = this->eigts3.nc;
int *igl2isz = nullptr, *is2fftixy = nullptr, *atom_na = nullptr, *h_atom_na = new int[GlobalC::ucell.ntype];
FPTYPE *atom_tau = nullptr, *h_atom_tau = new FPTYPE[GlobalC::ucell.nat * 3], *kvec = wfc_basis->get_kvec_c_data<FPTYPE>();
int *igl2isz = nullptr, *is2fftixy = nullptr, *atom_na = nullptr, *h_atom_na = new int[ucell->ntype];
FPTYPE *atom_tau = nullptr, *h_atom_tau = new FPTYPE[ucell->nat * 3], *kvec = wfc_basis->get_kvec_c_data<FPTYPE>();
std::complex<FPTYPE> *eigts1 = this->get_eigts1_data<FPTYPE>(), *eigts2 = this->get_eigts2_data<FPTYPE>(),
*eigts3 = this->get_eigts3_data<FPTYPE>();
for (int it = 0; it < GlobalC::ucell.ntype; it++)
for (int it = 0; it < ucell->ntype; it++)
{
h_atom_na[it] = GlobalC::ucell.atoms[it].na;
h_atom_na[it] = ucell->atoms[it].na;
}
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
for (int iat = 0; iat < ucell->nat; iat++)
{
int it = GlobalC::ucell.iat2it[iat];
int ia = GlobalC::ucell.iat2ia[iat];
auto *tau = reinterpret_cast<double *>(GlobalC::ucell.atoms[it].tau.data());
int it = ucell->iat2it[iat];
int ia = ucell->iat2ia[iat];
auto *tau = reinterpret_cast<double *>(ucell->atoms[it].tau.data());
h_atom_tau[iat * 3 + 0] = static_cast<FPTYPE>(tau[ia * 3 + 0]);
h_atom_tau[iat * 3 + 1] = static_cast<FPTYPE>(tau[ia * 3 + 1]);
h_atom_tau[iat * 3 + 2] = static_cast<FPTYPE>(tau[ia * 3 + 2]);
}
if (device == base_device::GpuDevice)
{
resmem_int_op()(ctx, atom_na, GlobalC::ucell.ntype);
syncmem_int_op()(ctx, cpu_ctx, atom_na, h_atom_na, GlobalC::ucell.ntype);
resmem_int_op()(ctx, atom_na, ucell->ntype);
syncmem_int_op()(ctx, cpu_ctx, atom_na, h_atom_na, ucell->ntype);

resmem_var_op()(ctx, atom_tau, GlobalC::ucell.nat * 3);
syncmem_var_op()(ctx, cpu_ctx, atom_tau, h_atom_tau, GlobalC::ucell.nat * 3);
resmem_var_op()(ctx, atom_tau, ucell->nat * 3);
syncmem_var_op()(ctx, cpu_ctx, atom_tau, h_atom_tau, ucell->nat * 3);

igl2isz = wfc_basis->d_igl2isz_k;
is2fftixy = wfc_basis->d_is2fftixy;
Expand All @@ -107,7 +110,7 @@ void Structure_Factor::get_sk(Device* ctx,

cal_sk_op()(ctx,
ik,
GlobalC::ucell.ntype,
ucell->ntype,
wfc_basis->nx,
wfc_basis->ny,
wfc_basis->nz,
Expand Down Expand Up @@ -152,7 +155,7 @@ std::complex<double>* Structure_Factor::get_skq(int ik,
for (int ig = 0; ig < npw; ig++)
{
ModuleBase::Vector3<double> qkq = wfc_basis->getgpluskcar(ik, ig) + q;
double arg = (qkq * GlobalC::ucell.atoms[it].tau[ia]) * ModuleBase::TWO_PI;
double arg = (qkq * ucell->atoms[it].tau[ia]) * ModuleBase::TWO_PI;
skq[ig] = std::complex<double>(cos(arg), -sin(arg));
}

Expand Down
2 changes: 2 additions & 0 deletions source/module_hsolver/test/hsolver_pw_sup.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ void diago_PAO_in_pw_k2(
wavefunc* p_wf,
const ModuleBase::realArray& tab_at,
const int& lmaxkb,
const UnitCell& ucell,
hamilt::Hamilt<std::complex<float>, base_device::DEVICE_CPU>* phm_in) {
for (int i = 0; i < wvf.size(); i++) {
wvf.get_pointer()[i] = std::complex<float>((float)i + 1, 0);
Expand All @@ -207,6 +208,7 @@ void diago_PAO_in_pw_k2(
wavefunc* p_wf,
const ModuleBase::realArray& tab_at,
const int& lmaxkb,
const UnitCell& ucell,
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>* phm_in) {
for (int i = 0; i < wvf.size(); i++) {
wvf.get_pointer()[i] = std::complex<double>((double)i + 1, 0);
Expand Down
10 changes: 5 additions & 5 deletions source/module_lr/esolver_lrtd_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
// necessary steps in ESolver_KS::before_all_runners : symmetry and k-points
if (ModuleSymmetry::Symmetry::symm_flag == 1)
{
GlobalC::ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running);
ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SYMMETRY");
}
this->kv.set(ucell,ucell.symm, PARAM.inp.kpoint_file, PARAM.inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
Expand Down Expand Up @@ -318,12 +318,12 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
this->init_pot(chg_gs);

// search adjacent atoms and init Gint
std::cout << "ucell.infoNL.get_rcutmax_Beta(): " << GlobalC::ucell.infoNL.get_rcutmax_Beta() << std::endl;
std::cout << "ucell.infoNL.get_rcutmax_Beta(): " << ucell.infoNL.get_rcutmax_Beta() << std::endl;
double search_radius = -1.0;
search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running,
PARAM.inp.out_level,
orb.get_rcutmax_Phi(),
GlobalC::ucell.infoNL.get_rcutmax_Beta(),
ucell.infoNL.get_rcutmax_Beta(),
PARAM.globalv.gamma_only_local);
atom_arrange::search(PARAM.inp.search_pbc,
GlobalV::ofs_running,
Expand All @@ -341,7 +341,7 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
std::vector<std::vector<double>> dpsi_u;
std::vector<std::vector<double>> d2psi_u;

Gint_Tools::init_orb(dr_uniform, rcuts, GlobalC::ucell, orb, psi_u, dpsi_u, d2psi_u);
Gint_Tools::init_orb(dr_uniform, rcuts, ucell, orb, psi_u, dpsi_u, d2psi_u);
this->gt_.set_pbc_grid(this->pw_rho->nx,
this->pw_rho->ny,
this->pw_rho->nz,
Expand All @@ -357,7 +357,7 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
this->pw_rho->ny,
this->pw_rho->nplane,
this->pw_rho->startz_current,
GlobalC::ucell,
ucell,
GlobalC::GridD,
dr_uniform,
rcuts,
Expand Down
2 changes: 1 addition & 1 deletion source/module_lr/operator_casida/operator_lr_exx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ namespace LR
for (auto cell : this->BvK_cells)
{
std::complex<double> frac = RI::Global_Func::convert<std::complex<double>>(std::exp(
-ModuleBase::TWO_PI * ModuleBase::IMAG_UNIT * (this->kv.kvec_c.at(ik) * (RI_Util::array3_to_Vector3(cell) * GlobalC::ucell.latvec))));
-ModuleBase::TWO_PI * ModuleBase::IMAG_UNIT * (this->kv.kvec_c.at(ik) * (RI_Util::array3_to_Vector3(cell) * ucell.latvec))));
for (int it1 = 0;it1 < ucell.ntype;++it1)
for (int ia1 = 0; ia1 < ucell.atoms[it1].na; ++ia1)
for (int it2 = 0;it2 < ucell.ntype;++it2)
Expand Down
12 changes: 6 additions & 6 deletions source/module_lr/operator_casida/operator_lr_hxc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ namespace LR

// 3. v_hxc = f_hxc * rho_trans
ModuleBase::matrix vr_hxc(1, nrxx); //grid
this->pot.lock()->cal_v_eff(rho_trans, GlobalC::ucell, vr_hxc, ispin_ks);
this->pot.lock()->cal_v_eff(rho_trans, ucell, vr_hxc, ispin_ks);
LR_Util::_deallocate_2order_nested_ptr(rho_trans, 1);

// 4. V^{Hxc}_{\mu,\nu}=\int{dr} \phi_\mu(r) v_{Hxc}(r) \phi_\mu(r)
Gint_inout inout_vlocal(vr_hxc.c, 0, Gint_Tools::job_type::vlocal);
this->gint->get_hRGint()->set_zero();
this->gint->cal_gint(&inout_vlocal);
this->hR->set_zero(); // clear hR for each bands
this->gint->transfer_pvpR(&*this->hR, &GlobalC::ucell); //grid to 2d block
this->gint->transfer_pvpR(&*this->hR, &ucell); //grid to 2d block
ModuleBase::timer::tick("OperatorLRHxc", "grid_calculation");
}

Expand All @@ -88,7 +88,7 @@ namespace LR

elecstate::DensityMatrix<std::complex<double>, double> DM_trans_real_imag(&pmat, 1, kv.kvec_d, kv.get_nks() / nspin);
DM_trans_real_imag.init_DMR(*this->hR);
hamilt::HContainer<double> HR_real_imag(GlobalC::ucell, &this->pmat);
hamilt::HContainer<double> HR_real_imag(ucell, &this->pmat);
LR_Util::initialize_HR<std::complex<double>, double>(HR_real_imag, ucell, gd, orb_cutoff_);

auto dmR_to_hR = [&, this](const char& type) -> void
Expand All @@ -111,7 +111,7 @@ namespace LR

// 3. v_hxc = f_hxc * rho_trans
ModuleBase::matrix vr_hxc(1, nrxx); //grid
this->pot.lock()->cal_v_eff(rho_trans, GlobalC::ucell, vr_hxc, ispin_ks);
this->pot.lock()->cal_v_eff(rho_trans, ucell, vr_hxc, ispin_ks);
// print_grid_nonzero(vr_hxc.c, this->poticab->nrxx, 10, "vr_hxc");

LR_Util::_deallocate_2order_nested_ptr(rho_trans, 1);
Expand All @@ -123,9 +123,9 @@ namespace LR

// LR_Util::print_HR(*this->gint->get_hRGint(), this->ucell.nat, "VR(grid)");
HR_real_imag.set_zero();
this->gint->transfer_pvpR(&HR_real_imag, &GlobalC::ucell, &GlobalC::GridD);
this->gint->transfer_pvpR(&HR_real_imag, &ucell, &GlobalC::GridD);
// LR_Util::print_HR(HR_real_imag, this->ucell.nat, "VR(real, 2d)");
LR_Util::set_HR_real_imag_part(HR_real_imag, *this->hR, GlobalC::ucell.nat, type);
LR_Util::set_HR_real_imag_part(HR_real_imag, *this->hR, ucell.nat, type);
};
this->hR->set_zero();
dmR_to_hR('R'); //real
Expand Down
15 changes: 11 additions & 4 deletions source/module_psi/psi_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
const int* ngk,
const int npwx,
Structure_Factor* p_sf,
pseudopot_cell_vnl* p_ppcell)
pseudopot_cell_vnl* p_ppcell,
const UnitCell& ucell)
{
// allocate memory for std::complex<double> datatype psi
// New psi initializer in ABACUS, Developer's note:
Expand Down Expand Up @@ -126,7 +127,7 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
// however, init_at_1 does not actually initialize the psi, instead, it is a
// function to calculate a interpolate table saving overlap intergral or say
// Spherical Bessel Transform of atomic orbitals.
this->wf_old.init_at_1(p_sf, &p_ppcell->tab_at);
this->wf_old.init_at_1(ucell,p_sf, &p_ppcell->tab_at);
// similarly, wfcinit not really initialize any wavefunction, instead, it initialize
// the mapping from ixy, the 1d flattened index of point on fft grid (x, y) plane,
// to the index of "stick", composed of grid points.
Expand All @@ -135,15 +136,18 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
}

template <typename T, typename Device>
void PSIInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell)
void PSIInit<T, Device>::make_table(const int nks,
Structure_Factor* p_sf,
pseudopot_cell_vnl* p_ppcell,
const UnitCell& ucell)
{
if (this->use_psiinitializer)
{
} // do not need to do anything because the interpolate table is unchanged
else // old initialization method, used in EXX calculation
{
this->wf_old.init_after_vc(nks); // reallocate wanf2, the planewave expansion of lcao
this->wf_old.init_at_1(p_sf, &p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
this->wf_old.init_at_1(ucell,p_sf, &p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
}
}

Expand All @@ -152,6 +156,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
psi::Psi<T, Device>* kspw_psi,
hamilt::Hamilt<T, Device>* p_hamilt,
const pseudopot_cell_vnl& nlpp,
const UnitCell& ucell,
std::ofstream& ofs_running,
const bool is_already_initpsi)
{
Expand Down Expand Up @@ -278,6 +283,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
&this->wf_old,
nlpp.tab_at,
nlpp.lmaxkb,
ucell,
p_hamilt);
}
}
Expand All @@ -294,6 +300,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
&this->wf_old,
nlpp.tab_at,
nlpp.lmaxkb,
ucell,
p_hamilt);
}
}
Expand Down
9 changes: 7 additions & 2 deletions source/module_psi/psi_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@ class PSIInit
const int* ngk, //< number of G-vectors in the current pool
const int npwx, //< max number of plane waves of all pools
Structure_Factor* p_sf, //< structure factor
pseudopot_cell_vnl* p_ppcell); //< nonlocal pseudopotential
pseudopot_cell_vnl* p_ppcell, //< nonlocal pseudopotential
const UnitCell& ucell); //< unit cell

// make interpolate table
void make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell);
void make_table(const int nks,
Structure_Factor* p_sf,
pseudopot_cell_vnl* p_ppcell,
const UnitCell& ucell);

//------------------------ only for psi_initializer --------------------
/**
Expand All @@ -54,6 +58,7 @@ class PSIInit
psi::Psi<T, Device>* kspw_psi,
hamilt::Hamilt<T, Device>* p_hamilt,
const pseudopot_cell_vnl& nlpp,
const UnitCell& ucell,
std::ofstream& ofs_running,
const bool is_already_initpsi);

Expand Down
Loading
Loading