Skip to content

Commit

Permalink
Finish CLI support
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Dec 11, 2023
1 parent 3cd8d0a commit 4b08b53
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 33 deletions.
28 changes: 23 additions & 5 deletions jitterbug/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from colmena.task_server import ParslTaskServer
from parsl import Config, HighThroughputExecutor

from jitterbug.model.dscribe import make_global_mbtr_model
from jitterbug.parsl import get_energy, load_configuration
from jitterbug.thinkers.exact import ExactHessianThinker
from jitterbug.thinkers.static import ApproximateHessianThinker
from jitterbug.utils import make_calculator
from jitterbug.sampler import methods
from jitterbug.sampler import methods, UniformSampler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,8 +51,7 @@ def main(args: Optional[list[str]] = None):

# Make the run directory
method, basis = (x.lower() for x in args.method)
compute_name = args.approach
run_dir = Path('run') / xyz_name / f'{method}_{basis}_{compute_name}'
run_dir = Path('run') / xyz_name / f'{method}_{basis}_{args.approach}'
run_dir.mkdir(parents=True, exist_ok=True)

# Start logging
Expand Down Expand Up @@ -99,14 +100,31 @@ def main(args: Optional[list[str]] = None):

# Create a thinker
queues = PipeQueues(topics=['simulation'])
if args.exact:
functions = [] # Additional functions needed by thinker
if args.approach == 'exact':
thinker = ExactHessianThinker(
queues=queues,
atoms=atoms,
run_dir=run_dir,
num_workers=num_workers,
)
functions = [] # No other functions to run
elif args.approach == 'static':
# Determine the number to run
exact_needs = len(atoms) * 3 * 2 + (len(atoms) * 3) * (len(atoms) * 3 - 1) * 2 + 1
if args.amount_to_run < 1:
num_to_run = int(args.amount_to_run * exact_needs)
else:
num_to_run = int(args.amount_to_run)
logger.info(f'Running {num_to_run} energies out of {exact_needs} required for exact Hessian')
thinker = ApproximateHessianThinker(
queues=queues,
atoms=atoms,
run_dir=run_dir,
num_workers=num_workers,
num_to_run=num_to_run,
sampler=UniformSampler(), # TODO (wardlt): Make this configurable
model=make_global_mbtr_model(atoms) # TODO (wardlt): Make this configurable
)
else:
raise NotImplementedError()

Expand Down
40 changes: 40 additions & 0 deletions jitterbug/model/dscribe/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,41 @@
"""Energy models using `DScribe <https://singroup.github.io/dscribe/latest/index.html>`_"""
import ase
from dscribe.descriptors.mbtr import MBTR
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.kernel_ridge import KernelRidge
import numpy as np

from .globald import DScribeGlobalEnergyModel


def make_global_mbtr_model(ref_atoms: ase.Atoms, n_points: int = 8, cutoff: float = 6.) -> DScribeGlobalEnergyModel:
"""Make an MBTR model using scikit-learn
Args:
ref_atoms: Reference atoms to use for the model
n_points: Number of points to include in the MBTR grid
cutoff: Cutoff distance for the descriptors (units: Angstrom)
Returns:
Energy model, ready to be trained
"""
species = list(set(ref_atoms.get_chemical_symbols()))
desc = MBTR(
species=species,
geometry={"function": "angle"},
grid={"min": 0., "max": 180, "n": n_points, "sigma": 180. / n_points / 2.},
weighting={"function": "smooth_cutoff", "r_cut": cutoff, "threshold": 1e-3},
periodic=False,
)
model = Pipeline(
[('scale', StandardScaler()),
('krr', GridSearchCV(KernelRidge(kernel='rbf', alpha=1e-10),
{'gamma': np.logspace(-5, 5, 32)}))]
)
return DScribeGlobalEnergyModel(
reference=ref_atoms,
model=model,
descriptors=desc,
num_calculators=2
)
14 changes: 12 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,26 @@ def test_exact_solver(file_dir, xyz):
with open(devnull, 'w') as fo:
with redirect_stdout(fo):
main([
str(file_dir / xyz), '--exact', '--method', 'hf', 'sto-3g'
str(file_dir / xyz), '--approach', 'exact', '--method', 'hf', 'sto-3g'
])
assert (Path('run') / xyz_name / 'hf_sto-3g_exact' / 'hessian.npy').exists()


@mark.parametrize('amount', [32., 0.5])
def test_approx_solver(amount, xyz_path):
with open(devnull, 'w') as fo:
with redirect_stdout(fo):
main([
str(xyz_path), '--approach', 'static', '--amount-to-run', str(amount), '--method', 'hf', 'sto-3g'
])
assert (Path('run') / xyz_path.name[:-4] / 'hf_sto-3g_static' / 'hessian.npy').exists()


def test_parsl_path(xyz_path, file_dir):
with open(devnull, 'w') as fo:
with redirect_stdout(fo):
main([
str(xyz_path), '--exact', '--method', 'pm7', 'None',
str(xyz_path), '--approach', 'exact', '--method', 'pm7', 'None',
'--parsl-config', str(file_dir / 'example_config.py')
])
assert (Path('run') / 'water' / 'pm7_none_exact' / 'hessian.npy').exists()
28 changes: 2 additions & 26 deletions tests/test_thinkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@
from ase.vibrations import Vibrations
from colmena.queue.python import PipeQueues
from colmena.task_server.parsl import ParslTaskServer
from dscribe.descriptors import MBTR
from parsl import Config, HighThroughputExecutor
from pytest import fixture, mark
from sklearn.kernel_ridge import KernelRidge
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

from jitterbug.compare import compare_hessians
from jitterbug.model.dscribe.globald import DScribeGlobalEnergyModel
from jitterbug.model.dscribe import make_global_mbtr_model
from jitterbug.parsl import get_energy
from jitterbug.sampler import UniformSampler
from jitterbug.thinkers.exact import ExactHessianThinker
Expand Down Expand Up @@ -43,26 +38,7 @@ def ase_hessian(atoms, tmp_path) -> np.ndarray:

@fixture()
def mbtr(atoms):
n_points = 32
r_cutoff = 6.
desc = MBTR(
species=["H", "C", "N", "O"],
geometry={"function": "angle"},
grid={"min": 0., "max": 180, "n": n_points, "sigma": 180. / n_points / 2.},
weighting={"function": "smooth_cutoff", "r_cut": r_cutoff, "threshold": 1e-3},
periodic=False,
)
model = Pipeline(
[('scale', StandardScaler()),
('krr', GridSearchCV(KernelRidge(kernel='rbf', alpha=1e-10),
{'gamma': np.logspace(-5, 5, 32)}))]
)
return DScribeGlobalEnergyModel(
reference=atoms,
model=model,
descriptors=desc,
num_calculators=2
)
return make_global_mbtr_model(atoms)


@fixture(autouse=True)
Expand Down

0 comments on commit 4b08b53

Please sign in to comment.