Skip to content

Commit

Permalink
Merge pull request #45 from deepmodeling/devel
Browse files Browse the repository at this point in the history
Merge devel into master
  • Loading branch information
Ericwang6 authored Jun 17, 2022
2 parents 9799844 + 4bd556d commit af4b1f2
Show file tree
Hide file tree
Showing 9 changed files with 626 additions and 425 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -778,4 +778,7 @@ FodyWeavers.xsd
.vscode/**

# acpype cache
*.acpype/
*.acpype/

*/_date.py
*/_version.py
37 changes: 20 additions & 17 deletions dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
import xml.etree.ElementTree as ET
from copy import deepcopy
import warnings

import numpy as np
import jax.numpy as jnp
Expand Down Expand Up @@ -1189,10 +1190,11 @@ def __init__(self, hamiltonian):
def registerBondType(self, bond):
typetxt = findAtomTypeTexts(bond, 2)
types = self.ff._findAtomTypes(bond, 2)
self.types.append(types)
self.typetexts.append(typetxt)
self.params["k"].append(float(bond["k"]))
self.params["length"].append(float(bond["length"])) # length := r0
if None not in types:
self.types.append(types)
self.typetexts.append(typetxt)
self.params["k"].append(float(bond["k"]))
self.params["length"].append(float(bond["length"])) # length := r0

@staticmethod
def parseElement(element, hamiltonian):
Expand Down Expand Up @@ -1287,9 +1289,10 @@ def __init__(self, hamiltonian):

def registerAngleType(self, angle):
types = self.ff._findAtomTypes(angle, 3)
self.types.append(types)
self.params["k"].append(float(angle["k"]))
self.params["angle"].append(float(angle["angle"]))
if None not in types:
self.types.append(types)
self.params["k"].append(float(angle["k"]))
self.params["angle"].append(float(angle["angle"]))

