Skip to content

Commit

Permalink
Refactor: update md parameters to apply const param_in
Browse files Browse the repository at this point in the history
  • Loading branch information
YuLiu98 committed Jul 26, 2024
1 parent 7a6fe6d commit f5ca124
Show file tree
Hide file tree
Showing 29 changed files with 220 additions and 206 deletions.
14 changes: 9 additions & 5 deletions source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,17 @@ void Driver::driver_run() {
const std::string cal_type = GlobalV::CALCULATION;

//! 4: different types of calculations
if (cal_type == "md") {
Run_MD::md_line(GlobalC::ucell, p_esolver, INPUT.mdp);
} else if (cal_type == "scf" || cal_type == "relax"
|| cal_type == "cell-relax") {
if (cal_type == "md")
{
Run_MD::md_line(GlobalC::ucell, p_esolver, PARAM);
}
else if (cal_type == "scf" || cal_type == "relax" || cal_type == "cell-relax")
{
Relax_Driver rl_driver;
rl_driver.relax_driver(p_esolver);
} else {
}
else
{
//! supported "other" functions:
//! nscf(PW,LCAO),
//! get_pchg(LCAO),
Expand Down
4 changes: 2 additions & 2 deletions source/module_hamilt_lcao/hamilt_lcaodft/LCAO_set_st.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "module_base/timer.h"
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h" // only for INPUT
#include "module_parameter/parameter.h"

namespace LCAO_domain
{
Expand Down Expand Up @@ -362,7 +362,7 @@ void build_ST_new(ForceStressArrays& fsr,
{
for (int k = 0; k < 3; k++)
{
tau1[k] = tau1[k] - atom1->vel[I1][k] * INPUT.mdp.md_dt / ucell.lat0;
tau1[k] = tau1[k] - atom1->vel[I1][k] * PARAM.mdp.md_dt / ucell.lat0;
}
}

Expand Down
18 changes: 9 additions & 9 deletions source/module_hamilt_lcao/module_tddft/propagator.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "propagator.h"

#include <complex>
#include <iostream>

#include "module_base/lapack_connector.h"
#include "module_base/scalapack_connector.h"
#include "module_io/input.h"
#include "module_parameter/parameter.h"

#include <complex>
#include <iostream>

namespace module_tddft
{
Expand Down Expand Up @@ -97,11 +97,11 @@ void Propagator::compute_propagator_cn2(const int nlocal,

// ->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// (2) compute Numerator & Denominator by GEADD
// Numerator = Stmp - i*para * Htmp; beta1 = - para = -0.25 * INPUT.mdp.md_dt
// Denominator = Stmp + i*para * Htmp; beta2 = para = 0.25 * INPUT.mdp.md_dt
// Numerator = Stmp - i*para * Htmp; beta1 = - para = -0.25 * PARAM.mdp.md_dt
// Denominator = Stmp + i*para * Htmp; beta2 = para = 0.25 * PARAM.mdp.md_dt
std::complex<double> alpha = {1.0, 0.0};
std::complex<double> beta1 = {0.0, -0.25 * INPUT.mdp.md_dt};
std::complex<double> beta2 = {0.0, 0.25 * INPUT.mdp.md_dt};
std::complex<double> beta1 = {0.0, -0.25 * PARAM.mdp.md_dt};
std::complex<double> beta2 = {0.0, 0.25 * PARAM.mdp.md_dt};

ScalapackConnector::geadd('N',
nlocal,
Expand Down Expand Up @@ -350,7 +350,7 @@ void Propagator::compute_propagator_taylor(const int nlocal,
} // loop ipcol
} // loop iprow

std::complex<double> beta = {0.0, -0.5 * INPUT.mdp.md_dt / tag}; // for ETRS tag=2 , for taylor tag=1
std::complex<double> beta = {0.0, -0.5 * PARAM.mdp.md_dt / tag}; // for ETRS tag=2 , for taylor tag=1

//->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// invert Stmp
Expand Down
4 changes: 1 addition & 3 deletions source/module_io/input.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ class Input
}

// They will be removed.
int cond_dtbatch;
int cond_dtbatch;
int nche_sto;
double md_tfirst;
MD_para mdp;
int* orbital_corr = nullptr; ///< which correlated orbitals need corrected ;
double* hubbard_u = nullptr; ///< Hubbard Coulomb interaction parameter U(ev)
std::string stru_file; // file contains atomic positions -- xiaohui modify
Expand Down
4 changes: 2 additions & 2 deletions source/module_io/input_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ void Input_Conv::Convert()
if (PARAM.inp.calculation == "md" && PARAM.mdp.md_restart) // md restart liuyu add 2023-04-12
{
int istep = 0;
MD_func::current_md_info(GlobalV::MY_RANK, GlobalV::global_readin_dir, istep, INPUT.md_tfirst);
INPUT.md_tfirst *= ModuleBase::Hartree_to_K;
double temperature = 0.0;
MD_func::current_md_info(GlobalV::MY_RANK, GlobalV::global_readin_dir, istep, temperature);
if (PARAM.inp.read_file_dir == "auto")
{
GlobalV::stru_file = INPUT.stru_file = GlobalV::global_stru_dir + "STRU_MD_" + std::to_string(istep);
Expand Down
1 change: 0 additions & 1 deletion source/module_io/input_conv_tmp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ void Input_Conv::tmp_convert()
INPUT.stru_file = PARAM.inp.stru_file;
INPUT.cond_dtbatch = PARAM.inp.cond_dtbatch;
INPUT.nche_sto = PARAM.inp.nche_sto;
INPUT.mdp = PARAM.mdp;

const int ntype = PARAM.inp.ntype;
delete[] INPUT.orbital_corr;
Expand Down
26 changes: 14 additions & 12 deletions source/module_io/print_info.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "print_info.h"
#include "module_io/input.h"
#include "../module_base/global_variable.h"
//#include "../module_cell/klist.h"

#include "module_base/global_variable.h"
#include "module_parameter/parameter.h"

Print_Info::Print_Info(){}

Expand Down Expand Up @@ -38,32 +38,34 @@ void Print_Info::setup_parameters(UnitCell &ucell, K_Vectors &kv)

std::cout << " ---------------------------------------------------------" << std::endl;

if(INPUT.mdp.md_type == "fire")
if (PARAM.mdp.md_type == "fire")
{
std::cout << " ENSEMBLE : " << "FIRE" << std::endl;
}
else if(INPUT.mdp.md_type == "nve")
else if (PARAM.mdp.md_type == "nve")
{
std::cout << " ENSEMBLE : " << "NVE" << std::endl;
}
else if(INPUT.mdp.md_type == "nvt")
else if (PARAM.mdp.md_type == "nvt")
{
std::cout << " ENSEMBLE : " << "NVT mode: " << INPUT.mdp.md_thermostat << std::endl;
std::cout << " ENSEMBLE : "
<< "NVT mode: " << PARAM.mdp.md_thermostat << std::endl;
}
else if(INPUT.mdp.md_type == "npt")
else if (PARAM.mdp.md_type == "npt")
{
std::cout << " ENSEMBLE : " << "NPT mode: " << INPUT.mdp.md_pmode << std::endl;
std::cout << " ENSEMBLE : "
<< "NPT mode: " << PARAM.mdp.md_pmode << std::endl;
}
else if(INPUT.mdp.md_type == "langevin")
else if (PARAM.mdp.md_type == "langevin")
{
std::cout << " ENSEMBLE : " << "Langevin" << std::endl;
}
else if(INPUT.mdp.md_type == "msst")
else if (PARAM.mdp.md_type == "msst")
{
std::cout << " ENSEMBLE : " << "MSST" << std::endl;
}

std::cout << " Time interval(fs) : " << INPUT.mdp.md_dt << std::endl;
std::cout << " Time interval(fs) : " << PARAM.mdp.md_dt << std::endl;
}
std::cout << " ---------------------------------------------------------" << std::endl;

Expand Down
12 changes: 12 additions & 0 deletions source/module_io/read_input_item_md.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ void ReadInput::item_md()
{
Input_Item item("md_tlast");
item.annotation = "temperature last";
item.reset_value = [](const Input_Item& item, Parameter& para) {
if (para.mdp.md_tlast < 0)
{
para.input.mdp.md_tlast = para.mdp.md_tfirst;
}
};
read_sync_double(input.mdp.md_tlast);
this->add_item(item);
}
Expand Down Expand Up @@ -296,6 +302,12 @@ void ReadInput::item_md()
{
Input_Item item("md_pcouple");
item.annotation = "whether couple different components: xyz, xy, yz, xz, none";
item.reset_value = [](const Input_Item& item, Parameter& para) {
if (para.mdp.md_pmode == "iso")
{
para.input.mdp.md_pcouple = "xyz";
}
};
read_sync_string(input.mdp.md_pcouple);
this->add_item(item);
}
Expand Down
4 changes: 0 additions & 4 deletions source/module_io/read_input_item_relax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,8 @@ void ReadInput::item_relax()
ModuleBase::WARNING("ReadInput", "both force_thr and force_thr_ev are set, use force_thr");
para.input.force_thr_ev = para.input.force_thr * 13.6058 / 0.529177;
}
para.input.mdp.force_thr = para.input.force_thr; // temperaory
};
sync_double(input.force_thr);
add_double_bcast(input.mdp.force_thr);
this->add_item(item);
}
{
Expand Down Expand Up @@ -295,10 +293,8 @@ void ReadInput::item_relax()
{
para.input.cal_stress = true;
}
para.input.mdp.cal_stress = para.input.cal_stress; // temperaory
};
read_sync_bool(input.cal_stress);
add_bool_bcast(input.mdp.cal_stress);
this->add_item(item);
}
{
Expand Down
25 changes: 13 additions & 12 deletions source/module_md/fire.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
#endif
#include "module_base/timer.h"

FIRE::FIRE(MD_para& MD_para_in, UnitCell& unit_in) : MD_base(MD_para_in, unit_in)
FIRE::FIRE(const Parameter& param_in, UnitCell& unit_in) : MD_base(param_in, unit_in)
{
force_thr = param_in.inp.force_thr;
dt_max = -1.0;
alpha_start = 0.10;
alpha = alpha_start;
Expand Down Expand Up @@ -82,18 +83,18 @@ void FIRE::print_md(std::ofstream& ofs, const bool& cal_stress)

void FIRE::write_restart(const std::string& global_out_dir)
{
if (!mdp.my_rank)
if (!my_rank)
{
std::stringstream ssc;
ssc << global_out_dir << "Restart_md.dat";
std::ofstream file(ssc.str().c_str());

file << step_ + step_rst_ << std::endl;
file << mdp.md_tfirst << std::endl;
file << md_tfirst << std::endl;
file << alpha << std::endl;
file << negative_count << std::endl;
file << dt_max << std::endl;
file << mdp.md_dt << std::endl;
file << md_dt << std::endl;
file.close();
}
#ifdef __MPI
Expand All @@ -108,7 +109,7 @@ void FIRE::restart(const std::string& global_readin_dir)
{
bool ok = true;

if (!mdp.my_rank)
if (!my_rank)
{
std::stringstream ssc;
ssc << global_readin_dir << "Restart_md.dat";
Expand All @@ -121,7 +122,7 @@ void FIRE::restart(const std::string& global_readin_dir)

if (ok)
{
file >> step_rst_ >> mdp.md_tfirst >> alpha >> negative_count >> dt_max >> mdp.md_dt;
file >> step_rst_ >> md_tfirst >> alpha >> negative_count >> dt_max >> md_dt;
file.close();
}
}
Expand All @@ -137,11 +138,11 @@ void FIRE::restart(const std::string& global_readin_dir)

#ifdef __MPI
MPI_Bcast(&step_rst_, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(&mdp.md_tfirst, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
MPI_Bcast(&md_tfirst, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
MPI_Bcast(&alpha, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
MPI_Bcast(&negative_count, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(&dt_max, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
MPI_Bcast(&mdp.md_dt, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
MPI_Bcast(&md_dt, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
#endif

return;
Expand All @@ -163,7 +164,7 @@ void FIRE::check_force(void)
}
}

if (2.0 * max < mdp.force_thr)
if (2.0 * max < force_thr)
{
stop = true;
}
Expand All @@ -181,7 +182,7 @@ void FIRE::check_fire(void)
/// initial dt_max
if (dt_max < 0)
{
dt_max = 2.5 * mdp.md_dt;
dt_max = 2.5 * md_dt;
}

for (int i = 0; i < ucell.nat; ++i)
Expand All @@ -207,13 +208,13 @@ void FIRE::check_fire(void)
negative_count++;
if (negative_count >= n_min)
{
mdp.md_dt = std::min(mdp.md_dt * finc, dt_max);
md_dt = std::min(md_dt * finc, dt_max);
alpha *= f_alpha;
}
}
else
{
mdp.md_dt *= fdec;
md_dt *= fdec;
negative_count = 0;

for (int i = 0; i < ucell.nat; ++i)
Expand Down
4 changes: 2 additions & 2 deletions source/module_md/fire.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
class FIRE : public MD_base
{
public:

FIRE(MD_para& MD_para_in, UnitCell& unit_in);
FIRE(const Parameter& param_in, UnitCell& unit_in);

~FIRE();

Expand Down Expand Up @@ -51,6 +50,7 @@ class FIRE : public MD_base
int n_min; ///< n_min
double dt_max; ///< dt_max
int negative_count; ///< Negative count
double force_thr = 1.0e-3; ///< force convergence threshold in FIRE method
};

#endif
12 changes: 6 additions & 6 deletions source/module_md/langevin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
#include "module_base/parallel_common.h"
#include "module_base/timer.h"

Langevin::Langevin(MD_para& MD_para_in, UnitCell& unit_in) : MD_base(MD_para_in, unit_in)
Langevin::Langevin(const Parameter& param_in, UnitCell& unit_in) : MD_base(param_in, unit_in)
{
/// convert to a.u. unit
assert(ModuleBase::AU_to_FS!=0.0);

mdp.md_damp /= ModuleBase::AU_to_FS;
md_damp = mdp.md_damp / ModuleBase::AU_to_FS;

assert(ucell.nat>0);

Expand Down Expand Up @@ -85,16 +85,16 @@ void Langevin::restart(const std::string& global_readin_dir)

void Langevin::post_force(void)
{
if (mdp.my_rank == 0)
if (my_rank == 0)
{
double t_target = MD_func::target_temp(step_ + step_rst_, mdp.md_nstep, mdp.md_tfirst, mdp.md_tlast);
double t_target = MD_func::target_temp(step_ + step_rst_, mdp.md_nstep, md_tfirst, md_tlast);
ModuleBase::Vector3<double> fictitious_force;
for (int i = 0; i < ucell.nat; ++i)
{
fictitious_force = -allmass[i] * vel[i] / mdp.md_damp;
fictitious_force = -allmass[i] * vel[i] / md_damp;
for (int j = 0; j < 3; ++j)
{
fictitious_force[j] += sqrt(24.0 * t_target * allmass[i] / mdp.md_damp / mdp.md_dt)
fictitious_force[j] += sqrt(24.0 * t_target * allmass[i] / md_damp / md_dt)
* (static_cast<double>(std::rand()) / RAND_MAX - 0.5);
}
total_force[i] = force[i] + fictitious_force;
Expand Down
3 changes: 2 additions & 1 deletion source/module_md/langevin.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class Langevin : public MD_base
{
public:
Langevin(MD_para& MD_para_in, UnitCell& unit_in);
Langevin(const Parameter& param_in, UnitCell& unit_in);

~Langevin();

Expand All @@ -38,6 +38,7 @@ class Langevin : public MD_base
void post_force();

ModuleBase::Vector3<double>* total_force; ///< total force = true force + Langevin fictitious_force
double md_damp; ///< damping factor
};

#endif
Loading

0 comments on commit f5ca124

Please sign in to comment.