From 66d2eb90c676e40f5288d4f59f8d2c9fc7a92ad0 Mon Sep 17 00:00:00 2001 From: Haichao Huang <100009402+gust-07@users.noreply.github.com> Date: Fri, 20 Oct 2023 14:46:41 +0800 Subject: [PATCH 1/3] qeq merge (#124) * add qeqforce and QeqQenerator, modify CoulmbGenerator * ethresh modified * add refresh in qeq.py --- dmff/admp/qeq.py | 213 ++++++++++++++++++++++++++++++++ dmff/generators/QeqGenerator.py | 135 ++++++++++++++++++++ dmff/generators/classical.py | 7 +- tests/data/qeq.pdb | 149 ++++++++++++++++++++++ tests/data/qeq.xml | 197 +++++++++++++++++++++++++++++ 5 files changed, 699 insertions(+), 2 deletions(-) create mode 100644 dmff/admp/qeq.py create mode 100644 dmff/generators/QeqGenerator.py create mode 100644 tests/data/qeq.pdb create mode 100644 tests/data/qeq.xml diff --git a/dmff/admp/qeq.py b/dmff/admp/qeq.py new file mode 100644 index 000000000..0c3e0cf77 --- /dev/null +++ b/dmff/admp/qeq.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python +import sys +import absl +import numpy as np +import jax.numpy as jnp +import openmm.app as app +import openmm.unit as unit +from dmff.settings import DO_JIT +from dmff.common.constants import DIELECTRIC +from dmff.common import nblist +from jax_md import space, partition +from jax import grad, value_and_grad, vmap, jit +from jaxopt import OptaxSolver +from itertools import combinations +import jaxopt +import jax +import scipy +import pickle + +from jax.scipy.special import erf, erfc + +from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales + + +jax.config.update("jax_enable_x64", True) + +class ADMPQeqForce: + + def __init__(self, q, lagmt, damp_mod=3, neutral_flag=True, slab_flag=False, constQ=True, pbc_flag = True): + + self.damp_mod = damp_mod + self.neutral_flag = neutral_flag + self.slab_flag = slab_flag + self.constQ = constQ + self.pbc_flag = pbc_flag + self.q = q + self.lagmt = lagmt + return + + def generate_get_energy(self): + # q = self.q + damp_mod = self.damp_mod + neutral_flag = self.neutral_flag + constQ = self.constQ + pbc_flag = self.pbc_flag + # lagmt = self.lagmt + + if eval(constQ) is True: + e_constraint = E_constQ + else: + e_constraint = E_constP + self.e_constraint = e_constraint + + if eval(damp_mod) is False: + e_sr = E_sr0 + e_site = E_site + elif eval(damp_mod) == 2: + e_sr = E_sr2 + e_site = E_site2 + elif eval(damp_mod) == 3: + e_sr = E_sr3 + e_site = E_site3 + + # if pbc_flag is False: + # e_coul = E_CoulNocutoff + # else: + # e_coul = E_coul + def get_energy(positions, box, pairs, q, lagmt, eta, chi, J, const_list, const_vals,pme_generator): + + pos = positions + ds = ds_pairs(pos, box, pairs, pbc_flag) + buffer_scales = pair_buffer_scales(pairs) + kappa = pme_generator.coulforce.kappa + def E_full(q, lagmt, const_vals, chi, J, pos, box, pairs, eta, ds, buffer_scales): + e1 = e_constraint(q, lagmt, const_list, const_vals) + e2 = e_sr(pos*10, box*10 ,pairs , q , eta, ds*10, buffer_scales) + e3 = e_site( chi, J , q) + e4 = pme_generator.coulenergy(pos, box ,pairs, q, pme_generator.mscales_coul) + e5 = E_corr(pos*10, box*10, pairs, q, kappa/10, neutral_flag) + return e1 + e2 + e3 + e4 + e5 + @jit + def E_grads(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales): + n_const = len(const_vals) + q = b_value[:-n_const] + lagmt = b_value[-n_const:] + g1,g2 = grad(E_full,argnums=(0,1))(q, lagmt, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales) + g = jnp.concatenate((g1,g2)) + return g + + def Q_equi(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales): + rf=jaxopt.ScipyRootFinding(optimality_fun=E_grads,method='hybr',jit=False,tol=1e-10) + q0,state1 = rf.run(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales) + return q0,state1 + + def get_chgs(): + n_const = len(self.lagmt) + b_value = jnp.concatenate((self.q,self.lagmt)) + q0,state1 = Q_equi(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales) + self.q = q0[:-n_const] + self.lagmt = q0[-n_const:] + return q0,state1 + + q0,state1 = get_chgs() + self.q0 = q0 + self.state1 = state1 + energy = E_full(self.q, self.lagmt, const_vals, chi, J, positions, box, pairs, eta, ds , buffer_scales) + self.e_grads = E_grads(q0, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales) + self.e_full = E_full + return energy + + return get_energy + def update_env(self, attr, val): + ''' + Update the environment of the calculator + ''' + setattr(self, attr, val) + self.refresh_calculators() + + + def refresh_calculators(self): + ''' + refresh the energy and force calculators according to the current environment + ''' + # generate the force calculator + self.get_energy = self.generate_get_energy() + self.get_forces = value_and_grad(self.get_energy) + return + +def E_constQ(q, lagmt, const_list, const_vals): + constraint = (jnp.sum(q[const_list], axis=1) - const_vals) * lagmt + return np.sum(constraint) +def E_constP(q, lagmt, const_list, const_vals): + constraint = jnp.sum(q[const_list], axis=1) * const_vals + return np.sum(constraint) + +def E_sr(pos, box, pairs, q, eta, ds, buffer_scales ): + return 0 +def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales ): + etasqrt = jnp.sqrt( 2 * ( jnp.array(eta)[pairs[:,0]] **2 + jnp.array(eta)[pairs[:,1]] **2)) + pre_pair = - eta_piecewise(etasqrt,ds) * DIELECTRIC + pre_self = etainv_piecewise(eta) /( jnp.sqrt(2 * jnp.pi)) * DIELECTRIC + e_sr_pair = pre_pair * q[pairs[:,0]] * q[pairs[:,1]] /ds * buffer_scales + e_sr_self = pre_self * q * q + e_sr = jnp.sum(e_sr_pair) + jnp.sum(e_sr_self) + return e_sr +def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales ): + etasqrt = jnp.sqrt( jnp.array(eta)[pairs[:,0]] **2 + jnp.array(eta)[pairs[:,1]] **2 ) + pre_pair = - eta_piecewise(etasqrt,ds) * DIELECTRIC + pre_self = etainv_piecewise(eta) /( jnp.sqrt(2 * jnp.pi)) * DIELECTRIC + e_sr_pair = pre_pair * q[pairs[:,0]] * q[pairs[:,1]] /ds * buffer_scales + e_sr_self = pre_self * q * q + e_sr = jnp.sum(e_sr_pair) + jnp.sum(e_sr_self) + return e_sr + +def E_site(chi, J , q ): + return 0 +def E_site2(chi, J , q ): + ene = (chi * q + 0.5 * J * q **2 ) * 96.4869 + return np.sum(ene) +def E_site3(chi, J , q ): + ene = chi * q *4.184 + J * q **2 *DIELECTRIC * 2 * jnp.pi + return np.sum(ene) + +def E_corr(pos, box, pairs, q, kappa, neutral_flag = True): + # def E_corr(): + V = jnp.linalg.det(box) + pre_corr = 2 * jnp.pi / V * DIELECTRIC + Mz = jnp.sum(q * pos[:,2]) + Q_tot = jnp.sum(q) + Lz = jnp.linalg.norm(box[3]) + e_corr = pre_corr * (Mz **2 - Q_tot * (jnp.sum(q * pos[:,2] **2)) - Q_tot **2 * Lz **2 /12) + if eval(neutral_flag) is True: + # kappa = pme_potential.pme_force.kappa + pre_corr_non = - jnp.pi / (2 * V * kappa **2) * DIELECTRIC + e_corr_non = pre_corr_non * Q_tot **2 + e_corr += e_corr_non + return np.sum( e_corr) + +def E_CoulNocutoff(pos, box, pairs, q, ds): + e = q[pairs[:,0]] * q[pairs[:,1]] /ds * DIELECTRIC + return jnp.sum(e) + +def E_Coul(pos, box, pairs, q, ds): + return 0 + +@jit_condition(static_argnums=(3)) +def ds_pairs(positions, box, pairs, pbc_flag): + pos1 = positions[pairs[:,0].astype(int)] + pos2 = positions[pairs[:,1].astype(int)] + if pbc_flag is False: + dr = pos1 - pos2 + else: + box_inv = jnp.linalg.inv(box) + dpos = pos1 - pos2 + dpos = dpos.dot(box_inv) + dpos -= jnp.floor(dpos+0.5) + dr = dpos.dot(box) + ds = jnp.linalg.norm(dr,axis=1) + return ds + +@jit_condition() +@vmap +def eta_piecewise(eta,ds): + return jnp.piecewise(eta, (eta > 1e-4, eta <= 1e-4), + (lambda x: jnp.array(erfc( ds / eta)), lambda x:jnp.array(0))) + +@jit_condition() +@vmap +def etainv_piecewise(eta): + return jnp.piecewise(eta, (eta > 1e-4, eta <= 1e-4), + (lambda x: jnp.array(1/eta), lambda x:jnp.array(0))) + + diff --git a/dmff/generators/QeqGenerator.py b/dmff/generators/QeqGenerator.py new file mode 100644 index 000000000..aba48d0f8 --- /dev/null +++ b/dmff/generators/QeqGenerator.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python + +import openmm.app as app +import openmm.unit as unit +from typing import Tuple +import numpy as np +import jax.numpy as jnp +import jax +from dmff.api.topology import DMFFTopology +from dmff.api.paramset import ParamSet +from dmff.api.xmlio import XMLIO +from dmff.api.hamiltonian import _DMFFGenerators +from dmff.utils import DMFFException, isinstance_jnp +from dmff.admp.qeq import ADMPQeqForce +from dmff.generators.classical import CoulombGenerator +from dmff.admp import qeq + + +class ADMPQeqGenerator: + def __init__(self, ffinfo:dict, paramset: ParamSet): + + self.name = 'ADMPQeqForce' + self.ffinfo = ffinfo + paramset.addField(self.name) + self.key_type = None + keys , params = [], [] + for node in self.ffinfo["Forces"][self.name]["node"]: + attribs = node["attrib"] + + if self.key_type is None and "type" in attribs: + self.key_type = "type" + elif self.key_type is None and "class" in attribs: + self.key_type = "class" + elif self.key_type is not None and f"{self.key_type}" not in attribs: + raise ValueError("Keyword 'class' or 'type' cannot be used together.") + elif self.key_type is not None and f"{self.key_type}" in attribs: + pass + else: + raise ValueError("Cannot find key type for ADMPQeqForce.") + key = attribs[self.key_type] + keys.append(key) + + chi0 = float(attribs["chi"]) + J0 = float(attribs["J"]) + eta0 = float(attribs["eta"]) + + params.append([chi0, J0, eta0]) + + self.keys = keys + chi = jnp.array([i[0] for i in params]) + J = jnp.array([i[1] for i in params]) + eta = jnp.array([i[2] for i in params]) + + paramset.addParameter(chi, "chi", field=self.name) + paramset.addParameter(J, "J", field=self.name) + paramset.addParameter(eta, "eta", field=self.name) + # default params + self._jaxPotential = None + self.damp_mod = self.ffinfo["Forces"][self.name]["meta"]["DampMod"] + self.neutral_flag = self.ffinfo["Forces"][self.name]["meta"]["NeutralFlag"] + self.slab_flag = self.ffinfo["Forces"][self.name]["meta"]["SlabFlag"] + self.constQ = self.ffinfo["Forces"][self.name]["meta"]["ConstQFlag"] + self.pbc_flag = self.ffinfo["Forces"][self.name]["meta"]["PbcFlag"] + + self.pme_generator = CoulombGenerator(ffinfo, paramset) + + def getName(self) -> str: + """ + Returns the name of the force field. + + Returns: + -------- + str + The name of the force field. + """ + return self.name + + def overwrite(self, paramset:ParamSet) -> None: + + node_indices = [ i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "QeqAtom"] + chi = paramset[self.name]["chi"] + J = paramset[self.name]["J"] + eta = paramset[self.name]["eta"] + for nnode, key in enumerate(self.keys): + self.ffinfo["Forces"][self.name]["node"][node_indices[nnode]]["attrib"] = {} + self.ffinfo["Forces"][self.name]["node"][node_indices[nnode]]["attrib"][f"{self.key_type}"] = key + chi0 = chi[nnode] + J0 = J[nnode] + eta0 = eta[nnode] + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["chi"] = str(chi0) + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["J"] = str(J0) + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["eta"] = str(eta0) + + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, charges, const_list, const_vals, map_atomtype): + + n_atoms = topdata._numAtoms + n_residues = topdata._numResidues + + q = jnp.array(charges) + lagmt = np.ones(n_residues) + b_value = jnp.concatenate((q,lagmt)) + qeq_force = ADMPQeqForce(q, lagmt,self.damp_mod, self.neutral_flag, + self.slab_flag, self.constQ, self.pbc_flag) + self.qeq_force = qeq_force + qeq_energy = qeq_force.generate_get_energy() + + self.pme_potential = self.pme_generator.createPotential(topdata, app.PME, nonbondedCutoff ) + def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet) -> jnp.ndarray: + + n_atoms = len(positions) + # map_atomtype = np.zeros(n_atoms) + eta = np.array(params[self.name]["eta"])[map_atomtype] + chi = np.array(params[self.name]["chi"])[map_atomtype] + J = np.array(params[self.name]["J"])[map_atomtype] + self.eta = jnp.array(eta) + self.chi = jnp.array(chi) + self.J = jnp.array(J) + # coulenergy = self.pme_generator.coulenergy + # pme_energy = pme_potential(positions, box, pairs, params) + damp_mod = self.damp_mod + neutral_flag = self.neutral_flag + constQ = self.constQ + pbc_flag = self.pbc_flag + + qeq_energy0 = qeq_energy(positions, box, pairs, q, lagmt, + eta, chi, J,const_list, + const_vals, self.pme_generator) + # return pme_energy + qeq_energy0 + return qeq_energy0 + + self._jaxPotential = potential_fn + return potential_fn + +_DMFFGenerators["ADMPQeqForce"] = ADMPQeqGenerator diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index 620d0a738..e6456d94f 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -1054,6 +1054,7 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, mscales_coul = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) # mscale for PME mscales_coul = mscales_coul.at[2].set(self.coulomb14scale) + self.mscales_coul = mscales_coul # for qeq calculation # set PBC if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]: @@ -1075,7 +1076,8 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, if nonbondedMethod is app.PME: cell = topdata.getPeriodicBoxVectors() box = jnp.array(cell) - self.ethresh = kwargs.get("ethresh", 1e-6) + # self.ethresh = kwargs.get("ethresh", 1e-6) + self.ethresh = kwargs.get("ethresh", 5e-4) #for qeq calculation self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm") self.fourier_spacing = kwargs.get("PmeSpacing", 0.1) kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh, @@ -1120,7 +1122,8 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, topology_matrix=top_mat if self._use_bcc else None) coulenergy = coulforce.generate_get_energy() - + self.coulforce = coulforce #for qeq calculation + self.coulenergy = coulenergy #for qeq calculation def potential_fn(positions, box, pairs, params): # check whether args passed into potential_fn are jnp.array and differentiable diff --git a/tests/data/qeq.pdb b/tests/data/qeq.pdb new file mode 100644 index 000000000..c835be211 --- /dev/null +++ b/tests/data/qeq.pdb @@ -0,0 +1,149 @@ +HEADER electrode systems +TITLE MDANALYSIS FRAME 0: Created by PDBWriter +REMARK GENERATED BY Haichao +CRYST1 22.116 17.184 200.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C001 CNC X 11 2.460 14.913 10.100 1.00 0.00 C +ATOM 2 C002 CNC X 11 4.920 14.913 10.100 1.00 0.00 C +ATOM 3 C003 CNC X 11 7.380 14.913 10.100 1.00 0.00 C +ATOM 4 C004 CNC X 11 9.840 14.913 10.100 1.00 0.00 C +ATOM 5 C005 CNC X 11 12.300 14.913 10.100 1.00 0.00 C +ATOM 6 C006 CNC X 11 14.760 14.913 10.100 1.00 0.00 C +ATOM 7 C007 CNC X 11 17.220 14.913 10.100 1.00 0.00 C +ATOM 8 C008 CNC X 11 0.000 14.913 10.100 1.00 0.00 C +ATOM 9 C009 CNC X 11 19.680 14.913 10.100 1.00 0.00 C +ATOM 10 C010 CNC X 11 11.070 0.000 10.100 1.00 0.00 C +ATOM 11 C011 CNC X 11 12.300 0.710 10.100 1.00 0.00 C +ATOM 12 C012 CNC X 11 13.530 0.000 10.100 1.00 0.00 C +ATOM 13 C013 CNC X 11 14.760 0.710 10.100 1.00 0.00 C +ATOM 14 C014 CNC X 11 15.990 0.000 10.100 1.00 0.00 C +ATOM 15 C015 CNC X 11 17.220 0.710 10.100 1.00 0.00 C +ATOM 16 C016 CNC X 11 18.450 0.000 10.100 1.00 0.00 C +ATOM 17 C017 CNC X 11 0.000 0.710 10.100 1.00 0.00 C +ATOM 18 C018 CNC X 11 19.680 0.710 10.100 1.00 0.00 C +ATOM 19 C019 CNC X 11 1.230 0.000 10.100 1.00 0.00 C +ATOM 20 C020 CNC X 11 20.910 0.000 10.100 1.00 0.00 C +ATOM 21 C021 CNC X 11 2.460 0.710 10.100 1.00 0.00 C +ATOM 22 C022 CNC X 11 3.690 0.000 10.100 1.00 0.00 C +ATOM 23 C023 CNC X 11 4.920 0.710 10.100 1.00 0.00 C +ATOM 24 C024 CNC X 11 6.150 0.000 10.100 1.00 0.00 C +ATOM 25 C025 CNC X 11 7.380 0.710 10.100 1.00 0.00 C +ATOM 26 C026 CNC X 11 8.610 0.000 10.100 1.00 0.00 C +ATOM 27 C027 CNC X 11 9.840 0.710 10.100 1.00 0.00 C +ATOM 28 C028 CNC X 11 9.840 2.130 10.100 1.00 0.00 C +ATOM 29 C029 CNC X 11 11.070 2.841 10.100 1.00 0.00 C +ATOM 30 C030 CNC X 11 12.300 2.130 10.100 1.00 0.00 C +ATOM 31 C031 CNC X 11 13.530 2.841 10.100 1.00 0.00 C +ATOM 32 C032 CNC X 11 14.760 2.130 10.100 1.00 0.00 C +ATOM 33 C033 CNC X 11 15.990 2.841 10.100 1.00 0.00 C +ATOM 34 C034 CNC X 11 17.220 2.130 10.100 1.00 0.00 C +ATOM 35 C035 CNC X 11 18.450 2.841 10.100 1.00 0.00 C +ATOM 36 C036 CNC X 11 0.000 2.130 10.100 1.00 0.00 C +ATOM 37 C037 CNC X 11 19.680 2.130 10.100 1.00 0.00 C +ATOM 38 C038 CNC X 11 1.230 2.841 10.100 1.00 0.00 C +ATOM 39 C039 CNC X 11 20.910 2.841 10.100 1.00 0.00 C +ATOM 40 C040 CNC X 11 2.460 2.130 10.100 1.00 0.00 C +ATOM 41 C041 CNC X 11 3.690 2.841 10.100 1.00 0.00 C +ATOM 42 C042 CNC X 11 4.920 2.130 10.100 1.00 0.00 C +ATOM 43 C043 CNC X 11 6.150 2.841 10.100 1.00 0.00 C +ATOM 44 C044 CNC X 11 7.380 2.130 10.100 1.00 0.00 C +ATOM 45 C045 CNC X 11 8.610 2.841 10.100 1.00 0.00 C +ATOM 46 C046 CNC X 11 8.610 4.261 10.100 1.00 0.00 C +ATOM 47 C047 CNC X 11 9.840 4.971 10.100 1.00 0.00 C +ATOM 48 C048 CNC X 11 11.070 4.261 10.100 1.00 0.00 C +ATOM 49 C049 CNC X 11 12.300 4.971 10.100 1.00 0.00 C +ATOM 50 C050 CNC X 11 13.530 4.261 10.100 1.00 0.00 C +ATOM 51 C051 CNC X 11 14.760 4.971 10.100 1.00 0.00 C +ATOM 52 C052 CNC X 11 15.990 4.261 10.100 1.00 0.00 C +ATOM 53 C053 CNC X 11 17.220 4.971 10.100 1.00 0.00 C +ATOM 54 C054 CNC X 11 18.450 4.261 10.100 1.00 0.00 C +ATOM 55 C055 CNC X 11 0.000 4.971 10.100 1.00 0.00 C +ATOM 56 C056 CNC X 11 19.680 4.971 10.100 1.00 0.00 C +ATOM 57 C057 CNC X 11 1.230 4.261 10.100 1.00 0.00 C +ATOM 58 C058 CNC X 11 20.910 4.261 10.100 1.00 0.00 C +ATOM 59 C059 CNC X 11 2.460 4.971 10.100 1.00 0.00 C +ATOM 60 C060 CNC X 11 3.690 4.261 10.100 1.00 0.00 C +ATOM 61 C061 CNC X 11 4.920 4.971 10.100 1.00 0.00 C +ATOM 62 C062 CNC X 11 6.150 4.261 10.100 1.00 0.00 C +ATOM 63 C063 CNC X 11 7.380 4.971 10.100 1.00 0.00 C +ATOM 64 C064 CNC X 11 7.380 6.391 10.100 1.00 0.00 C +ATOM 65 C065 CNC X 11 8.610 7.101 10.100 1.00 0.00 C +ATOM 66 C066 CNC X 11 9.840 6.391 10.100 1.00 0.00 C +ATOM 67 C067 CNC X 11 12.300 6.391 10.100 1.00 0.00 C +ATOM 68 C068 CNC X 11 13.530 7.101 10.100 1.00 0.00 C +ATOM 69 C069 CNC X 11 14.760 6.391 10.100 1.00 0.00 C +ATOM 70 C070 CNC X 11 15.990 7.101 10.100 1.00 0.00 C +ATOM 71 C071 CNC X 11 17.220 6.391 10.100 1.00 0.00 C +ATOM 72 C072 CNC X 11 18.450 7.101 10.100 1.00 0.00 C +ATOM 73 C073 CNC X 11 0.000 6.391 10.100 1.00 0.00 C +ATOM 74 C074 CNC X 11 19.680 6.391 10.100 1.00 0.00 C +ATOM 75 C075 CNC X 11 1.230 7.101 10.100 1.00 0.00 C +ATOM 76 C076 CNC X 11 20.910 7.101 10.100 1.00 0.00 C +ATOM 77 C077 CNC X 11 2.460 6.391 10.100 1.00 0.00 C +ATOM 78 C078 CNC X 11 3.690 7.101 10.100 1.00 0.00 C +ATOM 79 C079 CNC X 11 4.920 6.391 10.100 1.00 0.00 C +ATOM 80 C080 CNC X 11 6.150 7.101 10.100 1.00 0.00 C +ATOM 81 C081 CNC X 11 6.150 8.522 10.100 1.00 0.00 C +ATOM 82 C082 CNC X 11 7.380 9.232 10.100 1.00 0.00 C +ATOM 83 C083 CNC X 11 8.610 8.522 10.100 1.00 0.00 C +ATOM 84 C084 CNC X 11 9.840 9.232 10.100 1.00 0.00 C +ATOM 85 C085 CNC X 11 11.070 8.522 10.100 1.00 0.00 C +ATOM 86 C086 CNC X 11 12.300 9.232 10.100 1.00 0.00 C +ATOM 87 C087 CNC X 11 13.530 8.522 10.100 1.00 0.00 C +ATOM 88 C088 CNC X 11 14.760 9.232 10.100 1.00 0.00 C +ATOM 89 C089 CNC X 11 15.990 8.522 10.100 1.00 0.00 C +ATOM 90 C090 CNC X 11 17.220 9.232 10.100 1.00 0.00 C +ATOM 91 C091 CNC X 11 18.450 8.522 10.100 1.00 0.00 C +ATOM 92 C092 CNC X 11 0.000 9.232 10.100 1.00 0.00 C +ATOM 93 C093 CNC X 11 19.680 9.232 10.100 1.00 0.00 C +ATOM 94 C094 CNC X 11 1.230 8.522 10.100 1.00 0.00 C +ATOM 95 C095 CNC X 11 20.910 8.522 10.100 1.00 0.00 C +ATOM 96 C096 CNC X 11 2.460 9.232 10.100 1.00 0.00 C +ATOM 97 C097 CNC X 11 3.690 8.522 10.100 1.00 0.00 C +ATOM 98 C098 CNC X 11 4.920 9.232 10.100 1.00 0.00 C +ATOM 99 C099 CNC X 11 4.920 10.652 10.100 1.00 0.00 C +ATOM 100 C100 CNC X 11 6.150 11.362 10.100 1.00 0.00 C +ATOM 101 C101 CNC X 11 7.380 10.652 10.100 1.00 0.00 C +ATOM 102 C102 CNC X 11 8.610 11.362 10.100 1.00 0.00 C +ATOM 103 C103 CNC X 11 9.840 10.652 10.100 1.00 0.00 C +ATOM 104 C104 CNC X 11 11.070 11.362 10.100 1.00 0.00 C +ATOM 105 C105 CNC X 11 12.300 10.652 10.100 1.00 0.00 C +ATOM 106 C106 CNC X 11 13.530 11.362 10.100 1.00 0.00 C +ATOM 107 C107 CNC X 11 14.760 10.652 10.100 1.00 0.00 C +ATOM 108 C108 CNC X 11 15.990 11.362 10.100 1.00 0.00 C +ATOM 109 C109 CNC X 11 17.220 10.652 10.100 1.00 0.00 C +ATOM 110 C110 CNC X 11 18.450 11.362 10.100 1.00 0.00 C +ATOM 111 C111 CNC X 11 0.000 10.652 10.100 1.00 0.00 C +ATOM 112 C112 CNC X 11 19.680 10.652 10.100 1.00 0.00 C +ATOM 113 C113 CNC X 11 1.230 11.362 10.100 1.00 0.00 C +ATOM 114 C114 CNC X 11 20.910 11.362 10.100 1.00 0.00 C +ATOM 115 C115 CNC X 11 2.460 10.652 10.100 1.00 0.00 C +ATOM 116 C116 CNC X 11 3.690 11.362 10.100 1.00 0.00 C +ATOM 117 C117 CNC X 11 3.690 12.783 10.100 1.00 0.00 C +ATOM 118 C118 CNC X 11 4.920 13.493 10.100 1.00 0.00 C +ATOM 119 C119 CNC X 11 6.150 12.783 10.100 1.00 0.00 C +ATOM 120 C120 CNC X 11 7.380 13.493 10.100 1.00 0.00 C +ATOM 121 C121 CNC X 11 8.610 12.783 10.100 1.00 0.00 C +ATOM 122 C122 CNC X 11 9.840 13.493 10.100 1.00 0.00 C +ATOM 123 C123 CNC X 11 11.070 12.783 10.100 1.00 0.00 C +ATOM 124 C124 CNC X 11 12.300 13.493 10.100 1.00 0.00 C +ATOM 125 C125 CNC X 11 13.530 12.783 10.100 1.00 0.00 C +ATOM 126 C126 CNC X 11 14.760 13.493 10.100 1.00 0.00 C +ATOM 127 C127 CNC X 11 15.990 12.783 10.100 1.00 0.00 C +ATOM 128 C128 CNC X 11 17.220 13.493 10.100 1.00 0.00 C +ATOM 129 C129 CNC X 11 18.450 12.783 10.100 1.00 0.00 C +ATOM 130 C130 CNC X 11 0.000 13.493 10.100 1.00 0.00 C +ATOM 131 C131 CNC X 11 19.680 13.493 10.100 1.00 0.00 C +ATOM 132 C132 CNC X 11 1.230 12.783 10.100 1.00 0.00 C +ATOM 133 C133 CNC X 11 20.910 12.783 10.100 1.00 0.00 C +ATOM 134 C134 CNC X 11 2.460 13.493 10.100 1.00 0.00 C +ATOM 135 C135 CNC X 11 1.230 15.623 10.100 1.00 0.00 C +ATOM 136 C136 CNC X 11 3.690 15.623 10.100 1.00 0.00 C +ATOM 137 C137 CNC X 11 20.910 15.623 10.100 1.00 0.00 C +ATOM 138 C138 CNC X 11 6.150 15.623 10.100 1.00 0.00 C +ATOM 139 C139 CNC X 11 8.610 15.623 10.100 1.00 0.00 C +ATOM 140 C140 CNC X 11 11.070 15.623 10.100 1.00 0.00 C +ATOM 141 C141 CNC X 11 13.530 15.623 10.100 1.00 0.00 C +ATOM 142 C142 CNC X 11 15.990 15.623 10.100 1.00 0.00 C +ATOM 143 C143 CNC X 11 18.450 15.623 10.100 1.00 0.00 C +ATOM 144 N001 CNC X 11 11.070 7.101 10.100 1.00 0.00 N +END diff --git a/tests/data/qeq.xml b/tests/data/qeq.xml new file mode 100644 index 000000000..664609905 --- /dev/null +++ b/tests/data/qeq.xml @@ -0,0 +1,197 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 800480ab745b31d4a926a17ece8b081dea77b06f Mon Sep 17 00:00:00 2001 From: Kuang Yu Date: Sun, 22 Oct 2023 16:40:03 +0800 Subject: [PATCH 2/3] Add frontend for sGNN (#125) * Add sGNN generator fixed a few problems in ADMPPmeGenerator * remove debugging codes --- .github/workflows/ut.yml | 3 +- dmff/api/graph.py | 2 +- dmff/generators/__init__.py | 3 +- dmff/generators/admp.py | 34 +++++--- dmff/generators/ml.py | 74 +++++++++++++++++ dmff/sgnn/gnn.py | 76 +++++++++--------- dmff/sgnn/graph.py | 14 ++++ examples/classical/test_xml.py | 3 +- examples/sgnn/model1.pickle | Bin 17100 -> 26 bytes examples/sgnn/peg.xml | 48 +++++++++++ examples/sgnn/peg4.pdb | 64 +-------------- examples/sgnn/ref_out | 42 ++-------- examples/sgnn/residues.xml | 37 +++++++++ examples/sgnn/run.py | 67 ++++++++------- examples/sgnn/{ => test_backend}/model.pickle | Bin examples/sgnn/test_backend/model1.pickle | Bin 0 -> 17100 bytes examples/sgnn/{ => test_backend}/model1.pth | Bin .../sgnn/{ => test_backend}/mse_testing.xvg | 0 examples/sgnn/test_backend/peg4.pdb | 63 +++++++++++++++ .../sgnn/{ => test_backend}/pth2pickle.py | 0 examples/sgnn/test_backend/ref_out | 37 +++++++++ examples/sgnn/test_backend/run.py | 45 +++++++++++ .../sgnn/{ => test_backend}/set_test.pickle | Bin .../{ => test_backend}/set_test_lowT.pickle | Bin examples/sgnn/{ => test_backend}/test.py | 0 .../sgnn/{ => test_backend}/test_data.xvg | 0 examples/sgnn/{ => test_backend}/train.py | 0 examples/water_fullpol/monopole_nonpol/run.py | 11 +-- .../water_fullpol/monopole_polarizable/run.py | 11 +-- .../water_fullpol/quadrupole_nonpol/run.py | 10 +-- examples/water_fullpol/run.py | 10 +-- tests/data/admp_mono.xml | 34 ++++++++ tests/data/admp_nonpol.xml | 40 +++++++++ tests/data/peg4.pdb | 64 +++++++++++++++ tests/data/peg_sgnn.xml | 48 +++++++++++ tests/data/sgnn_model.pickle | Bin 0 -> 17100 bytes tests/test_admp/test_compute.py | 54 ++++++++++++- tests/test_sgnn/test_energy.py | 51 ++++++++++++ 38 files changed, 737 insertions(+), 208 deletions(-) create mode 100644 dmff/generators/ml.py mode change 100644 => 120000 examples/sgnn/model1.pickle create mode 100644 examples/sgnn/peg.xml mode change 100644 => 120000 examples/sgnn/peg4.pdb create mode 100644 examples/sgnn/residues.xml rename examples/sgnn/{ => test_backend}/model.pickle (100%) create mode 100644 examples/sgnn/test_backend/model1.pickle rename examples/sgnn/{ => test_backend}/model1.pth (100%) rename examples/sgnn/{ => test_backend}/mse_testing.xvg (100%) create mode 100644 examples/sgnn/test_backend/peg4.pdb rename examples/sgnn/{ => test_backend}/pth2pickle.py (100%) create mode 100644 examples/sgnn/test_backend/ref_out create mode 100755 examples/sgnn/test_backend/run.py rename examples/sgnn/{ => test_backend}/set_test.pickle (100%) rename examples/sgnn/{ => test_backend}/set_test_lowT.pickle (100%) rename examples/sgnn/{ => test_backend}/test.py (100%) rename examples/sgnn/{ => test_backend}/test_data.xvg (100%) rename examples/sgnn/{ => test_backend}/train.py (100%) create mode 100644 tests/data/admp_mono.xml create mode 100644 tests/data/admp_nonpol.xml create mode 100644 tests/data/peg4.pdb create mode 100644 tests/data/peg_sgnn.xml create mode 100644 tests/data/sgnn_model.pickle create mode 100644 tests/test_sgnn/test_energy.py diff --git a/.github/workflows/ut.yml b/.github/workflows/ut.yml index 701acf9bb..5b8de3ba1 100644 --- a/.github/workflows/ut.yml +++ b/.github/workflows/ut.yml @@ -33,4 +33,5 @@ jobs: pytest -vs tests/test_common/test_* pytest -vs tests/test_admp/test_* pytest -vs tests/test_utils.py - pytest -vs tests/test_mbar/test_* + pytest -vs tests/test_mbar/test_* + pytest -vs tests/test_sgnn/test_* diff --git a/dmff/api/graph.py b/dmff/api/graph.py index 6e1ccd6c0..64e9af10a 100644 --- a/dmff/api/graph.py +++ b/dmff/api/graph.py @@ -13,7 +13,7 @@ def matchTemplate(graph, template): if graph.number_of_nodes() != template.number_of_nodes(): - print("Node with different number of nodes.") + # print("Node with different number of nodes.") return False, {}, {} def match_func(n1, n2): diff --git a/dmff/generators/__init__.py b/dmff/generators/__init__.py index 7ddb93293..6f37cf7f0 100644 --- a/dmff/generators/__init__.py +++ b/dmff/generators/__init__.py @@ -1,2 +1,3 @@ from .classical import * -from .admp import * \ No newline at end of file +from .admp import * +from .ml import * diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index bf1cce2e0..cada68fe1 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -822,15 +822,28 @@ def __init__(self, ffinfo: dict, paramset: ParamSet): kzs.append(kz) # record multipoles c0.append(float(attribs["c0"])) - dX.append(float(attribs["dX"])) - dY.append(float(attribs["dY"])) - dZ.append(float(attribs["dZ"])) - qXX.append(float(attribs["qXX"])) - qYY.append(float(attribs["qYY"])) - qZZ.append(float(attribs["qZZ"])) - qXY.append(float(attribs["qXY"])) - qXZ.append(float(attribs["qXZ"])) - qYZ.append(float(attribs["qYZ"])) + if self.lmax >= 1: + dX.append(float(attribs["dX"])) + dY.append(float(attribs["dY"])) + dZ.append(float(attribs["dZ"])) + else: + dX.append(0.0) + dY.append(0.0) + dZ.append(0.0) + if self.lmax >= 2: + qXX.append(float(attribs["qXX"])) + qYY.append(float(attribs["qYY"])) + qZZ.append(float(attribs["qZZ"])) + qXY.append(float(attribs["qXY"])) + qXZ.append(float(attribs["qXZ"])) + qYZ.append(float(attribs["qYZ"])) + else: + qXX.append(0.0) + qYY.append(0.0) + qZZ.append(0.0) + qXY.append(0.0) + qXZ.append(0.0) + qYZ.append(0.0) mask = 1.0 if "mask" in attribs and attribs["mask"].upper() == "TRUE": mask = 0.0 @@ -1146,6 +1159,7 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutof pme_force = ADMPPmeForce(box, axis_types, axis_indices, rc, self.ethresh, self.lmax, self.lpol, lpme, self.step_pol) + self.pme_force = pme_force def potential_fn(positions, box, pairs, params): positions = positions * 10 @@ -1181,4 +1195,4 @@ def getMetaData(self): return self._meta -_DMFFGenerators["ADMPPmeForce"] = ADMPPmeGenerator \ No newline at end of file +_DMFFGenerators["ADMPPmeForce"] = ADMPPmeGenerator diff --git a/dmff/generators/ml.py b/dmff/generators/ml.py new file mode 100644 index 000000000..afec8cee8 --- /dev/null +++ b/dmff/generators/ml.py @@ -0,0 +1,74 @@ +from ..api.topology import DMFFTopology +from ..api.paramset import ParamSet +from ..api.hamiltonian import _DMFFGenerators +from ..utils import DMFFException, isinstance_jnp +from ..utils import jit_condition +import numpy as np +import jax +import jax.numpy as jnp +import openmm.app as app +import openmm.unit as unit +import pickle + +from ..sgnn.graph import MAX_VALENCE, TopGraph, from_pdb +from ..sgnn.gnn import MolGNNForce, prm_transform_f2i + + +class SGNNGenerator: + def __init__(self, ffinfo: dict, paramset: ParamSet): + + self.name = "SGNNForce" + self.ffinfo = ffinfo + paramset.addField(self.name) + self.key_type = None + + self.file = self.ffinfo["Forces"][self.name]["meta"]["file"] + self.nn = int(self.ffinfo["Forces"][self.name]["meta"]["nn"]) + self.pdb = self.ffinfo["Forces"][self.name]["meta"]["pdb"] + + # load ML potential parameters + with open(self.file, 'rb') as ifile: + params = pickle.load(ifile) + + # convert to jnp array + for k in params: + params[k] = jnp.array(params[k]) + # set mask to all true + paramset.addParameter(params[k], k, field=self.name, mask=jnp.ones(params[k].shape)) + + # mask = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape), params) + # paramset.addParameter(params, "params", field=self.name, mask=mask) + + + def getName(self) -> str: + return self.name + + def overwrite(self, paramset): + # do not use xml to handle ML potentials + # for ML potentials, xml only documents param file path + # so for ML potentials, overwrite function overwrites the file directly + with open(self.file, 'wb') as ofile: + pickle.dump(paramset[self.name], ofile) + return + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs): + self.G = from_pdb(self.pdb) + n_atoms = topdata.getNumAtoms() + self.model = MolGNNForce(self.G, nn=self.nn) + n_layers = self.model.n_layers + def potential_fn(positions, box, pairs, params): + # convert unit to angstrom + positions = positions * 10 + box = box * 10 + prms = prm_transform_f2i(params[self.name], n_layers) + return self.model.get_energy(positions, box, prms) + + self._jaxPotential = potential_fn + return potential_fn + + def getJaxPotential(self): + return self._jaxPotential + + +_DMFFGenerators["SGNNForce"] = SGNNGenerator + diff --git a/dmff/sgnn/gnn.py b/dmff/sgnn/gnn.py index 49a5ee985..403d36429 100755 --- a/dmff/sgnn/gnn.py +++ b/dmff/sgnn/gnn.py @@ -13,6 +13,39 @@ from jax import value_and_grad, vmap +def prm_transform_f2i(params, n_layers): + p = {} + for k in params: + p[k] = jnp.array(params[k]) + for i_nn in [0, 1]: + nn_name = 'fc%d' % i_nn + p['%s.weight' % nn_name] = [] + p['%s.bias' % nn_name] = [] + for i_layer in range(n_layers[i_nn]): + k_w = '%s.%d.weight' % (nn_name, i_layer) + k_b = '%s.%d.bias' % (nn_name, i_layer) + p['%s.weight' % nn_name].append(p.pop(k_w, None)) + p['%s.bias' % nn_name].append(p.pop(k_b, None)) + return p + + +def prm_transform_i2f(params, n_layers): + # transform format + p = {} + p['w'] = params['w'] + p['fc_final.weight'] = params['fc_final.weight'] + p['fc_final.bias'] = params['fc_final.bias'] + for i_nn in range(2): + nn_name = 'fc%d' % i_nn + for i_layer in range(n_layers[i_nn]): + p[nn_name + '.%d.weight' % + i_layer] = params[nn_name + '.weight'][i_layer] + p[nn_name + + '.%d.bias' % i_layer] = params[nn_name + + '.bias'][i_layer] + return p + + class MolGNNForce: def __init__(self, @@ -146,6 +179,7 @@ def message_pass(f_in, nb_connect, w, nn): return + def load_params(self, ifn): """ Load the network parameters from saved file @@ -160,32 +194,12 @@ def load_params(self, ifn): for k in params.keys(): params[k] = jnp.array(params[k]) # transform format - keys = list(params.keys()) - for i_nn in [0, 1]: - nn_name = 'fc%d' % i_nn - keys_weight = [] - keys_bias = [] - for k in keys: - if re.search(nn_name + '.[0-9]+.weight', k) is not None: - keys_weight.append(k) - elif re.search(nn_name + '.[0-9]+.bias', k) is not None: - keys_bias.append(k) - if len(keys_weight) != self.n_layers[i_nn] or len( - keys_bias) != self.n_layers[i_nn]: - sys.exit( - 'Error while loading GNN params, inconsistent inputs with the GNN structure, check your input!' - ) - params['%s.weight' % nn_name] = [] - params['%s.bias' % nn_name] = [] - for i_layer in range(self.n_layers[i_nn]): - k_w = '%s.%d.weight' % (nn_name, i_layer) - k_b = '%s.%d.bias' % (nn_name, i_layer) - params['%s.weight' % nn_name].append(params.pop(k_w, None)) - params['%s.bias' % nn_name].append(params.pop(k_b, None)) - # params[nn_name] - self.params = params + self.params = prm_transform_f2i(params, self.n_layers) return + + + def save_params(self, ofn): """ Save the network parameters to a pickle file @@ -196,18 +210,8 @@ def save_params(self, ofn): """ # transform format - params = {} - params['w'] = self.params['w'] - params['fc_final.weight'] = self.params['fc_final.weight'] - params['fc_final.bias'] = self.params['fc_final.bias'] - for i_nn in range(2): - nn_name = 'fc%d' % i_nn - for i_layer in range(self.n_layers[i_nn]): - params[nn_name + '.%d.weight' % - i_layer] = self.params[nn_name + '.weight'][i_layer] - params[nn_name + - '.%d.bias' % i_layer] = self.params[nn_name + - '.bias'][i_layer] + params = prm_transform_i2f(self.params, self.n_layers) with open(ofn, 'wb') as ofile: pickle.dump(params, ofile) return + diff --git a/dmff/sgnn/graph.py b/dmff/sgnn/graph.py index 93f6a809a..164a41f3f 100755 --- a/dmff/sgnn/graph.py +++ b/dmff/sgnn/graph.py @@ -1219,6 +1219,20 @@ def from_pdb(pdb): return TopGraph(list_atom_elems, bonds, positions=positions, box=box) +# def from_dmff_top(topdata): +# ''' +# Build the sGNN TopGraph object from a DMFFTopology object + +# Parameters +# ---------- +# topdata: DMFFTopology data +# ''' +# list_atom_elems = np.array([a.element for a in topdata.atoms()]) +# bonds = np.array([np.sort([b.atom1.index, b.atom2.index]) for b in topdata.bonds()]) +# n_atoms = len(list_atom_elems) +# return TopGraph(list_atom_elems, bonds, positions=jnp.zeros((n_atoms, 3)), box=jnp.eye(3)*10) + + def validation(): G = from_pdb('peg4.pdb') nn = 1 diff --git a/examples/classical/test_xml.py b/examples/classical/test_xml.py index e84c6d849..c5d594033 100755 --- a/examples/classical/test_xml.py +++ b/examples/classical/test_xml.py @@ -82,4 +82,5 @@ def getEnergyDecomposition(context, forcegroups): print("Nonbonded:", nbE(positions, box, pairs, params)) etotal = pot.getPotentialFunc() - print("Total:", etotal(positions, box, pairs, params)) \ No newline at end of file + print("Total:", etotal(positions, box, pairs, params)) + diff --git a/examples/sgnn/model1.pickle b/examples/sgnn/model1.pickle deleted file mode 100644 index 0c3959cd9d0ef4fac155676861aa6743acb96c99..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 17100 zcmYhjc|28J^gmAKAyY_*1~L^=63$*Xr4mVLpj0v>4d|LmrOb2Y$XF7RREWaa>!!>p zB}p1I7fmWEY4~}b&*%Ame)o_2y03HhzH6O(&RXyNUhBP2h=7}$&z?Qo-TZg@c>9Ul z`MPiS-R^F=)6HL;%co<{<=1xP=i}qs$DQEj9pJS$NZ-xJce}n!ym~N``>}?{y}@U zi*v;tCyZP1n9r54;j`h7=1SUgCu-XW{A-)xO08P8%KGp4>)$sUK7X#XYwQw11rJ#c zSH_*^@^&ulkCeM9|y+js8p^ykWX3V8l&dy?n4VR?J5!as(!ZRSUD6+I`$E)JQ- z$G1T+fcBgerYQfCJnLyCU5n1L30H2haA*$%7`q}YKMGyJIdm3Jif)~eM88DUuz!uc zNNAHD+&Oj~KmEz4?{+8A_PLsLs9c+veQ%*5X%DISqnE5w(ZYq^-y~7^*j0K*{~jG? zo0*RZ6-;YuG?Mo3Bsg^!h6>seBNG)o+1p0szEs1!-xHuICy&g3U4{qL>*3;vItT?J z++%M;@YfPJ%(@T_&OuNe=EM21Ps#JaKr1#@`jJY_vKV!_Lxq4MtdG?`%+ za8xiyTN-Hc7!zA9hPRsjl2rp@OquL=HvHWcni(O?s%SVeJ!8|Tdy+V;@zlZZItF#` z%);yGiaS(6Hjg^4&&Fz%Od?jim7Q>^1;4KifrysX5WY~0RQ(pit&#W0lmgrOgJDv< zx&`%6#crm}t;cD_+kE9m zyJaiMx=bnTIHUm~7pMV&TLd7zWFD1n4<*z7PJ_7zwouFJ1hY32#o>|pN%G5Q5k| z)2AF`t9s<9UuG0de00TJbnYBDGh3Z7k-hZs@j1l9PY80l=aC@6e8RtBElO;gi%lnz zuxr&8(x@wmTG!W-w{ca3(Vj^A&W;iO{7hW^CyG8@zlRtsmu6ZXpJ19w3Sq~myQq2O z8Qu-6!4wH&tedwSS#?EJ4^#sqrVF>X#BesXDG>gBmUul!hey?BQ(kU8=X2jl8l5!2 zR+g_uA%_esv$n-6-LpV@K$;LmX`I(oMLexz@J50j{T`G_RWpPjZPPN=PBH{}mUd`x zTL9E;&B&l!CE6qz(y4t>#5H~u8!A{#{(fiC>X{DjOieSLANP#5&5Q-V`IjjFfz9yf z$piAlV?Fv@ngg>BJ~CfE$qM|9U((AWPpQZUF^;c(F3c&Nj6EfH$?WTva6mB;YG)nc zOmqo=zn@m4b^3W;@Qx?&!E7}wkbVcf`U^QK)@oQSeTL|zFM`)rQ*j>-lY|K@b%AU$ z@2LdL4eTNZ?uo#$KkMqxDTdeYds9rVE$C){Mx0`HSxCZijVX}*_(}b#vHOJgS()-b zyGe~729a$u)lltlmARK%2T^cd1&7u@VTE?sQnf7;cn0}47zdj8;8isp-W>*;cLk%1 zya%<2n^;${Bn01D%>kur5qL6rE$(Z%3x$b^)ZZ!s?^>sz+?U@hi{nr_HpG*=u=RVOp zdJsa6>}RH5+=>P(LuvnX5q!~TOTWj6a~jHO;K$-A7%a4xhanx<_1POlqFYfRWF}8P zZ34&j`+KswYX$5WJPaN$1Yl)w5BYRIf|;qc0CIgJ$m#cH^zKIhdu=b8-RI8SJMoC* z*iI$F&zI8TE8;Nq_ZDd0B1x|FYMCn)CeTygmy_yJKDZ*R!g@T)lDN#=yPgj+svk*}fH3U2z>uE9@#Kc)1{B|2i`(J{>8(fah*rWf zRKBeQ=24$;rCI@WytE?S7r&DC1N+hT9FQf&44q}zLhHKXNxvz}rXI{9DK)!j6tO~q zM~~>-kKa*njE{aToQQ8q_^HA-P4aXhAFr{$j+QL-Ak7-BWYDr0SAUsI_Pn`8j8<3! z&p8A?2lX%_Q7xbvkVK4<=itfH3-E&4T-+5Q#0k6l2RD^jgT479W|0;@_!u@(6?;{h z))x#B!7`la+FuaI`N&Q$D8-wN(>P(9*WswLOZ}T2a&%}}EciU%4oki}l8M}9c>Vh- z5*D}shnx4{!KsrWf8R9TnS|G@*h)>BN~CCT!b)6YTFc&=mO%Apm{75vt)za|HD>v! zC}?n0!D-bVb~I85RWE0O^!aL9q4=J(ttM%T$FI9 zHm8qLts7p{MA8!rO+)G5w*%%sQv5K;;0)fGbc>l_kOG%qhfxw10}8&GwDG|ic1Qb7 zeDi%f7OXYq%ovHp)%ANxcA_fE!a5k7kxaHu^r8(7;;i`8-B?^_$~e*p;%wK;4i@gg z`DYiBA$H1QP6~oRJ{SN#DxZSv zXcF$4vJ-AwhB1v|h4t#^X3$ubNZNJf0<5AVDAoCk=pQm68~!K)_FRF|*YddBRRSNk zU5BHxe@RM}HZ9v0Mt6CBqlOwNsGFUc;z3$gY3epTdv6ky_NhJ)` zPgTMpy8=A=YZgZ%kVnrR%p>BwLVDO`5e77bfL(_=3H@r0i>vy`^MKzp^mQGT%U7dM zmidrhq87Atwh+YECe@em?!v3nlg$IypP-YRi@|WoZ16nGkLHU?i00}MD2$edqwXp= zN#!B2*)C1i8koX|3VRf7n@DE3d}en~(4}TpgCsbmoVNEHf%3H&-hMk95dOH1$bOzc zWA+Q-hU`2#!`l}`H$I{^D}G|_r@y#&!w)K%K0qINzopr|S*&bg7;|qCKO_}duyUd+ z=}^y2GV`4pY@CJc%JUOBULRk4M`W9pZTD5}PtCNA7TPi9}#3USF@k`7l~da#;of7u_UUsg2a?L=*k3 z^oTC78HVVo&5Y@rdsNPGPrYjSUe<7#4edrQhNc`Nn|IHK$Key<+6+yi-=F}(g{xq7 z*9J2BDS~Xzox+;6y}$r+9jz5>QE`$2uhvfwcnOYFv&|I)Ya%gy>Q0&)k`&&tCME!u6xaNaryRg6er7R3VNL!~N)4WX!W@sD;d^ z+tkDB9QX=4)<-v8$8Qduu$+ou<+J_ds#gIvPHScNeba>OUDEiva5`I|xgX_i^FaTI zFj{#>BgRVBe|;|iTBjsI>7fICK4mK~-M``c_ul%ki>ACfbSyL=_Nv|9()z3d?`tw!j>-|xwYc|AMFZXbJK+ZVF%w-&u5yM~4t zj%NVhEQZ&fKq}hG;e(OI}YD{$|CnGX5j6lzck|cVqCJ|9C@_65)$o3F^6vwt@2q;_C*@OnSrbD z+g$^0O#XpUg%Z45FAjj6lss?BuglC^&1iGMpU*j4f<@4ILN{~XK>{4EiD6`I8tiRY ziEm35(w8^H$?LPC)I+~@JRTEBe{wkqoxGFgTplE`wh7E%whs%;%s54E{Io>qDzJK| z>DYx8G;}N&cKW5z_wp}DO0_J_kL2LF4TiMW(TI2{*n_Qe03A-H*c6b6MmKEWo~j9H zaC$-iTo0mq7cgYUReid#(V9tZ$|UgzRiIatpNu35!-9R!!A#>MbEHolhYJnqK2#-@ zxBif$FEn`ZhtxU!;U1iq%BPs3?^^7g?k#lh=UZf|*)mkMe!$@61ZtP}kOT^*kU1}! zXvW$<#38kr=HK1O>g$vc!RHOsb+#UjRnKP<*9e33N)LwL$%$cW6d2!<2TXd4De26; zf}RG+bz;XKlW4;%qF38TgT1Qhr)B5a(16X1Vs;I&35=!u6P3t4F@ADbLXs3zXp*Zp z+F1q7WE9@Ufy$VDXq+g>ep_5kYnGJ2;O4(%)B6j&-R|m~m#tH=(fk^W-O=NnkC)<1 z%~9Yi1AltLY6CnzWX7?Zo(lP8UrGL?$!N}<1h&7-q4Uuwxv{x`b!_^>Y#F^qS28bI zWy!~eq1C{VL9~K zEnVhAXCC>;|C+pABtTa^6+zh_iSWbxCHt%9CoSgHk(NG1dN{0xr~l^^6nqt z46<-KG=qdxCc{12P9la*g6xxgG}5!CmKIUuqk;fVzbMEVK@Z-*Cm9GfQRZx#Yen@^ zyUFb7*NC@#D|Cpxqvxjk(`+MkI*)FJX08cd&bSQPWfLGIcsaaM*no1I`CxuhE*+>j z#6A%Tqz8R}na5-%(9r3RneM%(q5OUUQNJ=mS9n&!iaF^-&?}FLeq8_qazpfWa}qo; zJcb!=??{i@9pY~6g)T0uaGk+pV*EXZGC=KXEdAXJpz4R;3<2UD_Ob0_Uqi-)V> zUGOgH6F%42MT~n^AZNcX)%)H?ecj(t?T-Vf=Pzw0{$?-8KKu@8HEwWPScfzDls->x zavWX9-#~8t%piZ{1!<>vviZ|1?wDLBLN1Q#!3Fh?jQKokP#Os$zmEySj0XlVA~i~l zgC7$e;a0fsK9xy3UQc%Tl@oo>bVklr2n3?}sB_sLa>*ecK5fwC%b6H8~C^^H89FHU~zOIdKIJ;C1E##1uEe@pm2gJ3favq|ApK zW7|MIWDajykvG07noPba$?*JJLb2k`8zL(%MNXm*Jr?ST#b;EB(3k~O&ydF9+H>r- z)h1Lv>N&IeKtA2VS4+g#6_fiNc?_92lbn{?0M!s$(XHn_NM=?d9m?Pb&nrCh zg0;isbGIS=D?Ae&n;Xfo-W9NUX$iSmxS!5W5ykZndud|MWL!5Zf`-5eh~X~4z~OVy z$z%Y3^ffj^$_*D%IdaLOgi$r2$6fwqR{K{AYI{doE_ONN=+XY5rx1e^2&J;jYtb) z7bypmI|jAHcFBBZvfW;cQyienO6Sw2+)(@yzl9cXB=F6_B$Sh>W8@wvW5f?t*tm8v z?HX((l~Je3tOKRgN=k}`Ri%&~lV@;l^BpK^i-C1rxgg;f34Pydq5s@&T%3#-9!~R@4M6)bbfkW&}s;a3=UgXx1lDu}>IinQ! zpQwN#!60fSW`oye9wO1}x(V$sf-f^pQ+bJryv5#6p}2b+QBt4J{xZxUS7lR4ThUSW zOmjQY-g6cXNxmnEqRVKPeK80J?1tGl*Fdt-d|tpDK5VgUM(d9C`0L>%`Y}odeZO4C zgSq!fvaCPkJ;(u#Gy~3pqaB#F^axK{*$MXZf1|RSuaZgbF~rHsg#Pv>WW|PSJ}%l-LpfKr%eFk4Zlu#mmAH zdRgx*7?+e2$CKH_Wd1i4%U;49-5Sbx$7Vv~WLc2=I3HZVE$o z{pk0E##s!qM;`1Y{abZ$_hmu6+ZRUGHqXNC*9?fFi3Dbi4zgvv&S-V_JZ^EwqIHji zIGHwDP;r3aBpw_Fk$4df^JXVZjs1?E8=qpv<|Q1Pq3>kNnpn>3X>z>lMv1&Y<;kC9!P2_h-G0i?F1N$7L;cMSUxGKm2(W2QfqBX#r)fgd0RtrI6 zeJU+A?|qV!K~o{{q$IFgsfs%eRlmnWyfz>j{c-672LE|v$Q=?^gDPA}uG zA!>H=S1gH{lZrh}S#BwM^?UD`)W`C~ zd){>L6rGEH$Is%l>RlN2X%@6?ih);SZXiy-?~#64H6hB{5+*LX2TUKtC!42!VvudR+Kc z1SNyQ@T0a0>|6N?Hmtgg_iqod;XbKk_WcN0{;C>V@~7~0f)XIraUUl%CmK5a%5cq; zbQqo%%~tFQLC*wpl4KT6e+_3rtm1erWIZ3a#sy)Kfe{$K^P&cW2@qPNip6pQ*dRWU z7?y@$^rFRNefM#){p1aNFZ95Ct=~W_f{xkiyP6nmk6vO(Z z`pl88rRILVbHV<5IaN_NWw#ueNor60zy!zBRN&M_rpDzkR%8G&m&36!;WiBFtz#y% zB;%iLL9lMK9u)a6hpczvr2nG~Z*#c`&urr{n&M~8I~UXg=U)nO3Z58pL@#PnVXs=# zoq0!@nT6Hl?ze4FxM~J$IdhgozRMuT7B7WT!$6WR76$D~v#IL0qqH+f1WG?2CbPGR z!OX~$w8MZ4vf*1<$bCuLjds*WUdtg-eN$l3R2Xje&*UtP3&&Gy_kqWfJ{Z!R zOfKG$!<|Q3;Njj?@HpfLUGuCEKUcaFAGhhO-GOuD`1n1(R<8qhU-?PT&-+Zyk5!T9 z*UOmCXU{Vgf4>OXw&0ab_{K-BAuB zS^4N~9tTc$yvfKZey|nqqhF1G)xUDuk5?iTNqLepgnTKby!G>mwMiEd^(#RX?LybM zh0tND3AjIyv`L-DhL0>%m$zf9lpuDD4AkA3p$p2~cw(ZI#mqes0>Mw**x9rE>9vG7 z>TN8Ce?&ql-@|J9cK99HrX4_qLM9>qrbNg{tsvf3g~Yc$oWj||L~`H$c zsm2-0Qt1Todg?U9(5LeW6qYHH#!t zO7MJb4|!o{ie~i;=&Cs42PG+Lm7Gb#2c@ZwiZNaPJrW9b`LN9`TPPknPkYAyOfCt+{NpB|)R%_Z} zONI!poiI!%jYUJ$;RvWYnNEq_OrBzBBOJMZ7v-f4c;AwuA>PQ8qaY$j$2Wco%Ty*qi)$NPyEy@# z$!)GJ)6=9AJNiLLK8eNqXUtU&zMz|)>cgSE3cP(?spOQ&QM6n?LLX_2(YYn@7__ty z&RFHMnm??d`97DYr!4{jj3G~RTRew1I)$gzBLLs&R(eD~$z1LI43r*>B_FHK5eFF~ zYI&}lMksaBq}mo5)8vFc9}URomEyR*MiO7>vSjetdv-5GlgPt2>x;TB(k_j#`nG|I zczU-2-aos6)$y+<(?U`8$T5OEAa`#~M2NRsUzXg^T?ZXq5bs#6L2A^kN0tcE6 zcde6QZJQ_hZaNDGj15TdHCayHgG-S1jSK7gqCqf7lQSwmiC*?cvRG*%uP39Md{8yQ z@T%9auA?0Ngr}q5uE~s`N*^)FYQw7e7&2vAH zzweMSdp3p8=MqGBcK`$WN|?AZ1y0(Ua6D7QI0ea1VS>Q}xIOhQs5J5*CukmWEI+}@ z;>o-@;sWF%dxHd?j|Lr)3Y3t{fjP?c=*vxmv5Y+!Zm^W+^u`9#|6HY~>rH?=*`KTA z>L~vIvdQ>plB9Z#8)|c#(82d`|dhH8YJ(a&XEj z8KUKnG4ju+(@m?Y$*I4>F!(!={yn*#+%Yew^ZW(bPH#^Tc{0o-z1zYXnx>N4tUv7L zRnzg+)t#`!J%Eg5En}5ac9G*|zsVF%7L`BP%5c^Ex$6HX$e8<|AY%laanIBf)!NgP%<>o9FLVXQ{SDY7~&i0t&(fMy;(u=-Lo8tt&g zPgA4_+hM?4w)Ys8o+(1U_2UgbyJ)N&N&$~%Lx@}vkISFkBwe+!;PrYZbdMKpy~ONr zW9ndv%Wv7hlZaPLD5!K*fRG9Ot&f{jen%^kq3jhA3e#R zo#n9O&pXEcNj&oJiz0Kj_d=DC7tGQfgk)|sELr&pgAd4%CT%%((m^jU&pi%XeB4n} zJ(NiFSHhyDiqutHmCi^ng?T#mXr4)r zrgHF%_#dMEGo3DNs3nUZ7D2`2*QDsF4(@&{#o0}raWu{gddjk~X!HOi8jKLfVpUL7 z$%Gq!Ht?x} z_bvx;Funs9ngv02%_W+gv6DP1oXT67R1Wc}wz%rUEnw3|@Z9Zqc zhe6YozlgM@1uTqKf{({Uu)r%B+ru|O`0+}b@xBmtbu5PBPwmu5GM--C@R062yqO-! ztz)w1-J)6h1EA*ePB{1|fh1;e;DN+U+_~lw_LkYZ3yGc!Jj5wa+EAeuq9K9NO3Lh7w(@7VEN$i31 zBqn4pv2^mM^ByOIU|k^Ne0M!e7)+*>K56FO<3)o}{#zvS$xrsnvnx1c&XnFr;d}9QR-c-N?u1>{!tQj^+=U{U0Z+7RKBD__hOqC`NgVw?wL_a;2d~!QT zY!Z3&<&_G&5U7vst%M%+41vhfBF0WL4t9PPL9dvZyld*$sKM(sD6TRG?W3BozkHOu z*!T;s!)c3_AJ!-y!`6o*^ooWNU152F?g-yW<3tv-A8$IE zTfAGrTkBG8el*np2XZH%-P>poTXvEjeZGTgPRk|_6i?zFUKu&NBAf|-qe)I5eowxq zj1$Y^7-}LR3IZD^gXN4*L}NuMI{%u3lMbK2{YQR5NQV%HNL?TorWugyX@E3f z0kCxpHC$wPy`ArvmtJlFC-h;{Ze>ooo-k#< zbW#=B(B^};JgV5=Ru|0$f+q1yqyyn~Tpx_MkI`G4xtLih!CByY9{-Rn^}5wru%tN< z`m_pgqUC+^@$_b3bo-$^uA1)Zw?&r;<>XGU5Lqes5c?a_kV%z<#qzSiai~VKp7qrJ z_e!wlvq1MiY3Ojw2J0OQsf7tYiBy!w&*J%{W*YKiTD17+*mi*1jWoKb7t5T5<;T6Zp3*W~m^z=U4 zz3C=dRToRn>YOFZ@)Teo&xY{!h0%)3LfHB#4I^c3@ugn@>V0V-5iK!n$z}`qV0{_9 zC7MAbbpXtg+`*$}CGmZ2!8re%Ln6B`)$^tgQ31&gFxjj_g61$VaC!yIUzJN@#RurK zpZl2rqaWsu@1y9Ou-ft5qX_%0{*ZU-vb-a;C)jUc#yCrHK3?k&MgOWROvTlOxG{1L zd6~4FW0V~Mc{67*DO*_*qiD>QvZ?f5^C@B-(nki{%yD!_G}#@RhezyZkqAoX9foO=#xt!YpHJNLy|ur!U_@9E?FDN< zTxuhga2P;GP9@2qGl6qPA7A$Ow;c)A$jV)fv0m}1ROs&;o?9wTt0OdbY{3<<~0HKK(!9OUe!gS?|r!Ma*L-li>Ek&(Ox{B3bmv%CjdPAq~Xix=cz@7DSz z^)9%%tqz)vR>3yOs~{_E0cj^HsX_8WEMOng&akKTmx~KgG)x{2=ijCuH@|=pmvm;O zTsrdU&cbDTPQ(8A035C~!SzcQfzzRvbcx|^nr^L%5gq_l3xa_SIYzp^WZ|EiC#jda z0hoh2E;l{|J2uXN`>$8QzTr5u^azL7G0k|+@CQWkO*U;W*$;<*XyOrtCYqWt4IGNq zaCC&9#QfxgLo>v2Rb&#R?23T-SLT7^7c-PzvK%|U9DwJ`^I-J)E#k~K1w`LOqsx!a z5c1U&azZ4qckfP|sPTflUb`7g+qy}T7>_A7>roGG zGyL`&L9XUIrO$s8v!}xKgBNe(g6VR2Iqe=6z#3%gy>O3Erp9*{AFeI@VtJ zbMv?`W`ze?7}ZK&eC5%in}JNO-$W#9gg99qi^;>+wM=KfI7!M%g2SZ|?9}ypz?T!Urkd(*7vlYnI^ykj{KSDNF|6f9;^&fN5{)f4YYvOhPWiF9iy?+SV>Lqp)AVx}! z<}~Eu%XW9b=F=qS@Ju-9_!Kk_G?3rYjv$ghk6*A<6h)1{B(m5v@w zxbmDSmdmF(Lmq6w3rpnA_UF#|KLTa+AA!>U&kcir+%UA~&ix+(B_G8#`iDR*66WK( z*Pw<2jn9eO$;+s8_BhG7tO(Ocs&yn)eh9KZ(ik{;^s_QhHi1dGf z_f;LN!B;H1;M`aLO!?%)ana8f_>n!C=lM_s3=+4J%?Zy~ zNBOJ&xagR=ohmL}M&-qx(ks5pvHSRF{jAC) zv=B%mhSodLZpl3Kah?YYgY-b7AR0AhM56P^PMRPd4zDl|zSKXii~aK$zaKWifJzB? z_v$M&rb}_&4Oaj@AAxm~_;@?^yK!zj-iW!ey{LEh#<;pvfwS}eRd9aG!%aHTu={N; znhpNKw^>Iqz+H+asknj7yL)7Ga2>vmkibaab`p)pm<G~10v)_YOm&aoG zWI1vrR%|?o)M@Uqm_2}x|trH zs0)47Rv2?O4t`1%f@#okVlZUP**i}f0$vXhzr{{W-*_U@E^9G0IC_%~6v}hbT4ONY zV;D0ZCDZ=+Q$(aeh<8Unj5dGQN4a@XbiTkRGLn#jN8DY`|Yp|i)8T}!W%zoXHpYF~@BWfyVE zXBo5_(t=Z;k04!8ikA;*vQAUO%?(H`z63ce7W+v=^OHeZ{1V1p;1;*3X{PN!Ohff(%VH@bE67$K5Vc+h<~Ecgh*RgoeS}8)u-|@-JB(^92?; z^FgC-D6>bm8Pc4aNL(Es=V7`g%}g4FL(|Wqtn@tMuze-OMQYOfAKGcwTLaAf*~nO^ z@IxYbUgxOH$2pa8m9?~vW0f>}D1Fk!Za*N*am{k3PepdX=+0g;Sgp@_5cP^ieih(# zt9}5TjxegKy%^^lS`0sTz9X)dA4$pHGiVg7g_%bdW0nZRnJ835UY=ddyBhHm)6cBI z)yt*mrNj~tmMw(!7L9D?#0b1PuI9h~O`3P}X*znZy+bx^Z^xUj-jUj;-stenmHMB` zp*8uLASsrIVjm(d< zMUa<}L<8KH!t%*w5O5KIT?3%OSE`kf3U zIGlkGu@_KEf`Oxnp_uVnl1TC%GA}a6%zj%H!>(6dq-MhhcIZJEsyr}cyDn$qw>C}E z$e&1L85`_76#?tMN@Iy$I$gKF1|Li~$SxO@!8KnvbgKtTWUUTyN*1f3l0goAy~6;V zENYmyI*PcaD;Y{7V^Foin}%k4V_UB(I-DuT?u2orZ;l>t=lOHz{~rJv{SN>d|A#ya z{y`oSd+x&j1whk(0nmHuI=25@KKfm>9%u6z@V7r6ini>9BcrF;)sKZ?^&Tg1e07mb z^>l&r$4?@wY{vF&ssVn#c-r~#IjQywMNAu|d-6}frm@HoeG9G) zr^&aQ*XZtDo56~H$Iyts`4R2ppDY;E9$tm9^G@Izcu3 zCa26<_Ur>YRka09SjeF9rf+m&;YJiVwvX8+o{69D6`|XLg4c*_hc}m&R9DjK! zj>SbyUY2D6298%J-$*Ti0}sao!{<0zfgM!iWGGlXeNLBzZAF)B2g#cY*|=!2Jes>b zCO`Td;5SoCFRh8E%RlR&;lNN`RyP9$<6HhT7g)fBrK&J@z979`n?cjeT2R$R3tdKh z$zkES;CkE{u`L=>rkqCUSqrc`{x>~eI}ImlS>vBDC$MX2q@_O?GUM+Xy2^>t*{ybv zZh4q2d}>C`)@5O``&}@IwSxJ(-=gEnPat%U;MasB97bc??5UX2GF8i2n`1zrm*f`;ZiaB{B)8SP`Z6Nb>eIFr5~si3Q{0-OXs zzyzDCD7G#eJ`~(wURe7Or%Wlhu`~|5pCyyNW=VQSm7fT&y@-F#yAhpdMu3X zY5Q#e-6sri!lN}LciAd%Jg@||ycXx^5-HSw)5zHJ-ce0XHJQ}*jifnp;IF<2dX)Yp zULo@J!=+m|Et#c^Rq|5sm*U~`bqny;K_BwXRsp805#@-M_0YI4uB5Xt0?d|QXHss>K$*rq zS|bLq@O?JCk3sjx+e{LTJ-ahr&4*h_V9>Z73LW>P<9M0m|}acOHm>MO0nM9xEU zU3V>doiM~9ts{8Gd=VU3qmTCTy`(E%81~LpgzYhWH1kC>H0vuuZSF2SzqgRYn987c z;1=?$w;Bqvl`(#@H9p82ueO5#>d2p_5fL-tu6`y_6L!XWt$1=XDFJ-nD&xoAajB+? z4`%KhCs41?(2BeW8rk5E>l|tswQZ~LXl)leO&|%&P7tCSpo&RHggK8p1jqk79Z-C1 zPxfwXV@}rHMQAtco_`tOlK+s$ z@*m_`YR_HvzW})WUjXzpUj~u)$M-aT&?8glOXG(E1=KY=N}lId&_j;K%a z`Nxfw_FUV4+z_`JXMn5z0l;Yje0=(|v!TFzKS}W|po;1$ko09g&?TXms8fpbHO^wX zj~|ZRc~6ea5~P~L&sqJKTu|Uoq)l^Ile^^+bm475B4DD6@_%jN=s9P2^(h6;JR8Q# zA0kj?_G?-4Jib>)`K31DP_+HDl_phVZ0xiMnc@}+Gc7r{C zyNq@!-m5=zL!X?wca{E_&PR4{wucI?7vYqQzb{7mZDcJq1z4y*It)cj>+G@YWj z&-{QzBpbT>3M*a`O1syK0jY>$Z55=b%8JF*_1s+YSkj10a}r>k?**H`5K^muaj&cX zWUC^xYISp+UE?WcG5;S%L&KupJ#q>afAQGd`fM0$_E?IQ6Sz{ZRI`Gtm;J$NAJt%2 z-fmzuofont*KUtrzl|8IZwN}3_2if5V)o(YX;2i{PP$WmQ|a9c=!((^vYgvV;`}vH zc6cMXQs@Te$5vv~_M_&1iqGPHl`?Aeu$K64+>GLH))YI8w}bu|BMep8Z<(#QHxG}#`nt-K3zNxjVa=K{E(Cm2uo)H3tdbm7qTR%)HT z4j0F{ApLw59FE1K{Um8fSfoxHvUOqIi!J0;&P?j@S7`m?rfXb+<(9%;%4n z3tbO~QKt#D-w*SNUqyIzc1)8$Cz^6{khv+1A@68j$-5UoxQgEySTgfyKr}S z{`0H{Pi@EXqkP<*!Q5SrBIBO|xBL2yKNf6%x9IRkelHVByl@Bs~CLXT<9O diff --git a/examples/sgnn/model1.pickle b/examples/sgnn/model1.pickle new file mode 120000 index 000000000..88c340bb9 --- /dev/null +++ b/examples/sgnn/model1.pickle @@ -0,0 +1 @@ +test_backend/model1.pickle \ No newline at end of file diff --git a/examples/sgnn/peg.xml b/examples/sgnn/peg.xml new file mode 100644 index 000000000..d3f41baaf --- /dev/null +++ b/examples/sgnn/peg.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/sgnn/peg4.pdb b/examples/sgnn/peg4.pdb deleted file mode 100644 index 2c11081d1..000000000 --- a/examples/sgnn/peg4.pdb +++ /dev/null @@ -1,63 +0,0 @@ -REMARK -CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1 -ATOM 1 C00 TER 1 -2.962 3.637 -1.170 -ATOM 2 H01 TER 1 -2.608 4.142 -0.296 -ATOM 3 H02 TER 1 -4.032 3.635 -1.171 -ATOM 4 O03 TER 1 -2.484 2.289 -1.168 -ATOM 5 C04 TER 1 -2.961 1.615 0.000 -ATOM 6 H05 TER 1 -2.604 0.606 0.000 -ATOM 7 H06 TER 1 -2.604 2.119 0.874 -ATOM 8 H07 TER 1 -4.031 1.615 0.000 -ATOM 9 C00 INT 2 -2.449 6.384 -3.596 -ATOM 10 H01 INT 2 -2.804 5.879 -4.470 -ATOM 11 H02 INT 2 -1.379 6.386 -3.595 -ATOM 12 O03 INT 2 -2.927 5.710 -2.429 -ATOM 13 C04 INT 2 -2.448 4.362 -2.427 -ATOM 14 H05 INT 2 -2.803 3.856 -3.301 -ATOM 15 H06 INT 2 -1.378 4.364 -2.425 -ATOM 16 C00 INT 3 -2.966 9.857 -4.767 -ATOM 17 H01 INT 3 -2.612 10.363 -3.893 -ATOM 18 H02 INT 3 -4.036 9.855 -4.768 -ATOM 19 O03 INT 3 -2.488 8.509 -4.765 -ATOM 20 C04 INT 3 -2.965 7.835 -3.597 -ATOM 21 H05 INT 3 -2.610 8.340 -2.724 -ATOM 22 H06 INT 3 -4.035 7.833 -3.599 -ATOM 23 C00 TER 4 -2.452 10.582 -6.024 -ATOM 24 H01 TER 4 -2.807 10.077 -6.898 -ATOM 25 H02 TER 4 -1.382 10.584 -6.022 -ATOM 26 O03 TER 4 -2.931 11.930 -6.026 -ATOM 27 C04 TER 4 -2.453 12.604 -7.193 -ATOM 28 H05 TER 4 -2.808 12.099 -8.067 -ATOM 29 H06 TER 4 -2.812 13.613 -7.194 -ATOM 30 H07 TER 4 -1.383 12.606 -7.192 -TER -CONECT 5 6 -CONECT 5 7 -CONECT 5 8 -CONECT 5 4 -CONECT 4 1 -CONECT 1 2 -CONECT 1 3 -CONECT 1 13 -CONECT 13 14 -CONECT 13 15 -CONECT 13 12 -CONECT 12 9 -CONECT 9 10 -CONECT 9 11 -CONECT 9 20 -CONECT 20 21 -CONECT 20 22 -CONECT 20 19 -CONECT 19 16 -CONECT 16 17 -CONECT 16 18 -CONECT 16 23 -CONECT 23 24 -CONECT 23 25 -CONECT 23 26 -CONECT 26 27 -CONECT 27 28 -CONECT 27 29 -CONECT 27 30 -END diff --git a/examples/sgnn/peg4.pdb b/examples/sgnn/peg4.pdb new file mode 120000 index 000000000..3c2bb15b6 --- /dev/null +++ b/examples/sgnn/peg4.pdb @@ -0,0 +1 @@ +test_backend/peg4.pdb \ No newline at end of file diff --git a/examples/sgnn/ref_out b/examples/sgnn/ref_out index 96bc1e62f..039b75427 100644 --- a/examples/sgnn/ref_out +++ b/examples/sgnn/ref_out @@ -1,37 +1,5 @@ -Energy: -21.588394 -Force -[[ 90.02814 2.0374336 35.38877 ] - [ -98.410095 -1.6865425 -30.066338 ] - [ 48.29245 31.675808 -43.390694 ] - [ 59.717484 -35.94304 50.599678 ] - [ -24.63767 218.36092 168.47194 ] - [ 43.258293 81.24294 -87.22882 ] - [ -67.66767 -17.780457 -5.6038494 ] - [ -22.928284 -302.96246 -123.14815 ] - [ 306.24683 -21.33866 -156.95491 ] - [ -4.715515 13.664352 -23.222527 ] - [-258.61304 -26.577957 85.58963 ] - [ -10.179474 106.21161 64.846924 ] - [-210.20566 -52.107193 58.04005 ] - [ 118.68472 -8.033836 -81.18109 ] - [ 44.02272 -34.508667 46.852356 ] - [-214.84206 115.90286 -227.59117 ] - [ 44.243336 -7.151741 26.06369 ] - [ 87.46674 38.574554 192.17757 ] - [ 27.345726 -58.87986 -44.685863 ] - [ -83.354774 -29.714098 214.93097 ] - [ -71.111305 34.880676 -77.53289 ] - [ 141.12836 49.28147 -97.597305 ] - [-220.25613 -134.58449 -23.567059 ] - [ 75.2593 58.432755 -63.99505 ] - [ 123.56466 -82.0066 94.63971 ] - [ 57.822285 17.07631 -53.788273 ] - [ -73.37115 0.50865555 16.240654 ] - [ 54.86133 97.53715 73.672806 ] - [ -23.997787 -73.92179 -13.749107 ] - [ 62.348286 21.809956 25.78839 ]] -Batched Energies: -[-21.653 -39.830627 9.988983 -48.292953 -32.959183 -49.7164 - -47.617737 -51.76767 -37.42943 -35.06703 -46.111145 -31.748154 - -6.939003 -5.1853027 -27.427734 -44.695312 -52.027237 3.1541443 - -72.8221 -28.33014 ] +-21.588284621154912 +[-21.58828462 -39.79334159 10.03889335 -48.22451239 -32.90970162 + -49.68568287 -47.58035178 -51.73860617 -37.39235277 -35.01933271 + -46.06621902 -31.69327601 -6.86739655 -5.13698524 -27.4031207 + -44.65301991 -52.00357797 3.1734038 -72.79081259 -28.27007722] diff --git a/examples/sgnn/residues.xml b/examples/sgnn/residues.xml new file mode 100644 index 000000000..aa78866eb --- /dev/null +++ b/examples/sgnn/residues.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/sgnn/run.py b/examples/sgnn/run.py index f87c887b5..14c7c1b84 100755 --- a/examples/sgnn/run.py +++ b/examples/sgnn/run.py @@ -1,45 +1,44 @@ #!/usr/bin/env python import sys -import numpy as np +import jax import jax.numpy as jnp -import jax.lax as lax -from jax import vmap, value_and_grad -import dmff -from dmff.sgnn.gnn import MolGNNForce -from dmff.utils import jit_condition -from dmff.sgnn.graph import MAX_VALENCE -from dmff.sgnn.graph import TopGraph, from_pdb +import openmm.app as app +import openmm.unit as unit +from dmff.api import Hamiltonian +from dmff.common import nblist +from jax import value_and_grad import pickle -import re -from collections import OrderedDict -from functools import partial - if __name__ == '__main__': - # params = load_params('benchmark/model1.pickle') - G = from_pdb('peg4.pdb') - model = MolGNNForce(G, nn=1) - model.load_params('model1.pickle') - E = model.get_energy(G.positions, G.box, model.params) + + H = Hamiltonian('peg.xml') + app.Topology.loadBondDefinitions("residues.xml") + pdb = app.PDBFile("peg4.pdb") + rc = 0.6 + # generator stores all force field parameters + pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer, ethresh=5e-4) + + # construct inputs + positions = jnp.array(pdb.positions._value) + a, b, c = pdb.topology.getPeriodicBoxVectors() + box = jnp.array([a._value, b._value, c._value]) + # neighbor list + nbl = nblist.NeighborList(box, rc, pots.meta['cov_map']) + nbl.allocate(positions) + + + paramset = H.getParameters() + # params = paramset.parameters - with open('set_test_lowT.pickle', 'rb') as ifile: + with open('test_backend/set_test_lowT.pickle', 'rb') as ifile: data = pickle.load(ifile) - # pos = jnp.array(data['positions'][0:100]) - # box = jnp.tile(jnp.eye(3) * 50, (100, 1, 1)) - pos = jnp.array(data['positions'][0]) - box = jnp.eye(3) * 50 + # input in nm + pos = jnp.array(data['positions'][0:20]) / 10 + box = jnp.eye(3) * 5 - # energies = model.batch_forward(pos, box, model.params) - E, F = value_and_grad(model.get_energy, argnums=(0))(pos, box, model.params) - F = -F - print('Energy:', E) - print('Force') - print(F) + efunc = jax.jit(pots.getPotentialFunc()) + efunc_vmap = jax.vmap(jax.jit(pots.getPotentialFunc()), in_axes=(0, None, None, None), out_axes=0) + print(efunc(pos[0], box, nbl.pairs, paramset)) + print(efunc_vmap(pos, box, nbl.pairs, paramset)) - # test batch processing - pos = jnp.array(data['positions'][:20]) - box = jnp.tile(box, (20, 1, 1)) - E = model.batch_forward(pos, box, model.params) - print('Batched Energies:') - print(E) diff --git a/examples/sgnn/model.pickle b/examples/sgnn/test_backend/model.pickle similarity index 100% rename from examples/sgnn/model.pickle rename to examples/sgnn/test_backend/model.pickle diff --git a/examples/sgnn/test_backend/model1.pickle b/examples/sgnn/test_backend/model1.pickle new file mode 100644 index 0000000000000000000000000000000000000000..0c3959cd9d0ef4fac155676861aa6743acb96c99 GIT binary patch literal 17100 zcmYhjc|28J^gmAKAyY_*1~L^=63$*Xr4mVLpj0v>4d|LmrOb2Y$XF7RREWaa>!!>p zB}p1I7fmWEY4~}b&*%Ame)o_2y03HhzH6O(&RXyNUhBP2h=7}$&z?Qo-TZg@c>9Ul z`MPiS-R^F=)6HL;%co<{<=1xP=i}qs$DQEj9pJS$NZ-xJce}n!ym~N``>}?{y}@U zi*v;tCyZP1n9r54;j`h7=1SUgCu-XW{A-)xO08P8%KGp4>)$sUK7X#XYwQw11rJ#c zSH_*^@^&ulkCeM9|y+js8p^ykWX3V8l&dy?n4VR?J5!as(!ZRSUD6+I`$E)JQ- z$G1T+fcBgerYQfCJnLyCU5n1L30H2haA*$%7`q}YKMGyJIdm3Jif)~eM88DUuz!uc zNNAHD+&Oj~KmEz4?{+8A_PLsLs9c+veQ%*5X%DISqnE5w(ZYq^-y~7^*j0K*{~jG? zo0*RZ6-;YuG?Mo3Bsg^!h6>seBNG)o+1p0szEs1!-xHuICy&g3U4{qL>*3;vItT?J z++%M;@YfPJ%(@T_&OuNe=EM21Ps#JaKr1#@`jJY_vKV!_Lxq4MtdG?`%+ za8xiyTN-Hc7!zA9hPRsjl2rp@OquL=HvHWcni(O?s%SVeJ!8|Tdy+V;@zlZZItF#` z%);yGiaS(6Hjg^4&&Fz%Od?jim7Q>^1;4KifrysX5WY~0RQ(pit&#W0lmgrOgJDv< zx&`%6#crm}t;cD_+kE9m zyJaiMx=bnTIHUm~7pMV&TLd7zWFD1n4<*z7PJ_7zwouFJ1hY32#o>|pN%G5Q5k| z)2AF`t9s<9UuG0de00TJbnYBDGh3Z7k-hZs@j1l9PY80l=aC@6e8RtBElO;gi%lnz zuxr&8(x@wmTG!W-w{ca3(Vj^A&W;iO{7hW^CyG8@zlRtsmu6ZXpJ19w3Sq~myQq2O z8Qu-6!4wH&tedwSS#?EJ4^#sqrVF>X#BesXDG>gBmUul!hey?BQ(kU8=X2jl8l5!2 zR+g_uA%_esv$n-6-LpV@K$;LmX`I(oMLexz@J50j{T`G_RWpPjZPPN=PBH{}mUd`x zTL9E;&B&l!CE6qz(y4t>#5H~u8!A{#{(fiC>X{DjOieSLANP#5&5Q-V`IjjFfz9yf z$piAlV?Fv@ngg>BJ~CfE$qM|9U((AWPpQZUF^;c(F3c&Nj6EfH$?WTva6mB;YG)nc zOmqo=zn@m4b^3W;@Qx?&!E7}wkbVcf`U^QK)@oQSeTL|zFM`)rQ*j>-lY|K@b%AU$ z@2LdL4eTNZ?uo#$KkMqxDTdeYds9rVE$C){Mx0`HSxCZijVX}*_(}b#vHOJgS()-b zyGe~729a$u)lltlmARK%2T^cd1&7u@VTE?sQnf7;cn0}47zdj8;8isp-W>*;cLk%1 zya%<2n^;${Bn01D%>kur5qL6rE$(Z%3x$b^)ZZ!s?^>sz+?U@hi{nr_HpG*=u=RVOp zdJsa6>}RH5+=>P(LuvnX5q!~TOTWj6a~jHO;K$-A7%a4xhanx<_1POlqFYfRWF}8P zZ34&j`+KswYX$5WJPaN$1Yl)w5BYRIf|;qc0CIgJ$m#cH^zKIhdu=b8-RI8SJMoC* z*iI$F&zI8TE8;Nq_ZDd0B1x|FYMCn)CeTygmy_yJKDZ*R!g@T)lDN#=yPgj+svk*}fH3U2z>uE9@#Kc)1{B|2i`(J{>8(fah*rWf zRKBeQ=24$;rCI@WytE?S7r&DC1N+hT9FQf&44q}zLhHKXNxvz}rXI{9DK)!j6tO~q zM~~>-kKa*njE{aToQQ8q_^HA-P4aXhAFr{$j+QL-Ak7-BWYDr0SAUsI_Pn`8j8<3! z&p8A?2lX%_Q7xbvkVK4<=itfH3-E&4T-+5Q#0k6l2RD^jgT479W|0;@_!u@(6?;{h z))x#B!7`la+FuaI`N&Q$D8-wN(>P(9*WswLOZ}T2a&%}}EciU%4oki}l8M}9c>Vh- z5*D}shnx4{!KsrWf8R9TnS|G@*h)>BN~CCT!b)6YTFc&=mO%Apm{75vt)za|HD>v! zC}?n0!D-bVb~I85RWE0O^!aL9q4=J(ttM%T$FI9 zHm8qLts7p{MA8!rO+)G5w*%%sQv5K;;0)fGbc>l_kOG%qhfxw10}8&GwDG|ic1Qb7 zeDi%f7OXYq%ovHp)%ANxcA_fE!a5k7kxaHu^r8(7;;i`8-B?^_$~e*p;%wK;4i@gg z`DYiBA$H1QP6~oRJ{SN#DxZSv zXcF$4vJ-AwhB1v|h4t#^X3$ubNZNJf0<5AVDAoCk=pQm68~!K)_FRF|*YddBRRSNk zU5BHxe@RM}HZ9v0Mt6CBqlOwNsGFUc;z3$gY3epTdv6ky_NhJ)` zPgTMpy8=A=YZgZ%kVnrR%p>BwLVDO`5e77bfL(_=3H@r0i>vy`^MKzp^mQGT%U7dM zmidrhq87Atwh+YECe@em?!v3nlg$IypP-YRi@|WoZ16nGkLHU?i00}MD2$edqwXp= zN#!B2*)C1i8koX|3VRf7n@DE3d}en~(4}TpgCsbmoVNEHf%3H&-hMk95dOH1$bOzc zWA+Q-hU`2#!`l}`H$I{^D}G|_r@y#&!w)K%K0qINzopr|S*&bg7;|qCKO_}duyUd+ z=}^y2GV`4pY@CJc%JUOBULRk4M`W9pZTD5}PtCNA7TPi9}#3USF@k`7l~da#;of7u_UUsg2a?L=*k3 z^oTC78HVVo&5Y@rdsNPGPrYjSUe<7#4edrQhNc`Nn|IHK$Key<+6+yi-=F}(g{xq7 z*9J2BDS~Xzox+;6y}$r+9jz5>QE`$2uhvfwcnOYFv&|I)Ya%gy>Q0&)k`&&tCME!u6xaNaryRg6er7R3VNL!~N)4WX!W@sD;d^ z+tkDB9QX=4)<-v8$8Qduu$+ou<+J_ds#gIvPHScNeba>OUDEiva5`I|xgX_i^FaTI zFj{#>BgRVBe|;|iTBjsI>7fICK4mK~-M``c_ul%ki>ACfbSyL=_Nv|9()z3d?`tw!j>-|xwYc|AMFZXbJK+ZVF%w-&u5yM~4t zj%NVhEQZ&fKq}hG;e(OI}YD{$|CnGX5j6lzck|cVqCJ|9C@_65)$o3F^6vwt@2q;_C*@OnSrbD z+g$^0O#XpUg%Z45FAjj6lss?BuglC^&1iGMpU*j4f<@4ILN{~XK>{4EiD6`I8tiRY ziEm35(w8^H$?LPC)I+~@JRTEBe{wkqoxGFgTplE`wh7E%whs%;%s54E{Io>qDzJK| z>DYx8G;}N&cKW5z_wp}DO0_J_kL2LF4TiMW(TI2{*n_Qe03A-H*c6b6MmKEWo~j9H zaC$-iTo0mq7cgYUReid#(V9tZ$|UgzRiIatpNu35!-9R!!A#>MbEHolhYJnqK2#-@ zxBif$FEn`ZhtxU!;U1iq%BPs3?^^7g?k#lh=UZf|*)mkMe!$@61ZtP}kOT^*kU1}! zXvW$<#38kr=HK1O>g$vc!RHOsb+#UjRnKP<*9e33N)LwL$%$cW6d2!<2TXd4De26; zf}RG+bz;XKlW4;%qF38TgT1Qhr)B5a(16X1Vs;I&35=!u6P3t4F@ADbLXs3zXp*Zp z+F1q7WE9@Ufy$VDXq+g>ep_5kYnGJ2;O4(%)B6j&-R|m~m#tH=(fk^W-O=NnkC)<1 z%~9Yi1AltLY6CnzWX7?Zo(lP8UrGL?$!N}<1h&7-q4Uuwxv{x`b!_^>Y#F^qS28bI zWy!~eq1C{VL9~K zEnVhAXCC>;|C+pABtTa^6+zh_iSWbxCHt%9CoSgHk(NG1dN{0xr~l^^6nqt z46<-KG=qdxCc{12P9la*g6xxgG}5!CmKIUuqk;fVzbMEVK@Z-*Cm9GfQRZx#Yen@^ zyUFb7*NC@#D|Cpxqvxjk(`+MkI*)FJX08cd&bSQPWfLGIcsaaM*no1I`CxuhE*+>j z#6A%Tqz8R}na5-%(9r3RneM%(q5OUUQNJ=mS9n&!iaF^-&?}FLeq8_qazpfWa}qo; zJcb!=??{i@9pY~6g)T0uaGk+pV*EXZGC=KXEdAXJpz4R;3<2UD_Ob0_Uqi-)V> zUGOgH6F%42MT~n^AZNcX)%)H?ecj(t?T-Vf=Pzw0{$?-8KKu@8HEwWPScfzDls->x zavWX9-#~8t%piZ{1!<>vviZ|1?wDLBLN1Q#!3Fh?jQKokP#Os$zmEySj0XlVA~i~l zgC7$e;a0fsK9xy3UQc%Tl@oo>bVklr2n3?}sB_sLa>*ecK5fwC%b6H8~C^^H89FHU~zOIdKIJ;C1E##1uEe@pm2gJ3favq|ApK zW7|MIWDajykvG07noPba$?*JJLb2k`8zL(%MNXm*Jr?ST#b;EB(3k~O&ydF9+H>r- z)h1Lv>N&IeKtA2VS4+g#6_fiNc?_92lbn{?0M!s$(XHn_NM=?d9m?Pb&nrCh zg0;isbGIS=D?Ae&n;Xfo-W9NUX$iSmxS!5W5ykZndud|MWL!5Zf`-5eh~X~4z~OVy z$z%Y3^ffj^$_*D%IdaLOgi$r2$6fwqR{K{AYI{doE_ONN=+XY5rx1e^2&J;jYtb) z7bypmI|jAHcFBBZvfW;cQyienO6Sw2+)(@yzl9cXB=F6_B$Sh>W8@wvW5f?t*tm8v z?HX((l~Je3tOKRgN=k}`Ri%&~lV@;l^BpK^i-C1rxgg;f34Pydq5s@&T%3#-9!~R@4M6)bbfkW&}s;a3=UgXx1lDu}>IinQ! zpQwN#!60fSW`oye9wO1}x(V$sf-f^pQ+bJryv5#6p}2b+QBt4J{xZxUS7lR4ThUSW zOmjQY-g6cXNxmnEqRVKPeK80J?1tGl*Fdt-d|tpDK5VgUM(d9C`0L>%`Y}odeZO4C zgSq!fvaCPkJ;(u#Gy~3pqaB#F^axK{*$MXZf1|RSuaZgbF~rHsg#Pv>WW|PSJ}%l-LpfKr%eFk4Zlu#mmAH zdRgx*7?+e2$CKH_Wd1i4%U;49-5Sbx$7Vv~WLc2=I3HZVE$o z{pk0E##s!qM;`1Y{abZ$_hmu6+ZRUGHqXNC*9?fFi3Dbi4zgvv&S-V_JZ^EwqIHji zIGHwDP;r3aBpw_Fk$4df^JXVZjs1?E8=qpv<|Q1Pq3>kNnpn>3X>z>lMv1&Y<;kC9!P2_h-G0i?F1N$7L;cMSUxGKm2(W2QfqBX#r)fgd0RtrI6 zeJU+A?|qV!K~o{{q$IFgsfs%eRlmnWyfz>j{c-672LE|v$Q=?^gDPA}uG zA!>H=S1gH{lZrh}S#BwM^?UD`)W`C~ zd){>L6rGEH$Is%l>RlN2X%@6?ih);SZXiy-?~#64H6hB{5+*LX2TUKtC!42!VvudR+Kc z1SNyQ@T0a0>|6N?Hmtgg_iqod;XbKk_WcN0{;C>V@~7~0f)XIraUUl%CmK5a%5cq; zbQqo%%~tFQLC*wpl4KT6e+_3rtm1erWIZ3a#sy)Kfe{$K^P&cW2@qPNip6pQ*dRWU z7?y@$^rFRNefM#){p1aNFZ95Ct=~W_f{xkiyP6nmk6vO(Z z`pl88rRILVbHV<5IaN_NWw#ueNor60zy!zBRN&M_rpDzkR%8G&m&36!;WiBFtz#y% zB;%iLL9lMK9u)a6hpczvr2nG~Z*#c`&urr{n&M~8I~UXg=U)nO3Z58pL@#PnVXs=# zoq0!@nT6Hl?ze4FxM~J$IdhgozRMuT7B7WT!$6WR76$D~v#IL0qqH+f1WG?2CbPGR z!OX~$w8MZ4vf*1<$bCuLjds*WUdtg-eN$l3R2Xje&*UtP3&&Gy_kqWfJ{Z!R zOfKG$!<|Q3;Njj?@HpfLUGuCEKUcaFAGhhO-GOuD`1n1(R<8qhU-?PT&-+Zyk5!T9 z*UOmCXU{Vgf4>OXw&0ab_{K-BAuB zS^4N~9tTc$yvfKZey|nqqhF1G)xUDuk5?iTNqLepgnTKby!G>mwMiEd^(#RX?LybM zh0tND3AjIyv`L-DhL0>%m$zf9lpuDD4AkA3p$p2~cw(ZI#mqes0>Mw**x9rE>9vG7 z>TN8Ce?&ql-@|J9cK99HrX4_qLM9>qrbNg{tsvf3g~Yc$oWj||L~`H$c zsm2-0Qt1Todg?U9(5LeW6qYHH#!t zO7MJb4|!o{ie~i;=&Cs42PG+Lm7Gb#2c@ZwiZNaPJrW9b`LN9`TPPknPkYAyOfCt+{NpB|)R%_Z} zONI!poiI!%jYUJ$;RvWYnNEq_OrBzBBOJMZ7v-f4c;AwuA>PQ8qaY$j$2Wco%Ty*qi)$NPyEy@# z$!)GJ)6=9AJNiLLK8eNqXUtU&zMz|)>cgSE3cP(?spOQ&QM6n?LLX_2(YYn@7__ty z&RFHMnm??d`97DYr!4{jj3G~RTRew1I)$gzBLLs&R(eD~$z1LI43r*>B_FHK5eFF~ zYI&}lMksaBq}mo5)8vFc9}URomEyR*MiO7>vSjetdv-5GlgPt2>x;TB(k_j#`nG|I zczU-2-aos6)$y+<(?U`8$T5OEAa`#~M2NRsUzXg^T?ZXq5bs#6L2A^kN0tcE6 zcde6QZJQ_hZaNDGj15TdHCayHgG-S1jSK7gqCqf7lQSwmiC*?cvRG*%uP39Md{8yQ z@T%9auA?0Ngr}q5uE~s`N*^)FYQw7e7&2vAH zzweMSdp3p8=MqGBcK`$WN|?AZ1y0(Ua6D7QI0ea1VS>Q}xIOhQs5J5*CukmWEI+}@ z;>o-@;sWF%dxHd?j|Lr)3Y3t{fjP?c=*vxmv5Y+!Zm^W+^u`9#|6HY~>rH?=*`KTA z>L~vIvdQ>plB9Z#8)|c#(82d`|dhH8YJ(a&XEj z8KUKnG4ju+(@m?Y$*I4>F!(!={yn*#+%Yew^ZW(bPH#^Tc{0o-z1zYXnx>N4tUv7L zRnzg+)t#`!J%Eg5En}5ac9G*|zsVF%7L`BP%5c^Ex$6HX$e8<|AY%laanIBf)!NgP%<>o9FLVXQ{SDY7~&i0t&(fMy;(u=-Lo8tt&g zPgA4_+hM?4w)Ys8o+(1U_2UgbyJ)N&N&$~%Lx@}vkISFkBwe+!;PrYZbdMKpy~ONr zW9ndv%Wv7hlZaPLD5!K*fRG9Ot&f{jen%^kq3jhA3e#R zo#n9O&pXEcNj&oJiz0Kj_d=DC7tGQfgk)|sELr&pgAd4%CT%%((m^jU&pi%XeB4n} zJ(NiFSHhyDiqutHmCi^ng?T#mXr4)r zrgHF%_#dMEGo3DNs3nUZ7D2`2*QDsF4(@&{#o0}raWu{gddjk~X!HOi8jKLfVpUL7 z$%Gq!Ht?x} z_bvx;Funs9ngv02%_W+gv6DP1oXT67R1Wc}wz%rUEnw3|@Z9Zqc zhe6YozlgM@1uTqKf{({Uu)r%B+ru|O`0+}b@xBmtbu5PBPwmu5GM--C@R062yqO-! ztz)w1-J)6h1EA*ePB{1|fh1;e;DN+U+_~lw_LkYZ3yGc!Jj5wa+EAeuq9K9NO3Lh7w(@7VEN$i31 zBqn4pv2^mM^ByOIU|k^Ne0M!e7)+*>K56FO<3)o}{#zvS$xrsnvnx1c&XnFr;d}9QR-c-N?u1>{!tQj^+=U{U0Z+7RKBD__hOqC`NgVw?wL_a;2d~!QT zY!Z3&<&_G&5U7vst%M%+41vhfBF0WL4t9PPL9dvZyld*$sKM(sD6TRG?W3BozkHOu z*!T;s!)c3_AJ!-y!`6o*^ooWNU152F?g-yW<3tv-A8$IE zTfAGrTkBG8el*np2XZH%-P>poTXvEjeZGTgPRk|_6i?zFUKu&NBAf|-qe)I5eowxq zj1$Y^7-}LR3IZD^gXN4*L}NuMI{%u3lMbK2{YQR5NQV%HNL?TorWugyX@E3f z0kCxpHC$wPy`ArvmtJlFC-h;{Ze>ooo-k#< zbW#=B(B^};JgV5=Ru|0$f+q1yqyyn~Tpx_MkI`G4xtLih!CByY9{-Rn^}5wru%tN< z`m_pgqUC+^@$_b3bo-$^uA1)Zw?&r;<>XGU5Lqes5c?a_kV%z<#qzSiai~VKp7qrJ z_e!wlvq1MiY3Ojw2J0OQsf7tYiBy!w&*J%{W*YKiTD17+*mi*1jWoKb7t5T5<;T6Zp3*W~m^z=U4 zz3C=dRToRn>YOFZ@)Teo&xY{!h0%)3LfHB#4I^c3@ugn@>V0V-5iK!n$z}`qV0{_9 zC7MAbbpXtg+`*$}CGmZ2!8re%Ln6B`)$^tgQ31&gFxjj_g61$VaC!yIUzJN@#RurK zpZl2rqaWsu@1y9Ou-ft5qX_%0{*ZU-vb-a;C)jUc#yCrHK3?k&MgOWROvTlOxG{1L zd6~4FW0V~Mc{67*DO*_*qiD>QvZ?f5^C@B-(nki{%yD!_G}#@RhezyZkqAoX9foO=#xt!YpHJNLy|ur!U_@9E?FDN< zTxuhga2P;GP9@2qGl6qPA7A$Ow;c)A$jV)fv0m}1ROs&;o?9wTt0OdbY{3<<~0HKK(!9OUe!gS?|r!Ma*L-li>Ek&(Ox{B3bmv%CjdPAq~Xix=cz@7DSz z^)9%%tqz)vR>3yOs~{_E0cj^HsX_8WEMOng&akKTmx~KgG)x{2=ijCuH@|=pmvm;O zTsrdU&cbDTPQ(8A035C~!SzcQfzzRvbcx|^nr^L%5gq_l3xa_SIYzp^WZ|EiC#jda z0hoh2E;l{|J2uXN`>$8QzTr5u^azL7G0k|+@CQWkO*U;W*$;<*XyOrtCYqWt4IGNq zaCC&9#QfxgLo>v2Rb&#R?23T-SLT7^7c-PzvK%|U9DwJ`^I-J)E#k~K1w`LOqsx!a z5c1U&azZ4qckfP|sPTflUb`7g+qy}T7>_A7>roGG zGyL`&L9XUIrO$s8v!}xKgBNe(g6VR2Iqe=6z#3%gy>O3Erp9*{AFeI@VtJ zbMv?`W`ze?7}ZK&eC5%in}JNO-$W#9gg99qi^;>+wM=KfI7!M%g2SZ|?9}ypz?T!Urkd(*7vlYnI^ykj{KSDNF|6f9;^&fN5{)f4YYvOhPWiF9iy?+SV>Lqp)AVx}! z<}~Eu%XW9b=F=qS@Ju-9_!Kk_G?3rYjv$ghk6*A<6h)1{B(m5v@w zxbmDSmdmF(Lmq6w3rpnA_UF#|KLTa+AA!>U&kcir+%UA~&ix+(B_G8#`iDR*66WK( z*Pw<2jn9eO$;+s8_BhG7tO(Ocs&yn)eh9KZ(ik{;^s_QhHi1dGf z_f;LN!B;H1;M`aLO!?%)ana8f_>n!C=lM_s3=+4J%?Zy~ zNBOJ&xagR=ohmL}M&-qx(ks5pvHSRF{jAC) zv=B%mhSodLZpl3Kah?YYgY-b7AR0AhM56P^PMRPd4zDl|zSKXii~aK$zaKWifJzB? z_v$M&rb}_&4Oaj@AAxm~_;@?^yK!zj-iW!ey{LEh#<;pvfwS}eRd9aG!%aHTu={N; znhpNKw^>Iqz+H+asknj7yL)7Ga2>vmkibaab`p)pm<G~10v)_YOm&aoG zWI1vrR%|?o)M@Uqm_2}x|trH zs0)47Rv2?O4t`1%f@#okVlZUP**i}f0$vXhzr{{W-*_U@E^9G0IC_%~6v}hbT4ONY zV;D0ZCDZ=+Q$(aeh<8Unj5dGQN4a@XbiTkRGLn#jN8DY`|Yp|i)8T}!W%zoXHpYF~@BWfyVE zXBo5_(t=Z;k04!8ikA;*vQAUO%?(H`z63ce7W+v=^OHeZ{1V1p;1;*3X{PN!Ohff(%VH@bE67$K5Vc+h<~Ecgh*RgoeS}8)u-|@-JB(^92?; z^FgC-D6>bm8Pc4aNL(Es=V7`g%}g4FL(|Wqtn@tMuze-OMQYOfAKGcwTLaAf*~nO^ z@IxYbUgxOH$2pa8m9?~vW0f>}D1Fk!Za*N*am{k3PepdX=+0g;Sgp@_5cP^ieih(# zt9}5TjxegKy%^^lS`0sTz9X)dA4$pHGiVg7g_%bdW0nZRnJ835UY=ddyBhHm)6cBI z)yt*mrNj~tmMw(!7L9D?#0b1PuI9h~O`3P}X*znZy+bx^Z^xUj-jUj;-stenmHMB` zp*8uLASsrIVjm(d< zMUa<}L<8KH!t%*w5O5KIT?3%OSE`kf3U zIGlkGu@_KEf`Oxnp_uVnl1TC%GA}a6%zj%H!>(6dq-MhhcIZJEsyr}cyDn$qw>C}E z$e&1L85`_76#?tMN@Iy$I$gKF1|Li~$SxO@!8KnvbgKtTWUUTyN*1f3l0goAy~6;V zENYmyI*PcaD;Y{7V^Foin}%k4V_UB(I-DuT?u2orZ;l>t=lOHz{~rJv{SN>d|A#ya z{y`oSd+x&j1whk(0nmHuI=25@KKfm>9%u6z@V7r6ini>9BcrF;)sKZ?^&Tg1e07mb z^>l&r$4?@wY{vF&ssVn#c-r~#IjQywMNAu|d-6}frm@HoeG9G) zr^&aQ*XZtDo56~H$Iyts`4R2ppDY;E9$tm9^G@Izcu3 zCa26<_Ur>YRka09SjeF9rf+m&;YJiVwvX8+o{69D6`|XLg4c*_hc}m&R9DjK! zj>SbyUY2D6298%J-$*Ti0}sao!{<0zfgM!iWGGlXeNLBzZAF)B2g#cY*|=!2Jes>b zCO`Td;5SoCFRh8E%RlR&;lNN`RyP9$<6HhT7g)fBrK&J@z979`n?cjeT2R$R3tdKh z$zkES;CkE{u`L=>rkqCUSqrc`{x>~eI}ImlS>vBDC$MX2q@_O?GUM+Xy2^>t*{ybv zZh4q2d}>C`)@5O``&}@IwSxJ(-=gEnPat%U;MasB97bc??5UX2GF8i2n`1zrm*f`;ZiaB{B)8SP`Z6Nb>eIFr5~si3Q{0-OXs zzyzDCD7G#eJ`~(wURe7Or%Wlhu`~|5pCyyNW=VQSm7fT&y@-F#yAhpdMu3X zY5Q#e-6sri!lN}LciAd%Jg@||ycXx^5-HSw)5zHJ-ce0XHJQ}*jifnp;IF<2dX)Yp zULo@J!=+m|Et#c^Rq|5sm*U~`bqny;K_BwXRsp805#@-M_0YI4uB5Xt0?d|QXHss>K$*rq zS|bLq@O?JCk3sjx+e{LTJ-ahr&4*h_V9>Z73LW>P<9M0m|}acOHm>MO0nM9xEU zU3V>doiM~9ts{8Gd=VU3qmTCTy`(E%81~LpgzYhWH1kC>H0vuuZSF2SzqgRYn987c z;1=?$w;Bqvl`(#@H9p82ueO5#>d2p_5fL-tu6`y_6L!XWt$1=XDFJ-nD&xoAajB+? z4`%KhCs41?(2BeW8rk5E>l|tswQZ~LXl)leO&|%&P7tCSpo&RHggK8p1jqk79Z-C1 zPxfwXV@}rHMQAtco_`tOlK+s$ z@*m_`YR_HvzW})WUjXzpUj~u)$M-aT&?8glOXG(E1=KY=N}lId&_j;K%a z`Nxfw_FUV4+z_`JXMn5z0l;Yje0=(|v!TFzKS}W|po;1$ko09g&?TXms8fpbHO^wX zj~|ZRc~6ea5~P~L&sqJKTu|Uoq)l^Ile^^+bm475B4DD6@_%jN=s9P2^(h6;JR8Q# zA0kj?_G?-4Jib>)`K31DP_+HDl_phVZ0xiMnc@}+Gc7r{C zyNq@!-m5=zL!X?wca{E_&PR4{wucI?7vYqQzb{7mZDcJq1z4y*It)cj>+G@YWj z&-{QzBpbT>3M*a`O1syK0jY>$Z55=b%8JF*_1s+YSkj10a}r>k?**H`5K^muaj&cX zWUC^xYISp+UE?WcG5;S%L&KupJ#q>afAQGd`fM0$_E?IQ6Sz{ZRI`Gtm;J$NAJt%2 z-fmzuofont*KUtrzl|8IZwN}3_2if5V)o(YX;2i{PP$WmQ|a9c=!((^vYgvV;`}vH zc6cMXQs@Te$5vv~_M_&1iqGPHl`?Aeu$K64+>GLH))YI8w}bu|BMep8Z<(#QHxG}#`nt-K3zNxjVa=K{E(Cm2uo)H3tdbm7qTR%)HT z4j0F{ApLw59FE1K{Um8fSfoxHvUOqIi!J0;&P?j@S7`m?rfXb+<(9%;%4n z3tbO~QKt#D-w*SNUqyIzc1)8$Cz^6{khv+1A@68j$-5UoxQgEySTgfyKr}S z{`0H{Pi@EXqkP<*!Q5SrBIBO|xBL2yKNf6%x9IRkelHVByl@Bs~CLXT<9O literal 0 HcmV?d00001 diff --git a/examples/sgnn/model1.pth b/examples/sgnn/test_backend/model1.pth similarity index 100% rename from examples/sgnn/model1.pth rename to examples/sgnn/test_backend/model1.pth diff --git a/examples/sgnn/mse_testing.xvg b/examples/sgnn/test_backend/mse_testing.xvg similarity index 100% rename from examples/sgnn/mse_testing.xvg rename to examples/sgnn/test_backend/mse_testing.xvg diff --git a/examples/sgnn/test_backend/peg4.pdb b/examples/sgnn/test_backend/peg4.pdb new file mode 100644 index 000000000..2c11081d1 --- /dev/null +++ b/examples/sgnn/test_backend/peg4.pdb @@ -0,0 +1,63 @@ +REMARK +CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C00 TER 1 -2.962 3.637 -1.170 +ATOM 2 H01 TER 1 -2.608 4.142 -0.296 +ATOM 3 H02 TER 1 -4.032 3.635 -1.171 +ATOM 4 O03 TER 1 -2.484 2.289 -1.168 +ATOM 5 C04 TER 1 -2.961 1.615 0.000 +ATOM 6 H05 TER 1 -2.604 0.606 0.000 +ATOM 7 H06 TER 1 -2.604 2.119 0.874 +ATOM 8 H07 TER 1 -4.031 1.615 0.000 +ATOM 9 C00 INT 2 -2.449 6.384 -3.596 +ATOM 10 H01 INT 2 -2.804 5.879 -4.470 +ATOM 11 H02 INT 2 -1.379 6.386 -3.595 +ATOM 12 O03 INT 2 -2.927 5.710 -2.429 +ATOM 13 C04 INT 2 -2.448 4.362 -2.427 +ATOM 14 H05 INT 2 -2.803 3.856 -3.301 +ATOM 15 H06 INT 2 -1.378 4.364 -2.425 +ATOM 16 C00 INT 3 -2.966 9.857 -4.767 +ATOM 17 H01 INT 3 -2.612 10.363 -3.893 +ATOM 18 H02 INT 3 -4.036 9.855 -4.768 +ATOM 19 O03 INT 3 -2.488 8.509 -4.765 +ATOM 20 C04 INT 3 -2.965 7.835 -3.597 +ATOM 21 H05 INT 3 -2.610 8.340 -2.724 +ATOM 22 H06 INT 3 -4.035 7.833 -3.599 +ATOM 23 C00 TER 4 -2.452 10.582 -6.024 +ATOM 24 H01 TER 4 -2.807 10.077 -6.898 +ATOM 25 H02 TER 4 -1.382 10.584 -6.022 +ATOM 26 O03 TER 4 -2.931 11.930 -6.026 +ATOM 27 C04 TER 4 -2.453 12.604 -7.193 +ATOM 28 H05 TER 4 -2.808 12.099 -8.067 +ATOM 29 H06 TER 4 -2.812 13.613 -7.194 +ATOM 30 H07 TER 4 -1.383 12.606 -7.192 +TER +CONECT 5 6 +CONECT 5 7 +CONECT 5 8 +CONECT 5 4 +CONECT 4 1 +CONECT 1 2 +CONECT 1 3 +CONECT 1 13 +CONECT 13 14 +CONECT 13 15 +CONECT 13 12 +CONECT 12 9 +CONECT 9 10 +CONECT 9 11 +CONECT 9 20 +CONECT 20 21 +CONECT 20 22 +CONECT 20 19 +CONECT 19 16 +CONECT 16 17 +CONECT 16 18 +CONECT 16 23 +CONECT 23 24 +CONECT 23 25 +CONECT 23 26 +CONECT 26 27 +CONECT 27 28 +CONECT 27 29 +CONECT 27 30 +END diff --git a/examples/sgnn/pth2pickle.py b/examples/sgnn/test_backend/pth2pickle.py similarity index 100% rename from examples/sgnn/pth2pickle.py rename to examples/sgnn/test_backend/pth2pickle.py diff --git a/examples/sgnn/test_backend/ref_out b/examples/sgnn/test_backend/ref_out new file mode 100644 index 000000000..96bc1e62f --- /dev/null +++ b/examples/sgnn/test_backend/ref_out @@ -0,0 +1,37 @@ +Energy: -21.588394 +Force +[[ 90.02814 2.0374336 35.38877 ] + [ -98.410095 -1.6865425 -30.066338 ] + [ 48.29245 31.675808 -43.390694 ] + [ 59.717484 -35.94304 50.599678 ] + [ -24.63767 218.36092 168.47194 ] + [ 43.258293 81.24294 -87.22882 ] + [ -67.66767 -17.780457 -5.6038494 ] + [ -22.928284 -302.96246 -123.14815 ] + [ 306.24683 -21.33866 -156.95491 ] + [ -4.715515 13.664352 -23.222527 ] + [-258.61304 -26.577957 85.58963 ] + [ -10.179474 106.21161 64.846924 ] + [-210.20566 -52.107193 58.04005 ] + [ 118.68472 -8.033836 -81.18109 ] + [ 44.02272 -34.508667 46.852356 ] + [-214.84206 115.90286 -227.59117 ] + [ 44.243336 -7.151741 26.06369 ] + [ 87.46674 38.574554 192.17757 ] + [ 27.345726 -58.87986 -44.685863 ] + [ -83.354774 -29.714098 214.93097 ] + [ -71.111305 34.880676 -77.53289 ] + [ 141.12836 49.28147 -97.597305 ] + [-220.25613 -134.58449 -23.567059 ] + [ 75.2593 58.432755 -63.99505 ] + [ 123.56466 -82.0066 94.63971 ] + [ 57.822285 17.07631 -53.788273 ] + [ -73.37115 0.50865555 16.240654 ] + [ 54.86133 97.53715 73.672806 ] + [ -23.997787 -73.92179 -13.749107 ] + [ 62.348286 21.809956 25.78839 ]] +Batched Energies: +[-21.653 -39.830627 9.988983 -48.292953 -32.959183 -49.7164 + -47.617737 -51.76767 -37.42943 -35.06703 -46.111145 -31.748154 + -6.939003 -5.1853027 -27.427734 -44.695312 -52.027237 3.1541443 + -72.8221 -28.33014 ] diff --git a/examples/sgnn/test_backend/run.py b/examples/sgnn/test_backend/run.py new file mode 100755 index 000000000..f87c887b5 --- /dev/null +++ b/examples/sgnn/test_backend/run.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +import sys +import numpy as np +import jax.numpy as jnp +import jax.lax as lax +from jax import vmap, value_and_grad +import dmff +from dmff.sgnn.gnn import MolGNNForce +from dmff.utils import jit_condition +from dmff.sgnn.graph import MAX_VALENCE +from dmff.sgnn.graph import TopGraph, from_pdb +import pickle +import re +from collections import OrderedDict +from functools import partial + + +if __name__ == '__main__': + # params = load_params('benchmark/model1.pickle') + G = from_pdb('peg4.pdb') + model = MolGNNForce(G, nn=1) + model.load_params('model1.pickle') + E = model.get_energy(G.positions, G.box, model.params) + + with open('set_test_lowT.pickle', 'rb') as ifile: + data = pickle.load(ifile) + + # pos = jnp.array(data['positions'][0:100]) + # box = jnp.tile(jnp.eye(3) * 50, (100, 1, 1)) + pos = jnp.array(data['positions'][0]) + box = jnp.eye(3) * 50 + + # energies = model.batch_forward(pos, box, model.params) + E, F = value_and_grad(model.get_energy, argnums=(0))(pos, box, model.params) + F = -F + print('Energy:', E) + print('Force') + print(F) + + # test batch processing + pos = jnp.array(data['positions'][:20]) + box = jnp.tile(box, (20, 1, 1)) + E = model.batch_forward(pos, box, model.params) + print('Batched Energies:') + print(E) diff --git a/examples/sgnn/set_test.pickle b/examples/sgnn/test_backend/set_test.pickle similarity index 100% rename from examples/sgnn/set_test.pickle rename to examples/sgnn/test_backend/set_test.pickle diff --git a/examples/sgnn/set_test_lowT.pickle b/examples/sgnn/test_backend/set_test_lowT.pickle similarity index 100% rename from examples/sgnn/set_test_lowT.pickle rename to examples/sgnn/test_backend/set_test_lowT.pickle diff --git a/examples/sgnn/test.py b/examples/sgnn/test_backend/test.py similarity index 100% rename from examples/sgnn/test.py rename to examples/sgnn/test_backend/test.py diff --git a/examples/sgnn/test_data.xvg b/examples/sgnn/test_backend/test_data.xvg similarity index 100% rename from examples/sgnn/test_data.xvg rename to examples/sgnn/test_backend/test_data.xvg diff --git a/examples/sgnn/train.py b/examples/sgnn/test_backend/train.py similarity index 100% rename from examples/sgnn/train.py rename to examples/sgnn/test_backend/train.py diff --git a/examples/water_fullpol/monopole_nonpol/run.py b/examples/water_fullpol/monopole_nonpol/run.py index 3b1b4799f..617e96e76 100755 --- a/examples/water_fullpol/monopole_nonpol/run.py +++ b/examples/water_fullpol/monopole_nonpol/run.py @@ -13,18 +13,19 @@ H = Hamiltonian('forcefield.xml') pdb = app.PDBFile("pair.pdb") - rc = 6 + rc = 0.6 # generator stores all force field parameters params = H.getParameters() - pot_pme = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom).dmff_potentials['ADMPPmeForce'] + pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer) + pot_pme = pots.dmff_potentials['ADMPPmeForce'] # construct inputs - positions = jnp.array(pdb.positions._value) * 10 + positions = jnp.array(pdb.positions._value) a, b, c = pdb.topology.getPeriodicBoxVectors() - box = jnp.array([a._value, b._value, c._value]) * 10 + box = jnp.array([a._value, b._value, c._value]) # neighbor list - nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map) + nbl = nblist.NeighborList(box, rc, pots.meta['cov_map']) nbl.allocate(positions) E_pme, F_pme = value_and_grad(pot_pme)(positions, box, nbl.pairs, params) diff --git a/examples/water_fullpol/monopole_polarizable/run.py b/examples/water_fullpol/monopole_polarizable/run.py index 808ee5801..560cd8da0 100755 --- a/examples/water_fullpol/monopole_polarizable/run.py +++ b/examples/water_fullpol/monopole_polarizable/run.py @@ -13,18 +13,19 @@ H = Hamiltonian('forcefield.xml') pdb = app.PDBFile("waterbox_31ang.pdb") - rc = 6 + rc = 0.6 # generator stores all force field parameters params = H.getParameters() - pot_pme = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom).dmff_potentials['ADMPPmeForce'] + pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer) + pot_pme = pots.dmff_potentials['ADMPPmeForce'] # construct inputs - positions = jnp.array(pdb.positions._value) * 10 + positions = jnp.array(pdb.positions._value) a, b, c = pdb.topology.getPeriodicBoxVectors() - box = jnp.array([a._value, b._value, c._value]) * 10 + box = jnp.array([a._value, b._value, c._value]) # neighbor list - nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map) + nbl = nblist.NeighborList(box, rc, pots.meta['cov_map']) nbl.allocate(positions) E_pme, F_pme = value_and_grad(pot_pme)(positions, box, nbl.pairs, params) diff --git a/examples/water_fullpol/quadrupole_nonpol/run.py b/examples/water_fullpol/quadrupole_nonpol/run.py index b408792aa..0b6fe6394 100755 --- a/examples/water_fullpol/quadrupole_nonpol/run.py +++ b/examples/water_fullpol/quadrupole_nonpol/run.py @@ -14,20 +14,20 @@ H = Hamiltonian('forcefield.xml') app.Topology.loadBondDefinitions("residues.xml") pdb = app.PDBFile("waterbox_31ang.pdb") - rc = 6 + rc = 0.6 # generator stores all force field parameters params = H.getParameters() - pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom) + pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer) pot_disp = pots.dmff_potentials['ADMPDispForce'] pot_pme = pots.dmff_potentials['ADMPPmeForce'] # construct inputs - positions = jnp.array(pdb.positions._value) * 10 + positions = jnp.array(pdb.positions._value) a, b, c = pdb.topology.getPeriodicBoxVectors() - box = jnp.array([a._value, b._value, c._value]) * 10 + box = jnp.array([a._value, b._value, c._value]) # neighbor list - nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map) + nbl = nblist.NeighborList(box, rc, pots.meta["cov_map"]) nbl.allocate(positions) diff --git a/examples/water_fullpol/run.py b/examples/water_fullpol/run.py index dccf92921..f024a8d66 100755 --- a/examples/water_fullpol/run.py +++ b/examples/water_fullpol/run.py @@ -13,18 +13,18 @@ H = Hamiltonian('forcefield.xml') app.Topology.loadBondDefinitions("residues.xml") pdb = app.PDBFile("waterbox_31ang.pdb") - rc = 6 + rc = 0.6 # generator stores all force field parameters disp_generator, pme_generator = H.getGenerators() - pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4) + pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer, ethresh=5e-4) # construct inputs - positions = jnp.array(pdb.positions._value) * 10 + positions = jnp.array(pdb.positions._value) a, b, c = pdb.topology.getPeriodicBoxVectors() - box = jnp.array([a._value, b._value, c._value]) * 10 + box = jnp.array([a._value, b._value, c._value]) # neighbor list - nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map) + nbl = nblist.NeighborList(box, rc, pots.meta['cov_map']) nbl.allocate(positions) diff --git a/tests/data/admp_mono.xml b/tests/data/admp_mono.xml new file mode 100644 index 000000000..3970ff522 --- /dev/null +++ b/tests/data/admp_mono.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/admp_nonpol.xml b/tests/data/admp_nonpol.xml new file mode 100644 index 000000000..7cc1b4653 --- /dev/null +++ b/tests/data/admp_nonpol.xml @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/peg4.pdb b/tests/data/peg4.pdb new file mode 100644 index 000000000..eee06d7da --- /dev/null +++ b/tests/data/peg4.pdb @@ -0,0 +1,64 @@ +HEADER +TITLE MDANALYSIS FRAME 0: Created by PDBWriter +CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C00 TER X 1 47.381 10.286 49.808 0.00 1.00 SYST +ATOM 2 H01 TER X 1 47.251 11.255 50.307 0.00 1.00 SYST +ATOM 3 H02 TER X 1 46.907 9.487 50.425 0.00 1.00 SYST +ATOM 4 O03 TER X 1 48.814 10.202 49.785 0.00 1.00 SYST +ATOM 5 C04 TER X 1 49.336 9.203 50.665 0.00 1.00 SYST +ATOM 6 H05 TER X 1 50.344 9.329 51.054 0.00 1.00 SYST +ATOM 7 H06 TER X 1 48.796 9.176 51.611 0.00 1.00 SYST +ATOM 8 H07 TER X 1 49.296 8.320 50.177 0.00 1.00 SYST +ATOM 9 C00 INT X 2 46.552 8.760 46.601 0.00 1.00 SYST +ATOM 10 H01 INT X 2 46.737 9.609 45.939 0.00 1.00 SYST +ATOM 11 H02 INT X 2 45.532 8.628 46.649 0.00 1.00 SYST +ATOM 12 O03 INT X 2 47.247 8.976 47.799 0.00 1.00 SYST +ATOM 13 C04 INT X 2 46.919 10.250 48.371 0.00 1.00 SYST +ATOM 14 H05 INT X 2 47.190 11.176 47.880 0.00 1.00 SYST +ATOM 15 H06 INT X 2 45.801 10.369 48.307 0.00 1.00 SYST +ATOM 16 C00 INT X 3 46.760 5.982 44.153 0.00 1.00 SYST +ATOM 17 H01 INT X 3 47.759 6.173 43.770 0.00 1.00 SYST +ATOM 18 H02 INT X 3 46.121 5.894 43.168 0.00 1.00 SYST +ATOM 19 O03 INT X 3 46.268 7.098 44.918 0.00 1.00 SYST +ATOM 20 C04 INT X 3 47.139 7.493 45.949 0.00 1.00 SYST +ATOM 21 H05 INT X 3 47.292 6.726 46.769 0.00 1.00 SYST +ATOM 22 H06 INT X 3 48.124 7.662 45.625 0.00 1.00 SYST +ATOM 23 C00 TER X 4 46.610 4.692 44.880 0.00 1.00 SYST +ATOM 24 H01 TER X 4 45.686 4.613 45.520 0.00 1.00 SYST +ATOM 25 H02 TER X 4 47.444 4.603 45.516 0.00 1.00 SYST +ATOM 26 O03 TER X 4 46.501 3.674 43.869 0.00 1.00 SYST +ATOM 27 C04 TER X 4 45.802 2.493 44.226 0.00 1.00 SYST +ATOM 28 H05 TER X 4 45.959 1.651 43.497 0.00 1.00 SYST +ATOM 29 H06 TER X 4 46.125 2.280 45.251 0.00 1.00 SYST +ATOM 30 H07 TER X 4 44.695 2.638 44.209 0.00 1.00 SYST +CONECT 1 2 3 4 13 +CONECT 2 1 +CONECT 3 1 +CONECT 4 1 5 +CONECT 5 4 6 7 8 +CONECT 6 5 +CONECT 7 5 +CONECT 8 5 +CONECT 9 10 11 12 20 +CONECT 10 9 +CONECT 11 9 +CONECT 12 9 13 +CONECT 13 1 12 14 15 +CONECT 14 13 +CONECT 15 13 +CONECT 16 17 18 19 23 +CONECT 17 16 +CONECT 18 16 +CONECT 19 16 20 +CONECT 20 9 19 21 22 +CONECT 21 20 +CONECT 22 20 +CONECT 23 16 24 25 26 +CONECT 24 23 +CONECT 25 23 +CONECT 26 23 27 +CONECT 27 26 28 29 30 +CONECT 28 27 +CONECT 29 27 +CONECT 30 27 +END diff --git a/tests/data/peg_sgnn.xml b/tests/data/peg_sgnn.xml new file mode 100644 index 000000000..206326d1e --- /dev/null +++ b/tests/data/peg_sgnn.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/sgnn_model.pickle b/tests/data/sgnn_model.pickle new file mode 100644 index 0000000000000000000000000000000000000000..0c3959cd9d0ef4fac155676861aa6743acb96c99 GIT binary patch literal 17100 zcmYhjc|28J^gmAKAyY_*1~L^=63$*Xr4mVLpj0v>4d|LmrOb2Y$XF7RREWaa>!!>p zB}p1I7fmWEY4~}b&*%Ame)o_2y03HhzH6O(&RXyNUhBP2h=7}$&z?Qo-TZg@c>9Ul z`MPiS-R^F=)6HL;%co<{<=1xP=i}qs$DQEj9pJS$NZ-xJce}n!ym~N``>}?{y}@U zi*v;tCyZP1n9r54;j`h7=1SUgCu-XW{A-)xO08P8%KGp4>)$sUK7X#XYwQw11rJ#c zSH_*^@^&ulkCeM9|y+js8p^ykWX3V8l&dy?n4VR?J5!as(!ZRSUD6+I`$E)JQ- z$G1T+fcBgerYQfCJnLyCU5n1L30H2haA*$%7`q}YKMGyJIdm3Jif)~eM88DUuz!uc zNNAHD+&Oj~KmEz4?{+8A_PLsLs9c+veQ%*5X%DISqnE5w(ZYq^-y~7^*j0K*{~jG? zo0*RZ6-;YuG?Mo3Bsg^!h6>seBNG)o+1p0szEs1!-xHuICy&g3U4{qL>*3;vItT?J z++%M;@YfPJ%(@T_&OuNe=EM21Ps#JaKr1#@`jJY_vKV!_Lxq4MtdG?`%+ za8xiyTN-Hc7!zA9hPRsjl2rp@OquL=HvHWcni(O?s%SVeJ!8|Tdy+V;@zlZZItF#` z%);yGiaS(6Hjg^4&&Fz%Od?jim7Q>^1;4KifrysX5WY~0RQ(pit&#W0lmgrOgJDv< zx&`%6#crm}t;cD_+kE9m zyJaiMx=bnTIHUm~7pMV&TLd7zWFD1n4<*z7PJ_7zwouFJ1hY32#o>|pN%G5Q5k| z)2AF`t9s<9UuG0de00TJbnYBDGh3Z7k-hZs@j1l9PY80l=aC@6e8RtBElO;gi%lnz zuxr&8(x@wmTG!W-w{ca3(Vj^A&W;iO{7hW^CyG8@zlRtsmu6ZXpJ19w3Sq~myQq2O z8Qu-6!4wH&tedwSS#?EJ4^#sqrVF>X#BesXDG>gBmUul!hey?BQ(kU8=X2jl8l5!2 zR+g_uA%_esv$n-6-LpV@K$;LmX`I(oMLexz@J50j{T`G_RWpPjZPPN=PBH{}mUd`x zTL9E;&B&l!CE6qz(y4t>#5H~u8!A{#{(fiC>X{DjOieSLANP#5&5Q-V`IjjFfz9yf z$piAlV?Fv@ngg>BJ~CfE$qM|9U((AWPpQZUF^;c(F3c&Nj6EfH$?WTva6mB;YG)nc zOmqo=zn@m4b^3W;@Qx?&!E7}wkbVcf`U^QK)@oQSeTL|zFM`)rQ*j>-lY|K@b%AU$ z@2LdL4eTNZ?uo#$KkMqxDTdeYds9rVE$C){Mx0`HSxCZijVX}*_(}b#vHOJgS()-b zyGe~729a$u)lltlmARK%2T^cd1&7u@VTE?sQnf7;cn0}47zdj8;8isp-W>*;cLk%1 zya%<2n^;${Bn01D%>kur5qL6rE$(Z%3x$b^)ZZ!s?^>sz+?U@hi{nr_HpG*=u=RVOp zdJsa6>}RH5+=>P(LuvnX5q!~TOTWj6a~jHO;K$-A7%a4xhanx<_1POlqFYfRWF}8P zZ34&j`+KswYX$5WJPaN$1Yl)w5BYRIf|;qc0CIgJ$m#cH^zKIhdu=b8-RI8SJMoC* z*iI$F&zI8TE8;Nq_ZDd0B1x|FYMCn)CeTygmy_yJKDZ*R!g@T)lDN#=yPgj+svk*}fH3U2z>uE9@#Kc)1{B|2i`(J{>8(fah*rWf zRKBeQ=24$;rCI@WytE?S7r&DC1N+hT9FQf&44q}zLhHKXNxvz}rXI{9DK)!j6tO~q zM~~>-kKa*njE{aToQQ8q_^HA-P4aXhAFr{$j+QL-Ak7-BWYDr0SAUsI_Pn`8j8<3! z&p8A?2lX%_Q7xbvkVK4<=itfH3-E&4T-+5Q#0k6l2RD^jgT479W|0;@_!u@(6?;{h z))x#B!7`la+FuaI`N&Q$D8-wN(>P(9*WswLOZ}T2a&%}}EciU%4oki}l8M}9c>Vh- z5*D}shnx4{!KsrWf8R9TnS|G@*h)>BN~CCT!b)6YTFc&=mO%Apm{75vt)za|HD>v! zC}?n0!D-bVb~I85RWE0O^!aL9q4=J(ttM%T$FI9 zHm8qLts7p{MA8!rO+)G5w*%%sQv5K;;0)fGbc>l_kOG%qhfxw10}8&GwDG|ic1Qb7 zeDi%f7OXYq%ovHp)%ANxcA_fE!a5k7kxaHu^r8(7;;i`8-B?^_$~e*p;%wK;4i@gg z`DYiBA$H1QP6~oRJ{SN#DxZSv zXcF$4vJ-AwhB1v|h4t#^X3$ubNZNJf0<5AVDAoCk=pQm68~!K)_FRF|*YddBRRSNk zU5BHxe@RM}HZ9v0Mt6CBqlOwNsGFUc;z3$gY3epTdv6ky_NhJ)` zPgTMpy8=A=YZgZ%kVnrR%p>BwLVDO`5e77bfL(_=3H@r0i>vy`^MKzp^mQGT%U7dM zmidrhq87Atwh+YECe@em?!v3nlg$IypP-YRi@|WoZ16nGkLHU?i00}MD2$edqwXp= zN#!B2*)C1i8koX|3VRf7n@DE3d}en~(4}TpgCsbmoVNEHf%3H&-hMk95dOH1$bOzc zWA+Q-hU`2#!`l}`H$I{^D}G|_r@y#&!w)K%K0qINzopr|S*&bg7;|qCKO_}duyUd+ z=}^y2GV`4pY@CJc%JUOBULRk4M`W9pZTD5}PtCNA7TPi9}#3USF@k`7l~da#;of7u_UUsg2a?L=*k3 z^oTC78HVVo&5Y@rdsNPGPrYjSUe<7#4edrQhNc`Nn|IHK$Key<+6+yi-=F}(g{xq7 z*9J2BDS~Xzox+;6y}$r+9jz5>QE`$2uhvfwcnOYFv&|I)Ya%gy>Q0&)k`&&tCME!u6xaNaryRg6er7R3VNL!~N)4WX!W@sD;d^ z+tkDB9QX=4)<-v8$8Qduu$+ou<+J_ds#gIvPHScNeba>OUDEiva5`I|xgX_i^FaTI zFj{#>BgRVBe|;|iTBjsI>7fICK4mK~-M``c_ul%ki>ACfbSyL=_Nv|9()z3d?`tw!j>-|xwYc|AMFZXbJK+ZVF%w-&u5yM~4t zj%NVhEQZ&fKq}hG;e(OI}YD{$|CnGX5j6lzck|cVqCJ|9C@_65)$o3F^6vwt@2q;_C*@OnSrbD z+g$^0O#XpUg%Z45FAjj6lss?BuglC^&1iGMpU*j4f<@4ILN{~XK>{4EiD6`I8tiRY ziEm35(w8^H$?LPC)I+~@JRTEBe{wkqoxGFgTplE`wh7E%whs%;%s54E{Io>qDzJK| z>DYx8G;}N&cKW5z_wp}DO0_J_kL2LF4TiMW(TI2{*n_Qe03A-H*c6b6MmKEWo~j9H zaC$-iTo0mq7cgYUReid#(V9tZ$|UgzRiIatpNu35!-9R!!A#>MbEHolhYJnqK2#-@ zxBif$FEn`ZhtxU!;U1iq%BPs3?^^7g?k#lh=UZf|*)mkMe!$@61ZtP}kOT^*kU1}! zXvW$<#38kr=HK1O>g$vc!RHOsb+#UjRnKP<*9e33N)LwL$%$cW6d2!<2TXd4De26; zf}RG+bz;XKlW4;%qF38TgT1Qhr)B5a(16X1Vs;I&35=!u6P3t4F@ADbLXs3zXp*Zp z+F1q7WE9@Ufy$VDXq+g>ep_5kYnGJ2;O4(%)B6j&-R|m~m#tH=(fk^W-O=NnkC)<1 z%~9Yi1AltLY6CnzWX7?Zo(lP8UrGL?$!N}<1h&7-q4Uuwxv{x`b!_^>Y#F^qS28bI zWy!~eq1C{VL9~K zEnVhAXCC>;|C+pABtTa^6+zh_iSWbxCHt%9CoSgHk(NG1dN{0xr~l^^6nqt z46<-KG=qdxCc{12P9la*g6xxgG}5!CmKIUuqk;fVzbMEVK@Z-*Cm9GfQRZx#Yen@^ zyUFb7*NC@#D|Cpxqvxjk(`+MkI*)FJX08cd&bSQPWfLGIcsaaM*no1I`CxuhE*+>j z#6A%Tqz8R}na5-%(9r3RneM%(q5OUUQNJ=mS9n&!iaF^-&?}FLeq8_qazpfWa}qo; zJcb!=??{i@9pY~6g)T0uaGk+pV*EXZGC=KXEdAXJpz4R;3<2UD_Ob0_Uqi-)V> zUGOgH6F%42MT~n^AZNcX)%)H?ecj(t?T-Vf=Pzw0{$?-8KKu@8HEwWPScfzDls->x zavWX9-#~8t%piZ{1!<>vviZ|1?wDLBLN1Q#!3Fh?jQKokP#Os$zmEySj0XlVA~i~l zgC7$e;a0fsK9xy3UQc%Tl@oo>bVklr2n3?}sB_sLa>*ecK5fwC%b6H8~C^^H89FHU~zOIdKIJ;C1E##1uEe@pm2gJ3favq|ApK zW7|MIWDajykvG07noPba$?*JJLb2k`8zL(%MNXm*Jr?ST#b;EB(3k~O&ydF9+H>r- z)h1Lv>N&IeKtA2VS4+g#6_fiNc?_92lbn{?0M!s$(XHn_NM=?d9m?Pb&nrCh zg0;isbGIS=D?Ae&n;Xfo-W9NUX$iSmxS!5W5ykZndud|MWL!5Zf`-5eh~X~4z~OVy z$z%Y3^ffj^$_*D%IdaLOgi$r2$6fwqR{K{AYI{doE_ONN=+XY5rx1e^2&J;jYtb) z7bypmI|jAHcFBBZvfW;cQyienO6Sw2+)(@yzl9cXB=F6_B$Sh>W8@wvW5f?t*tm8v z?HX((l~Je3tOKRgN=k}`Ri%&~lV@;l^BpK^i-C1rxgg;f34Pydq5s@&T%3#-9!~R@4M6)bbfkW&}s;a3=UgXx1lDu}>IinQ! zpQwN#!60fSW`oye9wO1}x(V$sf-f^pQ+bJryv5#6p}2b+QBt4J{xZxUS7lR4ThUSW zOmjQY-g6cXNxmnEqRVKPeK80J?1tGl*Fdt-d|tpDK5VgUM(d9C`0L>%`Y}odeZO4C zgSq!fvaCPkJ;(u#Gy~3pqaB#F^axK{*$MXZf1|RSuaZgbF~rHsg#Pv>WW|PSJ}%l-LpfKr%eFk4Zlu#mmAH zdRgx*7?+e2$CKH_Wd1i4%U;49-5Sbx$7Vv~WLc2=I3HZVE$o z{pk0E##s!qM;`1Y{abZ$_hmu6+ZRUGHqXNC*9?fFi3Dbi4zgvv&S-V_JZ^EwqIHji zIGHwDP;r3aBpw_Fk$4df^JXVZjs1?E8=qpv<|Q1Pq3>kNnpn>3X>z>lMv1&Y<;kC9!P2_h-G0i?F1N$7L;cMSUxGKm2(W2QfqBX#r)fgd0RtrI6 zeJU+A?|qV!K~o{{q$IFgsfs%eRlmnWyfz>j{c-672LE|v$Q=?^gDPA}uG zA!>H=S1gH{lZrh}S#BwM^?UD`)W`C~ zd){>L6rGEH$Is%l>RlN2X%@6?ih);SZXiy-?~#64H6hB{5+*LX2TUKtC!42!VvudR+Kc z1SNyQ@T0a0>|6N?Hmtgg_iqod;XbKk_WcN0{;C>V@~7~0f)XIraUUl%CmK5a%5cq; zbQqo%%~tFQLC*wpl4KT6e+_3rtm1erWIZ3a#sy)Kfe{$K^P&cW2@qPNip6pQ*dRWU z7?y@$^rFRNefM#){p1aNFZ95Ct=~W_f{xkiyP6nmk6vO(Z z`pl88rRILVbHV<5IaN_NWw#ueNor60zy!zBRN&M_rpDzkR%8G&m&36!;WiBFtz#y% zB;%iLL9lMK9u)a6hpczvr2nG~Z*#c`&urr{n&M~8I~UXg=U)nO3Z58pL@#PnVXs=# zoq0!@nT6Hl?ze4FxM~J$IdhgozRMuT7B7WT!$6WR76$D~v#IL0qqH+f1WG?2CbPGR z!OX~$w8MZ4vf*1<$bCuLjds*WUdtg-eN$l3R2Xje&*UtP3&&Gy_kqWfJ{Z!R zOfKG$!<|Q3;Njj?@HpfLUGuCEKUcaFAGhhO-GOuD`1n1(R<8qhU-?PT&-+Zyk5!T9 z*UOmCXU{Vgf4>OXw&0ab_{K-BAuB zS^4N~9tTc$yvfKZey|nqqhF1G)xUDuk5?iTNqLepgnTKby!G>mwMiEd^(#RX?LybM zh0tND3AjIyv`L-DhL0>%m$zf9lpuDD4AkA3p$p2~cw(ZI#mqes0>Mw**x9rE>9vG7 z>TN8Ce?&ql-@|J9cK99HrX4_qLM9>qrbNg{tsvf3g~Yc$oWj||L~`H$c zsm2-0Qt1Todg?U9(5LeW6qYHH#!t zO7MJb4|!o{ie~i;=&Cs42PG+Lm7Gb#2c@ZwiZNaPJrW9b`LN9`TPPknPkYAyOfCt+{NpB|)R%_Z} zONI!poiI!%jYUJ$;RvWYnNEq_OrBzBBOJMZ7v-f4c;AwuA>PQ8qaY$j$2Wco%Ty*qi)$NPyEy@# z$!)GJ)6=9AJNiLLK8eNqXUtU&zMz|)>cgSE3cP(?spOQ&QM6n?LLX_2(YYn@7__ty z&RFHMnm??d`97DYr!4{jj3G~RTRew1I)$gzBLLs&R(eD~$z1LI43r*>B_FHK5eFF~ zYI&}lMksaBq}mo5)8vFc9}URomEyR*MiO7>vSjetdv-5GlgPt2>x;TB(k_j#`nG|I zczU-2-aos6)$y+<(?U`8$T5OEAa`#~M2NRsUzXg^T?ZXq5bs#6L2A^kN0tcE6 zcde6QZJQ_hZaNDGj15TdHCayHgG-S1jSK7gqCqf7lQSwmiC*?cvRG*%uP39Md{8yQ z@T%9auA?0Ngr}q5uE~s`N*^)FYQw7e7&2vAH zzweMSdp3p8=MqGBcK`$WN|?AZ1y0(Ua6D7QI0ea1VS>Q}xIOhQs5J5*CukmWEI+}@ z;>o-@;sWF%dxHd?j|Lr)3Y3t{fjP?c=*vxmv5Y+!Zm^W+^u`9#|6HY~>rH?=*`KTA z>L~vIvdQ>plB9Z#8)|c#(82d`|dhH8YJ(a&XEj z8KUKnG4ju+(@m?Y$*I4>F!(!={yn*#+%Yew^ZW(bPH#^Tc{0o-z1zYXnx>N4tUv7L zRnzg+)t#`!J%Eg5En}5ac9G*|zsVF%7L`BP%5c^Ex$6HX$e8<|AY%laanIBf)!NgP%<>o9FLVXQ{SDY7~&i0t&(fMy;(u=-Lo8tt&g zPgA4_+hM?4w)Ys8o+(1U_2UgbyJ)N&N&$~%Lx@}vkISFkBwe+!;PrYZbdMKpy~ONr zW9ndv%Wv7hlZaPLD5!K*fRG9Ot&f{jen%^kq3jhA3e#R zo#n9O&pXEcNj&oJiz0Kj_d=DC7tGQfgk)|sELr&pgAd4%CT%%((m^jU&pi%XeB4n} zJ(NiFSHhyDiqutHmCi^ng?T#mXr4)r zrgHF%_#dMEGo3DNs3nUZ7D2`2*QDsF4(@&{#o0}raWu{gddjk~X!HOi8jKLfVpUL7 z$%Gq!Ht?x} z_bvx;Funs9ngv02%_W+gv6DP1oXT67R1Wc}wz%rUEnw3|@Z9Zqc zhe6YozlgM@1uTqKf{({Uu)r%B+ru|O`0+}b@xBmtbu5PBPwmu5GM--C@R062yqO-! ztz)w1-J)6h1EA*ePB{1|fh1;e;DN+U+_~lw_LkYZ3yGc!Jj5wa+EAeuq9K9NO3Lh7w(@7VEN$i31 zBqn4pv2^mM^ByOIU|k^Ne0M!e7)+*>K56FO<3)o}{#zvS$xrsnvnx1c&XnFr;d}9QR-c-N?u1>{!tQj^+=U{U0Z+7RKBD__hOqC`NgVw?wL_a;2d~!QT zY!Z3&<&_G&5U7vst%M%+41vhfBF0WL4t9PPL9dvZyld*$sKM(sD6TRG?W3BozkHOu z*!T;s!)c3_AJ!-y!`6o*^ooWNU152F?g-yW<3tv-A8$IE zTfAGrTkBG8el*np2XZH%-P>poTXvEjeZGTgPRk|_6i?zFUKu&NBAf|-qe)I5eowxq zj1$Y^7-}LR3IZD^gXN4*L}NuMI{%u3lMbK2{YQR5NQV%HNL?TorWugyX@E3f z0kCxpHC$wPy`ArvmtJlFC-h;{Ze>ooo-k#< zbW#=B(B^};JgV5=Ru|0$f+q1yqyyn~Tpx_MkI`G4xtLih!CByY9{-Rn^}5wru%tN< z`m_pgqUC+^@$_b3bo-$^uA1)Zw?&r;<>XGU5Lqes5c?a_kV%z<#qzSiai~VKp7qrJ z_e!wlvq1MiY3Ojw2J0OQsf7tYiBy!w&*J%{W*YKiTD17+*mi*1jWoKb7t5T5<;T6Zp3*W~m^z=U4 zz3C=dRToRn>YOFZ@)Teo&xY{!h0%)3LfHB#4I^c3@ugn@>V0V-5iK!n$z}`qV0{_9 zC7MAbbpXtg+`*$}CGmZ2!8re%Ln6B`)$^tgQ31&gFxjj_g61$VaC!yIUzJN@#RurK zpZl2rqaWsu@1y9Ou-ft5qX_%0{*ZU-vb-a;C)jUc#yCrHK3?k&MgOWROvTlOxG{1L zd6~4FW0V~Mc{67*DO*_*qiD>QvZ?f5^C@B-(nki{%yD!_G}#@RhezyZkqAoX9foO=#xt!YpHJNLy|ur!U_@9E?FDN< zTxuhga2P;GP9@2qGl6qPA7A$Ow;c)A$jV)fv0m}1ROs&;o?9wTt0OdbY{3<<~0HKK(!9OUe!gS?|r!Ma*L-li>Ek&(Ox{B3bmv%CjdPAq~Xix=cz@7DSz z^)9%%tqz)vR>3yOs~{_E0cj^HsX_8WEMOng&akKTmx~KgG)x{2=ijCuH@|=pmvm;O zTsrdU&cbDTPQ(8A035C~!SzcQfzzRvbcx|^nr^L%5gq_l3xa_SIYzp^WZ|EiC#jda z0hoh2E;l{|J2uXN`>$8QzTr5u^azL7G0k|+@CQWkO*U;W*$;<*XyOrtCYqWt4IGNq zaCC&9#QfxgLo>v2Rb&#R?23T-SLT7^7c-PzvK%|U9DwJ`^I-J)E#k~K1w`LOqsx!a z5c1U&azZ4qckfP|sPTflUb`7g+qy}T7>_A7>roGG zGyL`&L9XUIrO$s8v!}xKgBNe(g6VR2Iqe=6z#3%gy>O3Erp9*{AFeI@VtJ zbMv?`W`ze?7}ZK&eC5%in}JNO-$W#9gg99qi^;>+wM=KfI7!M%g2SZ|?9}ypz?T!Urkd(*7vlYnI^ykj{KSDNF|6f9;^&fN5{)f4YYvOhPWiF9iy?+SV>Lqp)AVx}! z<}~Eu%XW9b=F=qS@Ju-9_!Kk_G?3rYjv$ghk6*A<6h)1{B(m5v@w zxbmDSmdmF(Lmq6w3rpnA_UF#|KLTa+AA!>U&kcir+%UA~&ix+(B_G8#`iDR*66WK( z*Pw<2jn9eO$;+s8_BhG7tO(Ocs&yn)eh9KZ(ik{;^s_QhHi1dGf z_f;LN!B;H1;M`aLO!?%)ana8f_>n!C=lM_s3=+4J%?Zy~ zNBOJ&xagR=ohmL}M&-qx(ks5pvHSRF{jAC) zv=B%mhSodLZpl3Kah?YYgY-b7AR0AhM56P^PMRPd4zDl|zSKXii~aK$zaKWifJzB? z_v$M&rb}_&4Oaj@AAxm~_;@?^yK!zj-iW!ey{LEh#<;pvfwS}eRd9aG!%aHTu={N; znhpNKw^>Iqz+H+asknj7yL)7Ga2>vmkibaab`p)pm<G~10v)_YOm&aoG zWI1vrR%|?o)M@Uqm_2}x|trH zs0)47Rv2?O4t`1%f@#okVlZUP**i}f0$vXhzr{{W-*_U@E^9G0IC_%~6v}hbT4ONY zV;D0ZCDZ=+Q$(aeh<8Unj5dGQN4a@XbiTkRGLn#jN8DY`|Yp|i)8T}!W%zoXHpYF~@BWfyVE zXBo5_(t=Z;k04!8ikA;*vQAUO%?(H`z63ce7W+v=^OHeZ{1V1p;1;*3X{PN!Ohff(%VH@bE67$K5Vc+h<~Ecgh*RgoeS}8)u-|@-JB(^92?; z^FgC-D6>bm8Pc4aNL(Es=V7`g%}g4FL(|Wqtn@tMuze-OMQYOfAKGcwTLaAf*~nO^ z@IxYbUgxOH$2pa8m9?~vW0f>}D1Fk!Za*N*am{k3PepdX=+0g;Sgp@_5cP^ieih(# zt9}5TjxegKy%^^lS`0sTz9X)dA4$pHGiVg7g_%bdW0nZRnJ835UY=ddyBhHm)6cBI z)yt*mrNj~tmMw(!7L9D?#0b1PuI9h~O`3P}X*znZy+bx^Z^xUj-jUj;-stenmHMB` zp*8uLASsrIVjm(d< zMUa<}L<8KH!t%*w5O5KIT?3%OSE`kf3U zIGlkGu@_KEf`Oxnp_uVnl1TC%GA}a6%zj%H!>(6dq-MhhcIZJEsyr}cyDn$qw>C}E z$e&1L85`_76#?tMN@Iy$I$gKF1|Li~$SxO@!8KnvbgKtTWUUTyN*1f3l0goAy~6;V zENYmyI*PcaD;Y{7V^Foin}%k4V_UB(I-DuT?u2orZ;l>t=lOHz{~rJv{SN>d|A#ya z{y`oSd+x&j1whk(0nmHuI=25@KKfm>9%u6z@V7r6ini>9BcrF;)sKZ?^&Tg1e07mb z^>l&r$4?@wY{vF&ssVn#c-r~#IjQywMNAu|d-6}frm@HoeG9G) zr^&aQ*XZtDo56~H$Iyts`4R2ppDY;E9$tm9^G@Izcu3 zCa26<_Ur>YRka09SjeF9rf+m&;YJiVwvX8+o{69D6`|XLg4c*_hc}m&R9DjK! zj>SbyUY2D6298%J-$*Ti0}sao!{<0zfgM!iWGGlXeNLBzZAF)B2g#cY*|=!2Jes>b zCO`Td;5SoCFRh8E%RlR&;lNN`RyP9$<6HhT7g)fBrK&J@z979`n?cjeT2R$R3tdKh z$zkES;CkE{u`L=>rkqCUSqrc`{x>~eI}ImlS>vBDC$MX2q@_O?GUM+Xy2^>t*{ybv zZh4q2d}>C`)@5O``&}@IwSxJ(-=gEnPat%U;MasB97bc??5UX2GF8i2n`1zrm*f`;ZiaB{B)8SP`Z6Nb>eIFr5~si3Q{0-OXs zzyzDCD7G#eJ`~(wURe7Or%Wlhu`~|5pCyyNW=VQSm7fT&y@-F#yAhpdMu3X zY5Q#e-6sri!lN}LciAd%Jg@||ycXx^5-HSw)5zHJ-ce0XHJQ}*jifnp;IF<2dX)Yp zULo@J!=+m|Et#c^Rq|5sm*U~`bqny;K_BwXRsp805#@-M_0YI4uB5Xt0?d|QXHss>K$*rq zS|bLq@O?JCk3sjx+e{LTJ-ahr&4*h_V9>Z73LW>P<9M0m|}acOHm>MO0nM9xEU zU3V>doiM~9ts{8Gd=VU3qmTCTy`(E%81~LpgzYhWH1kC>H0vuuZSF2SzqgRYn987c z;1=?$w;Bqvl`(#@H9p82ueO5#>d2p_5fL-tu6`y_6L!XWt$1=XDFJ-nD&xoAajB+? z4`%KhCs41?(2BeW8rk5E>l|tswQZ~LXl)leO&|%&P7tCSpo&RHggK8p1jqk79Z-C1 zPxfwXV@}rHMQAtco_`tOlK+s$ z@*m_`YR_HvzW})WUjXzpUj~u)$M-aT&?8glOXG(E1=KY=N}lId&_j;K%a z`Nxfw_FUV4+z_`JXMn5z0l;Yje0=(|v!TFzKS}W|po;1$ko09g&?TXms8fpbHO^wX zj~|ZRc~6ea5~P~L&sqJKTu|Uoq)l^Ile^^+bm475B4DD6@_%jN=s9P2^(h6;JR8Q# zA0kj?_G?-4Jib>)`K31DP_+HDl_phVZ0xiMnc@}+Gc7r{C zyNq@!-m5=zL!X?wca{E_&PR4{wucI?7vYqQzb{7mZDcJq1z4y*It)cj>+G@YWj z&-{QzBpbT>3M*a`O1syK0jY>$Z55=b%8JF*_1s+YSkj10a}r>k?**H`5K^muaj&cX zWUC^xYISp+UE?WcG5;S%L&KupJ#q>afAQGd`fM0$_E?IQ6Sz{ZRI`Gtm;J$NAJt%2 z-fmzuofont*KUtrzl|8IZwN}3_2if5V)o(YX;2i{PP$WmQ|a9c=!((^vYgvV;`}vH zc6cMXQs@Te$5vv~_M_&1iqGPHl`?Aeu$K64+>GLH))YI8w}bu|BMep8Z<(#QHxG}#`nt-K3zNxjVa=K{E(Cm2uo)H3tdbm7qTR%)HT z4j0F{ApLw59FE1K{Um8fSfoxHvUOqIi!J0;&P?j@S7`m?rfXb+<(9%;%4n z3tbO~QKt#D-w*SNUqyIzc1)8$Cz^6{khv+1A@68j$-5UoxQgEySTgfyKr}S z{`0H{Pi@EXqkP<*!Q5SrBIBO|xBL2yKNf6%x9IRkelHVByl@Bs~CLXT<9O literal 0 HcmV?d00001 diff --git a/tests/test_admp/test_compute.py b/tests/test_admp/test_compute.py index 02d81ead4..be4b9d99d 100644 --- a/tests/test_admp/test_compute.py +++ b/tests/test_admp/test_compute.py @@ -24,13 +24,17 @@ def test_init(self): """ rc = 4.0 H = Hamiltonian('tests/data/admp.xml') + H1 = Hamiltonian('tests/data/admp_mono.xml') + H2 = Hamiltonian('tests/data/admp_nonpol.xml') pdb = app.PDBFile('tests/data/water_dimer.pdb') potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + potential1 = H1.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + potential2 = H2.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) - yield potential, H.paramset + yield potential, potential1, potential2, H.paramset, H1.paramset, H2.paramset def test_ADMPPmeForce(self, pot_prm): - potential, paramset = pot_prm + potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm rc = 0.4 pdb = app.PDBFile('tests/data/water_dimer.pdb') positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) @@ -51,7 +55,7 @@ def test_ADMPPmeForce(self, pot_prm): def test_ADMPPmeForce_jit(self, pot_prm): - potential, paramset = pot_prm + potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm rc = 0.4 pdb = app.PDBFile('tests/data/water_dimer.pdb') positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) @@ -67,5 +71,47 @@ def test_ADMPPmeForce_jit(self, pot_prm): pot = potential.getPotentialFunc(names=["ADMPPmeForce"]) j_pot_pme = jit(value_and_grad(pot)) energy, grad = j_pot_pme(positions, box, pairs, paramset.parameters) - print(energy) + print('hahahah', energy) np.testing.assert_almost_equal(energy, -35.71585296268245, decimal=1) + + + def test_ADMPPmeForce_mono(self, pot_prm): + potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm + rc = 0.4 + pdb = app.PDBFile('tests/data/water_dimer.pdb') + positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + positions = jnp.array(positions) + a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer) + box = jnp.array([a, b, c]) + # neighbor list + + covalent_map = potential1.meta["cov_map"] + + nblist = NeighborList(box, rc, covalent_map) + nblist.allocate(positions) + pairs = nblist.pairs + pot = potential1.getPotentialFunc(names=["ADMPPmeForce"]) + energy = pot(positions, box, pairs, paramset1) + print(energy) + np.testing.assert_almost_equal(energy, -66.55921382, decimal=2) + + + def test_ADMPPmeForce_nonpol(self, pot_prm): + potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm + rc = 0.4 + pdb = app.PDBFile('tests/data/water_dimer.pdb') + positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + positions = jnp.array(positions) + a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer) + box = jnp.array([a, b, c]) + # neighbor list + + covalent_map = potential2.meta["cov_map"] + + nblist = NeighborList(box, rc, covalent_map) + nblist.allocate(positions) + pairs = nblist.pairs + pot = potential2.getPotentialFunc(names=["ADMPPmeForce"]) + energy = pot(positions, box, pairs, paramset2) + print(energy) + np.testing.assert_almost_equal(energy, -31.69025446, decimal=2) diff --git a/tests/test_sgnn/test_energy.py b/tests/test_sgnn/test_energy.py new file mode 100644 index 000000000..a771f4508 --- /dev/null +++ b/tests/test_sgnn/test_energy.py @@ -0,0 +1,51 @@ +import openmm.app as app +import openmm.unit as unit +import numpy as np +import jax.numpy as jnp +import numpy.testing as npt +import pytest +from dmff import Hamiltonian, NeighborList +from jax import jit, value_and_grad + +class TestADMPAPI: + + """ Test sGNN related generators + """ + + @pytest.fixture(scope='class', name='pot_prm') + def test_init(self): + """load generators from XML file + + Yields: + Tuple: ( + ADMPDispForce, + ADMPPmeForce, # polarized + ) + """ + rc = 4.0 + H = Hamiltonian('tests/data/peg_sgnn.xml') + pdb = app.PDBFile('tests/data/peg4.pdb') + potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + + yield potential, H.paramset + + def test_sGNN_energy(self, pot_prm): + potential, paramset = pot_prm + rc = 0.4 + pdb = app.PDBFile('tests/data/peg4.pdb') + positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + positions = jnp.array(positions) + a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer) + box = jnp.array([a, b, c]) + # neighbor list + covalent_map = potential.meta["cov_map"] + + nblist = NeighborList(box, rc, covalent_map) + nblist.allocate(positions) + pairs = nblist.pairs + pot = potential.getPotentialFunc(names=["SGNNForce"]) + energy = pot(positions, box, pairs, paramset) + print(energy) + np.testing.assert_almost_equal(energy, -21.81780787, decimal=2) + + From 5acacff429ad104f4edf4ceada2ebd4bef7b0906 Mon Sep 17 00:00:00 2001 From: Wang Xinyan Date: Sun, 22 Oct 2023 16:41:39 +0800 Subject: [PATCH 3/3] Modified QEQ potential and add JIT support --- dmff/admp/multipole.py | 2 +- dmff/admp/pairwise.py | 4 +- dmff/admp/parser.py | 2 +- dmff/admp/pme.py | 16 +- dmff/admp/qeq.py | 430 ++++++++++++++++++++------------ dmff/generators/QeqGenerator.py | 135 ---------- dmff/generators/__init__.py | 1 + dmff/generators/classical.py | 7 +- dmff/generators/qeq.py | 214 ++++++++++++++++ tests/data/qeq.xml | 44 +--- 10 files changed, 502 insertions(+), 353 deletions(-) delete mode 100644 dmff/generators/QeqGenerator.py create mode 100644 dmff/generators/qeq.py diff --git a/dmff/admp/multipole.py b/dmff/admp/multipole.py index e863f9e72..1fb8e197f 100644 --- a/dmff/admp/multipole.py +++ b/dmff/admp/multipole.py @@ -1,7 +1,7 @@ from functools import partial import jax.numpy as jnp -from dmff.utils import jit_condition +from ..utils import jit_condition from jax import vmap # This module deals with the transformations and rotations of multipoles diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py index e1510ffb3..8fbc1b856 100755 --- a/dmff/admp/pairwise.py +++ b/dmff/admp/pairwise.py @@ -1,8 +1,8 @@ from functools import partial import jax.numpy as jnp -from dmff.admp.spatial import v_pbc_shift -from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs +from .spatial import v_pbc_shift +from ..utils import jit_condition, pair_buffer_scales, regularize_pairs from jax import vmap DIELECTRIC = 1389.35455846 diff --git a/dmff/admp/parser.py b/dmff/admp/parser.py index 44e83a0b3..5a9efc857 100644 --- a/dmff/admp/parser.py +++ b/dmff/admp/parser.py @@ -4,7 +4,7 @@ import warnings from collections import defaultdict import jax.numpy as jnp -from dmff.admp.multipole import convert_cart2harm +from .multipole import convert_cart2harm def read_atom_line(line_full): """ diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py index 98947d0b2..17cdcd617 100755 --- a/dmff/admp/pme.py +++ b/dmff/admp/pme.py @@ -7,24 +7,24 @@ from jax import grad, value_and_grad, vmap, jit from jax.scipy.special import erf, erfc -from dmff.settings import DO_JIT -from dmff.common.constants import DIELECTRIC -from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales -from dmff.admp.settings import POL_CONV, MAX_N_POL -from dmff.admp.recip import generate_pme_recip, Ck_1 -from dmff.admp.multipole import ( +from ..settings import DO_JIT +from ..common.constants import DIELECTRIC +from ..utils import jit_condition, regularize_pairs, pair_buffer_scales +from .settings import POL_CONV, MAX_N_POL +from .recip import generate_pme_recip, Ck_1 +from .multipole import ( C1_c2h, convert_cart2harm, rot_ind_global2local, rot_global2local, rot_local2global ) -from dmff.admp.spatial import ( +from .spatial import ( v_pbc_shift, generate_construct_local_frames, build_quasi_internal ) -from dmff.admp.pairwise import ( +from .pairwise import ( distribute_scalar, distribute_v3, distribute_multipoles, diff --git a/dmff/admp/qeq.py b/dmff/admp/qeq.py index 0c3e0cf77..4de5c0be4 100644 --- a/dmff/admp/qeq.py +++ b/dmff/admp/qeq.py @@ -1,213 +1,311 @@ -#!/usr/bin/env python -import sys -import absl -import numpy as np +import numpy as np import jax.numpy as jnp -import openmm.app as app -import openmm.unit as unit -from dmff.settings import DO_JIT -from dmff.common.constants import DIELECTRIC -from dmff.common import nblist -from jax_md import space, partition -from jax import grad, value_and_grad, vmap, jit -from jaxopt import OptaxSolver -from itertools import combinations -import jaxopt +from ..common.constants import DIELECTRIC +from jax import grad, vmap +from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce +from typing import Tuple, List +from ..settings import PRECISION + +if PRECISION == "double": + CONST_0 = jnp.array(0, dtype=jnp.float64) + CONST_1 = jnp.array(1, dtype=jnp.float64) +else: + CONST_0 = jnp.array(0, dtype=jnp.float32) + CONST_1 = jnp.array(1, dtype=jnp.float32) + +try: + import jaxopt +except ImportError: + print("jaxopt not found, QEQ cannot be used.") import jax -import scipy -import pickle from jax.scipy.special import erf, erfc from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales -jax.config.update("jax_enable_x64", True) +@jit_condition() +def group_sum(val_list, indices): + max_idx = indices.max() + exceed = jnp.piecewise( + indices, + [indices < max_idx, indices >= max_idx], + [lambda x: CONST_1, lambda x: CONST_0], + ) + return jnp.sum(val_list[indices] * exceed) -class ADMPQeqForce: - def __init__(self, q, lagmt, damp_mod=3, neutral_flag=True, slab_flag=False, constQ=True, pbc_flag = True): - - self.damp_mod = damp_mod - self.neutral_flag = neutral_flag - self.slab_flag = slab_flag - self.constQ = constQ - self.pbc_flag = pbc_flag - self.q = q - self.lagmt = lagmt - return +group_sum_vmap = jax.vmap(group_sum, in_axes=(None, 0)) - def generate_get_energy(self): - # q = self.q - damp_mod = self.damp_mod - neutral_flag = self.neutral_flag - constQ = self.constQ - pbc_flag = self.pbc_flag - # lagmt = self.lagmt - - if eval(constQ) is True: - e_constraint = E_constQ - else: - e_constraint = E_constP - self.e_constraint = e_constraint - if eval(damp_mod) is False: - e_sr = E_sr0 - e_site = E_site - elif eval(damp_mod) == 2: - e_sr = E_sr2 - e_site = E_site2 - elif eval(damp_mod) == 3: - e_sr = E_sr3 - e_site = E_site3 - - # if pbc_flag is False: - # e_coul = E_CoulNocutoff - # else: - # e_coul = E_coul - def get_energy(positions, box, pairs, q, lagmt, eta, chi, J, const_list, const_vals,pme_generator): - - pos = positions - ds = ds_pairs(pos, box, pairs, pbc_flag) - buffer_scales = pair_buffer_scales(pairs) - kappa = pme_generator.coulforce.kappa - def E_full(q, lagmt, const_vals, chi, J, pos, box, pairs, eta, ds, buffer_scales): - e1 = e_constraint(q, lagmt, const_list, const_vals) - e2 = e_sr(pos*10, box*10 ,pairs , q , eta, ds*10, buffer_scales) - e3 = e_site( chi, J , q) - e4 = pme_generator.coulenergy(pos, box ,pairs, q, pme_generator.mscales_coul) - e5 = E_corr(pos*10, box*10, pairs, q, kappa/10, neutral_flag) - return e1 + e2 + e3 + e4 + e5 - @jit - def E_grads(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales): - n_const = len(const_vals) - q = b_value[:-n_const] - lagmt = b_value[-n_const:] - g1,g2 = grad(E_full,argnums=(0,1))(q, lagmt, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales) - g = jnp.concatenate((g1,g2)) - return g - - def Q_equi(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales): - rf=jaxopt.ScipyRootFinding(optimality_fun=E_grads,method='hybr',jit=False,tol=1e-10) - q0,state1 = rf.run(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales) - return q0,state1 - - def get_chgs(): - n_const = len(self.lagmt) - b_value = jnp.concatenate((self.q,self.lagmt)) - q0,state1 = Q_equi(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales) - self.q = q0[:-n_const] - self.lagmt = q0[-n_const:] - return q0,state1 - - q0,state1 = get_chgs() - self.q0 = q0 - self.state1 = state1 - energy = E_full(self.q, self.lagmt, const_vals, chi, J, positions, box, pairs, eta, ds , buffer_scales) - self.e_grads = E_grads(q0, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales) - self.e_full = E_full - return energy +# @jit_condition +def padding_consts(const_list, max_idx): + max_length = max([len(i) for i in const_list]) + new_const_list = np.zeros((len(const_list), max_length)) + max_idx + for ncl, cl in enumerate(const_list): + for nitem, item in enumerate(cl): + new_const_list[ncl, nitem] = item + return jnp.array(new_const_list, dtype=int) - return get_energy - def update_env(self, attr, val): - ''' - Update the environment of the calculator - ''' - setattr(self, attr, val) - self.refresh_calculators() - - - def refresh_calculators(self): - ''' - refresh the energy and force calculators according to the current environment - ''' - # generate the force calculator - self.get_energy = self.generate_get_energy() - self.get_forces = value_and_grad(self.get_energy) - return - + +@jit_condition() def E_constQ(q, lagmt, const_list, const_vals): - constraint = (jnp.sum(q[const_list], axis=1) - const_vals) * lagmt - return np.sum(constraint) + constraint = (group_sum_vmap(q, const_list) - const_vals) * lagmt + return jnp.sum(constraint) + + +@jit_condition() def E_constP(q, lagmt, const_list, const_vals): - constraint = jnp.sum(q[const_list], axis=1) * const_vals - return np.sum(constraint) - -def E_sr(pos, box, pairs, q, eta, ds, buffer_scales ): - return 0 -def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales ): - etasqrt = jnp.sqrt( 2 * ( jnp.array(eta)[pairs[:,0]] **2 + jnp.array(eta)[pairs[:,1]] **2)) - pre_pair = - eta_piecewise(etasqrt,ds) * DIELECTRIC - pre_self = etainv_piecewise(eta) /( jnp.sqrt(2 * jnp.pi)) * DIELECTRIC - e_sr_pair = pre_pair * q[pairs[:,0]] * q[pairs[:,1]] /ds * buffer_scales + constraint = group_sum_vmap(q, const_list) * const_vals + return jnp.sum(constraint) + + +@vmap +@jit_condition() +def mask_to_zero(v, mask): + return jnp.piecewise( + v, [mask < 1e-5, mask >= 1e-5], [lambda x: CONST_0, lambda x: v] + ) + + +@jit_condition() +def E_sr(pos, box, pairs, q, eta, ds, buffer_scales): + return 0.0 + + +@jit_condition() +def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales): + etasqrt = jnp.sqrt(2 * (eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2)) + pre_pair = -eta_piecewise(etasqrt, ds) * DIELECTRIC + pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC + e_sr_pair = pre_pair * q[pairs[:, 0]] * q[pairs[:, 1]] / ds * buffer_scales + e_sr_pair = mask_to_zero(e_sr_pair, buffer_scales) e_sr_self = pre_self * q * q e_sr = jnp.sum(e_sr_pair) + jnp.sum(e_sr_self) return e_sr -def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales ): - etasqrt = jnp.sqrt( jnp.array(eta)[pairs[:,0]] **2 + jnp.array(eta)[pairs[:,1]] **2 ) - pre_pair = - eta_piecewise(etasqrt,ds) * DIELECTRIC - pre_self = etainv_piecewise(eta) /( jnp.sqrt(2 * jnp.pi)) * DIELECTRIC - e_sr_pair = pre_pair * q[pairs[:,0]] * q[pairs[:,1]] /ds * buffer_scales + + +@jit_condition() +def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales): + etasqrt = jnp.sqrt(eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2) + epiece = eta_piecewise(etasqrt, ds) + pre_pair = -epiece * DIELECTRIC + pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC + e_sr_pair = pre_pair * q[pairs[:, 0]] * q[pairs[:, 1]] / ds + e_sr_pair = mask_to_zero(e_sr_pair, buffer_scales) e_sr_self = pre_self * q * q e_sr = jnp.sum(e_sr_pair) + jnp.sum(e_sr_self) return e_sr -def E_site(chi, J , q ): - return 0 -def E_site2(chi, J , q ): - ene = (chi * q + 0.5 * J * q **2 ) * 96.4869 - return np.sum(ene) -def E_site3(chi, J , q ): - ene = chi * q *4.184 + J * q **2 *DIELECTRIC * 2 * jnp.pi - return np.sum(ene) - -def E_corr(pos, box, pairs, q, kappa, neutral_flag = True): - # def E_corr(): + +@jit_condition() +def E_site(chi, J, q): + return 0.0 + + +@jit_condition() +def E_site2(chi, J, q): + ene = (chi * q + 0.5 * J * q**2) * 96.4869 + return jnp.sum(ene) + + +@jit_condition() +def E_site3(chi, J, q): + ene = chi * q * 4.184 + J * q**2 * DIELECTRIC * 2 * jnp.pi + return jnp.sum(ene) + + +@jit_condition(static_argnums=[5]) +def E_corr(pos, box, pairs, q, kappa, neutral_flag=True): + # def E_corr(): V = jnp.linalg.det(box) pre_corr = 2 * jnp.pi / V * DIELECTRIC - Mz = jnp.sum(q * pos[:,2]) + Mz = jnp.sum(q * pos[:, 2]) Q_tot = jnp.sum(q) Lz = jnp.linalg.norm(box[3]) - e_corr = pre_corr * (Mz **2 - Q_tot * (jnp.sum(q * pos[:,2] **2)) - Q_tot **2 * Lz **2 /12) - if eval(neutral_flag) is True: - # kappa = pme_potential.pme_force.kappa - pre_corr_non = - jnp.pi / (2 * V * kappa **2) * DIELECTRIC - e_corr_non = pre_corr_non * Q_tot **2 + e_corr = pre_corr * ( + Mz**2 + - Q_tot * (jnp.sum(q * pos[:, 2] ** 2)) + - jnp.power(Q_tot, 2) * jnp.power(Lz, 2) / 12 + ) + if neutral_flag: + # kappa = pme_potential.pme_force.kappa + pre_corr_non = -jnp.pi / (2 * V * kappa**2) * DIELECTRIC + e_corr_non = pre_corr_non * Q_tot**2 e_corr += e_corr_non - return np.sum( e_corr) + return jnp.sum(e_corr) + +@jit_condition def E_CoulNocutoff(pos, box, pairs, q, ds): - e = q[pairs[:,0]] * q[pairs[:,1]] /ds * DIELECTRIC + e = q[pairs[:, 0]] * q[pairs[:, 1]] / ds * DIELECTRIC return jnp.sum(e) + +@jit_condition def E_Coul(pos, box, pairs, q, ds): - return 0 + return 0.0 -@jit_condition(static_argnums=(3)) + +@jit_condition(static_argnums=[3]) def ds_pairs(positions, box, pairs, pbc_flag): - pos1 = positions[pairs[:,0].astype(int)] - pos2 = positions[pairs[:,1].astype(int)] + pos1 = positions[pairs[:, 0]] + pos2 = positions[pairs[:, 1]] if pbc_flag is False: dr = pos1 - pos2 else: box_inv = jnp.linalg.inv(box) dpos = pos1 - pos2 dpos = dpos.dot(box_inv) - dpos -= jnp.floor(dpos+0.5) + dpos -= jnp.floor(dpos + 0.5) dr = dpos.dot(box) - ds = jnp.linalg.norm(dr,axis=1) + ds = jnp.linalg.norm(dr, axis=1) return ds + @jit_condition() -@vmap -def eta_piecewise(eta,ds): - return jnp.piecewise(eta, (eta > 1e-4, eta <= 1e-4), - (lambda x: jnp.array(erfc( ds / eta)), lambda x:jnp.array(0))) - +def eta_piecewise(eta, ds): + return jnp.piecewise( + eta, + (eta > 1e-4, eta <= 1e-4), + (lambda x: erfc(ds / x), lambda x: x - x), + ) + + +eta_piecewise = jax.vmap(eta_piecewise, in_axes=(0, 0)) + + @jit_condition() -@vmap def etainv_piecewise(eta): - return jnp.piecewise(eta, (eta > 1e-4, eta <= 1e-4), - (lambda x: jnp.array(1/eta), lambda x:jnp.array(0))) - + return jnp.piecewise( + eta, + (eta > 1e-4, eta <= 1e-4), + (lambda x: 1 / x, lambda x: x - x), + ) + +etainv_piecewise = jax.vmap(etainv_piecewise, in_axes=0) + + +class ADMPQeqForce: + def __init__( + self, + init_q, + r_cut: float, + kappa: float, + K: Tuple[int, int, int], + damp_mod: int = 3, + const_list: List = [], + const_vals: List = [], + neutral_flag: bool = True, + slab_flag: bool = False, + constQ: bool = True, + pbc_flag: bool = True, + ): + if not isinstance(const_vals, jnp.ndarray): + self.const_vals = jnp.array(const_vals) + else: + self.const_vals = const_vals + assert len(const_list) == len( + const_vals + ), "const_list and const_vals must have the same length" + n_atoms = len(init_q) + self.const_list = padding_consts(const_list, n_atoms) + self.init_q = jnp.array(init_q) + self.init_lagmt = jnp.ones((len(const_list),)) + + self.damp_mod = damp_mod + self.neutral_flag = neutral_flag + self.slab_flag = slab_flag + self.constQ = constQ + self.pbc_flag = pbc_flag + + if constQ: + e_constraint = E_constQ + else: + e_constraint = E_constP + self.e_constraint = e_constraint + + if damp_mod == 1: + self.e_sr = E_sr + self.e_site = E_site + elif damp_mod == 2: + self.e_sr = E_sr2 + self.e_site = E_site2 + elif damp_mod == 3: + self.e_sr = E_sr3 + self.e_site = E_site3 + else: + raise ValueError("damp_mod must be 1, 2 or 3") + + if pbc_flag: + force = CoulombPMEForce(r_cut, kappa, K) + self.kappa = kappa + else: + force = CoulNoCutoffForce() + self.kappa = 1.0 + self.coul_energy = force.generate_get_energy() + + def generate_get_energy(self): + @jit_condition() + def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales): + e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals) + e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, ds * 10, buffer_scales) + e3 = self.e_site(chi, J, q) + e4 = self.coul_energy(pos, box, pairs, q, mscales) + e5 = E_corr( + pos * 10.0, box * 10.0, pairs, q, self.kappa / 10, self.neutral_flag + ) + return e1 + e2 + e3 + e4 + e5 + + grad_E_full = grad(E_full, argnums=(0, 1)) + + @jit_condition() + def E_grads( + b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales + ): + n_const = len(self.const_vals) + q = b_value[:-n_const] + lagmt = b_value[-n_const:] + + g1, g2 = grad_E_full( + q, lagmt, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales + ) + g = jnp.concatenate((g1, g2)) + return g + + def get_energy(positions, box, pairs, mscales, eta, chi, J): + pos = positions + ds = ds_pairs(pos, box, pairs, self.pbc_flag) + buffer_scales = pair_buffer_scales(pairs) + + n_const = len(self.init_lagmt) + b_value = jnp.concatenate((self.init_q, self.init_lagmt)) + rf = jaxopt.ScipyRootFinding( + optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10 + ) + b_0, _ = rf.run( + b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales + ) + b_0 = jax.lax.stop_gradient(b_0) + q_0 = b_0[:-n_const] + lagmt_0 = b_0[-n_const:] + print("Q:", q_0) + print("Lagrange_multi:", lagmt_0) + + energy = E_full( + q_0, + lagmt_0, + chi, + J, + positions, + box, + pairs, + eta, + ds, + buffer_scales, + mscales, + ) + return energy + + return get_energy diff --git a/dmff/generators/QeqGenerator.py b/dmff/generators/QeqGenerator.py deleted file mode 100644 index aba48d0f8..000000000 --- a/dmff/generators/QeqGenerator.py +++ /dev/null @@ -1,135 +0,0 @@ -#!/usr/bin/env python - -import openmm.app as app -import openmm.unit as unit -from typing import Tuple -import numpy as np -import jax.numpy as jnp -import jax -from dmff.api.topology import DMFFTopology -from dmff.api.paramset import ParamSet -from dmff.api.xmlio import XMLIO -from dmff.api.hamiltonian import _DMFFGenerators -from dmff.utils import DMFFException, isinstance_jnp -from dmff.admp.qeq import ADMPQeqForce -from dmff.generators.classical import CoulombGenerator -from dmff.admp import qeq - - -class ADMPQeqGenerator: - def __init__(self, ffinfo:dict, paramset: ParamSet): - - self.name = 'ADMPQeqForce' - self.ffinfo = ffinfo - paramset.addField(self.name) - self.key_type = None - keys , params = [], [] - for node in self.ffinfo["Forces"][self.name]["node"]: - attribs = node["attrib"] - - if self.key_type is None and "type" in attribs: - self.key_type = "type" - elif self.key_type is None and "class" in attribs: - self.key_type = "class" - elif self.key_type is not None and f"{self.key_type}" not in attribs: - raise ValueError("Keyword 'class' or 'type' cannot be used together.") - elif self.key_type is not None and f"{self.key_type}" in attribs: - pass - else: - raise ValueError("Cannot find key type for ADMPQeqForce.") - key = attribs[self.key_type] - keys.append(key) - - chi0 = float(attribs["chi"]) - J0 = float(attribs["J"]) - eta0 = float(attribs["eta"]) - - params.append([chi0, J0, eta0]) - - self.keys = keys - chi = jnp.array([i[0] for i in params]) - J = jnp.array([i[1] for i in params]) - eta = jnp.array([i[2] for i in params]) - - paramset.addParameter(chi, "chi", field=self.name) - paramset.addParameter(J, "J", field=self.name) - paramset.addParameter(eta, "eta", field=self.name) - # default params - self._jaxPotential = None - self.damp_mod = self.ffinfo["Forces"][self.name]["meta"]["DampMod"] - self.neutral_flag = self.ffinfo["Forces"][self.name]["meta"]["NeutralFlag"] - self.slab_flag = self.ffinfo["Forces"][self.name]["meta"]["SlabFlag"] - self.constQ = self.ffinfo["Forces"][self.name]["meta"]["ConstQFlag"] - self.pbc_flag = self.ffinfo["Forces"][self.name]["meta"]["PbcFlag"] - - self.pme_generator = CoulombGenerator(ffinfo, paramset) - - def getName(self) -> str: - """ - Returns the name of the force field. - - Returns: - -------- - str - The name of the force field. - """ - return self.name - - def overwrite(self, paramset:ParamSet) -> None: - - node_indices = [ i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "QeqAtom"] - chi = paramset[self.name]["chi"] - J = paramset[self.name]["J"] - eta = paramset[self.name]["eta"] - for nnode, key in enumerate(self.keys): - self.ffinfo["Forces"][self.name]["node"][node_indices[nnode]]["attrib"] = {} - self.ffinfo["Forces"][self.name]["node"][node_indices[nnode]]["attrib"][f"{self.key_type}"] = key - chi0 = chi[nnode] - J0 = J[nnode] - eta0 = eta[nnode] - self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["chi"] = str(chi0) - self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["J"] = str(J0) - self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["eta"] = str(eta0) - - - def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, charges, const_list, const_vals, map_atomtype): - - n_atoms = topdata._numAtoms - n_residues = topdata._numResidues - - q = jnp.array(charges) - lagmt = np.ones(n_residues) - b_value = jnp.concatenate((q,lagmt)) - qeq_force = ADMPQeqForce(q, lagmt,self.damp_mod, self.neutral_flag, - self.slab_flag, self.constQ, self.pbc_flag) - self.qeq_force = qeq_force - qeq_energy = qeq_force.generate_get_energy() - - self.pme_potential = self.pme_generator.createPotential(topdata, app.PME, nonbondedCutoff ) - def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet) -> jnp.ndarray: - - n_atoms = len(positions) - # map_atomtype = np.zeros(n_atoms) - eta = np.array(params[self.name]["eta"])[map_atomtype] - chi = np.array(params[self.name]["chi"])[map_atomtype] - J = np.array(params[self.name]["J"])[map_atomtype] - self.eta = jnp.array(eta) - self.chi = jnp.array(chi) - self.J = jnp.array(J) - # coulenergy = self.pme_generator.coulenergy - # pme_energy = pme_potential(positions, box, pairs, params) - damp_mod = self.damp_mod - neutral_flag = self.neutral_flag - constQ = self.constQ - pbc_flag = self.pbc_flag - - qeq_energy0 = qeq_energy(positions, box, pairs, q, lagmt, - eta, chi, J,const_list, - const_vals, self.pme_generator) - # return pme_energy + qeq_energy0 - return qeq_energy0 - - self._jaxPotential = potential_fn - return potential_fn - -_DMFFGenerators["ADMPQeqForce"] = ADMPQeqGenerator diff --git a/dmff/generators/__init__.py b/dmff/generators/__init__.py index 6f37cf7f0..163f6c12a 100644 --- a/dmff/generators/__init__.py +++ b/dmff/generators/__init__.py @@ -1,3 +1,4 @@ from .classical import * from .admp import * from .ml import * +from .qeq import * \ No newline at end of file diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index e6456d94f..8e9272d24 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -1030,10 +1030,14 @@ def overwrite(self, paramset): # paramset to ffinfo if self._use_bcc: bcc_now = paramset[self.name]["bcc"] + mask_list = paramset.mask[self.name]["bcc"] nbcc = 0 for nnode, node in enumerate(self.ffinfo["Forces"][self.name]["node"]): if node["name"] == "BondChargeCorrection": + mask = mask_list[nbcc] self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["bcc"] = bcc_now[nbcc] + if mask < 0.999: + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true" nbcc += 1 def createPotential(self, topdata: DMFFTopology, nonbondedMethod, @@ -1076,8 +1080,7 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, if nonbondedMethod is app.PME: cell = topdata.getPeriodicBoxVectors() box = jnp.array(cell) - # self.ethresh = kwargs.get("ethresh", 1e-6) - self.ethresh = kwargs.get("ethresh", 5e-4) #for qeq calculation + self.ethresh = kwargs.get("ethresh", 1e-5) self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm") self.fourier_spacing = kwargs.get("PmeSpacing", 0.1) kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh, diff --git a/dmff/generators/qeq.py b/dmff/generators/qeq.py new file mode 100644 index 000000000..afc99e2b1 --- /dev/null +++ b/dmff/generators/qeq.py @@ -0,0 +1,214 @@ +import openmm.app as app +import openmm.unit as unit +from typing import Tuple +import numpy as np +import jax.numpy as jnp +from ..api.topology import DMFFTopology +from ..api.paramset import ParamSet +from ..api.xmlio import XMLIO +from ..api.hamiltonian import _DMFFGenerators +from ..utils import DMFFException, isinstance_jnp +from ..admp.qeq import ADMPQeqForce +from ..generators.classical import CoulombGenerator +from ..admp.qeq import ADMPQeqForce +from ..admp.pme import setup_ewald_parameters + + +class ADMPQeqGenerator: + def __init__(self, ffinfo: dict, paramset: ParamSet): + self.name = "ADMPQeqForce" + self.ffinfo = ffinfo + paramset.addField(self.name) + self.coulomb14scale = float( + self.ffinfo["Forces"][self.name]["meta"]["coulomb14scale"]) + + self.key_type = None + keys, params = [], [] + qeq_mask = [] + for node in self.ffinfo["Forces"][self.name]["node"]: + attribs = node["attrib"] + + if self.key_type is None and "type" in attribs: + self.key_type = "type" + elif self.key_type is None and "class" in attribs: + self.key_type = "class" + elif self.key_type is not None and f"{self.key_type}" not in attribs: + raise ValueError("Keyword 'class' or 'type' cannot be used together.") + elif self.key_type is not None and f"{self.key_type}" in attribs: + pass + else: + raise ValueError("Cannot find key type for ADMPQeqForce.") + key = attribs[self.key_type] + keys.append(key) + + chi0 = float(attribs["chi"]) + J0 = float(attribs["J"]) + eta0 = float(attribs["eta"]) + + if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE": + qeq_mask.append(0.0) + else: + qeq_mask.append(1.0) + + params.append([chi0, J0, eta0]) + + self.atom_keys = keys + qeq_mask = jnp.array(qeq_mask) + chi = jnp.array([i[0] for i in params]) + J = jnp.array([i[1] for i in params]) + eta = jnp.array([i[2] for i in params]) + + paramset.addParameter(chi, "chi", field=self.name, mask=qeq_mask) + paramset.addParameter(J, "J", field=self.name, mask=qeq_mask) + paramset.addParameter(eta, "eta", field=self.name, mask=qeq_mask) + # default params + self._jaxPotential = None + meta = self.ffinfo["Forces"][self.name]["meta"] + if "DampMod" in meta: + self.damp_mod = int(meta["DampMod"]) + else: + self.damp_mod = 3 + + def getName(self) -> str: + """ + Returns the name of the force field. + + Returns: + -------- + str + The name of the force field. + """ + return self.name + + def overwrite(self, paramset: ParamSet) -> None: + node_indices = [ + i + for i in range(len(self.ffinfo["Forces"][self.name]["node"])) + if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom" + ] + chi = paramset[self.name]["chi"] + J = paramset[self.name]["J"] + eta = paramset[self.name]["eta"] + atom_mask = paramset[self.name]["mask"] + for nidx, idx in enumerate(node_indices): + chi0 = chi[nidx] + J0 = J[nidx] + eta0 = eta[nidx] + mask = atom_mask[nidx] + self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["chi"] = str(chi0) + self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["J"] = str(J0) + self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["eta"] = str(eta0) + if mask < 0.999: + self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["mask"] = "true" + + def _find_atype_key_index(self, atype: str): + for n, i in enumerate(self.atom_keys): + if i == atype: + return n + return None + + def createPotential( + self, + topdata: DMFFTopology, + nonbondedMethod, + nonbondedCutoff, + **kwargs + ): + + methodMap = { + app.NoCutoff: "NoCutoff", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise DMFFException("Illegal nonbonded method for NonbondedForce") + + # setting for coul force + isNoCut = False + if nonbondedMethod is app.NoCutoff: + isNoCut = True + + mscales_coul = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0, + 1.0]) # mscale for PME + mscales_coul = mscales_coul.at[2].set(self.coulomb14scale) + self.mscales_coul = mscales_coul # for qeq calculation + + if unit.is_quantity(nonbondedCutoff): + r_cut = nonbondedCutoff.value_in_unit(unit.nanometer) + else: + r_cut = nonbondedCutoff + + if not isNoCut: + cell = topdata.getPeriodicBoxVectors() + box = jnp.array(cell) + self.ethresh = kwargs.get("ethresh", 1e-5) + self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm") + self.fourier_spacing = kwargs.get("PmeSpacing", 0.1) + kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh, + box, + self.fourier_spacing, + self.coeff_method) + else: + kappa, K1, K2, K3 = 1.0, 1, 1, 1 + K = (K1, K2, K3) + + neutral_flag = kwargs.get("neutral", True) + slab_flag = kwargs.get("slab", False) + constQ = kwargs.get("constQ", True) + + # top info + n_atoms = topdata.getNumAtoms() + atoms = [a for a in topdata.atoms()] + residues = [r for r in topdata.residues()] + n_residues = len(residues) + init_q = np.array([a.meta["charge"] for a in atoms]) + map_idx = [] + for natom, atom in enumerate(atoms): + atype = atom.meta[self.key_type] + map_idx.append(self._find_atype_key_index(atype)) + map_idx = jnp.array(map_idx) + + if "const_list" in kwargs and "const_vals" in kwargs: + const_list = kwargs["const_list"] + const_vals = kwargs["const_vals"] + else: + const_list = [] + const_vals = [] + for r in residues: + aidx = [a.index for a in r.atoms()] + const_list.append(aidx) + const_vals.append(sum(init_q[aidx])) + + qeq_force = ADMPQeqForce( + init_q, r_cut, kappa, K, damp_mod=self.damp_mod, + const_list=const_list, const_vals=const_vals, + neutral_flag=neutral_flag, slab_flag=slab_flag, + constQ=constQ, pbc_flag=(not isNoCut) + ) + qeq_energy = qeq_force.generate_get_energy() + + mscales_coul = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0, + 1.0]) # mscale for PME + mscales_coul = mscales_coul.at[2].set(self.coulomb14scale) + + def potential_fn( + positions: jnp.ndarray, + box: jnp.ndarray, + pairs: jnp.ndarray, + params: ParamSet, + ) -> jnp.ndarray: + # map_atomtype = np.zeros(n_atoms) + eta = params[self.name]["eta"][map_idx] + chi = params[self.name]["chi"][map_idx] + J = params[self.name]["J"][map_idx] + + qeq_energy0 = qeq_energy( + positions, box, pairs, mscales_coul, eta, chi, J + ) + # return pme_energy + qeq_energy0 + return qeq_energy0 + + self._jaxPotential = potential_fn + return potential_fn + + +_DMFFGenerators["ADMPQeqForce"] = ADMPQeqGenerator diff --git a/tests/data/qeq.xml b/tests/data/qeq.xml index 664609905..a017ea93f 100644 --- a/tests/data/qeq.xml +++ b/tests/data/qeq.xml @@ -155,43 +155,11 @@ - - - - - - - - - - - - - - + + + + + +