Skip to content

Commit

Permalink
update pgrad method in qeq
Browse files Browse the repository at this point in the history
  • Loading branch information
ChiahsinChu committed Jul 23, 2024
1 parent a04bc10 commit 2ec8f88
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 45 deletions.
68 changes: 55 additions & 13 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
except ImportError:
JAXOPT_OLD = True
import warnings

warnings.warn(
"jaxopt is too old. The QEQ potential function cannot be jitted. Please update jaxopt to the latest version for speed concern."
)
except ImportError:
import warnings

warnings.warn("jaxopt not found, QEQ cannot be used.")
import jax

from jax.scipy.special import erf, erfc
from jax.scipy.special import erfc

from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales

Expand Down Expand Up @@ -198,6 +200,22 @@ def etainv_piecewise(eta):
etainv_piecewise = jax.vmap(etainv_piecewise, in_axes=0)


def fn_value_and_proj_grad(func, constraint_matrix, has_aux=False):
def value_and_proj_grad(*arg, **kwargs):
value, grad = jax.value_and_grad(func, has_aux=has_aux)(*arg, **kwargs)
# n * 1
a = jnp.matmul(constraint_matrix, grad.reshape(-1, 1))
# n * 1
b = jnp.sum(constraint_matrix * constraint_matrix, axis=1, keepdims=True)
# 1 * N
delta_grad = jnp.matmul((a / b).T, constraint_matrix)
# N
proj_grad = grad - delta_grad.reshape(-1)
return value, proj_grad

return value_and_proj_grad


