Skip to content

Commit

Permalink
Refactor: remove do_after_converge
Browse files Browse the repository at this point in the history
  • Loading branch information
YuLiu98 committed Aug 8, 2024
1 parent 1bea224 commit 0ae2e6b
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 114 deletions.
6 changes: 0 additions & 6 deletions source/module_elecstate/elecstate_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,11 +421,6 @@ void ElecState::print_etot(const bool converged,
{FmtTable::Align::LEFT, FmtTable::Align::CENTER});
table << titles << energies_Ry << energies_eV;
GlobalV::ofs_running << table.str() << std::endl;
if (iter_in == 1) // pengfei Li added 2015-1-31
{
this->f_en.etot_old = this->f_en.etot;
}
this->f_en.etot_delta = this->f_en.etot - this->f_en.etot_old;
if (GlobalV::OUT_LEVEL == "ie" || GlobalV::OUT_LEVEL == "m") // xiaohui add 'm' option, 2015-09-16
{
std::vector<double> mag;
Expand Down Expand Up @@ -462,7 +457,6 @@ void ElecState::print_etot(const bool converged,
duration,
6);
}
this->f_en.etot_old = this->f_en.etot;
return;
}

Expand Down
34 changes: 23 additions & 11 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,12 +563,6 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
}
this->print_iter(iter, drho, dkin, duration, diag_ethr);

// add a energy threshold for SCF convergence
if (this->conv_elec == 0) // only check when density is not converged
{
this->conv_elec = ( iter != 1 && std::abs(this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV) < this->scf_ene_thr );
}

// 12) Json, need to be moved to somewhere else
#ifdef __RAPIDJSON
// add Json of scf mag
Expand All @@ -584,11 +578,7 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
if (this->conv_elec)
{
this->niter = iter;
bool stop = this->do_after_converge(iter);
if (stop)
{
break;
}
break;
}

// notice for restart
Expand All @@ -615,6 +605,28 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
return;
};

template <typename T, typename Device>
void ESolver_KS<T, Device>::iter_finish(const int iter)
{
// 1 means Harris-Foulkes functional
// 2 means Kohn-Sham functional
this->pelec->cal_energies(2);

if (iter == 1)
{
this->pelec->f_en.etot_old = this->pelec->f_en.etot;
}
this->pelec->f_en.etot_delta = this->pelec->f_en.etot - this->pelec->f_en.etot_old;
this->pelec->f_en.etot_old = this->pelec->f_en.etot;

// add a energy threshold for SCF convergence
if (this->conv_elec == 0) // only check when density is not converged
{
this->conv_elec
= (iter != 1 && std::abs(this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV) < this->scf_ene_thr);
}
}

//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
template <typename T, typename Device>
void ESolver_KS<T, Device>::after_scf(const int istep)
Expand Down
9 changes: 3 additions & 6 deletions source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,15 @@ class ESolver_KS : public ESolver_FP
virtual void iter_init(const int istep, const int iter) {};

//! Something to do after hamilt2density function in each iter loop.
virtual void iter_finish(const int iter) {};
virtual void iter_finish(const int iter);

//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
virtual void after_scf(const int istep);

//! <Temporary> It should be replaced by a function in Hamilt Class
virtual void update_pot(const int istep, const int iter) {};

//! choose strategy when charge density convergence achieved
virtual bool do_after_converge(int& iter){return true;}

protected:
protected:

// Print the headline on the screen:
// ITER ETOT(eV) EDIFF(eV) DRHO TIME(s)
Expand Down
79 changes: 28 additions & 51 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,13 +916,15 @@ void ESolver_KS_LCAO<TK, TR>::update_pot(const int istep, const int iter)
//! 3) output exx matrix
//! 4) output charge density and density matrix
//! 5) cal_MW? (why put it here?)
//! 6) calculate the total energy?
//------------------------------------------------------------------------------
template <typename TK, typename TR>
void ESolver_KS_LCAO<TK, TR>::iter_finish(int iter)
{
ModuleBase::TITLE("ESolver_KS_LCAO", "iter_finish");

// call iter_finish() of ESolver_KS
ESolver_KS<TK>::iter_finish(iter);

// 1) mix density matrix if mixing_restart + mixing_dmr + not first
// mixing_restart at every iter
if (GlobalV::MIXING_RESTART > 0 && this->p_chgmix->mixing_restart_count > 0 && GlobalV::MIXING_DMR)
Expand Down Expand Up @@ -980,16 +982,30 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int iter)
GlobalC::restart.save_disk("Eexx", 0, 1, &this->pelec->f_en.exx);
}
}
#endif

