diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..e8f121c --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +.pdb filter=lfs diff=lfs merge=lfs -text +*.pdb filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index b61ebc3..0eea8d6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ checkpoints results logs *.traj +*.pdb experimental # Byte-compiled / optimized / DLL files @@ -111,3 +112,21 @@ Local # VS Code .vscode/ + +# testfiles +testfiles/ +testfiles_old/ +*mae + +# Schrodinger stuff +maestro_package/ + +# Jupyter Notebook +.ipynb_checkpoints/ +*ipynb + +# Misc +notes.txt +*prof + + diff --git a/electrolytes/.gitattributes b/electrolytes/.gitattributes new file mode 100644 index 0000000..515f1b0 --- /dev/null +++ b/electrolytes/.gitattributes @@ -0,0 +1 @@ +*pdb filter=lfs diff=lfs merge=lfs -text diff --git a/electrolytes/README.md b/electrolytes/README.md deleted file mode 100644 index e69de29..0000000 diff --git a/electrolytes/run_extraction.sh b/electrolytes/run_extraction.sh new file mode 100755 index 0000000..28cb49d --- /dev/null +++ b/electrolytes/run_extraction.sh @@ -0,0 +1,8 @@ +#!/bin/bash +#TODO: can we automatically extract the names of all the solute atoms from the PDB file so we don't have to re-run this command for each solute? + + +$SCHRODINGER/run python3 -m cProfile -o output100.prof solvation_shell_extract.py --input_dir 'testfiles/1' \ + --save_dir 'results' \ + --system_name 'Li_BF4' + \ No newline at end of file diff --git a/electrolytes/solvation_shell_extract.py b/electrolytes/solvation_shell_extract.py new file mode 100644 index 0000000..20e87b9 --- /dev/null +++ b/electrolytes/solvation_shell_extract.py @@ -0,0 +1,466 @@ +import logging +from typing import List + +logging.basicConfig(level=logging.INFO) + +import argparse +import json +import os +import random +from collections import Counter + +import numpy as np +from schrodinger.application.jaguar.utils import group_with_comparison +from schrodinger.application.matsci import clusterstruct +from schrodinger.comparison import are_conformers +from schrodinger.structure import Structure, StructureReader +from schrodinger.structutils import analyze +from tqdm import tqdm + +from solvation_shell_utils import ( + expand_shell, + filter_by_rmsd, + generate_lognormal_samples, + renumber_molecules_to_match, +) +from utils import validate_metadata_file + + +def extract_solvation_shells( + input_dir: str, + save_dir: str, + system_name: str, + solute_radii: List[float], + skip_solvent_centered_shells: bool, + solvent_radii: List[float], + shells_per_frame: int, + max_shell_size: int, + top_n: int, +): + """ + Given a MD trajectory in a PDB file, perform a solvation analysis + on the specified solute to extract the first solvation shell. + + Args: + input_dir: Path to 1) the PDB file containing the MD trajectory (system_output.pdb) and 2) a metadata file (system_metadata.json) + save_dir: Directory in which to save extracted solvation shells. + system_name: Name of the system - used for naming the save directory. + solute_radii: List of shell radii to extract around solutes. + skip_solvent_centered_shells: Skip extracting solvent-centered shells. + solvent_radii: List of shell radii to extract around solvents. + shells_per_frame: Number of solutes or solvents per MD simulation frame from which to extract candidate shells. + max_shell_size: Maximum size (in atoms) of saved shells. + top_n: Number of snapshots to extract per topology. + """ + + # Read a structure and metadata file + logging.info("Reading structure and metadata files") + + # Read metadata + with open(os.path.join(input_dir, "metadata_system.json")) as f: + metadata = json.load(f) + + validate_metadata_file(metadata) + + partial_charges = np.array(metadata["partial_charges"]) + + solutes = {} + solvents = {} + for res, species, spec_type in zip( + metadata["residue"], metadata["species"], metadata["solute_or_solvent"] + ): + if spec_type == "solute": + solutes[species] = res + elif spec_type == "solvent": + solvents[species] = res + spec_dicts = {"solute": solutes, "solvent": solvents} + solute_resnames = set(solutes.values()) + solvent_resnames = set(solvents.values()) + + # Read structures + structures = list(StructureReader(os.path.join(input_dir, "system_output.pdb")))[ + :100 + ] + # assign partial charges to atoms + logging.info("Assigning partial charges to atoms") + for st in tqdm(structures): + for at, charge in zip(st.atom, partial_charges): + at.partial_charge = charge + + # For each solute: extract shells around the solute of some heuristic radii and bin by composition/graph hash + # Choose the N most diverse in each bin + spec_types = ["solute"] + if not skip_solvent_centered_shells: + spec_types.append("solvent") + + for spec_type in spec_types: + for species, residue in spec_dicts[spec_type].items(): + logging.info(f"Extracting solvation shells around {species}") + for radius in solute_radii: + logging.info(f"Radius = {radius} A") + extracted_shells = [] + for i, st in tqdm( + enumerate(structures), total=len(structures) + ): # loop over timesteps + extracted_shells.extend( + extract_residue_from_structure( + st, + radius, + residue, + spec_type, + solute_resnames, + solvent_resnames, + shells_per_frame, + max_shell_size, + ) + ) + + if spec_type == "solvent": + assert extracted_shells, "No solute-free shells found for solvent" + # DSL: Is the really an assertion error? Or just sad but we should soldier on (i.e. continue) + + # Choose a random subset of shells + random.shuffle(extracted_shells) + extracted_shells = extracted_shells[:1000] + + grouped_shells = group_shells(extracted_shells, spec_type) + + # Now ensure that topologically related atoms are equivalently numbered (up to molecular symmetry) + grouped_shells = [ + renumber_molecules_to_match(items) for items in grouped_shells + ] + + # Now extract the top N most diverse shells from each group + logging.info( + f"Extracting top {top_n} most diverse shells from each group" + ) + final_shells = [] + # example grouping - set of structures + for group_idx, shell_group in tqdm( + enumerate(grouped_shells), total=len(grouped_shells) + ): + filtered = filter_by_rmsd(shell_group, n=top_n) + filtered = [(group_idx, st) for st in filtered] + final_shells.extend(filtered) + + # Save the final shells + logging.info("Saving final shells") + save_path = os.path.join( + save_dir, system_name, species, f"radius_{radius}" + ) + os.makedirs(save_path, exist_ok=True) + for i, (group_idx, st) in enumerate(final_shells): + charge = get_structure_charge(st) + if spec_type == "solute": + fname = os.path.join( + save_path, f"group_{group_idx}_shell_{i}_{charge}.xyz" + ) + elif spec_type == "solvent": + fname = os.path.join(save_path, f"shell_{i}_{charge}.xyz") + + # TODO: seems like this is saving an extra line at the end of the xyz files + # DSL: So what? + st.write(fname) + + +def extract_residue_from_structure( + st: Structure, + radius: float, + residue: str, + spec_type: str, + solute_resnames: list[str], + solvent_resnames: list[str], + shells_per_frame: int, + max_shell_size: int, +) -> Structure: + """ + Extract around a given residue type from a structure by a given radius + + :param st: Structure to extract from + :param radius: distance (in Angstrom) around residue to expand + (initially in the case of solutes) + :param resiude: name of residue to consider + :param spec_type: type of species being extracted, either 'solute' or 'solvent' + :param solute_resnames: list of names of solute residues + :param solvent_resnames: list of names of solvent residues + :param shells_per_frame: number of shells to extract from this structure + :param max_shell_size: maximum number of atoms in a shell + :return: extracted shell structure + """ + # extract all molecules of interest + molecules = [res for res in st.residue if res.pdbres.strip() == residue] + + # Subsample a random set of k solute molecules + if shells_per_frame > 0: + molecules = random.sample(molecules, shells_per_frame) + + central_mol_nums = list({mol.molecule_number for mol in molecules}) + # Extract solvation shells + shells = [ + set(analyze.evaluate_asl(st, f"fillres within {radius} mol {mol_num}")) + for mol_num in central_mol_nums + ] + + if spec_type == "solvent": + # Only keep the shells that have no solute atoms and below a maximum size + solute_atoms = analyze.evaluate_asl(st, f'res {",".join(solute_resnames)}') + shells = [ + (shell, central_mol) + for shell, central_mol in zip(shells, central_mol_nums) + if (not shell.intersection(solute_atoms)) + and len(shell) <= max_shell_size + ] + extracted_shells = [ + extract_contracted_shell(st, at_list, central_mol) + for at_list, central_mol in shells + ] + + elif spec_type == "solute": + extracted_shells = [] + for shell_ats, central_solute in zip(shells, central_mol_nums): + # Now expand the shells + expanded_shell_ats = shell_ats + # If we have a solvent-free system, don't expand shells around solutes, + # because we'll always have solutes and will never terminate + if solvent_resnames: + # TODO: how to choose the mean/scale for sampling? + # Should the mean be set to the number of atoms in the non-expanded shell? + upper_bound = max(len(shell_ats), generate_lognormal_samples()[0]) + upper_bound = min(upper_bound, max_shell_size) + expanded_shell_ats = expand_shell( + st, + shell_ats, + central_solute, + radius, + solute_resnames, + max_shell_size=upper_bound, + ) + expanded_shell = extract_contracted_shell( + st, expanded_shell_ats, central_solute + ) + + assert ( + expanded_shell.atom_total <= max_shell_size + ), "Expanded shell too large" + extracted_shells.append(expanded_shell) + return extracted_shells + + +def extract_contracted_shell( + st: Structure, at_list: list[int], central_mol: int +) -> Structure: + """ + Extract the shell from the structure + + :param st: structure to extract from + :param at_list: list of atom indices that specify shell + :param central_mol: index of central molecule around which we should contract + with respect to PBC to get a non-PBC valid structure + :return: extracted shell + """ + central_at = st.molecule[central_mol].atom[1] + central_at.property["b_m_central"] = True + extracted_shell = st.extract(at_list, copy_props=True) + central_at.property.pop("b_m_central") + + # find index of first atom of the central solute in the sorted shell_ats (adjust for 1-indexing) + central_atom_idx = next( + at for at in extracted_shell.atom if at.property.pop("b_m_central", False) + ).index + + # contract everthing to be centered on our molecule of interest + # (this will also handle if a molecule is split across a PBC) + clusterstruct.contract_structure( + extracted_shell, contract_on_atoms=[central_atom_idx] + ) + return extracted_shell + + +def group_shells(shell_list: list[Structure], spec_type: str) -> list[list[Structure]]: + """ + Partition shells by conformers. + + This checks the topological similarity of the shells to group them. For solvents, + we don't check this topology explicitly but assume it holds if the molecules are + at least isomers of each other. Revise if we have solvent mixtures where the + components are isomers + + :param shell_list: list of structures to be partitioned + :param spec_type: type of species being grouped, either 'solute' or 'solvent' + :return: members of `shell_list` grouped by conformers, all members of a given + sublist are conformers + """ + # Now compare the expanded shells and group them by similarity + # we will get lists of lists of shells where each list of structures are conformers of each other + logging.info("Grouping solvation shells into conformers") + # TODO: speed this up + grouped_shells = group_with_comparison(shell_list, are_isomeric_molecules) + logging.info("Grouped into isomers") + + if spec_type == "solute": + new_grouped_shells = [] + for isomer_group in tqdm(grouped_shells): + new_grouped_shells.extend(groupby_molecules_are_conformers(isomer_group)) + grouped_shells = new_grouped_shells + return grouped_shells + + +def get_structure_charge(st: Structure) -> int: + """ + Get the charge on the structure as the sum of the partial charges + of the atoms + + :param st: Structure to get charge of + :return: charge on structure + """ + charge = sum(at.partial_charge for at in st.atom) + return round(charge) + + +def are_isomeric_molecules(st1: Structure, st2: Structure) -> bool: + """ + Determine if two structures have molecules which are isomers of each other. + + This is stronger than just ensuring that the structures are isomers and should + be sufficient for cases of just solvents as there are no expected topological + differences. + """ + isomers = st1.atom_total == st2.atom_total and st1.mol_total == st2.mol_total + if isomers: + isomers = Counter(at.atomic_number for at in st1.atom) == Counter( + at.atomic_number for at in st2.atom + ) + if isomers: + cnt1 = { + frozenset(Counter(at.atomic_number for at in mol.atom).items()) + for mol in st1.molecule + } + cnt2 = { + frozenset(Counter(at.atomic_number for at in mol.atom).items()) + for mol in st2.molecule + } + isomers = cnt1 == cnt2 + return isomers + + +def groupby_molecules_are_conformers(st_list: list[Structure]) -> list[list[Structure]]: + """ + Given a list of Structures which are assumed to have isomeric molecules, + partition the structures by conformers. + """ + + def are_same_group_counts(group1, group2): + return {count for count, _ in group1} == {count for count, _ in group2} + + def are_groups_conformers(group1, group2): + matched_groups = 0 + for count1, st1 in group1: + for count2, st2 in group2: + if count1 == count2 and are_conformers(st1, st2): + matched_groups += 1 + break + return matched_groups == len(group1) + + if len(st_list) == 1: + return [st_list] + + mol_to_st = {} + # split structures up into lists of molecules + for st in st_list: + mol_list = [mol.extractStructure() for mol in st.molecule] + mol_list.sort(key=lambda x: x.atom_total) + # group those molecules by conformers + grouped_mol_list = group_with_comparison(mol_list, are_conformers) + # represent the structure as the counts of each type of molecule + # and a representative structure + st_mols = frozenset((len(grp), grp[0]) for grp in grouped_mol_list) + mol_to_st[st_mols] = st + + # Group structures by if their counts of molecules are the same + grouped_by_molecular_counts = group_with_comparison( + mol_to_st.keys(), are_same_group_counts + ) + conf_groups = [] + # Group structures by if their molecules (and their counts) are conformers + for groups in grouped_by_molecular_counts: + conf_groups.extend(group_with_comparison(groups, are_groups_conformers)) + conf_groups = [[mol_to_st[grp] for grp in cgroup] for cgroup in conf_groups] + return conf_groups + + +if __name__ == "__main__": + logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--seed", + type=int, + default=10, + help="Random seed", + ) + parser.add_argument( + "--input_dir", + type=str, + help="Path containing PDB trajectory and LAMMPS data files", + ) + parser.add_argument("--save_dir", type=str, help="Path to save xyz files") + parser.add_argument( + "--system_name", type=str, help="Name of system used for directory naming" + ) + + parser.add_argument( + "--solute_radii", + type=list, + default=[3], + help="List of shell radii to extract around solutes", + ) + + parser.add_argument( + "--skip_solvent_centered_shells", + action="store_true", + help="Skip extracting solvent-centered shells", + ) + + parser.add_argument( + "--solvent_radii", + type=list, + default=[3], + help="List of shell radii to extract around solvents", + ) + + parser.add_argument( + "--shells_per_frame", + type=int, + default=-1, + help="Number of solutes or solvents per MD simulation frame from which to extract candidate shells", + ) + + parser.add_argument( + "--max_shell_size", + type=int, + default=200, + help="Maximum size (in atoms) of the saved shells", + ) + + parser.add_argument( + "--top_n", + type=int, + default=20, + help="Number of most diverse shells to extract per topology", + ) + + args = parser.parse_args() + + random.seed(args.seed) + + extract_solvation_shells( + args.input_dir, + args.save_dir, + args.system_name, + args.solute_radii, + args.skip_solvent_centered_shells, + args.solvent_radii, + args.shells_per_frame, + args.max_shell_size, + args.top_n, + ) diff --git a/electrolytes/solvation_shell_utils.py b/electrolytes/solvation_shell_utils.py new file mode 100644 index 0000000..dda6f2c --- /dev/null +++ b/electrolytes/solvation_shell_utils.py @@ -0,0 +1,133 @@ +import random +from typing import List, Set + +import numpy as np +from schrodinger.comparison.atom_mapper import ConnectivityAtomMapper +from schrodinger.structure import Structure +from schrodinger.structutils import analyze, rmsd + + +def generate_lognormal_samples(loc=75, sigma=0.45, size=1): + """ + Generate random samples from a lognormal distribution. + + Parameters: + - loc: float, mean of the distribution + - sigma: float, standard deviation of the log of the distribution + - size: int, number of samples to generate (default is 1000) + + Returns: + - samples: numpy array, random samples from the lognormal distribution + """ + samples = np.random.lognormal(mean=np.log(loc), sigma=sigma, size=size) + return samples + + +def expand_shell( + st: Structure, + shell_ats: Set[int], + central_solute: int, + radius: float, + solute_res_names: List[str], + max_shell_size: int = 200, +) -> Set[int]: + """ + Expands a solvation shell. If there are any (non-central) solutes present in the shell, + recursively include shells around those solutes. + First, gets the molecule numbers of solute molecules that are within the radius + and not already expanded around. Then, continuously expand around them as long as we don't hit an atom limit. + Args: + st: Entire structure from the PDB file + shell_ats: Set of atom indices (of `st`) in a shell (1-indexed) + central_solute: Molecule index (of 'st') of the central solute in the shell + radius: Solvation radius (Angstroms) to consider + solute_res_names: List of residue names that correspond to solute atoms in the simulation + max_shell_size: Maximum size (in atoms) of the expanded shell + Returns: + Set of atom indices (of `st`) of the expanded shell (1-indexed) + """ + solutes_included = set([central_solute]) + + def get_new_solutes(st, shell_ats, solutes_included, solute_res_names): + new_solutes = set() + for at in shell_ats: + # If atom is part of a non-central solute molecule - should expand the shell + if ( + st.atom[at].molecule_number not in solutes_included + and st.atom[at].getResidue().pdbres.strip() in solute_res_names + ): + new_solutes.add(st.atom[at].molecule_number) + return new_solutes + + new_solutes = get_new_solutes(st, shell_ats, solutes_included, solute_res_names) + while new_solutes: + # add entire residues within solvation shell radius of any extra solute atoms + new_shell_ats = shell_ats.union( + analyze.evaluate_asl( + st, + f'fillres within {radius} mol {",".join([str(i) for i in new_solutes])}', + ) + ) + if len(new_shell_ats) <= max_shell_size: + shell_ats = new_shell_ats + solutes_included.update(new_solutes) + new_solutes = get_new_solutes( + st, shell_ats, solutes_included, solute_res_names + ) + else: + break + + return shell_ats + + +def filter_by_rmsd(shells: List[Structure], n: int = 20) -> List[Structure]: + """ + From a set of shell coordinates, determine the n most diverse shells, where "most diverse" means "most different, in terms of minimum RMSD. + Note: The Max-Min Diversity Problem (MMDP) is in general NP-hard. This algorithm generates a candidate solution to MMDP for these coords + by assuming that the random seed point is actually in the MMDP set (which there's no reason a priori to assume). As a result, if we ran + this function multiple times, we would get different results. + + Args: + shell: List of Schrodinger structure objects containing solvation shells + n: number of most diverse shells to return + Returns: + List of n Schrodinger structures that are the most diverse in terms of minimum RMSD + """ + + seed_point = random.randint(0, len(shells) - 1) + final_shell_idxs = {seed_point} + min_rmsds = np.array([rmsd_wrapper(shells[seed_point], shell) for shell in shells]) + for _ in range(n - 1): + best = np.argmax(min_rmsds) + min_rmsds = np.minimum( + min_rmsds, + np.array([rmsd_wrapper(shells[best], shell) for shell in shells]), + ) + final_shell_idxs.add(best) + return [shells[i] for i in final_shell_idxs] + + +def rmsd_wrapper(st1: Structure, st2: Structure) -> float: + """ + Wrapper around Schrodinger's RMSD calculation function. + """ + assert ( + st1.atom_total == st2.atom_total + ), "Structures must have the same number of atoms for RMSD calculation" + if st1 == st2: + return 0.0 + at_list = list(range(1, st1.atom_total + 1)) + return rmsd.superimpose(st1, at_list, st2.copy(), at_list, use_symmetry=True) + + +def renumber_molecules_to_match(mol_list): + """ + Ensure that topologically equivalent sites are equivalently numbered + """ + mapper = ConnectivityAtomMapper(use_chirality=False) + atlist = range(1, mol_list[0].atom_total + 1) + renumbered_mols = [mol_list[0]] + for mol in mol_list[1:]: + _, r_mol = mapper.reorder_structures(mol_list[0], atlist, mol, atlist) + renumbered_mols.append(r_mol) + return renumbered_mols diff --git a/electrolytes/utils.py b/electrolytes/utils.py new file mode 100644 index 0000000..4c9a690 --- /dev/null +++ b/electrolytes/utils.py @@ -0,0 +1,18 @@ +from typing import Dict + + +def validate_metadata_file(metadata: Dict): + """ + Validates the metadata file to ensure that it contains the necessary fields. + Args: + metadata: Dictionary containing metadata for the system. + """ + required_fields = [ + "residue", + "species", + "solute_or_solvent", + "partial_charges", + ] + for field in required_fields: + if field not in metadata: + raise ValueError(f"Metadata file is missing required field: {field}") diff --git a/setup.py b/setup.py index 27d4a54..4963923 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,11 @@ description="Code for generating OMOL input configurations", url="http://github.com/Open-Catalyst-Project/om-data", packages=find_packages(), - install_requires=["ase", "quacc[sella]>=0.7.2"], + install_requires=[ + "ase", + "quacc[sella]>=0.7.2", + "numpy", + "tqdm", + ], include_package_data=True, )