From b9a794d252f06dd70cbe281bbebd66594268677b Mon Sep 17 00:00:00 2001 From: KuangYu Date: Thu, 22 Feb 2024 11:47:35 +0800 Subject: [PATCH] Add support for small (ABn-type) molecules for sGNN --- dmff/sgnn/gnn.py | 36 +++++++--- dmff/sgnn/graph.py | 168 +++++++++++++++++++++++++++++++++------------ 2 files changed, 152 insertions(+), 52 deletions(-) mode change 100755 => 100644 dmff/sgnn/gnn.py mode change 100755 => 100644 dmff/sgnn/graph.py diff --git a/dmff/sgnn/gnn.py b/dmff/sgnn/gnn.py old mode 100755 new mode 100644 index d05b848be..e953d4394 --- a/dmff/sgnn/gnn.py +++ b/dmff/sgnn/gnn.py @@ -8,7 +8,8 @@ import jax.nn.initializers import jax.numpy as jnp import numpy as np -from .graph import MAX_VALENCE, TopGraph, from_pdb +from .graph import TopGraph, from_pdb +from .graph import MAX_VALENCE, ATYPE_INDEX, FSCALE_BOND, FSCALE_ANGLE from ..utils import jit_condition from jax import value_and_grad, vmap @@ -55,7 +56,12 @@ def __init__(self, nn=1, sigma=162.13039087945623, mu=117.41975505778706, - seed=12345): + seed=12345, + max_valence=MAX_VALENCE, + atype_index=ATYPE_INDEX, + fscale_bond=FSCALE_BOND, + fscale_angle=FSCALE_ANGLE + ): """ Constructor for MolGNNForce Parameters @@ -77,15 +83,25 @@ def __init__(self, mu: float, optional a constant shift the final total energy would be ${(E_{NN} + \mu) * \sigma} - seed: int: optional + seed: int, optional random seed used in network initialization default = 12345 - + max_valence: int, optional + Maximal valence number for all atoms inside the graph, use the value in graph.py by default + atype_index: dict, optional + A dictionary that assign index to each relevant element: e.g., {'H': 0, 'C': 1, 'O': 2}, use the ATYPE_INDEX in graph.py by default + fscale_bond: float, optional + The scaling factor for bond features, use value in graph.py by default + fscale_angle: float, optional + The scaling factor for angle features, use value in graph.py by default """ self.nn = nn self.G = G self.G.get_all_subgraphs(nn, typify=True) - self.G.prepare_subgraph_feature_calc() + self.G.prepare_subgraph_feature_calc(max_valence=max_valence, + atype_index=atype_index, + fscale_bond=fscale_bond, + fscale_angle=fscale_angle) params = OrderedDict() key = jax.random.PRNGKey(seed) params['w'] = jax.random.uniform(key) @@ -151,14 +167,14 @@ def message_pass(f_in, nb_connect, w, nn): if nn == 0: return f_in[0] elif nn == 1: - nb_connect0 = nb_connect[0:MAX_VALENCE - 1] - nb_connect1 = nb_connect[MAX_VALENCE - 1:2 * - (MAX_VALENCE - 1)] + nb_connect0 = nb_connect[0:max_valence - 1] + nb_connect1 = nb_connect[max_valence - 1:2 * + (max_valence - 1)] nb0 = jnp.sum(nb_connect0) nb1 = jnp.sum(nb_connect1) f = f_in[0] * (1 - jnp.heaviside(nb0, 0)*w - jnp.heaviside(nb1, 0)*w) + \ - w * nb_connect0.dot(f_in[1:MAX_VALENCE, :]) / jnp.piecewise(nb0, [nb0<1e-5, nb0>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x]) + \ - w * nb_connect1.dot(f_in[MAX_VALENCE:2*MAX_VALENCE-1, :])/ jnp.piecewise(nb1, [nb1<1e-5, nb1>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x]) + w * nb_connect0.dot(f_in[1:max_valence, :]) / jnp.piecewise(nb0, [nb0<1e-5, nb0>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x]) + \ + w * nb_connect1.dot(f_in[max_valence:2*max_valence-1, :])/ jnp.piecewise(nb1, [nb1<1e-5, nb1>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x]) return f features = fc0(features, params) diff --git a/dmff/sgnn/graph.py b/dmff/sgnn/graph.py old mode 100755 new mode 100644 index 1d85d92e2..e6c240e68 --- a/dmff/sgnn/graph.py +++ b/dmff/sgnn/graph.py @@ -37,15 +37,18 @@ # 'S': 4 # } ATYPE_INDEX = {'H': 0, 'C': 1, 'O': 2} +# ATYPE_INDEX = {'H': 0, 'B':1, 'C': 2, 'N': 3 , 'O': 4, 'F': 5, 'P': 6, 'S':7 } N_ATYPES = len(ATYPE_INDEX.keys()) # used to compute equilibrium bond lengths -COVALENT_RADIUS = {'H': 0.31, 'C': 0.76, 'N': 0.71, 'O': 0.66, 'S': 1.05} +# COVALENT_RADIUS = {'H': 0.31, 'C': 0.76, 'N': 0.71, 'O': 0.66, 'S': 1.05, } +COVALENT_RADIUS = {'H': 0.31, 'B': 0.84, 'C': 0.76, 'N': 0.71, 'O': 0.66, 'F': 0.57, 'S': 1.05, 'P': 1.07 } # scaling parameters for feature calculations FSCALE_BOND = 10.0 FSCALE_ANGLE = 5.0 +# not worth it to make it 6 for only PF6 MAX_VALENCE = 4 MAX_ANGLES_PER_SITE = MAX_VALENCE * (MAX_VALENCE - 1) // 2 MAX_DIHEDS_PER_BOND = (MAX_VALENCE - 1)**2 @@ -53,9 +56,9 @@ # dimension of bond features DIM_BOND_FEATURES_GEOM = { 'bonds': 2 * MAX_VALENCE - 1, - 'angles0': MAX_VALENCE * (MAX_VALENCE - 1) // 2, - 'angles1': MAX_VALENCE * (MAX_VALENCE - 1) // 2, - 'diheds': (MAX_VALENCE - 1)**2 + 'angles0': MAX_ANGLES_PER_SITE, + 'angles1': MAX_ANGLES_PER_SITE, + 'diheds': MAX_DIHEDS_PER_BOND } DIM_BOND_FEATURES_GEOM_TOT = np.sum( [DIM_BOND_FEATURES_GEOM[k] for k in DIM_BOND_FEATURES_GEOM.keys()]) @@ -93,8 +96,9 @@ def __init__(self, list_atom_elems, bonds, positions=None, box=None): self._get_valences() self.set_internal_coords_indices() self.box = box + if box is not None: - self.box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) + self.box_inv = jnp.linalg.inv(box) else: self.box_inv = None return @@ -109,7 +113,7 @@ def set_box(self, box): 3 * 3: the box array, pbc vectors arranged in rows ''' self.box = box - self.box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) + self.box_inv = jnp.linalg.inv(box) if hasattr(self, 'subgraphs'): self._propagate_attr('box') self._propagate_attr('box_inv') @@ -206,11 +210,27 @@ def get_all_subgraphs(self, self.n_subgraphs = len(self.subgraphs) if typify: self.typify_all_subgraphs() - if typify and id_chiral: + # if it is a symmetric tetrahedral dipyramid (PF6), add extra labels + self.is_AB6 = False + self.is_small = False + if set(self.valences) == {1, self.n_atoms-1} \ + and np.sum(np.array(self.valences) == 1) == self.n_atoms-1: + self.is_small = True + if self.n_atoms == 7 and len(set(self.list_atom_elems)) == 2: + self.is_AB6 = True + # for small molecules (ABn like), only conduct one neighbor search + if self.is_small: + nn = 0 + self.nn = nn + if self.is_AB6: + for g in self.subgraphs: + g._add_tetrahedral_bipyramid_labels() + if id_chiral: + for g in self.subgraphs: + g._add_chirality_labels() for g in self.subgraphs: - g._add_chirality_labels() # create permutation groups, and canonical orders for atoms - g.get_canonical_orders_wt_permutation_grps() + g.get_canonical_orders_wt_permutation_grps(is_AB6=self.is_AB6) return def _update_subgraph_positions(self): @@ -346,6 +366,26 @@ def _add_chirality_labels(self, verbose=False): self.atom_types[k] += 'R' return + def _add_tetrahedral_bipyramid_labels(self, verbose=False): + ''' + This subroutine is specific to PF6-like tetrahedral bipyramid molecule + It uses initial positions to specify the special F on the opposite position + of the central bond + ''' + pos = self.positions + # find the central atom and the central bond + valences = np.array(self.valences) + i = np.where(valences == 6)[0][0] + j = np.where(valences[0:2] == 1)[0][0] + vec_ij = pos[j] - pos[i] + for k in range(2, 8): + vec_ik = pos[k] - pos[i] + if np.dot(vec_ij, vec_ik) / np.linalg.norm(vec_ij) / np.linalg.norm(vec_ik) < -0.7: + break + # label the opposite F + self.atom_types[k] += '1' + return + def set_internal_coords_indices(self): ''' This method go over the graph and search for all bonds, angles, diheds @@ -375,6 +415,7 @@ def set_internal_coords_indices(self): angles.append([j, i, k]) self.angles = np.array(angles) + def get_a0(indices_angles): a0 = np.zeros(len(indices_angles)) for ia, (j, i, k) in enumerate(indices_angles): @@ -393,6 +434,16 @@ def get_a0(indices_angles): cos_a0 = np.cos(120.00 / 180 * np.pi) elif valence == 4: cos_a0 = np.cos(109.45 / 180 * np.pi) # 109.5 degree + # special treatment for tetrahedral bipyramid structure like PF6 + elif valence == 6: + vec_ij = self.positions[j] - self.positions[i] + vec_ik = self.positions[k] - self.positions[i] + # if it is atoms in opposite positions + if np.dot(vec_ij, vec_ik) / np.linalg.norm(vec_ij) / np.linalg.norm(vec_ik) < -0.7: + cos_a0 = -1.0 + # otherwise, it is in perpendicular position + else: + cos_a0 = 0.0 a0[ia] = cos_a0 return a0 @@ -426,7 +477,7 @@ def calc_internal_coords_features(positions, box): All these variables should be "static" throughout NVE/NVT/NPT simulations ''' - box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) + box_inv = jnp.linalg.inv(box) @jit_condition(static_argnums=()) @partial(vmap, in_axes=(0, None, 0), out_axes=(0)) @@ -435,7 +486,7 @@ def _calc_bond_features(idx, pos, b0): pos1 = pos[idx[1]] dr = pbc_shift(pos1 - pos0, box, box_inv) blength = jnp.linalg.norm(dr) - return (blength - b0) * FSCALE_BOND + return (blength - b0) * self.fscale_bond @jit_condition(static_argnums=()) @partial(vmap, in_axes=(0, None, 0), out_axes=(0)) @@ -448,7 +499,7 @@ def _calc_angle_features(idx, pos, cos_a0): n_ij = jnp.linalg.norm(r_ij) n_ik = jnp.linalg.norm(r_ik) cos_a = jnp.dot(r_ij, r_ik) / n_ij / n_ik - return (cos_a - cos_a0) * FSCALE_ANGLE + return (cos_a - cos_a0) * self.fscale_angle @jit_condition(static_argnums=()) @partial(vmap, in_axes=(0, None), out_axes=(0)) @@ -469,7 +520,10 @@ def _calc_dihed_features(idx, pos): fb = _calc_bond_features(self.bonds, positions, self.b0) fa = _calc_angle_features(self.angles, positions, self.cos_a0) - fd = _calc_dihed_features(self.diheds, positions) + if len(self.diheds) == 0: + fd = jnp.array([0.0]) + else: + fd = _calc_dihed_features(self.diheds, positions) return fb, fa, fd @@ -477,7 +531,12 @@ def _calc_dihed_features(idx, pos): return - def prepare_subgraph_feature_calc(self): + def prepare_subgraph_feature_calc(self, + max_valence=MAX_VALENCE, + atype_index=ATYPE_INDEX, + fscale_bond=FSCALE_BOND, + fscale_angle=FSCALE_ANGLE + ): ''' Preparing the feature calculation. Specifically, find out the indices mapping between feature elements and ICs @@ -498,10 +557,30 @@ def prepare_subgraph_feature_calc(self): pos (Na*3), box (3*3) -> features (Ntot*7*n_features) The calculator for the Graph features. ''' + + # system dependent dimension parameters + self.atype_index = atype_index + self.n_atypes = len(atype_index.keys()) + self.fscale_bond = fscale_bond + self.fscale_angle = fscale_angle + self.max_valence = max_valence + self.max_angles_per_site = max_valence * (max_valence - 1) // 2 + self.max_diheds_per_bond = (max_valence - 1)**2 + self.dim_bond_features_geom = { + 'bonds': 2 * max_valence - 1, + 'angles0': self.max_angles_per_site, + 'angles1': self.max_angles_per_site, + 'diheds': self.max_diheds_per_bond + } + self.dim_bond_features_geom_tot = np.sum( + [self.dim_bond_features_geom[k] for k in self.dim_bond_features_geom.keys()]) + self.dim_bond_features_atypes = max_valence * 2 * self.n_atypes + + for g in self.subgraphs: g.prepare_graph_feature_calc() - self.n_features_atypes = DIM_BOND_FEATURES_ATYPES - self.n_features_geom = DIM_BOND_FEATURES_GEOM_TOT + self.n_features_atypes = self.dim_bond_features_atypes + self.n_features_geom = self.dim_bond_features_geom_tot self.n_features = self.n_features_atypes + self.n_features_geom # concatenate permutations @@ -527,6 +606,8 @@ def prepare_subgraph_feature_calc(self): jnp.tile(g.nb_connect[kb], (g.n_sym_perm, 1)) for g in self.subgraphs ]) + else: + self.nb_connect = None self.map_subgraph_perm = jnp.concatenate([ jnp.full((self.subgraphs[ig].n_sym_perm), ig, dtype=int) for ig in range(self.n_subgraphs) @@ -684,7 +765,7 @@ def add_neighbors(self): self.n_atoms = n_atoms return - def get_canonical_orders_wt_permutation_grps(self): + def get_canonical_orders_wt_permutation_grps(self, is_AB6=False): ''' This function sets up all the canonical orders for the atoms, based on existing atom typification (atom_types) information and the connection topology. @@ -787,10 +868,10 @@ def prepare_bond_feature_atypes(self, bond, map_order): # elements elem_i = self.list_atom_elems[i] elem_j = self.list_atom_elems[j] - fi = np.zeros(N_ATYPES) - fj = np.zeros(N_ATYPES) - fi[ATYPE_INDEX[elem_i]] = 1 - fj[ATYPE_INDEX[elem_j]] = 1 + fi = np.zeros(self.parent.n_atypes) + fj = np.zeros(self.parent.n_atypes) + fi[self.parent.atype_index[elem_i]] = 1 + fj[self.parent.atype_index[elem_j]] = 1 # neighbour atoms indices_n0 = np.array(np.where(self.connectivity[i] == 1)[0]) indices_n1 = np.array(np.where(self.connectivity[j] == 1)[0]) @@ -802,18 +883,18 @@ def prepare_bond_feature_atypes(self, bond, map_order): nn0 = len(indices_n0) nn1 = len(indices_n1) # features of the neighbour atom types - f_n0 = np.zeros(N_ATYPES * (MAX_VALENCE - 1)) - f_n1 = np.zeros(N_ATYPES * (MAX_VALENCE - 1)) + f_n0 = np.zeros(self.parent.n_atypes * (self.parent.max_valence - 1)) + f_n1 = np.zeros(self.parent.n_atypes * (self.parent.max_valence - 1)) for ii, i in enumerate(indices_n0): - tmp = np.zeros(N_ATYPES) + tmp = np.zeros(self.parent.n_atypes) elem = self.list_atom_elems[i] - tmp[ATYPE_INDEX[elem]] = 1 - f_n0[ii * N_ATYPES:ii * N_ATYPES + N_ATYPES] = tmp + tmp[self.parent.atype_index[elem]] = 1 + f_n0[ii * self.parent.n_atypes:ii * self.parent.n_atypes + self.parent.n_atypes] = tmp for ii, i in enumerate(indices_n1): - tmp = np.zeros(N_ATYPES) + tmp = np.zeros(self.parent.n_atypes) elem = self.list_atom_elems[i] - tmp[ATYPE_INDEX[elem]] = 1 - f_n1[ii * N_ATYPES:ii * N_ATYPES + N_ATYPES] = tmp + tmp[self.parent.atype_index[elem]] = 1 + f_n1[ii * self.parent.n_atypes:ii * self.parent.n_atypes + self.parent.n_atypes] = tmp return np.array(np.concatenate((fi, fj, f_n0, f_n1))) def prepare_bond_feature_calc_indices(self, @@ -850,8 +931,8 @@ def prepare_bond_feature_calc_indices(self, nn0 = len(indices_n0) nn1 = len(indices_n1) # padding neighbours - indices_atoms_n0 = -np.ones(MAX_VALENCE - 1, dtype=int) - indices_atoms_n1 = -np.ones(MAX_VALENCE - 1, dtype=int) + indices_atoms_n0 = -np.ones(self.parent.max_valence - 1, dtype=int) + indices_atoms_n1 = -np.ones(self.parent.max_valence - 1, dtype=int) indices_atoms_n0[:nn0] = indices_n0 indices_atoms_n1[:nn1] = indices_n1 @@ -927,9 +1008,12 @@ def prepare_bond_feature_calc_indices(self, indices['diheds'] = [] for d in indices_diheds: p = np.array([self.map_sub2parent[i] for i in d]) - match = np.where( - np.all(G.diheds == p, axis=1) + - np.all(G.diheds == p[::-1], axis=1))[0] + if self.parent.is_small: + match = np.array([]) + else: + match = np.where( + np.all(G.diheds == p, axis=1) + + np.all(G.diheds == p[::-1], axis=1))[0] if len(match) == 0: indices['diheds'].append(-1) else: @@ -981,8 +1065,8 @@ def prepare_graph_feature_calc(self): self.nb_connect['nb_bonds_0'] = jnp.array([1., 1., 0.]) ''' - self.n_bond_features_atypes = DIM_BOND_FEATURES_ATYPES - self.n_bond_features_geom = DIM_BOND_FEATURES_GEOM_TOT + self.n_bond_features_atypes = self.parent.dim_bond_features_atypes + self.n_bond_features_geom = self.parent.dim_bond_features_geom_tot self.n_bond_features = self.n_bond_features_atypes + self.n_bond_features_geom # assume the first bond is always the central bond center_bond = self.bonds[0] # should always be (0, 1) @@ -1099,17 +1183,17 @@ def prepare_graph_feature_calc(self): elif self.nn == 1: keys = ['center', 'nb_bonds_0', 'nb_bonds_1'] self.nb_connect = {} - self.nb_connect['nb_bonds_0'] = np.zeros(MAX_VALENCE - 1) - self.nb_connect['nb_bonds_1'] = np.zeros(MAX_VALENCE - 1) + self.nb_connect['nb_bonds_0'] = np.zeros(self.parent.max_valence - 1) + self.nb_connect['nb_bonds_1'] = np.zeros(self.parent.max_valence - 1) nb_list = { 'center': 1, - 'nb_bonds_0': MAX_VALENCE - 1, - 'nb_bonds_1': MAX_VALENCE - 1 + 'nb_bonds_0': self.parent.max_valence - 1, + 'nb_bonds_1': self.parent.max_valence - 1 } for kb in keys: # deal with the atype features feature_atypes[kb] = np.zeros( - (self.n_sym_perm, nb_list[kb], DIM_BOND_FEATURES_ATYPES)) + (self.n_sym_perm, nb_list[kb], self.parent.dim_bond_features_atypes)) nb = len(self.feature_atypes[kb][0]) if nb > 0: feature_atypes[kb][:, 0:nb, :] = np.array( @@ -1119,7 +1203,7 @@ def prepare_graph_feature_calc(self): feature_indices[kb] = {} for kf in ['bonds', 'angles0', 'angles1', 'diheds']: feature_indices[kb][kf] = -np.ones( - (self.n_sym_perm, nb_list[kb], DIM_BOND_FEATURES_GEOM[kf]), + (self.n_sym_perm, nb_list[kb], self.parent.dim_bond_features_geom[kf]), dtype=int) if nb > 0: feature_indices[kb][kf][:, 0:nb, :] = np.array([[