// 4) output charge density and density matrix
bool print = false;
if (this->out_freq_elec && iter % this->out_freq_elec == 0)
if (GlobalC::exx_info.info_global.cal_exx && this->conv_elec)
{
print = true;
if (GlobalC::exx_info.info_ri.real_number)
{
this->conv_elec = this->exd->exx_after_converge(
*this->p_hamilt,
*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
this->kv,
iter);
}
else
{
this->conv_elec = this->exc->exx_after_converge(
*this->p_hamilt,
*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
this->kv,
iter);
}
}
#endif

if (print)
// 4) output charge density and density matrix
if (this->out_freq_elec && iter % this->out_freq_elec == 0)
{
for (int is = 0; is < GlobalV::NSPIN; is++)
{
Expand Down Expand Up @@ -1055,8 +1071,11 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int iter)
sc.cal_MW(iter, this->p_hamilt);
}

// 6) calculate the total energy.
this->pelec->cal_energies(2);
// 6) use the converged occupation matrix for next MD/Relax SCF calculation
if (GlobalV::dft_plus_u && this->conv_elec)
{
GlobalC::dftu.initialed_locale = true;
}
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -1235,48 +1254,6 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
}
}

//------------------------------------------------------------------------------
//! the 15th function of ESolver_KS_LCAO: do_after_converge
//! mohan add 2024-05-11
//------------------------------------------------------------------------------
template <typename TK, typename TR>
bool ESolver_KS_LCAO<TK, TR>::do_after_converge(int& iter)
{
ModuleBase::TITLE("ESolver_KS_LCAO", "do_after_converge");

if (GlobalV::dft_plus_u)
{
// use the converged occupation matrix for next MD/Relax SCF calculation
GlobalC::dftu.initialed_locale = true;
}
// FIXME: for developer who want to test restarting DeePKS with same Descriptor/PDM in last MD step
// RUN: " GlobalC::ld.set_init_pdm(true); " can skip the calculation of PDM in the next iter_init

#ifdef __EXX
if (GlobalC::exx_info.info_global.cal_exx)
{
if (GlobalC::exx_info.info_ri.real_number) {
return this->exd->exx_after_converge(
*this->p_hamilt,
*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)
->get_DM(),
this->kv,
iter);
}
else {
return this->exc->exx_after_converge(
*this->p_hamilt,
*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)
->get_DM(),
this->kv,
iter);
}
}
#endif // __EXX

return true;
}

//------------------------------------------------------------------------------
//! the 20th,21th,22th functions of ESolver_KS_LCAO
//! mohan add 2024-05-11
Expand Down
2 changes: 0 additions & 2 deletions source/module_esolver/esolver_ks_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ class ESolver_KS_LCAO : public ESolver_KS<TK> {

virtual void after_scf(const int istep) override;

virtual bool do_after_converge(int& iter) override;

virtual void others(const int istep) override;

// we will get rid of this class soon, don't use it, mohan 2024-03-28
Expand Down
41 changes: 22 additions & 19 deletions source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,13 @@ namespace ModuleESolver
ModuleBase::timer::tick("ESolver_KS_LIP", "hamilt2density");
}

#ifdef __EXX
template <typename T>
bool ESolver_KS_LIP<T>::do_after_converge(int& iter)
void ESolver_KS_LIP<T>::iter_finish(const int iter)
{
if (GlobalC::exx_info.info_global.cal_exx)
ESolver_KS_PW<T>::iter_finish(iter);

#ifdef __EXX
if (GlobalC::exx_info.info_global.cal_exx && this->conv_elec)
{
// no separate_loop case
if (!GlobalC::exx_info.info_global.separate_loop)
Expand All @@ -226,24 +228,24 @@ namespace ModuleESolver
// in first scf loop, exx updated once in beginning,
// in second scf loop, exx updated every iter

if (this->two_level_step) {
return true;
} else
if (!this->two_level_step)
{
// update exx and redo scf
XC_Functional::set_xc_type(GlobalC::ucell.atoms[0].ncpp.xc_func);
iter = 0;
std::cout << " Entering 2nd SCF, where EXX is updated" << std::endl;
this->two_level_step++;
return false;
this->conv_elec = false;
}
}
// has separate_loop case
// exx converged or get max exx steps
else if (this->two_level_step == GlobalC::exx_info.info_global.hybrid_step
|| (iter == 1 && this->two_level_step != 0)) {
return true;
} else
|| (iter == 1 && this->two_level_step != 0))
{
this->conv_elec = true;
}
else
{
// update exx and redo scf
if (this->two_level_step == 0)
Expand All @@ -252,23 +254,24 @@ namespace ModuleESolver
}

