From 598d26551f58a650b7cb95ddc116250bff7e9a1a Mon Sep 17 00:00:00 2001 From: Anyang Peng Date: Tue, 23 Jan 2024 12:19:13 +0800 Subject: [PATCH] chore: refactor distance calc --- deepmd_pt/model/model/pair_tab.py | 52 ++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/deepmd_pt/model/model/pair_tab.py b/deepmd_pt/model/model/pair_tab.py index aac7396..a628f0d 100644 --- a/deepmd_pt/model/model/pair_tab.py +++ b/deepmd_pt/model/model/pair_tab.py @@ -89,7 +89,6 @@ def distinguish_types(self)->bool: # this model has no descriptor, thus no type_split. return - # Since ZBL model is different from other AtomicModels, overwritting it here def forward_atomic( self, extended_coord, @@ -104,6 +103,7 @@ def forward_atomic( nframes, nloc, nnei = nlist.shape # atype = extended_atype[:, :nloc] + pairwise_dr = self._get_pairwise_dist(extended_coord) """ below is the sudo code, need to figure out how the index works. @@ -116,15 +116,17 @@ def forward_atomic( # removing _pair_tab_jloop method, just unwrap here. cur_table_data --> subtable based on atype. - dr = extended_coord[:, a_loc] - extended_coord[:, a_nei] - pairwise_ene = self._pair_tabulated_inter(cur_table_data, dr) + + rr = pairwise_dr[a_loc][a_nei].pow(2).sum().sqrt() # this is the salar distance. + + pairwise_ene = self._pair_tabulated_inter(cur_table_data, rr) atomic_energy[a_loc] += pairwise_ene return {"atomic_energy": atomic_energy} --> convert to FittingOutputDef """ - def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, dr: torch.Tensor) -> torch.Tensor: + def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, rr: torch.Tensor) -> torch.Tensor: """Pairwise tabulated energy. Parameters @@ -132,8 +134,8 @@ def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, dr: torch.Tensor) cur_table_data : torch.Tensor The tabulated cubic spline coefficients for the current atom types. - dr : torch.Tensor - The distance vector between two atoms. + rr : torch.Tensor + The salar distance vector between two atoms. Returns ------- @@ -158,8 +160,6 @@ def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, dr: torch.Tensor) nspline = int(self.tab_info[2] + 0.1) ndata = nspline * 4 - r2 = dr[0]**2 + dr[1]**2 + dr[2]**2 - rr = torch.sqrt(r2) # scaler distance uu = (rr - rmin) * hi @@ -184,5 +184,41 @@ def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, dr: torch.Tensor) ener = etmp * uu + a0 return ener + @staticmethod + def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor: + """Get pairwise distance `dr`. + + Parameters + ---------- + coords : torch.Tensor + The coordinate of the atoms. + + Returns + ------- + torch.Tensor + The pairwise distance between the atoms. + + Examples + -------- + + coords = torch.tensor([ + [0,0,0], + [1,3,5], + [2,4,6] + ]) + + dist = tensor([[[ 0, 0, 0], + [-1, -3, -5], + [-2, -4, -6]], + + [[ 1, 3, 5], + [ 0, 0, 0], + [-1, -1, -1]], + + [[ 2, 4, 6], + [ 1, 1, 1], + [ 0, 0, 0]]]) + """ + return coords.unsqueeze(1) - coords.unsqueeze(0)