Skip to content

Commit

Permalink
Merge pull request #20 from KuangYu/master
Browse files Browse the repository at this point in the history
Merge devel to master
  • Loading branch information
KuangYu authored May 5, 2022
2 parents 7fbcf9c + 31f1b87 commit df21c94
Show file tree
Hide file tree
Showing 150 changed files with 64,439 additions and 501 deletions.
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@

# temporary
err
out
sub.sh
*.npy

### C++ ###
# Prerequisites
*.d
Expand Down Expand Up @@ -463,7 +470,6 @@ StyleCopReport.xml
*.ilk
*.meta
*.iobj
*.pdb
*.ipdb
*.pgc
*.pgd
Expand Down Expand Up @@ -769,3 +775,4 @@ FodyWeavers.xsd

### VisualStudio Patch ###
# Additional files built by Visual Studio
.vscode/**
27 changes: 26 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,27 @@
# DMFF
Differentiable Molecular Force Field

**DMFF** (**D**ifferentiable **M**olecular **F**orce **F**ield) is a Jax-based python package that provides a full differentiable implementation of molecular force field models. This project aims to establish an extensible codebase to minimize the efforts in force field parameterization, and to ease the force and virial tensor evaluations for advanced complicated potentials (e.g., polarizable models with geometry-dependent atomic parameters). Currently, this project mainly focuses on the molecular systems such as: water, biological macromolecules (peptides, proteins, nucleic acids), organic polymers, and small organic molecules (organic electrolyte, drug-like molecules) etc. We support both the conventional point charge models (OPLS and AMBER like) and multipolar polarizable models (AMOEBA and MPID like). The entire project is backed by the XLA technique in JAX, thus can be "jitted" and run in GPU devices much more efficiently compared to normal python codes.

The behavior of organic molecular systems (e.g., protein folding, polymer structure, etc.) is often determined by a complex effect of many different types of interactions. The existing organic molecular force fields are mainly empirically fitted and their performance relies heavily on error cancellation. Therefore, the transferability and the prediction power of these force fields are insufficient. For new molecules, the parameter fitting process requires essential manual intervention and can be quite cumbersome. In order to automate the parametrization process and increase the robustness of the model, it is necessary to apply modern AI techniques in conventional force field development. This project serves for this purpose by utilizing the automatic differentiable programming technique to develop a codebase, which allows a more convenient incorporation of modern AI optimization techniques. It also helps the realization of many exciting functions including (but not limited to): hybrid machine learning/force field models and parameter optimization based on trajectory.

## User Guide

+ [1. Introduction](user_guide/introduction.md)
+ [2. Installation](user_guide/installation.md)
+ [3. Compute energy and forces](user_guide/compute.md)
+ [4. Compute gradients with auto differentiable framework](user_guide/auto_diff.md)
+ [5. Theories](user_guide/theory.md)
+ [6. Introduction to force field xml files](user_guide/xml_spec.md)

## Developer Guide
+ [1. Introduction](dev_guide/introduction.md)
+ [2. Architecture](dev_guide/arch.md)
+ [3. Convention](dev_guide/convention.md)

## Modules
+ [1. ADMP](modules/admp.md)


## Support and Contribution

Please visit our repository on [GitHub](https://github.com/deepmodeling/DMFF) for the library source code. Any issues or bugs may be reported at our issue tracker. All contributions to DMFF are welcomed via pull requests!
147 changes: 41 additions & 106 deletions dmff/admp/disp_pme.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax.numpy as jnp
from jax import vmap, value_and_grad
from dmff.utils import jit_condition
from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.spatial import pbc_shift
from dmff.admp.pme import setup_ewald_parameters
from dmff.admp.recip import generate_pme_recip, Ck_6, Ck_8, Ck_10
Expand All @@ -14,17 +14,24 @@ class ADMPDispPmeForce:
The so called "environment paramters" means parameters that do not need to be differentiable
'''

def __init__(self, box, covalent_map, rc, ethresh, pmax):
def __init__(self, box, covalent_map, rc, ethresh, pmax, lpme=True):
self.covalent_map = covalent_map
self.rc = rc
self.ethresh = ethresh
self.pmax = pmax
# Need a different function for dispersion ??? Need tests
kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
self.kappa = kappa
self.K1 = K1
self.K2 = K2
self.K3 = K3
self.lpme = lpme
if lpme:
kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
self.kappa = kappa
self.K1 = K1
self.K2 = K2
self.K3 = K3
else:
self.kappa = 0.0
self.K1 = 0
self.K2 = 0
self.K3 = 0
self.pme_order = 6
# setup calculators
self.refresh_calculators()
Expand All @@ -36,7 +43,7 @@ def get_energy(positions, box, pairs, c_list, mScales):
return energy_disp_pme(positions, box, pairs,
c_list, mScales, self.covalent_map,
self.kappa, self.K1, self.K2, self.K3, self.pmax,
self.d6_recip, self.d8_recip, self.d10_recip)
self.d6_recip, self.d8_recip, self.d10_recip, lpme=self.lpme)
return get_energy


Expand Down Expand Up @@ -70,7 +77,7 @@ def refresh_calculators(self):
def energy_disp_pme(positions, box, pairs,
c_list, mScales, covalent_map,
kappa, K1, K2, K3, pmax,
recip_fn6, recip_fn8, recip_fn10):
recip_fn6, recip_fn8, recip_fn10, lpme=True):
'''
Top level wrapper for dispersion pme
Expand All @@ -95,22 +102,29 @@ def energy_disp_pme(positions, box, pairs,
int: max K for reciprocal calculations
pmax:
int array: maximal exponents (p) to compute, e.g., (6, 8, 10)
lpme:
bool: whether do pme or not, useful when doing cluster calculations
Output:
energy: total dispersion pme energy
'''

ene_real = disp_pme_real(positions, box, pairs, c_list, mScales, covalent_map, kappa, pmax)
if lpme is False:
kappa = 0

ene_recip = recip_fn6(positions, box, c_list[:, 0, jnp.newaxis])
if pmax >= 8:
ene_recip += recip_fn8(positions, box, c_list[:, 1, jnp.newaxis])
if pmax >= 10:
ene_recip += recip_fn10(positions, box, c_list[:, 2, jnp.newaxis])
ene_real = disp_pme_real(positions, box, pairs, c_list, mScales, covalent_map, kappa, pmax)

ene_self = disp_pme_self(c_list, kappa, pmax)
if lpme:
ene_recip = recip_fn6(positions, box, c_list[:, 0, jnp.newaxis])
if pmax >= 8:
ene_recip += recip_fn8(positions, box, c_list[:, 1, jnp.newaxis])
if pmax >= 10:
ene_recip += recip_fn10(positions, box, c_list[:, 2, jnp.newaxis])
ene_self = disp_pme_self(c_list, kappa, pmax)
return ene_real + ene_recip + ene_self

return ene_real + ene_recip + ene_self
else:
return ene_real


def disp_pme_real(positions, box, pairs,
Expand Down Expand Up @@ -144,24 +158,26 @@ def disp_pme_real(positions, box, pairs,
'''

# expand pairwise parameters
pairs = pairs[pairs[:, 0] < pairs[:, 1]]
# pairs = pairs[pairs[:, 0] < pairs[:, 1]]
pairs = regularize_pairs(pairs)

box_inv = jnp.linalg.inv(box)

ri = distribute_v3(positions, pairs[:, 0])
rj = distribute_v3(positions, pairs[:, 1])
# ri = positions[pairs[:, 0]]
# rj = positions[pairs[:, 1]]
nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
mscales = distribute_scalar(mScales, nbonds-1)
# mscales = mScales[nbonds-1]

buffer_scales = pair_buffer_scales(pairs)
mscales = mscales * buffer_scales

ci = distribute_dispcoeff(c_list, pairs[:, 0])
cj = distribute_dispcoeff(c_list, pairs[:, 1])
# ci = c_list[pairs[:, 0], :]
# cj = c_list[pairs[:, 1], :]

ene_real = jnp.sum(disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax))
ene_real = jnp.sum(
disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax)
* buffer_scales
)

return jnp.sum(ene_real)

Expand Down Expand Up @@ -193,6 +209,7 @@ def disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax):
dr = ri - rj
dr = pbc_shift(dr, box, box_inv)
dr2 = jnp.dot(dr, dr)

x2 = kappa * kappa * dr2
g = g_p(x2, pmax)
dr6 = dr2 * dr2 * dr2
Expand Down Expand Up @@ -269,85 +286,3 @@ def disp_pme_self(c_list, kappa, pmax):
return E


# def validation(pdb):
# xml = 'mpidwater.xml'
# pdbinfo = read_pdb(pdb)
# serials = pdbinfo['serials']
# names = pdbinfo['names']
# resNames = pdbinfo['resNames']
# resSeqs = pdbinfo['resSeqs']
# positions = pdbinfo['positions']
# box = pdbinfo['box'] # a, b, c, α, β, γ
# charges = pdbinfo['charges']
# positions = jnp.asarray(positions)
# lx, ly, lz, _, _, _ = box
# box = jnp.eye(3)*jnp.array([lx, ly, lz])

# mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
# pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
# dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])

# rc = 4 # in Angstrom
# ethresh = 1e-4

# n_atoms = len(serials)

# atomTemplate, residueTemplate = read_xml(xml)
# atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate)

# covalent_map = assemble_covalent(residueDicts, n_atoms)
# displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False)
# neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse)
# nbr = neighbor_list_fn.allocate(positions)
# pairs = nbr.idx.T

# pmax = 10
# kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
# kappa = 0.657065221219616

# # construct the C list
# c_list = np.zeros((3,n_atoms))
# nmol=int(n_atoms/3)
# for i in range(nmol):
# a = i*3
# b = i*3+1
# c = i*3+2
# c_list[0][a]=37.19677405
# c_list[0][b]=7.6111103
# c_list[0][c]=7.6111103
# c_list[1][a]=85.26810658
# c_list[1][b]=11.90220148
# c_list[1][c]=11.90220148
# c_list[2][a]=134.44874488
# c_list[2][b]=15.05074749
# c_list[2][c]=15.05074749
# c_list = jnp.array(c_list.T)


# # Finish data preparation
# # -------------------------------------------------------------------------------------
# # pme_order = 6
# # d6_recip = generate_pme_recip(Ck_6, kappa, True, pme_order, K1, K2, K3, 0)
# # d8_recip = generate_pme_recip(Ck_8, kappa, True, pme_order, K1, K2, K3, 0)
# # d10_recip = generate_pme_recip(Ck_10, kappa, True, pme_order, K1, K2, K3, 0)
# # disp_pme_recip_fns = [d6_recip, d8_recip, d10_recip]
# # energy_force_disp_pme = value_and_grad(energy_disp_pme)
# # e, f = energy_force_disp_pme(positions, box, pairs, c_list, mScales, covalent_map, kappa, K1, K2, K3, pmax, *disp_pme_recip_fns)
# # print('ok')
# # e, f = energy_force_disp_pme(positions, box, pairs, c_list, mScales, covalent_map, kappa, K1, K2, K3, pmax, *disp_pme_recip_fns)
# # print(e)

# disp_pme_force = ADMPDispPmeForce(box, covalent_map, rc, ethresh, pmax)
# disp_pme_force.update_env('kappa', 0.657065221219616)

# print(c_list[:4])
# E, F = disp_pme_force.get_forces(positions, box, pairs, c_list, mScales)
# print('ok')
# E, F = disp_pme_force.get_forces(positions, box, pairs, c_list, mScales)
# print(E)
# return


# # below is the validation code
# if __name__ == '__main__':
# validation(sys.argv[1])
Loading

0 comments on commit df21c94

Please sign in to comment.