Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add methods for QEq #184

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 225 additions & 48 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import jax.numpy as jnp
from ..common.constants import DIELECTRIC
from scipy import constants
from ..common.constants import DIELECTRIC, ENERGY_COEFF
from jax import grad, vmap
from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce
from typing import Tuple, List
Expand All @@ -25,15 +26,17 @@
except ImportError:
JAXOPT_OLD = True
import warnings

warnings.warn(
"jaxopt is too old. The QEQ potential function cannot be jitted. Please update jaxopt to the latest version for speed concern."
)
except ImportError:
import warnings

warnings.warn("jaxopt not found, QEQ cannot be used.")
import jax

from jax.scipy.special import erf, erfc
from jax.scipy.special import erfc

from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales

Expand Down Expand Up @@ -126,13 +129,13 @@ def E_site(chi, J, q):

@jit_condition()
def E_site2(chi, J, q):
ene = (chi * q + 0.5 * J * q**2) * 96.4869
ene = (chi * q + 0.5 * J * q**2) * ENERGY_COEFF
return jnp.sum(ene)


@jit_condition()
def E_site3(chi, J, q):
ene = chi * q * 4.184 + J * q**2 * DIELECTRIC * 2 * jnp.pi
ene = chi * q * constants.calorie + J * q**2 * DIELECTRIC * 2 * jnp.pi
return jnp.sum(ene)


Expand Down Expand Up @@ -197,6 +200,22 @@ def etainv_piecewise(eta):
etainv_piecewise = jax.vmap(etainv_piecewise, in_axes=0)


def fn_value_and_proj_grad(func, constraint_matrix, has_aux=False):
def value_and_proj_grad(*arg, **kwargs):
value, grad = jax.value_and_grad(func, has_aux=has_aux)(*arg, **kwargs)
# n * 1
a = jnp.matmul(constraint_matrix, grad.reshape(-1, 1))
# n * 1
b = jnp.sum(constraint_matrix * constraint_matrix, axis=1, keepdims=True)
# 1 * N
delta_grad = jnp.matmul((a / b).T, constraint_matrix)
# N
proj_grad = grad - delta_grad.reshape(-1)
return value, proj_grad

return value_and_proj_grad


