Skip to content

Commit

Permalink
Refactor: remove pdiagh pointer and redundant code (#4520)
Browse files Browse the repository at this point in the history
* rafactor: remove pdiagh pointer and redundant code

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
Cstandardlib and pre-commit-ci-lite[bot] authored Jun 29, 2024
1 parent cf211d2 commit dc03a50
Showing 1 changed file with 1 addition and 161 deletions.
162 changes: 1 addition & 161 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,140 +35,7 @@ HSolverPW<T, Device>::HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pw
template <typename T, typename Device>
void HSolverPW<T, Device>::initDiagh(const psi::Psi<T, Device>& psi)
{
if (this->method == "cg")
{
// if (this->pdiagh != nullptr)
// {
// if (this->pdiagh->method != this->method)
// {
// delete reinterpret_cast<DiagoCG<T, Device>*>(this->pdiagh);
// }
// else
// {
// return;
// }
// }

// this->pdiagh = new DiagoCG<T, Device>(precondition.data());

// // warp the subspace_func into a lambda function
// auto ngk_pointer = psi.get_ngk_pointer();
// auto subspace_func = [this, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
// // psi_in should be a 2D tensor:
// // psi_in.shape() = [nbands, nbasis]
// const auto ndim = psi_in.shape().ndim();
// REQUIRES_OK(ndim == 2, "dims of psi_in should be less than or equal to 2");
// // Convert a Tensor object to a psi::Psi object
// auto psi_in_wrapper = psi::Psi<T, Device>(psi_in.data<T>(),
// 1,
// psi_in.shape().dim_size(0),
// psi_in.shape().dim_size(1),
// ngk_pointer);
// auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
// 1,
// psi_out.shape().dim_size(0),
// psi_out.shape().dim_size(1),
// ngk_pointer);
// auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
// ct::DeviceType::CpuDevice,
// ct::TensorShape({psi_in.shape().dim_size(0)}));

// DiagoIterAssist<T, Device>::diagH_subspace(hamilt_, psi_in_wrapper, psi_out_wrapper, eigen.data<Real>());
// };
// this->pdiagh = new DiagoCG<T, Device>(GlobalV::BASIS_TYPE,
// GlobalV::CALCULATION,
// DiagoIterAssist<T, Device>::need_subspace,
// subspace_func,
// DiagoIterAssist<T, Device>::PW_DIAG_THR,
// DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
// GlobalV::NPROC_IN_POOL);
// this->pdiagh->method = this->method;
}
else if (this->method == "dav")
{
// #ifdef __MPI
// const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
// #else
// const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
// #endif

// if (this->pdiagh != nullptr)
// {
// if (this->pdiagh->method != this->method)
// {
// delete (DiagoDavid<T, Device>*)this->pdiagh;

// this->pdiagh = new DiagoDavid<T, Device>(precondition.data(),
// GlobalV::PW_DIAG_NDIM,
// GlobalV::use_paw,
// comm_info);

// this->pdiagh->method = this->method;
// }
// }
// else
// {
// this->pdiagh
// = new DiagoDavid<T, Device>(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw,
// comm_info);

// this->pdiagh->method = this->method;
// }
}
else if (this->method == "dav_subspace")
{
// #ifdef __MPI
// const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
// #else
// const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
// #endif
// if (this->pdiagh != nullptr)
// {
// if (this->pdiagh->method != this->method)
// {
// delete (Diago_DavSubspace<T, Device>*)this->pdiagh;

// this->pdiagh = new Diago_DavSubspace<T, Device>(precondition.data(),
// GlobalV::PW_DIAG_NDIM,
// DiagoIterAssist<T, Device>::PW_DIAG_THR,
// DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
// DiagoIterAssist<T, Device>::need_subspace,
// comm_info);

// this->pdiagh->method = this->method;
// }
// }
// else
// {
// this->pdiagh = new Diago_DavSubspace<T, Device>(precondition.data(),
// GlobalV::PW_DIAG_NDIM,
// DiagoIterAssist<T, Device>::PW_DIAG_THR,
// DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
// DiagoIterAssist<T, Device>::need_subspace,
// comm_info);
// this->pdiagh->method = this->method;
// }
}
else if (this->method == "bpcg")
{
// if (this->pdiagh != nullptr)
// {
// if (this->pdiagh->method != this->method)
// {
// delete (DiagoBPCG<T, Device>*)this->pdiagh;
// this->pdiagh = new DiagoBPCG<T, Device>(precondition.data());
// this->pdiagh->method = this->method;
// reinterpret_cast<DiagoBPCG<T, Device>*>(this->pdiagh)->init_iter(psi);
// }
// }
// else
// {
// this->pdiagh = new DiagoBPCG<T, Device>(precondition.data());
// this->pdiagh->method = this->method;
// reinterpret_cast<DiagoBPCG<T, Device>*>(this->pdiagh)->init_iter(psi);
// }
}
else
if (this->method != "cg" && this->method != "dav" && this->method != "dav_subspace" && this->method != "bpcg")
{
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
}
Expand Down Expand Up @@ -656,29 +523,6 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt, // ESolver_
template <typename T, typename Device>
void HSolverPW<T, Device>::endDiagh()
{
// DiagoCG would keep 9*nbasis memory in cache during loop-k
// it should be deleted before calculating charge
// if (this->method == "cg")
// {
// delete reinterpret_cast<DiagoCG<T, Device>*>(this->pdiagh);
// this->pdiagh = nullptr;
// }
// if (this->method == "dav")
// {
// delete reinterpret_cast<DiagoDavid<T, Device>*>(this->pdiagh);
// this->pdiagh = nullptr;
// }
// if (this->method == "dav_subspace")
// {
// delete reinterpret_cast<Diago_DavSubspace<T, Device>*>(this->pdiagh);
// this->pdiagh = nullptr;
// }
// if (this->method == "bpcg")
// {
// delete reinterpret_cast<DiagoBPCG<T, Device>*>(this->pdiagh);
// this->pdiagh = nullptr;
// }

// in PW base, average iteration steps for each band and k-point should be printing
if (DiagoIterAssist<T, Device>::avg_iter > 0.0)
{
Expand Down Expand Up @@ -841,7 +685,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
DiagoIterAssist<T, Device>::need_subspace,
comm_info);

bool scf;
if (GlobalV::CALCULATION == "nscf")
{
Expand Down Expand Up @@ -893,12 +736,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
dav_subspace
.diag(hpsi_func, subspace_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, is_occupied, scf));

this->pdiagh = nullptr;
}
else if (this->method == "bpcg")
{
// this->pdiagh->diag(hm, psi, eigenvalue);
DiagoBPCG<T, Device> bpcg(precondition.data());
bpcg.init_iter(psi);
bpcg.diag(hm, psi, eigenvalue);
Expand Down

0 comments on commit dc03a50

Please sign in to comment.