-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CLI support for approximate Hessians (#35)
* Add initial sampler We'll probably use it in our bigger tests anyway * Initial draft of the static thinker * Finish CLI support * Make a log statement less verbose
- Loading branch information
Showing
11 changed files
with
382 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,41 @@ | ||
"""Energy models using `DScribe <https://singroup.github.io/dscribe/latest/index.html>`_""" | ||
import ase | ||
from dscribe.descriptors.mbtr import MBTR | ||
from sklearn.pipeline import Pipeline | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.model_selection import GridSearchCV | ||
from sklearn.kernel_ridge import KernelRidge | ||
import numpy as np | ||
|
||
from .globald import DScribeGlobalEnergyModel | ||
|
||
|
||
def make_global_mbtr_model(ref_atoms: ase.Atoms, n_points: int = 8, cutoff: float = 6.) -> DScribeGlobalEnergyModel: | ||
"""Make an MBTR model using scikit-learn | ||
Args: | ||
ref_atoms: Reference atoms to use for the model | ||
n_points: Number of points to include in the MBTR grid | ||
cutoff: Cutoff distance for the descriptors (units: Angstrom) | ||
Returns: | ||
Energy model, ready to be trained | ||
""" | ||
species = list(set(ref_atoms.get_chemical_symbols())) | ||
desc = MBTR( | ||
species=species, | ||
geometry={"function": "angle"}, | ||
grid={"min": 0., "max": 180, "n": n_points, "sigma": 180. / n_points / 2.}, | ||
weighting={"function": "smooth_cutoff", "r_cut": cutoff, "threshold": 1e-3}, | ||
periodic=False, | ||
) | ||
model = Pipeline( | ||
[('scale', StandardScaler()), | ||
('krr', GridSearchCV(KernelRidge(kernel='rbf', alpha=1e-10), | ||
{'gamma': np.logspace(-5, 5, 32)}))] | ||
) | ||
return DScribeGlobalEnergyModel( | ||
reference=ref_atoms, | ||
model=model, | ||
descriptors=desc, | ||
num_calculators=2 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""Functions to sample atomic configurations""" | ||
from typing import Type | ||
|
||
from .base import StructureSampler | ||
from .random import UniformSampler | ||
|
||
methods: dict[str, Type[StructureSampler]] = { | ||
'uniform': UniformSampler | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import ase | ||
|
||
|
||
class StructureSampler: | ||
"""Base class for generating structures used to train Hessian model | ||
Options for the sampler should be defined in the initializer. | ||
""" | ||
|
||
@property | ||
def name(self) -> str: | ||
"""Name for the sampling strategy""" | ||
raise NotImplementedError() | ||
|
||
def produce_structures(self, atoms: ase.Atoms, count: int, seed: int = 1) -> list[ase.Atoms]: | ||
"""Generate a set of training structure | ||
Args: | ||
atoms: Unperturbed geometry | ||
count: Number of structure to produce | ||
seed: Random seed | ||
Returns: | ||
List of structures to be evaluated | ||
""" | ||
|
||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
"""Simple, friendly, random sampling""" | ||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
import ase | ||
|
||
from jitterbug.sampler.base import StructureSampler | ||
|
||
|
||
@dataclass | ||
class UniformSampler(StructureSampler): | ||
"""Sample randomly-chosen directions | ||
Perturbs each atom in each direction a random amount between -:attr:`step_size` and :attr:`step_size`. | ||
""" | ||
|
||
step_size: float = 0.005 | ||
"""Amount to displace the atoms (units: Angstrom)""" | ||
|
||
@property | ||
def name(self) -> str: | ||
return f'uniform_{self.step_size:.3e}' | ||
|
||
def produce_structures(self, atoms: ase.Atoms, count: int, seed: int = 1) -> list[ase.Atoms]: | ||
# Make the RNG | ||
n_atoms = len(atoms) | ||
rng = np.random.RandomState(seed + n_atoms) | ||
|
||
output = [] | ||
for _ in range(count): | ||
# Sample a perturbation | ||
disp = rng.normal(-self.step_size, self.step_size, size=(n_atoms, 3)) | ||
|
||
# Make the new atoms | ||
new_atoms = atoms.copy() | ||
new_atoms.positions += disp | ||
output.append(new_atoms) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from pathlib import Path | ||
|
||
import ase | ||
import numpy as np | ||
from colmena.queue import ColmenaQueues | ||
from colmena.thinker import BaseThinker, ResourceCounter | ||
|
||
|
||
class HessianThinker(BaseThinker): | ||
"""Base class for thinkers | ||
Implementations must write their simulation data to the same spot""" | ||
|
||
atoms: ase.Atoms | ||
"""Unperturbed atomic structure""" | ||
|
||
run_dir: Path | ||
"""Path to the run directory""" | ||
result_file: Path | ||
"""Path to file in which to store result records""" | ||
|
||
def __init__(self, queues: ColmenaQueues, rec: ResourceCounter, run_dir: Path, atoms: ase.Atoms): | ||
super().__init__(queues, rec) | ||
self.atoms = atoms | ||
|
||
# Prepare for outputs | ||
self.run_dir = run_dir | ||
self.run_dir.mkdir(exist_ok=True) | ||
self.result_file = run_dir / 'simulation-results.json' | ||
|
||
def compute_hessian(self) -> np.ndarray: | ||
"""Compute the Hessian using finite differences | ||
Returns: | ||
Hessian in the 2D form | ||
Raises: | ||
(ValueError) If there is missing data | ||
""" | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
"""Approach which uses a static set of structures to compute Hessian""" | ||
from pathlib import Path | ||
|
||
import ase | ||
import numpy as np | ||
from ase.db import connect | ||
from colmena.models import Result | ||
from colmena.queue import ColmenaQueues | ||
from colmena.thinker import ResourceCounter, agent, result_processor | ||
|
||
from .base import HessianThinker | ||
from jitterbug.sampler.base import StructureSampler | ||
from ..model.base import EnergyModel | ||
from ..utils import read_from_string | ||
|
||
|
||
class ApproximateHessianThinker(HessianThinker): | ||
"""Approach which approximates a Hessian by computing it from a forcefield fit to structures | ||
Saves structures to an ASE db and labels them with the name of the sampler and the index | ||
""" | ||
|
||
def __init__(self, | ||
queues: ColmenaQueues, | ||
num_workers: int, | ||
atoms: ase.Atoms, | ||
run_dir: Path, | ||
sampler: StructureSampler, | ||
num_to_run: int, | ||
model: EnergyModel, | ||
step_size: float = 0.005): | ||
super().__init__(queues, ResourceCounter(num_workers), run_dir, atoms) | ||
self.step_size = step_size | ||
self.sampler = sampler | ||
self.num_to_run = num_to_run | ||
self.model = model | ||
|
||
# Generate the structures to be sampled | ||
self.to_sample = self.sampler.produce_structures(atoms, num_to_run) | ||
sampler_name = self.sampler.name | ||
self.logger.info(f'Generated {len(self.to_sample)} structures with strategy: {sampler_name}') | ||
|
||
# Find how many we've done already | ||
self.db_path = self.run_dir / 'atoms.db' | ||
self.completed: set[int] = set() | ||
with connect(self.db_path) as db: | ||
for row in db.select(f'index<{self.num_to_run}', sampler=sampler_name): | ||
atoms = row.toatoms(True) | ||
ind = atoms.info['key_value_pairs']['index'] | ||
assert np.isclose(atoms.positions, self.to_sample[ind].positions).all(), f'Structure {ind} in the DB and generated structure are inconsistent' | ||
self.completed.add(ind) | ||
num_remaining = self.num_to_run - len(self.completed) | ||
self.logger.info(f'Completed {len(self.completed)} structures already. Need to run {num_remaining} more') | ||
|
||
@agent() | ||
def submit_tasks(self): | ||
"""Submit all required tasks then start the shutdown process by exiting""" | ||
|
||
for ind, atoms in enumerate(self.to_sample): | ||
# Skip structures which we've done already | ||
if ind in self.completed: | ||
continue | ||
|
||
# Submit it | ||
self.rec.acquire(None, 1) | ||
self.queues.send_inputs( | ||
atoms, | ||
method='get_energy', | ||
task_info={'index': ind} | ||
) | ||
|
||
@result_processor | ||
def store_energy(self, result: Result): | ||
"""Store the energy in the appropriate files""" | ||
self.rec.release() | ||
|
||
# Store the result object to disk | ||
with self.result_file.open('a') as fp: | ||
print(result.json(exclude={'inputs'}), file=fp) | ||
|
||
if not result.success: | ||
self.logger.warning(f'Calculation failed due to {result.failure_info.exception}') | ||
return | ||
|
||
# Store the result into the ASE database | ||
sampler_name = self.sampler.name | ||
index = result.task_info['index'] | ||
atoms = read_from_string(result.value, 'extxyz') | ||
assert np.isclose(result.args[0].positions, atoms.positions).all() | ||
self.completed.add(index) | ||
with connect(self.db_path) as db: | ||
db.write(atoms, sampler=sampler_name, index=index) | ||
self.logger.info(f'Saved completed structure. Progress: {len(self.completed)}/{self.num_to_run}' | ||
f' ({len(self.completed) / self.num_to_run * 100:.2f}%)') | ||
|
||
def compute_hessian(self) -> np.ndarray: | ||
# Load the models | ||
atoms = [] | ||
with connect(self.db_path) as db: | ||
for row in db.select(f'index<{self.num_to_run}', sampler=self.sampler.name): | ||
atoms.append(row.toatoms()) | ||
self.logger.info(f'Pulled {len(atoms)} atoms for a training set') | ||
|
||
# Fit the model | ||
model = self.model.train(atoms) | ||
self.logger.info('Completed model fitting') | ||
|
||
return self.model.mean_hessian(model) |
Oops, something went wrong.