Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated solvation shell extraction workflow #8

Draft
wants to merge 45 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
c431bea
automated solvation shell extraction workflow
Apr 22, 2024
9c75284
begin addressing Daniel comments
May 15, 2024
ef3e241
operate on pairwise distances, account for PBCs
May 15, 2024
f092257
switch to kabsch rmsd
May 15, 2024
14749ae
add pdb file example
May 15, 2024
4280670
add dependencies to setup
May 15, 2024
09031e2
rename from rmse to rmsd
May 15, 2024
70de4e1
random seed for rmsd set instead of always starting at 0, remove reor…
May 15, 2024
3868637
start adding support for multiatom solutes, not working yet
May 17, 2024
6b8e9a7
fix star import, add al_cl04 test file
May 17, 2024
88832da
update extraction script to be the non-working Al_Cl04 one
May 17, 2024
0af4830
correct charge assignment
May 20, 2024
30ed02c
trying to handle multi-atom solutes, committing before migrating to s…
Jun 27, 2024
1c7fd1a
start schrodinger workflow
Jul 1, 2024
d65860f
can extract shells, now need to loop over solutes, apply pbcs and add…
Jul 1, 2024
ed22d93
loop over solutes
Jul 1, 2024
82fff88
fix loop nesting
Jul 1, 2024
1ef6e1b
add partial charge assignment
Jul 1, 2024
4a55633
fix pbc wrapping index error
Jul 1, 2024
66b4afd
append central atom to shell if not already present, have fully runni…
Jul 2, 2024
ac25b01
cleanup files
Jul 2, 2024
46aab0d
ignore readme
Jul 2, 2024
af2245b
begin solvent-solvent stuff, not working
Jul 2, 2024
e0479be
revert to working solvent-solvent interactions, need to implement PBCs
Jul 2, 2024
9bc0a70
uncomment solvent part
Jul 2, 2024
e31bf22
switch to superimpose version of rmsd
Jul 2, 2024
3ee9aa0
make size checking caps hard
Jul 2, 2024
8e2a1e0
address some more pr comments - more accurate comments, restructuring…
Jul 2, 2024
b2b836d
option to skip solvent centered shells
Jul 2, 2024
cb1139c
remove equals from dir name
Jul 2, 2024
f7bd798
assign partial charges once outside loop
Jul 2, 2024
68dabb2
partial charge assignment once up front
Jul 2, 2024
ca95322
remove size filtering function
Jul 2, 2024
eb12dd8
more explicit version of expand shells function
Jul 2, 2024
f432765
fix central solute molecule bug
Jul 2, 2024
53b6101
simplify filter solute function
Jul 2, 2024
aa7d024
turn off shell expansion if there are no solvents
Jul 4, 2024
34d2f21
lognormal sampling of max size
Jul 4, 2024
a51ea9b
option to subsample solutes/solvents per frame
Jul 4, 2024
3e65d58
speedups for are_conformers
levineds Jul 9, 2024
0e54a94
cleanup
levineds Jul 9, 2024
3966421
some logging, restrict to 100 frames for now
levineds Jul 10, 2024
c877a8b
basic docstrings
levineds Jul 10, 2024
2963480
refactor into separate functions, re-use code
levineds Jul 10, 2024
f225f37
refactor bug
levineds Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ checkpoints
results
logs
*.traj
*.pdb
experimental

# Byte-compiled / optimized / DLL files
Expand Down
19 changes: 19 additions & 0 deletions electrolytes/run_extraction.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash
#TODO: can we automatically extract the names of all the solute atoms from the PDB file so we don't have to re-run this command for each solute?

# Run these scripts from om-data/electrolytes
python solvation_shell_extract.py --pdb_file_path 'testfiles/water_nacl_example.pdb' \
--save_dir 'results' \
--system_name 'NaCl_Water' \
--solute_atom 'NA0' \
--min_coord 2 \
--max_coord 5 \
--top_n 20

python solvation_shell_extract.py --pdb_file_path 'testfiles/water_nacl_example.pdb' \
--save_dir 'results' \
--system_name 'NaCl_Water' \
--solute_atom 'CL0' \
--min_coord 2 \
--max_coord 5 \
--top_n 20
179 changes: 179 additions & 0 deletions electrolytes/solvation_shell_extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import logging
import argparse
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
import MDAnalysis as mda
import nglview as nv
from solvation_analysis.solute import Solute
from solvation_analysis._column_names import *
jeevster marked this conversation as resolved.
Show resolved Hide resolved
from pymatgen.core.structure import Molecule
from solvation_shell_utils import filter_by_rmse, wrap_positions

logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))


