Skip to content

Commit

Permalink
chore: refactor distance calc
Browse files Browse the repository at this point in the history
  • Loading branch information
Anyang Peng authored and Anyang Peng committed Jan 23, 2024
1 parent f38916e commit 598d265
Showing 1 changed file with 44 additions and 8 deletions.
52 changes: 44 additions & 8 deletions deepmd_pt/model/model/pair_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def distinguish_types(self)->bool:
# this model has no descriptor, thus no type_split.
return

# Since ZBL model is different from other AtomicModels, overwritting it here
def forward_atomic(
self,
extended_coord,
Expand All @@ -104,6 +103,7 @@ def forward_atomic(

nframes, nloc, nnei = nlist.shape
# atype = extended_atype[:, :nloc]
pairwise_dr = self._get_pairwise_dist(extended_coord)

"""
below is the sudo code, need to figure out how the index works.
Expand All @@ -116,24 +116,26 @@ def forward_atomic(
# removing _pair_tab_jloop method, just unwrap here.
cur_table_data --> subtable based on atype.
dr = extended_coord[:, a_loc] - extended_coord[:, a_nei]
pairwise_ene = self._pair_tabulated_inter(cur_table_data, dr)
rr = pairwise_dr[a_loc][a_nei].pow(2).sum().sqrt() # this is the salar distance.
pairwise_ene = self._pair_tabulated_inter(cur_table_data, rr)
atomic_energy[a_loc] += pairwise_ene
return {"atomic_energy": atomic_energy} --> convert to FittingOutputDef
"""

def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, dr: torch.Tensor) -> torch.Tensor:
def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, rr: torch.Tensor) -> torch.Tensor:
"""Pairwise tabulated energy.
Parameters
----------
cur_table_data : torch.Tensor
The tabulated cubic spline coefficients for the current atom types.
dr : torch.Tensor
The distance vector between two atoms.
rr : torch.Tensor
The salar distance vector between two atoms.
Returns
-------
Expand All @@ -158,8 +160,6 @@ def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, dr: torch.Tensor)
nspline = int(self.tab_info[2] + 0.1)
ndata = nspline * 4

r2 = dr[0]**2 + dr[1]**2 + dr[2]**2
rr = torch.sqrt(r2) # scaler distance

uu = (rr - rmin) * hi

Expand All @@ -184,5 +184,41 @@ def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, dr: torch.Tensor)
ener = etmp * uu + a0
return ener

@staticmethod
def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:
"""Get pairwise distance `dr`.
Parameters
----------
coords : torch.Tensor
The coordinate of the atoms.
Returns
-------
torch.Tensor
The pairwise distance between the atoms.
Examples
--------
coords = torch.tensor([
[0,0,0],
[1,3,5],
[2,4,6]
])
dist = tensor([[[ 0, 0, 0],
[-1, -3, -5],
[-2, -4, -6]],
[[ 1, 3, 5],
[ 0, 0, 0],
[-1, -1, -1]],
[[ 2, 4, 6],
[ 1, 1, 1],
[ 0, 0, 0]]])
"""
return coords.unsqueeze(1) - coords.unsqueeze(0)

0 comments on commit 598d265

Please sign in to comment.