Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
 into refactor
  • Loading branch information
YuLiu98 committed Nov 23, 2024
2 parents 9776828 + 9cc044e commit d5ea3af
Show file tree
Hide file tree
Showing 15 changed files with 65 additions and 154 deletions.
3 changes: 3 additions & 0 deletions source/module_base/vector3.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ template <class T> class Vector3
Vector3(const Vector3<T> &v) : x(v.x), y(v.y), z(v.z){}; // Peize Lin add 2018-07-16
explicit Vector3(const std::array<T,3> &v) :x(v[0]), y(v[1]), z(v[2]){}

template <typename U>
explicit Vector3(const Vector3<U>& other) : x(static_cast<T>(other.x)), y(static_cast<T>(other.y)), z(static_cast<T>(other.z)) {}

Vector3(Vector3<T> &&v) noexcept : x(v.x), y(v.y), z(v.z) {}

/**
Expand Down
4 changes: 0 additions & 4 deletions source/module_hamilt_lcao/module_gint/gint.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,16 @@ class Gint {

//! calculate local potential contribution to the Hamiltonian
//! na_grid: how many atoms on this (i,j,k) grid
//! block_iw: dim is [na_grid], index of wave function for each block
//! block_size: dim is [block_size], number of columns of a band
//! block_index: dim is [na_grid+1], total number of atomic orbitals
//! grid_index: index of grid group, for tracing iat
//! cal_flag: dim is [bxyz][na_grid], whether the atom-grid distance is larger than cutoff
//! psir_ylm: dim is [bxyz][LD_pool]
//! psir_vlbr3: dim is [bxyz][LD_pool]
//! hR: HContainer for storing the <phi_0|V|phi_R> matrix elements

void cal_meshball_vlocal(
const int na_grid,
const int LD_pool,
const int* const block_iw,
const int* const block_size,
const int* const block_index,
const int grid_index,
Expand All @@ -154,7 +151,6 @@ class Gint {
const double* const* const psir_vlbr3,
hamilt::HContainer<double>* hR);


//! in gint_fvl.cpp
//! calculate vl contributuion to force & stress via grid integrals
void gint_kernel_force(const int na_grid,
Expand Down
1 change: 0 additions & 1 deletion source/module_hamilt_lcao/module_gint/gint_gamma_env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ void Gint_Gamma::cal_env(const double* wfc, double* rho, UnitCell& ucell)
}
const int nbx = this->gridt->nbx;
const int nby = this->gridt->nby;
const int nbz_start = this->gridt->nbzp_start;
const int nbz = this->gridt->nbzp;
const int ncyz = this->ny * this->nplane; // mohan add 2012-03-25
const int bxyz = this->bxyz;
Expand Down
7 changes: 2 additions & 5 deletions source/module_hamilt_lcao/module_gint/gint_k_env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "module_basis/module_ao/ORB_read.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_base/array_pool.h"
#include "module_base/vector3.h"

void Gint_k::cal_env_k(int ik,
const std::complex<double>* psi_k,
Expand All @@ -26,7 +27,6 @@ void Gint_k::cal_env_k(int ik,
}
const int nbx = this->gridt->nbx;
const int nby = this->gridt->nby;
const int nbz_start = this->gridt->nbzp_start;
const int nbz = this->gridt->nbzp;
const int ncyz = this->ny * this->nplane; // mohan add 2012-03-25

Expand Down Expand Up @@ -88,10 +88,7 @@ void Gint_k::cal_env_k(int ik,

// find R by which_unitcell and cal kphase
const int id_ucell = this->gridt->which_unitcell[mcell_index1];
const int Rx = this->gridt->ucell_index2x[id_ucell] + this->gridt->min_ucell_para[0];
const int Ry = this->gridt->ucell_index2y[id_ucell] + this->gridt->min_ucell_para[1];
const int Rz = this->gridt->ucell_index2z[id_ucell] + this->gridt->min_ucell_para[2];
ModuleBase::Vector3<double> R((double)Rx, (double)Ry, (double)Rz);
ModuleBase::Vector3<double> R(this->gridt->get_ucell_coords(id_ucell));
// std::cout << "kvec_d: " << kvec_d[ik].x << " " << kvec_d[ik].y << " " << kvec_d[ik].z << std::endl;
// std::cout << "kvec_c: " << kvec_c[ik].x << " " << kvec_c[ik].y << " " << kvec_c[ik].z << std::endl;
// std::cout << "R: " << R.x << " " << R.y << " " << R.z << std::endl;
Expand Down
17 changes: 4 additions & 13 deletions source/module_hamilt_lcao/module_gint/gint_vl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "module_base/blas_connector.h"
#include "module_base/timer.h"
#include "module_base/array_pool.h"
#include "module_base/vector3.h"
//#include <mkl_cblas.h>

#ifdef _OPENMP
Expand All @@ -22,7 +23,6 @@
void Gint::cal_meshball_vlocal(
const int na_grid, // how many atoms on this (i,j,k) grid
const int LD_pool,
const int*const block_iw, // block_iw[na_grid], index of wave functions for each block
const int*const block_size, // block_size[na_grid], number of columns of a band
const int*const block_index, // block_index[na_grid+1], count total number of atomis orbitals
const int grid_index, // index of grid group, for tracing global atom index
Expand All @@ -41,18 +41,14 @@ void Gint::cal_meshball_vlocal(
const int bcell1 = mcell_index + ia1;
const int iat1 = this->gridt->which_atom[bcell1];
const int id1 = this->gridt->which_unitcell[bcell1];
const int r1x = this->gridt->ucell_index2x[id1];
const int r1y = this->gridt->ucell_index2y[id1];
const int r1z = this->gridt->ucell_index2z[id1];
const ModuleBase::Vector3<int> r1 = this->gridt->get_ucell_coords(id1);

for(int ia2=0; ia2<na_grid; ++ia2)
{
const int bcell2 = mcell_index + ia2;
const int iat2= this->gridt->which_atom[bcell2];
const int id2 = this->gridt->which_unitcell[bcell2];
const int r2x = this->gridt->ucell_index2x[id2];
const int r2y = this->gridt->ucell_index2y[id2];
const int r2z = this->gridt->ucell_index2z[id2];
const ModuleBase::Vector3<int> r2 = this->gridt->get_ucell_coords(id2);

if(iat1<=iat2)
{
Expand All @@ -77,12 +73,7 @@ void Gint::cal_meshball_vlocal(
const int ib_length = last_ib-first_ib;
if(ib_length<=0) { continue; }

// calculate the BaseMatrix of <iat1, iat2, R> atom-pair
const int dRx = r1x - r2x;
const int dRy = r1y - r2y;
const int dRz = r1z - r2z;

const auto tmp_matrix = hR->find_matrix(iat1, iat2, dRx, dRy, dRz);
const auto tmp_matrix = hR->find_matrix(iat1, iat2, r1-r2);
if (tmp_matrix == nullptr)
{
continue;
Expand Down
16 changes: 8 additions & 8 deletions source/module_hamilt_lcao/module_gint/gint_vl_cpu_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void Gint::gint_kernel_vlocal(Gint_inout* inout) {
//integrate (psi_mu*v(r)*dv) * psi_nu on grid
//and accumulates to the corresponding element in Hamiltonian
this->cal_meshball_vlocal(
na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index,
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index,
cal_flag.get_ptr_2D(),psir_ylm.get_ptr_2D(), psir_vlbr3.get_ptr_2D(),
&hRGint_thread);
}
Expand Down Expand Up @@ -158,13 +158,13 @@ void Gint::gint_kernel_dvlocal(Gint_inout* inout) {
//integrate (psi_mu*v(r)*dv) * psi_nu on grid
//and accumulates to the corresponding element in Hamiltonian
this->cal_meshball_vlocal(na_grid, LD_pool, block_size.data(), block_index.data(),
block_iw.data(), grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
dpsir_ylm_x.get_ptr_2D(), &pvdpRx_thread);
this->cal_meshball_vlocal(na_grid, LD_pool, block_size.data(), block_index.data(),
block_iw.data(), grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
dpsir_ylm_y.get_ptr_2D(), &pvdpRy_thread);
this->cal_meshball_vlocal(na_grid, LD_pool, block_size.data(), block_index.data(),
block_iw.data(), grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
dpsir_ylm_z.get_ptr_2D(), &pvdpRz_thread);
}
#pragma omp critical(gint_k)
Expand Down Expand Up @@ -281,18 +281,18 @@ void Gint::gint_kernel_vlocal_meta(Gint_inout* inout) {
//integrate (psi_mu*v(r)*dv) * psi_nu on grid
//and accumulates to the corresponding element in Hamiltonian
this->cal_meshball_vlocal(
na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
psir_ylm.get_ptr_2D(), psir_vlbr3.get_ptr_2D(), &hRGint_thread);
//integrate (d/dx_i psi_mu*vk(r)*dv) * (d/dx_i psi_nu) on grid (x_i=x,y,z)
//and accumulates to the corresponding element in Hamiltonian
this->cal_meshball_vlocal(
na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
dpsir_ylm_x.get_ptr_2D(), dpsix_vlbr3.get_ptr_2D(), &hRGint_thread);
this->cal_meshball_vlocal(
na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
dpsir_ylm_y.get_ptr_2D(), dpsiy_vlbr3.get_ptr_2D(), &hRGint_thread);
this->cal_meshball_vlocal(
na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
dpsir_ylm_z.get_ptr_2D(), dpsiz_vlbr3.get_ptr_2D(), &hRGint_thread);
}

Expand Down
20 changes: 0 additions & 20 deletions source/module_hamilt_lcao/module_gint/grid_index.h

This file was deleted.

6 changes: 5 additions & 1 deletion source/module_hamilt_lcao/module_gint/grid_meshcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class Grid_MeshCell: public Grid_MeshK
int nbzp_start,nbzp;
// save the position of each meshcell.
std::vector<std::vector<double>> meshcell_pos;

private:
// latvec0 and GT are not used in current code.
// these two variables may be removed in the future.
ModuleBase::Matrix3 meshcell_latvec0;
ModuleBase::Matrix3 meshcell_GT;

Expand Down Expand Up @@ -45,7 +49,7 @@ class Grid_MeshCell: public Grid_MeshK
const int &nbzp_in);

void init_latvec(const UnitCell &ucell);
void init_meshcell_pos(void);
void init_meshcell_pos();

};

Expand Down
42 changes: 17 additions & 25 deletions source/module_hamilt_lcao/module_gint/grid_meshk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,14 @@ int Grid_MeshK::cal_Rindex(const int &u1, const int &u2, const int &u3)const
return (x3 + x2 * this->nu3 + x1 * this->nu2 * this->nu3);
}

void Grid_MeshK::init_ucell_para(void)
ModuleBase::Vector3<int> Grid_MeshK::get_ucell_coords(const int &Rindex)const
{
this->max_ucell_para=std::vector<int>(3,0);
this->max_ucell_para[0]=this->maxu1;
this->max_ucell_para[1]=this->maxu2;
this->max_ucell_para[2]=this->maxu3;

this->min_ucell_para=std::vector<int>(3,0);
this->min_ucell_para[0]=this->minu1;
this->min_ucell_para[1]=this->minu2;
this->min_ucell_para[2]=this->minu3;

this->num_ucell_para=std::vector<int>(4,0);
this->num_ucell_para[0]=this->nu1;
this->num_ucell_para[1]=this->nu2;
this->num_ucell_para[2]=this->nu3;
this->num_ucell_para[3]=this->nutot;
}
const int x = ucell_index2x[Rindex];
const int y = ucell_index2y[Rindex];
const int z = ucell_index2z[Rindex];

return ModuleBase::Vector3<int>(x, y, z);
}

void Grid_MeshK::cal_extended_cell(const int &dxe, const int &dye, const int &dze,const int& nbx, const int& nby, const int& nbz)
{
Expand All @@ -66,8 +55,10 @@ void Grid_MeshK::cal_extended_cell(const int &dxe, const int &dye, const int &dz
this->minu2 = (-dye+1) / nby - 1;
this->minu3 = (-dze+1) / nbz - 1;

if(PARAM.inp.test_gridt)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MaxUnitcell",maxu1,maxu2,maxu3);
if(PARAM.inp.test_gridt)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MinUnitcell",minu1,minu2,minu3);
if(PARAM.inp.test_gridt) {ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MaxUnitcell",maxu1,maxu2,maxu3);
}
if(PARAM.inp.test_gridt) {ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MinUnitcell",minu1,minu2,minu3);
}

//--------------------------------------
// number of unitcell in each direction.
Expand All @@ -77,9 +68,10 @@ void Grid_MeshK::cal_extended_cell(const int &dxe, const int &dye, const int &dz
this->nu3 = maxu3 - minu3 + 1;
this->nutot = nu1 * nu2 * nu3;

init_ucell_para();
if(PARAM.inp.test_gridt)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"UnitCellNumber",nu1,nu2,nu3);
if(PARAM.inp.out_level != "m") ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"UnitCellTotal",nutot);
if(PARAM.inp.test_gridt) {ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"UnitCellNumber",nu1,nu2,nu3);
}
if(PARAM.inp.out_level != "m") { ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"UnitCellTotal",nutot);
}


this->ucell_index2x = std::vector<int>(nutot, 0);
Expand All @@ -97,9 +89,9 @@ void Grid_MeshK::cal_extended_cell(const int &dxe, const int &dye, const int &dz
const int cell = cal_Rindex(i,j,k);
assert(cell<nutot);

this->ucell_index2x[cell] = i-minu1;
this->ucell_index2y[cell] = j-minu2;
this->ucell_index2z[cell] = k-minu3;
this->ucell_index2x[cell] = i;
this->ucell_index2y[cell] = j;
this->ucell_index2z[cell] = k;

}
}
Expand Down
22 changes: 10 additions & 12 deletions source/module_hamilt_lcao/module_gint/grid_meshk.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,23 @@
#define GRID_MESHK_H
#include "module_base/global_function.h"
#include "module_base/global_variable.h"
#include "module_base/vector3.h"

class Grid_MeshK
{
public:
Grid_MeshK();
~Grid_MeshK();
// from 1D index to unitcell.
std::vector<int> ucell_index2x;
std::vector<int> ucell_index2y;
std::vector<int> ucell_index2z;

// the unitcell parameters.
std::vector<int> max_ucell_para;
std::vector<int> min_ucell_para;
std::vector<int> num_ucell_para;

// calculate the index of unitcell.
int cal_Rindex(const int& u1, const int& u2, const int& u3)const;

ModuleBase::Vector3<int> get_ucell_coords(const int& Rindex)const;

/// move operator for the next ESolver to directly use its infomation
Grid_MeshK& operator=(Grid_MeshK&& rhs) = default;

protected:
private:
// the max and the min unitcell.
int maxu1;
int maxu2;
Expand All @@ -40,11 +34,15 @@ class Grid_MeshK
int nu3;
int nutot;

// from 1D index to unitcell.
std::vector<int> ucell_index2x;
std::vector<int> ucell_index2y;
std::vector<int> ucell_index2z;

protected:
// calculate the extended unitcell.
void cal_extended_cell(const int &dxe, const int &dye, const int &dze,
const int& nbx, const int& nby, const int& nbz);
// initialize the unitcell parameters.
void init_ucell_para(void);
};

#endif
17 changes: 0 additions & 17 deletions source/module_hamilt_lcao/module_gint/grid_technique.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#ifndef GRID_TECHNIQUE_H
#define GRID_TECHNIQUE_H

#include "grid_index.h"
#include "grid_meshball.h"
#include "module_basis/module_ao/ORB_read.h"
#include "module_basis/module_ao/parallel_orbitals.h"
Expand Down Expand Up @@ -84,9 +83,6 @@ class Grid_Technique : public Grid_MeshBall {
std::vector<std::vector<double>> dpsi_u;
std::vector<std::vector<double>> d2psi_u;

// indexes for nnrg -> orbital index + R index
std::vector<gridIntegral::gridIndex> nnrg_index;

// Determine whether the grid point integration is initialized.
bool init_malloced;

Expand Down Expand Up @@ -132,19 +128,6 @@ class Grid_Technique : public Grid_MeshBall {
// store the information of atom pairs on this processor, used to initialize hcontainer.
// The meaning of ijr can be referred to in the get_ijr_info function in hcontainer.cpp.
std::vector<int> ijr_info;
int maxB1;
int maxB2;
int maxB3;

int minB1;
int minB2;
int minB3;

int nB1;
int nB2;
int nB3;

int nbox;

void cal_max_box_index();
// atoms on meshball
Expand Down
Loading

0 comments on commit d5ea3af

Please sign in to comment.