t%*w5O5KIT?3%OSE`kf3U
zIGlkGu@_KEf`Oxnp_uVnl1TC%GA}a6%zj%H!>(6dq-MhhcIZJEsyr}cyDn$qw>C}E
z$e&1L85`_76#?tMN@Iy$I$gKF1|Li~$SxO@!8KnvbgKtTWUUTyN*1f3l0goAy~6;V
zENYmyI*PcaD;Y{7V^Foin}%k4V_UB(I-DuT?u2orZ;l>t=lOHz{~rJv{SN>d|A#ya
z{y`oSd+x&j1whk(0nmHuI=25@KKfm>9%u6z@V7r6ini>9BcrF;)sKZ?^&Tg1e07mb
z^>l&r$4?@wY{vF&ssVn#c-r~#IjQywMNAu|d-6}frm@HoeG9G)
zr^&aQ*XZtDo56~H$Iyts`4R2ppDY;E9$tm9^G@Izcu3
zCa26<_Ur>YRka09SjeF9rf+m&;YJiVwvX8+o{69D6`|XLg4c*_hc}m&R9DjK!
zj>SbyUY2D6298%J-$*Ti0}sao!{<0zfgM!iWGGlXeNLBzZAF)B2g#cY*|=!2Jes>b
zCO`Td;5SoCFRh8E%RlR&;lNN`RyP9$<6HhT7g)fBrK&J@z979`n?cjeT2R$R3tdKh
z$zkES;CkE{u`L=>rkqCUSqrc`{x>~eI}ImlS>vBDC$MX2q@_O?GUM+Xy2^>t*{ybv
zZh4q2d}>C`)@5O``&}@IwSxJ(-=gEnPat%U;MasB97bc??5UX2GF8i2n`1zrm*f`;ZiaB{B)8SP`Z6Nb>eIFr5~si3Q{0-OXs
zzyzDCD7G#eJ`~(wURe7Or%Wlhu`~|5pCyyNW=VQSm7fT&y@-F#yAhpdMu3X
zY5Q#e-6sri!lN}LciAd%Jg@||ycXx^5-HSw)5zHJ-ce0XHJQ}*jifnp;IF<2dX)Yp
zULo@J!=+m|Et#c^Rq|5sm*U~`bqny;K_BwXRsp805#@-M_0YI4uB5Xt0?d|QXHss>K$*rq
zS|bLq@O?JCk3sjx+e{LTJ-ahr&4*h_V9>Z73LW>P<9M0m|}acOHm>MO0nM9xEU
zU3V>doiM~9ts{8Gd=VU3qmTCTy`(E%81~LpgzYhWH1kC>H0vuuZSF2SzqgRYn987c
z;1=?$w;Bqvl`(#@H9p82ueO5#>d2p_5fL-tu6`y_6L!XWt$1=XDFJ-nD&xoAajB+?
z4`%KhCs41?(2BeW8rk5E>l|tswQZ~LXl)leO&|%&P7tCSpo&RHggK8p1jqk79Z-C1
zPxfwXV@}rHMQAtco_`tOlK+s$
z@*m_`YR_HvzW})WUjXzpUj~u)$M-aT&?8glOXG(E1=KY=N}lId&_j;K%a
z`Nxfw_FUV4+z_`JXMn5z0l;Yje0=(|v!TFzKS}W|po;1$ko09g&?TXms8fpbHO^wX
zj~|ZRc~6ea5~P~L&sqJKTu|Uoq)l^Ile^^+bm475B4DD6@_%jN=s9P2^(h6;JR8Q#
zA0kj?_G?-4Jib>)`K31DP_+HDl_phVZ0xiMnc@}+Gc7r{C
zyNq@!-m5=zL!X?wca{E_&PR4{wucI?7vYqQzb{7mZDcJq1z4y*It)cj>+G@YWj
z&-{QzBpbT>3M*a`O1syK0jY>$Z55=b%8JF*_1s+YSkj10a}r>k?**H`5K^muaj&cX
zWUC^xYISp+UE?WcG5;S%L&KupJ#q>afAQGd`fM0$_E?IQ6Sz{ZRI`Gtm;J$NAJt%2
z-fmzuofont*KUtrzl|8IZwN}3_2if5V)o(YX;2i{PP$WmQ|a9c=!((^vYgvV;`}vH
zc6cMXQs@Te$5vv~_M_&1iqGPHl`?Aeu$K64+>GLH))YI8w}bu|BMep8Z<(#QHxG}#`nt-K3zNxjVa=K{E(Cm2uo)H3tdbm7qTR%)HT
z4j0F{ApLw59FE1K{Um8fSfoxHvUOqIi!J0;&P?j@S7`m?rfXb+<(9%;%4n
z3tbO~QKt#D-w*SNUqyIzc1)8$Cz^6{khv+1A@68j$-5UoxQgEySTgfyKr}S
z{`0H{Pi@EXqkP<*!Q5SrBIBO|xBL2yKNf6%x9IRkelHVByl@Bs~CLXT<9O
diff --git a/examples/sgnn/model1.pickle b/examples/sgnn/model1.pickle
new file mode 120000
index 000000000..88c340bb9
--- /dev/null
+++ b/examples/sgnn/model1.pickle
@@ -0,0 +1 @@
+test_backend/model1.pickle
\ No newline at end of file
diff --git a/examples/sgnn/peg.xml b/examples/sgnn/peg.xml
new file mode 100644
index 000000000..d3f41baaf
--- /dev/null
+++ b/examples/sgnn/peg.xml
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/sgnn/peg4.pdb b/examples/sgnn/peg4.pdb
deleted file mode 100644
index 2c11081d1..000000000
--- a/examples/sgnn/peg4.pdb
+++ /dev/null
@@ -1,63 +0,0 @@
-REMARK
-CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1
-ATOM 1 C00 TER 1 -2.962 3.637 -1.170
-ATOM 2 H01 TER 1 -2.608 4.142 -0.296
-ATOM 3 H02 TER 1 -4.032 3.635 -1.171
-ATOM 4 O03 TER 1 -2.484 2.289 -1.168
-ATOM 5 C04 TER 1 -2.961 1.615 0.000
-ATOM 6 H05 TER 1 -2.604 0.606 0.000
-ATOM 7 H06 TER 1 -2.604 2.119 0.874
-ATOM 8 H07 TER 1 -4.031 1.615 0.000
-ATOM 9 C00 INT 2 -2.449 6.384 -3.596
-ATOM 10 H01 INT 2 -2.804 5.879 -4.470
-ATOM 11 H02 INT 2 -1.379 6.386 -3.595
-ATOM 12 O03 INT 2 -2.927 5.710 -2.429
-ATOM 13 C04 INT 2 -2.448 4.362 -2.427
-ATOM 14 H05 INT 2 -2.803 3.856 -3.301
-ATOM 15 H06 INT 2 -1.378 4.364 -2.425
-ATOM 16 C00 INT 3 -2.966 9.857 -4.767
-ATOM 17 H01 INT 3 -2.612 10.363 -3.893
-ATOM 18 H02 INT 3 -4.036 9.855 -4.768
-ATOM 19 O03 INT 3 -2.488 8.509 -4.765
-ATOM 20 C04 INT 3 -2.965 7.835 -3.597
-ATOM 21 H05 INT 3 -2.610 8.340 -2.724
-ATOM 22 H06 INT 3 -4.035 7.833 -3.599
-ATOM 23 C00 TER 4 -2.452 10.582 -6.024
-ATOM 24 H01 TER 4 -2.807 10.077 -6.898
-ATOM 25 H02 TER 4 -1.382 10.584 -6.022
-ATOM 26 O03 TER 4 -2.931 11.930 -6.026
-ATOM 27 C04 TER 4 -2.453 12.604 -7.193
-ATOM 28 H05 TER 4 -2.808 12.099 -8.067
-ATOM 29 H06 TER 4 -2.812 13.613 -7.194
-ATOM 30 H07 TER 4 -1.383 12.606 -7.192
-TER
-CONECT 5 6
-CONECT 5 7
-CONECT 5 8
-CONECT 5 4
-CONECT 4 1
-CONECT 1 2
-CONECT 1 3
-CONECT 1 13
-CONECT 13 14
-CONECT 13 15
-CONECT 13 12
-CONECT 12 9
-CONECT 9 10
-CONECT 9 11
-CONECT 9 20
-CONECT 20 21
-CONECT 20 22
-CONECT 20 19
-CONECT 19 16
-CONECT 16 17
-CONECT 16 18
-CONECT 16 23
-CONECT 23 24
-CONECT 23 25
-CONECT 23 26
-CONECT 26 27
-CONECT 27 28
-CONECT 27 29
-CONECT 27 30
-END
diff --git a/examples/sgnn/peg4.pdb b/examples/sgnn/peg4.pdb
new file mode 120000
index 000000000..3c2bb15b6
--- /dev/null
+++ b/examples/sgnn/peg4.pdb
@@ -0,0 +1 @@
+test_backend/peg4.pdb
\ No newline at end of file
diff --git a/examples/sgnn/ref_out b/examples/sgnn/ref_out
index 96bc1e62f..039b75427 100644
--- a/examples/sgnn/ref_out
+++ b/examples/sgnn/ref_out
@@ -1,37 +1,5 @@
-Energy: -21.588394
-Force
-[[ 90.02814 2.0374336 35.38877 ]
- [ -98.410095 -1.6865425 -30.066338 ]
- [ 48.29245 31.675808 -43.390694 ]
- [ 59.717484 -35.94304 50.599678 ]
- [ -24.63767 218.36092 168.47194 ]
- [ 43.258293 81.24294 -87.22882 ]
- [ -67.66767 -17.780457 -5.6038494 ]
- [ -22.928284 -302.96246 -123.14815 ]
- [ 306.24683 -21.33866 -156.95491 ]
- [ -4.715515 13.664352 -23.222527 ]
- [-258.61304 -26.577957 85.58963 ]
- [ -10.179474 106.21161 64.846924 ]
- [-210.20566 -52.107193 58.04005 ]
- [ 118.68472 -8.033836 -81.18109 ]
- [ 44.02272 -34.508667 46.852356 ]
- [-214.84206 115.90286 -227.59117 ]
- [ 44.243336 -7.151741 26.06369 ]
- [ 87.46674 38.574554 192.17757 ]
- [ 27.345726 -58.87986 -44.685863 ]
- [ -83.354774 -29.714098 214.93097 ]
- [ -71.111305 34.880676 -77.53289 ]
- [ 141.12836 49.28147 -97.597305 ]
- [-220.25613 -134.58449 -23.567059 ]
- [ 75.2593 58.432755 -63.99505 ]
- [ 123.56466 -82.0066 94.63971 ]
- [ 57.822285 17.07631 -53.788273 ]
- [ -73.37115 0.50865555 16.240654 ]
- [ 54.86133 97.53715 73.672806 ]
- [ -23.997787 -73.92179 -13.749107 ]
- [ 62.348286 21.809956 25.78839 ]]
-Batched Energies:
-[-21.653 -39.830627 9.988983 -48.292953 -32.959183 -49.7164
- -47.617737 -51.76767 -37.42943 -35.06703 -46.111145 -31.748154
- -6.939003 -5.1853027 -27.427734 -44.695312 -52.027237 3.1541443
- -72.8221 -28.33014 ]
+-21.588284621154912
+[-21.58828462 -39.79334159 10.03889335 -48.22451239 -32.90970162
+ -49.68568287 -47.58035178 -51.73860617 -37.39235277 -35.01933271
+ -46.06621902 -31.69327601 -6.86739655 -5.13698524 -27.4031207
+ -44.65301991 -52.00357797 3.1734038 -72.79081259 -28.27007722]
diff --git a/examples/sgnn/residues.xml b/examples/sgnn/residues.xml
new file mode 100644
index 000000000..aa78866eb
--- /dev/null
+++ b/examples/sgnn/residues.xml
@@ -0,0 +1,37 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/sgnn/run.py b/examples/sgnn/run.py
index f87c887b5..14c7c1b84 100755
--- a/examples/sgnn/run.py
+++ b/examples/sgnn/run.py
@@ -1,45 +1,44 @@
#!/usr/bin/env python
import sys
-import numpy as np
+import jax
import jax.numpy as jnp
-import jax.lax as lax
-from jax import vmap, value_and_grad
-import dmff
-from dmff.sgnn.gnn import MolGNNForce
-from dmff.utils import jit_condition
-from dmff.sgnn.graph import MAX_VALENCE
-from dmff.sgnn.graph import TopGraph, from_pdb
+import openmm.app as app
+import openmm.unit as unit
+from dmff.api import Hamiltonian
+from dmff.common import nblist
+from jax import value_and_grad
import pickle
-import re
-from collections import OrderedDict
-from functools import partial
-
if __name__ == '__main__':
- # params = load_params('benchmark/model1.pickle')
- G = from_pdb('peg4.pdb')
- model = MolGNNForce(G, nn=1)
- model.load_params('model1.pickle')
- E = model.get_energy(G.positions, G.box, model.params)
+
+ H = Hamiltonian('peg.xml')
+ app.Topology.loadBondDefinitions("residues.xml")
+ pdb = app.PDBFile("peg4.pdb")
+ rc = 0.6
+ # generator stores all force field parameters
+ pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer, ethresh=5e-4)
+
+ # construct inputs
+ positions = jnp.array(pdb.positions._value)
+ a, b, c = pdb.topology.getPeriodicBoxVectors()
+ box = jnp.array([a._value, b._value, c._value])
+ # neighbor list
+ nbl = nblist.NeighborList(box, rc, pots.meta['cov_map'])
+ nbl.allocate(positions)
+
+
+ paramset = H.getParameters()
+ # params = paramset.parameters
- with open('set_test_lowT.pickle', 'rb') as ifile:
+ with open('test_backend/set_test_lowT.pickle', 'rb') as ifile:
data = pickle.load(ifile)
- # pos = jnp.array(data['positions'][0:100])
- # box = jnp.tile(jnp.eye(3) * 50, (100, 1, 1))
- pos = jnp.array(data['positions'][0])
- box = jnp.eye(3) * 50
+ # input in nm
+ pos = jnp.array(data['positions'][0:20]) / 10
+ box = jnp.eye(3) * 5
- # energies = model.batch_forward(pos, box, model.params)
- E, F = value_and_grad(model.get_energy, argnums=(0))(pos, box, model.params)
- F = -F
- print('Energy:', E)
- print('Force')
- print(F)
+ efunc = jax.jit(pots.getPotentialFunc())
+ efunc_vmap = jax.vmap(jax.jit(pots.getPotentialFunc()), in_axes=(0, None, None, None), out_axes=0)
+ print(efunc(pos[0], box, nbl.pairs, paramset))
+ print(efunc_vmap(pos, box, nbl.pairs, paramset))
- # test batch processing
- pos = jnp.array(data['positions'][:20])
- box = jnp.tile(box, (20, 1, 1))
- E = model.batch_forward(pos, box, model.params)
- print('Batched Energies:')
- print(E)
diff --git a/examples/sgnn/model.pickle b/examples/sgnn/test_backend/model.pickle
similarity index 100%
rename from examples/sgnn/model.pickle
rename to examples/sgnn/test_backend/model.pickle
diff --git a/examples/sgnn/test_backend/model1.pickle b/examples/sgnn/test_backend/model1.pickle
new file mode 100644
index 0000000000000000000000000000000000000000..0c3959cd9d0ef4fac155676861aa6743acb96c99
GIT binary patch
literal 17100
zcmYhjc|28J^gmAKAyY_*1~L^=63$*Xr4mVLpj0v>4d|LmrOb2Y$XF7RREWaa>!!>p
zB}p1I7fmWEY4~}b&*%Ame)o_2y03HhzH6O(&RXyNUhBP2h=7}$&z?Qo-TZg@c>9Ul
z`MPiS-R^F=)6HL;%co<{<=1xP=i}qs$DQEj9pJS$NZ-xJce}n!ym~N``>}?{y}@U
zi*v;tCyZP1n9r54;j`h7=1SUgCu-XW{A-)xO08P8%KGp4>)$sUK7X#XYwQw11rJ#c
zSH_*^@^&ulkCeM9|y+js8p^ykWX3V8l&dy?n4VR?J5!as(!ZRSUD6+I`$E)JQ-
z$G1T+fcBgerYQfCJnLyCU5n1L30H2haA*$%7`q}YKMGyJIdm3Jif)~eM88DUuz!uc
zNNAHD+&Oj~KmEz4?{+8A_PLsLs9c+veQ%*5X%DISqnE5w(ZYq^-y~7^*j0K*{~jG?
zo0*RZ6-;YuG?Mo3Bsg^!h6>seBNG)o+1p0szEs1!-xHuICy&g3U4{qL>*3;vItT?J
z++%M;@YfPJ%(@T_&OuNe=EM21Ps#JaKr1#@`jJY_vKV!_Lxq4MtdG?`%+
za8xiyTN-Hc7!zA9hPRsjl2rp@OquL=HvHWcni(O?s%SVeJ!8|Tdy+V;@zlZZItF#`
z%);yGiaS(6Hjg^4&&Fz%Od?jim7Q>^1;4KifrysX5WY~0RQ(pitW0lmgrOgJDv<
zx&`%6#crm}t;cD_+kE9m
zyJaiMx=bnTIHUm~7pMV&TLd7zWFD1n4<*z7PJ_7zwouFJ1hY32#o>|pN%G5Q5k|
z)2AF`t9s<9UuG0de00TJbnYBDGh3Z7k-hZs@j1l9PY80l=aC@6e8RtBElO;gi%lnz
zuxr&8(x@wmTG!W-w{ca3(Vj^A&W;iO{7hW^CyG8@zlRtsmu6ZXpJ19w3Sq~myQq2O
z8Qu-6!4wH&tedwSS#?EJ4^#sqrVF>X#BesXDG>gBmUul!hey?BQ(kU8=X2jl8l5!2
zR+g_uA%_esv$n-6-LpV@K$;LmX`I(oMLexz@J50j{T`G_RWpPjZPPN=PBH{}mUd`x
zTL9E;&B&l!CE6qz(y4t>#5H~u8!A{#{(fiC>X{DjOieSLANP#5&5Q-V`IjjFfz9yf
z$piAlV?Fv@ngg>BJ~CfE$qM|9U((AWPpQZUF^;c(F3c&Nj6EfH$?WTva6mB;YG)nc
zOmqo=zn@m4b^3W;@Qx?&!E7}wkbVcf`U^QK)@oQSeTL|zFM`)rQ*j>-lY|K@b%AU$
z@2LdL4eTNZ?uo#$KkMqxDTdeYds9rVE$C){Mx0`HSxCZijVX}*_(}b#vHOJgS()-b
zyGe~729a$u)lltlmARK%2T^cd1&7u@VTE?sQnf7;cn0}47zdj8;8isp-W>*;cLk%1
zya%<2n^;${Bn01D%>kur5qL6rE$(Z%3x$b^)ZZ!s?^>sz+?U@hi{nr_HpG*=u=RVOp
zdJsa6>}RH5+=>P(LuvnX5q!~TOTWj6a~jHO;K$-A7%a4xhanx<_1POlqFYfRWF}8P
zZ34&j`+KswYX$5WJPaN$1Yl)w5BYRIf|;qc0CIgJ$m#cH^zKIhdu=b8-RI8SJMoC*
z*iI$F&zI8TE8;Nq_ZDd0B1x|FYMCn)CeTygmy_yJKDZ*R!g@T)lDN#=yPgj+svk*}fH3U2z>uE9@#Kc)1{B|2i`(J{>8(fah*rWf
zRKBeQ=24$;rCI@WytE?S7r&DC1N+hT9FQf&44q}zLhHKXNxvz}rXI{9DK)!j6tO~q
zM~~>-kKa*njE{aToQQ8q_^HA-P4aXhAFr{$j+QL-Ak7-BWYDr0SAUsI_Pn`8j8<3!
z&p8A?2lX%_Q7xbvkVK4<=itfH3-E&4T-+5Q#0k6l2RD^jgT479W|0;@_!u@(6?;{h
z))x#B!7`la+FuaI`N&Q$D8-wN(>P(9*WswLOZ}T2a&%}}EciU%4oki}l8M}9c>Vh-
z5*D}shnx4{!KsrWf8R9TnS|G@*h)>BN~CCT!b)6YTFc&=mO%Apm{75vt)za|HD>v!
zC}?n0!D-bVb~I85RWE0O^!aL9q4=J(ttM%T$FI9
zHm8qLts7p{MA8!rO+)G5w*%%sQv5K;;0)fGbc>l_kOG%qhfxw10}8&GwDG|ic1Qb7
zeDi%f7OXYq%ovHp)%ANxcA_fE!a5k7kxaHu^r8(7;;i`8-B?^_$~e*p;%wK;4i@gg
z`DYiBA$H1QP6~oRJ{SN#DxZSv
zXcF$4vJ-AwhB1v|h4t#^X3$ubNZNJf0<5AVDAoCk=pQm68~!K)_FRF|*YddBRRSNk
zU5BHxe@RM}HZ9v0Mt6CBqlOwNsGFUc;z3$gY3epTdv6ky_NhJ)`
zPgTMpy8=A=YZgZ%kVnrR%p>BwLVDO`5e77bfL(_=3H@r0i>vy`^MKzp^mQGT%U7dM
zmidrhq87Atwh+YECe@em?!v3nlg$IypP-YRi@|WoZ16nGkLHU?i00}MD2$edqwXp=
zN#!B2*)C1i8koX|3VRf7n@DE3d}en~(4}TpgCsbmoVNEHf%3H&-hMk95dOH1$bOzc
zWA+Q-hU`2#!`l}`H$I{^D}G|_r@y#&!w)K%K0qINzopr|S*&bg7;|qCKO_}duyUd+
z=}^y2GV`4pY@CJc%JUOBULRk4M`W9pZTD5}PtCNA7TPi9}#3USF@k`7l~da#;of7u_UUsg2a?L=*k3
z^oTC78HVVo&5Y@rdsNPGPrYjSUe<7#4edrQhNc`Nn|IHK$Key<+6+yi-=F}(g{xq7
z*9J2BDS~Xzox+;6y}$r+9jz5>QE`$2uhvfwcnOYFv&|I)Ya%gy>Q0&)k`&&tCME!u6xaNaryRg6er7R3VNL!~N)4WX!W@sD;d^
z+tkDB9QX=4)<-v8$8Qduu$+ou<+J_ds#gIvPHScNeba>OUDEiva5`I|xgX_i^FaTI
zFj{#>BgRVBe|;|iTBjsI>7fICK4mK~-M``c_ul%ki>ACfbSyL=_Nv|9()z3d?`tw!j>-|xwYc|AMFZXbJK+ZVF%w-&u5yM~4t
zj%NVhEQZ&fKq}hG;e(OI}YD{$|CnGX5j6lzck|cVqCJ|9C@_65)$o3F^6vwt@2q;_C*@OnSrbD
z+g$^0O#XpUg%Z45FAjj6lss?BuglC^&1iGMpU*j4f<@4ILN{~XK>{4EiD6`I8tiRY
ziEm35(w8^H$?LPC)I+~@JRTEBe{wkqoxGFgTplE`wh7E%whs%;%s54E{Io>qDzJK|
z>DYx8G;}N&cKW5z_wp}DO0_J_kL2LF4TiMW(TI2{*n_Qe03A-H*c6b6MmKEWo~j9H
zaC$-iTo0mq7cgYUReid#(V9tZ$|UgzRiIatpNu35!-9R!!A#>MbEHolhYJnqK2#-@
zxBif$FEn`ZhtxU!;U1iq%BPs3?^^7g?k#lh=UZf|*)mkMe!$@61ZtP}kOT^*kU1}!
zXvW$<#38kr=HK1O>g$vc!RHOsb+#UjRnKP<*9e33N)LwL$%$cW6d2!<2TXd4De26;
zf}RG+bz;XKlW4;%qF38TgT1Qhr)B5a(16X1Vs;I&35=!u6P3t4F@ADbLXs3zXp*Zp
z+F1q7WE9@Ufy$VDXq+g>ep_5kYnGJ2;O4(%)B6j&-R|m~m#tH=(fk^W-O=NnkC)<1
z%~9Yi1AltLY6CnzWX7?Zo(lP8UrGL?$!N}<1h&7-q4Uuwxv{x`b!_^>Y#F^qS28bI
zWy!~eq1C{VL9~K
zEnVhAXCC>;|C+pABtTa^6+zh_iSWbxCHt%9CoSgHk(NG1dN{0xr~l^^6nqt
z46<-KG=qdxCc{12P9la*g6xxgG}5!CmKIUuqk;fVzbMEVK@Z-*Cm9GfQRZx#Yen@^
zyUFb7*NC@#D|Cpxqvxjk(`+MkI*)FJX08cd&bSQPWfLGIcsaaM*no1I`CxuhE*+>j
z#6A%Tqz8R}na5-%(9r3RneM%(q5OUUQNJ=mS9n&!iaF^-&?}FLeq8_qazpfWa}qo;
zJcb!=??{i@9pY~6g)T0uaGk+pV*EXZGC=KXEdAXJpz4R;3<2UD_Ob0_Uqi-)V>
zUGOgH6F%42MT~n^AZNcX)%)H?ecj(t?T-Vf=Pzw0{$?-8KKu@8HEwWPScfzDls->x
zavWX9-#~8t%piZ{1!<>vviZ|1?wDLBLN1Q#!3Fh?jQKokP#Os$zmEySj0XlVA~i~l
zgC7$e;a0fsK9xy3UQc%Tl@oo>bVklr2n3?}sB_sLa>*ecK5fwC%b6H8~C^^H89FHU~zOIdKIJ;C1E##1uEe@pm2gJ3favq|ApK
zW7|MIWDajykvG07noPba$?*JJLb2k`8zL(%MNXm*Jr?ST#b;EB(3k~O&ydF9+H>r-
z)h1Lv>N&IeKtA2VS4+g#6_fiNc?_92lbn{?0M!s$(XHn_NM=?d9m?Pb&nrCh
zg0;isbGIS=D?Ae&n;Xfo-W9NUX$iSmxS!5W5ykZndud|MWL!5Zf`-5eh~X~4z~OVy
z$z%Y3^ffj^$_*D%IdaLOgi$r2$6fwqR{K{AYI{doE_ONN=+XY5rx1e^2&J;jYtb)
z7bypmI|jAHcFBBZvfW;cQyienO6Sw2+)(@yzl9cXB=F6_B$Sh>W8@wvW5f?t*tm8v
z?HX((l~Je3tOKRgN=k}`Ri%&~lV@;l^BpK^i-C1rxgg;f34Pydq5s@&T%3#-9!~R@4M6)bbfkW&}s;a3=UgXx1lDu}>IinQ!
zpQwN#!60fSW`oye9wO1}x(V$sf-f^pQ+bJryv5#6p}2b+QBt4J{xZxUS7lR4ThUSW
zOmjQY-g6cXNxmnEqRVKPeK80J?1tGl*Fdt-d|tpDK5VgUM(d9C`0L>%`Y}odeZO4C
zgSq!fvaCPkJ;(u#Gy~3pqaB#F^axK{*$MXZf1|RSuaZgbF~rHsg#Pv>WW|PSJ}%l-LpfKr%eFk4Zlu#mmAH
zdRgx*7?+e2$CKH_Wd1i4%U;49-5Sbx$7Vv~WLc2=I3HZVE$o
z{pk0E##s!qM;`1Y{abZ$_hmu6+ZRUGHqXNC*9?fFi3Dbi4zgvv&S-V_JZ^EwqIHji
zIGHwDP;r3aBpw_Fk$4df^JXVZjs1?E8=qpv<|Q1Pq3>kNnpn>3X>z>lMv1&Y<;kC9!P2_h-G0i?F1N$7L;cMSUxGKm2(W2QfqBX#r)fgd0RtrI6
zeJU+A?|qV!K~o{{q$IFgsfs%eRlmnWyfz>j{c-672LE|v$Q=?^gDPA}uG
zA!>H=S1gH{lZrh}S#BwM^?UD`)W`C~
zd){>L6rGEH$Is%l>RlN2X%@6?ih);SZXiy-?~#64H6hB{5+*LX2TUKtC!42!VvudR+Kc
z1SNyQ@T0a0>|6N?Hmtgg_iqod;XbKk_WcN0{;C>V@~7~0f)XIraUUl%CmK5a%5cq;
zbQqo%%~tFQLC*wpl4KT6e+_3rtm1erWIZ3a#sy)Kfe{$K^P&cW2@qPNip6pQ*dRWU
z7?y@$^rFRNefM#){p1aNFZ95Ct=~W_f{xkiyP6nmk6vO(Z
z`pl88rRILVbHV<5IaN_NWw#ueNor60zy!zBRN&M_rpDzkR%8G&m&36!;WiBFtz#y%
zB;%iLL9lMK9u)a6hpczvr2nG~Z*#c`&urr{n&M~8I~UXg=U)nO3Z58pL@#PnVXs=#
zoq0!@nT6Hl?ze4FxM~J$IdhgozRMuT7B7WT!$6WR76$D~v#IL0qqH+f1WG?2CbPGR
z!OX~$w8MZ4vf*1<$bCuLjds*WUdtg-eN$l3R2Xje&*UtP3&&Gy_kqWfJ{Z!R
zOfKG$!<|Q3;Njj?@HpfLUGuCEKUcaFAGhhO-GOuD`1n1(R<8qhU-?PT&-+Zyk5!T9
z*UOmCXU{Vgf4>OXw&0ab_{K-BAuB
zS^4N~9tTc$yvfKZey|nqqhF1G)xUDuk5?iTNqLepgnTKby!G>mwMiEd^(#RX?LybM
zh0tND3AjIyv`L-DhL0>%m$zf9lpuDD4AkA3p$p2~cw(ZI#mqes0>Mw**x9rE>9vG7
z>TN8Ce?&ql-@|J9cK99HrX4_qLM9>qrbNg{tsvf3g~Yc$oWj||L~`H$c
zsm2-0Qt1Todg?U9(5LeW6qYHH#!t
zO7MJb4|!o{ie~i;=&Cs42PG+Lm7Gb#2c@ZwiZNaPJrW9b`LN9`TPPknPkYAyOfCt+{NpB|)R%_Z}
zONI!poiI!%jYUJ$;RvWYnNEq_OrBzBBOJMZ7v-f4c;AwuA>PQ8qaY$j$2Wco%Ty*qi)$NPyEy@#
z$!)GJ)6=9AJNiLLK8eNqXUtU&zMz|)>cgSE3cP(?spOQ&QM6n?LLX_2(YYn@7__ty
z&RFHMnm??d`97DYr!4{jj3G~RTRew1I)$gzBLLs&R(eD~$z1LI43r*>B_FHK5eFF~
zYI&}lMksaBq}mo5)8vFc9}URomEyR*MiO7>vSjetdv-5GlgPt2>x;TB(k_j#`nG|I
zczU-2-aos6)$y+<(?U`8$T5OEAa`#~M2NRsUzXg^T?ZXq5bs#6L2A^kN0tcE6
zcde6QZJQ_hZaNDGj15TdHCayHgG-S1jSK7gqCqf7lQSwmiC*?cvRG*%uP39Md{8yQ
z@T%9auA?0Ngr}q5uE~s`N*^)FYQw7e7&2vAH
zzweMSdp3p8=MqGBcK`$WN|?AZ1y0(Ua6D7QI0ea1VS>Q}xIOhQs5J5*CukmWEI+}@
z;>o-@;sWF%dxHd?j|Lr)3Y3t{fjP?c=*vxmv5Y+!Zm^W+^u`9#|6HY~>rH?=*`KTA
z>L~vIvdQ>plB9Z#8)|c#(82d`|dhH8YJ(a&XEj
z8KUKnG4ju+(@m?Y$*I4>F!(!={yn*#+%Yew^ZW(bPH#^Tc{0o-z1zYXnx>N4tUv7L
zRnzg+)t#`!J%Eg5En}5ac9G*|zsVF%7L`BP%5c^Ex$6HX$e8<|AY%laanIBf)!NgP%<>o9FLVXQ{SDY7~&i0t&(fMy;(u=-Lo8tt&g
zPgA4_+hM?4w)Ys8o+(1U_2UgbyJ)N&N&$~%Lx@}vkISFkBwe+!;PrYZbdMKpy~ONr
zW{qci*G6Iudx*JR69TB>9ndv%Wv7hlZaPLD5!K*fRG9Ot&f{jen%^kq3jhA3e#R
zo#n9O&pXEcNj&oJiz0Kj_d=DC7tGQfgk)|sELr&pgAd4%CT%%((m^jU&pi%XeB4n}
zJ(NiFSHhyDiqutHmCi^ng?T#mXr4)r
zrgHF%_#dMEGo3DNs3nUZ7D2`2*QDsF4(@&{#o0}raWu{gddjk~X!HOi8jKLfVpUL7
z$%Gq!Ht?x}
z_bvx;Funs9ngv02%_W+gv6DP1oXT67R1Wc}wz%rUEnw3|@Z9Zqc
zhe6YozlgM@1uTqKf{({Uu)r%B+ru|O`0+}b@xBmtbu5PBPwmu5GM--C@R062yqO-!
ztz)w1-J)6h1EA*ePB{1|fh1;e;DN+U+_~lw_LkYZ3yGc!Jj5wa+EAeuq9K9NO3Lh7w(@7VEN$i31
zBqn4pv2^mM^ByOIU|k^Ne0M!e7)+*>K56FO<3)o}{#zvS$xrsnvnx1c&XnFr;d}9QR-c-N?u1>{!tQj^+=U{U0Z+7RKBD__hOqC`NgVw?wL_a;2d~!QT
zY!Z3&<&_G&5U7vst%M%+41vhfBF0WL4t9PPL9dvZyld*$sKM(sD6TRG?W3BozkHOu
z*!T;s!)c3_AJ!-y!`6o*^ooWNU152F?g-yW<3tv-A8$IE
zTfAGrTkBG8el*np2XZH%-P>poTXvEjeZGTgPRk|_6i?zFUKu&NBAf|-qe)I5eowxq
zj1$Y^7-}LR3IZD^gXN4*L}NuMI{%u3lMbK2{YQR5NQV%HNL?TorWugyX@E3f
z0kCxpHC$wPy`ArvmtJlFC-h;{Ze>ooo-k#<
zbW#=B(B^};JgV5=Ru|0$f+q1yqyyn~Tpx_MkI`G4xtLih!CByY9{-Rn^}5wru%tN<
z`m_pgqUC+^@$_b3bo-$^uA1)Zw?&r;<>XGU5Lqes5c?a_kV%z<#qzSiai~VKp7qrJ
z_e!wlvq1MiY3Ojw2J0OQsf7tYiBy!w&*J%{W*YKiTD17+*mi*1jWoKb7t5T5<;T6Zp3*W~m^z=U4
zz3C=dRToRn>YOFZ@)Teo&xY{!h0%)3LfHB#4I^c3@ugn@>V0V-5iK!n$z}`qV0{_9
zC7MAbbpXtg+`*$}CGmZ2!8re%Ln6B`)$^tgQ31&gFxjj_g61$VaC!yIUzJN@#RurK
zpZl2rqaWsu@1y9Ou-ft5qX_%0{*ZU-vb-a;C)jUc#yCrHK3?k&MgOWROvTlOxG{1L
zd6~4FW0V~Mc{67*DO*_*qiD>QvZ?f5^C@B-(nki{%yD!_G}#@RhezyZkqAoX9foO=#xt!YpHJNLy|ur!U_@9E?FDN<
zTxuhga2P;GP9@2qGl6qPA7A$Ow;c)A$jV)fv0m}1ROs&;o?9wTt0OdbY{3<<~0HKK(!9OUe!gS?|r!Ma*L-li>Ek&(Ox{B3bmv%CjdPAq~Xix=cz@7DSz
z^)9%%tqz)vR>3yOs~{_E0cj^HsX_8WEMOng&akKTmx~KgG)x{2=ijCuH@|=pmvm;O
zTsrdU&cbDTPQ(8A035C~!SzcQfzzRvbcx|^nr^L%5gq_l3xa_SIYzp^WZ|EiC#jda
z0hoh2E;l{|J2uXN`>$8QzTr5u^azL7G0k|+@CQWkO*U;W*$;<*XyOrtCYqWt4IGNq
zaCC&9#QfxgLo>v2RbR?23T-SLT7^7c-PzvK%|U9DwJ`^I-J)E#k~K1w`LOqsx!a
z5c1U&azZ4qckfP|sPTflUb`7g+qy}T7>_A7>roGG
zGyL`&L9XUIrO$s8v!}xKgBNe(g6VR2Iqe=6z#3%gy>O3Erp9*{AFeI@VtJ
zbMv?`W`ze?7}ZK&eC5%in}JNO-$W#9gg99qi^;>+wM=KfI7!M%g2SZ|?9}ypz?T!Urkd(*7vlYnI^ykj{KSDNF|6f9;^&fN5{)f4YYvOhPWiF9iy?+SV>Lqp)AVx}!
z<}~Eu%XW9b=F=qS@Ju-9_!Kk_G?3rYjv$ghk6*A<6h)1{B(m5v@w
zxbmDSmdmF(Lmq6w3rpnA_UF#|KLTa+AA!>U&kcir+%UA~&ix+(B_G8#`iDR*66WK(
z*Pw<2jn9eO$;+s8_BhG7tO(Ocs&yn)eh9KZ(ik{;^s_QhHi1dGf
z_f;LN!B;H1;M`aLO!?%)ana8f_>n!C=lM_s3=+4J%?Zy~
zNBOJ&xagR=ohmL}M&-qx(ks5pvHSRF{jAC)
zv=B%mhSodLZpl3Kah?YYgY-b7AR0AhM56P^PMRPd4zDl|zSKXii~aK$zaKWifJzB?
z_v$M&rb}_&4Oaj@AAxm~_;@?^yK!zj-iW!ey{LEh#<;pvfwS}eRd9aG!%aHTu={N;
znhpNKw^>Iqz+H+asknj7yL)7Ga2>vmkibaab`p)pm<G~10v)_YOm&aoG
zWI1vrR%|?o)M@Uqm_2}x|trH
zs0)47Rv2?O4t`1%f@#okVlZUP**i}f0$vXhzr{{W-*_U@E^9G0IC_%~6v}hbT4ONY
zV;D0ZCDZ=+Q$(aeh<8Unj5dGQN4a@XbiTkRGLn#jN8DY`|Yp|i)8T}!W%zoXHpYF~@BWfyVE
zXBo5_(t=Z;k04!8ikA;*vQAUO%?(H`z63ce7W+v=^OHeZ{1V1p;1;*3X{PN!Ohff(%VH@bE67$K5Vc+h<~Ecgh*RgoeS}8)u-|@-JB(^92?;
z^FgC-D6>bm8Pc4aNL(Es=V7`g%}g4FL(|Wqtn@tMuze-OMQYOfAKGcwTLaAf*~nO^
z@IxYbUgxOH$2pa8m9?~vW0f>}D1Fk!Za*N*am{k3PepdX=+0g;Sgp@_5cP^ieih(#
zt9}5TjxegKy%^^lS`0sTz9X)dA4$pHGiVg7g_%bdW0nZRnJ835UY=ddyBhHm)6cBI
z)yt*mrNj~tmMw(!7L9D?#0b1PuI9h~O`3P}X*znZy+bx^Z^xUj-jUj;-stenmHMB`
zp*8uLASsrIVjm(d<
zMUa<}L<8KH!t%*w5O5KIT?3%OSE`kf3U
zIGlkGu@_KEf`Oxnp_uVnl1TC%GA}a6%zj%H!>(6dq-MhhcIZJEsyr}cyDn$qw>C}E
z$e&1L85`_76#?tMN@Iy$I$gKF1|Li~$SxO@!8KnvbgKtTWUUTyN*1f3l0goAy~6;V
zENYmyI*PcaD;Y{7V^Foin}%k4V_UB(I-DuT?u2orZ;l>t=lOHz{~rJv{SN>d|A#ya
z{y`oSd+x&j1whk(0nmHuI=25@KKfm>9%u6z@V7r6ini>9BcrF;)sKZ?^&Tg1e07mb
z^>l&r$4?@wY{vF&ssVn#c-r~#IjQywMNAu|d-6}frm@HoeG9G)
zr^&aQ*XZtDo56~H$Iyts`4R2ppDY;E9$tm9^G@Izcu3
zCa26<_Ur>YRka09SjeF9rf+m&;YJiVwvX8+o{69D6`|XLg4c*_hc}m&R9DjK!
zj>SbyUY2D6298%J-$*Ti0}sao!{<0zfgM!iWGGlXeNLBzZAF)B2g#cY*|=!2Jes>b
zCO`Td;5SoCFRh8E%RlR&;lNN`RyP9$<6HhT7g)fBrK&J@z979`n?cjeT2R$R3tdKh
z$zkES;CkE{u`L=>rkqCUSqrc`{x>~eI}ImlS>vBDC$MX2q@_O?GUM+Xy2^>t*{ybv
zZh4q2d}>C`)@5O``&}@IwSxJ(-=gEnPat%U;MasB97bc??5UX2GF8i2n`1zrm*f`;ZiaB{B)8SP`Z6Nb>eIFr5~si3Q{0-OXs
zzyzDCD7G#eJ`~(wURe7Or%Wlhu`~|5pCyyNW=VQSm7fT&y@-F#yAhpdMu3X
zY5Q#e-6sri!lN}LciAd%Jg@||ycXx^5-HSw)5zHJ-ce0XHJQ}*jifnp;IF<2dX)Yp
zULo@J!=+m|Et#c^Rq|5sm*U~`bqny;K_BwXRsp805#@-M_0YI4uB5Xt0?d|QXHss>K$*rq
zS|bLq@O?JCk3sjx+e{LTJ-ahr&4*h_V9>Z73LW>P<9M0m|}acOHm>MO0nM9xEU
zU3V>doiM~9ts{8Gd=VU3qmTCTy`(E%81~LpgzYhWH1kC>H0vuuZSF2SzqgRYn987c
z;1=?$w;Bqvl`(#@H9p82ueO5#>d2p_5fL-tu6`y_6L!XWt$1=XDFJ-nD&xoAajB+?
z4`%KhCs41?(2BeW8rk5E>l|tswQZ~LXl)leO&|%&P7tCSpo&RHggK8p1jqk79Z-C1
zPxfwXV@}rHMQAtco_`tOlK+s$
z@*m_`YR_HvzW})WUjXzpUj~u)$M-aT&?8glOXG(E1=KY=N}lId&_j;K%a
z`Nxfw_FUV4+z_`JXMn5z0l;Yje0=(|v!TFzKS}W|po;1$ko09g&?TXms8fpbHO^wX
zj~|ZRc~6ea5~P~L&sqJKTu|Uoq)l^Ile^^+bm475B4DD6@_%jN=s9P2^(h6;JR8Q#
zA0kj?_G?-4Jib>)`K31DP_+HDl_phVZ0xiMnc@}+Gc7r{C
zyNq@!-m5=zL!X?wca{E_&PR4{wucI?7vYqQzb{7mZDcJq1z4y*It)cj>+G@YWj
z&-{QzBpbT>3M*a`O1syK0jY>$Z55=b%8JF*_1s+YSkj10a}r>k?**H`5K^muaj&cX
zWUC^xYISp+UE?WcG5;S%L&KupJ#q>afAQGd`fM0$_E?IQ6Sz{ZRI`Gtm;J$NAJt%2
z-fmzuofont*KUtrzl|8IZwN}3_2if5V)o(YX;2i{PP$WmQ|a9c=!((^vYgvV;`}vH
zc6cMXQs@Te$5vv~_M_&1iqGPHl`?Aeu$K64+>GLH))YI8w}bu|BMep8Z<(#QHxG}#`nt-K3zNxjVa=K{E(Cm2uo)H3tdbm7qTR%)HT
z4j0F{ApLw59FE1K{Um8fSfoxHvUOqIi!J0;&P?j@S7`m?rfXb+<(9%;%4n
z3tbO~QKt#D-w*SNUqyIzc1)8$Cz^6{khv+1A@68j$-5UoxQgEySTgfyKr}S
z{`0H{Pi@EXqkP<*!Q5SrBIBO|xBL2yKNf6%x9IRkelHVByl@Bs~CLXT<9O
literal 0
HcmV?d00001
diff --git a/examples/sgnn/model1.pth b/examples/sgnn/test_backend/model1.pth
similarity index 100%
rename from examples/sgnn/model1.pth
rename to examples/sgnn/test_backend/model1.pth
diff --git a/examples/sgnn/mse_testing.xvg b/examples/sgnn/test_backend/mse_testing.xvg
similarity index 100%
rename from examples/sgnn/mse_testing.xvg
rename to examples/sgnn/test_backend/mse_testing.xvg
diff --git a/examples/sgnn/test_backend/peg4.pdb b/examples/sgnn/test_backend/peg4.pdb
new file mode 100644
index 000000000..2c11081d1
--- /dev/null
+++ b/examples/sgnn/test_backend/peg4.pdb
@@ -0,0 +1,63 @@
+REMARK
+CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1
+ATOM 1 C00 TER 1 -2.962 3.637 -1.170
+ATOM 2 H01 TER 1 -2.608 4.142 -0.296
+ATOM 3 H02 TER 1 -4.032 3.635 -1.171
+ATOM 4 O03 TER 1 -2.484 2.289 -1.168
+ATOM 5 C04 TER 1 -2.961 1.615 0.000
+ATOM 6 H05 TER 1 -2.604 0.606 0.000
+ATOM 7 H06 TER 1 -2.604 2.119 0.874
+ATOM 8 H07 TER 1 -4.031 1.615 0.000
+ATOM 9 C00 INT 2 -2.449 6.384 -3.596
+ATOM 10 H01 INT 2 -2.804 5.879 -4.470
+ATOM 11 H02 INT 2 -1.379 6.386 -3.595
+ATOM 12 O03 INT 2 -2.927 5.710 -2.429
+ATOM 13 C04 INT 2 -2.448 4.362 -2.427
+ATOM 14 H05 INT 2 -2.803 3.856 -3.301
+ATOM 15 H06 INT 2 -1.378 4.364 -2.425
+ATOM 16 C00 INT 3 -2.966 9.857 -4.767
+ATOM 17 H01 INT 3 -2.612 10.363 -3.893
+ATOM 18 H02 INT 3 -4.036 9.855 -4.768
+ATOM 19 O03 INT 3 -2.488 8.509 -4.765
+ATOM 20 C04 INT 3 -2.965 7.835 -3.597
+ATOM 21 H05 INT 3 -2.610 8.340 -2.724
+ATOM 22 H06 INT 3 -4.035 7.833 -3.599
+ATOM 23 C00 TER 4 -2.452 10.582 -6.024
+ATOM 24 H01 TER 4 -2.807 10.077 -6.898
+ATOM 25 H02 TER 4 -1.382 10.584 -6.022
+ATOM 26 O03 TER 4 -2.931 11.930 -6.026
+ATOM 27 C04 TER 4 -2.453 12.604 -7.193
+ATOM 28 H05 TER 4 -2.808 12.099 -8.067
+ATOM 29 H06 TER 4 -2.812 13.613 -7.194
+ATOM 30 H07 TER 4 -1.383 12.606 -7.192
+TER
+CONECT 5 6
+CONECT 5 7
+CONECT 5 8
+CONECT 5 4
+CONECT 4 1
+CONECT 1 2
+CONECT 1 3
+CONECT 1 13
+CONECT 13 14
+CONECT 13 15
+CONECT 13 12
+CONECT 12 9
+CONECT 9 10
+CONECT 9 11
+CONECT 9 20
+CONECT 20 21
+CONECT 20 22
+CONECT 20 19
+CONECT 19 16
+CONECT 16 17
+CONECT 16 18
+CONECT 16 23
+CONECT 23 24
+CONECT 23 25
+CONECT 23 26
+CONECT 26 27
+CONECT 27 28
+CONECT 27 29
+CONECT 27 30
+END
diff --git a/examples/sgnn/pth2pickle.py b/examples/sgnn/test_backend/pth2pickle.py
similarity index 100%
rename from examples/sgnn/pth2pickle.py
rename to examples/sgnn/test_backend/pth2pickle.py
diff --git a/examples/sgnn/test_backend/ref_out b/examples/sgnn/test_backend/ref_out
new file mode 100644
index 000000000..96bc1e62f
--- /dev/null
+++ b/examples/sgnn/test_backend/ref_out
@@ -0,0 +1,37 @@
+Energy: -21.588394
+Force
+[[ 90.02814 2.0374336 35.38877 ]
+ [ -98.410095 -1.6865425 -30.066338 ]
+ [ 48.29245 31.675808 -43.390694 ]
+ [ 59.717484 -35.94304 50.599678 ]
+ [ -24.63767 218.36092 168.47194 ]
+ [ 43.258293 81.24294 -87.22882 ]
+ [ -67.66767 -17.780457 -5.6038494 ]
+ [ -22.928284 -302.96246 -123.14815 ]
+ [ 306.24683 -21.33866 -156.95491 ]
+ [ -4.715515 13.664352 -23.222527 ]
+ [-258.61304 -26.577957 85.58963 ]
+ [ -10.179474 106.21161 64.846924 ]
+ [-210.20566 -52.107193 58.04005 ]
+ [ 118.68472 -8.033836 -81.18109 ]
+ [ 44.02272 -34.508667 46.852356 ]
+ [-214.84206 115.90286 -227.59117 ]
+ [ 44.243336 -7.151741 26.06369 ]
+ [ 87.46674 38.574554 192.17757 ]
+ [ 27.345726 -58.87986 -44.685863 ]
+ [ -83.354774 -29.714098 214.93097 ]
+ [ -71.111305 34.880676 -77.53289 ]
+ [ 141.12836 49.28147 -97.597305 ]
+ [-220.25613 -134.58449 -23.567059 ]
+ [ 75.2593 58.432755 -63.99505 ]
+ [ 123.56466 -82.0066 94.63971 ]
+ [ 57.822285 17.07631 -53.788273 ]
+ [ -73.37115 0.50865555 16.240654 ]
+ [ 54.86133 97.53715 73.672806 ]
+ [ -23.997787 -73.92179 -13.749107 ]
+ [ 62.348286 21.809956 25.78839 ]]
+Batched Energies:
+[-21.653 -39.830627 9.988983 -48.292953 -32.959183 -49.7164
+ -47.617737 -51.76767 -37.42943 -35.06703 -46.111145 -31.748154
+ -6.939003 -5.1853027 -27.427734 -44.695312 -52.027237 3.1541443
+ -72.8221 -28.33014 ]
diff --git a/examples/sgnn/test_backend/run.py b/examples/sgnn/test_backend/run.py
new file mode 100755
index 000000000..f87c887b5
--- /dev/null
+++ b/examples/sgnn/test_backend/run.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+import sys
+import numpy as np
+import jax.numpy as jnp
+import jax.lax as lax
+from jax import vmap, value_and_grad
+import dmff
+from dmff.sgnn.gnn import MolGNNForce
+from dmff.utils import jit_condition
+from dmff.sgnn.graph import MAX_VALENCE
+from dmff.sgnn.graph import TopGraph, from_pdb
+import pickle
+import re
+from collections import OrderedDict
+from functools import partial
+
+
+if __name__ == '__main__':
+ # params = load_params('benchmark/model1.pickle')
+ G = from_pdb('peg4.pdb')
+ model = MolGNNForce(G, nn=1)
+ model.load_params('model1.pickle')
+ E = model.get_energy(G.positions, G.box, model.params)
+
+ with open('set_test_lowT.pickle', 'rb') as ifile:
+ data = pickle.load(ifile)
+
+ # pos = jnp.array(data['positions'][0:100])
+ # box = jnp.tile(jnp.eye(3) * 50, (100, 1, 1))
+ pos = jnp.array(data['positions'][0])
+ box = jnp.eye(3) * 50
+
+ # energies = model.batch_forward(pos, box, model.params)
+ E, F = value_and_grad(model.get_energy, argnums=(0))(pos, box, model.params)
+ F = -F
+ print('Energy:', E)
+ print('Force')
+ print(F)
+
+ # test batch processing
+ pos = jnp.array(data['positions'][:20])
+ box = jnp.tile(box, (20, 1, 1))
+ E = model.batch_forward(pos, box, model.params)
+ print('Batched Energies:')
+ print(E)
diff --git a/examples/sgnn/set_test.pickle b/examples/sgnn/test_backend/set_test.pickle
similarity index 100%
rename from examples/sgnn/set_test.pickle
rename to examples/sgnn/test_backend/set_test.pickle
diff --git a/examples/sgnn/set_test_lowT.pickle b/examples/sgnn/test_backend/set_test_lowT.pickle
similarity index 100%
rename from examples/sgnn/set_test_lowT.pickle
rename to examples/sgnn/test_backend/set_test_lowT.pickle
diff --git a/examples/sgnn/test.py b/examples/sgnn/test_backend/test.py
similarity index 100%
rename from examples/sgnn/test.py
rename to examples/sgnn/test_backend/test.py
diff --git a/examples/sgnn/test_data.xvg b/examples/sgnn/test_backend/test_data.xvg
similarity index 100%
rename from examples/sgnn/test_data.xvg
rename to examples/sgnn/test_backend/test_data.xvg
diff --git a/examples/sgnn/train.py b/examples/sgnn/test_backend/train.py
similarity index 100%
rename from examples/sgnn/train.py
rename to examples/sgnn/test_backend/train.py
diff --git a/examples/water_fullpol/monopole_nonpol/run.py b/examples/water_fullpol/monopole_nonpol/run.py
index 3b1b4799f..617e96e76 100755
--- a/examples/water_fullpol/monopole_nonpol/run.py
+++ b/examples/water_fullpol/monopole_nonpol/run.py
@@ -13,18 +13,19 @@
H = Hamiltonian('forcefield.xml')
pdb = app.PDBFile("pair.pdb")
- rc = 6
+ rc = 0.6
# generator stores all force field parameters
params = H.getParameters()
- pot_pme = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom).dmff_potentials['ADMPPmeForce']
+ pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer)
+ pot_pme = pots.dmff_potentials['ADMPPmeForce']
# construct inputs
- positions = jnp.array(pdb.positions._value) * 10
+ positions = jnp.array(pdb.positions._value)
a, b, c = pdb.topology.getPeriodicBoxVectors()
- box = jnp.array([a._value, b._value, c._value]) * 10
+ box = jnp.array([a._value, b._value, c._value])
# neighbor list
- nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map)
+ nbl = nblist.NeighborList(box, rc, pots.meta['cov_map'])
nbl.allocate(positions)
E_pme, F_pme = value_and_grad(pot_pme)(positions, box, nbl.pairs, params)
diff --git a/examples/water_fullpol/monopole_polarizable/run.py b/examples/water_fullpol/monopole_polarizable/run.py
index 808ee5801..560cd8da0 100755
--- a/examples/water_fullpol/monopole_polarizable/run.py
+++ b/examples/water_fullpol/monopole_polarizable/run.py
@@ -13,18 +13,19 @@
H = Hamiltonian('forcefield.xml')
pdb = app.PDBFile("waterbox_31ang.pdb")
- rc = 6
+ rc = 0.6
# generator stores all force field parameters
params = H.getParameters()
- pot_pme = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom).dmff_potentials['ADMPPmeForce']
+ pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer)
+ pot_pme = pots.dmff_potentials['ADMPPmeForce']
# construct inputs
- positions = jnp.array(pdb.positions._value) * 10
+ positions = jnp.array(pdb.positions._value)
a, b, c = pdb.topology.getPeriodicBoxVectors()
- box = jnp.array([a._value, b._value, c._value]) * 10
+ box = jnp.array([a._value, b._value, c._value])
# neighbor list
- nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map)
+ nbl = nblist.NeighborList(box, rc, pots.meta['cov_map'])
nbl.allocate(positions)
E_pme, F_pme = value_and_grad(pot_pme)(positions, box, nbl.pairs, params)
diff --git a/examples/water_fullpol/quadrupole_nonpol/run.py b/examples/water_fullpol/quadrupole_nonpol/run.py
index b408792aa..0b6fe6394 100755
--- a/examples/water_fullpol/quadrupole_nonpol/run.py
+++ b/examples/water_fullpol/quadrupole_nonpol/run.py
@@ -14,20 +14,20 @@
H = Hamiltonian('forcefield.xml')
app.Topology.loadBondDefinitions("residues.xml")
pdb = app.PDBFile("waterbox_31ang.pdb")
- rc = 6
+ rc = 0.6
# generator stores all force field parameters
params = H.getParameters()
- pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom)
+ pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer)
pot_disp = pots.dmff_potentials['ADMPDispForce']
pot_pme = pots.dmff_potentials['ADMPPmeForce']
# construct inputs
- positions = jnp.array(pdb.positions._value) * 10
+ positions = jnp.array(pdb.positions._value)
a, b, c = pdb.topology.getPeriodicBoxVectors()
- box = jnp.array([a._value, b._value, c._value]) * 10
+ box = jnp.array([a._value, b._value, c._value])
# neighbor list
- nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map)
+ nbl = nblist.NeighborList(box, rc, pots.meta["cov_map"])
nbl.allocate(positions)
diff --git a/examples/water_fullpol/run.py b/examples/water_fullpol/run.py
index dccf92921..f024a8d66 100755
--- a/examples/water_fullpol/run.py
+++ b/examples/water_fullpol/run.py
@@ -13,18 +13,18 @@
H = Hamiltonian('forcefield.xml')
app.Topology.loadBondDefinitions("residues.xml")
pdb = app.PDBFile("waterbox_31ang.pdb")
- rc = 6
+ rc = 0.6
# generator stores all force field parameters
disp_generator, pme_generator = H.getGenerators()
- pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4)
+ pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer, ethresh=5e-4)
# construct inputs
- positions = jnp.array(pdb.positions._value) * 10
+ positions = jnp.array(pdb.positions._value)
a, b, c = pdb.topology.getPeriodicBoxVectors()
- box = jnp.array([a._value, b._value, c._value]) * 10
+ box = jnp.array([a._value, b._value, c._value])
# neighbor list
- nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map)
+ nbl = nblist.NeighborList(box, rc, pots.meta['cov_map'])
nbl.allocate(positions)
diff --git a/tests/data/admp_mono.xml b/tests/data/admp_mono.xml
new file mode 100644
index 000000000..3970ff522
--- /dev/null
+++ b/tests/data/admp_mono.xml
@@ -0,0 +1,34 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/data/admp_nonpol.xml b/tests/data/admp_nonpol.xml
new file mode 100644
index 000000000..7cc1b4653
--- /dev/null
+++ b/tests/data/admp_nonpol.xml
@@ -0,0 +1,40 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/data/peg4.pdb b/tests/data/peg4.pdb
new file mode 100644
index 000000000..eee06d7da
--- /dev/null
+++ b/tests/data/peg4.pdb
@@ -0,0 +1,64 @@
+HEADER
+TITLE MDANALYSIS FRAME 0: Created by PDBWriter
+CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1
+ATOM 1 C00 TER X 1 47.381 10.286 49.808 0.00 1.00 SYST
+ATOM 2 H01 TER X 1 47.251 11.255 50.307 0.00 1.00 SYST
+ATOM 3 H02 TER X 1 46.907 9.487 50.425 0.00 1.00 SYST
+ATOM 4 O03 TER X 1 48.814 10.202 49.785 0.00 1.00 SYST
+ATOM 5 C04 TER X 1 49.336 9.203 50.665 0.00 1.00 SYST
+ATOM 6 H05 TER X 1 50.344 9.329 51.054 0.00 1.00 SYST
+ATOM 7 H06 TER X 1 48.796 9.176 51.611 0.00 1.00 SYST
+ATOM 8 H07 TER X 1 49.296 8.320 50.177 0.00 1.00 SYST
+ATOM 9 C00 INT X 2 46.552 8.760 46.601 0.00 1.00 SYST
+ATOM 10 H01 INT X 2 46.737 9.609 45.939 0.00 1.00 SYST
+ATOM 11 H02 INT X 2 45.532 8.628 46.649 0.00 1.00 SYST
+ATOM 12 O03 INT X 2 47.247 8.976 47.799 0.00 1.00 SYST
+ATOM 13 C04 INT X 2 46.919 10.250 48.371 0.00 1.00 SYST
+ATOM 14 H05 INT X 2 47.190 11.176 47.880 0.00 1.00 SYST
+ATOM 15 H06 INT X 2 45.801 10.369 48.307 0.00 1.00 SYST
+ATOM 16 C00 INT X 3 46.760 5.982 44.153 0.00 1.00 SYST
+ATOM 17 H01 INT X 3 47.759 6.173 43.770 0.00 1.00 SYST
+ATOM 18 H02 INT X 3 46.121 5.894 43.168 0.00 1.00 SYST
+ATOM 19 O03 INT X 3 46.268 7.098 44.918 0.00 1.00 SYST
+ATOM 20 C04 INT X 3 47.139 7.493 45.949 0.00 1.00 SYST
+ATOM 21 H05 INT X 3 47.292 6.726 46.769 0.00 1.00 SYST
+ATOM 22 H06 INT X 3 48.124 7.662 45.625 0.00 1.00 SYST
+ATOM 23 C00 TER X 4 46.610 4.692 44.880 0.00 1.00 SYST
+ATOM 24 H01 TER X 4 45.686 4.613 45.520 0.00 1.00 SYST
+ATOM 25 H02 TER X 4 47.444 4.603 45.516 0.00 1.00 SYST
+ATOM 26 O03 TER X 4 46.501 3.674 43.869 0.00 1.00 SYST
+ATOM 27 C04 TER X 4 45.802 2.493 44.226 0.00 1.00 SYST
+ATOM 28 H05 TER X 4 45.959 1.651 43.497 0.00 1.00 SYST
+ATOM 29 H06 TER X 4 46.125 2.280 45.251 0.00 1.00 SYST
+ATOM 30 H07 TER X 4 44.695 2.638 44.209 0.00 1.00 SYST
+CONECT 1 2 3 4 13
+CONECT 2 1
+CONECT 3 1
+CONECT 4 1 5
+CONECT 5 4 6 7 8
+CONECT 6 5
+CONECT 7 5
+CONECT 8 5
+CONECT 9 10 11 12 20
+CONECT 10 9
+CONECT 11 9
+CONECT 12 9 13
+CONECT 13 1 12 14 15
+CONECT 14 13
+CONECT 15 13
+CONECT 16 17 18 19 23
+CONECT 17 16
+CONECT 18 16
+CONECT 19 16 20
+CONECT 20 9 19 21 22
+CONECT 21 20
+CONECT 22 20
+CONECT 23 16 24 25 26
+CONECT 24 23
+CONECT 25 23
+CONECT 26 23 27
+CONECT 27 26 28 29 30
+CONECT 28 27
+CONECT 29 27
+CONECT 30 27
+END
diff --git a/tests/data/peg_sgnn.xml b/tests/data/peg_sgnn.xml
new file mode 100644
index 000000000..206326d1e
--- /dev/null
+++ b/tests/data/peg_sgnn.xml
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/data/sgnn_model.pickle b/tests/data/sgnn_model.pickle
new file mode 100644
index 0000000000000000000000000000000000000000..0c3959cd9d0ef4fac155676861aa6743acb96c99
GIT binary patch
literal 17100
zcmYhjc|28J^gmAKAyY_*1~L^=63$*Xr4mVLpj0v>4d|LmrOb2Y$XF7RREWaa>!!>p
zB}p1I7fmWEY4~}b&*%Ame)o_2y03HhzH6O(&RXyNUhBP2h=7}$&z?Qo-TZg@c>9Ul
z`MPiS-R^F=)6HL;%co<{<=1xP=i}qs$DQEj9pJS$NZ-xJce}n!ym~N``>}?{y}@U
zi*v;tCyZP1n9r54;j`h7=1SUgCu-XW{A-)xO08P8%KGp4>)$sUK7X#XYwQw11rJ#c
zSH_*^@^&ulkCeM9|y+js8p^ykWX3V8l&dy?n4VR?J5!as(!ZRSUD6+I`$E)JQ-
z$G1T+fcBgerYQfCJnLyCU5n1L30H2haA*$%7`q}YKMGyJIdm3Jif)~eM88DUuz!uc
zNNAHD+&Oj~KmEz4?{+8A_PLsLs9c+veQ%*5X%DISqnE5w(ZYq^-y~7^*j0K*{~jG?
zo0*RZ6-;YuG?Mo3Bsg^!h6>seBNG)o+1p0szEs1!-xHuICy&g3U4{qL>*3;vItT?J
z++%M;@YfPJ%(@T_&OuNe=EM21Ps#JaKr1#@`jJY_vKV!_Lxq4MtdG?`%+
za8xiyTN-Hc7!zA9hPRsjl2rp@OquL=HvHWcni(O?s%SVeJ!8|Tdy+V;@zlZZItF#`
z%);yGiaS(6Hjg^4&&Fz%Od?jim7Q>^1;4KifrysX5WY~0RQ(pitW0lmgrOgJDv<
zx&`%6#crm}t;cD_+kE9m
zyJaiMx=bnTIHUm~7pMV&TLd7zWFD1n4<*z7PJ_7zwouFJ1hY32#o>|pN%G5Q5k|
z)2AF`t9s<9UuG0de00TJbnYBDGh3Z7k-hZs@j1l9PY80l=aC@6e8RtBElO;gi%lnz
zuxr&8(x@wmTG!W-w{ca3(Vj^A&W;iO{7hW^CyG8@zlRtsmu6ZXpJ19w3Sq~myQq2O
z8Qu-6!4wH&tedwSS#?EJ4^#sqrVF>X#BesXDG>gBmUul!hey?BQ(kU8=X2jl8l5!2
zR+g_uA%_esv$n-6-LpV@K$;LmX`I(oMLexz@J50j{T`G_RWpPjZPPN=PBH{}mUd`x
zTL9E;&B&l!CE6qz(y4t>#5H~u8!A{#{(fiC>X{DjOieSLANP#5&5Q-V`IjjFfz9yf
z$piAlV?Fv@ngg>BJ~CfE$qM|9U((AWPpQZUF^;c(F3c&Nj6EfH$?WTva6mB;YG)nc
zOmqo=zn@m4b^3W;@Qx?&!E7}wkbVcf`U^QK)@oQSeTL|zFM`)rQ*j>-lY|K@b%AU$
z@2LdL4eTNZ?uo#$KkMqxDTdeYds9rVE$C){Mx0`HSxCZijVX}*_(}b#vHOJgS()-b
zyGe~729a$u)lltlmARK%2T^cd1&7u@VTE?sQnf7;cn0}47zdj8;8isp-W>*;cLk%1
zya%<2n^;${Bn01D%>kur5qL6rE$(Z%3x$b^)ZZ!s?^>sz+?U@hi{nr_HpG*=u=RVOp
zdJsa6>}RH5+=>P(LuvnX5q!~TOTWj6a~jHO;K$-A7%a4xhanx<_1POlqFYfRWF}8P
zZ34&j`+KswYX$5WJPaN$1Yl)w5BYRIf|;qc0CIgJ$m#cH^zKIhdu=b8-RI8SJMoC*
z*iI$F&zI8TE8;Nq_ZDd0B1x|FYMCn)CeTygmy_yJKDZ*R!g@T)lDN#=yPgj+svk*}fH3U2z>uE9@#Kc)1{B|2i`(J{>8(fah*rWf
zRKBeQ=24$;rCI@WytE?S7r&DC1N+hT9FQf&44q}zLhHKXNxvz}rXI{9DK)!j6tO~q
zM~~>-kKa*njE{aToQQ8q_^HA-P4aXhAFr{$j+QL-Ak7-BWYDr0SAUsI_Pn`8j8<3!
z&p8A?2lX%_Q7xbvkVK4<=itfH3-E&4T-+5Q#0k6l2RD^jgT479W|0;@_!u@(6?;{h
z))x#B!7`la+FuaI`N&Q$D8-wN(>P(9*WswLOZ}T2a&%}}EciU%4oki}l8M}9c>Vh-
z5*D}shnx4{!KsrWf8R9TnS|G@*h)>BN~CCT!b)6YTFc&=mO%Apm{75vt)za|HD>v!
zC}?n0!D-bVb~I85RWE0O^!aL9q4=J(ttM%T$FI9
zHm8qLts7p{MA8!rO+)G5w*%%sQv5K;;0)fGbc>l_kOG%qhfxw10}8&GwDG|ic1Qb7
zeDi%f7OXYq%ovHp)%ANxcA_fE!a5k7kxaHu^r8(7;;i`8-B?^_$~e*p;%wK;4i@gg
z`DYiBA$H1QP6~oRJ{SN#DxZSv
zXcF$4vJ-AwhB1v|h4t#^X3$ubNZNJf0<5AVDAoCk=pQm68~!K)_FRF|*YddBRRSNk
zU5BHxe@RM}HZ9v0Mt6CBqlOwNsGFUc;z3$gY3epTdv6ky_NhJ)`
zPgTMpy8=A=YZgZ%kVnrR%p>BwLVDO`5e77bfL(_=3H@r0i>vy`^MKzp^mQGT%U7dM
zmidrhq87Atwh+YECe@em?!v3nlg$IypP-YRi@|WoZ16nGkLHU?i00}MD2$edqwXp=
zN#!B2*)C1i8koX|3VRf7n@DE3d}en~(4}TpgCsbmoVNEHf%3H&-hMk95dOH1$bOzc
zWA+Q-hU`2#!`l}`H$I{^D}G|_r@y#&!w)K%K0qINzopr|S*&bg7;|qCKO_}duyUd+
z=}^y2GV`4pY@CJc%JUOBULRk4M`W9pZTD5}PtCNA7TPi9}#3USF@k`7l~da#;of7u_UUsg2a?L=*k3
z^oTC78HVVo&5Y@rdsNPGPrYjSUe<7#4edrQhNc`Nn|IHK$Key<+6+yi-=F}(g{xq7
z*9J2BDS~Xzox+;6y}$r+9jz5>QE`$2uhvfwcnOYFv&|I)Ya%gy>Q0&)k`&&tCME!u6xaNaryRg6er7R3VNL!~N)4WX!W@sD;d^
z+tkDB9QX=4)<-v8$8Qduu$+ou<+J_ds#gIvPHScNeba>OUDEiva5`I|xgX_i^FaTI
zFj{#>BgRVBe|;|iTBjsI>7fICK4mK~-M``c_ul%ki>ACfbSyL=_Nv|9()z3d?`tw!j>-|xwYc|AMFZXbJK+ZVF%w-&u5yM~4t
zj%NVhEQZ&fKq}hG;e(OI}YD{$|CnGX5j6lzck|cVqCJ|9C@_65)$o3F^6vwt@2q;_C*@OnSrbD
z+g$^0O#XpUg%Z45FAjj6lss?BuglC^&1iGMpU*j4f<@4ILN{~XK>{4EiD6`I8tiRY
ziEm35(w8^H$?LPC)I+~@JRTEBe{wkqoxGFgTplE`wh7E%whs%;%s54E{Io>qDzJK|
z>DYx8G;}N&cKW5z_wp}DO0_J_kL2LF4TiMW(TI2{*n_Qe03A-H*c6b6MmKEWo~j9H
zaC$-iTo0mq7cgYUReid#(V9tZ$|UgzRiIatpNu35!-9R!!A#>MbEHolhYJnqK2#-@
zxBif$FEn`ZhtxU!;U1iq%BPs3?^^7g?k#lh=UZf|*)mkMe!$@61ZtP}kOT^*kU1}!
zXvW$<#38kr=HK1O>g$vc!RHOsb+#UjRnKP<*9e33N)LwL$%$cW6d2!<2TXd4De26;
zf}RG+bz;XKlW4;%qF38TgT1Qhr)B5a(16X1Vs;I&35=!u6P3t4F@ADbLXs3zXp*Zp
z+F1q7WE9@Ufy$VDXq+g>ep_5kYnGJ2;O4(%)B6j&-R|m~m#tH=(fk^W-O=NnkC)<1
z%~9Yi1AltLY6CnzWX7?Zo(lP8UrGL?$!N}<1h&7-q4Uuwxv{x`b!_^>Y#F^qS28bI
zWy!~eq1C{VL9~K
zEnVhAXCC>;|C+pABtTa^6+zh_iSWbxCHt%9CoSgHk(NG1dN{0xr~l^^6nqt
z46<-KG=qdxCc{12P9la*g6xxgG}5!CmKIUuqk;fVzbMEVK@Z-*Cm9GfQRZx#Yen@^
zyUFb7*NC@#D|Cpxqvxjk(`+MkI*)FJX08cd&bSQPWfLGIcsaaM*no1I`CxuhE*+>j
z#6A%Tqz8R}na5-%(9r3RneM%(q5OUUQNJ=mS9n&!iaF^-&?}FLeq8_qazpfWa}qo;
zJcb!=??{i@9pY~6g)T0uaGk+pV*EXZGC=KXEdAXJpz4R;3<2UD_Ob0_Uqi-)V>
zUGOgH6F%42MT~n^AZNcX)%)H?ecj(t?T-Vf=Pzw0{$?-8KKu@8HEwWPScfzDls->x
zavWX9-#~8t%piZ{1!<>vviZ|1?wDLBLN1Q#!3Fh?jQKokP#Os$zmEySj0XlVA~i~l
zgC7$e;a0fsK9xy3UQc%Tl@oo>bVklr2n3?}sB_sLa>*ecK5fwC%b6H8~C^^H89FHU~zOIdKIJ;C1E##1uEe@pm2gJ3favq|ApK
zW7|MIWDajykvG07noPba$?*JJLb2k`8zL(%MNXm*Jr?ST#b;EB(3k~O&ydF9+H>r-
z)h1Lv>N&IeKtA2VS4+g#6_fiNc?_92lbn{?0M!s$(XHn_NM=?d9m?Pb&nrCh
zg0;isbGIS=D?Ae&n;Xfo-W9NUX$iSmxS!5W5ykZndud|MWL!5Zf`-5eh~X~4z~OVy
z$z%Y3^ffj^$_*D%IdaLOgi$r2$6fwqR{K{AYI{doE_ONN=+XY5rx1e^2&J;jYtb)
z7bypmI|jAHcFBBZvfW;cQyienO6Sw2+)(@yzl9cXB=F6_B$Sh>W8@wvW5f?t*tm8v
z?HX((l~Je3tOKRgN=k}`Ri%&~lV@;l^BpK^i-C1rxgg;f34Pydq5s@&T%3#-9!~R@4M6)bbfkW&}s;a3=UgXx1lDu}>IinQ!
zpQwN#!60fSW`oye9wO1}x(V$sf-f^pQ+bJryv5#6p}2b+QBt4J{xZxUS7lR4ThUSW
zOmjQY-g6cXNxmnEqRVKPeK80J?1tGl*Fdt-d|tpDK5VgUM(d9C`0L>%`Y}odeZO4C
zgSq!fvaCPkJ;(u#Gy~3pqaB#F^axK{*$MXZf1|RSuaZgbF~rHsg#Pv>WW|PSJ}%l-LpfKr%eFk4Zlu#mmAH
zdRgx*7?+e2$CKH_Wd1i4%U;49-5Sbx$7Vv~WLc2=I3HZVE$o
z{pk0E##s!qM;`1Y{abZ$_hmu6+ZRUGHqXNC*9?fFi3Dbi4zgvv&S-V_JZ^EwqIHji
zIGHwDP;r3aBpw_Fk$4df^JXVZjs1?E8=qpv<|Q1Pq3>kNnpn>3X>z>lMv1&Y<;kC9!P2_h-G0i?F1N$7L;cMSUxGKm2(W2QfqBX#r)fgd0RtrI6
zeJU+A?|qV!K~o{{q$IFgsfs%eRlmnWyfz>j{c-672LE|v$Q=?^gDPA}uG
zA!>H=S1gH{lZrh}S#BwM^?UD`)W`C~
zd){>L6rGEH$Is%l>RlN2X%@6?ih);SZXiy-?~#64H6hB{5+*LX2TUKtC!42!VvudR+Kc
z1SNyQ@T0a0>|6N?Hmtgg_iqod;XbKk_WcN0{;C>V@~7~0f)XIraUUl%CmK5a%5cq;
zbQqo%%~tFQLC*wpl4KT6e+_3rtm1erWIZ3a#sy)Kfe{$K^P&cW2@qPNip6pQ*dRWU
z7?y@$^rFRNefM#){p1aNFZ95Ct=~W_f{xkiyP6nmk6vO(Z
z`pl88rRILVbHV<5IaN_NWw#ueNor60zy!zBRN&M_rpDzkR%8G&m&36!;WiBFtz#y%
zB;%iLL9lMK9u)a6hpczvr2nG~Z*#c`&urr{n&M~8I~UXg=U)nO3Z58pL@#PnVXs=#
zoq0!@nT6Hl?ze4FxM~J$IdhgozRMuT7B7WT!$6WR76$D~v#IL0qqH+f1WG?2CbPGR
z!OX~$w8MZ4vf*1<$bCuLjds*WUdtg-eN$l3R2Xje&*UtP3&&Gy_kqWfJ{Z!R
zOfKG$!<|Q3;Njj?@HpfLUGuCEKUcaFAGhhO-GOuD`1n1(R<8qhU-?PT&-+Zyk5!T9
z*UOmCXU{Vgf4>OXw&0ab_{K-BAuB
zS^4N~9tTc$yvfKZey|nqqhF1G)xUDuk5?iTNqLepgnTKby!G>mwMiEd^(#RX?LybM
zh0tND3AjIyv`L-DhL0>%m$zf9lpuDD4AkA3p$p2~cw(ZI#mqes0>Mw**x9rE>9vG7
z>TN8Ce?&ql-@|J9cK99HrX4_qLM9>qrbNg{tsvf3g~Yc$oWj||L~`H$c
zsm2-0Qt1Todg?U9(5LeW6qYHH#!t
zO7MJb4|!o{ie~i;=&Cs42PG+Lm7Gb#2c@ZwiZNaPJrW9b`LN9`TPPknPkYAyOfCt+{NpB|)R%_Z}
zONI!poiI!%jYUJ$;RvWYnNEq_OrBzBBOJMZ7v-f4c;AwuA>PQ8qaY$j$2Wco%Ty*qi)$NPyEy@#
z$!)GJ)6=9AJNiLLK8eNqXUtU&zMz|)>cgSE3cP(?spOQ&QM6n?LLX_2(YYn@7__ty
z&RFHMnm??d`97DYr!4{jj3G~RTRew1I)$gzBLLs&R(eD~$z1LI43r*>B_FHK5eFF~
zYI&}lMksaBq}mo5)8vFc9}URomEyR*MiO7>vSjetdv-5GlgPt2>x;TB(k_j#`nG|I
zczU-2-aos6)$y+<(?U`8$T5OEAa`#~M2NRsUzXg^T?ZXq5bs#6L2A^kN0tcE6
zcde6QZJQ_hZaNDGj15TdHCayHgG-S1jSK7gqCqf7lQSwmiC*?cvRG*%uP39Md{8yQ
z@T%9auA?0Ngr}q5uE~s`N*^)FYQw7e7&2vAH
zzweMSdp3p8=MqGBcK`$WN|?AZ1y0(Ua6D7QI0ea1VS>Q}xIOhQs5J5*CukmWEI+}@
z;>o-@;sWF%dxHd?j|Lr)3Y3t{fjP?c=*vxmv5Y+!Zm^W+^u`9#|6HY~>rH?=*`KTA
z>L~vIvdQ>plB9Z#8)|c#(82d`|dhH8YJ(a&XEj
z8KUKnG4ju+(@m?Y$*I4>F!(!={yn*#+%Yew^ZW(bPH#^Tc{0o-z1zYXnx>N4tUv7L
zRnzg+)t#`!J%Eg5En}5ac9G*|zsVF%7L`BP%5c^Ex$6HX$e8<|AY%laanIBf)!NgP%<>o9FLVXQ{SDY7~&i0t&(fMy;(u=-Lo8tt&g
zPgA4_+hM?4w)Ys8o+(1U_2UgbyJ)N&N&$~%Lx@}vkISFkBwe+!;PrYZbdMKpy~ONr
zW{qci*G6Iudx*JR69TB>9ndv%Wv7hlZaPLD5!K*fRG9Ot&f{jen%^kq3jhA3e#R
zo#n9O&pXEcNj&oJiz0Kj_d=DC7tGQfgk)|sELr&pgAd4%CT%%((m^jU&pi%XeB4n}
zJ(NiFSHhyDiqutHmCi^ng?T#mXr4)r
zrgHF%_#dMEGo3DNs3nUZ7D2`2*QDsF4(@&{#o0}raWu{gddjk~X!HOi8jKLfVpUL7
z$%Gq!Ht?x}
z_bvx;Funs9ngv02%_W+gv6DP1oXT67R1Wc}wz%rUEnw3|@Z9Zqc
zhe6YozlgM@1uTqKf{({Uu)r%B+ru|O`0+}b@xBmtbu5PBPwmu5GM--C@R062yqO-!
ztz)w1-J)6h1EA*ePB{1|fh1;e;DN+U+_~lw_LkYZ3yGc!Jj5wa+EAeuq9K9NO3Lh7w(@7VEN$i31
zBqn4pv2^mM^ByOIU|k^Ne0M!e7)+*>K56FO<3)o}{#zvS$xrsnvnx1c&XnFr;d}9QR-c-N?u1>{!tQj^+=U{U0Z+7RKBD__hOqC`NgVw?wL_a;2d~!QT
zY!Z3&<&_G&5U7vst%M%+41vhfBF0WL4t9PPL9dvZyld*$sKM(sD6TRG?W3BozkHOu
z*!T;s!)c3_AJ!-y!`6o*^ooWNU152F?g-yW<3tv-A8$IE
zTfAGrTkBG8el*np2XZH%-P>poTXvEjeZGTgPRk|_6i?zFUKu&NBAf|-qe)I5eowxq
zj1$Y^7-}LR3IZD^gXN4*L}NuMI{%u3lMbK2{YQR5NQV%HNL?TorWugyX@E3f
z0kCxpHC$wPy`ArvmtJlFC-h;{Ze>ooo-k#<
zbW#=B(B^};JgV5=Ru|0$f+q1yqyyn~Tpx_MkI`G4xtLih!CByY9{-Rn^}5wru%tN<
z`m_pgqUC+^@$_b3bo-$^uA1)Zw?&r;<>XGU5Lqes5c?a_kV%z<#qzSiai~VKp7qrJ
z_e!wlvq1MiY3Ojw2J0OQsf7tYiBy!w&*J%{W*YKiTD17+*mi*1jWoKb7t5T5<;T6Zp3*W~m^z=U4
zz3C=dRToRn>YOFZ@)Teo&xY{!h0%)3LfHB#4I^c3@ugn@>V0V-5iK!n$z}`qV0{_9
zC7MAbbpXtg+`*$}CGmZ2!8re%Ln6B`)$^tgQ31&gFxjj_g61$VaC!yIUzJN@#RurK
zpZl2rqaWsu@1y9Ou-ft5qX_%0{*ZU-vb-a;C)jUc#yCrHK3?k&MgOWROvTlOxG{1L
zd6~4FW0V~Mc{67*DO*_*qiD>QvZ?f5^C@B-(nki{%yD!_G}#@RhezyZkqAoX9foO=#xt!YpHJNLy|ur!U_@9E?FDN<
zTxuhga2P;GP9@2qGl6qPA7A$Ow;c)A$jV)fv0m}1ROs&;o?9wTt0OdbY{3<<~0HKK(!9OUe!gS?|r!Ma*L-li>Ek&(Ox{B3bmv%CjdPAq~Xix=cz@7DSz
z^)9%%tqz)vR>3yOs~{_E0cj^HsX_8WEMOng&akKTmx~KgG)x{2=ijCuH@|=pmvm;O
zTsrdU&cbDTPQ(8A035C~!SzcQfzzRvbcx|^nr^L%5gq_l3xa_SIYzp^WZ|EiC#jda
z0hoh2E;l{|J2uXN`>$8QzTr5u^azL7G0k|+@CQWkO*U;W*$;<*XyOrtCYqWt4IGNq
zaCC&9#QfxgLo>v2RbR?23T-SLT7^7c-PzvK%|U9DwJ`^I-J)E#k~K1w`LOqsx!a
z5c1U&azZ4qckfP|sPTflUb`7g+qy}T7>_A7>roGG
zGyL`&L9XUIrO$s8v!}xKgBNe(g6VR2Iqe=6z#3%gy>O3Erp9*{AFeI@VtJ
zbMv?`W`ze?7}ZK&eC5%in}JNO-$W#9gg99qi^;>+wM=KfI7!M%g2SZ|?9}ypz?T!Urkd(*7vlYnI^ykj{KSDNF|6f9;^&fN5{)f4YYvOhPWiF9iy?+SV>Lqp)AVx}!
z<}~Eu%XW9b=F=qS@Ju-9_!Kk_G?3rYjv$ghk6*A<6h)1{B(m5v@w
zxbmDSmdmF(Lmq6w3rpnA_UF#|KLTa+AA!>U&kcir+%UA~&ix+(B_G8#`iDR*66WK(
z*Pw<2jn9eO$;+s8_BhG7tO(Ocs&yn)eh9KZ(ik{;^s_QhHi1dGf
z_f;LN!B;H1;M`aLO!?%)ana8f_>n!C=lM_s3=+4J%?Zy~
zNBOJ&xagR=ohmL}M&-qx(ks5pvHSRF{jAC)
zv=B%mhSodLZpl3Kah?YYgY-b7AR0AhM56P^PMRPd4zDl|zSKXii~aK$zaKWifJzB?
z_v$M&rb}_&4Oaj@AAxm~_;@?^yK!zj-iW!ey{LEh#<;pvfwS}eRd9aG!%aHTu={N;
znhpNKw^>Iqz+H+asknj7yL)7Ga2>vmkibaab`p)pm<G~10v)_YOm&aoG
zWI1vrR%|?o)M@Uqm_2}x|trH
zs0)47Rv2?O4t`1%f@#okVlZUP**i}f0$vXhzr{{W-*_U@E^9G0IC_%~6v}hbT4ONY
zV;D0ZCDZ=+Q$(aeh<8Unj5dGQN4a@XbiTkRGLn#jN8DY`|Yp|i)8T}!W%zoXHpYF~@BWfyVE
zXBo5_(t=Z;k04!8ikA;*vQAUO%?(H`z63ce7W+v=^OHeZ{1V1p;1;*3X{PN!Ohff(%VH@bE67$K5Vc+h<~Ecgh*RgoeS}8)u-|@-JB(^92?;
z^FgC-D6>bm8Pc4aNL(Es=V7`g%}g4FL(|Wqtn@tMuze-OMQYOfAKGcwTLaAf*~nO^
z@IxYbUgxOH$2pa8m9?~vW0f>}D1Fk!Za*N*am{k3PepdX=+0g;Sgp@_5cP^ieih(#
zt9}5TjxegKy%^^lS`0sTz9X)dA4$pHGiVg7g_%bdW0nZRnJ835UY=ddyBhHm)6cBI
z)yt*mrNj~tmMw(!7L9D?#0b1PuI9h~O`3P}X*znZy+bx^Z^xUj-jUj;-stenmHMB`
zp*8uLASsrIVjm(d<
zMUa<}L<8KH!t%*w5O5KIT?3%OSE`kf3U
zIGlkGu@_KEf`Oxnp_uVnl1TC%GA}a6%zj%H!>(6dq-MhhcIZJEsyr}cyDn$qw>C}E
z$e&1L85`_76#?tMN@Iy$I$gKF1|Li~$SxO@!8KnvbgKtTWUUTyN*1f3l0goAy~6;V
zENYmyI*PcaD;Y{7V^Foin}%k4V_UB(I-DuT?u2orZ;l>t=lOHz{~rJv{SN>d|A#ya
z{y`oSd+x&j1whk(0nmHuI=25@KKfm>9%u6z@V7r6ini>9BcrF;)sKZ?^&Tg1e07mb
z^>l&r$4?@wY{vF&ssVn#c-r~#IjQywMNAu|d-6}frm@HoeG9G)
zr^&aQ*XZtDo56~H$Iyts`4R2ppDY;E9$tm9^G@Izcu3
zCa26<_Ur>YRka09SjeF9rf+m&;YJiVwvX8+o{69D6`|XLg4c*_hc}m&R9DjK!
zj>SbyUY2D6298%J-$*Ti0}sao!{<0zfgM!iWGGlXeNLBzZAF)B2g#cY*|=!2Jes>b
zCO`Td;5SoCFRh8E%RlR&;lNN`RyP9$<6HhT7g)fBrK&J@z979`n?cjeT2R$R3tdKh
z$zkES;CkE{u`L=>rkqCUSqrc`{x>~eI}ImlS>vBDC$MX2q@_O?GUM+Xy2^>t*{ybv
zZh4q2d}>C`)@5O``&}@IwSxJ(-=gEnPat%U;MasB97bc??5UX2GF8i2n`1zrm*f`;ZiaB{B)8SP`Z6Nb>eIFr5~si3Q{0-OXs
zzyzDCD7G#eJ`~(wURe7Or%Wlhu`~|5pCyyNW=VQSm7fT&y@-F#yAhpdMu3X
zY5Q#e-6sri!lN}LciAd%Jg@||ycXx^5-HSw)5zHJ-ce0XHJQ}*jifnp;IF<2dX)Yp
zULo@J!=+m|Et#c^Rq|5sm*U~`bqny;K_BwXRsp805#@-M_0YI4uB5Xt0?d|QXHss>K$*rq
zS|bLq@O?JCk3sjx+e{LTJ-ahr&4*h_V9>Z73LW>P<9M0m|}acOHm>MO0nM9xEU
zU3V>doiM~9ts{8Gd=VU3qmTCTy`(E%81~LpgzYhWH1kC>H0vuuZSF2SzqgRYn987c
z;1=?$w;Bqvl`(#@H9p82ueO5#>d2p_5fL-tu6`y_6L!XWt$1=XDFJ-nD&xoAajB+?
z4`%KhCs41?(2BeW8rk5E>l|tswQZ~LXl)leO&|%&P7tCSpo&RHggK8p1jqk79Z-C1
zPxfwXV@}rHMQAtco_`tOlK+s$
z@*m_`YR_HvzW})WUjXzpUj~u)$M-aT&?8glOXG(E1=KY=N}lId&_j;K%a
z`Nxfw_FUV4+z_`JXMn5z0l;Yje0=(|v!TFzKS}W|po;1$ko09g&?TXms8fpbHO^wX
zj~|ZRc~6ea5~P~L&sqJKTu|Uoq)l^Ile^^+bm475B4DD6@_%jN=s9P2^(h6;JR8Q#
zA0kj?_G?-4Jib>)`K31DP_+HDl_phVZ0xiMnc@}+Gc7r{C
zyNq@!-m5=zL!X?wca{E_&PR4{wucI?7vYqQzb{7mZDcJq1z4y*It)cj>+G@YWj
z&-{QzBpbT>3M*a`O1syK0jY>$Z55=b%8JF*_1s+YSkj10a}r>k?**H`5K^muaj&cX
zWUC^xYISp+UE?WcG5;S%L&KupJ#q>afAQGd`fM0$_E?IQ6Sz{ZRI`Gtm;J$NAJt%2
z-fmzuofont*KUtrzl|8IZwN}3_2if5V)o(YX;2i{PP$WmQ|a9c=!((^vYgvV;`}vH
zc6cMXQs@Te$5vv~_M_&1iqGPHl`?Aeu$K64+>GLH))YI8w}bu|BMep8Z<(#QHxG}#`nt-K3zNxjVa=K{E(Cm2uo)H3tdbm7qTR%)HT
z4j0F{ApLw59FE1K{Um8fSfoxHvUOqIi!J0;&P?j@S7`m?rfXb+<(9%;%4n
z3tbO~QKt#D-w*SNUqyIzc1)8$Cz^6{khv+1A@68j$-5UoxQgEySTgfyKr}S
z{`0H{Pi@EXqkP<*!Q5SrBIBO|xBL2yKNf6%x9IRkelHVByl@Bs~CLXT<9O
literal 0
HcmV?d00001
diff --git a/tests/test_admp/test_compute.py b/tests/test_admp/test_compute.py
index 02d81ead4..be4b9d99d 100644
--- a/tests/test_admp/test_compute.py
+++ b/tests/test_admp/test_compute.py
@@ -24,13 +24,17 @@ def test_init(self):
"""
rc = 4.0
H = Hamiltonian('tests/data/admp.xml')
+ H1 = Hamiltonian('tests/data/admp_mono.xml')
+ H2 = Hamiltonian('tests/data/admp_nonpol.xml')
pdb = app.PDBFile('tests/data/water_dimer.pdb')
potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
+ potential1 = H1.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
+ potential2 = H2.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
- yield potential, H.paramset
+ yield potential, potential1, potential2, H.paramset, H1.paramset, H2.paramset
def test_ADMPPmeForce(self, pot_prm):
- potential, paramset = pot_prm
+ potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
@@ -51,7 +55,7 @@ def test_ADMPPmeForce(self, pot_prm):
def test_ADMPPmeForce_jit(self, pot_prm):
- potential, paramset = pot_prm
+ potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
@@ -67,5 +71,47 @@ def test_ADMPPmeForce_jit(self, pot_prm):
pot = potential.getPotentialFunc(names=["ADMPPmeForce"])
j_pot_pme = jit(value_and_grad(pot))
energy, grad = j_pot_pme(positions, box, pairs, paramset.parameters)
- print(energy)
+ print('hahahah', energy)
np.testing.assert_almost_equal(energy, -35.71585296268245, decimal=1)
+
+
+ def test_ADMPPmeForce_mono(self, pot_prm):
+ potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
+ rc = 0.4
+ pdb = app.PDBFile('tests/data/water_dimer.pdb')
+ positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
+ positions = jnp.array(positions)
+ a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer)
+ box = jnp.array([a, b, c])
+ # neighbor list
+
+ covalent_map = potential1.meta["cov_map"]
+
+ nblist = NeighborList(box, rc, covalent_map)
+ nblist.allocate(positions)
+ pairs = nblist.pairs
+ pot = potential1.getPotentialFunc(names=["ADMPPmeForce"])
+ energy = pot(positions, box, pairs, paramset1)
+ print(energy)
+ np.testing.assert_almost_equal(energy, -66.55921382, decimal=2)
+
+
+ def test_ADMPPmeForce_nonpol(self, pot_prm):
+ potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
+ rc = 0.4
+ pdb = app.PDBFile('tests/data/water_dimer.pdb')
+ positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
+ positions = jnp.array(positions)
+ a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer)
+ box = jnp.array([a, b, c])
+ # neighbor list
+
+ covalent_map = potential2.meta["cov_map"]
+
+ nblist = NeighborList(box, rc, covalent_map)
+ nblist.allocate(positions)
+ pairs = nblist.pairs
+ pot = potential2.getPotentialFunc(names=["ADMPPmeForce"])
+ energy = pot(positions, box, pairs, paramset2)
+ print(energy)
+ np.testing.assert_almost_equal(energy, -31.69025446, decimal=2)
diff --git a/tests/test_sgnn/test_energy.py b/tests/test_sgnn/test_energy.py
new file mode 100644
index 000000000..a771f4508
--- /dev/null
+++ b/tests/test_sgnn/test_energy.py
@@ -0,0 +1,51 @@
+import openmm.app as app
+import openmm.unit as unit
+import numpy as np
+import jax.numpy as jnp
+import numpy.testing as npt
+import pytest
+from dmff import Hamiltonian, NeighborList
+from jax import jit, value_and_grad
+
+class TestADMPAPI:
+
+ """ Test sGNN related generators
+ """
+
+ @pytest.fixture(scope='class', name='pot_prm')
+ def test_init(self):
+ """load generators from XML file
+
+ Yields:
+ Tuple: (
+ ADMPDispForce,
+ ADMPPmeForce, # polarized
+ )
+ """
+ rc = 4.0
+ H = Hamiltonian('tests/data/peg_sgnn.xml')
+ pdb = app.PDBFile('tests/data/peg4.pdb')
+ potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
+
+ yield potential, H.paramset
+
+ def test_sGNN_energy(self, pot_prm):
+ potential, paramset = pot_prm
+ rc = 0.4
+ pdb = app.PDBFile('tests/data/peg4.pdb')
+ positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
+ positions = jnp.array(positions)
+ a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer)
+ box = jnp.array([a, b, c])
+ # neighbor list
+ covalent_map = potential.meta["cov_map"]
+
+ nblist = NeighborList(box, rc, covalent_map)
+ nblist.allocate(positions)
+ pairs = nblist.pairs
+ pot = potential.getPotentialFunc(names=["SGNNForce"])
+ energy = pot(positions, box, pairs, paramset)
+ print(energy)
+ np.testing.assert_almost_equal(energy, -21.81780787, decimal=2)
+
+
From 5acacff429ad104f4edf4ceada2ebd4bef7b0906 Mon Sep 17 00:00:00 2001
From: Wang Xinyan
Date: Sun, 22 Oct 2023 16:41:39 +0800
Subject: [PATCH 3/3] Modified QEQ potential and add JIT support
---
dmff/admp/multipole.py | 2 +-
dmff/admp/pairwise.py | 4 +-
dmff/admp/parser.py | 2 +-
dmff/admp/pme.py | 16 +-
dmff/admp/qeq.py | 430 ++++++++++++++++++++------------
dmff/generators/QeqGenerator.py | 135 ----------
dmff/generators/__init__.py | 1 +
dmff/generators/classical.py | 7 +-
dmff/generators/qeq.py | 214 ++++++++++++++++
tests/data/qeq.xml | 44 +---
10 files changed, 502 insertions(+), 353 deletions(-)
delete mode 100644 dmff/generators/QeqGenerator.py
create mode 100644 dmff/generators/qeq.py
diff --git a/dmff/admp/multipole.py b/dmff/admp/multipole.py
index e863f9e72..1fb8e197f 100644
--- a/dmff/admp/multipole.py
+++ b/dmff/admp/multipole.py
@@ -1,7 +1,7 @@
from functools import partial
import jax.numpy as jnp
-from dmff.utils import jit_condition
+from ..utils import jit_condition
from jax import vmap
# This module deals with the transformations and rotations of multipoles
diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py
index e1510ffb3..8fbc1b856 100755
--- a/dmff/admp/pairwise.py
+++ b/dmff/admp/pairwise.py
@@ -1,8 +1,8 @@
from functools import partial
import jax.numpy as jnp
-from dmff.admp.spatial import v_pbc_shift
-from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs
+from .spatial import v_pbc_shift
+from ..utils import jit_condition, pair_buffer_scales, regularize_pairs
from jax import vmap
DIELECTRIC = 1389.35455846
diff --git a/dmff/admp/parser.py b/dmff/admp/parser.py
index 44e83a0b3..5a9efc857 100644
--- a/dmff/admp/parser.py
+++ b/dmff/admp/parser.py
@@ -4,7 +4,7 @@
import warnings
from collections import defaultdict
import jax.numpy as jnp
-from dmff.admp.multipole import convert_cart2harm
+from .multipole import convert_cart2harm
def read_atom_line(line_full):
"""
diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py
index 98947d0b2..17cdcd617 100755
--- a/dmff/admp/pme.py
+++ b/dmff/admp/pme.py
@@ -7,24 +7,24 @@
from jax import grad, value_and_grad, vmap, jit
from jax.scipy.special import erf, erfc
-from dmff.settings import DO_JIT
-from dmff.common.constants import DIELECTRIC
-from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
-from dmff.admp.settings import POL_CONV, MAX_N_POL
-from dmff.admp.recip import generate_pme_recip, Ck_1
-from dmff.admp.multipole import (
+from ..settings import DO_JIT
+from ..common.constants import DIELECTRIC
+from ..utils import jit_condition, regularize_pairs, pair_buffer_scales
+from .settings import POL_CONV, MAX_N_POL
+from .recip import generate_pme_recip, Ck_1
+from .multipole import (
C1_c2h,
convert_cart2harm,
rot_ind_global2local,
rot_global2local,
rot_local2global
)
-from dmff.admp.spatial import (
+from .spatial import (
v_pbc_shift,
generate_construct_local_frames,
build_quasi_internal
)
-from dmff.admp.pairwise import (
+from .pairwise import (
distribute_scalar,
distribute_v3,
distribute_multipoles,
diff --git a/dmff/admp/qeq.py b/dmff/admp/qeq.py
index 0c3e0cf77..4de5c0be4 100644
--- a/dmff/admp/qeq.py
+++ b/dmff/admp/qeq.py
@@ -1,213 +1,311 @@
-#!/usr/bin/env python
-import sys
-import absl
-import numpy as np
+import numpy as np
import jax.numpy as jnp
-import openmm.app as app
-import openmm.unit as unit
-from dmff.settings import DO_JIT
-from dmff.common.constants import DIELECTRIC
-from dmff.common import nblist
-from jax_md import space, partition
-from jax import grad, value_and_grad, vmap, jit
-from jaxopt import OptaxSolver
-from itertools import combinations
-import jaxopt
+from ..common.constants import DIELECTRIC
+from jax import grad, vmap
+from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce
+from typing import Tuple, List
+from ..settings import PRECISION
+
+if PRECISION == "double":
+ CONST_0 = jnp.array(0, dtype=jnp.float64)
+ CONST_1 = jnp.array(1, dtype=jnp.float64)
+else:
+ CONST_0 = jnp.array(0, dtype=jnp.float32)
+ CONST_1 = jnp.array(1, dtype=jnp.float32)
+
+try:
+ import jaxopt
+except ImportError:
+ print("jaxopt not found, QEQ cannot be used.")
import jax
-import scipy
-import pickle
from jax.scipy.special import erf, erfc
from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
-jax.config.update("jax_enable_x64", True)
+@jit_condition()
+def group_sum(val_list, indices):
+ max_idx = indices.max()
+ exceed = jnp.piecewise(
+ indices,
+ [indices < max_idx, indices >= max_idx],
+ [lambda x: CONST_1, lambda x: CONST_0],
+ )
+ return jnp.sum(val_list[indices] * exceed)
-class ADMPQeqForce:
- def __init__(self, q, lagmt, damp_mod=3, neutral_flag=True, slab_flag=False, constQ=True, pbc_flag = True):
-
- self.damp_mod = damp_mod
- self.neutral_flag = neutral_flag
- self.slab_flag = slab_flag
- self.constQ = constQ
- self.pbc_flag = pbc_flag
- self.q = q
- self.lagmt = lagmt
- return
+group_sum_vmap = jax.vmap(group_sum, in_axes=(None, 0))
- def generate_get_energy(self):
- # q = self.q
- damp_mod = self.damp_mod
- neutral_flag = self.neutral_flag
- constQ = self.constQ
- pbc_flag = self.pbc_flag
- # lagmt = self.lagmt
-
- if eval(constQ) is True:
- e_constraint = E_constQ
- else:
- e_constraint = E_constP
- self.e_constraint = e_constraint
- if eval(damp_mod) is False:
- e_sr = E_sr0
- e_site = E_site
- elif eval(damp_mod) == 2:
- e_sr = E_sr2
- e_site = E_site2
- elif eval(damp_mod) == 3:
- e_sr = E_sr3
- e_site = E_site3
-
- # if pbc_flag is False:
- # e_coul = E_CoulNocutoff
- # else:
- # e_coul = E_coul
- def get_energy(positions, box, pairs, q, lagmt, eta, chi, J, const_list, const_vals,pme_generator):
-
- pos = positions
- ds = ds_pairs(pos, box, pairs, pbc_flag)
- buffer_scales = pair_buffer_scales(pairs)
- kappa = pme_generator.coulforce.kappa
- def E_full(q, lagmt, const_vals, chi, J, pos, box, pairs, eta, ds, buffer_scales):
- e1 = e_constraint(q, lagmt, const_list, const_vals)
- e2 = e_sr(pos*10, box*10 ,pairs , q , eta, ds*10, buffer_scales)
- e3 = e_site( chi, J , q)
- e4 = pme_generator.coulenergy(pos, box ,pairs, q, pme_generator.mscales_coul)
- e5 = E_corr(pos*10, box*10, pairs, q, kappa/10, neutral_flag)
- return e1 + e2 + e3 + e4 + e5
- @jit
- def E_grads(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales):
- n_const = len(const_vals)
- q = b_value[:-n_const]
- lagmt = b_value[-n_const:]
- g1,g2 = grad(E_full,argnums=(0,1))(q, lagmt, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales)
- g = jnp.concatenate((g1,g2))
- return g
-
- def Q_equi(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales):
- rf=jaxopt.ScipyRootFinding(optimality_fun=E_grads,method='hybr',jit=False,tol=1e-10)
- q0,state1 = rf.run(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales)
- return q0,state1
-
- def get_chgs():
- n_const = len(self.lagmt)
- b_value = jnp.concatenate((self.q,self.lagmt))
- q0,state1 = Q_equi(b_value, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales)
- self.q = q0[:-n_const]
- self.lagmt = q0[-n_const:]
- return q0,state1
-
- q0,state1 = get_chgs()
- self.q0 = q0
- self.state1 = state1
- energy = E_full(self.q, self.lagmt, const_vals, chi, J, positions, box, pairs, eta, ds , buffer_scales)
- self.e_grads = E_grads(q0, const_vals, chi, J, positions, box, pairs, eta, ds, buffer_scales)
- self.e_full = E_full
- return energy
+# @jit_condition
+def padding_consts(const_list, max_idx):
+ max_length = max([len(i) for i in const_list])
+ new_const_list = np.zeros((len(const_list), max_length)) + max_idx
+ for ncl, cl in enumerate(const_list):
+ for nitem, item in enumerate(cl):
+ new_const_list[ncl, nitem] = item
+ return jnp.array(new_const_list, dtype=int)
- return get_energy
- def update_env(self, attr, val):
- '''
- Update the environment of the calculator
- '''
- setattr(self, attr, val)
- self.refresh_calculators()
-
-
- def refresh_calculators(self):
- '''
- refresh the energy and force calculators according to the current environment
- '''
- # generate the force calculator
- self.get_energy = self.generate_get_energy()
- self.get_forces = value_and_grad(self.get_energy)
- return
-
+
+@jit_condition()
def E_constQ(q, lagmt, const_list, const_vals):
- constraint = (jnp.sum(q[const_list], axis=1) - const_vals) * lagmt
- return np.sum(constraint)
+ constraint = (group_sum_vmap(q, const_list) - const_vals) * lagmt
+ return jnp.sum(constraint)
+
+
+@jit_condition()
def E_constP(q, lagmt, const_list, const_vals):
- constraint = jnp.sum(q[const_list], axis=1) * const_vals
- return np.sum(constraint)
-
-def E_sr(pos, box, pairs, q, eta, ds, buffer_scales ):
- return 0
-def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales ):
- etasqrt = jnp.sqrt( 2 * ( jnp.array(eta)[pairs[:,0]] **2 + jnp.array(eta)[pairs[:,1]] **2))
- pre_pair = - eta_piecewise(etasqrt,ds) * DIELECTRIC
- pre_self = etainv_piecewise(eta) /( jnp.sqrt(2 * jnp.pi)) * DIELECTRIC
- e_sr_pair = pre_pair * q[pairs[:,0]] * q[pairs[:,1]] /ds * buffer_scales
+ constraint = group_sum_vmap(q, const_list) * const_vals
+ return jnp.sum(constraint)
+
+
+@vmap
+@jit_condition()
+def mask_to_zero(v, mask):
+ return jnp.piecewise(
+ v, [mask < 1e-5, mask >= 1e-5], [lambda x: CONST_0, lambda x: v]
+ )
+
+
+@jit_condition()
+def E_sr(pos, box, pairs, q, eta, ds, buffer_scales):
+ return 0.0
+
+
+@jit_condition()
+def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales):
+ etasqrt = jnp.sqrt(2 * (eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2))
+ pre_pair = -eta_piecewise(etasqrt, ds) * DIELECTRIC
+ pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC
+ e_sr_pair = pre_pair * q[pairs[:, 0]] * q[pairs[:, 1]] / ds * buffer_scales
+ e_sr_pair = mask_to_zero(e_sr_pair, buffer_scales)
e_sr_self = pre_self * q * q
e_sr = jnp.sum(e_sr_pair) + jnp.sum(e_sr_self)
return e_sr
-def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales ):
- etasqrt = jnp.sqrt( jnp.array(eta)[pairs[:,0]] **2 + jnp.array(eta)[pairs[:,1]] **2 )
- pre_pair = - eta_piecewise(etasqrt,ds) * DIELECTRIC
- pre_self = etainv_piecewise(eta) /( jnp.sqrt(2 * jnp.pi)) * DIELECTRIC
- e_sr_pair = pre_pair * q[pairs[:,0]] * q[pairs[:,1]] /ds * buffer_scales
+
+
+@jit_condition()
+def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales):
+ etasqrt = jnp.sqrt(eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2)
+ epiece = eta_piecewise(etasqrt, ds)
+ pre_pair = -epiece * DIELECTRIC
+ pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC
+ e_sr_pair = pre_pair * q[pairs[:, 0]] * q[pairs[:, 1]] / ds
+ e_sr_pair = mask_to_zero(e_sr_pair, buffer_scales)
e_sr_self = pre_self * q * q
e_sr = jnp.sum(e_sr_pair) + jnp.sum(e_sr_self)
return e_sr
-def E_site(chi, J , q ):
- return 0
-def E_site2(chi, J , q ):
- ene = (chi * q + 0.5 * J * q **2 ) * 96.4869
- return np.sum(ene)
-def E_site3(chi, J , q ):
- ene = chi * q *4.184 + J * q **2 *DIELECTRIC * 2 * jnp.pi
- return np.sum(ene)
-
-def E_corr(pos, box, pairs, q, kappa, neutral_flag = True):
- # def E_corr():
+
+@jit_condition()
+def E_site(chi, J, q):
+ return 0.0
+
+
+@jit_condition()
+def E_site2(chi, J, q):
+ ene = (chi * q + 0.5 * J * q**2) * 96.4869
+ return jnp.sum(ene)
+
+
+@jit_condition()
+def E_site3(chi, J, q):
+ ene = chi * q * 4.184 + J * q**2 * DIELECTRIC * 2 * jnp.pi
+ return jnp.sum(ene)
+
+
+@jit_condition(static_argnums=[5])
+def E_corr(pos, box, pairs, q, kappa, neutral_flag=True):
+ # def E_corr():
V = jnp.linalg.det(box)
pre_corr = 2 * jnp.pi / V * DIELECTRIC
- Mz = jnp.sum(q * pos[:,2])
+ Mz = jnp.sum(q * pos[:, 2])
Q_tot = jnp.sum(q)
Lz = jnp.linalg.norm(box[3])
- e_corr = pre_corr * (Mz **2 - Q_tot * (jnp.sum(q * pos[:,2] **2)) - Q_tot **2 * Lz **2 /12)
- if eval(neutral_flag) is True:
- # kappa = pme_potential.pme_force.kappa
- pre_corr_non = - jnp.pi / (2 * V * kappa **2) * DIELECTRIC
- e_corr_non = pre_corr_non * Q_tot **2
+ e_corr = pre_corr * (
+ Mz**2
+ - Q_tot * (jnp.sum(q * pos[:, 2] ** 2))
+ - jnp.power(Q_tot, 2) * jnp.power(Lz, 2) / 12
+ )
+ if neutral_flag:
+ # kappa = pme_potential.pme_force.kappa
+ pre_corr_non = -jnp.pi / (2 * V * kappa**2) * DIELECTRIC
+ e_corr_non = pre_corr_non * Q_tot**2
e_corr += e_corr_non
- return np.sum( e_corr)
+ return jnp.sum(e_corr)
+
+@jit_condition
def E_CoulNocutoff(pos, box, pairs, q, ds):
- e = q[pairs[:,0]] * q[pairs[:,1]] /ds * DIELECTRIC
+ e = q[pairs[:, 0]] * q[pairs[:, 1]] / ds * DIELECTRIC
return jnp.sum(e)
+
+@jit_condition
def E_Coul(pos, box, pairs, q, ds):
- return 0
+ return 0.0
-@jit_condition(static_argnums=(3))
+
+@jit_condition(static_argnums=[3])
def ds_pairs(positions, box, pairs, pbc_flag):
- pos1 = positions[pairs[:,0].astype(int)]
- pos2 = positions[pairs[:,1].astype(int)]
+ pos1 = positions[pairs[:, 0]]
+ pos2 = positions[pairs[:, 1]]
if pbc_flag is False:
dr = pos1 - pos2
else:
box_inv = jnp.linalg.inv(box)
dpos = pos1 - pos2
dpos = dpos.dot(box_inv)
- dpos -= jnp.floor(dpos+0.5)
+ dpos -= jnp.floor(dpos + 0.5)
dr = dpos.dot(box)
- ds = jnp.linalg.norm(dr,axis=1)
+ ds = jnp.linalg.norm(dr, axis=1)
return ds
+
@jit_condition()
-@vmap
-def eta_piecewise(eta,ds):
- return jnp.piecewise(eta, (eta > 1e-4, eta <= 1e-4),
- (lambda x: jnp.array(erfc( ds / eta)), lambda x:jnp.array(0)))
-
+def eta_piecewise(eta, ds):
+ return jnp.piecewise(
+ eta,
+ (eta > 1e-4, eta <= 1e-4),
+ (lambda x: erfc(ds / x), lambda x: x - x),
+ )
+
+
+eta_piecewise = jax.vmap(eta_piecewise, in_axes=(0, 0))
+
+
@jit_condition()
-@vmap
def etainv_piecewise(eta):
- return jnp.piecewise(eta, (eta > 1e-4, eta <= 1e-4),
- (lambda x: jnp.array(1/eta), lambda x:jnp.array(0)))
-
+ return jnp.piecewise(
+ eta,
+ (eta > 1e-4, eta <= 1e-4),
+ (lambda x: 1 / x, lambda x: x - x),
+ )
+
+etainv_piecewise = jax.vmap(etainv_piecewise, in_axes=0)
+
+
+class ADMPQeqForce:
+ def __init__(
+ self,
+ init_q,
+ r_cut: float,
+ kappa: float,
+ K: Tuple[int, int, int],
+ damp_mod: int = 3,
+ const_list: List = [],
+ const_vals: List = [],
+ neutral_flag: bool = True,
+ slab_flag: bool = False,
+ constQ: bool = True,
+ pbc_flag: bool = True,
+ ):
+ if not isinstance(const_vals, jnp.ndarray):
+ self.const_vals = jnp.array(const_vals)
+ else:
+ self.const_vals = const_vals
+ assert len(const_list) == len(
+ const_vals
+ ), "const_list and const_vals must have the same length"
+ n_atoms = len(init_q)
+ self.const_list = padding_consts(const_list, n_atoms)
+ self.init_q = jnp.array(init_q)
+ self.init_lagmt = jnp.ones((len(const_list),))
+
+ self.damp_mod = damp_mod
+ self.neutral_flag = neutral_flag
+ self.slab_flag = slab_flag
+ self.constQ = constQ
+ self.pbc_flag = pbc_flag
+
+ if constQ:
+ e_constraint = E_constQ
+ else:
+ e_constraint = E_constP
+ self.e_constraint = e_constraint
+
+ if damp_mod == 1:
+ self.e_sr = E_sr
+ self.e_site = E_site
+ elif damp_mod == 2:
+ self.e_sr = E_sr2
+ self.e_site = E_site2
+ elif damp_mod == 3:
+ self.e_sr = E_sr3
+ self.e_site = E_site3
+ else:
+ raise ValueError("damp_mod must be 1, 2 or 3")
+
+ if pbc_flag:
+ force = CoulombPMEForce(r_cut, kappa, K)
+ self.kappa = kappa
+ else:
+ force = CoulNoCutoffForce()
+ self.kappa = 1.0
+ self.coul_energy = force.generate_get_energy()
+
+ def generate_get_energy(self):
+ @jit_condition()
+ def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales):
+ e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals)
+ e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, ds * 10, buffer_scales)
+ e3 = self.e_site(chi, J, q)
+ e4 = self.coul_energy(pos, box, pairs, q, mscales)
+ e5 = E_corr(
+ pos * 10.0, box * 10.0, pairs, q, self.kappa / 10, self.neutral_flag
+ )
+ return e1 + e2 + e3 + e4 + e5
+
+ grad_E_full = grad(E_full, argnums=(0, 1))
+
+ @jit_condition()
+ def E_grads(
+ b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
+ ):
+ n_const = len(self.const_vals)
+ q = b_value[:-n_const]
+ lagmt = b_value[-n_const:]
+
+ g1, g2 = grad_E_full(
+ q, lagmt, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
+ )
+ g = jnp.concatenate((g1, g2))
+ return g
+
+ def get_energy(positions, box, pairs, mscales, eta, chi, J):
+ pos = positions
+ ds = ds_pairs(pos, box, pairs, self.pbc_flag)
+ buffer_scales = pair_buffer_scales(pairs)
+
+ n_const = len(self.init_lagmt)
+ b_value = jnp.concatenate((self.init_q, self.init_lagmt))
+ rf = jaxopt.ScipyRootFinding(
+ optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10
+ )
+ b_0, _ = rf.run(
+ b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
+ )
+ b_0 = jax.lax.stop_gradient(b_0)
+ q_0 = b_0[:-n_const]
+ lagmt_0 = b_0[-n_const:]
+ print("Q:", q_0)
+ print("Lagrange_multi:", lagmt_0)
+
+ energy = E_full(
+ q_0,
+ lagmt_0,
+ chi,
+ J,
+ positions,
+ box,
+ pairs,
+ eta,
+ ds,
+ buffer_scales,
+ mscales,
+ )
+ return energy
+
+ return get_energy
diff --git a/dmff/generators/QeqGenerator.py b/dmff/generators/QeqGenerator.py
deleted file mode 100644
index aba48d0f8..000000000
--- a/dmff/generators/QeqGenerator.py
+++ /dev/null
@@ -1,135 +0,0 @@
-#!/usr/bin/env python
-
-import openmm.app as app
-import openmm.unit as unit
-from typing import Tuple
-import numpy as np
-import jax.numpy as jnp
-import jax
-from dmff.api.topology import DMFFTopology
-from dmff.api.paramset import ParamSet
-from dmff.api.xmlio import XMLIO
-from dmff.api.hamiltonian import _DMFFGenerators
-from dmff.utils import DMFFException, isinstance_jnp
-from dmff.admp.qeq import ADMPQeqForce
-from dmff.generators.classical import CoulombGenerator
-from dmff.admp import qeq
-
-
-class ADMPQeqGenerator:
- def __init__(self, ffinfo:dict, paramset: ParamSet):
-
- self.name = 'ADMPQeqForce'
- self.ffinfo = ffinfo
- paramset.addField(self.name)
- self.key_type = None
- keys , params = [], []
- for node in self.ffinfo["Forces"][self.name]["node"]:
- attribs = node["attrib"]
-
- if self.key_type is None and "type" in attribs:
- self.key_type = "type"
- elif self.key_type is None and "class" in attribs:
- self.key_type = "class"
- elif self.key_type is not None and f"{self.key_type}" not in attribs:
- raise ValueError("Keyword 'class' or 'type' cannot be used together.")
- elif self.key_type is not None and f"{self.key_type}" in attribs:
- pass
- else:
- raise ValueError("Cannot find key type for ADMPQeqForce.")
- key = attribs[self.key_type]
- keys.append(key)
-
- chi0 = float(attribs["chi"])
- J0 = float(attribs["J"])
- eta0 = float(attribs["eta"])
-
- params.append([chi0, J0, eta0])
-
- self.keys = keys
- chi = jnp.array([i[0] for i in params])
- J = jnp.array([i[1] for i in params])
- eta = jnp.array([i[2] for i in params])
-
- paramset.addParameter(chi, "chi", field=self.name)
- paramset.addParameter(J, "J", field=self.name)
- paramset.addParameter(eta, "eta", field=self.name)
- # default params
- self._jaxPotential = None
- self.damp_mod = self.ffinfo["Forces"][self.name]["meta"]["DampMod"]
- self.neutral_flag = self.ffinfo["Forces"][self.name]["meta"]["NeutralFlag"]
- self.slab_flag = self.ffinfo["Forces"][self.name]["meta"]["SlabFlag"]
- self.constQ = self.ffinfo["Forces"][self.name]["meta"]["ConstQFlag"]
- self.pbc_flag = self.ffinfo["Forces"][self.name]["meta"]["PbcFlag"]
-
- self.pme_generator = CoulombGenerator(ffinfo, paramset)
-
- def getName(self) -> str:
- """
- Returns the name of the force field.
-
- Returns:
- --------
- str
- The name of the force field.
- """
- return self.name
-
- def overwrite(self, paramset:ParamSet) -> None:
-
- node_indices = [ i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "QeqAtom"]
- chi = paramset[self.name]["chi"]
- J = paramset[self.name]["J"]
- eta = paramset[self.name]["eta"]
- for nnode, key in enumerate(self.keys):
- self.ffinfo["Forces"][self.name]["node"][node_indices[nnode]]["attrib"] = {}
- self.ffinfo["Forces"][self.name]["node"][node_indices[nnode]]["attrib"][f"{self.key_type}"] = key
- chi0 = chi[nnode]
- J0 = J[nnode]
- eta0 = eta[nnode]
- self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["chi"] = str(chi0)
- self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["J"] = str(J0)
- self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["eta"] = str(eta0)
-
-
- def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, charges, const_list, const_vals, map_atomtype):
-
- n_atoms = topdata._numAtoms
- n_residues = topdata._numResidues
-
- q = jnp.array(charges)
- lagmt = np.ones(n_residues)
- b_value = jnp.concatenate((q,lagmt))
- qeq_force = ADMPQeqForce(q, lagmt,self.damp_mod, self.neutral_flag,
- self.slab_flag, self.constQ, self.pbc_flag)
- self.qeq_force = qeq_force
- qeq_energy = qeq_force.generate_get_energy()
-
- self.pme_potential = self.pme_generator.createPotential(topdata, app.PME, nonbondedCutoff )
- def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet) -> jnp.ndarray:
-
- n_atoms = len(positions)
- # map_atomtype = np.zeros(n_atoms)
- eta = np.array(params[self.name]["eta"])[map_atomtype]
- chi = np.array(params[self.name]["chi"])[map_atomtype]
- J = np.array(params[self.name]["J"])[map_atomtype]
- self.eta = jnp.array(eta)
- self.chi = jnp.array(chi)
- self.J = jnp.array(J)
- # coulenergy = self.pme_generator.coulenergy
- # pme_energy = pme_potential(positions, box, pairs, params)
- damp_mod = self.damp_mod
- neutral_flag = self.neutral_flag
- constQ = self.constQ
- pbc_flag = self.pbc_flag
-
- qeq_energy0 = qeq_energy(positions, box, pairs, q, lagmt,
- eta, chi, J,const_list,
- const_vals, self.pme_generator)
- # return pme_energy + qeq_energy0
- return qeq_energy0
-
- self._jaxPotential = potential_fn
- return potential_fn
-
-_DMFFGenerators["ADMPQeqForce"] = ADMPQeqGenerator
diff --git a/dmff/generators/__init__.py b/dmff/generators/__init__.py
index 6f37cf7f0..163f6c12a 100644
--- a/dmff/generators/__init__.py
+++ b/dmff/generators/__init__.py
@@ -1,3 +1,4 @@
from .classical import *
from .admp import *
from .ml import *
+from .qeq import *
\ No newline at end of file
diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py
index e6456d94f..8e9272d24 100644
--- a/dmff/generators/classical.py
+++ b/dmff/generators/classical.py
@@ -1030,10 +1030,14 @@ def overwrite(self, paramset):
# paramset to ffinfo
if self._use_bcc:
bcc_now = paramset[self.name]["bcc"]
+ mask_list = paramset.mask[self.name]["bcc"]
nbcc = 0
for nnode, node in enumerate(self.ffinfo["Forces"][self.name]["node"]):
if node["name"] == "BondChargeCorrection":
+ mask = mask_list[nbcc]
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["bcc"] = bcc_now[nbcc]
+ if mask < 0.999:
+ self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true"
nbcc += 1
def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
@@ -1076,8 +1080,7 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
if nonbondedMethod is app.PME:
cell = topdata.getPeriodicBoxVectors()
box = jnp.array(cell)
- # self.ethresh = kwargs.get("ethresh", 1e-6)
- self.ethresh = kwargs.get("ethresh", 5e-4) #for qeq calculation
+ self.ethresh = kwargs.get("ethresh", 1e-5)
self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm")
self.fourier_spacing = kwargs.get("PmeSpacing", 0.1)
kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh,
diff --git a/dmff/generators/qeq.py b/dmff/generators/qeq.py
new file mode 100644
index 000000000..afc99e2b1
--- /dev/null
+++ b/dmff/generators/qeq.py
@@ -0,0 +1,214 @@
+import openmm.app as app
+import openmm.unit as unit
+from typing import Tuple
+import numpy as np
+import jax.numpy as jnp
+from ..api.topology import DMFFTopology
+from ..api.paramset import ParamSet
+from ..api.xmlio import XMLIO
+from ..api.hamiltonian import _DMFFGenerators
+from ..utils import DMFFException, isinstance_jnp
+from ..admp.qeq import ADMPQeqForce
+from ..generators.classical import CoulombGenerator
+from ..admp.qeq import ADMPQeqForce
+from ..admp.pme import setup_ewald_parameters
+
+
+class ADMPQeqGenerator:
+ def __init__(self, ffinfo: dict, paramset: ParamSet):
+ self.name = "ADMPQeqForce"
+ self.ffinfo = ffinfo
+ paramset.addField(self.name)
+ self.coulomb14scale = float(
+ self.ffinfo["Forces"][self.name]["meta"]["coulomb14scale"])
+
+ self.key_type = None
+ keys, params = [], []
+ qeq_mask = []
+ for node in self.ffinfo["Forces"][self.name]["node"]:
+ attribs = node["attrib"]
+
+ if self.key_type is None and "type" in attribs:
+ self.key_type = "type"
+ elif self.key_type is None and "class" in attribs:
+ self.key_type = "class"
+ elif self.key_type is not None and f"{self.key_type}" not in attribs:
+ raise ValueError("Keyword 'class' or 'type' cannot be used together.")
+ elif self.key_type is not None and f"{self.key_type}" in attribs:
+ pass
+ else:
+ raise ValueError("Cannot find key type for ADMPQeqForce.")
+ key = attribs[self.key_type]
+ keys.append(key)
+
+ chi0 = float(attribs["chi"])
+ J0 = float(attribs["J"])
+ eta0 = float(attribs["eta"])
+
+ if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE":
+ qeq_mask.append(0.0)
+ else:
+ qeq_mask.append(1.0)
+
+ params.append([chi0, J0, eta0])
+
+ self.atom_keys = keys
+ qeq_mask = jnp.array(qeq_mask)
+ chi = jnp.array([i[0] for i in params])
+ J = jnp.array([i[1] for i in params])
+ eta = jnp.array([i[2] for i in params])
+
+ paramset.addParameter(chi, "chi", field=self.name, mask=qeq_mask)
+ paramset.addParameter(J, "J", field=self.name, mask=qeq_mask)
+ paramset.addParameter(eta, "eta", field=self.name, mask=qeq_mask)
+ # default params
+ self._jaxPotential = None
+ meta = self.ffinfo["Forces"][self.name]["meta"]
+ if "DampMod" in meta:
+ self.damp_mod = int(meta["DampMod"])
+ else:
+ self.damp_mod = 3
+
+ def getName(self) -> str:
+ """
+ Returns the name of the force field.
+
+ Returns:
+ --------
+ str
+ The name of the force field.
+ """
+ return self.name
+
+ def overwrite(self, paramset: ParamSet) -> None:
+ node_indices = [
+ i
+ for i in range(len(self.ffinfo["Forces"][self.name]["node"]))
+ if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom"
+ ]
+ chi = paramset[self.name]["chi"]
+ J = paramset[self.name]["J"]
+ eta = paramset[self.name]["eta"]
+ atom_mask = paramset[self.name]["mask"]
+ for nidx, idx in enumerate(node_indices):
+ chi0 = chi[nidx]
+ J0 = J[nidx]
+ eta0 = eta[nidx]
+ mask = atom_mask[nidx]
+ self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["chi"] = str(chi0)
+ self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["J"] = str(J0)
+ self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["eta"] = str(eta0)
+ if mask < 0.999:
+ self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["mask"] = "true"
+
+ def _find_atype_key_index(self, atype: str):
+ for n, i in enumerate(self.atom_keys):
+ if i == atype:
+ return n
+ return None
+
+ def createPotential(
+ self,
+ topdata: DMFFTopology,
+ nonbondedMethod,
+ nonbondedCutoff,
+ **kwargs
+ ):
+
+ methodMap = {
+ app.NoCutoff: "NoCutoff",
+ app.PME: "PME",
+ }
+ if nonbondedMethod not in methodMap:
+ raise DMFFException("Illegal nonbonded method for NonbondedForce")
+
+ # setting for coul force
+ isNoCut = False
+ if nonbondedMethod is app.NoCutoff:
+ isNoCut = True
+
+ mscales_coul = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0,
+ 1.0]) # mscale for PME
+ mscales_coul = mscales_coul.at[2].set(self.coulomb14scale)
+ self.mscales_coul = mscales_coul # for qeq calculation
+
+ if unit.is_quantity(nonbondedCutoff):
+ r_cut = nonbondedCutoff.value_in_unit(unit.nanometer)
+ else:
+ r_cut = nonbondedCutoff
+
+ if not isNoCut:
+ cell = topdata.getPeriodicBoxVectors()
+ box = jnp.array(cell)
+ self.ethresh = kwargs.get("ethresh", 1e-5)
+ self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm")
+ self.fourier_spacing = kwargs.get("PmeSpacing", 0.1)
+ kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh,
+ box,
+ self.fourier_spacing,
+ self.coeff_method)
+ else:
+ kappa, K1, K2, K3 = 1.0, 1, 1, 1
+ K = (K1, K2, K3)
+
+ neutral_flag = kwargs.get("neutral", True)
+ slab_flag = kwargs.get("slab", False)
+ constQ = kwargs.get("constQ", True)
+
+ # top info
+ n_atoms = topdata.getNumAtoms()
+ atoms = [a for a in topdata.atoms()]
+ residues = [r for r in topdata.residues()]
+ n_residues = len(residues)
+ init_q = np.array([a.meta["charge"] for a in atoms])
+ map_idx = []
+ for natom, atom in enumerate(atoms):
+ atype = atom.meta[self.key_type]
+ map_idx.append(self._find_atype_key_index(atype))
+ map_idx = jnp.array(map_idx)
+
+ if "const_list" in kwargs and "const_vals" in kwargs:
+ const_list = kwargs["const_list"]
+ const_vals = kwargs["const_vals"]
+ else:
+ const_list = []
+ const_vals = []
+ for r in residues:
+ aidx = [a.index for a in r.atoms()]
+ const_list.append(aidx)
+ const_vals.append(sum(init_q[aidx]))
+
+ qeq_force = ADMPQeqForce(
+ init_q, r_cut, kappa, K, damp_mod=self.damp_mod,
+ const_list=const_list, const_vals=const_vals,
+ neutral_flag=neutral_flag, slab_flag=slab_flag,
+ constQ=constQ, pbc_flag=(not isNoCut)
+ )
+ qeq_energy = qeq_force.generate_get_energy()
+
+ mscales_coul = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0,
+ 1.0]) # mscale for PME
+ mscales_coul = mscales_coul.at[2].set(self.coulomb14scale)
+
+ def potential_fn(
+ positions: jnp.ndarray,
+ box: jnp.ndarray,
+ pairs: jnp.ndarray,
+ params: ParamSet,
+ ) -> jnp.ndarray:
+ # map_atomtype = np.zeros(n_atoms)
+ eta = params[self.name]["eta"][map_idx]
+ chi = params[self.name]["chi"][map_idx]
+ J = params[self.name]["J"][map_idx]
+
+ qeq_energy0 = qeq_energy(
+ positions, box, pairs, mscales_coul, eta, chi, J
+ )
+ # return pme_energy + qeq_energy0
+ return qeq_energy0
+
+ self._jaxPotential = potential_fn
+ return potential_fn
+
+
+_DMFFGenerators["ADMPQeqForce"] = ADMPQeqGenerator
diff --git a/tests/data/qeq.xml b/tests/data/qeq.xml
index 664609905..a017ea93f 100644
--- a/tests/data/qeq.xml
+++ b/tests/data/qeq.xml
@@ -155,43 +155,11 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+