Skip to content

Commit

Permalink
Merge branch 'wangxy/v1.0.0-devel' of github.com:deepmodeling/DMFF in…
Browse files Browse the repository at this point in the history
…to wangxy/v1.0.0-devel
  • Loading branch information
KuangYu committed Oct 22, 2023
2 parents 10dbff8 + 5acacff commit 223ac95
Show file tree
Hide file tree
Showing 10 changed files with 860 additions and 14 deletions.
2 changes: 1 addition & 1 deletion dmff/admp/multipole.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial

import jax.numpy as jnp
from dmff.utils import jit_condition
from ..utils import jit_condition
from jax import vmap

# This module deals with the transformations and rotations of multipoles
Expand Down
4 changes: 2 additions & 2 deletions dmff/admp/pairwise.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from functools import partial

import jax.numpy as jnp
from dmff.admp.spatial import v_pbc_shift
from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs
from .spatial import v_pbc_shift
from ..utils import jit_condition, pair_buffer_scales, regularize_pairs
from jax import vmap

DIELECTRIC = 1389.35455846
Expand Down
2 changes: 1 addition & 1 deletion dmff/admp/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from collections import defaultdict
import jax.numpy as jnp
from dmff.admp.multipole import convert_cart2harm
from .multipole import convert_cart2harm

