Skip to content

Commit

Permalink
Add eps when calc box_inv
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Dec 7, 2023
1 parent 7a7a7e8 commit 80eee20
Show file tree
Hide file tree
Showing 14 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion dmff/admp/disp_pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def disp_pme_real(positions, box, pairs,
# pairs = pairs[pairs[:, 0] < pairs[:, 1]]
pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))

box_inv = jnp.linalg.inv(box)
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)

ri = distribute_v3(positions, pairs[:, 0])
rj = distribute_v3(positions, pairs[:, 1])
Expand Down
2 changes: 1 addition & 1 deletion dmff/admp/mbpol_intra.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@

## compute intra
def onebodyenergy(positions, box):
box_inv = jnp.linalg.inv(box)
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
O = positions[::3]
H1 = positions[1::3]
H2 = positions[2::3]
Expand Down
2 changes: 1 addition & 1 deletion dmff/admp/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def pair_int(positions, box, pairs, mScales, *atomic_params):
buffer_scales = pair_buffer_scales(pairs)
mscales = mscales * buffer_scales
# mscales = mScales[nbonds-1]
box_inv = jnp.linalg.inv(box)
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
dr = ri - rj
dr = v_pbc_shift(dr, box, box_inv)
dr = jnp.linalg.norm(dr, axis=1)
Expand Down
2 changes: 1 addition & 1 deletion dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def pme_real(
"""
pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))
buffer_scales = pair_buffer_scales(pairs[:, :2])
box_inv = jnp.linalg.inv(box)
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
r1 = distribute_v3(positions, pairs[:, 0])
r2 = distribute_v3(positions, pairs[:, 1])
Q_extendi = distribute_multipoles(Q_global, pairs[:, 0])
Expand Down
2 changes: 1 addition & 1 deletion dmff/admp/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def ds_pairs(positions, box, pairs, pbc_flag):
if pbc_flag is False:
dr = pos1 - pos2
else:
box_inv = jnp.linalg.inv(box)
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
dpos = pos1 - pos2
dpos = dpos.dot(box_inv)
dpos -= jnp.floor(dpos + 0.5)
Expand Down
4 changes: 2 additions & 2 deletions dmff/admp/recip.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_recip_vectors(N, box):
3 x 3 matrix, the first index denotes reciprocal lattice vector, the second index is the component xyz.
(lattice vectors arranged in rows)
"""
Nj_Aji_star = (N.reshape((1, 3)) * jnp.linalg.inv(box)).T
Nj_Aji_star = (N.reshape((1, 3)) * jnp.linalg.inv(box + jnp.eye(3) * 1e-36)).T
return Nj_Aji_star


Expand Down Expand Up @@ -396,7 +396,7 @@ def setup_kpts(box, kpts_int):
4 * K, K=K1*K2*K3, contains kx, ky, kz, k^2 for each kpoint
'''
# in this array, a*, b*, c* (without 2*pi) are arranged in column
box_inv = jnp.linalg.inv(box).T
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36).T
# K * 3, coordinate in reciprocal space
kpts = 2 * jnp.pi * kpts_int.dot(box_inv)
ksq = jnp.sum(kpts**2, axis=1)
Expand Down
4 changes: 2 additions & 2 deletions dmff/admp/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def normalize(matrix, axis=1, ord=2):
'''
Normalise a matrix along one dimension
'''
normalised = matrix / jnp.linalg.norm(matrix, axis=axis, keepdims=True, ord=ord)
normalised = matrix / jnp.linalg.norm(matrix + 1e-36, axis=axis, keepdims=True, ord=ord)
return normalised


Expand Down Expand Up @@ -93,7 +93,7 @@ def construct_local_frames(positions, box):

positions = jnp.array(positions)
n_sites = positions.shape[0]
box_inv = jnp.linalg.inv(box)
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)

### Process the x, y, z vectors according to local axis rules
vec_z = pbc_shift(positions[z_atoms] - positions, box, box_inv)
Expand Down
4 changes: 2 additions & 2 deletions dmff/classical/fep.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales, l
eps_scale = eps * mscale_pair

if self.ifPBC:
dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box))
dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box + jnp.eye(3) * 1e-36))

dr_norm = jnp.linalg.norm(dr_vec, axis=1)

Expand Down Expand Up @@ -281,7 +281,7 @@ def get_energy(positions, box, pairs, charges, mscales, lambda_):
pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))
bufScales = pair_buffer_scales(pairs[:, :2])
dr_vec = positions[pairs[:, 0]] - positions[pairs[:, 1]]
dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box))
dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box + jnp.eye(3) * 1e-36))
dr_norm = jnp.linalg.norm(dr_vec, axis=1)

atomCharges = charges[self.map_prm[np.arange(positions.shape[0])]]
Expand Down
4 changes: 2 additions & 2 deletions dmff/classical/inter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
def generate_get_energy(self):
def get_LJ_energy(dr_vec, sig, eps, box):
if self.ifPBC:
dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box))
dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box + jnp.eye(3) * 1e-36))
dr_norm = jnp.linalg.norm(dr_vec, axis=1)

dr_inv = 1.0 / dr_norm
Expand Down Expand Up @@ -224,7 +224,7 @@ def __init__(
def generate_get_energy(self):
def get_rf_energy(dr_vec, chrgprod, box):
if self.ifPBC:
dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box))
dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box + jnp.eye(3) * 1e-36))
dr_norm = jnp.linalg.norm(dr_vec, axis=1)

dr_inv = 1.0 / dr_norm
Expand Down
2 changes: 1 addition & 1 deletion dmff/eann/eann.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def get_energy(positions, box, pairs, params):
buffer_scales = pair_buffer_scales(pairs)

# get distances
box_inv = jnp.linalg.inv(box)
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
ri = distribute_v3(positions, pairs[:, 0])
rj = distribute_v3(positions, pairs[:, 1])
dr = rj - ri
Expand Down
6 changes: 3 additions & 3 deletions dmff/sgnn/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, list_atom_elems, bonds, positions=None, box=None):
self.set_internal_coords_indices()
self.box = box
if box is not None:
self.box_inv = jnp.linalg.inv(box)
self.box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
else:
self.box_inv = None
return
Expand All @@ -109,7 +109,7 @@ def set_box(self, box):
3 * 3: the box array, pbc vectors arranged in rows
'''
self.box = box
self.box_inv = jnp.linalg.inv(box)
self.box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
if hasattr(self, 'subgraphs'):
self._propagate_attr('box')
self._propagate_attr('box_inv')
Expand Down Expand Up @@ -426,7 +426,7 @@ def calc_internal_coords_features(positions, box):
All these variables should be "static" throughout NVE/NVT/NPT simulations
'''

box_inv = jnp.linalg.inv(box)
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)

@jit_condition(static_argnums=())
@partial(vmap, in_axes=(0, None, 0), out_axes=(0))
Expand Down
2 changes: 1 addition & 1 deletion examples/eann/eann.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def get_energy(positions, box, pairs, params):
buffer_scales = pair_buffer_scales(pairs)

# get distances
box_inv = jnp.linalg.inv(box)
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
ri = distribute_v3(positions, pairs[:, 0])
rj = distribute_v3(positions, pairs[:, 1])
dr = rj - ri
Expand Down
2 changes: 1 addition & 1 deletion examples/fluctuated_leading_term_waterff/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def compute_leading_terms(positions,box):
n_atoms = len(positions)
c0 = jnp.zeros(n_atoms)
c6_list = jnp.zeros(n_atoms)
box_inv = jnp.linalg.inv(box)
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
O = positions[::3]
H1 = positions[1::3]
H2 = positions[2::3]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_frontend/test_inter_water.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def dist_pbc(vi, vj, box):
box_inv = np.linalg.inv(box)
box_inv = np.linalg.inv(box + jnp.eye(3) * 1e-36)
drvec = (vi - vj).reshape((1, 3))
unshifted_dsvecs = drvec.dot(box_inv)
dsvecs = unshifted_dsvecs - np.floor(unshifted_dsvecs + 0.5)
Expand Down

0 comments on commit 80eee20

Please sign in to comment.