Skip to content

Commit

Permalink
Implement a per-element version of SOAP
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Dec 25, 2023
1 parent 35a23ac commit a01614c
Show file tree
Hide file tree
Showing 3 changed files with 522 additions and 105 deletions.
89 changes: 68 additions & 21 deletions jitterbug/model/dscribe/local.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Create a PyTorch-based model which uses features for each atom"""
from typing import Union, Optional, Callable

import ase
from ase.calculators.calculator import Calculator, all_changes
from dscribe.descriptors.descriptorlocal import DescriptorLocal
from torch.utils.data import TensorDataset, DataLoader
Expand Down Expand Up @@ -43,33 +44,69 @@ def forward(self, x) -> torch.Tensor:
return torch.tensordot(self.alpha, esd, dims=([0], [0]))


def make_gpr_model(train_descriptors: np.ndarray, num_inducing_points: int, use_ard_kernel: bool = False) -> InducedKernelGPR:
class PerElementModule(torch.nn.Module):
"""Fit a different model for each element
Args:
models: Map of atomic number to element to use
"""

def __init__(self, models: dict[int, torch.nn.Module]):
super().__init__()
self.models = torch.nn.ModuleDict(
dict((str(k), v) for k, v in models.items())
)

def forward(self, element: torch.IntTensor, desc: torch.Tensor) -> torch.Tensor:
output = torch.empty((desc.shape[0],), dtype=desc.dtype)
for elem, model in self.models.items():
elem_id = int(elem)
mask = element == elem_id
output[mask] = model(desc[mask, :])
return output


def make_gpr_model(elements: np.ndarray,
descriptors: np.ndarray,
num_inducing_points: int,
use_ard_kernel: bool = False) -> PerElementModule:
"""Make the GPR model for a certain atomic system
Assumes that the descriptors have already been scaled
Args:
train_descriptors: 3D array of all training points (num_configurations, num_atoms, num_descriptors)
num_inducing_points: Number of inducing points to use in the kernel. More points, more complex model
elements: Element for each atom in a structure (num_atoms,)
descriptors: 3D array of all training points (num_configurations, num_atoms, num_descriptors)
num_inducing_points: Number of inducing points to use in the kernel for each model. More points, more complex model
use_ard_kernel: Whether to use a different length scale parameter for each descriptor
Returns:
Model which can predict energy given descriptors for a single configuration
"""

# Select a set of inducing points randomly
descriptors = np.reshape(train_descriptors, (-1, train_descriptors.shape[-1]))
num_inducing_points = min(num_inducing_points, descriptors.shape[0])
inducing_inds = np.random.choice(descriptors.shape[0], size=(num_inducing_points,), replace=False)
inducing_points = descriptors[inducing_inds, :]
# Make a model for each element type
models: dict[int, InducedKernelGPR] = {}
element_types = np.unique(elements)

for element in element_types:
# Select a set of inducing points from records of each atom
mask = elements == element
masked_descriptors = descriptors[:, mask, :]
masked_descriptors = np.reshape(masked_descriptors, (-1, masked_descriptors.shape[-1]))
num_inducing_points = min(num_inducing_points, masked_descriptors.shape[0])
inducing_inds = np.random.choice(masked_descriptors.shape[0], size=(num_inducing_points,), replace=False)
inducing_points = masked_descriptors[inducing_inds, :]

# Make the model
models[element] = InducedKernelGPR(
inducing_x=torch.from_numpy(inducing_points),
use_ard=use_ard_kernel,
)

# Make the model
return InducedKernelGPR(
inducing_x=torch.from_numpy(inducing_points),
use_ard=use_ard_kernel,
)
return PerElementModule(models)


def train_model(model: torch.nn.Module,
train_e: np.ndarray,
train_x: np.ndarray,
train_y: np.ndarray,
steps: int,
Expand All @@ -83,6 +120,7 @@ def train_model(model: torch.nn.Module,
Args:
model: Model to be trained
train_e: Elements for each atom in a configuration (num_atoms,)
train_x: 3D array of all training points (num_configurations, num_atoms, num_descriptors)
train_y: Energies for each training point
steps: Number of interactions over all training points
Expand All @@ -94,10 +132,14 @@ def train_model(model: torch.nn.Module,
Mean loss over each iteration
"""

