From 51b4c885177c12152c9d0553c70663525a79b989 Mon Sep 17 00:00:00 2001 From: Haozhi Han Date: Mon, 20 Jan 2025 21:51:13 +0800 Subject: [PATCH 1/2] Refactor: refactor `npol`, constructors and `int *ngk` of Psi class (#5863) * change npol to private * fix cuda build bug * fix cuda build bug * fix build bug in cuda * remove npol value in psi * fix bug * fix bugs * update constructers * fix bug * fix test bug * remove Constructor 1-1 * fix bug * update psi * remove useless code --- source/module_elecstate/cal_dm.h | 2 +- source/module_elecstate/elecstate_pw.cpp | 2 +- .../module_esolver/esolver_ks_lcao_tddft.cpp | 4 +- source/module_esolver/esolver_ks_lcaopw.cpp | 7 +- source/module_esolver/esolver_ks_pw.cpp | 2 +- source/module_esolver/esolver_sdft_pw.cpp | 11 ++- source/module_esolver/lcao_before_scf.cpp | 2 +- source/module_esolver/lcao_others.cpp | 2 +- source/module_hamilt_general/operator.cpp | 6 +- .../module_deltaspin/cal_mw.cpp | 4 +- .../module_deltaspin/cal_mw_from_lambda.cpp | 8 +-- .../module_dftu/dftu_pw.cpp | 8 +-- .../hamilt_pwdft/onsite_projector.cpp | 4 +- .../hamilt_pwdft/operator_pw/velocity_pw.cpp | 4 +- .../module_hamilt_pw/hamilt_stodft/sto_wf.cpp | 14 ++-- .../module_hamilt_pw/hamilt_stodft/sto_wf.h | 2 +- source/module_hsolver/hsolver_lcao.cpp | 2 +- source/module_hsolver/test/diago_mock.h | 2 +- .../module_io/test/read_wfc_to_rho_test.cpp | 2 +- source/module_io/to_wannier90_lcao_in_pw.cpp | 12 +++- source/module_io/write_vxc_lip.hpp | 2 +- .../test/ao_to_mo_test.cpp | 72 +++++++++---------- .../module_lr/dm_trans/test/dm_trans_test.cpp | 48 ++++++------- source/module_psi/psi.cpp | 67 ++++++----------- source/module_psi/psi.h | 7 +- source/module_psi/psi_init.cpp | 8 +-- source/module_psi/psi_init.h | 2 +- .../psi_initializer_atomic_random.cpp | 2 +- .../module_psi/psi_initializer_nao_random.cpp | 2 +- .../test/psi_initializer_unit_test.cpp | 20 +++--- source/module_psi/test/psi_test.cpp | 20 +++--- source/module_ri/exx_lip.hpp | 2 +- 32 files changed, 170 insertions(+), 182 deletions(-) diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index 56aad08f3c..99815b296b 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -82,7 +82,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, //dm.fix_k(ik); dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), nullptr); + psi::Psi> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), wfc.get_nbasis(), true); const std::complex* pwfc = wfc.get_pointer(); std::complex* pwg_wfc = wg_wfc.get_pointer(); #ifdef _OPENMP diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index f241c59db8..944bcddef4 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -271,7 +271,7 @@ void ElecStatePW::cal_becsum(const psi::Psi& psi) { const T one{1, 0}; const T zero{0, 0}; - const int npol = psi.npol; + const int npol = psi.get_npol(); const int npwx = psi.get_nbasis() / npol; const int nbands = psi.get_nbands() * npol; const int nkb = this->ppcell->nkb; diff --git a/source/module_esolver/esolver_ks_lcao_tddft.cpp b/source/module_esolver/esolver_ks_lcao_tddft.cpp index 84a4ed0c68..acff4f7bf1 100644 --- a/source/module_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/module_esolver/esolver_ks_lcao_tddft.cpp @@ -196,9 +196,9 @@ void ESolver_KS_LCAO_TDDFT::update_pot(UnitCell& ucell, const int istep, const i if (this->psi_laststep == nullptr) { #ifdef __MPI - this->psi_laststep = new psi::Psi>(kv.get_nks(), ncol_nbands, nrow, nullptr); + this->psi_laststep = new psi::Psi>(kv.get_nks(), ncol_nbands, nrow, kv.ngk, true); #else - this->psi_laststep = new psi::Psi>(kv.get_nks(), nbands, nlocal, nullptr); + this->psi_laststep = new psi::Psi>(kv.get_nks(), nbands, nlocal, kv.ngk, true); #endif } diff --git a/source/module_esolver/esolver_ks_lcaopw.cpp b/source/module_esolver/esolver_ks_lcaopw.cpp index 08d1043a4a..4649fb07ca 100644 --- a/source/module_esolver/esolver_ks_lcaopw.cpp +++ b/source/module_esolver/esolver_ks_lcaopw.cpp @@ -93,9 +93,10 @@ namespace ModuleESolver ESolver_KS_PW::before_all_runners(ucell, inp); delete this->psi_local; this->psi_local = new psi::Psi(this->psi->get_nk(), - this->p_psi_init->psi_initer->nbands_start(), - this->psi->get_nbasis(), - this->psi->get_ngk_pointer()); + this->p_psi_init->psi_initer->nbands_start(), + this->psi->get_nbasis(), + this->kv.ngk, + true); #ifdef __EXX if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax" diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index a96d487a5c..cf5e300537 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -212,7 +212,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->kv, this->ppcell, *this->pw_wfc); - allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk.data(), PARAM.inp.nbands, this->pw_wfc->npwk_max); + allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.inp.nbands, this->pw_wfc->npwk_max); this->p_psi_init->prepare_init(PARAM.inp.pw_seed); this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single" diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index 667b440916..f5f9292522 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -78,13 +78,20 @@ void ESolver_SDFT_PW::before_all_runners(UnitCell& ucell, const Input // 4) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}> size_t size = stowf.chi0->size(); this->stowf.shchi - = new psi::Psi(this->kv.get_nks(), this->stowf.nchip_max, this->pw_wfc->npwk_max, this->kv.ngk.data()); + = new psi::Psi(this->kv.get_nks(), + this->stowf.nchip_max, + this->pw_wfc->npwk_max, + this->kv.ngk, + true); ModuleBase::Memory::record("SDFT::shchi", size * sizeof(T)); if (PARAM.inp.nbands > 0) { this->stowf.chiortho - = new psi::Psi(this->kv.get_nks(), this->stowf.nchip_max, this->pw_wfc->npwk_max, this->kv.ngk.data()); + = new psi::Psi(this->kv.get_nks(), + this->stowf.nchip_max, + this->pw_wfc->npwk_max, + this->kv.ngk, true); ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(T)); } diff --git a/source/module_esolver/lcao_before_scf.cpp b/source/module_esolver/lcao_before_scf.cpp index 2637fe41d8..0bf61d6947 100644 --- a/source/module_esolver/lcao_before_scf.cpp +++ b/source/module_esolver/lcao_before_scf.cpp @@ -159,7 +159,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) ncol = PARAM.inp.nbands; #endif } - this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, nullptr); + this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, this->kv.ngk, true); } // init wfc from file diff --git a/source/module_esolver/lcao_others.cpp b/source/module_esolver/lcao_others.cpp index fc0ab246d3..faca3563f0 100644 --- a/source/module_esolver/lcao_others.cpp +++ b/source/module_esolver/lcao_others.cpp @@ -165,7 +165,7 @@ void ESolver_KS_LCAO::others(UnitCell& ucell, const int istep) ncol = PARAM.inp.nbands; #endif } - this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, nullptr); + this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, this->kv.ngk, true); } // init wfc from file diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index e9020866e6..08a5ba97cc 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -63,7 +63,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp delete this->hpsi; this->hpsi = new psi::Psi(hpsi_pointer, 1, - nbands / psi_input->npol, + nbands / psi_input->get_npol(), psi_input->get_nbasis(), psi_input->get_nbasis(), true); @@ -86,7 +86,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp default: op->act(nbands, psi_input->get_nbasis(), - psi_input->npol, + psi_input->get_npol(), tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas(), @@ -105,7 +105,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp } ModuleBase::timer::tick("Operator", "hPsi"); - return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); + return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->get_npol()), hpsi_pointer); } template diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp index 94c5c74db7..25b2e4e879 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp @@ -66,7 +66,7 @@ void spinconstrain::SpinConstrain>::cal_mi_pw() psi::Psi, base_device::DEVICE_CPU>* psi_t = static_cast, base_device::DEVICE_CPU>*>(this->psi); const int nbands = psi_t->get_nbands(); const int nks = psi_t->get_nk(); - const int npol = psi_t->npol; + const int npol = psi_t->get_npol(); for(int ik = 0; ik < nks; ik++) { psi_t->fix_k(ik); @@ -112,7 +112,7 @@ void spinconstrain::SpinConstrain>::cal_mi_pw() psi::Psi, base_device::DEVICE_GPU>* psi_t = static_cast, base_device::DEVICE_GPU>*>(this->psi); const int nbands = psi_t->get_nbands(); const int nks = psi_t->get_nk(); - const int npol = psi_t->npol; + const int npol = psi_t->get_npol(); for(int ik = 0; ik < nks; ik++) { psi_t->fix_k(ik); diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp index 36baed7bab..d6602e6b11 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp @@ -199,7 +199,7 @@ void spinconstrain::SpinConstrain>::cal_mw_from_lambda(int hamilt::Hamilt, base_device::DEVICE_CPU>* hamilt_t = static_cast, base_device::DEVICE_CPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); @@ -252,7 +252,7 @@ void spinconstrain::SpinConstrain>::cal_mw_from_lambda(int hamilt::Hamilt, base_device::DEVICE_GPU>* hamilt_t = static_cast, base_device::DEVICE_GPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); @@ -382,7 +382,7 @@ void spinconstrain::SpinConstrain>::update_psi_charge(const hamilt::Hamilt, base_device::DEVICE_CPU>* hamilt_t = static_cast, base_device::DEVICE_CPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); @@ -454,7 +454,7 @@ void spinconstrain::SpinConstrain>::update_psi_charge(const hamilt::Hamilt, base_device::DEVICE_GPU>* hamilt_t = static_cast, base_device::DEVICE_GPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); diff --git a/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp b/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp index cc0c3a6c30..0ae2588625 100644 --- a/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp +++ b/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp @@ -29,11 +29,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr psi_p->fix_k(ik); onsite_p->tabulate_atomic(ik); - onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer()); + onsite_p->overlap_proj_psi(nbands*psi_p->get_npol(), psi_p->get_pointer()); const std::complex* becp = onsite_p->get_h_becp(); // becp(nbands*npol , nkb) // mag = wg * \sum_{nh}becp * becp - int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol; + int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol(); int begin_ih = 0; for(int iat = 0; iat < cell.nat; iat++) { @@ -88,11 +88,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr psi_p->fix_k(ik); onsite_p->tabulate_atomic(ik); - onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer()); + onsite_p->overlap_proj_psi(nbands*psi_p->get_npol(), psi_p->get_pointer()); const std::complex* becp = onsite_p->get_h_becp(); // becp(nbands*npol , nkb) // mag = wg * \sum_{nh}becp * becp - int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol; + int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol(); int begin_ih = 0; for(int iat = 0; iat < cell.nat; iat++) { diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp index f235df15e5..47faf38797 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp @@ -165,7 +165,7 @@ void projectors::OnsiteProjector::init(const std::string& orbital_dir RadialProjection::RadialProjector::_build_backward_map(it2iproj, lproj, irow2it_, irow2iproj_, irow2m_); RadialProjection::RadialProjector::_build_forward_map(it2ia, it2iproj, lproj, itiaiprojm2irow_); //rp_._build_sbt_tab(rgrid, projs, lproj, nq, dq); - rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.npol, tab, nhtol); + rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.get_npol(), tab, nhtol); // For being compatible with present cal_force and cal_stress framework // uncomment the following code block if you want to use the Onsite_Proj_tools if(this->tab_atomic_ == nullptr) @@ -541,7 +541,7 @@ void projectors::OnsiteProjector::cal_occupations(const psi::Psi::allocate_chi0() Device* ctx = {}; if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } else { @@ -207,7 +207,7 @@ void Stochastic_WF::init_com_orbitals() delete[] npwip; } size_t size = this->nchip_max * npwx * nks; - this->chi0_cpu = new psi::Psi>(nks, this->nchip_max, npwx, this->ngk); + this->chi0_cpu = new psi::Psi>(nks, this->nchip_max, npwx, this->ngk, true); this->chi0_cpu->zero_out(); ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(std::complex)); for (int ik = 0; ik < nks; ++ik) @@ -252,7 +252,7 @@ void Stochastic_WF::init_com_orbitals() Device* ctx = {}; if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } else { @@ -266,7 +266,7 @@ void Stochastic_WF::init_com_orbitals() const int npwx = this->npwx; const int nks = this->nks; size_t size = this->nchip_max * npwx * nks; - this->chi0_cpu = new psi::Psi>(nks, npwx, npwx, this->ngk); + this->chi0_cpu = new psi::Psi>(nks, npwx, npwx, this->ngk, true); this->chi0_cpu->zero_out(); ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(std::complex)); for (int ik = 0; ik < nks; ++ik) @@ -284,7 +284,7 @@ void Stochastic_WF::init_com_orbitals() Device* ctx = {}; if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } else { diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h index a423810544..4afdeb4247 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h @@ -30,10 +30,10 @@ class Stochastic_WF int* nchip = nullptr; ///< The number of stochatic orbitals in current process of each k point. int nchip_max = 0; ///< Max number of stochastic orbitals among all k points. int nks = 0; ///< number of k-points - int* ngk = nullptr; ///< ngk in klist int npwx = 0; ///< max ngk[ik] in all processors int nbands_diag = 0; ///< number of bands obtained from diagonalization int nbands_total = 0; ///< number of bands in total, nbands_total=nchi+nbands_diag; + std::vector ngk; ///< ngk in klist public: // Tn(H)|chi> psi::Psi* chiallorder = nullptr; diff --git a/source/module_hsolver/hsolver_lcao.cpp b/source/module_hsolver/hsolver_lcao.cpp index 2f9a1b4313..44deac1bbd 100644 --- a/source/module_hsolver/hsolver_lcao.cpp +++ b/source/module_hsolver/hsolver_lcao.cpp @@ -219,7 +219,7 @@ void HSolverLCAO::parakSolve(hamilt::Hamilt* pHamilt, k2d.distribute_hsk(pHamilt, ik_kpar, nrow); /// global index of k point int ik_global = ik + k2d.get_pKpoints()->startk_pool[k2d.get_my_pool()]; - auto psi_pool = psi::Psi(1, ncol_bands_pool, k2d.get_p2D_pool()->nrow, nullptr); + auto psi_pool = psi::Psi(1, ncol_bands_pool, k2d.get_p2D_pool()->nrow, k2d.get_p2D_pool()->nrow, true); ModuleBase::Memory::record("HSolverLCAO::psi_pool", nrow * ncol_bands_pool * sizeof(T)); if (ik_global < psi.get_nk() && ik < k2d.get_pKpoints()->nks_pool[k2d.get_my_pool()]) { diff --git a/source/module_hsolver/test/diago_mock.h b/source/module_hsolver/test/diago_mock.h index e63022f43d..85a7750fc5 100644 --- a/source/module_hsolver/test/diago_mock.h +++ b/source/module_hsolver/test/diago_mock.h @@ -214,7 +214,7 @@ class HPsi { Structure_Factor* sf; int* ngk = nullptr; - psi::Psi psitmp(1, nband, npw, ngk); + psi::Psi psitmp(1, nband, npw, npw, true); for(int i=0;i>(nks, nbands, wfcpw->npwk_max, wfcpw->npwk); + psi = new psi::Psi>(nks, nbands, wfcpw->npwk_max, kv->ngk, true); std::complex* ptr = psi->get_pointer(); for (int i = 0; i < nks * nbands * wfcpw->npwk_max; i++) { diff --git a/source/module_io/to_wannier90_lcao_in_pw.cpp b/source/module_io/to_wannier90_lcao_in_pw.cpp index 78f04a8a97..e067671465 100644 --- a/source/module_io/to_wannier90_lcao_in_pw.cpp +++ b/source/module_io/to_wannier90_lcao_in_pw.cpp @@ -52,7 +52,11 @@ void toWannier90_LCAO_IN_PW::calculate( const int nks_psi = (PARAM.inp.calculation == "nscf" && PARAM.inp.mem_saver == 1)? 1 : wfcpw->nks; const int nks_psig = (PARAM.inp.basis_type == "pw")? 1 : nks_psi; const int nbands_actual = this->psi_initer_->nbands_start(); - this->psi = new psi::Psi, base_device::DEVICE_CPU>(nks_psig, nbands_actual, wfcpw->npwk_max*PARAM.globalv.npol, wfcpw->npwk); + this->psi = new psi::Psi, base_device::DEVICE_CPU>(nks_psig, + nbands_actual, + wfcpw->npwk_max*PARAM.globalv.npol, + kv.ngk, + true); read_nnkp(ucell,kv); if (PARAM.inp.nspin == 2) @@ -117,7 +121,11 @@ psi::Psi>* toWannier90_LCAO_IN_PW::get_unk_from_lcao( { // init int npwx = wfcpw->npwk_max; - psi::Psi> *unk_inLcao = new psi::Psi>(num_kpts, num_bands, npwx*PARAM.globalv.npol, kv.ngk.data()); + psi::Psi> *unk_inLcao = new psi::Psi>(num_kpts, + num_bands, + npwx*PARAM.globalv.npol, + kv.ngk, + true); unk_inLcao->zero_out(); // Orbital projection to plane wave diff --git a/source/module_io/write_vxc_lip.hpp b/source/module_io/write_vxc_lip.hpp index d57c8f2ccd..1a50c8d00e 100644 --- a/source/module_io/write_vxc_lip.hpp +++ b/source/module_io/write_vxc_lip.hpp @@ -161,7 +161,7 @@ namespace ModuleIO // psi::Psi hpsi_single_band(&hpsi_localxc(ik, ib, 0), 1, 1, hpsi_localxc.get_current_nbas()); // vxcs_op_pw->act(1, psi_pw.get_current_nbas(), psi_pw.npol, psi_single_band.get_pointer(), hpsi_single_band.get_pointer(), psi_pw.get_ngk(ik)); // } - vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.npol, &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik)); + vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.get_npol(), &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik)); delete vxcs_op_pw; std::vector vxc_local_k_mo = psi_Hpsi(&psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_nbasis(), psi_pw.get_nbands()); Parallel_Reduce::reduce_pool(vxc_local_k_mo.data(), nbands * nbands); diff --git a/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp b/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp index 5601ad451d..8bcb88b525 100644 --- a/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp +++ b/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp @@ -64,18 +64,18 @@ TEST_F(AO2MOTest, DoubleSerial) { for (auto s : this->sizes) { - psi::Psi vo_for(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi vo_blas(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi oo_for(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi oo_blas(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi vv_for(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); - psi::Psi vv_blas(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi vo_for(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi vo_blas(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi oo_for(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi oo_blas(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi vv_for(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); + psi::Psi vv_blas(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; int size_v = s.naos * s.naos; for (int istate = 0;istate < nstate;++istate) { std::vector temp(s.nks, s.naos); - psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); std::vector V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); set_rand(&c(0, 0, 0), size_c); for (auto& v : V) { set_rand(v.data(), size_v); } @@ -96,18 +96,18 @@ TEST_F(AO2MOTest, ComplexSerial) { for (auto s : this->sizes) { - psi::Psi> vo_for(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi> vo_blas(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi> oo_for(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi> oo_blas(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi> vv_for(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); - psi::Psi> vv_blas(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> vo_for(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi> vo_blas(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi> oo_for(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi> oo_blas(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi> vv_for(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); + psi::Psi> vv_blas(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; int size_v = s.naos * s.naos; for (int istate = 0;istate < nstate;++istate) { std::vector temp(s.nks, s.naos); - psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); std::vector V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); set_rand(&c(0, 0, 0), size_c); for (auto& v : V) { set_rand(v.data>(), size_v); } @@ -137,7 +137,7 @@ TEST_F(AO2MOTest, DoubleParallel) LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt); std::vector ngk_temp(s.nks, pc.get_row_size()); - psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp.data(), true); + psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp, true); Parallel_2D pvo, poo, pvv; LR_Util::setup_2d_division(pvo, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt); LR_Util::setup_2d_division(poo, s.nb, s.nocc, s.nocc, pV.blacs_ctxt); @@ -148,12 +148,12 @@ TEST_F(AO2MOTest, DoubleParallel) EXPECT_GE(s.nvirt, pvo.dim0); EXPECT_GE(s.nocc, pvo.dim1); EXPECT_GE(s.naos, pc.dim0); - psi::Psi vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), nullptr, false); - psi::Psi vo_gather(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi oo_pblas_loc(s.nks, nstate, poo.get_local_size(), nullptr, false); - psi::Psi oo_gather(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), nullptr, false); - psi::Psi vv_gather(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), pvo.get_local_size(), false); + psi::Psi vo_gather(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi oo_pblas_loc(s.nks, nstate, poo.get_local_size(), poo.get_local_size(), false); + psi::Psi oo_gather(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), pvv.get_local_size(), false); + psi::Psi vv_gather(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); for (int istate = 0;istate < nstate;++istate) { for (int isk = 0;isk < s.nks;++isk) @@ -174,7 +174,7 @@ TEST_F(AO2MOTest, DoubleParallel) // compare to global AX std::vector V_full(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); std::vector ngk_temp_1(s.nks, s.naos); - psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_1.data(), true); + psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_1, true); for (int isk = 0;isk < s.nks;++isk) { LR_Util::gather_2d_to_full(pV, V.at(isk).data(), V_full.at(isk).data(), false, s.naos, s.naos); @@ -182,13 +182,13 @@ TEST_F(AO2MOTest, DoubleParallel) } if (my_rank == 0) { - psi::Psi vo_full_istate(s.nks, 1, s.nocc * s.nvirt, nullptr, false); + psi::Psi vo_full_istate(s.nks, 1, s.nocc * s.nvirt, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vo_full_istate(0, 0, 0), false); check_eq(&vo_full_istate(0, 0, 0), &vo_gather(istate, 0, 0), s.nks * s.nocc * s.nvirt); - psi::Psi oo_full_istate(s.nks, 1, s.nocc * s.nocc, nullptr, false); + psi::Psi oo_full_istate(s.nks, 1, s.nocc * s.nocc, s.nocc * s.nocc, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &oo_full_istate(0, 0, 0), false, LR::MO_TYPE::OO); check_eq(&oo_full_istate(0, 0, 0), &oo_gather(istate, 0, 0), s.nks * s.nocc * s.nocc); - psi::Psi vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, nullptr, false); + psi::Psi vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vv_full_istate(0, 0, 0), false, LR::MO_TYPE::VV); check_eq(&vv_full_istate(0, 0, 0), &vv_gather(istate, 0, 0), s.nks * s.nvirt * s.nvirt); } @@ -208,18 +208,18 @@ TEST_F(AO2MOTest, ComplexParallel) LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt); std::vector ngk_temp_1(s.nks, pc.get_row_size()); - psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp_1.data(), true); + psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp_1, true); Parallel_2D pvo, poo, pvv; LR_Util::setup_2d_division(pvo, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt); LR_Util::setup_2d_division(poo, s.nb, s.nocc, s.nocc, pV.blacs_ctxt); LR_Util::setup_2d_division(pvv, s.nb, s.nvirt, s.nvirt, pV.blacs_ctxt); - psi::Psi> vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), nullptr, false); - psi::Psi> vo_gather(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi> oo_pblas_loc(s.nks, nstate, poo.get_local_size(), nullptr, false); - psi::Psi> oo_gather(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi> vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), nullptr, false); - psi::Psi> vv_gather(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), pvo.get_local_size(), false); + psi::Psi> vo_gather(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi> oo_pblas_loc(s.nks, nstate, poo.get_local_size(), poo.get_local_size(), false); + psi::Psi> oo_gather(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi> vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), pvv.get_local_size(), false); + psi::Psi> vv_gather(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); for (int istate = 0;istate < nstate;++istate) { for (int isk = 0;isk < s.nks;++isk) @@ -241,7 +241,7 @@ TEST_F(AO2MOTest, ComplexParallel) // compare to global AX std::vector V_full(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); std::vector ngk_temp_2(s.nks, s.naos); - psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true); + psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2, true); for (int isk = 0;isk < s.nks;++isk) { LR_Util::gather_2d_to_full(pV, V.at(isk).data>(), V_full.at(isk).data>(), false, s.naos, s.naos); @@ -249,13 +249,13 @@ TEST_F(AO2MOTest, ComplexParallel) } if (my_rank == 0) { - psi::Psi> vo_full_istate(s.nks, 1, s.nocc * s.nvirt, nullptr, false); + psi::Psi> vo_full_istate(s.nks, 1, s.nocc * s.nvirt, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vo_full_istate(0, 0, 0), false); check_eq(&vo_full_istate(0, 0, 0), &vo_gather(istate, 0, 0), s.nks * s.nocc * s.nvirt); - psi::Psi> oo_full_istate(s.nks, 1, s.nocc * s.nocc, nullptr, false); + psi::Psi> oo_full_istate(s.nks, 1, s.nocc * s.nocc, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nocc, &oo_full_istate(0, 0, 0), false, LR::MO_TYPE::OO); check_eq(&oo_full_istate(0, 0, 0), &oo_gather(istate, 0, 0), s.nks * s.nocc * s.nocc); - psi::Psi> vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vv_full_istate(0, 0, 0), false, LR::MO_TYPE::VV); check_eq(&vv_full_istate(0, 0, 0), &vv_gather(istate, 0, 0), s.nks * s.nvirt * s.nvirt); } diff --git a/source/module_lr/dm_trans/test/dm_trans_test.cpp b/source/module_lr/dm_trans/test/dm_trans_test.cpp index 8a40f08c61..acef1e8a40 100644 --- a/source/module_lr/dm_trans/test/dm_trans_test.cpp +++ b/source/module_lr/dm_trans/test/dm_trans_test.cpp @@ -61,18 +61,18 @@ TEST_F(DMTransTest, DoubleSerial) { for (auto s : this->sizes) { - psi::Psi X_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); + psi::Psi X_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); set_rand(X_vo.get_pointer(), nstate * s.nks * s.nocc * s.nvirt); - psi::Psi X_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); + psi::Psi X_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); set_rand(X_oo.get_pointer(), nstate * s.nks * s.nocc * s.nocc); - psi::Psi X_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi X_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); set_rand(X_vv.get_pointer(), nstate * s.nks * s.nvirt * s.nvirt); for (int istate = 0;istate < nstate;++istate) { const int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; std::vector temp(s.nks, s.naos); - psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); set_rand(c.get_pointer(), size_c); auto test = [&](psi::Psi& X, const LR::MO_TYPE type) { @@ -92,18 +92,18 @@ TEST_F(DMTransTest, ComplexSerial) { for (auto s : this->sizes) { - psi::Psi> X_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); + psi::Psi> X_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); set_rand(X_vo.get_pointer(), nstate * s.nks * s.nocc * s.nvirt); - psi::Psi> X_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); + psi::Psi> X_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); set_rand(X_oo.get_pointer(), nstate * s.nks * s.nocc * s.nocc); - psi::Psi> X_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> X_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); set_rand(X_vv.get_pointer(), nstate * s.nks * s.nvirt * s.nvirt); for (int istate = 0;istate < nstate;++istate) { const int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; std::vector temp(s.nks, s.naos); - psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); set_rand(c.get_pointer(), size_c); auto test = [&](psi::Psi>& X, const LR::MO_TYPE type) { @@ -132,18 +132,18 @@ TEST_F(DMTransTest, DoubleParallel) LR_Util::setup_2d_division(px_oo, s.nb, s.nocc, s.nocc, px_vo.blacs_ctxt); LR_Util::setup_2d_division(px_vv, s.nb, s.nvirt, s.nvirt, px_vo.blacs_ctxt); - psi::Psi X_vo(s.nks, nstate, px_vo.get_local_size(), nullptr, false); + psi::Psi X_vo(s.nks, nstate, px_vo.get_local_size(), px_vo.get_local_size(), false); set_rand(X_vo.get_pointer(), nstate * s.nks * px_vo.get_local_size()); - psi::Psi X_oo(s.nks, nstate, px_oo.get_local_size(), nullptr, false); + psi::Psi X_oo(s.nks, nstate, px_oo.get_local_size(), px_oo.get_local_size(), false); set_rand(X_oo.get_pointer(), nstate * s.nks * px_oo.get_local_size()); - psi::Psi X_vv(s.nks, nstate, px_vv.get_local_size(), nullptr, false); + psi::Psi X_vv(s.nks, nstate, px_vv.get_local_size(), px_vv.get_local_size(), false); set_rand(X_vv.get_pointer(), nstate * s.nks * px_vv.get_local_size()); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px_vo.blacs_ctxt); std::vector temp_2(s.nks, pc.get_row_size()); - psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), temp_2.data(), true); + psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), temp_2, true); Parallel_2D pmat; LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px_vo.blacs_ctxt); @@ -153,9 +153,9 @@ TEST_F(DMTransTest, DoubleParallel) EXPECT_GE(s.nocc, px_vo.dim1); EXPECT_GE(s.naos, pc.dim0); - psi::Psi X_full_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); // allocate X_full - psi::Psi X_full_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); // allocate X_full - psi::Psi X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); // allocate X_full + psi::Psi X_full_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); // allocate X_full + psi::Psi X_full_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); // allocate X_full + psi::Psi X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); // allocate X_full auto gather = [&](const psi::Psi& X, psi::Psi& X_full, const Parallel_2D& px, const int dim1, const int dim2) { @@ -182,7 +182,7 @@ TEST_F(DMTransTest, DoubleParallel) // gather C std::vector temp(s.nks, s.naos); - psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, temp, true); for (int isk = 0;isk < s.nks;++isk) { c.fix_k(isk); @@ -223,24 +223,24 @@ TEST_F(DMTransTest, ComplexParallel) LR_Util::setup_2d_division(px_oo, s.nb, s.nocc, s.nocc, px_vo.blacs_ctxt); LR_Util::setup_2d_division(px_vv, s.nb, s.nvirt, s.nvirt, px_vo.blacs_ctxt); - psi::Psi> X_vo(s.nks, nstate, px_vo.get_local_size(), nullptr, false); + psi::Psi> X_vo(s.nks, nstate, px_vo.get_local_size(), px_vo.get_local_size(), false); set_rand(X_vo.get_pointer(), nstate * s.nks * px_vo.get_local_size()); - psi::Psi> X_oo(s.nks, nstate, px_oo.get_local_size(), nullptr, false); + psi::Psi> X_oo(s.nks, nstate, px_oo.get_local_size(), px_oo.get_local_size(), false); set_rand(X_oo.get_pointer(), nstate * s.nks * px_oo.get_local_size()); - psi::Psi> X_vv(s.nks, nstate, px_vv.get_local_size(), nullptr, false); + psi::Psi> X_vv(s.nks, nstate, px_vv.get_local_size(), px_vv.get_local_size(), false); set_rand(X_vv.get_pointer(), nstate * s.nks * px_vv.get_local_size()); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px_vo.blacs_ctxt); std::vector temp(s.nks, pc.get_row_size()); - psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), temp.data(), true); + psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), temp, true); Parallel_2D pmat; LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px_vo.blacs_ctxt); - psi::Psi> X_full_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); // allocate X_full - psi::Psi> X_full_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); // allocate X_full - psi::Psi> X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); // allocate X_full + psi::Psi> X_full_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); // allocate X_full + psi::Psi> X_full_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nvirt, false); // allocate X_full + psi::Psi> X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nocc * s.nvirt, false); // allocate X_full auto gather = [&](const psi::Psi>& X, psi::Psi>& X_full, const Parallel_2D& px, const int dim1, const int dim2) { @@ -266,7 +266,7 @@ TEST_F(DMTransTest, ComplexParallel) set_rand(c.get_pointer(), s.nks * pc.get_local_size()); // set c // compare to global matrix std::vector ngk_temp_2(s.nks, s.naos); - psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true); + psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2, true); for (int isk = 0;isk < s.nks;++isk) { c.fix_k(isk); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index a69635dffb..78eb202766 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -32,7 +32,6 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_ template Psi::Psi() { - this->npol = PARAM.globalv.npol; } template @@ -44,41 +43,7 @@ Psi::~Psi() } } -// Constructor 1-1: -template -Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) -{ - assert(nk_in > 0); - assert(nbd_in >= 0); // 187_PW_SDFT_ALL_GPU && 187_PW_MD_SDFT_ALL_GPU - assert(nbs_in > 0); - - this->k_first = k_first_in; - this->npol = PARAM.globalv.npol; - this->allocate_inside = true; - - this->ngk = ngk_in; // modify later - // This function will delete the psi array first(if psi exist), then malloc a new memory for it. - resize_memory_op()(this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); - - this->nk = nk_in; - this->nbands = nbd_in; - this->nbasis = nbs_in; - - this->current_b = 0; - this->current_k = 0; - this->current_nbasis = nbs_in; - this->psi_current = this->psi; - this->psi_bias = 0; - - // Currently only GPU's implementation is supported for device recording! - base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); - base_device::information::record_device_memory(this->ctx, - GlobalV::ofs_device, - "Psi->resize()", - sizeof(T) * nk_in * nbd_in * nbs_in); -} - -// Constructor 1-2: +// Constructor 1: template Psi::Psi(const int nk_in, const int nbd_in, @@ -87,11 +52,10 @@ Psi::Psi(const int nk_in, const bool k_first_in) { assert(nk_in > 0); - assert(nbd_in > 0); + assert(nbd_in >= 0); assert(nbs_in > 0); this->k_first = k_first_in; - this->npol = PARAM.globalv.npol; this->allocate_inside = true; this->ngk = ngk_in.data(); // modify later @@ -129,7 +93,6 @@ Psi::Psi(T* psi_pointer, // assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func this->k_first = k_first_in; - this->npol = PARAM.globalv.npol; this->allocate_inside = false; this->ngk = nullptr; @@ -158,10 +121,9 @@ Psi::Psi(const int nk_in, const bool k_first_in) { // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. - assert(nk_in == 1); + // assert(nk_in == 1); this->k_first = k_first_in; - this->npol = PARAM.globalv.npol; this->allocate_inside = true; this->ngk = nullptr; @@ -190,8 +152,8 @@ Psi::Psi(const int nk_in, template Psi::Psi(const Psi& psi_in) { + this->ngk = psi_in.ngk; - this->npol = psi_in.npol; this->nk = psi_in.get_nk(); this->nbands = psi_in.get_nbands(); this->nbasis = psi_in.get_nbasis(); @@ -215,8 +177,8 @@ template template Psi::Psi(const Psi& psi_in) { + this->ngk = psi_in.get_ngk_pointer(); - this->npol = psi_in.npol; this->nk = psi_in.get_nk(); this->nbands = psi_in.get_nbands(); this->nbasis = psi_in.get_nbasis(); @@ -323,7 +285,7 @@ const int& Psi::get_psi_bias() const template const int& Psi::get_current_ngk() const { - if (this->npol == 1) + if (this->get_npol() == 1) { return this->current_nbasis; } @@ -333,6 +295,19 @@ const int& Psi::get_current_ngk() const } } +template +const int Psi::get_npol() const +{ + if (PARAM.inp.nspin == 4) + { + return 2; + } + else + { + return 1; + } +} + template const int& Psi::get_nk() const { @@ -511,13 +486,13 @@ std::tuple Psi::to_range(const Range& range) const else if (i1 < 0) // [r1, r2] is the range of index1 with length m { const T* p = &this->psi[r1 * (k_first ? this->nbands : this->nk) * this->nbasis]; - int m = (r2 - r1 + 1) * this->npol; + int m = (r2 - r1 + 1) * this->get_npol(); return std::tuple(p, m); } else // [r1, r2] is the range of index2 with length m { const T* p = &this->psi[(i1 * (k_first ? this->nbands : this->nk) + r1) * this->nbasis]; - int m = (r2 - r1 + 1) * this->npol; + int m = (r2 - r1 + 1) * this->get_npol(); return std::tuple(p, m); } } diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index d8a994377a..75e13433ea 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -39,10 +39,7 @@ class Psi // Constructor 0: basic Psi(); - // Constructor 1-1: specify nk, nbands, nbasis, ngk, and do not need to call resize() later - Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in = true); - - // Constructor 1-2: + // Constructor 1: Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in); // Constructor 2-1: initialize a new psi from the given psi_in @@ -137,7 +134,7 @@ class Psi // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const; - int npol = 1; + const int get_npol() const; private: T* psi = nullptr; // avoid using C++ STL diff --git a/source/module_psi/psi_init.cpp b/source/module_psi/psi_init.cpp index 102e2d4b1a..8ef89dcfdc 100644 --- a/source/module_psi/psi_init.cpp +++ b/source/module_psi/psi_init.cpp @@ -106,7 +106,7 @@ void PSIInit::initialize_psi(Psi>* psi, if (not_equal) { - psi_cpu = new Psi(1, nbands_start, nbasis, nullptr); + psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); psi_device = PARAM.inp.device == "gpu" ? new psi::Psi(psi_cpu[0]) : reinterpret_cast*>(psi_cpu); } @@ -119,7 +119,7 @@ void PSIInit::initialize_psi(Psi>* psi, } else { - psi_cpu = new Psi(1, nbands_start, nbasis, nullptr); + psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); psi_device = kspw_psi; } } @@ -203,7 +203,7 @@ void PSIInit::initialize_lcao_in_pw(Psi* psi_local, std::ofstream& } } -void allocate_psi(Psi>*& psi, const int& nks, const int* ngk, const int& nbands, const int& npwx) +void allocate_psi(Psi>*& psi, const int& nks, const std::vector& ngk, const int& nbands, const int& npwx) { assert(npwx > 0); assert(nks > 0); @@ -215,7 +215,7 @@ void allocate_psi(Psi>*& psi, const int& nks, const int* ng { nks2 = 1; } - psi = new psi::Psi>(nks2, nbands, npwx * PARAM.globalv.npol, ngk); + psi = new psi::Psi>(nks2, nbands, npwx * PARAM.globalv.npol, ngk, true); const size_t memory_cost = sizeof(std::complex) * nks2 * nbands * (PARAM.globalv.npol * npwx); std::cout << " MEMORY FOR PSI (MB) : " << static_cast(memory_cost) / 1024.0 / 1024.0 << std::endl; ModuleBase::Memory::record("Psi_PW", memory_cost); diff --git a/source/module_psi/psi_init.h b/source/module_psi/psi_init.h index e112a71a6e..bf93e534d0 100644 --- a/source/module_psi/psi_init.h +++ b/source/module_psi/psi_init.h @@ -86,7 +86,7 @@ class PSIInit }; ///@brief allocate the wavefunction -void allocate_psi(Psi>*& psi, const int& nks, const int* ngk, const int& nbands, const int& npwx); +void allocate_psi(Psi>*& psi, const int& nks, const std::vector& ngk, const int& nbands, const int& npwx); } // namespace psi #endif \ No newline at end of file diff --git a/source/module_psi/psi_initializer_atomic_random.cpp b/source/module_psi/psi_initializer_atomic_random.cpp index f7b735f5ed..7e0652c25c 100644 --- a/source/module_psi/psi_initializer_atomic_random.cpp +++ b/source/module_psi/psi_initializer_atomic_random.cpp @@ -21,7 +21,7 @@ void psi_initializer_atomic_random::init_psig(T* psig, const int& ik) psi_initializer_atomic::init_psig(psig, ik); const int npol = PARAM.globalv.npol; const int nbasis = this->pw_wfc_->npwk_max * npol; - psi::Psi psi_random(1, this->nbands_start_, nbasis, nullptr); + psi::Psi psi_random(1, this->nbands_start_, nbasis, nbasis, true); psi_random.fix_k(0); this->random_t(psi_random.get_pointer(), 0, this->nbands_start_, ik, 0); for (int iband = 0; iband < this->nbands_start_; iband++) diff --git a/source/module_psi/psi_initializer_nao_random.cpp b/source/module_psi/psi_initializer_nao_random.cpp index 4f8b8d940f..ab23c4a163 100644 --- a/source/module_psi/psi_initializer_nao_random.cpp +++ b/source/module_psi/psi_initializer_nao_random.cpp @@ -21,7 +21,7 @@ void psi_initializer_nao_random::init_psig(T* psig, const int& ik) psi_initializer_nao::init_psig(psig, ik); const int npol = PARAM.globalv.npol; const int nbasis = this->pw_wfc_->npwk_max * npol; - psi::Psi psi_random(1, this->nbands_start_, nbasis, nullptr); + psi::Psi psi_random(1, this->nbands_start_, nbasis, nbasis, true); psi_random.fix_k(0); this->random_t(psi_random.get_pointer(), 0, this->nbands_start_, ik, 0); for (int iband = 0; iband < this->nbands_start_; iband++) diff --git a/source/module_psi/test/psi_initializer_unit_test.cpp b/source/module_psi/test/psi_initializer_unit_test.cpp index fd9dcd497c..b5b4180b2d 100644 --- a/source/module_psi/test/psi_initializer_unit_test.cpp +++ b/source/module_psi/test/psi_initializer_unit_test.cpp @@ -321,7 +321,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigRandom) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(-0.66187696761064307, psi->operator()(0,0,0).real(), 1e-4); delete psi; @@ -340,7 +340,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomic) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -363,7 +363,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomicSoc) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); PARAM.input.nspin = 1; @@ -390,7 +390,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomicSocHasSo) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); PARAM.input.nspin = 1; @@ -413,7 +413,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomicRandom) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -432,7 +432,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNao) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -451,7 +451,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoRandom) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -475,7 +475,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoSoc) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -499,7 +499,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoSocHasSo) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -523,7 +523,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoSocHasSoDOMAG) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index 0b42df63c7..598cbe21bd 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -8,12 +8,12 @@ class TestPsi : public ::testing::Test const int ink = 2; const int inbands = 4; const int inbasis = 10; - int ngk[4] = {10, 10, 10, 10}; + std::vector ngk = {10, 10, 10, 10}; - const psi::Psi>* psi_object31 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); - const psi::Psi* psi_object32 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); - const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); - const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); + const psi::Psi>* psi_object31 = new psi::Psi>(ink, inbands, inbasis, ngk, true); + const psi::Psi* psi_object32 = new psi::Psi(ink, inbands, inbasis, ngk, true); + const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, ngk, true); + const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, ngk, true); }; TEST_F(TestPsi, get_val) @@ -98,7 +98,7 @@ TEST_F(TestPsi, get_pointer_op_zero_complex_double) EXPECT_EQ(psi_object31->get_psi_bias(), 0); std::vector temp(ink, inbasis); - psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis, temp.data(), true); + psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis, temp, true); psi_temp->fix_k(0); EXPECT_EQ(psi_object31->get_current_nbas(), inbasis); delete psi_temp; @@ -241,10 +241,10 @@ TEST_F(TestPsi, range) TEST_F(TestPsi, band_first) { - const psi::Psi>* psi_band_c64 = new psi::Psi>(ink, inbands, inbasis, &ngk[0], false); - const psi::Psi* psi_band_64 = new psi::Psi(ink, inbands, inbasis, &ngk[0], false); - const psi::Psi>* psi_band_c32 = new psi::Psi>(ink, inbands, inbasis, &ngk[0], false); - const psi::Psi* psi_band_32 = new psi::Psi(ink, inbands, inbasis, &ngk[0], false); + const psi::Psi>* psi_band_c64 = new psi::Psi>(ink, inbands, inbasis, ngk, false); + const psi::Psi* psi_band_64 = new psi::Psi(ink, inbands, inbasis, ngk, false); + const psi::Psi>* psi_band_c32 = new psi::Psi>(ink, inbands, inbasis, ngk, false); + const psi::Psi* psi_band_32 = new psi::Psi(ink, inbands, inbasis, ngk, false); // set values: cover 4 different cases for (int ib = 0;ib < inbands;++ib) diff --git a/source/module_ri/exx_lip.hpp b/source/module_ri/exx_lip.hpp index 6be31a26b4..5e26446df4 100644 --- a/source/module_ri/exx_lip.hpp +++ b/source/module_ri/exx_lip.hpp @@ -112,7 +112,7 @@ Exx_Lip::Exx_Lip(const Exx_Info::Exx_Info_Lip& info_in, #endif this->k_pack->wf_wg.create(this->k_pack->kv_ptr->get_nks(),PARAM.inp.nbands); - this->k_pack->hvec_array = new psi::Psi(this->k_pack->kv_ptr->get_nks(), PARAM.inp.nbands, PARAM.globalv.nlocal, kv_ptr_in->ngk.data(), true); + this->k_pack->hvec_array = new psi::Psi(this->k_pack->kv_ptr->get_nks(), PARAM.inp.nbands, PARAM.globalv.nlocal, kv_ptr_in->ngk, true); // this->k_pack->hvec_array = new ModuleBase::ComplexMatrix[this->k_pack->kv_ptr->get_nks()]; // for( int ik=0; ikk_pack->kv_ptr->get_nks(); ++ik) // { From 00981713a278e38c82af4b343502c7386c56a99e Mon Sep 17 00:00:00 2001 From: Qianrui Liu <76200646+Qianruipku@users.noreply.github.com> Date: Mon, 20 Jan 2025 21:58:14 +0800 Subject: [PATCH 2/2] Feature: add para_gemm to do parallel matrix multiply (#5870) * Feature: add para_gemm to do parallel matrix multi Refator: move math_kernel_op to module_base * fix compile * fix compile * try fix pyabacus * add gatherC for para_gemm * add test --- python/pyabacus/CONTRIBUTING.md | 9 +- python/pyabacus/src/ModuleBase/CMakeLists.txt | 1 + python/pyabacus/src/ModuleNAO/CMakeLists.txt | 1 + python/pyabacus/src/hsolver/CMakeLists.txt | 2 +- python/pyabacus/src/hsolver/py_hsolver.cpp | 2 +- source/CMakeLists.txt | 6 +- source/Makefile.Objects | 3 +- source/module_base/CMakeLists.txt | 1 + source/module_base/blas_connector.cpp | 14 +- .../kernels/cuda/math_kernel_op.cu | 29 +- .../kernels/math_kernel_op.cpp | 4 +- .../kernels/math_kernel_op.h | 2 +- .../kernels/rocm/math_kernel_op.hip.cu | 27 +- .../module_base/kernels/test/CMakeLists.txt | 3 +- .../kernels/test/math_kernel_test.cpp | 54 +- source/module_base/para_gemm.cpp | 239 +++++++++ source/module_base/para_gemm.h | 93 ++++ source/module_base/parallel_device.cpp | 122 +++++ source/module_base/parallel_device.h | 127 ++--- .../module_base/test_parallel/CMakeLists.txt | 11 + .../test_parallel/test_para_gemm.cpp | 466 ++++++++++++++++++ source/module_elecstate/elecstate_pw.h | 6 +- source/module_esolver/esolver_ks_lcaopw.cpp | 2 +- source/module_esolver/esolver_ks_pw.cpp | 6 +- source/module_esolver/pw_others.cpp | 2 +- .../module_deltaspin/cal_mw_from_lambda.cpp | 6 +- source/module_hamilt_pw/hamilt_pwdft/forces.h | 4 +- .../hamilt_pwdft/fs_nonlocal_tools.cpp | 2 +- .../hamilt_pwdft/fs_nonlocal_tools.h | 4 +- .../module_hamilt_pw/hamilt_pwdft/hamilt_pw.h | 6 +- .../hamilt_pwdft/nonlocal_maths.hpp | 2 +- .../hamilt_pwdft/onsite_proj_tools.h | 4 +- .../hamilt_pwdft/onsite_projector.cpp | 2 +- .../hamilt_pwdft/onsite_projector.h | 4 +- .../hamilt_pwdft/operator_pw/meta_pw.h | 4 +- .../hamilt_pwdft/operator_pw/nonlocal_pw.h | 6 +- .../hamilt_pwdft/operator_pw/onsite_proj_pw.h | 6 +- .../hamilt_pwdft/stress_func.h | 4 +- .../hamilt_stodft/sto_che.cpp | 2 +- .../module_hamilt_pw/hamilt_stodft/sto_che.h | 4 +- .../hamilt_stodft/sto_iter.cpp | 8 +- .../module_hamilt_pw/hamilt_stodft/sto_iter.h | 2 +- source/module_hsolver/CMakeLists.txt | 1 - source/module_hsolver/diago_bpcg.cpp | 2 +- source/module_hsolver/diago_bpcg.h | 8 +- source/module_hsolver/diago_cg.cpp | 216 ++++---- source/module_hsolver/diago_cg.h | 4 +- source/module_hsolver/diago_dav_subspace.cpp | 135 ++--- source/module_hsolver/diago_david.cpp | 325 ++++++------ source/module_hsolver/diago_iter_assist.cpp | 288 +++++------ source/module_hsolver/hsolver_pw_sdft.cpp | 2 +- .../kernels/test/CMakeLists.txt | 8 +- .../kernels/test/math_dngvd_test.cpp | 8 +- .../kernels/test/perf_math_kernel.cpp | 36 +- source/module_hsolver/test/CMakeLists.txt | 1 - .../module_hsolver/test/diago_bpcg_test.cpp | 2 +- .../operator_casida/operator_lr_diag.h | 4 +- 57 files changed, 1656 insertions(+), 686 deletions(-) rename source/{module_hsolver => module_base}/kernels/cuda/math_kernel_op.cu (97%) rename source/{module_hsolver => module_base}/kernels/math_kernel_op.cpp (99%) rename source/{module_hsolver => module_base}/kernels/math_kernel_op.h (99%) rename source/{module_hsolver => module_base}/kernels/rocm/math_kernel_op.hip.cu (96%) rename source/{module_hsolver => module_base}/kernels/test/math_kernel_test.cpp (93%) create mode 100644 source/module_base/para_gemm.cpp create mode 100644 source/module_base/para_gemm.h create mode 100644 source/module_base/test_parallel/test_para_gemm.cpp diff --git a/python/pyabacus/CONTRIBUTING.md b/python/pyabacus/CONTRIBUTING.md index fbd23ad9ff..b5d7728eae 100644 --- a/python/pyabacus/CONTRIBUTING.md +++ b/python/pyabacus/CONTRIBUTING.md @@ -8,10 +8,13 @@ Welcome to the `pyabacus` project! This document provides guidelines and instruc -- [Project structure](#project-structure) +- [Developer Guide](#developer-guide) + - [Introduction](#introduction) + - [Project Structure](#project-structure) - [Root CMake Configuration](#root-cmake-configuration) - [Module CMake Configuration](#module-cmake-configuration) -- [Development Process](#development-process) + - [Development Process](#development-process) + - [Conclusion](#conclusion) @@ -187,7 +190,7 @@ list(APPEND _diago ${HSOLVER_PATH}/diag_const_nums.cpp ${HSOLVER_PATH}/diago_iter_assist.cpp ${HSOLVER_PATH}/kernels/dngvd_op.cpp - ${HSOLVER_PATH}/kernels/math_kernel_op.cpp + ${BASE_PATH}/kernels/math_kernel_op.cpp ${BASE_PATH}/kernels/math_op.cpp ${BASE_PATH}/module_device/device.cpp ${BASE_PATH}/module_device/memory_op.cpp diff --git a/python/pyabacus/src/ModuleBase/CMakeLists.txt b/python/pyabacus/src/ModuleBase/CMakeLists.txt index 7ce5fb5e3b..1c2d9a728b 100644 --- a/python/pyabacus/src/ModuleBase/CMakeLists.txt +++ b/python/pyabacus/src/ModuleBase/CMakeLists.txt @@ -1,6 +1,7 @@ list(APPEND pymodule_base ${PROJECT_SOURCE_DIR}/src/ModuleBase/py_base_math.cpp ${BASE_PATH}/kernels/math_op.cpp + ${BASE_PATH}/kernels/math_kernel_op.cpp ${BASE_PATH}/module_device/memory_op.cpp ${BASE_PATH}/module_device/device.cpp ) diff --git a/python/pyabacus/src/ModuleNAO/CMakeLists.txt b/python/pyabacus/src/ModuleNAO/CMakeLists.txt index c5eb016903..5e86604adc 100644 --- a/python/pyabacus/src/ModuleNAO/CMakeLists.txt +++ b/python/pyabacus/src/ModuleNAO/CMakeLists.txt @@ -14,6 +14,7 @@ list(APPEND _naos ${NAO_PATH}/two_center_table.cpp # dependency ${ABACUS_SOURCE_DIR}/module_base/kernels/math_op.cpp + ${ABACUS_SOURCE_DIR}/module_base/kernels/math_kernel_op.cpp # ${ABACUS_SOURCE_DIR}/module_psi/kernels/psi_memory_op.cpp ${ABACUS_SOURCE_DIR}/module_base/module_device/memory_op.cpp ${ABACUS_SOURCE_DIR}/module_base/module_device/device.cpp diff --git a/python/pyabacus/src/hsolver/CMakeLists.txt b/python/pyabacus/src/hsolver/CMakeLists.txt index f0f04f97a7..4bd0153b48 100644 --- a/python/pyabacus/src/hsolver/CMakeLists.txt +++ b/python/pyabacus/src/hsolver/CMakeLists.txt @@ -10,8 +10,8 @@ list(APPEND _diago ${HSOLVER_PATH}/kernels/dngvd_op.cpp - ${HSOLVER_PATH}/kernels/math_kernel_op.cpp # dependency + ${BASE_PATH}/kernels/math_kernel_op.cpp ${BASE_PATH}/kernels/math_op.cpp ${BASE_PATH}/module_device/device.cpp ${BASE_PATH}/module_device/memory_op.cpp diff --git a/python/pyabacus/src/hsolver/py_hsolver.cpp b/python/pyabacus/src/hsolver/py_hsolver.cpp index e791fe9f09..3c4d1c66c4 100644 --- a/python/pyabacus/src/hsolver/py_hsolver.cpp +++ b/python/pyabacus/src/hsolver/py_hsolver.cpp @@ -6,7 +6,7 @@ #include #include "module_hsolver/diago_dav_subspace.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_base/module_device/types.h" #include "./py_diago_dav_subspace.hpp" diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 1f4d4a8370..769138b096 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -36,7 +36,6 @@ list(APPEND device_srcs module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.cpp module_basis/module_pw/kernels/pw_op.cpp module_hsolver/kernels/dngvd_op.cpp - module_hsolver/kernels/math_kernel_op.cpp module_elecstate/kernels/elecstate_op.cpp # module_psi/kernels/psi_memory_op.cpp @@ -44,6 +43,7 @@ list(APPEND device_srcs module_base/module_device/device.cpp module_base/module_device/memory_op.cpp + module_base/kernels/math_kernel_op.cpp module_hamilt_pw/hamilt_pwdft/kernels/force_op.cpp module_hamilt_pw/hamilt_pwdft/kernels/stress_op.cpp @@ -64,7 +64,6 @@ if(USE_CUDA) module_hamilt_pw/hamilt_pwdft/kernels/cuda/onsite_op.cu module_basis/module_pw/kernels/cuda/pw_op.cu module_hsolver/kernels/cuda/dngvd_op.cu - module_hsolver/kernels/cuda/math_kernel_op.cu module_elecstate/kernels/cuda/elecstate_op.cu # module_psi/kernels/cuda/memory_op.cu @@ -75,6 +74,7 @@ if(USE_CUDA) module_hamilt_pw/hamilt_pwdft/kernels/cuda/wf_op.cu module_hamilt_pw/hamilt_pwdft/kernels/cuda/vnl_op.cu module_base/kernels/cuda/math_op.cu + module_base/kernels/cuda/math_kernel_op.cu module_hamilt_general/module_xc/kernels/cuda/xc_functional_op.cu ) endif() @@ -89,7 +89,6 @@ if(USE_ROCM) module_hamilt_pw/hamilt_stodft/kernels/rocm/hpsi_norm_op.hip.cu module_basis/module_pw/kernels/rocm/pw_op.hip.cu module_hsolver/kernels/rocm/dngvd_op.hip.cu - module_hsolver/kernels/rocm/math_kernel_op.hip.cu module_elecstate/kernels/rocm/elecstate_op.hip.cu # module_psi/kernels/rocm/memory_op.hip.cu @@ -99,6 +98,7 @@ if(USE_ROCM) module_hamilt_pw/hamilt_pwdft/kernels/rocm/stress_op.hip.cu module_hamilt_pw/hamilt_pwdft/kernels/rocm/wf_op.hip.cu module_hamilt_pw/hamilt_pwdft/kernels/rocm/vnl_op.hip.cu + module_base/kernels/rocm/math_kernel_op.hip.cu module_base/kernels/rocm/math_op.hip.cu module_hamilt_general/module_xc/kernels/rocm/xc_functional_op.hip.cu ) diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 5d01dd1839..ad13d75976 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -146,11 +146,13 @@ OBJS_BASE=abfs-vector3_order.o\ math_bspline.o\ math_chebyshev.o\ math_op.o\ + math_kernel_op.o\ mathzone_add1.o\ matrix.o\ matrix3.o\ memory.o\ mymath.o\ + para_gemm.o\ realarray.o\ sph_bessel_recursive-d1.o\ sph_bessel_recursive-d2.o\ @@ -336,7 +338,6 @@ OBJS_HSOLVER=diago_cg.o\ hsolver_lcaopw.o\ hsolver_pw_sdft.o\ diago_iter_assist.o\ - math_kernel_op.o\ dngvd_op.o\ diag_const_nums.o\ diag_hs_para.o\ diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index 38c466a2c1..ecbdedcf6a 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -37,6 +37,7 @@ add_library( mymath.cpp opt_CG.cpp opt_DCsrch.cpp + para_gemm.cpp realarray.cpp sph_bessel_recursive-d1.cpp sph_bessel_recursive-d2.cpp diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 14fb76e2ed..b422969ac5 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -10,7 +10,7 @@ #include #include #include "cublas_v2.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_base/module_device/memory_op.h" @@ -668,7 +668,7 @@ void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vec } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - hsolver::vector_mul_vector_op()(gpu_ctx, dim, result, vector1, vector2); + ModuleBase::vector_mul_vector_op()(gpu_ctx, dim, result, vector1, vector2); #endif } } @@ -688,7 +688,7 @@ void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vec } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - hsolver::vector_div_vector_op()(gpu_ctx, dim, result, vector1, vector2); + ModuleBase::vector_div_vector_op()(gpu_ctx, dim, result, vector1, vector2); #endif } } @@ -706,7 +706,7 @@ void vector_add_vector(const int& dim, float *result, const float *vector1, cons } else if (device_type == base_device::GpuDevice){ #ifdef __CUDA - hsolver::constantvector_addORsub_constantVector_op()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); + ModuleBase::constantvector_addORsub_constantVector_op()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); #endif } } @@ -724,7 +724,7 @@ void vector_add_vector(const int& dim, double *result, const double *vector1, co } else if (device_type == base_device::GpuDevice){ #ifdef __CUDA - hsolver::constantvector_addORsub_constantVector_op()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); + ModuleBase::constantvector_addORsub_constantVector_op()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); #endif } } @@ -742,7 +742,7 @@ void vector_add_vector(const int& dim, std::complex *result, const std::c } else if (device_type == base_device::GpuDevice){ #ifdef __CUDA - hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); + ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); #endif } } @@ -760,7 +760,7 @@ void vector_add_vector(const int& dim, std::complex *result, const std:: } else if (device_type == base_device::GpuDevice){ #ifdef __CUDA - hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); + ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); #endif } } \ No newline at end of file diff --git a/source/module_hsolver/kernels/cuda/math_kernel_op.cu b/source/module_base/kernels/cuda/math_kernel_op.cu similarity index 97% rename from source/module_hsolver/kernels/cuda/math_kernel_op.cu rename to source/module_base/kernels/cuda/math_kernel_op.cu index cd3ac41812..d48862ef33 100644 --- a/source/module_hsolver/kernels/cuda/math_kernel_op.cu +++ b/source/module_base/kernels/cuda/math_kernel_op.cu @@ -1,5 +1,5 @@ #include "module_base/module_device/memory_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" #include "module_base/tool_quit.h" @@ -9,7 +9,7 @@ #include #include -namespace hsolver +namespace ModuleBase { const int warp_size = 32; // const unsigned int full_mask = 0xffffffff; @@ -24,7 +24,7 @@ template <> struct GetTypeReal> { using type = double; /**< The return type specialization for std::complex. */ }; -namespace hsolver { +namespace ModuleBase { template struct GetTypeThrust { using type = T; @@ -817,6 +817,27 @@ void scal_op::operator()(const base_device::DEV cublasErrcheck(cublasZscal(cublas_handle, N, (double2*)alpha, (double2*)X, incx)); } +template <> +void gemm_op::operator()(const base_device::DEVICE_GPU* d, + const char& transa, + const char& transb, + const int& m, + const int& n, + const int& k, + const float* alpha, + const float* a, + const int& lda, + const float* b, + const int& ldb, + const float* beta, + float* c, + const int& ldc) +{ + cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op"); + cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op"); + cublasErrcheck(cublasSgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)); +} + template <> void gemm_op::operator()(const base_device::DEVICE_GPU* d, const char& transa, @@ -1060,4 +1081,4 @@ template struct vector_div_vector_op; template struct matrixSetToAnother; template struct constantvector_addORsub_constantVector_op; #endif -} // namespace hsolver +} // namespace ModuleBase diff --git a/source/module_hsolver/kernels/math_kernel_op.cpp b/source/module_base/kernels/math_kernel_op.cpp similarity index 99% rename from source/module_hsolver/kernels/math_kernel_op.cpp rename to source/module_base/kernels/math_kernel_op.cpp index db2a12e9db..59a3c2ace8 100644 --- a/source/module_hsolver/kernels/math_kernel_op.cpp +++ b/source/module_base/kernels/math_kernel_op.cpp @@ -1,9 +1,9 @@ -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include #include -namespace hsolver +namespace ModuleBase { template diff --git a/source/module_hsolver/kernels/math_kernel_op.h b/source/module_base/kernels/math_kernel_op.h similarity index 99% rename from source/module_hsolver/kernels/math_kernel_op.h rename to source/module_base/kernels/math_kernel_op.h index 0daf0e5718..b525ce8467 100644 --- a/source/module_hsolver/kernels/math_kernel_op.h +++ b/source/module_base/kernels/math_kernel_op.h @@ -17,7 +17,7 @@ #include "cublas_v2.h" #endif //__CUDA || __UT_USE_CUDA -namespace hsolver { +namespace ModuleBase { inline std::complex set_real_tocomplex(const std::complex &x) { return {x.real(), 0.0}; diff --git a/source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu b/source/module_base/kernels/rocm/math_kernel_op.hip.cu similarity index 96% rename from source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu rename to source/module_base/kernels/rocm/math_kernel_op.hip.cu index 1993ae4c64..5ee0648e11 100644 --- a/source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu +++ b/source/module_base/kernels/rocm/math_kernel_op.hip.cu @@ -1,5 +1,5 @@ #include "module_base/module_device/memory_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" #include "module_base/tool_quit.h" @@ -20,7 +20,7 @@ struct GetTypeReal> { using type = double; /**< The return type specialization for std::complex. */ }; -namespace hsolver { +namespace ModuleBase { template struct GetTypeThrust { @@ -735,6 +735,27 @@ void scal_op::operator()(const base_device::DEV hipblasErrcheck(hipblasZscal(cublas_handle, N, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)X, incx)); } +template <> +void gemm_op::operator()(const base_device::DEVICE_GPU* d, + const char& transa, + const char& transb, + const int& m, + const int& n, + const int& k, + const float* alpha, + const float* a, + const int& lda, + const float* b, + const int& ldb, + const float* beta, + float* c, + const int& ldc) +{ + hipblasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op"); + hipblasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op"); + hipblasErrcheck(hipblasSgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)); +} + template <> void gemm_op::operator()(const base_device::DEVICE_GPU* d, const char& transa, @@ -968,4 +989,4 @@ template struct vector_div_vector_op; template struct matrixSetToAnother; template struct constantvector_addORsub_constantVector_op; #endif -} // namespace hsolver +} // namespace ModuleBase diff --git a/source/module_base/kernels/test/CMakeLists.txt b/source/module_base/kernels/test/CMakeLists.txt index 960de3b613..1453545d14 100644 --- a/source/module_base/kernels/test/CMakeLists.txt +++ b/source/module_base/kernels/test/CMakeLists.txt @@ -3,6 +3,5 @@ remove_definitions(-D__MPI) AddTest( TARGET Base_Kernels_UTs LIBS parameter ${math_libs} base device - SOURCES math_op_test.cpp + SOURCES math_op_test.cpp math_kernel_test.cpp ) - diff --git a/source/module_hsolver/kernels/test/math_kernel_test.cpp b/source/module_base/kernels/test/math_kernel_test.cpp similarity index 93% rename from source/module_hsolver/kernels/test/math_kernel_test.cpp rename to source/module_base/kernels/test/math_kernel_test.cpp index 0781d54787..caf320ef81 100644 --- a/source/module_hsolver/kernels/test/math_kernel_test.cpp +++ b/source/module_base/kernels/test/math_kernel_test.cpp @@ -1,7 +1,7 @@ #include "module_base/blas_connector.h" #include "module_base/constants.h" #include "module_base/module_device/memory_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include #include @@ -51,8 +51,8 @@ class TestModuleHsolverMathKernel : public ::testing::Test { } - using zdot_real_cpu_op = hsolver::dot_real_op, base_device::DEVICE_CPU>; - using zdot_real_gpu_op = hsolver::dot_real_op, base_device::DEVICE_GPU>; + using zdot_real_cpu_op = ModuleBase::dot_real_op, base_device::DEVICE_CPU>; + using zdot_real_gpu_op = ModuleBase::dot_real_op, base_device::DEVICE_GPU>; using resize_memory_op = base_device::memory::resize_memory_op, base_device::DEVICE_GPU>; using delete_memory_op = base_device::memory::delete_memory_op, base_device::DEVICE_GPU>; @@ -72,23 +72,23 @@ class TestModuleHsolverMathKernel : public ::testing::Test // haozhihan add // cpu operator - using vector_div_constant_op_cpu = hsolver::vector_div_constant_op, base_device::DEVICE_CPU>; - using vector_mul_vector_op_cpu = hsolver::vector_mul_vector_op, base_device::DEVICE_CPU>; - using vector_div_vector_op_cpu = hsolver::vector_div_vector_op, base_device::DEVICE_CPU>; + using vector_div_constant_op_cpu = ModuleBase::vector_div_constant_op, base_device::DEVICE_CPU>; + using vector_mul_vector_op_cpu = ModuleBase::vector_mul_vector_op, base_device::DEVICE_CPU>; + using vector_div_vector_op_cpu = ModuleBase::vector_div_vector_op, base_device::DEVICE_CPU>; using constantvector_addORsub_constantVector_op_cpu - = hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_CPU>; - using axpy_op_cpu = hsolver::axpy_op, base_device::DEVICE_CPU>; - using scal_op_cpu = hsolver::scal_op; - using gemv_op_cpu = hsolver::gemv_op, base_device::DEVICE_CPU>; + = ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_CPU>; + using axpy_op_cpu = ModuleBase::axpy_op, base_device::DEVICE_CPU>; + using scal_op_cpu = ModuleBase::scal_op; + using gemv_op_cpu = ModuleBase::gemv_op, base_device::DEVICE_CPU>; // gpu operator - using vector_div_constant_op_gpu = hsolver::vector_div_constant_op, base_device::DEVICE_GPU>; - using vector_mul_vector_op_gpu = hsolver::vector_mul_vector_op, base_device::DEVICE_GPU>; - using vector_div_vector_op_gpu = hsolver::vector_div_vector_op, base_device::DEVICE_GPU>; + using vector_div_constant_op_gpu = ModuleBase::vector_div_constant_op, base_device::DEVICE_GPU>; + using vector_mul_vector_op_gpu = ModuleBase::vector_mul_vector_op, base_device::DEVICE_GPU>; + using vector_div_vector_op_gpu = ModuleBase::vector_div_vector_op, base_device::DEVICE_GPU>; using constantvector_addORsub_constantVector_op_gpu - = hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>; - using axpy_op_gpu = hsolver::axpy_op, base_device::DEVICE_GPU>; - using scal_op_gpu = hsolver::scal_op; - using gemv_op_gpu = hsolver::gemv_op, base_device::DEVICE_GPU>; + = ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>; + using axpy_op_gpu = ModuleBase::axpy_op, base_device::DEVICE_GPU>; + using scal_op_gpu = ModuleBase::scal_op; + using gemv_op_gpu = ModuleBase::gemv_op, base_device::DEVICE_GPU>; // haozhihan add std::vector> L = {{-0.65412617, -0.74208893}, @@ -375,9 +375,9 @@ TEST_F(TestModuleHsolverMathKernel, zdot_real_op_gpu) resize_memory_op()(psi_R_dev, psi_R.size()); synchronize_memory_op()(psi_L_dev, psi_L.data(), psi_L.size()); synchronize_memory_op()(psi_R_dev, psi_R.data(), psi_R.size()); - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); double result = zdot_real_gpu_op()(gpu_ctx, dim, psi_L_dev, psi_R_dev, false); - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); EXPECT_LT(fabs(result - expected_result), 1e-12); delete_memory_op()(psi_L_dev); delete_memory_op()(psi_R_dev); @@ -537,9 +537,9 @@ TEST_F(TestModuleHsolverMathKernel, axpy_op_gpu) synchronize_memory_op()(Y_axpy_dev, Y_axpy.data(), Y_axpy.size()); // run - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); axpy_op_gpu()(gpu_ctx, dim, &alpha_axpy, X_axpy_dev, 1, Y_axpy_dev, 1); - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); // syn the output data in GPU to CPU synchronize_memory_op_gpu()(Y_axpy.data(), Y_axpy_dev, Y_axpy.size()); @@ -566,9 +566,9 @@ TEST_F(TestModuleHsolverMathKernel, scal_op_gpu) synchronize_memory_op()(X_scal_dev, X_scal.data(), X_scal.size()); // run - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); scal_op_gpu()(gpu_ctx, dim, &alpha_scal, X_scal_dev, 1); - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); // syn the output data in GPU to CPU synchronize_memory_op_gpu()(X_scal.data(), X_scal_dev, X_scal.size()); @@ -599,9 +599,9 @@ TEST_F(TestModuleHsolverMathKernel, gemv_op_gpu) synchronize_memory_op()(Y_gemv_dev, Y_gemv.data(), Y_gemv.size()); // run - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); gemv_op_gpu()(gpu_ctx, 'C', 2, 3, &ModuleBase::ONE, A_gemv_dev, 2, X_gemv_dev, 1, &ModuleBase::ONE, Y_gemv_dev, 1); - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); // syn the output data in GPU to CPU synchronize_memory_op_gpu()(Y_gemv.data(), Y_gemv_dev, Y_gemv.size()); @@ -668,7 +668,7 @@ TEST_F(TestModuleHsolverMathKernel, matrixSetToAnother_op_gpu) B.size()); // run - hsolver::matrixSetToAnother, base_device::DEVICE_GPU>()(gpu_ctx, + ModuleBase::matrixSetToAnother, base_device::DEVICE_GPU>()(gpu_ctx, n, device_A, LDA, @@ -683,7 +683,7 @@ TEST_F(TestModuleHsolverMathKernel, matrixSetToAnother_op_gpu) B_gpu2cpu.size()); std::vector> B_cpu(8); - hsolver::matrixSetToAnother, base_device::DEVICE_CPU>()(cpu_ctx, + ModuleBase::matrixSetToAnother, base_device::DEVICE_CPU>()(cpu_ctx, n, A.data(), LDA, diff --git a/source/module_base/para_gemm.cpp b/source/module_base/para_gemm.cpp new file mode 100644 index 0000000000..0908457108 --- /dev/null +++ b/source/module_base/para_gemm.cpp @@ -0,0 +1,239 @@ +#include "para_gemm.h" + +#include "kernels/math_kernel_op.h" +#include "parallel_device.h" +namespace ModuleBase +{ +template +PGemmCN::PGemmCN() +{ +} +template +PGemmCN::~PGemmCN() +{ +} + +template +void PGemmCN::set_dimension( +#ifdef __MPI + MPI_Comm comm_col, + MPI_Comm comm_row, +#endif + const int ncolA_in, + const int LDA_in, + const int ncolB_in, + const int LDB_in, + const int nrow_in, + const int LDC_in, + const bool gatherC_in) +{ +#ifdef __MPI + MPI_Comm_rank(comm_col, &col_rank); + MPI_Comm_size(comm_col, &col_nproc); + if (comm_row != MPI_COMM_NULL) + { + MPI_Comm_rank(comm_row, &row_rank); + MPI_Comm_size(comm_row, &row_nproc); + } + col_world = comm_col; + row_world = comm_row; +#endif + this->LDA = LDA_in; + this->LDB = LDB_in; + this->LDC = LDC_in; + this->ncolA = ncolA_in; + this->ncolB = ncolB_in; + this->nrow = nrow_in; +#ifdef __MPI + this->gatherC = gatherC_in; + requests.resize(col_nproc); + colA_loc.resize(col_nproc); + MPI_Allgather(&ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world); + for (int ip = 0; ip < col_nproc; ip++) + { + max_colA = std::max(max_colA, colA_loc[ip]); + } + + if (this->gatherC) + { + colB_loc.resize(col_nproc); + recv_counts.resize(col_nproc); + displs.resize(col_nproc); + MPI_Allgather(&ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, col_world); + for (int ip = 0; ip < col_nproc; ip++) + { + recv_counts[ip] = LDC * colB_loc[ip]; + } + displs[0] = 0; + for (int ip = 1; ip < col_nproc; ip++) + { + displs[ip] = displs[ip - 1] + recv_counts[ip - 1]; + } + size_C_global = displs[col_nproc - 1] + recv_counts[col_nproc - 1]; + } + size_C_local = ncolB * LDC; +#endif +} + +template +void PGemmCN::multiply(const T alpha, const T* A, const T* B, const T beta, T* C) +{ + const Device* ctx = {}; +#ifdef __MPI + if (col_nproc > 1) + { + std::vector A_tmp(max_colA * LDA); + for (int ip = 0; ip < col_nproc; ip++) + { + if (col_rank != ip) + { + int size = ncolA * LDA; + Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], A_tmp.data()); + } + } + + T* C_local = C; + std::vector C_tmp; + if (this->gatherC) + { + C_tmp.resize(size_C_local); + if (std::is_same::value) + { + C_local = nullptr; + resmem_dev_op()(C_local, size_C_local); + } + else + { + C_local = C_tmp.data(); + } + syncmem_dev_op()(C_local, C + displs[col_rank], size_C_local); + } + + T* Atmp_device = nullptr; + if (std::is_same::value) + { + resmem_dev_op()(Atmp_device, max_colA * LDA); + } + else + { + Atmp_device = A_tmp.data(); + } + + int shift = 0; + T real_beta = row_rank == 0 ? beta : 0; + for (int ip = 0; ip < col_nproc; ip++) + { + T* C_start = C_local + shift; + if (col_rank == ip) + { + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + ncolA, + ncolB, + nrow, + &alpha, + A, + LDA, + B, + LDB, + &real_beta, + C_start, + LDC); + shift += ncolA; + } + else + { + int m = colA_loc[ip]; + int size = m * LDA; + MPI_Status status; + Parallel_Common::recv_dev(Atmp_device, size, ip, 0, col_world, &status, A_tmp.data()); + MPI_Wait(&requests[ip], &status); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + m, + ncolB, + nrow, + &alpha, + Atmp_device, + LDA, + B, + LDB, + &real_beta, + C_start, + LDC); + shift += m; + } + } + + if (this->gatherC) + { + T* Cglobal_cpu = nullptr; + T* Clocal_cpu = C_tmp.data();; + if (std::is_same::value) + { + delmem_dev_op()(Atmp_device); + + syncmem_d2h_op()(Clocal_cpu, C_local, size_C_local); + delmem_dev_op()(C_local); + + resmem_dev_op()(Cglobal_cpu, size_C_global); + } + else + { + Cglobal_cpu = C; + } + if (this->row_nproc > 1) + { + Parallel_Common::reduce_data(Clocal_cpu, size_C_local, row_world); + } + Parallel_Common::gatherv_data(Clocal_cpu, + size_C_local, + Cglobal_cpu, + recv_counts.data(), + displs.data(), + col_world); + + if (std::is_same::value) + { + syncmem_h2d_op()(C, Cglobal_cpu, size_C_global); + delmem_dev_op()(Cglobal_cpu); + } + } + else + { + if (this->row_nproc > 1) + { + Parallel_Common::reduce_dev(C, size_C_local, row_world); + } + } + } + else + { + T real_beta = row_rank == 0 ? beta : 0; +#else + T real_beta = beta; +#endif + ModuleBase::gemm_op()(ctx, 'C', 'N', ncolA, ncolB, nrow, &alpha, A, LDA, B, LDB, &real_beta, C, LDC); +#ifdef __MPI + if (this->row_nproc > 1) + { + Parallel_Common::reduce_dev(C, size_C_local, row_world); + } + } +#endif +} + +template class PGemmCN; +template class PGemmCN; +template class PGemmCN, base_device::DEVICE_CPU>; +template class PGemmCN, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template class PGemmCN; +template class PGemmCN; +template class PGemmCN, base_device::DEVICE_GPU>; +template class PGemmCN, base_device::DEVICE_GPU>; +#endif + +} // namespace ModuleBase \ No newline at end of file diff --git a/source/module_base/para_gemm.h b/source/module_base/para_gemm.h new file mode 100644 index 0000000000..69ffd6d146 --- /dev/null +++ b/source/module_base/para_gemm.h @@ -0,0 +1,93 @@ +#ifndef PARA_GEMM_H +#define PARA_GEMM_H +#include "module_base/module_device/device.h" +#include "module_base/module_device/memory_op.h" + +#include +#ifdef __MPI +#include "mpi.h" +#endif + +namespace ModuleBase +{ +/** + * @brief this class is used to perform parallel matrix multiplication + * C = alpha * A^H * B + beta * C + * Here, A and B are local matrices in each proc, + * C can be C_local or C_global, depending on the value of gatherC + * C_local is a local matrix in each proc + * C_global is a global matrix gathered from all procs and all procs have their own C_global matrix with the same + * C_global and C_local have the same LDC, but different column numbers + * values. + */ +template +class PGemmCN +{ + public: + PGemmCN(); + ~PGemmCN(); + + /** + * @brief set the dimension of A, B, and C + * + * @param ncolA number of columns of A, which is a local matrix in each proc + * @param LDA leading dimension of A in each proc + * @param ncolB number of columns of B, which is a local matrix in each proc + * @param LDB leading dimension of B in each proc + * @param nrow number of rows of A or B + * @param LDC leading dimension of C. C can be C_local or C_global + * @param gatherC whether gather C_local to C_global + */ + void set_dimension( +#ifdef __MPI + MPI_Comm comm_col, + MPI_Comm comm_row, +#endif + const int ncolA, + const int LDA, + const int ncolB, + const int LDB, + const int nrow, + const int LDC, + const bool gatherC = true); + + /** + * @brief calculate C = alpha * A^H * B + beta * C + * + */ + void multiply(const T alpha, const T* A, const T* B, const T beta, T* C); +#ifdef __MPI + MPI_Comm col_world = MPI_COMM_NULL; ///< column communicator world + MPI_Comm row_world = MPI_COMM_NULL; ///< row communicator world + + int col_rank = 0; ///< rank in col_world + int col_nproc = 1; ///< number of procs in col_world + int row_rank = 0; ///< rank in row_world + int row_nproc = 1; ///< number of procs in row_world + + std::vector colA_loc; ///< [col_nproc] number of columns of A matrix in each proc + int max_colA = 0; ///< maximum number of columns of A matrix in all procs + std::vector colB_loc; ///<[col_nproc] number of columns of B matrix in each proc + + std::vector requests; ///< MPI request + std::vector recv_counts; ///< receive counts for gathering C_local to C_global + std::vector displs; ///< displacements for gathering C_local to C_global + int size_C_local = 0; ///< size of C_local, which is a local matrix in each proc + int size_C_global = 0; ///< size of C_global, which is the global C matrix gathered from all procs + bool gatherC = true; ///< whether gather C_local to C_global +#endif + int ncolA = 0; ///< number of columns of A, which is a local matrix in each proc + int ncolB = 0; ///< number of columns of B, which is a local matrix in each proc + int nrow = 0; ///< number of rows of A or B + int LDA = 0; ///< leading dimension of A in each proc + int LDB = 0; ///< leading dimension of B in each proc + int LDC = 0; ///< leading dimension of C, which can be C_local or C_global + private: + using resmem_dev_op = base_device::memory::resize_memory_op; + using delmem_dev_op = base_device::memory::delete_memory_op; + using syncmem_dev_op = base_device::memory::synchronize_memory_op; + using syncmem_d2h_op = base_device::memory::synchronize_memory_op; + using syncmem_h2d_op = base_device::memory::synchronize_memory_op; +}; +} // namespace ModuleBase +#endif \ No newline at end of file diff --git a/source/module_base/parallel_device.cpp b/source/module_base/parallel_device.cpp index 269a41821e..d7373674d6 100644 --- a/source/module_base/parallel_device.cpp +++ b/source/module_base/parallel_device.cpp @@ -2,6 +2,38 @@ #ifdef __MPI namespace Parallel_Common { +void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) +{ + MPI_Isend(buf, count, MPI_DOUBLE, dest, tag, comm, request); +} +void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) +{ + MPI_Isend(buf, count, MPI_DOUBLE_COMPLEX, dest, tag, comm, request); +} +void isend_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) +{ + MPI_Isend(buf, count, MPI_FLOAT, dest, tag, comm, request); +} +void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) +{ + MPI_Isend(buf, count, MPI_COMPLEX, dest, tag, comm, request); +} +void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status) +{ + MPI_Recv(buf, count, MPI_DOUBLE, source, tag, comm, status); +} +void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status) +{ + MPI_Recv(buf, count, MPI_DOUBLE_COMPLEX, source, tag, comm, status); +} +void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status) +{ + MPI_Recv(buf, count, MPI_FLOAT, source, tag, comm, status); +} +void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status) +{ + MPI_Recv(buf, count, MPI_COMPLEX, source, tag, comm, status); +} void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm) { MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm); @@ -34,5 +66,95 @@ void reduce_data(float* object, const int& n, const MPI_Comm& comm) { MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_FLOAT, MPI_SUM, comm); } +void gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + MPI_Allgatherv(sendbuf, sendcount, MPI_DOUBLE, recvbuf, recvcounts, displs, MPI_DOUBLE, comm); +} +void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + MPI_Allgatherv(sendbuf, sendcount, MPI_DOUBLE_COMPLEX, recvbuf, recvcounts, displs, MPI_DOUBLE_COMPLEX, comm); +} +void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + MPI_Allgatherv(sendbuf, sendcount, MPI_FLOAT, recvbuf, recvcounts, displs, MPI_FLOAT, comm); +} +void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + MPI_Allgatherv(sendbuf, sendcount, MPI_COMPLEX, recvbuf, recvcounts, displs, MPI_COMPLEX, comm); } + +template +struct object_cpu_point +{ + bool alloc = false; + T* get(const T* object, const int& n, T* tmp_space = nullptr) + { + T* object_cpu = nullptr; + alloc = false; + + if (tmp_space == nullptr) + { + base_device::memory::resize_memory_op()(object_cpu, n); + alloc = true; + } + else + { + object_cpu = tmp_space; + } + base_device::memory::synchronize_memory_op()(object_cpu, + object, + n); + + return object_cpu; + } + void sync_h2d(T* object, const T* object_cpu, const int& n) + { + base_device::memory::synchronize_memory_op()(object, + object_cpu, + n); + } + void sync_d2h(T* object_cpu, const T* object, const int& n) + { + base_device::memory::synchronize_memory_op()(object_cpu, + object, + n); + } + void del(T* object_cpu) + { + if (alloc) + { + base_device::memory::delete_memory_op()(object_cpu); + } + } +}; + +template +struct object_cpu_point +{ + bool alloc = false; + T* get(const T* object, const int& n, T* tmp_space = nullptr) + { + return const_cast(object); + } + void sync_h2d(T* object, const T* object_cpu, const int& n) + { + } + void sync_d2h(T* object_cpu, const T* object, const int& n) + { + } + void del(T* object_cpu) + { + } +}; + +template struct object_cpu_point; +template struct object_cpu_point; +template struct object_cpu_point, base_device::DEVICE_CPU>; +template struct object_cpu_point, base_device::DEVICE_GPU>; +template struct object_cpu_point; +template struct object_cpu_point; +template struct object_cpu_point, base_device::DEVICE_CPU>; +template struct object_cpu_point, base_device::DEVICE_GPU>; + +} // namespace Parallel_Common #endif \ No newline at end of file diff --git a/source/module_base/parallel_device.h b/source/module_base/parallel_device.h index 7c41b8f28f..776de4e755 100644 --- a/source/module_base/parallel_device.h +++ b/source/module_base/parallel_device.h @@ -7,6 +7,14 @@ #include namespace Parallel_Common { +void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); +void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); +void isend_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); +void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); +void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); +void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); +void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); +void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm); void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm); void bcast_data(double* object, const int& n, const MPI_Comm& comm); @@ -15,6 +23,50 @@ void reduce_data(std::complex* object, const int& n, const MPI_Comm& com void reduce_data(std::complex* object, const int& n, const MPI_Comm& comm); void reduce_data(double* object, const int& n, const MPI_Comm& comm); void reduce_data(float* object, const int& n, const MPI_Comm& comm); +void gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); + +template +struct object_cpu_point +{ + bool alloc = false; + T* get(const T* object, const int& n, T* tmp_space = nullptr); + void del(T* object); + void sync_d2h(T* object_cpu, const T* object, const int& n); + void sync_h2d(T* object, const T* object_cpu, const int& n); +}; + +/** + * @brief isend data in Device + * + */ +template +void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* tmp_space = nullptr) +{ + object_cpu_point o; + T* object_cpu = o.get(object, count, tmp_space); + o.sync_d2h(object_cpu, object, count); + isend_data(object_cpu, count, dest, tag, comm, request); + o.del(object_cpu); + return; +} + +/** + * @brief recv data in Device + * + */ +template +void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status, T* tmp_space = nullptr) +{ + object_cpu_point o; + T* object_cpu = o.get(object, count, tmp_space); + recv_data(object_cpu, count, source, tag, comm, status); + o.sync_h2d(object, object_cpu, count); + o.del(object_cpu); + return; +} /** * @brief bcast data in Device @@ -28,79 +80,28 @@ void reduce_data(float* object, const int& n, const MPI_Comm& comm); * @param tmp_space tmp space in CPU */ template -void bcast_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) +void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { - const base_device::DEVICE_CPU* cpu_ctx = {}; - T* object_cpu = nullptr; - bool alloc = false; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) - { - if(tmp_space == nullptr) - { - base_device::memory::resize_memory_op()(object_cpu, n); - alloc = true; - } - else - { - object_cpu = tmp_space; - } - base_device::memory::synchronize_memory_op()(object_cpu, object, n); - } - else - { - object_cpu = object; - } - + object_cpu_point o; + T* object_cpu = o.get(object, n, tmp_space); + o.sync_d2h(object_cpu, object, n); bcast_data(object_cpu, n, comm); - - if (base_device::get_device_type(ctx) == base_device::GpuDevice) - { - base_device::memory::synchronize_memory_op()(object, object_cpu, n); - if(alloc) - { - base_device::memory::delete_memory_op()(object_cpu); - } - } + o.sync_h2d(object, object_cpu, n); + o.del(object_cpu); return; } template -void reduce_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) +void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { - const base_device::DEVICE_CPU* cpu_ctx = {}; - T* object_cpu = nullptr; - bool alloc = false; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) - { - if(tmp_space == nullptr) - { - base_device::memory::resize_memory_op()(object_cpu, n); - alloc = true; - } - else - { - object_cpu = tmp_space; - } - base_device::memory::synchronize_memory_op()(object_cpu, object, n); - } - else - { - object_cpu = object; - } - + object_cpu_point o; + T* object_cpu = o.get(object, n, tmp_space); + o.sync_d2h(object_cpu, object, n); reduce_data(object_cpu, n, comm); - - if (base_device::get_device_type(ctx) == base_device::GpuDevice) - { - base_device::memory::synchronize_memory_op()(object, object_cpu, n); - if(alloc) - { - base_device::memory::delete_memory_op()(object_cpu); - } - } + o.sync_h2d(object, object_cpu, n); + o.del(object_cpu); return; } - } diff --git a/source/module_base/test_parallel/CMakeLists.txt b/source/module_base/test_parallel/CMakeLists.txt index f6a2c34c50..5132549f7a 100644 --- a/source/module_base/test_parallel/CMakeLists.txt +++ b/source/module_base/test_parallel/CMakeLists.txt @@ -34,6 +34,17 @@ add_test(NAME base_parallel_reduce_test WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) +AddTest( + TARGET base_para_gemm + LIBS MPI::MPI_CXX ${math_libs} base device parameter + SOURCES test_para_gemm.cpp +) + +add_test(NAME base_para_gemm_parallel + COMMAND mpirun -np 4 ./base_para_gemm + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +) + AddTest( TARGET parallel_2d_test SOURCES parallel_2d_test.cpp ../parallel_2d.cpp diff --git a/source/module_base/test_parallel/test_para_gemm.cpp b/source/module_base/test_parallel/test_para_gemm.cpp new file mode 100644 index 0000000000..4b6445d057 --- /dev/null +++ b/source/module_base/test_parallel/test_para_gemm.cpp @@ -0,0 +1,466 @@ +#include "../kernels/math_kernel_op.h" +#include "../para_gemm.h" + +#include +#include +#include +#include + +void random_data(std::vector& A_global, + std::vector& B_global, + std::vector& Cref_global, + std::vector& C_global, + double& alpha, + double& beta) +{ + for (auto& val: A_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + for (auto& val: B_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + for (auto& val: Cref_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + C_global = Cref_global; + + alpha = std::rand() / (RAND_MAX + 1.0); + beta = std::rand() / (RAND_MAX + 1.0); +} +void random_data(std::vector>& A_global, + std::vector>& B_global, + std::vector>& Cref_global, + std::vector>& C_global, + std::complex& alpha, + std::complex& beta) +{ + for (auto& val: A_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + for (auto& val: B_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + for (auto& val: Cref_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + C_global = Cref_global; + + alpha = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + beta = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); +} +double get_double(std::complex& val) +{ + return val.real() + val.imag(); +} +double get_double(double& val) +{ + return val; +} + +void scatterv_data(const double* sendbuf, + const int* sendcounts, + const int* displs, + double* recvbuf, + const int recvcount, + MPI_Comm comm) +{ + MPI_Scatterv(sendbuf, sendcounts, displs, MPI_DOUBLE, recvbuf, recvcount, MPI_DOUBLE, 0, comm); +} +void scatterv_data(const std::complex* sendbuf, + const int* sendcounts, + const int* displs, + std::complex* recvbuf, + const int recvcount, + MPI_Comm comm) +{ + MPI_Scatterv(sendbuf, sendcounts, displs, MPI_DOUBLE_COMPLEX, recvbuf, recvcount, MPI_DOUBLE_COMPLEX, 0, comm); +} +template +class PgemmTest : public ::testing::Test +{ + protected: + void SetUp() override + { + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &nproc); + } + void TearDown() override + { + MPI_Comm_free(&col_world); + MPI_Comm_free(&row_world); + } + + public: + void decide_ngroup(const int& willing_ncolgroup, const int& willing_nrowgroup) + { + ncolgroup = willing_ncolgroup; + nrowgroup = willing_nrowgroup; + if (nproc % (ncolgroup * nrowgroup) != 0) + { + ncolgroup = nproc; + nrowgroup = 1; + } + else + { + nrowgroup = nproc / ncolgroup; + } + + MPI_Comm_split(MPI_COMM_WORLD, rank % nrowgroup, rank / nrowgroup, &col_world); + MPI_Comm_split(MPI_COMM_WORLD, rank / nrowgroup, rank % nrowgroup, &row_world); + MPI_Comm_rank(col_world, &rank_col); + MPI_Comm_rank(row_world, &rank_row); + MPI_Comm_size(col_world, &nproc_col); + MPI_Comm_size(row_world, &nproc_row); + } + void randomize_initialization() + { + random_data(A_global, B_global, Cref_global, C_global, alpha, beta); + } + + void prepare(const int& ncolA_global, + const int& ncolB_global, + const int& nrow_global, + const int& LDA_global, + const int& LDB_global, + const int& LDC_global) + { + A_global = std::vector(LDA_global * ncolA_global, 0.0); + B_global = std::vector(LDB_global * ncolB_global, 0.0); + C_global = std::vector(LDC_global * ncolB_global, 0.0); + Cref_global = std::vector(LDC_global * ncolB_global, 0.0); + if (rank == 0) + { + + this->randomize_initialization(); + const base_device::DEVICE_CPU* ctx = {}; + char transC = 'C'; + char transN = 'N'; + ModuleBase::gemm_op()(ctx, + transC, + transN, + ncolA_global, + ncolB_global, + nrow_global, + &alpha, + A_global.data(), + LDA_global, + B_global.data(), + LDB_global, + &beta, + Cref_global.data(), + LDC_global); + } + + if (std::is_same::value) + { + MPI_Bcast(A_global.data(), A_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global.data(), B_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(C_global.data(), C_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(Cref_global.data(), Cref_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(&alpha, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(&beta, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); + } + else if (std::is_same>::value) + { + MPI_Bcast(A_global.data(), A_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global.data(), B_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(C_global.data(), C_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(Cref_global.data(), Cref_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(&alpha, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(&beta, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + } + + // Broadcast A_global and B_global to all ranks + getncol_and_row(ncolA_global, ncolB_global, nrow_global); + LDA = nrow + 1; + LDB = nrow + 2; + + A_local = std::vector(LDA * ncolA, 0.0); + B_local = std::vector(LDB * ncolB, 0.0); + + scatter_matrix(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global); + } + + void getncol_and_row(const int& ncolA_global, const int& ncolB_global, const int& nrow_global) + { + ncolA = ncolA_global / ncolgroup; + if (ncolA_global % ncolgroup > rank_col) + { + ncolA += 1; + } + ncolB = ncolB_global / ncolgroup; + if (ncolB_global % ncolgroup > rank_col) + { + ncolB += 1; + } + + nrow = nrow_global / nrowgroup; + if (nrow_global % nrowgroup > rank_row) + { + nrow += 1; + } + + ncolA_ip.resize(nproc_col); + ncolB_ip.resize(nproc_col); + nrow_ip.resize(nproc_row); + MPI_Allgather(&ncolA, 1, MPI_INT, ncolA_ip.data(), 1, MPI_INT, col_world); + MPI_Allgather(&ncolB, 1, MPI_INT, ncolB_ip.data(), 1, MPI_INT, col_world); + if (row_world != MPI_COMM_NULL) + { + MPI_Allgather(&nrow, 1, MPI_INT, nrow_ip.data(), 1, MPI_INT, row_world); + } + } + + void scatter_matrix(const int& ncolA_global, + const int& ncolB_global, + const int& nrow_global, + const int& LDA_global, + const int& LDB_global) + { + std::vector A_semiglobal(ncolA * LDA_global, 0.0); + std::vector B_semiglobal(ncolB * LDB_global, 0.0); + + // Scatter A_global and B_global to A_semiglobal and B_semiglobal + std::vector sendcounts(nproc_col, 0); + std::vector displs(nproc_col, 0); + for (int i = 0; i < nproc_col; i++) + { + sendcounts[i] = ncolA_ip[i] * LDA_global; + } + displs[0] = 0; + for (int i = 1; i < nproc_col; i++) + { + displs[i] = displs[i - 1] + sendcounts[i - 1]; + } + scatterv_data(A_global.data(), + sendcounts.data(), + displs.data(), + A_semiglobal.data(), + ncolA * LDA_global, + col_world); + + for (int i = 0; i < nproc_col; i++) + { + sendcounts[i] = ncolB_ip[i] * LDB_global; + } + displs[0] = 0; + for (int i = 1; i < nproc_col; i++) + { + displs[i] = displs[i - 1] + sendcounts[i - 1]; + } + scatterv_data(B_global.data(), + sendcounts.data(), + displs.data(), + B_semiglobal.data(), + ncolB * LDB_global, + col_world); + + // Scatter A_semiglobal and B_semiglobal to A_local and B_local + sendcounts.resize(nproc_row, 0); + displs.resize(nproc_row, 0); + for (int i = 0; i < nproc_row; i++) + { + sendcounts[i] = nrow_ip[i]; + } + displs[0] = 0; + for (int i = 1; i < nproc_row; i++) + { + displs[i] = displs[i - 1] + sendcounts[i - 1]; + } + for (int i = 0; i < ncolA; i++) + { + scatterv_data(A_semiglobal.data() + i * LDA_global, + sendcounts.data(), + displs.data(), + A_local.data() + i * LDA, + nrow, + row_world); + } + + for (int i = 0; i < ncolB; i++) + { + scatterv_data(B_semiglobal.data() + i * LDB_global, + sendcounts.data(), + displs.data(), + B_local.data() + i * LDB, + nrow, + row_world); + } + } + + void compare_result(const int& nrowC_global, const int& ncolC_global, const int& LDC_global) + { + for (int i = 0; i < ncolC_global; i++) + { + for (int j = 0; j < nrowC_global; j++) + { + EXPECT_NEAR(get_double(Cref_global[i * LDC_global + j]), + get_double(C_global[i * LDC_global + j]), + 1e-10); + } + } + } + + int rank = 0, nproc = 0; + T alpha = 0, beta = 0; + std::vector A_global, B_global, Cref_global, C_global; + std::vector A_local, B_local; + int ncolA = 0, ncolB = 0, nrow = 0, LDA = 0, LDB = 0; + int ncolgroup = 1, nrowgroup = 1; + int rank_col = 0, rank_row = 0; + int nproc_col = 0, nproc_row = 0; + ModuleBase::PGemmCN pgemm; + MPI_Comm col_world; + MPI_Comm row_world; + std::vector ncolA_ip, ncolB_ip, nrow_ip; +}; + +typedef ::testing::Types> MyTypes; + +TYPED_TEST_SUITE(PgemmTest, MyTypes); + +TYPED_TEST(PgemmTest, even_case) +{ + const int ncolA_global = 16, ncolB_global = 8, nrow_global = 12; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(2, 2); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +TYPED_TEST(PgemmTest, odd_case) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(2, 2); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +TYPED_TEST(PgemmTest, odd_case_not_gather) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(2, 2); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + std::vector colB_loc(this->nproc_col); + MPI_Allgather(&this->ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, this->col_world); + std::vector displs(this->nproc_col); + displs[0] = 0; + for (int i = 1; i < this->nproc_col; i++) + { + displs[i] = (displs[i - 1] + colB_loc[i - 1]) * LDC_global; + } + int start = displs[this->rank_col]; + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global, + false); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()+ start); + + + + for (int i = 0; i < this->ncolB; i++) + { + for (int j = 0; j < ncolA_global; j++) + { + EXPECT_NEAR(get_double(this->Cref_global[i * LDC_global + start + j]), + get_double(this->C_global[i * LDC_global + start + j]), + 1e-10); + } + } +} + +TYPED_TEST(PgemmTest, row_parallel) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(1, 4); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +TYPED_TEST(PgemmTest, col_parallel) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(4, 1); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + MPI_Init(&argc, &argv); + + int RANK, NPROC; + MPI_Comm_rank(MPI_COMM_WORLD, &RANK); + MPI_Comm_size(MPI_COMM_WORLD, &NPROC); + + int result = RUN_ALL_TESTS(); + + MPI_Finalize(); + return result; +} \ No newline at end of file diff --git a/source/module_elecstate/elecstate_pw.h b/source/module_elecstate/elecstate_pw.h index 8259d83024..679b9b712c 100644 --- a/source/module_elecstate/elecstate_pw.h +++ b/source/module_elecstate/elecstate_pw.h @@ -7,7 +7,7 @@ #include "module_basis/module_pw/pw_basis_k.h" #include "module_elecstate/kernels/elecstate_op.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/meta_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace elecstate { @@ -98,8 +98,8 @@ class ElecStatePW : public ElecState using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; - using gemv_op = hsolver::gemv_op; - using gemm_op = hsolver::gemm_op; + using gemv_op = ModuleBase::gemv_op; + using gemm_op = ModuleBase::gemm_op; }; } // namespace elecstate diff --git a/source/module_esolver/esolver_ks_lcaopw.cpp b/source/module_esolver/esolver_ks_lcaopw.cpp index 4649fb07ca..3f7296395b 100644 --- a/source/module_esolver/esolver_ks_lcaopw.cpp +++ b/source/module_esolver/esolver_ks_lcaopw.cpp @@ -28,7 +28,7 @@ #include "module_hsolver/diago_iter_assist.h" #include "module_hsolver/hsolver_lcaopw.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_io/berryphase.h" #include "module_io/numerical_basis.h" #include "module_io/numerical_descriptor.h" diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index cf5e300537..2110cd76fc 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -26,7 +26,7 @@ #include "module_hsolver/diago_iter_assist.h" #include "module_hsolver/hsolver_pw.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_io/berryphase.h" #include "module_io/cube_io.h" #include "module_io/get_pchg_pw.h" @@ -73,7 +73,7 @@ ESolver_KS_PW::ESolver_KS_PW() #if ((defined __CUDA) || (defined __ROCM)) if (this->device == base_device::GpuDevice) { - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); hsolver::createGpuSolverHandle(); container::kernels::createGpuBlasHandle(); container::kernels::createGpuSolverHandle(); @@ -101,7 +101,7 @@ ESolver_KS_PW::~ESolver_KS_PW() if (this->device == base_device::GpuDevice) { #if defined(__CUDA) || defined(__ROCM) - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); hsolver::destroyGpuSolverHandle(); container::kernels::destroyGpuBlasHandle(); container::kernels::destroyGpuSolverHandle(); diff --git a/source/module_esolver/pw_others.cpp b/source/module_esolver/pw_others.cpp index ef32f041e8..0f2be0a998 100644 --- a/source/module_esolver/pw_others.cpp +++ b/source/module_esolver/pw_others.cpp @@ -30,7 +30,7 @@ #include "module_hsolver/diago_iter_assist.h" #include "module_hsolver/hsolver_pw.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_io/berryphase.h" #include "module_io/numerical_basis.h" #include "module_io/numerical_descriptor.h" diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp index d6602e6b11..21b745e17b 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp @@ -5,7 +5,7 @@ #include "spin_constrain.h" #include "module_hamilt_pw/hamilt_pwdft/onsite_projector.h" #include "module_base/parallel_reduce.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_hsolver/hsolver_lcao.h" #include "module_hsolver/hsolver_pw.h" #include "module_elecstate/elecstate_pw.h" @@ -84,7 +84,7 @@ void spinconstrain::SpinConstrain>::calculate_delta_hcc(std { #if ((defined __CUDA) || (defined __ROCM)) base_device::DEVICE_GPU* ctx = {}; - hsolver::gemm_op, base_device::DEVICE_GPU>()( + ModuleBase::gemm_op, base_device::DEVICE_GPU>()( ctx, transa, transb, @@ -108,7 +108,7 @@ void spinconstrain::SpinConstrain>::calculate_delta_hcc(std else if (PARAM.inp.device == "cpu") { base_device::DEVICE_CPU* ctx = {}; - hsolver::gemm_op, base_device::DEVICE_CPU>()( + ModuleBase::gemm_op, base_device::DEVICE_CPU>()( ctx, transa, transb, diff --git a/source/module_hamilt_pw/hamilt_pwdft/forces.h b/source/module_hamilt_pw/hamilt_pwdft/forces.h index 695520dceb..5472396fcc 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/forces.h +++ b/source/module_hamilt_pw/hamilt_pwdft/forces.h @@ -11,7 +11,7 @@ #include "module_elecstate/elecstate.h" #include "module_hamilt_pw/hamilt_pwdft/VL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/force_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" #include "structure_factor.h" @@ -129,7 +129,7 @@ class Forces base_device::DEVICE_CPU* cpu_ctx = {}; base_device::AbacusDevice_t device = {}; private: - using gemm_op = hsolver::gemm_op, Device>; + using gemm_op = ModuleBase::gemm_op, Device>; using resmem_complex_op = base_device::memory::resize_memory_op, Device>; using resmem_complex_h_op = base_device::memory::resize_memory_op, base_device::DEVICE_CPU>; diff --git a/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp b/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp index 523cb2b504..b2dc156560 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp @@ -295,7 +295,7 @@ void FS_Nonlocal_tools::reduce_pool_becp(const int& npm) #ifdef __MPI if (GlobalV::NPROC_IN_POOL > 1) { - Parallel_Common::reduce_dev(this->ctx, this->becp, size_becp_act, POOL_WORLD); + Parallel_Common::reduce_dev, Device>(this->becp, size_becp_act, POOL_WORLD); } #endif } diff --git a/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.h b/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.h index 0cc640f27c..64a76e700d 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.h +++ b/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.h @@ -7,7 +7,7 @@ #include "module_cell/unitcell.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/stress_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" #include @@ -215,7 +215,7 @@ class FS_Nonlocal_tools std::complex* becp = nullptr; // nbands * nkb /// @brief rename the operators for CPU/GPU device - using gemm_op = hsolver::gemm_op, Device>; + using gemm_op = ModuleBase::gemm_op, Device>; using cal_stress_nl_op = hamilt::cal_stress_nl_op; using cal_dbecp_noevc_nl_op = hamilt::cal_dbecp_noevc_nl_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h b/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h index f87dca7745..badeae0db6 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h @@ -6,7 +6,7 @@ #include "module_elecstate/potentials/potential_new.h" #include "module_hamilt_general/hamilt.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace hamilt { @@ -44,8 +44,8 @@ class HamiltPW : public Hamilt T* qq_so = nullptr; Device* ctx = {}; - using gemv_op = hsolver::gemv_op; - using gemm_op = hsolver::gemm_op; + using gemv_op = ModuleBase::gemv_op; + using gemm_op = ModuleBase::gemm_op; using setmem_complex_op = base_device::memory::set_memory_op; using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/nonlocal_maths.hpp b/source/module_hamilt_pw/hamilt_pwdft/nonlocal_maths.hpp index 79649fab07..292ad80d43 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/nonlocal_maths.hpp +++ b/source/module_hamilt_pw/hamilt_pwdft/nonlocal_maths.hpp @@ -7,7 +7,7 @@ #include "module_cell/unitcell.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/stress_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace hamilt { diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_proj_tools.h b/source/module_hamilt_pw/hamilt_pwdft/onsite_proj_tools.h index 17c7e06491..0376a9709f 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_proj_tools.h +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_proj_tools.h @@ -7,7 +7,7 @@ #include "module_cell/unitcell.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/stress_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" #include @@ -178,7 +178,7 @@ class Onsite_Proj_tools std::complex* becp = nullptr; // nbands * nkb /// @brief rename the operators for CPU/GPU device - using gemm_op = hsolver::gemm_op, Device>; + using gemm_op = ModuleBase::gemm_op, Device>; using cal_stress_nl_op = hamilt::cal_stress_nl_op; using cal_dbecp_noevc_nl_op = hamilt::cal_dbecp_noevc_nl_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp index 47faf38797..09982d8e06 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp @@ -8,7 +8,7 @@ #include "module_base/projgen.h" #include "module_base/blas_connector.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #ifdef __MPI #include "module_base/parallel_reduce.h" #include "module_base/parallel_common.h" diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.h b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.h index a2bb99354b..b34d8291de 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.h +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.h @@ -1,7 +1,7 @@ #ifndef MODULEHAMILTPW_ONSITEPROJECTOR_H #define MODULEHAMILTPW_ONSITEPROJECTOR_H #include "module_base/module_device/device.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_hamilt_pw/hamilt_pwdft/structure_factor.h" #include "module_basis/module_pw/pw_basis_k.h" #include "module_hamilt_pw/hamilt_pwdft/radial_proj.h" @@ -130,7 +130,7 @@ namespace projectors bool initialed = false; /// @brief rename the operators for CPU/GPU device - using gemm_op = hsolver::gemm_op, Device>; + using gemm_op = ModuleBase::gemm_op, Device>; using resmem_complex_op = base_device::memory::resize_memory_op, Device>; using resmem_complex_h_op = base_device::memory::resize_memory_op, base_device::DEVICE_CPU>; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h index 21fc574f5b..133eed1f5b 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h @@ -5,7 +5,7 @@ #include "module_base/matrix.h" #include "module_basis/module_pw/pw_basis_k.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/meta_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include @@ -81,7 +81,7 @@ class Meta> : public OperatorPW base_device::DEVICE_CPU* cpu_ctx = {}; T *porter = nullptr; using meta_op = meta_pw_op; - using vector_mul_vector_op = hsolver::vector_mul_vector_op; + using vector_mul_vector_op = ModuleBase::vector_mul_vector_op; using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; using setmem_complex_op = base_device::memory::set_memory_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h index 91e760920a..31a98d24c9 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h @@ -5,7 +5,7 @@ #include "module_cell/unitcell.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/nonlocal_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" @@ -85,8 +85,8 @@ class Nonlocal> : public OperatorPW Real * deeq = nullptr; T * deeq_nc = nullptr; // using nonlocal_op = nonlocal_pw_op; - using gemv_op = hsolver::gemv_op; - using gemm_op = hsolver::gemm_op; + using gemv_op = ModuleBase::gemv_op; + using gemm_op = ModuleBase::gemm_op; using nonlocal_op = nonlocal_pw_op; using setmem_complex_op = base_device::memory::set_memory_op; using resmem_complex_op = base_device::memory::resize_memory_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/onsite_proj_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/onsite_proj_pw.h index 975967d5c8..b28657d0df 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/onsite_proj_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/onsite_proj_pw.h @@ -4,7 +4,7 @@ #include "operator_pw.h" #include "module_cell/unitcell.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace hamilt { @@ -76,8 +76,8 @@ class OnsiteProj> : public OperatorPW Device* ctx = {}; base_device::DEVICE_CPU* cpu_ctx = {}; - using gemv_op = hsolver::gemv_op; - using gemm_op = hsolver::gemm_op; + using gemv_op = ModuleBase::gemv_op; + using gemm_op = ModuleBase::gemm_op; using setmem_complex_op = base_device::memory::set_memory_op; using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/stress_func.h b/source/module_hamilt_pw/hamilt_pwdft/stress_func.h index 878206ad38..20f6a91937 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/stress_func.h +++ b/source/module_hamilt_pw/hamilt_pwdft/stress_func.h @@ -14,7 +14,7 @@ #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/stress_op.h" #include "module_hamilt_pw/hamilt_pwdft/structure_factor.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" //------------------------------------------------------------------- @@ -241,7 +241,7 @@ class Stress_Func base_device::DEVICE_CPU* cpu_ctx = {}; base_device::AbacusDevice_t device = {}; private: - using gemm_op = hsolver::gemm_op, Device>; + using gemm_op = ModuleBase::gemm_op, Device>; using cal_stress_nl_op = hamilt::cal_stress_nl_op; using cal_dbecp_noevc_nl_op = hamilt::cal_dbecp_noevc_nl_op; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp index 34e20977eb..9facef1ddf 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp @@ -1,7 +1,7 @@ #include "sto_che.h" #include "module_base/blas_connector.h" #include "module_base/module_device/device.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_base/module_container/ATen/kernels/blas.h" template diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_che.h b/source/module_hamilt_pw/hamilt_stodft/sto_che.h index f241553b66..578e5df0fb 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_che.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_che.h @@ -1,7 +1,7 @@ #ifndef STO_CHE_H #define STO_CHE_H #include "module_base/math_chebyshev.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_base/module_container/ATen/kernels/blas.h" template @@ -51,7 +51,7 @@ REAL vTMv(const REAL* v, const REAL* M, const int n) const REAL zero = 0; REAL* y = nullptr; base_device::memory::resize_memory_op()(y, n); - hsolver::gemv_op()(ctx, normal, n, n, &one, M, n, v, inc, &zero, y, inc); + ModuleBase::gemv_op()(ctx, normal, n, n, &one, M, n, v, inc, &zero, y, inc); REAL result = 0; REAL* dot_device = nullptr; base_device::memory::resize_memory_op()(dot_device, 1); diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index 8ec669febd..bd029a401d 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -7,7 +7,7 @@ #include "module_elecstate/occupy.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_parameter/parameter.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_elecstate/kernels/elecstate_op.h" template @@ -78,7 +78,7 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, char transN = 'N'; // sum(b - hsolver::gemm_op()(ctx, + ModuleBase::gemm_op()(ctx, transC, transN, PARAM.inp.nbands, @@ -95,7 +95,7 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, Parallel_Reduce::reduce_pool(sum, PARAM.inp.nbands * nchipk); // psi -= psi * sum - hsolver::gemm_op()(ctx, + ModuleBase::gemm_op()(ctx, transN, transN, npw, @@ -406,7 +406,7 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& const int N = norder; const Real kweight = this->pkv->wk[ik]; - hsolver::gemm_op()(this->ctx, trans, normal, N, N, M, &kweight, vec_all, LDA, vec_all, LDA, &one, spolyv, N); + ModuleBase::gemm_op()(this->ctx, trans, normal, N, N, M, &kweight, vec_all, LDA, vec_all, LDA, &one, spolyv, N); // dgemm_(&trans, &normal, &N, &N, &M, &kweight, vec_all, &LDA, vec_all, &LDA, &one, spolyv, &N); } ModuleBase::timer::tick("Stochastic_Iter", "calPn"); diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.h b/source/module_hamilt_pw/hamilt_stodft/sto_iter.h index 901b1311f3..9953cfcd3b 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.h @@ -163,7 +163,7 @@ class Stochastic_Iter using delmem_complex_op = base_device::memory::delete_memory_op; using castmem_d2z_op = base_device::memory::cast_memory_op; using castmem_var_d2h_op = base_device::memory::cast_memory_op; - using gemv_op = hsolver::gemv_op; + using gemv_op = ModuleBase::gemv_op; }; #endif // Eelectrons_Iter diff --git a/source/module_hsolver/CMakeLists.txt b/source/module_hsolver/CMakeLists.txt index 93a708f21d..7f6c8ca4c6 100644 --- a/source/module_hsolver/CMakeLists.txt +++ b/source/module_hsolver/CMakeLists.txt @@ -36,7 +36,6 @@ if(ENABLE_LCAO) if(USE_CUDA) list(APPEND objects - ./kernels/math_kernel_op.cpp ./kernels/dngvd_op.cpp ./kernels/cuda/diag_cusolver.cu diago_cusolver.cpp diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 846bef9ff8..36f77d372d 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -10,7 +10,7 @@ #include "diago_iter_assist.h" #include "module_base/blas_connector.h" #include "module_base/global_function.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace hsolver { diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index a80c1406b6..90907de5e9 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -7,7 +7,7 @@ #include "module_base/module_device/types.h" #include "module_base/module_device/memory_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_hsolver/kernels/dngvd_op.h" #include @@ -343,9 +343,9 @@ class DiagoBPCG // note: these operators use template parameter base_device::Device_* // defined in module_base/module_device/types.h // different from ct_Device! - using calc_grad_with_block_op = hsolver::calc_grad_with_block_op; - using line_minimize_with_block_op = hsolver::line_minimize_with_block_op; - using gemm_op = hsolver::gemm_op; + using calc_grad_with_block_op = ModuleBase::calc_grad_with_block_op; + using line_minimize_with_block_op = ModuleBase::line_minimize_with_block_op; + using gemm_op = ModuleBase::gemm_op; }; diff --git a/source/module_hsolver/diago_cg.cpp b/source/module_hsolver/diago_cg.cpp index 29bdffa977..ea872d6d3e 100644 --- a/source/module_hsolver/diago_cg.cpp +++ b/source/module_hsolver/diago_cg.cpp @@ -226,14 +226,14 @@ void DiagoCG::calc_grad(const ct::Tensor& prec, // } // denghui replace this at 20221106 // TODO: use GPU precondition to initialize CG class - vector_div_vector_op()(ctx_, this->n_basis_, grad.data(), hphi.data(), prec.data()); - vector_div_vector_op()(ctx_, this->n_basis_, pphi.data(), sphi.data(), prec.data()); + ModuleBase::vector_div_vector_op()(ctx_, this->n_basis_, grad.data(), hphi.data(), prec.data()); + ModuleBase::vector_div_vector_op()(ctx_, this->n_basis_, pphi.data(), sphi.data(), prec.data()); // Update lambda ! // (4) - const Real eh = hsolver::dot_real_op()(ctx_, this->n_basis_, sphi.data(), grad.data()); + const Real eh = ModuleBase::dot_real_op()(ctx_, this->n_basis_, sphi.data(), grad.data()); // (5) - const Real es = hsolver::dot_real_op()(ctx_, this->n_basis_, sphi.data(), pphi.data()); + const Real es = ModuleBase::dot_real_op()(ctx_, this->n_basis_, sphi.data(), pphi.data()); const Real lambda = eh / es; // Update g! @@ -247,13 +247,13 @@ void DiagoCG::calc_grad(const ct::Tensor& prec, // grad.data()[i] -= lambda * this->pphi[i]; // } // haozhihan replace this 2022-10-6 - constantvector_addORsub_constantVector_op()(ctx_, - this->n_basis_, - grad.data(), - grad.data(), - 1.0, - pphi.data(), - (-lambda)); + ModuleBase::constantvector_addORsub_constantVector_op()(ctx_, + this->n_basis_, + grad.data(), + grad.data(), + 1.0, + pphi.data(), + (-lambda)); } template @@ -264,49 +264,49 @@ void DiagoCG::orth_grad(const ct::Tensor& psi, ct::Tensor& lagrange) { this->spsi_func_(grad, scg); // scg = S|grad> - gemv_op()(ctx_, - 'C', - this->n_basis_, - m, - this->one_, - psi.data(), - this->n_basis_, - scg.data(), - 1, - this->zero_, - lagrange.data(), - 1); + ModuleBase::gemv_op()(ctx_, + 'C', + this->n_basis_, + m, + this->one_, + psi.data(), + this->n_basis_, + scg.data(), + 1, + this->zero_, + lagrange.data(), + 1); Parallel_Reduce::reduce_pool(lagrange.data(), m); // (3) orthogonal |g> and |scg> to all states (0~m-1) //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // haozhihan replace 2022-10-07 - gemv_op()(ctx_, - 'N', - this->n_basis_, - m, - this->neg_one_, - psi.data(), - this->n_basis_, - lagrange.data(), - 1, - this->one_, - grad.data(), - 1); - - gemv_op()(ctx_, - 'N', - this->n_basis_, - m, - this->neg_one_, - psi.data(), - this->n_basis_, - lagrange.data(), - 1, - this->one_, - scg.data(), - 1); + ModuleBase::gemv_op()(ctx_, + 'N', + this->n_basis_, + m, + this->neg_one_, + psi.data(), + this->n_basis_, + lagrange.data(), + 1, + this->one_, + grad.data(), + 1); + + ModuleBase::gemv_op()(ctx_, + 'N', + this->n_basis_, + m, + this->neg_one_, + psi.data(), + this->n_basis_, + lagrange.data(), + 1, + this->one_, + scg.data(), + 1); } template @@ -328,7 +328,7 @@ void DiagoCG::calc_gamma_cg(const int& iter, // gg_inter = // Attention : the 'g' in g0 is getted last time gg_inter - = hsolver::dot_real_op()(ctx_, this->n_basis_, grad.data(), g0.data()); // b means before + = ModuleBase::dot_real_op()(ctx_, this->n_basis_, grad.data(), g0.data()); // b means before } // (2) Update for g0! @@ -342,11 +342,11 @@ void DiagoCG::calc_gamma_cg(const int& iter, // } // denghui replace this 20221106 // TODO: use GPU precondition instead - vector_mul_vector_op()(ctx_, this->n_basis_, g0.data(), scg.data(), prec.data()); + ModuleBase::vector_mul_vector_op()(ctx_, this->n_basis_, g0.data(), scg.data(), prec.data()); // (3) Update gg_now! // gg_now = < g|P|scg > = < g|g0 > - const Real gg_now = hsolver::dot_real_op()(ctx_, this->n_basis_, grad.data(), g0.data()); + const Real gg_now = ModuleBase::dot_real_op()(ctx_, this->n_basis_, grad.data(), g0.data()); if (iter == 0) { @@ -370,13 +370,13 @@ void DiagoCG::calc_gamma_cg(const int& iter, // pcg[i] = gamma * pcg[i] + grad.data()[i]; // } // haozhihan replace this 2022-10-6 - constantvector_addORsub_constantVector_op()(ctx_, - this->n_basis_, - cg.data(), - cg.data(), - gamma, - grad.data(), - 1.0); + ModuleBase::constantvector_addORsub_constantVector_op()(ctx_, + this->n_basis_, + cg.data(), + cg.data(), + gamma, + grad.data(), + 1.0); const Real norma = gamma * cg_norm * sin(theta); T znorma = static_cast(norma * -1); @@ -388,7 +388,7 @@ void DiagoCG::calc_gamma_cg(const int& iter, { pcg[i] -= norma * pphi_m[i]; }*/ - axpy_op()(ctx_, this->n_basis_, &znorma, phi_m.data(), 1, cg.data(), 1); + ModuleBase::axpy_op()(ctx_, this->n_basis_, &znorma, phi_m.data(), 1, cg.data(), 1); } } @@ -404,15 +404,15 @@ bool DiagoCG::update_psi(const ct::Tensor& pphi, ct::Tensor& sphi, ct::Tensor& hphi) { - cg_norm = sqrt(hsolver::dot_real_op()(ctx_, this->n_basis_, cg.data(), scg.data())); + cg_norm = sqrt(ModuleBase::dot_real_op()(ctx_, this->n_basis_, cg.data(), scg.data())); if (cg_norm < 1.0e-10) return true; const Real a0 - = hsolver::dot_real_op()(ctx_, this->n_basis_, phi_m.data(), pphi.data()) * 2.0 / cg_norm; + = ModuleBase::dot_real_op()(ctx_, this->n_basis_, phi_m.data(), pphi.data()) * 2.0 / cg_norm; const Real b0 - = hsolver::dot_real_op()(ctx_, this->n_basis_, cg.data(), pphi.data()) / (cg_norm * cg_norm); + = ModuleBase::dot_real_op()(ctx_, this->n_basis_, cg.data(), pphi.data()) / (cg_norm * cg_norm); const Real e0 = eigen; theta = atan(a0 / (e0 - b0)) / 2.0; @@ -438,13 +438,13 @@ bool DiagoCG::update_psi(const ct::Tensor& pphi, // } // haozhihan replace this 2022-10-6 - constantvector_addORsub_constantVector_op()(ctx_, - this->n_basis_, - phi_m.data(), - phi_m.data(), - cost, - cg.data(), - sint_norm); + ModuleBase::constantvector_addORsub_constantVector_op()(ctx_, + this->n_basis_, + phi_m.data(), + phi_m.data(), + cost, + cg.data(), + sint_norm); if (std::abs(eigen - e0) < ethreshold) { @@ -460,20 +460,20 @@ bool DiagoCG::update_psi(const ct::Tensor& pphi, // } // haozhihan replace this 2022-10-6 - constantvector_addORsub_constantVector_op()(ctx_, - this->n_basis_, - sphi.data(), - sphi.data(), - cost, - scg.data(), - sint_norm); - constantvector_addORsub_constantVector_op()(ctx_, - this->n_basis_, - hphi.data(), - hphi.data(), - cost, - pphi.data(), - sint_norm); + ModuleBase::constantvector_addORsub_constantVector_op()(ctx_, + this->n_basis_, + sphi.data(), + sphi.data(), + cost, + scg.data(), + sint_norm); + ModuleBase::constantvector_addORsub_constantVector_op()(ctx_, + this->n_basis_, + hphi.data(), + hphi.data(), + cost, + pphi.data(), + sint_norm); return false; } } @@ -496,36 +496,36 @@ void DiagoCG::schmit_orth(const int& m, const ct::Tensor& psi, const //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // haozhihan replace 2022-10-6 int inc = 1; - gemv_op()(ctx_, - 'C', - this->n_basis_, - m + 1, - this->one_, - psi.data(), - this->n_basis_, - sphi.data(), - inc, - this->zero_, - lagrange_so.data(), - inc); + ModuleBase::gemv_op()(ctx_, + 'C', + this->n_basis_, + m + 1, + this->one_, + psi.data(), + this->n_basis_, + sphi.data(), + inc, + this->zero_, + lagrange_so.data(), + inc); // be careful , here reduce m+1 Parallel_Reduce::reduce_pool(lagrange_so.data(), m + 1); //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // haozhihan replace 2022-10-6 - gemv_op()(ctx_, - 'N', - this->n_basis_, - m, - this->neg_one_, - psi.data(), - this->n_basis_, - lagrange_so.data(), - inc, - this->one_, - phi_m.data(), - inc); + ModuleBase::gemv_op()(ctx_, + 'N', + this->n_basis_, + m, + this->neg_one_, + psi.data(), + this->n_basis_, + lagrange_so.data(), + inc, + this->one_, + phi_m.data(), + inc); //====================================================================== /*for (int j = 0; j < m; j++) @@ -563,7 +563,7 @@ void DiagoCG::schmit_orth(const int& m, const ct::Tensor& psi, const // { // pphi_m[ig] /= psi_norm; // } - vector_div_constant_op()(ctx_, this->n_basis_, phi_m.data(), phi_m.data(), psi_norm); + ModuleBase::vector_div_constant_op()(ctx_, this->n_basis_, phi_m.data(), phi_m.data(), psi_norm); // ModuleBase::timer::tick("DiagoCG","schmit_orth"); } diff --git a/source/module_hsolver/diago_cg.h b/source/module_hsolver/diago_cg.h index 2741df42d4..9d254ded18 100644 --- a/source/module_hsolver/diago_cg.h +++ b/source/module_hsolver/diago_cg.h @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include @@ -126,7 +126,7 @@ class DiagoCG final bool test_exit_cond(const int& ntry, const int& notconv) const; - using dot_real_op = hsolver::dot_real_op; + using dot_real_op = ModuleBase::dot_real_op; const T * one_ = nullptr, * zero_ = nullptr, * neg_one_ = nullptr; }; diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index f7daf229a2..177e68847c 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -5,7 +5,7 @@ #include "module_base/module_device/device.h" #include "module_base/timer.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_base/kernels/dsp/dsp_connector.h" #include "module_hsolver/diag_hs_para.h" @@ -191,24 +191,25 @@ int Diago_DavSubspace::diag_once(const HPsiFunc& hpsi_func, setmem_complex_op()(psi_in, 0, n_band * psi_in_dmax); #ifdef __DSP - gemm_op_mt() // In order to not coding another whole template, using this method to minimize the code change. + ModuleBase::gemm_op_mt() // In order to not coding another whole template, using this method to + // minimize the code change. #else - gemm_op() + ModuleBase::gemm_op() #endif - (this->ctx, - 'N', - 'N', - this->dim, - this->n_band, - nbase, - this->one, - this->psi_in_iter, - this->dim, - this->vcc, - this->nbase_x, - this->zero, - psi_in, - psi_in_dmax); + (this->ctx, + 'N', + 'N', + this->dim, + this->n_band, + nbase, + this->one, + this->psi_in_iter, + this->dim, + this->vcc, + this->nbase_x, + this->zero, + psi_in, + psi_in_dmax); if (!this->notconv || (dav_iter == this->iter_nmax)) { @@ -275,9 +276,9 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, } #ifdef __DSP - gemm_op_mt() + ModuleBase::gemm_op_mt() #else - gemm_op() + ModuleBase::gemm_op() #endif (this->ctx, 'N', @@ -308,11 +309,11 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, { syncmem_var_h2d_op()(e_temp_hd, e_temp_cpu.data(), nbase); } - vector_mul_vector_op()(this->ctx, - nbase, - vcc + m * this->nbase_x, - vcc + m * this->nbase_x, - e_temp_hd); + ModuleBase::vector_mul_vector_op()(this->ctx, + nbase, + vcc + m * this->nbase_x, + vcc + m * this->nbase_x, + e_temp_hd); } if(this->device == base_device::GpuDevice) { @@ -320,24 +321,24 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, } #ifdef __DSP - gemm_op_mt() + ModuleBase::gemm_op_mt() #else - gemm_op() -#endif - (this->ctx, - 'N', - 'N', - this->dim, - notconv, - nbase, - this->one, - psi_iter, - this->dim, - vcc, - this->nbase_x, - this->one, - psi_iter + nbase * this->dim, - this->dim); + ModuleBase::gemm_op() +#endif + (this->ctx, + 'N', + 'N', + this->dim, + notconv, + nbase, + this->one, + psi_iter, + this->dim, + vcc, + this->nbase_x, + this->one, + psi_iter + nbase * this->dim, + this->dim); // "precondition!!!" std::vector pre(this->dim, 0.0); @@ -353,20 +354,20 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, if (this->device == base_device::GpuDevice) { syncmem_var_h2d_op()(this->d_precondition, pre.data(), this->dim); - vector_div_vector_op()(this->ctx, - this->dim, - psi_iter + (nbase + m) * this->dim, - psi_iter + (nbase + m) * this->dim, - this->d_precondition); + ModuleBase::vector_div_vector_op()(this->ctx, + this->dim, + psi_iter + (nbase + m) * this->dim, + psi_iter + (nbase + m) * this->dim, + this->d_precondition); } else #endif { - vector_div_vector_op()(this->ctx, - this->dim, - psi_iter + (nbase + m) * this->dim, - psi_iter + (nbase + m) * this->dim, - pre.data()); + ModuleBase::vector_div_vector_op()(this->ctx, + this->dim, + psi_iter + (nbase + m) * this->dim, + psi_iter + (nbase + m) * this->dim, + pre.data()); } } @@ -374,19 +375,19 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, std::vector psi_norm(notconv, 0.0); for (size_t i = 0; i < notconv; i++) { - psi_norm[i] = dot_real_op()(this->ctx, - this->dim, - psi_iter + (nbase + i) * this->dim, - psi_iter + (nbase + i) * this->dim, - true); + psi_norm[i] = ModuleBase::dot_real_op()(this->ctx, + this->dim, + psi_iter + (nbase + i) * this->dim, + psi_iter + (nbase + i) * this->dim, + true); assert(psi_norm[i] > 0.0); psi_norm[i] = sqrt(psi_norm[i]); - vector_div_constant_op()(this->ctx, - this->dim, - psi_iter + (nbase + i) * this->dim, - psi_iter + (nbase + i) * this->dim, - psi_norm[i]); + ModuleBase::vector_div_constant_op()(this->ctx, + this->dim, + psi_iter + (nbase + i) * this->dim, + psi_iter + (nbase + i) * this->dim, + psi_norm[i]); } // update hpsi[:, nbase:nbase+notconv] @@ -409,9 +410,9 @@ void Diago_DavSubspace::cal_elem(const int& dim, ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem"); #ifdef __DSP - gemm_op_mt() + ModuleBase::gemm_op_mt() #else - gemm_op() + ModuleBase::gemm_op() #endif (this->ctx, 'C', @@ -429,9 +430,9 @@ void Diago_DavSubspace::cal_elem(const int& dim, this->nbase_x); #ifdef __DSP - gemm_op_mt() + ModuleBase::gemm_op_mt() #else - gemm_op() + ModuleBase::gemm_op() #endif (this->ctx, 'C', @@ -691,9 +692,9 @@ void Diago_DavSubspace::refresh(const int& dim, ModuleBase::timer::tick("Diago_DavSubspace", "refresh"); #ifdef __DSP - gemm_op_mt() + ModuleBase::gemm_op_mt() #else - gemm_op() + ModuleBase::gemm_op() #endif (this->ctx, 'N', diff --git a/source/module_hsolver/diago_david.cpp b/source/module_hsolver/diago_david.cpp index 6afaf998b8..21865a4ed1 100644 --- a/source/module_hsolver/diago_david.cpp +++ b/source/module_hsolver/diago_david.cpp @@ -5,7 +5,7 @@ #include "module_base/module_device/device.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #ifdef USE_PAW #include "module_cell/module_paw/paw_cell.h" @@ -266,21 +266,20 @@ int DiagoDavid::diag_once(const HPsiFunc& hpsi_func, setmem_complex_op()(psi_in, 0, nband * ld_psi); //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - gemm_op()(this->ctx, - 'N', - 'N', - dim, // m: row of A,C - nband, // n: col of B,C - nbase, // k: col of A, row of B - this->one, - basis, // A dim * nbase - dim, - this->vcc, // B nbase * nband - nbase_x, - this->zero, - psi_in, // C dim * nband - ld_psi - ); + ModuleBase::gemm_op()(this->ctx, + 'N', + 'N', + dim, // m: row of A,C + nband, // n: col of B,C + nbase, // k: col of A, row of B + this->one, + basis, // A dim * nbase + dim, + this->vcc, // B nbase * nband + nbase_x, + this->zero, + psi_in, // C dim * nband + ld_psi); if (!this->notconv || (dav_iter == david_maxiter)) { @@ -378,20 +377,20 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, // basis[nbase] = hpsi * vc_ev_vector = hpsi*vcc // basis' = vc_ev_vector' * hpsi' // (dim, notconv) (dim, nbase) (nbase, notconv) - gemm_op()(this->ctx, - 'N', - 'N', - dim, // m: row of A,C - notconv, // n: col of B,C - nbase, // k: col of A, row of B - this->one, // alpha - hpsi, // A dim * nbase - dim, // LDA: if(N) max(1,m) if(T) max(1,k) - vc_ev_vector, // B nbase * notconv - nbase, // LDB: if(N) max(1,k) if(T) max(1,n) - this->zero, // belta - basis + dim*nbase, // C dim * notconv - dim // LDC: if(N) max(1, m) + ModuleBase::gemm_op()(this->ctx, + 'N', + 'N', + dim, // m: row of A,C + notconv, // n: col of B,C + nbase, // k: col of A, row of B + this->one, // alpha + hpsi, // A dim * nbase + dim, // LDA: if(N) max(1,m) if(T) max(1,k) + vc_ev_vector, // B nbase * notconv + nbase, // LDB: if(N) max(1,k) if(T) max(1,n) + this->zero, // belta + basis + dim * nbase, // C dim * notconv + dim // LDC: if(N) max(1, m) ); //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< @@ -417,21 +416,21 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, Real* e_temp_gpu = nullptr; resmem_var_op()(e_temp_gpu, nbase); syncmem_var_h2d_op()(e_temp_gpu, e_temp_cpu.data(), nbase); - vector_mul_vector_op()(this->ctx, - nbase, - vc_ev_vector + m * nbase, - vc_ev_vector + m * nbase, - e_temp_gpu); + ModuleBase::vector_mul_vector_op()(this->ctx, + nbase, + vc_ev_vector + m * nbase, + vc_ev_vector + m * nbase, + e_temp_gpu); delmem_var_op()(e_temp_gpu); #endif } else { - vector_mul_vector_op()(this->ctx, - nbase, - vc_ev_vector + m * nbase, - vc_ev_vector + m * nbase, - e_temp_cpu.data()); + ModuleBase::vector_mul_vector_op()(this->ctx, + nbase, + vc_ev_vector + m * nbase, + vc_ev_vector + m * nbase, + e_temp_cpu.data()); } } //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< @@ -441,20 +440,20 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, // = (H - lambda * S) * psi * vcc // = (H - lambda * S) * psi_new //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - gemm_op()(this->ctx, - 'N', - 'N', - dim, // m: row of A,C - notconv, // n: col of B,C - nbase, // k: col of A, row of B - this->one, // alpha - spsi, // A - dim, // LDA: if(N) max(1,m) if(T) max(1,k) - vc_ev_vector, // B - nbase, // LDB: if(N) max(1,k) if(T) max(1,n) - this->one, // belta - basis + dim*nbase, // C dim * notconv - dim // LDC: if(N) max(1, m) + ModuleBase::gemm_op()(this->ctx, + 'N', + 'N', + dim, // m: row of A,C + notconv, // n: col of B,C + nbase, // k: col of A, row of B + this->one, // alpha + spsi, // A + dim, // LDA: if(N) max(1,m) if(T) max(1,k) + vc_ev_vector, // B + nbase, // LDB: if(N) max(1,k) if(T) max(1,n) + this->one, // belta + basis + dim * nbase, // C dim * notconv + dim // LDC: if(N) max(1, m) ); //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< @@ -469,20 +468,20 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, if (this->device == base_device::GpuDevice) { #if defined(__CUDA) || defined(__ROCM) - vector_div_vector_op()(this->ctx, - dim, - basis + dim*(nbase + m), - basis + dim*(nbase + m), - this->d_precondition); + ModuleBase::vector_div_vector_op()(this->ctx, + dim, + basis + dim * (nbase + m), + basis + dim * (nbase + m), + this->d_precondition); #endif } else { - vector_div_vector_op()(this->ctx, - dim, - basis + dim*(nbase + m), - basis + dim*(nbase + m), - this->precondition); + ModuleBase::vector_div_vector_op()(this->ctx, + dim, + basis + dim * (nbase + m), + basis + dim * (nbase + m), + this->precondition); } //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // for (int ig = 0; ig < dim; ig++) @@ -519,20 +518,20 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, // first nbase bands psi* dot notconv bands spsi to prepare lagrange_matrix // calculate the square matrix for future lagranges - gemm_op()(this->ctx, - 'C', - 'N', - nbase, // m: row of A,C - notconv, // n: col of B,C - dim, // k: col of A, row of B - this->one, // alpha - basis, // A - dim, // LDA: if(N) max(1,m) if(T) max(1,k) - &spsi[nbase * dim], // B - dim, // LDB: if(N) max(1,k) if(T) max(1,n) - this->zero, // belta - lagrange, // C - nbase + notconv // LDC: if(N) max(1, m) + ModuleBase::gemm_op()(this->ctx, + 'C', + 'N', + nbase, // m: row of A,C + notconv, // n: col of B,C + dim, // k: col of A, row of B + this->one, // alpha + basis, // A + dim, // LDA: if(N) max(1,m) if(T) max(1,k) + &spsi[nbase * dim], // B + dim, // LDB: if(N) max(1,k) if(T) max(1,n) + this->zero, // belta + lagrange, // C + nbase + notconv // LDC: if(N) max(1, m) ); for (int m = 0; m < notconv; m++) @@ -593,20 +592,20 @@ void DiagoDavid::cal_elem(const int& dim, ModuleBase::timer::tick("DiagoDavid", "cal_elem"); // hcc[nbase](notconv, nbase + notconv)= basis[nbase]' * hpsi - gemm_op()(this->ctx, - 'C', - 'N', - notconv, - nbase + notconv, - dim, - this->one, - basis + dim*nbase, // basis(:,nbase:) dim * notconv - dim, - hpsi, // dim * (nbase + notconv) - dim, - this->zero, - hcc + nbase, // notconv * (nbase + notconv) - nbase_x); + ModuleBase::gemm_op()(this->ctx, + 'C', + 'N', + notconv, + nbase + notconv, + dim, + this->one, + basis + dim * nbase, // basis(:,nbase:) dim * notconv + dim, + hpsi, // dim * (nbase + notconv) + dim, + this->zero, + hcc + nbase, // notconv * (nbase + notconv) + nbase_x); // scc[nbase] = basis[nbase]' * spsi // gemm_op()(this->ctx, // 'C', @@ -627,7 +626,7 @@ void DiagoDavid::cal_elem(const int& dim, #ifdef __MPI if (diag_comm.nproc > 1) { - matrixTranspose_op()(this->ctx, nbase_x, nbase_x, hcc, hcc); + ModuleBase::matrixTranspose_op()(this->ctx, nbase_x, nbase_x, hcc, hcc); // matrixTranspose_op()(this->ctx, nbase_x, nbase_x, scc, scc); auto* swap = new T[notconv * nbase_x]; @@ -657,7 +656,7 @@ void DiagoDavid::cal_elem(const int& dim, // Parallel_Reduce::reduce_complex_double_pool( hcc + nbase * nbase_x, notconv * nbase_x ); // Parallel_Reduce::reduce_complex_double_pool( scc + nbase * nbase_x, notconv * nbase_x ); - matrixTranspose_op()(this->ctx, nbase_x, nbase_x, hcc, hcc); + ModuleBase::matrixTranspose_op()(this->ctx, nbase_x, nbase_x, hcc, hcc); // matrixTranspose_op()(this->ctx, nbase_x, nbase_x, scc, scc); } #endif @@ -751,39 +750,37 @@ void DiagoDavid::refresh(const int& dim, setmem_complex_op()(basis , 0, nbase_x * dim); // basis(dim, nband) = hpsi(dim, nbase) * vcc(nbase, nband) - gemm_op()(this->ctx, - 'N', - 'N', - dim, // m: row of A,C - nband, // n: col of B,C - nbase, // k: col of A, row of B - this->one, - hpsi, // A dim * nbase - dim, - vcc, // B nbase * nband - nbase_x, - zero, - basis, // C dim * nband - dim - ); + ModuleBase::gemm_op()(this->ctx, + 'N', + 'N', + dim, // m: row of A,C + nband, // n: col of B,C + nbase, // k: col of A, row of B + this->one, + hpsi, // A dim * nbase + dim, + vcc, // B nbase * nband + nbase_x, + zero, + basis, // C dim * nband + dim); //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // basis[nband] = spsi * vcc - gemm_op()(this->ctx, - 'N', - 'N', - dim, // m: row of A,C - nband, // n: col of B,C - nbase, // k: col of A, row of B - this->one, - spsi, // A dim * nbase - dim, - vcc, // B nbase * nband - nbase_x, - this->zero, - basis + dim*nband, // C dim * nband - dim - ); + ModuleBase::gemm_op()(this->ctx, + 'N', + 'N', + dim, // m: row of A,C + nband, // n: col of B,C + nbase, // k: col of A, row of B + this->one, + spsi, // A dim * nbase + dim, + vcc, // B nbase * nband + nbase_x, + this->zero, + basis + dim * nband, // C dim * nband + dim); // hpsi = basis, spsi = basis[nband] syncmem_complex_op()(hpsi, basis, dim * nband); @@ -900,37 +897,37 @@ void DiagoDavid::SchmidtOrth(const int& dim, { // lagrange_m[m - mv_size + 1 - mm_size] // = basis[m - mv_size + 1 - mm_size]' * spsi[m] - gemm_op()(this->ctx, - 'C', - 'N', - mm_size, // m: row of A,C - mm_size, // n: col of B,C - dim, // k: col of A, row of B - this->one, // alpha - basis + dim*(m - mv_size + 1 - mm_size), // A - dim, // LDA: if(N) max(1,m) if(T) max(1,k) - &spsi[m * dim], // B - dim, // LDB: if(N) max(1,k) if(T) max(1,n) - this->zero, // belta - &lagrange_m[m - mv_size + 1 - mm_size], // C - nband // LDC: if(N) max(1, m) + ModuleBase::gemm_op()(this->ctx, + 'C', + 'N', + mm_size, // m: row of A,C + mm_size, // n: col of B,C + dim, // k: col of A, row of B + this->one, // alpha + basis + dim * (m - mv_size + 1 - mm_size), // A + dim, // LDA: if(N) max(1,m) if(T) max(1,k) + &spsi[m * dim], // B + dim, // LDB: if(N) max(1,k) if(T) max(1,n) + this->zero, // belta + &lagrange_m[m - mv_size + 1 - mm_size], // C + nband // LDC: if(N) max(1, m) ); } // calculate other lagranges for this band // lagrange_m[m - mv_size + 1] // = basis[m - mv_size + 1]' * spsi[m] - gemv_op()(this->ctx, - 'C', - dim, - mv_size, - this->one, - basis + dim*(m - mv_size + 1), - dim, - &spsi[m * dim], - 1, - this->zero, - &lagrange_m[m - mv_size + 1], - 1); + ModuleBase::gemv_op()(this->ctx, + 'C', + dim, + mv_size, + this->one, + basis + dim * (m - mv_size + 1), + dim, + &spsi[m * dim], + 1, + this->zero, + &lagrange_m[m - mv_size + 1], + 1); Parallel_Reduce::reduce_pool(lagrange_m, m + 1); @@ -942,21 +939,21 @@ void DiagoDavid::SchmidtOrth(const int& dim, // / psi_m = psi_m - \sum_{i < m} \langle psi(i)|S|psi(m) \rangle psi(i) // psi_m = psi_m - basis * lagrange_m - gemv_op()(this->ctx, - 'N', - dim, - m, - this->neg_one, - basis, - dim, - lagrange_m, - 1, - this->one, - psi_m, - 1); + ModuleBase::gemv_op()(this->ctx, + 'N', + dim, + m, + this->neg_one, + basis, + dim, + lagrange_m, + 1, + this->one, + psi_m, + 1); // psi_norm = psi_norm - lagrange_m ยท lagrange_m - psi_norm -= dot_real_op()(this->ctx, m, lagrange_m, lagrange_m, false); + psi_norm -= ModuleBase::dot_real_op()(this->ctx, m, lagrange_m, lagrange_m, false); // for (int j = 0; j < m; j++) // { @@ -983,7 +980,7 @@ void DiagoDavid::SchmidtOrth(const int& dim, else { // psi_m = psi_m / psi_norm - vector_div_constant_op()(this->ctx, dim, psi_m, psi_m, psi_norm); + ModuleBase::vector_div_constant_op()(this->ctx, dim, psi_m, psi_m, psi_norm); // for (int i = 0; i < npw; i++) // { // psi_m[i] /= psi_norm; diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index 5a3acf8e53..ea1f36d900 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -9,7 +9,7 @@ #include "module_base/parallel_reduce.h" #include "module_base/timer.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace hsolver { @@ -73,39 +73,39 @@ void DiagoIterAssist::diagH_subspace(const hamilt::Hamilt* hpsi_info hpsi_in(&psi, all_bands_range, hphi); pHamilt->ops->hPsi(hpsi_in); - gemm_op()(ctx, - 'C', - 'N', - nstart, - nstart, - dmin, - &one, - psi.get_pointer(), - dmax, - hphi, - dmax, - &zero, - hcc, - nstart); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + nstart, + nstart, + dmin, + &one, + psi.get_pointer(), + dmax, + hphi, + dmax, + &zero, + hcc, + nstart); T* sphi = temp; // do sPsi for all bands pHamilt->sPsi(psi.get_pointer(), sphi, dmax, dmin, nstart); - gemm_op()(ctx, - 'C', - 'N', - nstart, - nstart, - dmin, - &one, - psi.get_pointer(), - dmax, - sphi, - dmax, - &zero, - scc, - nstart); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + nstart, + nstart, + dmin, + &one, + psi.get_pointer(), + dmax, + sphi, + dmax, + &zero, + scc, + nstart); } if (GlobalV::NPROC_IN_POOL > 1) @@ -121,25 +121,25 @@ void DiagoIterAssist::diagH_subspace(const hamilt::Hamilt* const int ld_temp = in_place ? dmax : dmin; { // code block to calculate evc - gemm_op()(ctx, - 'N', - 'N', - dmin, - n_band, - nstart, - &one, - psi.get_pointer(), // dmin * nstart - dmax, - vcc, // nstart * n_band - nstart, - &zero, - temp, - ld_temp); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + dmin, + n_band, + nstart, + &one, + psi.get_pointer(), // dmin * nstart + dmax, + vcc, // nstart * n_band + nstart, + &zero, + temp, + ld_temp); } if (!in_place) { - matrixSetToAnother()(ctx, n_band, temp, ld_temp, evc.get_pointer(), dmax); + ModuleBase::matrixSetToAnother()(ctx, n_band, temp, ld_temp, evc.get_pointer(), dmax); delmem_complex_op()(temp); } delmem_complex_op()(hcc); @@ -222,7 +222,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* pHamilt->ops->hPsi(hpsi_in); // calculate the related elements in hcc - gemv_op()(ctx, 'C', psi_nc, nstart, &one, psi, psi_nc, hpsi, 1, &zero, hcc + i * nstart, 1); + ModuleBase::gemv_op()(ctx, 'C', psi_nc, nstart, &one, psi, psi_nc, hpsi, 1, &zero, hcc + i * nstart, 1); } T* spsi = temp; @@ -232,18 +232,18 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* syncmem_complex_op()(ppsi, psi + i * psi_nc, psi_nc); pHamilt->sPsi(ppsi, spsi, dmin, dmin, 1); - gemv_op()(ctx, - 'C', - psi_nc, - nstart, - &one, - psi, - psi_nc, // nbasis - spsi, - 1, - &zero, - scc + i * nstart, - 1); + ModuleBase::gemv_op()(ctx, + 'C', + psi_nc, + nstart, + &one, + psi, + psi_nc, // nbasis + spsi, + 1, + &zero, + scc + i * nstart, + 1); } delmem_complex_op()(temp); } @@ -264,13 +264,13 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi); pHamilt->ops->hPsi(hpsi_in); - gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, hpsi, dmax, &zero, hcc, nstart); + ModuleBase::gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, hpsi, dmax, &zero, hcc, nstart); T* spsi = temp; // do sPsi for all bands pHamilt->sPsi(ppsi, spsi, psi_temp.get_nbasis(), psi_temp.get_nbasis(), psi_temp.get_nbands()); - gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, spsi, dmax, &zero, scc, nstart); + ModuleBase::gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, spsi, dmax, &zero, scc, nstart); delmem_complex_op()(temp); add_to_hcc(hcc, nstart); @@ -315,20 +315,20 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* // because psi and evc are different here, // I think if psi and evc are the same, // there may be problems, mohan 2011-01-01 - gemm_op()(ctx, - 'N', - 'N', - dmax, - n_band, - nstart, - &one, - psi, // dmax * nstart - dmax, - vcc, // nstart * n_band - nstart, - &zero, - evc.get_pointer(), - dmax); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + dmax, + n_band, + nstart, + &one, + psi, // dmax * nstart + dmax, + vcc, // nstart * n_band + nstart, + &zero, + evc.get_pointer(), + dmax); } else { @@ -338,20 +338,20 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* // resmem_complex_op()(ctx, evctemp, n_band * dmin, "DiagSub::evctemp"); // setmem_complex_op()(ctx, evctemp, 0, n_band * dmin); - gemm_op()(ctx, - 'N', - 'N', - dmin, - n_band, - nstart, - &one, - psi, // dmin * nstart - dmax, - vcc, // nstart * n_band - nstart, - &zero, - evc.get_pointer(), - dmax); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + dmin, + n_band, + nstart, + &one, + psi, // dmin * nstart + dmax, + vcc, // nstart * n_band + nstart, + &zero, + evc.get_pointer(), + dmax); // matrixSetToAnother()(ctx, n_band, evctemp, dmin, evc.get_pointer(), dmax); @@ -442,39 +442,39 @@ void DiagoIterAssist::cal_hs_subspace(const hamilt::Hamilt hpsi_info hpsi_in(&psi, all_bands_range, hphi); pHamilt->ops->hPsi(hpsi_in); - gemm_op()(ctx, - 'C', - 'N', - nstart, - nstart, - dmin, - &one, - psi.get_pointer(), - dmax, - hphi, - dmax, - &zero, - hcc, - nstart); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + nstart, + nstart, + dmin, + &one, + psi.get_pointer(), + dmax, + hphi, + dmax, + &zero, + hcc, + nstart); T* sphi = temp; // do sPsi for all bands pHamilt->sPsi(psi.get_pointer(), sphi, dmax, dmin, nstart); - gemm_op()(ctx, - 'C', - 'N', - nstart, - nstart, - dmin, - &one, - psi.get_pointer(), - dmax, - sphi, - dmax, - &zero, - scc, - nstart); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + nstart, + nstart, + dmin, + &one, + psi.get_pointer(), + dmax, + sphi, + dmax, + &zero, + scc, + nstart); } if (GlobalV::NPROC_IN_POOL > 1) @@ -509,20 +509,20 @@ void DiagoIterAssist::diag_responce( const T* hcc, DiagoIterAssist::diagH_LAPACK(nstart, nstart, hcc, scc, nstart, en, vcc); { // code block to calculate tar_mat - gemm_op()(ctx, - 'N', - 'N', - mat_col, - nstart, - nstart, - &one, - mat_in, // mat_col * nstart - mat_col, - vcc, // nstart * nstart - nstart, - &zero, - mat_out, - mat_col); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + mat_col, + nstart, + nstart, + &one, + mat_in, // mat_col * nstart + mat_col, + vcc, // nstart * nstart + nstart, + &zero, + mat_out, + mat_col); } delmem_complex_op()(vcc); @@ -557,21 +557,21 @@ void DiagoIterAssist::diag_subspace_psi(const T* hcc, T* temp = nullptr; resmem_complex_op()(temp, nstart * dmax, "DiagSub::temp"); setmem_complex_op()(temp, 0, nstart * dmax); - gemm_op()(ctx, - 'N', - 'N', - dmin, - n_band, - nstart, - &one, - evc.get_pointer(), // dmin * nstart - dmax, - vcc, // nstart * n_band - nstart, - &zero, - temp, - dmin); - matrixSetToAnother()(ctx, n_band, temp, dmin, evc.get_pointer(), dmax); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + dmin, + n_band, + nstart, + &one, + evc.get_pointer(), // dmin * nstart + dmax, + vcc, // nstart * n_band + nstart, + &zero, + temp, + dmin); + ModuleBase::matrixSetToAnother()(ctx, n_band, temp, dmin, evc.get_pointer(), dmax); delmem_complex_op()(temp); } diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index 68075fc111..d03b37b848 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -60,7 +60,7 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, #ifdef __MPI if (nbands > 0 && PARAM.inp.bndpar > 1) { - Parallel_Common::bcast_dev(this->ctx, &psi(ik, 0, 0), npwx * nbands, PARAPW_WORLD, &psi_cpu(ik, 0, 0)); + Parallel_Common::bcast_dev(&psi(ik, 0, 0), npwx * nbands, PARAPW_WORLD, &psi_cpu(ik, 0, 0)); MPI_Bcast(&pes->ekb(ik, 0), nbands, MPI_DOUBLE, 0, PARAPW_WORLD); } #endif diff --git a/source/module_hsolver/kernels/test/CMakeLists.txt b/source/module_hsolver/kernels/test/CMakeLists.txt index c8d1f2cdd9..5fe6bf4a24 100644 --- a/source/module_hsolver/kernels/test/CMakeLists.txt +++ b/source/module_hsolver/kernels/test/CMakeLists.txt @@ -5,13 +5,7 @@ if(USE_CUDA OR USE_ROCM) AddTest( TARGET Hsolver_Kernels_UTs LIBS parameter ${math_libs} base device - SOURCES math_kernel_test.cpp math_dngvd_test.cpp - ) -elseif() - AddTest( - TARGET Hsolver_Kernels_UTs - LIBS parameter ${math_libs} base device - SOURCES math_kernel_test.cpp ../../../module_base/blas_connector.cpp + SOURCES math_dngvd_test.cpp ) endif() diff --git a/source/module_hsolver/kernels/test/math_dngvd_test.cpp b/source/module_hsolver/kernels/test/math_dngvd_test.cpp index a67b18d4be..d8f2376890 100644 --- a/source/module_hsolver/kernels/test/math_dngvd_test.cpp +++ b/source/module_hsolver/kernels/test/math_dngvd_test.cpp @@ -2,7 +2,7 @@ #include "module_base/lapack_connector.h" #include "module_base/module_device/memory_op.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include #include @@ -144,13 +144,13 @@ TEST_F(TestModuleHsolverMathDngvd, transpose_gpu) synchronize_memory_op_C2G_Z()(device_transpose, transpose.data(), transpose.size()); // run - hsolver::createGpuBlasHandle(); - hsolver::matrixTranspose_op, base_device::DEVICE_GPU>()(gpu_ctx, + ModuleBase::createGpuBlasHandle(); + ModuleBase::matrixTranspose_op, base_device::DEVICE_GPU>()(gpu_ctx, 2, 3, device_transpose, device_transpose); - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); // copy transpose data from GPU to CPU std::vector> transpose_result = { diff --git a/source/module_hsolver/kernels/test/perf_math_kernel.cpp b/source/module_hsolver/kernels/test/perf_math_kernel.cpp index b2b0704a9d..e0a955ccb5 100644 --- a/source/module_hsolver/kernels/test/perf_math_kernel.cpp +++ b/source/module_hsolver/kernels/test/perf_math_kernel.cpp @@ -1,7 +1,7 @@ #include "module_base/blas_connector.h" #include "module_base/constants.h" #include "module_base/module_device/memory_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include #include @@ -114,7 +114,7 @@ class PerfModuleHsolverMathKernel : public benchmark::Fixture { resize_memory_op_double()(test_dvector_a_gpu, dim_vector); synchronize_memory_op_double()(test_dvector_a_gpu, test_dvector_a, dim_vector); - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM @@ -125,36 +125,36 @@ class PerfModuleHsolverMathKernel : public benchmark::Fixture { delete[] result_zvector; delete[] test_dvector_a; #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM } // OPs need benchmark // CPU operator - using zdot_real_cpu_op = hsolver::dot_real_op, base_device::DEVICE_CPU>; + using zdot_real_cpu_op = ModuleBase::dot_real_op, base_device::DEVICE_CPU>; - using vector_div_constant_op_cpu = hsolver::vector_div_constant_op, base_device::DEVICE_CPU>; - using vector_mul_vector_op_cpu = hsolver::vector_mul_vector_op, base_device::DEVICE_CPU>; - using vector_div_vector_op_cpu = hsolver::vector_div_vector_op, base_device::DEVICE_CPU>; + using vector_div_constant_op_cpu = ModuleBase::vector_div_constant_op, base_device::DEVICE_CPU>; + using vector_mul_vector_op_cpu = ModuleBase::vector_mul_vector_op, base_device::DEVICE_CPU>; + using vector_div_vector_op_cpu = ModuleBase::vector_div_vector_op, base_device::DEVICE_CPU>; using constantvector_addORsub_constantVector_op_cpu - = hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_CPU>; - using axpy_op_cpu = hsolver::axpy_op, base_device::DEVICE_CPU>; - using scal_op_cpu = hsolver::scal_op; - using gemv_op_cpu = hsolver::gemv_op, base_device::DEVICE_CPU>; + = ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_CPU>; + using axpy_op_cpu = ModuleBase::axpy_op, base_device::DEVICE_CPU>; + using scal_op_cpu = ModuleBase::scal_op; + using gemv_op_cpu = ModuleBase::gemv_op, base_device::DEVICE_CPU>; #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM // GPU operator - using zdot_real_gpu_op = hsolver::dot_real_op, base_device::DEVICE_GPU>; + using zdot_real_gpu_op = ModuleBase::dot_real_op, base_device::DEVICE_GPU>; - using vector_div_constant_op_gpu = hsolver::vector_div_constant_op, base_device::DEVICE_GPU>; - using vector_mul_vector_op_gpu = hsolver::vector_mul_vector_op, base_device::DEVICE_GPU>; - using vector_div_vector_op_gpu = hsolver::vector_div_vector_op, base_device::DEVICE_GPU>; + using vector_div_constant_op_gpu = ModuleBase::vector_div_constant_op, base_device::DEVICE_GPU>; + using vector_mul_vector_op_gpu = ModuleBase::vector_mul_vector_op, base_device::DEVICE_GPU>; + using vector_div_vector_op_gpu = ModuleBase::vector_div_vector_op, base_device::DEVICE_GPU>; using constantvector_addORsub_constantVector_op_gpu - = hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>; - using axpy_op_gpu = hsolver::axpy_op, base_device::DEVICE_GPU>; - using scal_op_gpu = hsolver::scal_op; + = ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>; + using axpy_op_gpu = ModuleBase::axpy_op, base_device::DEVICE_GPU>; + using scal_op_gpu = ModuleBase::scal_op; #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM }; diff --git a/source/module_hsolver/test/CMakeLists.txt b/source/module_hsolver/test/CMakeLists.txt index fdb447a09d..e44171912c 100644 --- a/source/module_hsolver/test/CMakeLists.txt +++ b/source/module_hsolver/test/CMakeLists.txt @@ -114,7 +114,6 @@ if (ENABLE_MPI) TARGET HSolver_LCAO_cusolver LIBS parameter ${math_libs} base psi device SOURCES diago_lcao_cusolver_test.cpp ../diago_cusolver.cpp ../diago_scalapack.cpp - ../kernels/math_kernel_op.cpp ../kernels/dngvd_op.cpp ../kernels/cuda/diag_cusolver.cu ) diff --git a/source/module_hsolver/test/diago_bpcg_test.cpp b/source/module_hsolver/test/diago_bpcg_test.cpp index 8978334106..e6af8b5b5e 100644 --- a/source/module_hsolver/test/diago_bpcg_test.cpp +++ b/source/module_hsolver/test/diago_bpcg_test.cpp @@ -144,7 +144,7 @@ class DiagoBPCGPrepare base_device::DEVICE_CPU *ctx = {}; // hpsi_out(dim * nvec) = h_mat(dim * dim) * psi_in(dim * nvec) - hsolver::gemm_op()( + ModuleBase::gemm_op()( ctx, 'N', 'N', dim, nvec, dim, one_, diff --git a/source/module_lr/operator_casida/operator_lr_diag.h b/source/module_lr/operator_casida/operator_lr_diag.h index 99a61d90df..a739b81991 100644 --- a/source/module_lr/operator_casida/operator_lr_diag.h +++ b/source/module_lr/operator_casida/operator_lr_diag.h @@ -1,6 +1,6 @@ #pragma once #include "module_lr/utils/lr_util.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_hamilt_general/operator.h" #ifdef __MPI #include "module_base/parallel_common.h" @@ -46,7 +46,7 @@ namespace LR const bool is_first_node = false)const override { ModuleBase::TITLE("OperatorLRDiag", "act"); - hsolver::vector_mul_vector_op()(this->ctx, + ModuleBase::vector_mul_vector_op()(this->ctx, nk * pX.get_local_size(), // local size of particle-hole basis hpsi, psi_in,