def read_atom_line(line_full):
"""
Expand Down
16 changes: 8 additions & 8 deletions dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@
from jax import grad, value_and_grad, vmap, jit
from jax.scipy.special import erf, erfc

from dmff.settings import DO_JIT
from dmff.common.constants import DIELECTRIC
from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.settings import POL_CONV, MAX_N_POL
from dmff.admp.recip import generate_pme_recip, Ck_1
from dmff.admp.multipole import (
from ..settings import DO_JIT
from ..common.constants import DIELECTRIC
from ..utils import jit_condition, regularize_pairs, pair_buffer_scales
from .settings import POL_CONV, MAX_N_POL
from .recip import generate_pme_recip, Ck_1
from .multipole import (
C1_c2h,
convert_cart2harm,
rot_ind_global2local,
rot_global2local,
rot_local2global
)
from dmff.admp.spatial import (
from .spatial import (
v_pbc_shift,
generate_construct_local_frames,
build_quasi_internal
)
from dmff.admp.pairwise import (
from .pairwise import (
distribute_scalar,
distribute_v3,
distribute_multipoles,
Expand Down
311 changes: 311 additions & 0 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
import numpy as np
import jax.numpy as jnp
from ..common.constants import DIELECTRIC
from jax import grad, vmap
from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce
from typing import Tuple, List
from ..settings import PRECISION

if PRECISION == "double":
CONST_0 = jnp.array(0, dtype=jnp.float64)
CONST_1 = jnp.array(1, dtype=jnp.float64)
else:
CONST_0 = jnp.array(0, dtype=jnp.float32)
CONST_1 = jnp.array(1, dtype=jnp.float32)

try:
import jaxopt
except ImportError:
print("jaxopt not found, QEQ cannot be used.")
import jax

from jax.scipy.special import erf, erfc

from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales


@jit_condition()
def group_sum(val_list, indices):
max_idx = indices.max()
exceed = jnp.piecewise(
indices,
[indices < max_idx, indices >= max_idx],
[lambda x: CONST_1, lambda x: CONST_0],
)
return jnp.sum(val_list[indices] * exceed)


group_sum_vmap = jax.vmap(group_sum, in_axes=(None, 0))


# @jit_condition
def padding_consts(const_list, max_idx):
max_length = max([len(i) for i in const_list])
new_const_list = np.zeros((len(const_list), max_length)) + max_idx
for ncl, cl in enumerate(const_list):
for nitem, item in enumerate(cl):
new_const_list[ncl, nitem] = item
return jnp.array(new_const_list, dtype=int)


@jit_condition()
def E_constQ(q, lagmt, const_list, const_vals):
constraint = (group_sum_vmap(q, const_list) - const_vals) * lagmt
return jnp.sum(constraint)


@jit_condition()
def E_constP(q, lagmt, const_list, const_vals):
constraint = group_sum_vmap(q, const_list) * const_vals
return jnp.sum(constraint)


@vmap
@jit_condition()
def mask_to_zero(v, mask):
return jnp.piecewise(
v, [mask < 1e-5, mask >= 1e-5], [lambda x: CONST_0, lambda x: v]
)


@jit_condition()
def E_sr(pos, box, pairs, q, eta, ds, buffer_scales):
return 0.0


@jit_condition()
def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales):
etasqrt = jnp.sqrt(2 * (eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2))
pre_pair = -eta_piecewise(etasqrt, ds) * DIELECTRIC
pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC
e_sr_pair = pre_pair * q[pairs[:, 0]] * q[pairs[:, 1]] / ds * buffer_scales
e_sr_pair = mask_to_zero(e_sr_pair, buffer_scales)
e_sr_self = pre_self * q * q
e_sr = jnp.sum(e_sr_pair) + jnp.sum(e_sr_self)
return e_sr


@jit_condition()
def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales):
etasqrt = jnp.sqrt(eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2)
epiece = eta_piecewise(etasqrt, ds)
pre_pair = -epiece * DIELECTRIC
pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC
e_sr_pair = pre_pair * q[pairs[:, 0]] * q[pairs[:, 1]] / ds
e_sr_pair = mask_to_zero(e_sr_pair, buffer_scales)
e_sr_self = pre_self * q * q
e_sr = jnp.sum(e_sr_pair) + jnp.sum(e_sr_self)
return e_sr


@jit_condition()
def E_site(chi, J, q):
return 0.0


@jit_condition()
def E_site2(chi, J, q):
ene = (chi * q + 0.5 * J * q**2) * 96.4869
return jnp.sum(ene)


@jit_condition()
def E_site3(chi, J, q):
ene = chi * q * 4.184 + J * q**2 * DIELECTRIC * 2 * jnp.pi
return jnp.sum(ene)


@jit_condition(static_argnums=[5])
def E_corr(pos, box, pairs, q, kappa, neutral_flag=True):
# def E_corr():
V = jnp.linalg.det(box)
pre_corr = 2 * jnp.pi / V * DIELECTRIC
Mz = jnp.sum(q * pos[:, 2])
Q_tot = jnp.sum(q)
Lz = jnp.linalg.norm(box[3])
e_corr = pre_corr * (
Mz**2
- Q_tot * (jnp.sum(q * pos[:, 2] ** 2))
- jnp.power(Q_tot, 2) * jnp.power(Lz, 2) / 12
)
if neutral_flag:
# kappa = pme_potential.pme_force.kappa
pre_corr_non = -jnp.pi / (2 * V * kappa**2) * DIELECTRIC
e_corr_non = pre_corr_non * Q_tot**2
e_corr += e_corr_non
return jnp.sum(e_corr)


@jit_condition
def E_CoulNocutoff(pos, box, pairs, q, ds):
e = q[pairs[:, 0]] * q[pairs[:, 1]] / ds * DIELECTRIC
return jnp.sum(e)


@jit_condition
def E_Coul(pos, box, pairs, q, ds):
return 0.0


@jit_condition(static_argnums=[3])
def ds_pairs(positions, box, pairs, pbc_flag):
pos1 = positions[pairs[:, 0]]
pos2 = positions[pairs[:, 1]]
if pbc_flag is False:
dr = pos1 - pos2
else:
box_inv = jnp.linalg.inv(box)
dpos = pos1 - pos2
dpos = dpos.dot(box_inv)
dpos -= jnp.floor(dpos + 0.5)
dr = dpos.dot(box)
ds = jnp.linalg.norm(dr, axis=1)
return ds


@jit_condition()
def eta_piecewise(eta, ds):
return jnp.piecewise(
eta,
(eta > 1e-4, eta <= 1e-4),
(lambda x: erfc(ds / x), lambda x: x - x),
)


eta_piecewise = jax.vmap(eta_piecewise, in_axes=(0, 0))


@jit_condition()
def etainv_piecewise(eta):
return jnp.piecewise(
eta,
(eta > 1e-4, eta <= 1e-4),
(lambda x: 1 / x, lambda x: x - x),
)


etainv_piecewise = jax.vmap(etainv_piecewise, in_axes=0)


class ADMPQeqForce:
def __init__(
self,
init_q,
r_cut: float,
kappa: float,
K: Tuple[int, int, int],
damp_mod: int = 3,
const_list: List = [],
const_vals: List = [],
neutral_flag: bool = True,
slab_flag: bool = False,
constQ: bool = True,
pbc_flag: bool = True,
):
if not isinstance(const_vals, jnp.ndarray):
self.const_vals = jnp.array(const_vals)
else:
self.const_vals = const_vals
assert len(const_list) == len(
const_vals
), "const_list and const_vals must have the same length"
n_atoms = len(init_q)
self.const_list = padding_consts(const_list, n_atoms)
self.init_q = jnp.array(init_q)
self.init_lagmt = jnp.ones((len(const_list),))

self.damp_mod = damp_mod
self.neutral_flag = neutral_flag
self.slab_flag = slab_flag
self.constQ = constQ
self.pbc_flag = pbc_flag

if constQ:
e_constraint = E_constQ
else:
e_constraint = E_constP
self.e_constraint = e_constraint

if damp_mod == 1:
self.e_sr = E_sr
self.e_site = E_site
elif damp_mod == 2:
self.e_sr = E_sr2
self.e_site = E_site2
elif damp_mod == 3:
self.e_sr = E_sr3
self.e_site = E_site3
else:
raise ValueError("damp_mod must be 1, 2 or 3")

if pbc_flag:
force = CoulombPMEForce(r_cut, kappa, K)
self.kappa = kappa
else:
force = CoulNoCutoffForce()
self.kappa = 1.0
self.coul_energy = force.generate_get_energy()

def generate_get_energy(self):
@jit_condition()
def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales):
e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals)
e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, ds * 10, buffer_scales)
e3 = self.e_site(chi, J, q)
e4 = self.coul_energy(pos, box, pairs, q, mscales)
e5 = E_corr(
pos * 10.0, box * 10.0, pairs, q, self.kappa / 10, self.neutral_flag
)
return e1 + e2 + e3 + e4 + e5

grad_E_full = grad(E_full, argnums=(0, 1))

@jit_condition()
def E_grads(
b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
):
n_const = len(self.const_vals)
q = b_value[:-n_const]
lagmt = b_value[-n_const:]

g1, g2 = grad_E_full(
q, lagmt, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
)
g = jnp.concatenate((g1, g2))
return g

def get_energy(positions, box, pairs, mscales, eta, chi, J):
pos = positions
ds = ds_pairs(pos, box, pairs, self.pbc_flag)
buffer_scales = pair_buffer_scales(pairs)

n_const = len(self.init_lagmt)
b_value = jnp.concatenate((self.init_q, self.init_lagmt))
rf = jaxopt.ScipyRootFinding(
optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10
)
b_0, _ = rf.run(
b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
)
b_0 = jax.lax.stop_gradient(b_0)
q_0 = b_0[:-n_const]
lagmt_0 = b_0[-n_const:]
print("Q:", q_0)
print("Lagrange_multi:", lagmt_0)

energy = E_full(
q_0,
lagmt_0,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
return energy

return get_energy
1 change: 1 addition & 0 deletions dmff/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .classical import *
from .admp import *
from .ml import *
from .qeq import *
Loading

0 comments on commit 223ac95

Please sign in to comment.