class ADMPQeqForce:
def __init__(
self,
Expand All @@ -212,6 +231,8 @@ def __init__(
constQ: bool = True,
pbc_flag: bool = True,
has_aux=False,
method="root_finding",
pgrad_kwargs={},
):
self.has_aux = has_aux
const_vals = np.array(const_vals)
Expand All @@ -231,6 +252,8 @@ def __init__(
self.slab_flag = slab_flag
self.constQ = constQ
self.pbc_flag = pbc_flag
self.method = method
self.pgrad_kwargs = pgrad_kwargs

if constQ:
e_constraint = E_constQ
Expand Down Expand Up @@ -291,7 +314,9 @@ def coul_energy(positions, box, pairs, q, mscales):
else:

def get_coul_energy(dr_vec, chrgprod, box):
dr_norm = jnp.linalg.norm(dr_vec + 1e-64, axis=1) # add eta to avoid division by zero
dr_norm = jnp.linalg.norm(
dr_vec + 1e-64, axis=1
) # add eta to avoid division by zero

dr_inv = 1.0 / dr_norm
E = chrgprod * DIELECTRIC * 0.1 * dr_inv
Expand Down Expand Up @@ -349,57 +374,209 @@ def E_grads(
g = jnp.concatenate((g1, g2))
return g

hess_E_charge = jax.jacfwd(jax.jacrev(E_full))

@jit_condition()
def E_hess(
b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
):
n_const = len(self.const_vals)
q = b_value[:-n_const]
lagmt = b_value[-n_const:]
h_q = hess_E_charge(
q, lagmt, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
)
hess = jnp.pad(h_q, (0, 1), constant_values=1.0)
hess = hess.at[-1, -1].set(0.0)
return hess

@jit_condition()
def E_no_constraint(
q, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales
):
e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, ds * 10, buffer_scales)
e3 = self.e_site(chi, J, q)
e4 = self.coul_energy(pos, box, pairs, q, mscales)
if self.slab_flag:
e5 = E_corr(
pos * 10.0, box * 10.0, pairs, q, self.kappa / 10, self.neutral_flag
)
return e2 + e3 + e4 + e5
else:
return e2 + e3 + e4

def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
pos = positions
ds = ds_pairs(pos, box, pairs, self.pbc_flag)
buffer_scales = pair_buffer_scales(pairs)

n_const = len(self.init_lagmt)
if self.has_aux:
b_value = jnp.concatenate((aux["q"], aux["lagmt"]))
else:
b_value = jnp.concatenate([self.init_q, self.init_lagmt])
# if JAXOPT_OLD:
if True:
def get_energy_root_finding(
chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
):
if self.has_aux:
b_value = jnp.concatenate((aux["q"], aux["lagmt"]))
else:
b_value = jnp.concatenate([self.init_q, self.init_lagmt])
rf = jaxopt.ScipyRootFinding(
optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10
)
else:
rf = jaxopt.Broyden(fun=E_grads, tol=1e-10)
b_0, _ = rf.run(
b_value,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
b_0 = jax.lax.stop_gradient(b_0)
q_0 = b_0[:-n_const]
lagmt_0 = b_0[-n_const:]

energy = E_full(
q_0,
lagmt_0,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
b_0, _ = rf.run(
b_value,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
b_0 = jax.lax.stop_gradient(b_0)
# return b_0
n_const = len(self.init_lagmt)
q_0 = b_0[:-n_const]
lagmt_0 = b_0[-n_const:]

energy = E_full(
q_0,
lagmt_0,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
if self.has_aux:
aux["q"] = q_0
aux["lagmt"] = lagmt_0
return energy, aux
else:
return energy

def get_energy_mat_inv(
chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
):
if self.has_aux:
b_value = jnp.concatenate((aux["q"], aux["lagmt"]))
else:
b_value = jnp.concatenate([self.init_q, self.init_lagmt])
hessian = E_hess(
b_value,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
vector = jnp.concatenate([-chi * 4.184, self.const_vals])
# # E_site2
# vector = torch.concat([-chi / KJ2EV, self.const_vals])
b_0 = jnp.linalg.solve(hessian, vector.reshape(-1, 1)).reshape(-1)
b_0 = jax.lax.stop_gradient(b_0)

n_const = len(self.init_lagmt)
q_0 = b_0[:-n_const]
lagmt_0 = b_0[-n_const:]

energy = E_full(
q_0,
lagmt_0,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
if self.has_aux:
aux["q"] = q_0
aux["lagmt"] = lagmt_0
return energy, aux
else:
return energy

def get_energy_pgrad(
chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
):
# if self.has_aux:
# init_q = aux["q"]
# else:
# init_q = self.init_q

init_q = jnp.zeros_like(self.init_q)
n_atoms = len(init_q)

def const_matrix(n_atoms: int, indices):
n_const = indices.shape[0]
ref_ids = jnp.tile(jnp.arange(n_atoms).reshape(1, -1), [n_const, 1])
mask = jnp.where(jnp.isin(ref_ids, indices), 1, 0)
return mask

# build the constraint matrix based on the const_list
# one at the index of the const_list, and zero otherwise
# n_const * n_atoms
constraint_matrix = const_matrix(n_atoms, self.const_list)
func = fn_value_and_proj_grad(
E_no_constraint,
constraint_matrix,
)
# tol in LBFGS: norm(grad)
solver = jaxopt.LBFGS(
fun=func,
value_and_grad=True,
tol=1e-3 * n_atoms,
**self.pgrad_kwargs,
)
res = solver.run(
init_q,
chi=chi,
J=J,
pos=positions,
box=box,
pairs=pairs,
eta=eta,
ds=ds,
buffer_scales=buffer_scales,
mscales=mscales,
)
q_0 = res.params
q_0 = jax.lax.stop_gradient(q_0)

energy = E_no_constraint(
q_0,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
if self.has_aux:
aux["q"] = q_0
return energy, aux
else:
return energy

opt_func = locals().get("get_energy_%s" % self.method)
if opt_func is None:
raise ValueError(f"method {self.method} is not supported")
return opt_func(
chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
)
if self.has_aux:
aux["q"] = q_0
aux["lagmt"] = lagmt_0
return energy, aux
else:
return energy

return get_energy
14 changes: 12 additions & 2 deletions dmff/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
import numpy as np
from scipy import constants

DIELECTRIC = 1389.35455846
SQRT_PI = np.sqrt(np.pi)

SQRT_PI = np.sqrt(np.pi)

J2EV = constants.physical_constants["joule-electron volt relationship"][0]
# from kJ/mol to eV/particle
ENERGY_COEFF = J2EV * constants.kilo / constants.Avogadro

# vacuum electric permittivity in eV^-1 * angstrom^-1
EPSILON = constants.epsilon_0 / constants.elementary_charge * constants.angstrom
# DIELECTRIC = 1389.35455846
DIELECTRIC = 1 / (4 * np.pi * EPSILON) / ENERGY_COEFF
4 changes: 4 additions & 0 deletions dmff/generators/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def createPotential(
if "has_aux" in kwargs:
has_aux = kwargs["has_aux"]

method = kwargs.get("method", "root_finding")
pgrad_kwargs = kwargs.get("pgrad_kwargs", {})
qeq_force = ADMPQeqForce(
init_q,
r_cut,
Expand All @@ -189,6 +191,8 @@ def createPotential(
constQ=constQ,
pbc_flag=(not isNoCut),
has_aux=has_aux,
method=method,
pgrad_kwargs=pgrad_kwargs,
)
qeq_energy = qeq_force.generate_get_energy()

Expand Down
Loading
Loading