std::cout << " Updating EXX " << std::flush;
timeval t_start; gettimeofday(&t_start, nullptr);
timeval t_start;
gettimeofday(&t_start, nullptr);

this->exx_lip->cal_exx();
iter = 0;
this->two_level_step++;

timeval t_end; gettimeofday(&t_end, nullptr);
std::cout << "and rerun SCF\t"
<< std::setprecision(3) << std::setiosflags(std::ios::scientific)
<< (double)(t_end.tv_sec - t_start.tv_sec) + (double)(t_end.tv_usec - t_start.tv_usec) / 1000000.0
<< std::defaultfloat << " (s)" << std::endl;
return false;
timeval t_end;
gettimeofday(&t_end, nullptr);
std::cout << "and rerun SCF\t" << std::setprecision(3) << std::setiosflags(std::ios::scientific)
<< (double)(t_end.tv_sec - t_start.tv_sec)
+ (double)(t_end.tv_usec - t_start.tv_usec) / 1000000.0
<< std::defaultfloat << " (s)" << std::endl;
this->conv_elec = false;
}
}
else { return true; }
}
#endif
}

template <typename T>
void ESolver_KS_LIP<T>::after_all_runners()
Expand Down
13 changes: 6 additions & 7 deletions source/module_esolver/esolver_ks_lcaopw.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,19 @@ namespace ModuleESolver
virtual void hamilt2density(const int istep, const int iter, const double ethr) override;

void before_all_runners(const Input_para& inp, UnitCell& cell) override;
void iter_init(const int istep, const int iter) override;
void after_all_runners()override;

protected:

virtual void allocate_hsolver() override;
virtual void deallocate_hsolver() override;
virtual void allocate_hamilt() override;
virtual void deallocate_hamilt() override;
virtual void iter_init(const int istep, const int iter) override;
virtual void iter_finish(const int iter) override;
virtual void allocate_hsolver() override;
virtual void deallocate_hsolver() override;
virtual void allocate_hamilt() override;
virtual void deallocate_hamilt() override;

#ifdef __EXX
std::unique_ptr<Exx_Lip<T>> exx_lip;
int two_level_step = 0;
virtual bool do_after_converge(int& iter) override;
#endif

};
Expand Down
14 changes: 3 additions & 11 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ void ESolver_KS_PW<T, Device>::update_pot(const int istep, const int iter)
template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::iter_finish(const int iter)
{
// call iter_finish() of ESolver_KS
ESolver_KS<T, Device>::iter_finish(iter);

// liuyu 2023-10-24
// D in uspp need vloc, thus needs update when veff updated
// calculate the effective coefficient matrix for non-local pseudopotential
Expand All @@ -465,18 +468,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(const int iter)
GlobalC::ppcell.cal_effective_D(veff, this->pw_rhod, GlobalC::ucell);
}

// 1 means Harris-Foulkes functional
// 2 means Kohn-Sham functional
const int energy_type = 2;
this->pelec->cal_energies(2);

bool print = false;
if (this->out_freq_elec && iter % this->out_freq_elec == 0)
{
print = true;
}

if (print == true)
{
if (PARAM.inp.out_chg > 0)
{
Expand Down
3 changes: 2 additions & 1 deletion source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ void ESolver_SDFT_PW::before_scf(const int istep)

void ESolver_SDFT_PW::iter_finish(int iter)
{
this->pelec->cal_energies(2);
// call iter_finish() of ESolver_KS
ESolver_KS<std::complex<double>>::iter_finish(iter);
}

void ESolver_SDFT_PW::after_scf(const int istep)
Expand Down

0 comments on commit 0ae2e6b

Please sign in to comment.