-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
We'll probably use it in our bigger tests anyway
- Loading branch information
Showing
4 changed files
with
83 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""Functions to sample atomic configurations""" | ||
from typing import Type | ||
|
||
from .base import StructureSampler | ||
from .random import RandomSampler | ||
|
||
methods: dict[str, Type[StructureSampler]] = { | ||
'simple': RandomSampler | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import ase | ||
|
||
|
||
class StructureSampler: | ||
"""Base class for generating structures used to train Hessian model | ||
Options for the sampler should be defined in the initializer. | ||
""" | ||
|
||
def produce_structures(self, atoms: ase.Atoms, count: int, seed: int = 1) -> list[ase.Atoms]: | ||
"""Generate a set of training structure | ||
Args: | ||
atoms: Unperturbed geometry | ||
count: Number of structure to produce | ||
seed: Random seed | ||
Returns: | ||
List of structures to be evaluated | ||
""" | ||
|
||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
"""Simple, friendly, random sampling""" | ||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
import ase | ||
|
||
from jitterbug.sampler.base import StructureSampler | ||
|
||
|
||
@dataclass | ||
class RandomSampler(StructureSampler): | ||
"""Sample randomly-chosen directions | ||
Perturbs each atom in each direction a random amount between -:attr:`step_size` and :attr:`step_size`. | ||
""" | ||
|
||
step_size: float = 0.005 | ||
"""Amount to displace the atoms (units: Angstrom)""" | ||
|
||
def produce_structures(self, atoms: ase.Atoms, count: int, seed: int = 1) -> list[ase.Atoms]: | ||
# Make the RNG | ||
n_atoms = len(atoms) | ||
rng = np.random.RandomState(seed + n_atoms) | ||
|
||
output = [] | ||
for _ in range(count): | ||
# Sample a perturbation | ||
disp = rng.normal(-self.step_size, self.step_size, size=(n_atoms, 3)) | ||
|
||
# Make the new atoms | ||
new_atoms = atoms.copy() | ||
new_atoms.positions += disp | ||
output.append(new_atoms) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from pytest import mark | ||
from ase.io import read | ||
|
||
from jitterbug.sampler import methods | ||
|
||
|
||
@mark.parametrize('method', ['simple']) | ||
def test_random(method, xyz_path): | ||
"""Make sure we get the same structures each time""" | ||
|
||
atoms = read(xyz_path) | ||
sampler = methods[method]() | ||
|
||
# Generate two batches | ||
samples_1 = sampler.produce_structures(atoms, 4) | ||
samples_2 = sampler.produce_structures(atoms, 8) | ||
|
||
for a1, a2 in zip(samples_1, samples_2): | ||
assert a1 == a2 |