Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
 into refactor
  • Loading branch information
YuLiu98 committed Aug 13, 2024
2 parents 7f5594e + a15a6f6 commit 34fbc26
Show file tree
Hide file tree
Showing 53 changed files with 611 additions and 428 deletions.
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ OBJS_HSOLVER=diago_cg.o\
diago_david.o\
diago_dav_subspace.o\
diago_bpcg.o\
hsolver.o\
hsolver_pw.o\
hsolver_lcaopw.o\
hsolver_pw_sdft.o\
Expand Down
23 changes: 20 additions & 3 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,15 +468,32 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
if (firstscf)
{
firstscf = false;
hsolver_error = this->phsol->cal_hsolerror(diag_ethr);
hsolver_error = hsolver::cal_hsolve_error(PARAM.inp.basis_type,
PARAM.inp.esolver_type,
diag_ethr,
GlobalV::nelec);

// The error of HSolver is larger than drho,
// so a more precise HSolver should be excuconv_elected.
if (hsolver_error > drho)
{
diag_ethr = this->phsol->reset_diagethr(GlobalV::ofs_running, hsolver_error, drho, diag_ethr);
diag_ethr = hsolver::reset_diag_ethr(GlobalV::ofs_running,
PARAM.inp.basis_type,
PARAM.inp.esolver_type,
GlobalV::precision_flag,
hsolver_error,
drho,
diag_ethr,
GlobalV::nelec);

this->hamilt2density(istep, iter, diag_ethr);

drho = p_chgmix->get_drho(pelec->charge, GlobalV::nelec);
hsolver_error = this->phsol->cal_hsolerror(diag_ethr);

hsolver_error = hsolver::cal_hsolve_error(PARAM.inp.basis_type,
PARAM.inp.esolver_type,
diag_ethr,
GlobalV::nelec);
}
}
// mixing will restart at this->p_chgmix->mixing_restart steps
Expand Down
7 changes: 0 additions & 7 deletions source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,6 @@ namespace ModuleESolver
ModuleBase::WARNING_QUIT("ESolver_KS_PW::hamilt2density", "psig lifetime is expired");
}

// // from HSolverLIP
// this->phsol->solve(this->p_hamilt, // hamilt::Hamilt<T>* pHamilt,
// this->kspw_psi[0], // psi::Psi<T>& psi,
// this->pelec, // elecstate::ElecState<T>* pelec,
// psig.lock().get()[0]); // psi::Psi<T>& transform,


hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
hsolver_lip_obj.solve(this->p_hamilt,
this->kspw_psi[0],
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ list(APPEND objects
hsolver_lcaopw.cpp
hsolver_pw_sdft.cpp
diago_iter_assist.cpp

hsolver.cpp
)

if(ENABLE_LCAO)
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
this->device = base_device::get_device_type<Device>(this->ctx);
this->precondition = precondition_in;

this->one = &this->cs.one;
this->zero = &this->cs.zero;
this->neg_one = &this->cs.neg_one;
this->one = &one_;
this->zero = &zero_;
this->neg_one = &neg_one_;

test_david = 2;
// 1: check which function is called and which step is executed
Expand Down
5 changes: 2 additions & 3 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "diagh.h"
#include "module_hsolver/diag_comm_info.h"
#include "module_hsolver/diag_const_nums.h"

namespace hsolver
{
Expand Down Expand Up @@ -197,8 +196,8 @@ class DiagoDavid : public DiagH<T, Device>

using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;

const_nums<T> cs;
const T* one = nullptr, * zero = nullptr, * neg_one = nullptr;
const T *one = nullptr, *zero = nullptr, *neg_one = nullptr;
const T one_ = static_cast<T>(1.0), zero_ = static_cast<T>(0.0), neg_one_ = static_cast<T>(-1.0);
};
} // namespace hsolver

Expand Down
157 changes: 157 additions & 0 deletions source/module_hsolver/hsolver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#include "hsolver.h"