def extract_solvation_shells(
jeevster marked this conversation as resolved.
Show resolved Hide resolved
pdb_file_path: str,
save_dir: str,
system_name: str,
solute_atom: str,
min_coord: int,
max_coord: int,
top_n: int,
):
"""
Given a MD trajectory in a PDB file, perform a solvation analysis
on the specified solute to extract the first solvation shell. For each coordination number in the specified range,
extract and save the top_n most diverse snapshots based on a RMSD criterion.

Args:
pdb_file_path: Path to the PDB file containing the MD trajectory
save_dir: Directory in which to save extracted solvation shells
system_name: Name of the system - used for naming the save directory
solute_atom: Name (in the PDB file) of the solute atom type (e.g NA0) with which to perform the solvation analysis
min_coord: Minimum coordination number to consider
max_coord: Maximum coordination number to consider
top_n: Number of snapshots to extract per coordination number.
"""

# Create save directory
os.makedirs(os.path.join(save_dir, system_name, solute_atom), exist_ok=True)

# Initialize MDA Universe
universe = mda.Universe(pdb_file_path)

# Add PBC box
with open(pdb_file_path) as file:
dimension_lines = file.readlines()[1]
a = float(dimension_lines.split()[1])
b = float(dimension_lines.split()[2])
c = float(dimension_lines.split()[3])
universe.dimensions = [a, b, c, 90, 90, 90]

lattices = np.array([a, b, c])[None][None]

# Choose solute atom
solu = universe.select_atoms(f"name {solute_atom}")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I realized that this doesn't yet support multiple atoms in a single solute. I'll get on this.


logging.info("Translating atoms to solute center of mass")
for ts in tqdm(universe.trajectory):
ts.dimensions = universe.dimensions
solu_center = solu.center_of_mass(wrap=True)
dim = ts.triclinic_dimensions
box_center = np.sum(dim, axis=0) / 2
universe.atoms.translate(box_center - solu_center)

universe.atoms.unwrap()

solvent = universe.atoms - solu

solv_anal = Solute.from_atoms(solu, {"solvent": solvent}, solute_name=solute_atom)

# Identify the cutoff for the first solvation shell, based on the MD trajectory
logging.info("Running solvation analysis")
solv_anal.run()

# Plot the RDF
solv_anal.plot_solvation_radius("solute", "solvent")
plt.savefig(os.path.join(save_dir, system_name, solute_atom, "solvation_rdf.png"))
jeevster marked this conversation as resolved.
Show resolved Hide resolved

# There's probably a much faster way to do this
# But for now, we're prototyping, so slow is okay
shells = dict()
for j in solv_anal.speciation.speciation_fraction["solvent"]:
shells[j] = solv_anal.speciation.get_shells({"solvent": j})

# Now let's try getting the most diverse structures for each particular coordination number
# This is also a bit slow, particularly for the more common and/or larger solvent shells
for c in range(min_coord, max_coord + 1):
logging.info(f"Processing shells with coordination number {c}")
os.makedirs(
os.path.join(save_dir, system_name, solute_atom, f"coord={c}"),
exist_ok=True,
)
shell_species = []
shell_positions = []
for index, _ in tqdm(shells[c].iterrows()):
ts = universe.trajectory[index[0]]
universe.atoms.unwrap()
shell = solv_anal.solvation_data.xs(
(index[0], index[1]), level=(FRAME, SOLUTE_IX)
)
shell = solv_anal._df_to_atom_group(shell, solute_index=index[1])
shell = shell.copy()
if len(shell.atoms.elements) > len(shell_species):
shell_species = shell.atoms.elements

shell_positions.append(wrap_positions(shell.atoms.positions, lattices))
jeevster marked this conversation as resolved.
Show resolved Hide resolved

by_num_atoms = defaultdict(list)
for sps in shell_positions:
by_num_atoms[len(sps)].append(sps)

# filter by number of atoms per shell
selections_by_num_atoms = {
num_atoms: filter_by_rmse(shells_with_num_atoms, top_n)
for num_atoms, shells_with_num_atoms in by_num_atoms.items()
}

for (
shell_size,
shell_positions,
) in selections_by_num_atoms.items(): # loop over sizes
for idx, shell_pos in enumerate(shell_positions):
if shell_pos.shape[0] == shell_species.shape[0]:

