Skip to content

Commit

Permalink
Merge pull request #163 from Ethan-Norch/devel
Browse files Browse the repository at this point in the history
Implemet CustomGBForce, CustomTorsionForce, Custom1_5BondForce generators in DMFF
  • Loading branch information
KuangYu authored Jan 16, 2024
2 parents eab3d26 + 62db961 commit 2126ba4
Show file tree
Hide file tree
Showing 12 changed files with 2,137 additions and 483 deletions.
64 changes: 64 additions & 0 deletions dmff/classical/inter.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,67 @@ def get_energy(positions, box, pairs, bcc, mscales):
charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy_kernel(positions, box, pairs, charges, mscales)
return get_energy


class CustomGBForce:
def __init__(
self,
map_charge,
map_radius,
map_scale,
epsilon_1=1.0,
epsilon_solv=78.3,
alpha=1,
beta=0.8,
gamma=4.85,
) -> None:
self.map_charge = map_charge
self.map_radius = map_radius
self.map_scale = map_scale
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.exp_solv = epsilon_solv
self.eps_1 = epsilon_1

def generate_get_energy(self):
@jax.jit
def get_energy(positions, box, pairs, Ipairs, charges, radius, scales):
def calI(posList, radMap, scalMap, rhoMap, pairMap):
I = jnp.array([])

for i in range(len(radMap)):
posj = posList[Ipairs[i]]
rhoj = rhoMap[Ipairs[i]]
scalj = scalMap[Ipairs[i]]
posi = posList[i]
rhoi = rhoMap[i]

r = jnp.sqrt(jnp.sum(jnp.power(posi-posj,2),axis=1))
sr2 = rhoj * scalj
D = jnp.abs(r - sr2)
L = jnp.maximum(D, rhoi)
C = 2 * (1 / rhoi - 1 / L) * jnp.heaviside(sr2 - r - rhoi, 1)
U = r + sr2
I = jnp.append(I, jnp.sum(0.5 * jnp.heaviside(r + sr2 - rhoi, 1) * (
1 / L - 1 / U + 0.25 * (1 / U ** 2 - 1 / L ** 2) * (
r - sr2 ** 2 / r) + 0.5 * jnp.log(L / U) / r + C)))

return I

chargeMap = charges[self.map_charge]
radiusMap = radius[self.map_radius]
scalesMap = scales[self.map_scale]
rhoMap = radiusMap - 0.009

# effective radius
IList = calI(positions, radiusMap, scalesMap, rhoMap, Ipairs)
psi = IList*rhoMap
rEff = 1/(1/rhoMap-jnp.tanh(self.alpha*psi-self.beta*jnp.power(psi, 2)+self.gamma*jnp.power(psi, 3))/radiusMap)
Ese = jnp.sum(28.3919551*(radiusMap+0.14)**2*jnp.power(radiusMap/rEff, 6)-0.5*138.935456*(1/self.eps_1-1/self.exp_solv)*chargeMap**2/rEff)
dr_norm = jnp.linalg.norm(positions[pairs[:,0]] - positions[pairs[:,1]], axis=1)
chargepro = chargeMap[pairs[:, 0]] * chargeMap[pairs[:, 1]]
rEffpro = rEff[pairs[:, 0]] * rEff[pairs[:, 1]]
Egb = jnp.sum(-138.935456*(1/self.eps_1-1/self.exp_solv)*chargepro/jnp.sqrt(jnp.power(dr_norm, 2)+rEffpro*jnp.exp(-jnp.power(dr_norm,2)/(4*rEffpro))))
return Ese + Egb
return get_energy
75 changes: 75 additions & 0 deletions dmff/classical/intra.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,78 @@ def refresh_calculators(self):
"""
self.get_energy = self.generate_get_energy()
self.get_forces = value_and_grad(self.get_energy)


class Custom1_5BondJaxForce:
def __init__(self, p1idx, p2idx, prmidx):
self.p1idx = p1idx
self.p2idx = p2idx
self.prmidx = prmidx
self.refresh_calculators()

def generate_get_energy(self):
def get_energy(positions, box, pairs, k, length):
p1 = positions[self.p1idx,:]
p2 = positions[self.p2idx,:]
kprm = k[self.prmidx]
b0prm = length[self.prmidx]
dist = distance(p1, p2)
return jnp.sum(0.5 * kprm * jnp.power(dist - b0prm, 2))

return get_energy

def update_env(self, attr, val):
"""
Update the environment of the calculator
"""
setattr(self, attr, val)
self.refresh_calculators()

def refresh_calculators(self):
"""
refresh the energy and force calculators according to the current environment
"""
self.get_energy = self.generate_get_energy()
self.get_forces = value_and_grad(self.get_energy)


class CustomTorsionJaxForce:
def __init__(self, p1idx, p2idx, p3idx, p4idx, prmidx, order):
self.p1idx = p1idx
self.p2idx = p2idx
self.p3idx = p3idx
self.p4idx = p4idx
self.prmidx = prmidx
self.order = order
self.refresh_calculators()

def generate_get_energy(self):
if len(self.p1idx) == 0:
return lambda positions, box, pairs, k, psi, shift: 0.0
def get_energy(positions, box, pairs, k, psi, shift):
p1 = positions[self.p1idx, :]
p2 = positions[self.p2idx, :]
p3 = positions[self.p3idx, :]
p4 = positions[self.p4idx, :]
kp = k[self.prmidx]
psip = psi[self.prmidx]
shiftp = shift[self.prmidx]
dih = dihedral(p1, p2, p3, p4)
ener = kp * (jnp.cos(self.order * dih - psip)) + shiftp
return jnp.sum(ener)

return get_energy

def update_env(self, attr, val):
"""
Update the environment of the calculator
"""
setattr(self, attr, val)
self.refresh_calculators()

def refresh_calculators(self):
"""
refresh the energy and force calculators according to the current environment
"""
self.get_energy = self.generate_get_energy()
self.get_forces = value_and_grad(self.get_energy)
Loading

0 comments on commit 2126ba4

Please sign in to comment.