Skip to content

Commit

Permalink
Bug fix: Only scale forces one
Browse files Browse the repository at this point in the history
I scaled the energy, so I don't need to scale the derivatives
of energy wrt descriptors
  • Loading branch information
WardLT committed Dec 27, 2023
1 parent fd1a049 commit 9e9ce03
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 23 deletions.
61 changes: 52 additions & 9 deletions jitterbug/model/dscribe/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Create a PyTorch-based model which uses features for each atom"""
from typing import Union, Optional, Callable
from typing import Union, Optional, Callable, Sequence

import ase
from ase.calculators.calculator import Calculator, all_changes
Expand Down Expand Up @@ -41,7 +41,7 @@ def forward(self, x) -> torch.Tensor:
esd = torch.exp(-diff)

# Return the sum
return torch.tensordot(self.alpha, esd, dims=([0], [0]))
return torch.tensordot(self.alpha, esd, dims=([0], [0]))[:, None]


class PerElementModule(torch.nn.Module):
Expand All @@ -62,10 +62,52 @@ def forward(self, element: torch.IntTensor, desc: torch.Tensor) -> torch.Tensor:
for elem, model in self.models.items():
elem_id = int(elem)
mask = element == elem_id
output[mask] = model(desc[mask, :])
output[mask] = model(desc[mask, :])[:, 0]
return output


def make_nn_model(
elements: np.ndarray,
descriptors: np.ndarray,
hidden_layers: Sequence[int] = (),
activation: torch.nn.Module = torch.nn.Sigmoid()
) -> PerElementModule:
"""Make a neural network model for a certain atomic system
Assumes that the descriptors have already been scaled
Args:
elements: Element for each atom in a structure (num_atoms,)
descriptors: 3D array of all training points (num_configurations, num_atoms, num_descriptors)
hidden_layers: Number of units in the hidden layers
activation: Activation function used for the hidden layers
"""

# Detect the dtype
dtype = torch.from_numpy(descriptors[0, 0, :1]).dtype

# Make a model for each element type
models: dict[int, torch.nn.Sequential] = {}
element_types = np.unique(elements)

for element in element_types:
# Make the neural network
nn_layers = []
input_size = descriptors.shape[2]
for hidden_size in hidden_layers:
nn_layers.extend([
torch.nn.Linear(input_size, hidden_size, dtype=dtype),
activation
])
input_size = hidden_size

# Make the last layer
nn_layers.append(torch.nn.Linear(input_size, 1, dtype=dtype))
models[element] = torch.nn.Sequential(*nn_layers)

return PerElementModule(models)


def make_gpr_model(elements: np.ndarray,
descriptors: np.ndarray,
num_inducing_points: int,
Expand Down Expand Up @@ -237,9 +279,9 @@ def calculate(self, atoms: ase.Atoms = None, properties=('energy', 'forces', 'en
d_desc_d_pos, desc = self.parameters['desc'].derivatives(atoms, attach=True)

# Scale the descriptors
offset, scale = self.parameters['desc_scaling']
desc = (desc - offset) / scale
d_desc_d_pos /= scale
desc_offset, desc_scale = self.parameters['desc_scaling']
desc = (desc - desc_offset) / desc_scale
d_desc_d_pos /= desc_scale

# Convert to pytorch
# TODO (wardlt): Make it possible to convert to float32 or lower
Expand All @@ -256,13 +298,13 @@ def calculate(self, atoms: ase.Atoms = None, properties=('energy', 'forces', 'en
model.to(device)

# Run inference
offset, scale = self.parameters['energy_scaling']
eng_offset, eng_scale = self.parameters['energy_scaling']
elements = torch.from_numpy(atoms.get_atomic_numbers())
model.eval() # Ensure we're in eval mode
elements = elements.to(device)
desc = desc.to(device)
pred_energies_dist = model(elements, desc)
pred_energies = pred_energies_dist * scale + offset
pred_energies = pred_energies_dist * eng_scale + eng_offset
pred_energy = torch.sum(pred_energies)
self.results['energy'] = pred_energy.item()
self.results['energies'] = pred_energies.detach().cpu().numpy()
Expand All @@ -273,14 +315,15 @@ def calculate(self, atoms: ase.Atoms = None, properties=('energy', 'forces', 'en
# Derivatives for the descriptors are for each center (which is the input to the model) with respect to each atomic coordinate changing.
# Energy is summed over the contributions from each center.
# The total force is therefore a sum over the effect of an atom moving on all centers
# Note: Forces are scaled because pred_energy was scaled
d_energy_d_desc = torch.autograd.grad(
outputs=pred_energy,
inputs=desc,
grad_outputs=torch.ones_like(pred_energy),
)[0] # Derivative of the energy with respect to the descriptors for each center
d_desc_d_pos = d_desc_d_pos.to(device)
d_energy_d_center_d_pos = torch.einsum('ijkl,il->ijk', d_desc_d_pos, d_energy_d_desc) # Derivative for each center with respect to each atom
pred_forces = -d_energy_d_center_d_pos.sum(dim=0) * scale # Total effect on each center from each atom
pred_forces = -d_energy_d_center_d_pos.sum(dim=0) # Total effect on each center from each atom

# Store the results
self.results['forces'] = pred_forces.detach().cpu().numpy()
Expand Down
34 changes: 20 additions & 14 deletions tests/models/test_soap.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
import torch
from dscribe.descriptors.soap import SOAP
from pytest import mark, fixture
from pytest import fixture

from jitterbug.model.dscribe.local import make_gpr_model, train_model, DScribeLocalCalculator, DScribeLocalEnergyModel
from jitterbug.model.dscribe.local import make_gpr_model, train_model, DScribeLocalCalculator, DScribeLocalEnergyModel, make_nn_model


@fixture
Expand All @@ -27,10 +27,19 @@ def elements(train_set):
return train_set[0].get_atomic_numbers()


@mark.parametrize('use_adr', [True, False])
def test_make_model(use_adr, elements, descriptors, train_set):
model = make_gpr_model(elements, descriptors, 4, use_ard_kernel=use_adr)
@fixture(params=['gpr-ard', 'gpr', 'nn'])
def model(elements, descriptors, request):
if request.param == 'gpr':
return make_gpr_model(elements, descriptors, 4, use_ard_kernel=False)
elif request.param == 'gpr-ard':
return make_gpr_model(elements, descriptors, 4, use_ard_kernel=True)
elif request.param == 'nn':
return make_nn_model(elements, descriptors, (16, 16))
else:
raise NotImplementedError()


def test_make_model(model, elements, descriptors, train_set):
# Evaluate on a single point
model.eval()
pred_y = model(
Expand All @@ -40,14 +49,10 @@ def test_make_model(use_adr, elements, descriptors, train_set):
assert pred_y.shape == (3,) # 3 Atoms


@mark.parametrize('use_adr', [True, False])
def test_train(elements, descriptors, train_set, use_adr):
def test_train(model, elements, descriptors, train_set):
# Make the model and the training set
train_y = np.array([a.get_potential_energy() for a in train_set])
train_y -= train_y.min()
model = make_gpr_model(elements, descriptors, 4, use_ard_kernel=use_adr)
for submodel in model.models.values():
submodel.inducing_x.requires_grad = False

# Evaluate the untrained model
model.eval()
Expand All @@ -61,8 +66,8 @@ def test_train(elements, descriptors, train_set, use_adr):
mae_untrained = np.abs(error_y).mean()

# Train
losses = train_model(model, elements, descriptors, train_y, 64)
assert len(losses) == 64
losses = train_model(model, elements, descriptors, train_y, learning_rate=0.001, steps=8)
assert len(losses) == 8

# Run the evaluation
model.eval()
Expand All @@ -73,7 +78,7 @@ def test_train(elements, descriptors, train_set, use_adr):
pred_y = torch.reshape(pred_y, [-1, elements.shape[0]])
error_y = pred_y.sum(axis=-1).detach().numpy() - train_y
mae_trained = np.abs(error_y).mean()
assert mae_trained < mae_untrained
assert mae_trained < mae_untrained * 1.1


def test_calculator(elements, descriptors, soap, train_set):
Expand Down Expand Up @@ -101,7 +106,8 @@ def test_calculator(elements, descriptors, soap, train_set):
forces = atoms.get_forces()
energies.append(atoms.get_potential_energy())
numerical_forces = calc.calculate_numerical_forces(atoms, d=1e-4)
assert np.isclose(forces[:, :2], numerical_forces[:, :2], rtol=5e-1).all() # Make them agree w/i 50% (PES is not smooth)
force_mask = np.abs(numerical_forces) > 0.1
assert np.isclose(forces[force_mask], numerical_forces[force_mask], rtol=0.1).all() # Agree w/i 10%
assert np.std(energies) > 1e-6


Expand Down

0 comments on commit 9e9ce03

Please sign in to comment.