namespace hsolver
{

double reset_diag_ethr(std::ofstream& ofs_running,
const std::string basis_type,
const std::string esolver_type,
const std::string precision_flag_in,
const double hsover_error,
const double drho_in,
const double diag_ethr_in,
const double nelec_in)
{

double new_diag_ethr = 0.0;

if (basis_type == "pw" && esolver_type == "ksdft")
{
ofs_running << " Notice: Threshold on eigenvalues was too large.\n";

ModuleBase::WARNING("scf", "Threshold on eigenvalues was too large.");

ofs_running << " hsover_error=" << hsover_error << " > DRHO=" << drho_in << std::endl;
ofs_running << " Origin diag ethr = " << diag_ethr_in << std::endl;

new_diag_ethr = 0.1 * drho_in / nelec_in;

// It is essential for single precision implementation to keep the diag ethr
// value less or equal to the single-precision limit of convergence(0.5e-4).
// modified by denghuilu at 2023-05-15
if (precision_flag_in == "single")
{
new_diag_ethr = std::max(new_diag_ethr, static_cast<double>(0.5e-4));
}
ofs_running << " New diag ethr = " << new_diag_ethr << std::endl;
}
else
{
new_diag_ethr = 0.0;
}

return new_diag_ethr;
};

double cal_hsolve_error(const std::string basis_type,
const std::string esolver_type,
const double diag_ethr_in,
const double nelec_in)
{
if (basis_type == "pw" && esolver_type == "ksdft")
{
return diag_ethr_in * static_cast<double>(std::max(1.0, nelec_in));
}
else
{
return 0.0;
}
};


// double set_diagethr(double diag_ethr_in,
// const int istep,
// const int iter,
// const double drho,
// std::string basis_type,
// std::string esolver_type)
// {
// if (basis_type = "pw" && esolver_type = "ksdft")
// {
// // It is too complex now and should be modified.
// if (iter == 1)
// {
// if (std::abs(diag_ethr_in - 1.0e-2) < 1.0e-6)
// {
// if (GlobalV::init_chg == "file")
// {
// //======================================================
// // if you think that the starting potential is good
// // do not spoil it with a louly first diagonalization:
// // set a strict diag ethr in the input file
// // ()diago_the_init
// //======================================================
// diag_ethr_in = 1.0e-5;
// }
// else
// {
// //=======================================================
// // starting atomic potential is probably far from scf
// // don't waste iterations in the first diagonalization
// //=======================================================
// diag_ethr_in = 1.0e-2;
// }
// }

// if (GlobalV::CALCULATION == "md" || GlobalV::CALCULATION == "relax" || GlobalV::CALCULATION == "cell-relax")
// {
// diag_ethr_in = std::max(diag_ethr_in, static_cast<double>(GlobalV::PW_DIAG_THR));
// }
// }
// else
// {
// if (iter == 2)
// {
// diag_ethr_in = 1.e-2;
// }
// diag_ethr_in = std::min(diag_ethr_in,
// static_cast<double>(0.1) * drho
// / std::max(static_cast<double>(1.0), static_cast<double>(GlobalV::nelec)));
// }
// // It is essential for single precision implementation to keep the diag ethr
// // value less or equal to the single-precision limit of convergence(0.5e-4).
// // modified by denghuilu at 2023-05-15
// if (GlobalV::precision_flag == "single")
// {
// diag_ethr_in = std::max(diag_ethr_in, static_cast<double>(0.5e-4));
// }
// }
// else if (basis_type = "pw" && esolver_type = "sdft")
// {
// if (iter == 1)
// {
// if (istep == 0)
// {
// if (GlobalV::init_chg == "file")
// {
// diag_ethr_in = 1.0e-5;
// }
// diag_ethr_in = std::max(diag_ethr_in, GlobalV::PW_DIAG_THR);
// }
// else
// {
// diag_ethr_in = std::max(diag_ethr_in, 1.0e-5);
// }
// }
// else
// {
// if (GlobalV::NBANDS > 0 && this->stoiter.KS_ne > 1e-6)
// {
// diag_ethr_in = std::min(diag_ethr_in, 0.1 * drho / std::max(1.0, this->stoiter.KS_ne));
// }
// else
// {
// diag_ethr_in = 0.0;
// }
// }
// }
// else
// {
// diag_ethr_in = 0.0;
// }

// return 0.0;
// };


} // namespace hsolver
67 changes: 17 additions & 50 deletions source/module_hsolver/hsolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,43 +28,7 @@ class HSolver
psi::Psi<T, Device>& ppsi,
elecstate::ElecState* pes,
const std::string method,
const bool skip_charge = false)
{
return;
}

