This chapter will introduce some basic usage of DMFF. All scripts can be found in examples/
directory in which Jupyter notebook-based demos are provided.
DMFF uses OpenMM to parse input files, including coordinates file, topology specification file and force field parameter file. Then, the core class Hamiltonian
inherited from openmm.ForceField
will be initialized and the method createPotential
will be called to create differentiable potential energy functions for different energy terms. Take parametrzing an organic moleclue with GAFF2 force field as an example:
import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborList
app.Topology.loadBondDefinitions("lig-top.xml")
pdb = app.PDBFile("lig.pdb")
ff = Hamiltonian("gaff-2.11.xml", "lig-prm.xml")
potentials = ff.createPotential(pdb.topology)
for k in potentials.dmff_potentials.keys():
pot = potentials.dmff_potentials[k]
print(pot)
In this example, lig.pdb
is the PDB file containing atomic coordinates, and lig-top.xml
specifying bond connections within a molecule and this information is required by openmm.app
to generate molecular topology. Note that this file is not always required, if bond conncections are defined in .pdb file by CONNECT
keyword. gaff-2.11.xml
contains GAFF2 force field parameters (bonds, angles, torsion and vdW), and lig-prm.xml
contains atomic partial charges (GAFF2 requests a user-defined charge assignment process). This xml format is compatitable with OpenMM definitions, and a detailed description can be found in OpenMM user guide or XML-format force fields section.
If you run this script in examples/classical
, you will get the following output.
<function HarmonicBondGenerator.createPotential.<locals>.potential_fn at 0x7fe6c3bd2280>
<function HarmonicAngleGenerator.createPotential.<locals>.potential_fn at 0x7fe6c3bd2670>
<function PeriodicTorsionGenerator.createPotential.<locals>.potential_fn at 0x7fe6c3c4a8b0>
<function NonbondedGenerator.createPotential.<locals>.potential_fn at 0x7fe6c3bd8670>
The force field parameters are stored as a Python dict in the Hamiltonian.
params = ff.getParameters()
nbparam = params['NonbondedForce']
nbparam
{
'sigma': DeviceArray([0.33152124, ...], dtype=float32),
'epsilon': DeviceArray([0.4133792, ...], dtype=float32),
'epsfix': DeviceArray([], dtype=float32),
'sigfix': DeviceArray([], dtype=float32),
'charge': DeviceArray([-0.75401515, ...], dtype=float32),
'coulomb14scale': DeviceArray([0.8333333], dtype=float32),
'lj14scale': DeviceArray([0.5], dtype=float32)
}
Each generated function will read coordinates, box, pairs and force field parameters as inputs.
positions = jnp.array(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
box = jnp.array([
[10.0, 0.0, 0.0],
[ 0.0, 10.0, 0.0],
[ 0.0, 0.0, 10.0]
])
nbList = NeighborList(box, rcut=4, cov_map=potentials.meta["cov_map"])
nbList.allocate(positions)
pairs = nbList.pairs
Note that in order to take advantages of the auto-differentiable implementation in JAX, the input arrays have to be jax.numpy.ndarray
, otherwise DMFF will raise an error. pairs
is a rcut
. As shown above, this can be calculated with dmff.NeighborList
class which is supported by jax_md
.
The potential energy function will give energy (a scalar, in kJ/mol) as output:
nbfunc = potentials.dmff_potentials['NonbondedForce']
nbene = nbfunc(positions, box, pairs, params)
print(nbene)
If everything works fine, you will get -425.40470
as a result. In addition, you can also use getPotentialFunc()
and getParameters()
to obtain the whole potential energy function and force field parameter set, instead of seperated functions for different energy terms.
efunc = potentials.getPotentialFunc()
params = ff.getParameters()
totene = efunc(positions, box, pairs, params)
Different from conventional programming frameworks, explicit definition of atomic force calculation functions are no longer needed. Instead, the forces can be evaluated in an automatic manner with jax.grad
.
pos_grad_func = jax.grad(efunc, argnums=0)
force = -pos_grad_func(positions, box, pairs, params)
Similarly, the derivative of energy with regard to force field parameters can also be computed easily.
param_grad_func = jax.grad(nbfunc, argnums=-1)
pgrad = param_grad_func(positions, box, pairs, params)
print(pgrad["NonbondedForce"]["sigma"])
[ 1.12090099e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
7.57040892e+02 1.45521139e+03 0.00000000e+00 0.00000000e+00
6.78143151e+01 5.87935802e+01 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 4.97000516e+02
0.00000000e+00 0.00000000e+00 1.69941295e+01 0.00000000e+00
4.15689683e+02 1.07864961e+02 -1.05927404e+01 -3.34661347e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 1.95772900e+01
9.87994968e+01 0.00000000e+00 0.00000000e+00 0.00000000e+00
5.21105110e+01 0.00000000e+00 0.00000000e+00 0.00000000e+00
3.11528443e+01 5.66372398e+01 6.27484044e+02 0.00000000e+00
1.87121279e+02 0.00000000e+00 0.00000000e+00 1.16098707e+02
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 6.27866628e+01 5.10221034e+01
0.00000000e+00 0.00000000e+00 3.60011535e+01 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00]