# Save shell as xyz file
mol = Molecule(shell_species, shell_pos, charge=-1)
jeevster marked this conversation as resolved.
Show resolved Hide resolved
mol.to(
os.path.join(
save_dir,
system_name,
solute_atom,
f"coord={c}",
f"size{shell_size}_selection{idx}.xyz",
),
"xyz",
)


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--pdb_file_path", type=str, help="PDB trajectory file path")
parser.add_argument("--save_dir", type=str, help="Path to save xyz files")
parser.add_argument(
"--system_name", type=str, help="Name of system used for directory naming"
)
parser.add_argument(
"--solute_atom",
type=str,
help="Which solute atom to extract solvation shells for",
)
parser.add_argument(
"--min_coord", type=int, help="Minimum shell coordination number to extract"
)
parser.add_argument(
"--max_coord", type=int, help="Maximum shell coordination number to extract"
)
parser.add_argument(
"--top_n",
type=int,
default=20,
help="Number of most diverse shells to extract per coordination number",
)

args = parser.parse_args()

extract_solvation_shells(
args.pdb_file_path,
args.save_dir,
args.system_name,
args.solute_atom,
args.min_coord,
args.max_coord,
args.top_n,
)
103 changes: 103 additions & 0 deletions electrolytes/solvation_shell_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import copy
import itertools
import numpy as np


def rmse(a, b):
"""
Compute the root mean squared error between two sets of pairwise displacements
Args:
a: numpy array of pairwise displacements, shape [N_atoms, N_atoms, 3]
b: numpy array of pairwise displacements, shape [N_atoms, N_atoms, 3]
"""
return np.sqrt(np.mean(np.sum((b - a) ** 2, axis=2)))


def filter_by_rmse(coords, n=20):
"""
From a set of coordinates, determine the n most diverse, where "most diverse" means "most different, in terms of minimum RMSE.
We operate on pairwise distances so that the function is invariant to translation and rotation.
Note: The Max-Min Diversity Problem is in general NP-hard. This algorithm generates a candidate solution to MMDP for these coords
by assuming that the point 0 is actually in the MMDP set (which there's no reason a priori to assume). As a result, if we shuffled the order of coords, we would likely get a different result.
jeevster marked this conversation as resolved.
Show resolved Hide resolved

Args:
coords: list of np.ndarrays of atom coordinates. Must all have the same shape ([N_atoms, 3]), and must all reflect the same atom order!
Note that this latter requirement shouldn't be a problem, specifically when dealing with IonSolvR data.
n: number of most diverse coordinates to return
"""
pairwise_disps = [
coord[np.newaxis, :, :] - coord[:, np.newaxis, :] for coord in coords
]
states = {0}
min_rmsds = np.array(
[rmse(pairwise_disps[0], pairwise_disp) for pairwise_disp in pairwise_disps]
)
for _ in range(n - 1):
best = np.argmax(min_rmsds)
min_rmsds = np.minimum(
min_rmsds,
np.array(
[
rmse(pairwise_disps[best], pairwise_disp)
for pairwise_disp in pairwise_disps
]
),
)
states.add(best)

return [coords[i] for i in states]


def wrap_positions(positions, lattices):
jeevster marked this conversation as resolved.
Show resolved Hide resolved
"""
Wraps input positions based on periodic boundary conditions.
Args:
positions: numpy array of positions, shape [N_atoms, 3]
lattices: numpy array representing dimensions of simulation box, shape [1, 1, 3]
"""
displacements = positions[:, np.newaxis, :] - positions[np.newaxis, :, :]
idx = np.where(displacements > lattices / 2)[0]
dim = np.where(displacements > lattices / 2)[2]
if idx.shape[0] > 0:
positions[idx, dim] -= lattices[0, 0, dim]
return positions


def reorient(box_dimensions, coords, nsolute, solute_natoms, solvent_natoms):
"""
This function is not currently used in the pdb-file based solvation analysis
TODO: remove once Evan clarifies
jeevster marked this conversation as resolved.
Show resolved Hide resolved
"""

transforms = [
np.array(x) * box_dimensions
for x in itertools.product([0, -1, 1], [0, -1, 1], [0, -1, 1])
]

cog_solu = np.mean(coords[:solute_natoms], axis=0)

n_solvents = int((len(coords) - nsolute * solute_natoms) / solvent_natoms)

final_box = np.zeros(coords.shape)
final_box[: nsolute * solute_natoms] = coords[: nsolute * solute_natoms]

for i in range(n_solvents):
min_dist = np.inf
best_coords = np.zeros((solvent_natoms, 3))

start_index = nsolute * solute_natoms + i * solvent_natoms
coords_i = coords[start_index : start_index + solvent_natoms]
assert len(coords_i) == solvent_natoms

for transform in transforms:
coords_copy = copy.deepcopy(coords_i)
for i in range(solvent_natoms):
coords_copy[i] += transform

cog_solv = np.mean(coords_copy, axis=0)
dist = np.linalg.norm(cog_solv - cog_solu)
if dist < min_dist:
min_dist = dist
best_coords = coords_copy
final_box[start_index : start_index + solvent_natoms] = best_coords
return final_box