@staticmethod
def parseElement(element, hamiltonian):
Expand All @@ -1302,8 +1305,12 @@ def parseElement(element, hamiltonian):
<\HarmonicAngleForce>
"""
generator = HarmonicAngleJaxGenerator(hamiltonian)
hamiltonian.registerGenerator(generator)
existing = [f for f in hamiltonian._forces if isinstance(f, HarmonicAngleJaxGenerator)]
if len(existing) == 0:
generator = HarmonicAngleJaxGenerator(hamiltonian)
hamiltonian.registerGenerator(generator)
else:
generator = existing[0]
for angletype in element.findall("Angle"):
generator.registerAngleType(angletype.attrib)

Expand Down Expand Up @@ -1342,7 +1349,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
n_angles += 1
break
if not ifFound:
print(
warnings.warn(
"No parameter for angle %i - %i - %i" % (idx1, idx2, idx3)
)

Expand Down Expand Up @@ -1994,7 +2001,7 @@ def renderXML(self):

class NonbondJaxGenerator:

SCALETOL = 1e-5
SCALETOL = 1e-3

def __init__(self, hamiltionian, coulomb14scale, lj14scale):

Expand All @@ -2019,7 +2026,8 @@ def __init__(self, hamiltionian, coulomb14scale, lj14scale):
def registerAtom(self, atom):
# use types in nb cards or resname+atomname in residue cards
types = self.ff._findAtomTypes(atom, 1)[0]
self.types.append(types)
if None not in types:
self.types.append(types)

for key in ["sigma", "epsilon", "charge"]:
if key not in self.useAttributeFromResidue:
Expand Down Expand Up @@ -2065,11 +2073,6 @@ def parseElement(element, ff):

generator.n_atoms = len(element.findall("Atom"))

# jax it!
for k in generator.params.keys():
generator.params[k] = jnp.array(generator.params[k])
generator.types = np.array(generator.types)

def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
methodMap = {
app.NoCutoff: "NoCutoff",
Expand Down
5 changes: 0 additions & 5 deletions dmff/classical/inter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ def get_LJ_energy(dr_vec, sig, eps, box):
if self.ifPBC:
dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box))
dr_norm = jnp.linalg.norm(dr_vec, axis=1)
if not self.ifNoCut:
msk = dr_norm <= self.r_cut
sig = sig[msk]
eps = eps[msk]
dr_norm = dr_norm[msk]

dr_inv = 1.0 / dr_norm
sig_dr = sig * dr_inv
Expand Down
221 changes: 221 additions & 0 deletions examples/classical/demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Classical Force Field in DMFF"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"DMFF implements classcial molecular mechanics force fields with the following forms:\n",
"\n",
"$$\\begin{align*}\n",
" V(\\mathbf{R}) &= V_{\\mathrm{bond}} + V_{\\mathrm{angle}} + V_\\mathrm{torsion} + V_\\mathrm{vdW} + V_\\mathrm{Coulomb} \\\\\n",
" &= \\sum_{\\mathrm{bonds}}\\frac{1}{2}k_b(r - r_0)^2 + \\sum_{\\mathrm{angles}}\\frac{1}{2}k_\\theta (\\theta -\\theta_0)^2 + \\sum_{\\mathrm{torsion}}\\sum_{n=1}^4 V_n[1+\\cos(n\\phi - \\phi_s)] \\\\\n",
" &\\quad+ \\sum_{ij}4\\varepsilon_{ij}\\left[\\left(\\frac{\\sigma_{ij}}{r_{ij}}\\right)^{12} - \\left(\\frac{\\sigma_{ij}}{r_{ij}}\\right)^6\\right] + \\sum_{ij}\\frac{q_iq_j}{4\\pi\\varepsilon_0r_{ij}}\n",
"\\end{align*}$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import necessary packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import openmm.app as app\n",
"import openmm.unit as unit\n",
"from dmff import Hamiltonian, NeighborList"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute energy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"DMFF uses **OpenMM** to parse input files, including coordinates files, topology specification files. Class `Hamiltonian` inherited from `openmm.ForceField` will be initialized and used to parse force field parameters in XML format. Take parametrzing an organic moleclue with GAFF2 force field as an example.\n",
"\n",
"- `lig_top.xml`: Define bond connections (topology). Not necessary if such information is specified in pdb with `CONNECT` keyword.\n",
"- `gaff-2.11.xml`: GAFF2 force field parameters: bonds, angles, torsions and vdW params\n",
"- `lig-prm.xml`: Atomic charges"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"app.Topology.loadBondDefinitions(\"lig-top.xml\")\n",
"pdb = app.PDBFile(\"lig.pdb\")\n",
"ff = Hamiltonian(\"gaff-2.11.xml\", \"lig-prm.xml\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The method `Hamiltonian.createPotential` will be called to create differentiable potential energy functions for different energy terms. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"potentials = ff.createPotential(\n",
" pdb.topology,\n",
" nonbondedMethod=app.NoCutoff\n",
")\n",
"for pot in potentials:\n",
" print(pot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The force field parameters are stored as a Python dict in the `param` attribute of force generators."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nbparam = ff.getGenerators()[3].params\n",
"nbparam[\"charge\"] # also \"epsilon\", \"sigma\" etc. keys"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each generated function will read **coordinates, box, pairs** and force field parameters as inputs. `pairs` is a integer array in which each row specifying atoms condsidered as neighbors within rcut. This can be calculated with `dmff.NeighborList` class which is supported by `jax_md`.\n",
"\n",
"The potential energy function will give energy (a scalar, in kJ/mol) as output:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"positions = jnp.array(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))\n",
"box = jnp.array([\n",
" [10.0, 0.0, 0.0], \n",
" [0.0, 10.0, 0.0],\n",
" [0.0, 0.0, 10.0]\n",
"])\n",
"nbList = NeighborList(box, rc=4)\n",
"nbList.allocate(positions)\n",
"pairs = nbList.pairs\n",
"nbfunc = potentials[-1]\n",
"energy = nbfunc(positions, box, pairs, ff.getGenerators()[-1].params)\n",
"print(energy)\n",
"print(pairs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also obtain the whole potential energy function and force field parameter set, instead of seperated functions for different energy terms."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"efunc = ff.getPotentialFunc()\n",
"params = ff.getParameters()\n",
"totene = efunc(positions, box, pairs, params)\n",
"totene"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute forces and parametric gradients"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use `jax.grad` to compute forces and parametric gradients automatically"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pos_grad_func = jax.grad(efunc, argnums=0)\n",
"force = -pos_grad_func(positions, box, pairs, params)\n",
"force.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"param_grad_func = jax.grad(nbfunc, argnums=-1)\n",
"pgrad = param_grad_func(positions, box, pairs, nbparam)\n",
"pgrad[\"charge\"]"
]
}
],
"metadata": {
"interpreter": {
"hash": "44fe82502fda871be637af1aa98d2b3ddaac01204dd30f1519cbec4e95000815"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit af4b1f2

Please sign in to comment.