Skip to content

Commit

Permalink
Refactor for SOAP too
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Oct 16, 2023
1 parent a6318ab commit 40dd787
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 9 deletions.
71 changes: 63 additions & 8 deletions jitterbug/model/dscribe/local.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Create a Gaussian-process-regression based model which uses features for each atom
Builds the model using PyTorch so that one can come derivatives analytically."""
from typing import Union
"""Create a PyTorch-based model which uses features for each atom"""
from typing import Union, Optional, Callable

from ase.calculators.calculator import Calculator, all_changes
from dscribe.descriptors.descriptorlocal import DescriptorLocal
from torch.utils.data import TensorDataset, DataLoader
from ase import Atoms
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch

from jitterbug.model.base import ASEEnergyModel


class InducedKernelGPR(torch.nn.Module):
"""Gaussian process regression model with an induced kernel
Expand Down Expand Up @@ -66,7 +68,7 @@ def make_gpr_model(train_descriptors: np.ndarray, num_inducing_points: int, use_
)


def train_model(model: InducedKernelGPR,
def train_model(model: torch.nn.Module,
train_x: np.ndarray,
train_y: np.ndarray,
steps: int,
Expand Down Expand Up @@ -145,10 +147,10 @@ def train_model(model: InducedKernelGPR,


class DScribeLocalCalculator(Calculator):
"""Calculator which uses a GPR model trained using SOAP descriptors
"""Calculator which uses descriptors for each atom and PyTorch to compute energy
Keyword Args:
model (InducedKernelGPR): A machine learning model which takes descriptors as inputs
model (torch.nn.Module): A machine learning model which takes descriptors as inputs
desc (DescriptorLocal): Tool used to compute the descriptors
desc_scaling (tuple[np.ndarray, np.ndarray]): A offset and factor with which to adjust the energy per atom predictions,
which are typically he mean and standard deviation of energy per atom across the training set.
Expand Down Expand Up @@ -182,7 +184,7 @@ def calculate(self, atoms=None, properties=('energy', 'forces', 'energies'),
d_desc_d_pos = torch.from_numpy(d_desc_d_pos.astype(np.float32))

# Move the model to device if need be
model: InducedKernelGPR = self.parameters['model']
model: torch.nn.Module = self.parameters['model']
device = self.parameters['device']
model.to(device)

Expand Down Expand Up @@ -216,3 +218,56 @@ def calculate(self, atoms=None, properties=('energy', 'forces', 'energies'),

# Move the model back to CPU memory
model.to('cpu')


class DScribeLocalEnergyModel(ASEEnergyModel):
"""Energy model based on DScribe atom-level descriptors
Trains an energy model using PyTorch
Args:
reference: Reference structure at which we compute the Hessian
descriptors: Tool used to compute descriptors
model_fn: Function used to create the model given descriptors for the training set
num_calculators: Number of models to use in ensemble
device: Device used for training
train_options: Options passed to the training function
"""

def __init__(self,
reference: Atoms,
descriptors: DescriptorLocal,
model_fn: Callable[[np.ndarray], torch.nn.Module],
num_calculators: int,
device: str = 'cpu',
train_options: Optional[dict] = None):
super().__init__(reference, num_calculators)
self.descriptors = descriptors
self.model_fn = model_fn
self.device = device
self.train_options = train_options or {'steps': 4}

def train_calculator(self, data: list[Atoms]) -> Calculator:
# Train it using the user-provided options
train_x = self.descriptors.create(data)
offset_x = train_x.mean(axis=(0, 1))
scale_x = np.clip(train_x.std(axis=(0, 1)), a_min=1e-6, a_max=None)
train_x -= offset_x
train_x /= scale_x

train_y = np.array([a.get_potential_energy() for a in data])
scale_y, offset_y = np.std(train_y), np.mean(train_y)
train_y = (train_y - offset_y) / scale_y

# Make then train the model
model = self.model_fn(train_x)
train_model(model, train_x, train_y, device=self.device, **self.train_options)

# Make the calculator
return DScribeLocalCalculator(
model=model,
desc=self.descriptors,
desc_scaling=(offset_x, scale_x),
energy_scaling=(offset_y, scale_y),
device=self.device
)
26 changes: 25 additions & 1 deletion tests/models/test_soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dscribe.descriptors.soap import SOAP
from pytest import mark, fixture

from jitterbug.model.dscribe.local import make_gpr_model, train_model, DScribeLocalCalculator
from jitterbug.model.dscribe.local import make_gpr_model, train_model, DScribeLocalCalculator, DScribeLocalEnergyModel


@fixture
Expand Down Expand Up @@ -86,3 +86,27 @@ def test_calculator(descriptors, soap, train_set):
numerical_forces = calc.calculate_numerical_forces(atoms, d=1e-4)
assert np.isclose(forces[:, :2], numerical_forces[:, :2], rtol=5e-1).all() # Make them agree w/i 50% (PES is not smooth)
assert np.std(energies) > 1e-6


def test_model(soap, train_set):
# Assemble the model
model = DScribeLocalEnergyModel(
reference=train_set[0],
descriptors=soap,
model_fn=lambda x: make_gpr_model(x, num_inducing_points=32),
num_calculators=4,
)

# Run the fitting
calcs = model.train(train_set)

# Test the mean hessian function
mean_hess = model.mean_hessian(calcs)
assert mean_hess.shape == (9, 9), 'Wrong shape'
assert np.isclose(mean_hess, mean_hess.T).all(), 'Not symmetric'

# Test the sampling
sampled_hess = model.sample_hessians(calcs, 128)
assert all(np.isclose(hess, hess.T).all() for hess in sampled_hess)
mean_sampled_hess = np.mean(sampled_hess, 0)
assert np.isclose(np.diag(mean_sampled_hess), np.diag(mean_hess), atol=5).mean() > 0.5 # Make sure most agree

0 comments on commit 40dd787

Please sign in to comment.