# Convert the data to Torches
# Convert the data to Tensors
n_conf, n_atoms = train_x.shape[:2]
train_x = torch.from_numpy(train_x)
train_y = torch.from_numpy(train_y)
train_e = torch.from_numpy(train_e)

# Duplicate the elements per batch size
train_e = train_e.repeat(batch_size)

# Make the data loader
dataset = TensorDataset(train_x, train_y)
Expand Down Expand Up @@ -125,7 +167,7 @@ def train_model(model: torch.nn.Module,

# Predict on all configurations
batch_x = torch.reshape(batch_x, (-1, batch_x.shape[-1])) # Flatten from (n_confs, n_atoms, n_desc) -> (n_confs * n_atoms, n_desc)
pred_y_per_atoms_flat = model(batch_x)
pred_y_per_atoms_flat = model(train_e, batch_x)

# Get the mean sum for each atom
pred_y_per_atoms = torch.reshape(pred_y_per_atoms_flat, (batch_size, n_atoms))
Expand All @@ -151,7 +193,7 @@ class DScribeLocalCalculator(Calculator):
"""Calculator which uses descriptors for each atom and PyTorch to compute energy
Keyword Args:
model (torch.nn.Module): A machine learning model which takes descriptors as inputs
model (PerElementModule): A machine learning model which takes atomic numbers and descriptors as inputs
desc (DescriptorLocal): Tool used to compute the descriptors
desc_scaling (tuple[np.ndarray, np.ndarray]): A offset and factor with which to adjust the energy per atom predictions,
which are typically he mean and standard deviation of energy per atom across the training set.
Expand All @@ -160,6 +202,8 @@ class DScribeLocalCalculator(Calculator):
device (str | torch.device): Device to use for inference
"""

# TODO (wardlt): Have scaling for descriptors and energies be per-element

implemented_properties = ['energy', 'forces', 'energies']
default_parameters = {
'model': None,
Expand All @@ -169,7 +213,7 @@ class DScribeLocalCalculator(Calculator):
'device': 'cpu'
}

def calculate(self, atoms=None, properties=('energy', 'forces', 'energies'),
def calculate(self, atoms: ase.Atoms = None, properties=('energy', 'forces', 'energies'),
system_changes=all_changes):
# Compute the descriptors for the atoms
d_desc_d_pos, desc = self.parameters['desc'].derivatives(atoms, attach=True)
Expand All @@ -180,9 +224,10 @@ def calculate(self, atoms=None, properties=('energy', 'forces', 'energies'),
d_desc_d_pos /= scale

# Convert to pytorch
desc = torch.from_numpy(desc.astype(np.float32))
# TODO (wardlt): Make it possible to convert to float32 or lower
desc = torch.from_numpy(desc)
desc.requires_grad = True
d_desc_d_pos = torch.from_numpy(d_desc_d_pos.astype(np.float32))
d_desc_d_pos = torch.from_numpy(d_desc_d_pos)

# Move the model to device if need be
model: torch.nn.Module = self.parameters['model']
Expand All @@ -191,9 +236,11 @@ def calculate(self, atoms=None, properties=('energy', 'forces', 'energies'),

# Run inference
offset, scale = self.parameters['energy_scaling']
elements = torch.from_numpy(atoms.get_atomic_numbers())
model.eval() # Ensure we're in eval mode
elements = elements.to(device)
desc = desc.to(device)
pred_energies_dist = model(desc)
pred_energies_dist = model(elements, desc)
pred_energies = pred_energies_dist * scale + offset
pred_energy = torch.sum(pred_energies)
self.results['energy'] = pred_energy.item()
Expand Down Expand Up @@ -262,7 +309,7 @@ def train_calculator(self, data: list[Atoms]) -> Calculator:

# Make then train the model
model = self.model_fn(train_x)
train_model(model, train_x, train_y, device=self.device, **self.train_options)
train_model(model, data[0].get_atomic_numbers(), train_x, train_y, device=self.device, **self.train_options)

# Make the calculator
return DScribeLocalCalculator(
Expand Down
Loading

0 comments on commit a01614c

Please sign in to comment.