class ADMPQeqForce:
def __init__(
self,
Expand All @@ -214,6 +232,7 @@ def __init__(
pbc_flag: bool = True,
has_aux=False,
method="root_finding",
pgrad_kwargs={},
):
self.has_aux = has_aux
const_vals = np.array(const_vals)
Expand All @@ -234,6 +253,7 @@ def __init__(
self.constQ = constQ
self.pbc_flag = pbc_flag
self.method = method
self.pgrad_kwargs = pgrad_kwargs

if constQ:
e_constraint = E_constQ
Expand Down Expand Up @@ -294,7 +314,9 @@ def coul_energy(positions, box, pairs, q, mscales):
else:

def get_coul_energy(dr_vec, chrgprod, box):
dr_norm = jnp.linalg.norm(dr_vec + 1e-64, axis=1) # add eta to avoid division by zero
dr_norm = jnp.linalg.norm(
dr_vec + 1e-64, axis=1
) # add eta to avoid division by zero

dr_inv = 1.0 / dr_norm
E = chrgprod * DIELECTRIC * 0.1 * dr_inv
Expand Down Expand Up @@ -488,18 +510,37 @@ def get_energy_mat_inv(
def get_energy_pgrad(
chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
):
if self.has_aux:
init_q = aux["q"]
else:
init_q = self.init_q
pg = jaxopt.ProjectedGradient(
fun=E_no_constraint,
projection=jaxopt.projection.projection_hyperplane,
tol=1e-2,
# if self.has_aux:
# init_q = aux["q"]
# else:
# init_q = self.init_q

init_q = jnp.zeros_like(self.init_q)
n_atoms = len(init_q)

def const_matrix(n_atoms: int, indices):
n_const = indices.shape[0]
ref_ids = jnp.tile(jnp.arange(n_atoms).reshape(1, -1), [n_const, 1])
mask = jnp.where(jnp.isin(ref_ids, indices), 1, 0)
return mask

# build the constraint matrix based on the const_list
# one at the index of the const_list, and zero otherwise
# n_const * n_atoms
constraint_matrix = const_matrix(n_atoms, self.const_list)
func = fn_value_and_proj_grad(
E_no_constraint,
constraint_matrix,
)
q_0, _ = pg.run(
# tol in LBFGS: norm(grad)
solver = jaxopt.LBFGS(
fun=func,
value_and_grad=True,
tol=1e-3 * n_atoms,
**self.pgrad_kwargs,
)
res = solver.run(
init_q,
hyperparams_proj=(jnp.ones_like(init_q), 0.0),
chi=chi,
J=J,
pos=positions,
Expand All @@ -510,8 +551,9 @@ def get_energy_pgrad(
buffer_scales=buffer_scales,
mscales=mscales,
)
q_0 = res.params
q_0 = jax.lax.stop_gradient(q_0)

energy = E_no_constraint(
q_0,
chi,
Expand Down
3 changes: 2 additions & 1 deletion dmff/generators/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def createPotential(
has_aux = kwargs["has_aux"]

method = kwargs.get("method", "root_finding")

pgrad_kwargs = kwargs.get("pgrad_kwargs", {})
qeq_force = ADMPQeqForce(
init_q,
r_cut,
Expand All @@ -192,6 +192,7 @@ def createPotential(
pbc_flag=(not isNoCut),
has_aux=has_aux,
method=method,
pgrad_kwargs=pgrad_kwargs,
)
qeq_energy = qeq_force.generate_get_energy()

Expand Down
87 changes: 56 additions & 31 deletions tests/test_admp/test_qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,34 @@ def test_qeq_energy():

nblist = NeighborList(box, 0.6, dmfftop.buildCovMat())
pairs = nblist.allocate(pos)

pot = hamilt.createPotential(dmfftop, nonbondedCutoff=0.6*unit.nanometer, nonbondedMethod=app.PME,
ethresh=1e-3, neutral=True, slab=True, constQ=True
)
efunc = pot.getPotentialFunc()
energy = efunc(pos, box, pairs, hamilt.paramset.parameters)
np.testing.assert_almost_equal(energy, -37.84692763, decimal=3)
for method in ["root_finding", "mat_inv", "pgrad"]:
pot = hamilt.createPotential(
dmfftop,
nonbondedCutoff=0.6 * unit.nanometer,
nonbondedMethod=app.PME,
ethresh=1e-3,
neutral=True,
slab=True,
constQ=True,
method=method,
)
efunc = pot.getPotentialFunc()
energy = efunc(pos, box, pairs, hamilt.paramset.parameters)
np.testing.assert_almost_equal(energy, -37.84692763, decimal=3)


def test_qeq_energy_2res():
rc = 0.6
xml = XMLIO()
xml.loadXML("tests/data/qeq2.xml")
res = xml.parseResidues()
charges = [a["charge"] for a in res[0]["particles"]] + [a["charge"] for a in res[1]["particles"]]
charges = [a["charge"] for a in res[0]["particles"]] + [
a["charge"] for a in res[1]["particles"]
]
charges = np.zeros((len(charges),))
types = [a["type"] for a in res[0]["particles"]] + [a["type"] for a in res[1]["particles"]]
types = [a["type"] for a in res[0]["particles"]] + [
a["type"] for a in res[1]["particles"]
]

pdb = app.PDBFile("tests/data/qeq2.pdb")
top = pdb.topology
Expand Down Expand Up @@ -74,33 +85,41 @@ def test_qeq_energy_2res():
const_list[-1].append(ii)
const_val = [0.0, 0.0]

pot = hamilt.createPotential(dmfftop, nonbondedCutoff=rc*unit.nanometer, nonbondedMethod=app.PME,
ethresh=1e-3, neutral=True, slab=True, constQ=True,
const_list=const_list, const_vals=const_val,
has_aux=True
)
pot = hamilt.createPotential(
dmfftop,
nonbondedCutoff=rc * unit.nanometer,
nonbondedMethod=app.PME,
ethresh=1e-3,
neutral=True,
slab=True,
constQ=True,
const_list=const_list,
const_vals=const_val,
has_aux=True,
)
efunc = pot.getPotentialFunc()
aux = {
"q": jnp.array(charges),
"lagmt": jnp.array([1.0, 1.0])
}
aux = {"q": jnp.array(charges), "lagmt": jnp.array([1.0, 1.0])}
energy, aux = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
print(aux)
# print(aux)
np.testing.assert_almost_equal(energy, 4817.295171, decimal=2)

grad = jax.grad(efunc, argnums=0, has_aux=True)
gradient, aux = grad(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
print(gradient)
# print(gradient)


def _test_qeq_energy_2res_jit():
rc = 0.6
xml = XMLIO()
xml.loadXML("tests/data/qeq2.xml")
res = xml.parseResidues()
charges = [a["charge"] for a in res[0]["particles"]] + [a["charge"] for a in res[1]["particles"]]
charges = [a["charge"] for a in res[0]["particles"]] + [
a["charge"] for a in res[1]["particles"]
]
charges = np.zeros((len(charges),))
types = [a["type"] for a in res[0]["particles"]] + [a["type"] for a in res[1]["particles"]]
types = [a["type"] for a in res[0]["particles"]] + [
a["type"] for a in res[1]["particles"]
]

pdb = app.PDBFile("tests/data/qeq2.pdb")
top = pdb.topology
Expand Down Expand Up @@ -128,17 +147,23 @@ def _test_qeq_energy_2res_jit():
const_list[-1].append(ii)
const_val = [0.0, 0.0]

pot = hamilt.createPotential(dmfftop, nonbondedCutoff=rc*unit.nanometer, nonbondedMethod=app.PME,
ethresh=1e-3, neutral=True, slab=True, constQ=True,
const_list=const_list, const_vals=const_val,
has_aux=True
)
for method in ["root_finding", "mat_inv", "pgrad"]:
pot = hamilt.createPotential(
dmfftop,
nonbondedCutoff=rc * unit.nanometer,
nonbondedMethod=app.PME,
ethresh=1e-3,
neutral=True,
slab=True,
constQ=True,
const_list=const_list,
const_vals=const_val,
has_aux=True,
method=method,
)
efunc = jax.jit(pot.getPotentialFunc())
grad = jax.jit(jax.grad(efunc, argnums=0, has_aux=True))
aux = {
"q": jnp.array(charges),
"lagmt": jnp.array([1.0, 1.0])
}
aux = {"q": jnp.array(charges), "lagmt": jnp.array([1.0, 1.0])}
print("Start computing energy and force.")
energy, aux = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
print(aux)
Expand Down

0 comments on commit 2ec8f88

Please sign in to comment.