// virtual void solve(hamilt::Hamilt<T, Device>* phm,
// psi::Psi<T, Device>& ppsi,
// elecstate::ElecState* pes,
// double* out_eigenvalues,
// const std::vector<bool>& is_occupied_in,
// const std::string method,
// const std::string calculation_type_in,
// const std::string basis_type_in,
// const bool use_paw_in,
// const bool use_uspp_in,
// const int rank_in_pool_in,
// const int nproc_in_pool_in,
// const int scf_iter_in,
// const bool need_subspace_in,
// const int diag_iter_max_in,
// const double pw_diag_thr_in,
// const bool skip_charge)
// {
// return;
// }

/// @brief solve function for lcao_in_pw
/// @param phm interface to hamilt
/// @param ppsi reference to psi
/// @param pes interface to elecstate
/// @param transform transformation matrix between lcao and pw
/// @param skip_charge
virtual void solve(hamilt::Hamilt<T, Device>* phm,
psi::Psi<T, Device>& ppsi,
elecstate::ElecState* pes,
psi::Psi<T, Device>& transform,
const bool skip_charge = false)
const bool skip_charge)
{
return;
}
Expand All @@ -80,7 +44,7 @@ class HSolver
const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,
const double pw_diag_thr_in,
const bool skip_charge)
{
return;
Expand All @@ -91,20 +55,23 @@ class HSolver
{
return 0.0;
}
};

// reset diagethr according to drho and hsolver_error
virtual Real reset_diagethr(std::ofstream& ofs_running, const Real hsover_error, const Real drho, Real diag_ethr_in)
{
return 0.0;
}
// reset diagethr according to drho and hsolver_error
double reset_diag_ethr(std::ofstream& ofs_running,
const std::string basis_type,
const std::string esolver_type,
const std::string precision_flag_in,
const double hsover_error,
const double drho_in,
const double diag_ethr_in,
const double nelec_in);

// calculate hsolver_error (for sdft, lcao and lcao-in-pw, we suppose the error is zero)
virtual Real cal_hsolerror(const Real diag_ethr_in)
{
return 0.0;
};

};
// calculate hsolver_error (for sdft, lcao and lcao-in-pw, we suppose the error is zero)
double cal_hsolve_error(const std::string basis_type,
const std::string esolver_type,
const double diag_ethr_in,
const double nelec_in);

} // namespace hsolver
#endif
34 changes: 0 additions & 34 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,13 +638,6 @@ void HSolverPW<T, Device>::output_iterInfo()
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
}

template <typename T, typename Device>
typename HSolverPW<T, Device>::Real HSolverPW<T, Device>::cal_hsolerror(const Real diag_ethr_in)
{
return diag_ethr_in * static_cast<Real>(std::max(1.0, GlobalV::nelec));
}

template <typename T, typename Device>
typename HSolverPW<T, Device>::Real HSolverPW<T, Device>::set_diagethr(Real diag_ethr_in,
const int istep,
Expand Down Expand Up @@ -702,33 +695,6 @@ typename HSolverPW<T, Device>::Real HSolverPW<T, Device>::set_diagethr(Real diag
return diag_ethr_in;
}

template <typename T, typename Device>
typename HSolverPW<T, Device>::Real HSolverPW<T, Device>::reset_diagethr(std::ofstream& ofs_running,
const Real hsover_error,
const Real drho,
Real diag_ethr_in)
{
ofs_running << " Notice: Threshold on eigenvalues was too large.\n";

ModuleBase::WARNING("scf", "Threshold on eigenvalues was too large.");

ofs_running << " hsover_error=" << hsover_error << " > DRHO=" << drho << std::endl;
ofs_running << " Origin diag ethr = " << diag_ethr_in << std::endl;

diag_ethr_in = 0.1 * drho / GlobalV::nelec;

// It is essential for single precision implementation to keep the diag ethr
// value less or equal to the single-precision limit of convergence(0.5e-4).
// modified by denghuilu at 2023-05-15
if (GlobalV::precision_flag == "single")
{
diag_ethr_in = std::max(diag_ethr_in, static_cast<Real>(0.5e-4));
}
ofs_running << " New diag ethr = " << diag_ethr_in << std::endl;

return diag_ethr_in;
}

template class HSolverPW<std::complex<float>, base_device::DEVICE_CPU>;
template class HSolverPW<std::complex<double>, base_device::DEVICE_CPU>;
#if ((defined __CUDA) || (defined __ROCM))
Expand Down
Loading

0 comments on commit 34fbc26

Please sign in to comment.