diff --git a/mofa/assemble.py b/mofa/assemble.py index 0ab61eac..25e7bf33 100644 --- a/mofa/assemble.py +++ b/mofa/assemble.py @@ -1,19 +1,18 @@ """Functions for assembling a MOF structure""" -import ase +from typing import Sequence +from .model import NodeDescription, LigandDescription, MOFRecord -def assemble_mof( - node: object, - linker: str, - topology: object -) -> ase.Atoms: - """Generate a MOF structure from a recipe + +def assemble_mof(nodes: Sequence[NodeDescription], ligands: Sequence[LigandDescription], topology: str) -> MOFRecord: + """Generate a new MOF from the description of the nodes, ligands and toplogy Args: - node: Atomic structure of the nodes - linker: SMILES string defining the linker object - topology: Description of the network structure + nodes: Descriptions of each node + ligands: Description of the ligands + topology: Name of the topology + Returns: - Description of the 3D structure of the MOF + A new MOF record """ raise NotImplementedError() diff --git a/mofa/generator.py b/mofa/generator.py index 9b46823d..075069bc 100644 --- a/mofa/generator.py +++ b/mofa/generator.py @@ -1,9 +1,10 @@ """Functions pertaining to training and running the generative model""" from pathlib import Path +from ase.io import write import ase -from mofa.model import MOFRecord +from mofa.model import MOFRecord, LigandDescription def train_generator( @@ -25,15 +26,25 @@ def train_generator( def run_generator( model: str | Path, + fragment_template: LigandDescription, molecule_sizes: list[int], - num_samples: int + num_samples: int, + fragment_spacing: float | None = None, ) -> list[ase.Atoms]: """ Args: model: Path to the starting weights - molecule_sizes: Number of heavy atoms in the linker molecules to generate + fragment_template: Template to be filled with linker atoms + molecule_sizes: Number of heavy atoms in the linker to generate num_samples: Number of samples of molecules to generate + fragment_spacing: Starting distance between the fragments Returns: 3D geometries of the generated linkers """ + + # Create the template input + blank_template = fragment_template.generate_template(spacing_distance=fragment_spacing) + write('test.sdf', blank_template) + + # Run the generator raise NotImplementedError() diff --git a/mofa/model.py b/mofa/model.py index 8e57fea6..4d140bd1 100644 --- a/mofa/model.py +++ b/mofa/model.py @@ -47,6 +47,17 @@ def linker_atoms(self) -> list[int]: """All atoms which are not part of a fragment""" raise NotImplementedError() + def generate_template(self, spacing_distance: float | None = None) -> ase.Atoms: + """Generate a version of the ligand with only the fragments at the end + + Args: + spacing_distance: Distance to enforce between the fragments. Set to ``None`` + to keep the current distance + Returns: + The template with the desired spacing + """ + raise NotImplementedError() + @dataclass class MOFRecord: diff --git a/mofa/simulation/lammps.py b/mofa/simulation/lammps.py index 5831ec2a..a066e7c9 100644 --- a/mofa/simulation/lammps.py +++ b/mofa/simulation/lammps.py @@ -1,6 +1,8 @@ """Simulation operations that involve LAMMPS""" import ase +from mofa.model import MOFRecord + class LAMMPSRunner: """Interface for running pre-defined LAMMPS workflows @@ -12,11 +14,11 @@ class LAMMPSRunner: def __init__(self, lammps_command: str): self.lammps_command: str = lammps_command - def run_molecular_dynamics(self, mof: ase.Atoms, timesteps: int, report_frequency: int) -> list[ase.Atoms]: + def run_molecular_dynamics(self, mof: MOFRecord, timesteps: int, report_frequency: int) -> list[ase.Atoms]: """Run a molecular dynamics trajectory Args: - mof: Starting structure + mof: Record describing the MOF. Includes the structure in CIF format, which includes the bonding information used by UFF timesteps: Number of timesteps to run report_frequency: How often to report structures Returns: diff --git a/run_serial_workflow.py b/run_serial_workflow.py new file mode 100644 index 00000000..d3168605 --- /dev/null +++ b/run_serial_workflow.py @@ -0,0 +1,69 @@ +"""An example of the workflow which runs on a single node""" +import logging +import sys +from argparse import ArgumentParser + +import torch + +from mofa.generator import run_generator +from mofa.model import MOFRecord +from mofa.scoring.geometry import MinimumDistance +from mofa.simulation.lammps import LAMMPSRunner + +if __name__ == "__main__": + # Make the argument parser + parser = ArgumentParser() + + group = parser.add_argument_group(title='MOF Settings', description='Options related to the MOF type being generatored') + group.add_argument('--mof-template', default=None, help='Path to a MOF we are going to be altering') + + group = parser.add_argument_group(title='Generator Settings', description='Options related to how the generation is performed') + group.add_argument('--generator-path', default=None, help='Path to the PyTorch files describing model architecture and weights') + group.add_argument('--molecule-sizes', nargs='+', dtype=int, default=(10, 11, 12), help='Sizes of molecules we should generate') + group.add_argument('--num-samples', dtype=int, default=16, help='Number of molecules to generate at each size') + + args = parser.parse_args() + + # TODO: Make a run directory + + # Turn on logging + logger = logging.getLogger('main') + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + # Load a pretrained generator from disk and use it to create ligands + template_mof = MOFRecord.from_file(cif_path=args.mof_template) + model = torch.load(args.generator_path) + generated_ligand_xyzs = run_generator( + model, + fragment_template=template_mof.ligands[0], + molecule_sizes=args.molecule_sizes, + num_samples=args.num_samples + ) + logger.info(f'Generated {len(generated_ligand_xyzs)} ligands') + + # Initial quality checks and post-processing on the generated ligands + validated_ligand_xyzs = [] + for generated_xyz in generated_ligand_xyzs: + if False: # TODO (wardlt): Add checks for detecting fragmented molecules, valency checks, ... + pass + validated_ligand_xyzs.append(add_hydrogens_to_ligand(generated_xyz)) # TODO (wardlt): Add a function which adds H's to the XYZ file + logger.info(f'Screened generated ligands. {validated_ligand_xyzs} pass quality checks') + + # Combine them with the template MOF to create new MOFs + new_mofs = [] + for new_ligand in validated_ligand_xyzs: + new_mof = template_mof.replace_ligand(new_ligand) + new_mofs.append(new_mof) + logger.info(f'Generated {len(new_mofs)} new MOFs') + + # Score the MOFs + scorer = MinimumDistance() # TODO (wardlt): Add or replace with a CGCNN that predicts absorption + scores = [scorer.score_mof(new_mof) for new_mof in new_mofs] + logger.info(f'Scored all {len(new_mofs)} MOFs') + + # Run LAMMPS on the top MOF + ranked_mofs: list[tuple[float, MOFRecord]] = sorted(zip(scores, new_mofs)) + LAMMPSRunner('lmp_serial').run_molecular_dynamics(ranked_mofs[-1][1], 100, 1)