Skip to content

Commit

Permalink
Add a dynamic workflow (#30)
Browse files Browse the repository at this point in the history
* Add the workflow function

* Add missing fixture

* Add a workflow for exact computations

* Compute the Hessian with numerical diff

* Use that the matrix is symmetric

* Add a CLI
  • Loading branch information
WardLT authored Dec 6, 2023
1 parent 18a0c6d commit 499685d
Show file tree
Hide file tree
Showing 11 changed files with 433 additions and 1 deletion.
90 changes: 90 additions & 0 deletions jitterbug/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Command line interface to running Jitterbug on a new molecule"""
from argparse import ArgumentParser
from functools import partial, update_wrapper
from pathlib import Path
from typing import Optional
import logging
import sys

import numpy as np
from ase.io import read
from colmena.queue import PipeQueues
from colmena.task_server import ParslTaskServer
from parsl import Config, HighThroughputExecutor

from jitterbug.parsl import get_energy
from jitterbug.thinkers.exact import ExactHessianThinker

logger = logging.getLogger(__name__)


def main(args: Optional[list[str]] = None):
"""Run Jitterbug"""

parser = ArgumentParser()
parser.add_argument('xyz', help='Path to the XYZ file')
parser.add_argument('--method', nargs=2, required=True,
help='Method to use to compute energies. Format: [method] [basis]. Example: B3LYP 6-31g*')
parser.add_argument('--exact', help='Compute Hessian using numerical derivatives', action='store_true')
args = parser.parse_args(args)

# Load the structure
xyz_path = Path(args.xyz)
atoms = read(args.xyz)
xyz_name = xyz_path.with_suffix('').name

# Make the run directory
method, basis = (x.lower() for x in args.method)
run_dir = Path('run') / xyz_name / f'{method}_{basis}'
run_dir.mkdir(parents=True, exist_ok=True)

# Start logging
handlers = [logging.StreamHandler(sys.stdout), logging.FileHandler(run_dir / 'run.log', mode='a')]
for handler in handlers:
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))

for logger_name in ['jitterbug']:
my_logger = logging.getLogger(logger_name)
for handler in handlers:
my_logger.addHandler(handler)
my_logger.setLevel(logging.INFO)

# Write the XYZ file to the run directory
if (run_dir / xyz_path.name).exists() and (run_dir / xyz_path.name).read_text() != xyz_path.read_text():
raise ValueError('Run exists for a different structure with the same name.')
(run_dir / xyz_path.name).write_text(xyz_path.read_text())
logger.info(f'Started run for {xyz_name} at {method}/{basis}. Run directory: {run_dir.absolute()}')

# Make the function to compute energy
energy_fun = partial(get_energy, method=method, basis=basis)
update_wrapper(energy_fun, get_energy)

# Create a thinker
queues = PipeQueues(topics=['simulation'])
if args.exact:
thinker = ExactHessianThinker(
queues=queues,
atoms=atoms,
run_dir=run_dir,
num_workers=1,
)
functions = [] # No other functions to run
else:
raise NotImplementedError()

# Create the task server
config = Config(run_dir=str(run_dir / 'parsl-logs'), executors=[HighThroughputExecutor(max_workers=1)])
task_server = ParslTaskServer([energy_fun] + functions, queues, config)

# Run everything
try:
task_server.start()
thinker.run()
finally:
queues.send_kill_signal()

# Get the Hessian
hessian = thinker.compute_hessian()
hess_path = run_dir / 'hessian.npy'
np.save(hess_path, hessian)
logger.info(f'Wrote Hessian to {hess_path}')
23 changes: 23 additions & 0 deletions jitterbug/parsl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Wrappers for functions compatible with the Parsl workflow engine"""
from typing import Optional

import ase

from jitterbug.utils import make_calculator


def get_energy(atoms: ase.Atoms, method: str, basis: Optional[str], **kwargs) -> float:
"""Compute the energy of an atomic structure
Keyword arguments are passed to :meth:`make_calculator`.
Args:
atoms: Structure to evaluate
method: Name of the method to use (e.g., B3LYP)
basis: Basis set to use (e.g., cc-PVTZ)
Returns:
Energy (units: eV)
"""

calc = make_calculator(method, basis, **kwargs)
return calc.get_potential_energy(atoms)
Empty file added jitterbug/thinkers/__init__.py
Empty file.
188 changes: 188 additions & 0 deletions jitterbug/thinkers/exact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Run an exact Hessian computation"""
from csv import reader, writer
from pathlib import Path
from typing import Optional

import ase
import numpy as np
from colmena.models import Result
from colmena.queue import ColmenaQueues
from colmena.thinker import BaseThinker, ResourceCounter, agent, result_processor


class ExactHessianThinker(BaseThinker):
"""Schedule the calculation of a complete set of numerical derivatives"""

def __init__(self, queues: ColmenaQueues, num_workers: int, atoms: ase.Atoms, run_dir: Path, step_size: float = 0.005):
super().__init__(queues, ResourceCounter(num_workers))
self.atoms = atoms

# Initialize storage for the energies
self.result_file = run_dir / 'simulation-results.json'
self.step_size = step_size
self.unperturbed_energy: Optional[float] = None
self.single_perturb = np.zeros((len(atoms), 3, 2)) * np.nan # Perturbation of a single direction. [atom_id, axis (xyz), dir_id (0 back, 1 forward)]
self.double_perturb = np.zeros((len(atoms), 3, 2, len(atoms), 3, 2)) * np.nan
# Perturbation of two directions [atom1_id, axis1, dir1_id, atom2_id, axis2 dir2_id]

# Load what has been run already
self.run_dir = run_dir
self.run_dir.mkdir(exist_ok=True)
self.energy_path = self.run_dir / 'unperturbed.energy'
self.single_path = self.run_dir / 'single_energies.csv'
self.double_path = self.run_dir / 'double_energies.csv'

if self.energy_path.exists():
self.unperturbed_energy = float(self.energy_path.read_text())

if self.single_path.exists():
with self.single_path.open() as fp:
count = 0
for row in reader(fp):
index = tuple(map(int, row[:-1]))
count += 1
self.single_perturb[index] = row[-1]
self.logger.info(f'Read {count} single perturbations out of {self.single_perturb.size}')

if self.double_path.exists():
with self.double_path.open() as fp:
count = 0
for row in reader(fp):
count += 1
# Get the index and its symmetric counterpart
index = tuple(map(int, row[:-1]))
sym_index = list(index)
sym_index[:3], sym_index[3:] = index[3:], index[:3]
self.double_perturb[index] = self.double_perturb[tuple(sym_index)] = row[-1]
self.logger.info(f'Read {count} double perturbations')

@agent()
def submit_tasks(self):
"""Submit all required tasks then start the shutdown process by exiting"""

# Start with the unperturbed energy
if self.unperturbed_energy is None:
self.queues.send_inputs(
self.atoms,
method='get_energy',
task_info={'type': 'unperturbed'}
)

# Submit the single perturbations
with np.nditer(self.single_perturb, flags=['multi_index']) as it:
count = 0
for x in it:
# Skip if done
if np.isfinite(x):
continue

# Submit if not done
self.rec.acquire(None, 1)
count += 1
atom_id, axis_id, dir_id = it.multi_index

new_atoms = self.atoms.copy()
new_atoms.positions[atom_id, axis_id] += self.step_size - 2 * self.step_size * dir_id
self.queues.send_inputs(
new_atoms,
method='get_energy',
task_info={'type': 'single', 'coord': it.multi_index}
)
self.logger.info(f'Finished submitting {count} single perturbations')

# Submit the double perturbations
with np.nditer(self.double_perturb, flags=['multi_index']) as it:
count = 0
for x in it:
# Skip if done
if np.isfinite(x):
continue

# Skip if perturbing the same direction twice, or if from the lower triangle
if it.multi_index[:2] == it.multi_index[3:5] or it.multi_index[:3] < it.multi_index[3:]:
continue

# Submit if not done
self.rec.acquire(None, 1)
count += 1

# Perturb two axes
new_atoms = self.atoms.copy()
for atom_id, axis_id, dir_id in [it.multi_index[:3], it.multi_index[3:]]:
new_atoms.positions[atom_id, axis_id] += self.step_size - 2 * self.step_size * dir_id

self.queues.send_inputs(
new_atoms,
method='get_energy',
task_info={'type': 'double', 'coord': it.multi_index}
)
self.logger.info(f'Finished submitting {count} double perturbations')

@result_processor
def store_energy(self, result: Result):
"""Store the energy in the appropriate files"""
self.rec.release()

# Store the result object to disk
with self.result_file.open('a') as fp:
print(result.json(exclude={'inputs'}), file=fp)

if not result.success:
self.logger.warning(f'Calculation failed due to {result.failure_info.exception}')
return

calc_type = result.task_info['type']
# Store unperturbed energy
if calc_type == 'unperturbed':
self.logger.info('Storing energy of unperturbed structure')
self.unperturbed_energy = result.value
self.energy_path.write_text(str(result.value))
return

# Store perturbed energy
coord = result.task_info['coord']
self.logger.info(f'Saving a {calc_type} perturbation: ({",".join(map(str, coord))})')
if calc_type == 'single':
energy_file = self.single_path
energies = self.single_perturb
else:
energy_file = self.double_path
energies = self.double_perturb

with energy_file.open('a') as fp:
csv_writer = writer(fp)
csv_writer.writerow(coord + [result.value])

energies[tuple(coord)] = result.value
if calc_type == 'double':
sym_coord = list(coord)
sym_coord[:3], sym_coord[3:] = coord[3:], coord[:3]
energies[tuple(sym_coord)] = result.value

def compute_hessian(self) -> np.ndarray:
"""Compute the Hessian using finite differences
Returns:
Hessian in the 2D form
Raises:
(ValueError) If there is missing data
"""

# Check that all data are available
n_atoms = len(self.atoms)
if not np.isfinite(self.single_perturb).all():
raise ValueError(f'Missing {np.isnan(self.single_perturb).sum()} single perturbations')
expected_double = self.double_perturb.size - (4 * n_atoms ** 2)
if not np.isfinite(self.double_perturb).sum() == expected_double:
raise ValueError(f'Missing {expected_double - np.isfinite(self.double_perturb).sum()} double perturbations')

# Flatten the arrays
single_flat = np.reshape(self.single_perturb, (n_atoms * 3, 2))
double_flat = np.reshape(self.double_perturb, (n_atoms * 3, 2, n_atoms * 3, 2))

# Compute the finite differences
# https://en.wikipedia.org/wiki/Finite_difference#Multivariate_finite_differences
diagonal = (single_flat.sum(axis=1) - self.unperturbed_energy * 2) / (self.step_size ** 2)
off_diagonal = (double_flat[:, 0, :, 0] + double_flat[:, 1, :, 1] - double_flat[:, 0, :, 1] - double_flat[:, 1, :, 0]) / (4 * self.step_size ** 2)
np.fill_diagonal(off_diagonal, 0)
return np.diag(diagonal) + off_diagonal
2 changes: 1 addition & 1 deletion jitterbug/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def make_calculator(method: str, basis: Optional[str], **kwargs) -> Calculator:
"""

if method in mopac_methods:
if not (basis is None or basis == "None"):
if not (basis is None or basis.lower() == "none"):
raise ValueError(f'Basis must be none for method: {method}')
return MOPAC(method=method, command='mopac PREFIX.mop > /dev/null')
elif method == 'xtb':
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ test = [
'pytest-timeout',
'pytest-cov',
]

[project.scripts]
jitterbug = "jitterbug.cli:main"
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pathlib import Path

from pytest import fixture

_file_dir = Path(__file__).parent / 'files'


@fixture()
def xyz_path():
return _file_dir / 'water.xyz'
5 changes: 5 additions & 0 deletions tests/files/water.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
3
Properties=species:S:1:pos:R:3:forces:R:3 energy=-2080.5785031368773 pbc="F F F"
O 2.53690888 -0.34072954 0.00000000 0.00346641 0.00474413 0.00000000
H 3.29702983 0.24793463 0.00000000 0.00004343 -0.00355994 -0.00000000
H 1.77686228 0.24789366 -0.00000000 -0.00350905 -0.00087620 0.00000000
14 changes: 14 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from contextlib import redirect_stdout
from os import devnull
from pathlib import Path

from jitterbug.cli import main


def test_exact_solver(xyz_path):
with open(devnull, 'w') as fo:
with redirect_stdout(fo):
main([
str(xyz_path), '--exact', '--method', 'pm7', 'None'
])
assert (Path('run') / 'water' / 'pm7_none' / 'hessian.npy').exists()
8 changes: 8 additions & 0 deletions tests/test_parsl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ase.io import read

from jitterbug.parsl import get_energy


def test_energy(xyz_path):
atoms = read(xyz_path)
get_energy(atoms, 'pm7', None)
Loading

0 comments on commit 499685d

Please sign in to comment.