Skip to content

Commit

Permalink
Add support for small (ABn-type) molecules for sGNN
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangYu committed Feb 22, 2024
1 parent 2126ba4 commit b9a794d
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 52 deletions.
36 changes: 26 additions & 10 deletions dmff/sgnn/gnn.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import jax.nn.initializers
import jax.numpy as jnp
import numpy as np
from .graph import MAX_VALENCE, TopGraph, from_pdb
from .graph import TopGraph, from_pdb
from .graph import MAX_VALENCE, ATYPE_INDEX, FSCALE_BOND, FSCALE_ANGLE
from ..utils import jit_condition
from jax import value_and_grad, vmap

Expand Down Expand Up @@ -55,7 +56,12 @@ def __init__(self,
nn=1,
sigma=162.13039087945623,
mu=117.41975505778706,
seed=12345):
seed=12345,
max_valence=MAX_VALENCE,
atype_index=ATYPE_INDEX,
fscale_bond=FSCALE_BOND,
fscale_angle=FSCALE_ANGLE
):
""" Constructor for MolGNNForce
Parameters
Expand All @@ -77,15 +83,25 @@ def __init__(self,
mu: float, optional
a constant shift
the final total energy would be ${(E_{NN} + \mu) * \sigma}
seed: int: optional
seed: int, optional
random seed used in network initialization
default = 12345
max_valence: int, optional
Maximal valence number for all atoms inside the graph, use the value in graph.py by default
atype_index: dict, optional
A dictionary that assign index to each relevant element: e.g., {'H': 0, 'C': 1, 'O': 2}, use the ATYPE_INDEX in graph.py by default
fscale_bond: float, optional
The scaling factor for bond features, use value in graph.py by default
fscale_angle: float, optional
The scaling factor for angle features, use value in graph.py by default
"""
self.nn = nn
self.G = G
self.G.get_all_subgraphs(nn, typify=True)
self.G.prepare_subgraph_feature_calc()
self.G.prepare_subgraph_feature_calc(max_valence=max_valence,
atype_index=atype_index,
fscale_bond=fscale_bond,
fscale_angle=fscale_angle)
params = OrderedDict()
key = jax.random.PRNGKey(seed)
params['w'] = jax.random.uniform(key)
Expand Down Expand Up @@ -151,14 +167,14 @@ def message_pass(f_in, nb_connect, w, nn):
if nn == 0:
return f_in[0]
elif nn == 1:
nb_connect0 = nb_connect[0:MAX_VALENCE - 1]
nb_connect1 = nb_connect[MAX_VALENCE - 1:2 *
(MAX_VALENCE - 1)]
nb_connect0 = nb_connect[0:max_valence - 1]
nb_connect1 = nb_connect[max_valence - 1:2 *
(max_valence - 1)]
nb0 = jnp.sum(nb_connect0)
nb1 = jnp.sum(nb_connect1)
f = f_in[0] * (1 - jnp.heaviside(nb0, 0)*w - jnp.heaviside(nb1, 0)*w) + \
w * nb_connect0.dot(f_in[1:MAX_VALENCE, :]) / jnp.piecewise(nb0, [nb0<1e-5, nb0>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x]) + \
w * nb_connect1.dot(f_in[MAX_VALENCE:2*MAX_VALENCE-1, :])/ jnp.piecewise(nb1, [nb1<1e-5, nb1>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x])
w * nb_connect0.dot(f_in[1:max_valence, :]) / jnp.piecewise(nb0, [nb0<1e-5, nb0>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x]) + \
w * nb_connect1.dot(f_in[max_valence:2*max_valence-1, :])/ jnp.piecewise(nb1, [nb1<1e-5, nb1>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x])
return f

features = fc0(features, params)
Expand Down
Loading

0 comments on commit b9a794d

Please sign in to comment.