Skip to content

Commit

Permalink
Add initial sampler
Browse files Browse the repository at this point in the history
We'll probably use it in our bigger tests anyway
  • Loading branch information
WardLT committed Dec 11, 2023
1 parent 5e178af commit 7026d68
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 0 deletions.
9 changes: 9 additions & 0 deletions jitterbug/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Functions to sample atomic configurations"""
from typing import Type

from .base import StructureSampler
from .random import RandomSampler

methods: dict[str, Type[StructureSampler]] = {
'simple': RandomSampler
}
21 changes: 21 additions & 0 deletions jitterbug/sampler/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import ase


class StructureSampler:
"""Base class for generating structures used to train Hessian model
Options for the sampler should be defined in the initializer.
"""

def produce_structures(self, atoms: ase.Atoms, count: int, seed: int = 1) -> list[ase.Atoms]:
"""Generate a set of training structure
Args:
atoms: Unperturbed geometry
count: Number of structure to produce
seed: Random seed
Returns:
List of structures to be evaluated
"""

raise NotImplementedError()
34 changes: 34 additions & 0 deletions jitterbug/sampler/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Simple, friendly, random sampling"""
from dataclasses import dataclass

import numpy as np
import ase

from jitterbug.sampler.base import StructureSampler


@dataclass
class RandomSampler(StructureSampler):
"""Sample randomly-chosen directions
Perturbs each atom in each direction a random amount between -:attr:`step_size` and :attr:`step_size`.
"""

step_size: float = 0.005
"""Amount to displace the atoms (units: Angstrom)"""

def produce_structures(self, atoms: ase.Atoms, count: int, seed: int = 1) -> list[ase.Atoms]:
# Make the RNG
n_atoms = len(atoms)
rng = np.random.RandomState(seed + n_atoms)

output = []
for _ in range(count):
# Sample a perturbation
disp = rng.normal(-self.step_size, self.step_size, size=(n_atoms, 3))

# Make the new atoms
new_atoms = atoms.copy()
new_atoms.positions += disp
output.append(new_atoms)
return output
19 changes: 19 additions & 0 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pytest import mark
from ase.io import read

from jitterbug.sampler import methods


@mark.parametrize('method', ['simple'])
def test_random(method, xyz_path):
"""Make sure we get the same structures each time"""

atoms = read(xyz_path)
sampler = methods[method]()

# Generate two batches
samples_1 = sampler.produce_structures(atoms, 4)
samples_2 = sampler.produce_structures(atoms, 8)

for a1, a2 in zip(samples_1, samples_2):
assert a1 == a2

0 comments on commit 7026d68

Please sign in to comment.