Skip to content

Commit

Permalink
Add a starting point for workflow script (#19)
Browse files Browse the repository at this point in the history
* Clarified a few points of the data models

* De-duplicate the assemble function

Whoops, forgot I had a whole 'assemble' module

* Change the interface to pass the whole MOF record

We need the CIF, which holds the bonding structure

* Add a non-functional, but nearly-done workflow script

* Flake8 fix
  • Loading branch information
WardLT authored Oct 3, 2023
1 parent 3d1b620 commit 8893488
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 16 deletions.
21 changes: 10 additions & 11 deletions mofa/assemble.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
"""Functions for assembling a MOF structure"""
import ase
from typing import Sequence

from .model import NodeDescription, LigandDescription, MOFRecord

def assemble_mof(
node: object,
linker: str,
topology: object
) -> ase.Atoms:
"""Generate a MOF structure from a recipe

def assemble_mof(nodes: Sequence[NodeDescription], ligands: Sequence[LigandDescription], topology: str) -> MOFRecord:
"""Generate a new MOF from the description of the nodes, ligands and toplogy
Args:
node: Atomic structure of the nodes
linker: SMILES string defining the linker object
topology: Description of the network structure
nodes: Descriptions of each node
ligands: Description of the ligands
topology: Name of the topology
Returns:
Description of the 3D structure of the MOF
A new MOF record
"""
raise NotImplementedError()
17 changes: 14 additions & 3 deletions mofa/generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Functions pertaining to training and running the generative model"""
from pathlib import Path

from ase.io import write
import ase

from mofa.model import MOFRecord
from mofa.model import MOFRecord, LigandDescription


def train_generator(
Expand All @@ -25,15 +26,25 @@ def train_generator(

def run_generator(
model: str | Path,
fragment_template: LigandDescription,
molecule_sizes: list[int],
num_samples: int
num_samples: int,
fragment_spacing: float | None = None,
) -> list[ase.Atoms]:
"""
Args:
model: Path to the starting weights
molecule_sizes: Number of heavy atoms in the linker molecules to generate
fragment_template: Template to be filled with linker atoms
molecule_sizes: Number of heavy atoms in the linker to generate
num_samples: Number of samples of molecules to generate
fragment_spacing: Starting distance between the fragments
Returns:
3D geometries of the generated linkers
"""

# Create the template input
blank_template = fragment_template.generate_template(spacing_distance=fragment_spacing)
write('test.sdf', blank_template)

# Run the generator
raise NotImplementedError()
11 changes: 11 additions & 0 deletions mofa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def linker_atoms(self) -> list[int]:
"""All atoms which are not part of a fragment"""
raise NotImplementedError()

def generate_template(self, spacing_distance: float | None = None) -> ase.Atoms:
"""Generate a version of the ligand with only the fragments at the end
Args:
spacing_distance: Distance to enforce between the fragments. Set to ``None``
to keep the current distance
Returns:
The template with the desired spacing
"""
raise NotImplementedError()


@dataclass
class MOFRecord:
Expand Down
6 changes: 4 additions & 2 deletions mofa/simulation/lammps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Simulation operations that involve LAMMPS"""
import ase

from mofa.model import MOFRecord


class LAMMPSRunner:
"""Interface for running pre-defined LAMMPS workflows
Expand All @@ -12,11 +14,11 @@ class LAMMPSRunner:
def __init__(self, lammps_command: str):
self.lammps_command: str = lammps_command

def run_molecular_dynamics(self, mof: ase.Atoms, timesteps: int, report_frequency: int) -> list[ase.Atoms]:
def run_molecular_dynamics(self, mof: MOFRecord, timesteps: int, report_frequency: int) -> list[ase.Atoms]:
"""Run a molecular dynamics trajectory
Args:
mof: Starting structure
mof: Record describing the MOF. Includes the structure in CIF format, which includes the bonding information used by UFF
timesteps: Number of timesteps to run
report_frequency: How often to report structures
Returns:
Expand Down
69 changes: 69 additions & 0 deletions run_serial_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""An example of the workflow which runs on a single node"""
import logging
import sys
from argparse import ArgumentParser

import torch

from mofa.generator import run_generator
from mofa.model import MOFRecord
from mofa.scoring.geometry import MinimumDistance
from mofa.simulation.lammps import LAMMPSRunner

if __name__ == "__main__":
# Make the argument parser
parser = ArgumentParser()

group = parser.add_argument_group(title='MOF Settings', description='Options related to the MOF type being generatored')
group.add_argument('--mof-template', default=None, help='Path to a MOF we are going to be altering')

group = parser.add_argument_group(title='Generator Settings', description='Options related to how the generation is performed')
group.add_argument('--generator-path', default=None, help='Path to the PyTorch files describing model architecture and weights')
group.add_argument('--molecule-sizes', nargs='+', dtype=int, default=(10, 11, 12), help='Sizes of molecules we should generate')
group.add_argument('--num-samples', dtype=int, default=16, help='Number of molecules to generate at each size')

args = parser.parse_args()

# TODO: Make a run directory

# Turn on logging
logger = logging.getLogger('main')
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
logger.setLevel(logging.INFO)

# Load a pretrained generator from disk and use it to create ligands
template_mof = MOFRecord.from_file(cif_path=args.mof_template)
model = torch.load(args.generator_path)
generated_ligand_xyzs = run_generator(
model,
fragment_template=template_mof.ligands[0],
molecule_sizes=args.molecule_sizes,
num_samples=args.num_samples
)
logger.info(f'Generated {len(generated_ligand_xyzs)} ligands')

# Initial quality checks and post-processing on the generated ligands
validated_ligand_xyzs = []
for generated_xyz in generated_ligand_xyzs:
if False: # TODO (wardlt): Add checks for detecting fragmented molecules, valency checks, ...
pass
validated_ligand_xyzs.append(add_hydrogens_to_ligand(generated_xyz)) # TODO (wardlt): Add a function which adds H's to the XYZ file
logger.info(f'Screened generated ligands. {validated_ligand_xyzs} pass quality checks')

# Combine them with the template MOF to create new MOFs
new_mofs = []
for new_ligand in validated_ligand_xyzs:
new_mof = template_mof.replace_ligand(new_ligand)
new_mofs.append(new_mof)
logger.info(f'Generated {len(new_mofs)} new MOFs')

# Score the MOFs
scorer = MinimumDistance() # TODO (wardlt): Add or replace with a CGCNN that predicts absorption
scores = [scorer.score_mof(new_mof) for new_mof in new_mofs]
logger.info(f'Scored all {len(new_mofs)} MOFs')

# Run LAMMPS on the top MOF
ranked_mofs: list[tuple[float, MOFRecord]] = sorted(zip(scores, new_mofs))
LAMMPSRunner('lmp_serial').run_molecular_dynamics(ranked_mofs[-1][1], 100, 1)

0 comments on commit 8893488

Please sign in to comment.