Skip to content

Commit

Permalink
Start implementation of SchNet learner
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Nov 13, 2024
1 parent c67aebf commit 83cb3a6
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 5 deletions.
219 changes: 219 additions & 0 deletions cascade/learning/spk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""Utilities for using models based on SchNet"""
from tempfile import TemporaryDirectory, NamedTemporaryFile
from typing import List, Dict
from pathlib import Path
import os

from ase.calculators.calculator import Calculator
from more_itertools import batched
from schnetpack.data import AtomsLoader, ASEAtomsData

from schnetpack import transform as trn
import schnetpack as spk
from torch import optim
import pandas as pd
import numpy as np
import torch
import ase

from .base import BaseLearnableForcefield, State


def ase_to_spkdata(atoms: List[ase.Atoms], path: Path) -> ASEAtomsData:
"""Add a list of Atoms objects to a SchNetPack database
Args:
atoms: List of Atoms objects
path: Path to the database file
Returns:
A link to the database
"""

_props = ['energy', 'forces', 'stress']
if Path(path).exists():
raise ValueError('Path already exists')
db = ASEAtomsData(str(path))

# Get the properties as dictionaries
prop_lst = []
for a in atoms:
props = {}
# If we have the property, store it
if a.calc is not None:
calc = a.calc
for k in _props:
if k in calc.results:
props[k] = np.atleast_1d(calc.results[k])
else:
# If not, store a placeholder
props.update(dict((k, np.atleast_1d([])) for k in ['energy', 'forces', 'stress']))
prop_lst.append(props)
db.add_systems(prop_lst, atoms)
return db


class SchnetPackInterface(BaseLearnableForcefield):
"""Forcefield based on the SchNetPack implementation of SchNet"""

def __init__(self, scratch_dir: Path | None = None, timeout: float = None):
"""
Args:
scratch_dir: Directory in which to cache converted data
timeout: Maximum training time
"""
super().__init__(scratch_dir)
self.timeout = timeout

def evaluate(self,
model_msg: bytes | State,
atoms: list[ase.Atoms],
batch_size: int = 64,
device: str = 'cpu') -> (np.ndarray, list[np.ndarray], np.ndarray):
# Get the message
model_msg = self.get_model(model_msg)

# Iterate over chunks, coverting as we go
converter = spk.interfaces.AtomsConverter(
neighbor_list=trn.MatScipyNeighborList(cutoff=5.0), dtype=torch.float32, device=device
)
energies = []
forces = []
stresses = []
for batch in batched(atoms, batch_size):
# Push the batch to the device
inputs = converter(list(batch))
pred = model_msg(inputs)

# Extract data
energies.extend(pred['energy'].detach().cpu().numpy().tolist())
batch_f = pred['forces'].detach().cpu().numpy()
forces.extend(np.array_split(batch_f, np.cumsum([len(a) for a in batch]))[:-1])
print(pred['stress'])
stresses.append(pred['stress'].detach().cpu().numpy())

return np.array(energies), forces, np.concatenate(stresses)

def train(self,
model_msg: bytes | State,
train_data: list[ase.Atoms],
valid_data: list[ase.Atoms],
num_epochs: int,
device: str = 'cpu',
batch_size: int = 32,
learning_rate: float = 1e-3,
huber_deltas: tuple[float, float, float] = (0.5, 1, 1),
force_weight: float = 10,
stress_weight: float = 100,
reset_weights: bool = False,
**kwargs) -> tuple[bytes, pd.DataFrame]:

# Make sure the models are converted to Torch models
model_msg = self.get_model(model_msg)

# If desired, re-initialize weights
if reset_weights:
for module in model_msg.modules():
if hasattr(module, 'reset_parameters'):
module.reset_parameters()

# Start the training process
with TemporaryDirectory(dir=self.scratch_dir, prefix='spk') as td:
# Save the data to an ASE Atoms database
train_file = Path(td) / 'train_data.db'
train_db = ase_to_spkdata(train_data, train_file)
train_loader = AtomsLoader(train_db, batch_size=batch_size, shuffle=True, num_workers=8,
pin_memory=device != "cpu")

valid_file = Path(td) / 'valid_data.db'
valid_db = ase_to_spkdata(train_data, valid_file)
valid_loader = AtomsLoader(valid_db, batch_size=batch_size, num_workers=8, pin_memory=device != "cpu")

# Make the trainer
opt = optim.Adam(model_msg.parameters(), lr=learning_rate)

# tradeoff
rho_tradeoff = 0.9

# loss function
if huber_deltas is None:
# Use mean-squared loss
def loss(batch, result):
# compute the mean squared error on the energies
diff_energy = batch['energy'] - result['energy']
err_sq_energy = torch.mean(diff_energy ** 2)

# compute the mean squared error on the forces
diff_forces = batch['forces'] - result['forces']
err_sq_forces = torch.mean(diff_forces ** 2)

# build the combined loss function
err_sq = rho_tradeoff * err_sq_energy + (1 - rho_tradeoff) * err_sq_forces

return err_sq
else:
# Use huber loss
delta_energy, delta_force = huber_deltas

def loss(batch: Dict[str, torch.Tensor], result):
# compute the mean squared error on the energies per atom
n_atoms = batch['_atom_mask'].sum(axis=1)
err_sq_energy = torch.nn.functional.huber_loss(batch['energy'] / n_atoms,
result['energy'].float() / n_atoms,
delta=delta_energy)

# compute the mean squared error on the forces
err_sq_forces = torch.nn.functional.huber_loss(batch['forces'], result['forces'], delta=delta_force)

# build the combined loss function
err_sq = rho_tradeoff * err_sq_energy + (1 - rho_tradeoff) * err_sq_forces

return err_sq

metrics = [
spk.metrics.MeanAbsoluteError('energy'),
spk.metrics.MeanAbsoluteError('forces')
]

hooks = [
trn.CSVHook(log_path=td, metrics=metrics),
]

trainer = trn.Trainer(
model_path=td,
model=model_msg,
hooks=hooks,
loss_fn=loss,
optimizer=opt,
train_loader=train_loader,
validation_loader=valid_loader,
checkpoint_interval=num_epochs + 1 # Turns off checkpointing
)

trainer.train(device, n_epochs=num_epochs)

# Load in the best model
model_msg = torch.load(os.path.join(td, 'best_model'), map_location='cpu')

# Load in the training results
train_results = pd.read_csv(os.path.join(td, 'log.csv'))

return self.serialize_model(model_msg), train_results

def make_calculator(self, model_msg: bytes | State, device: str) -> Calculator:
# Write model to disk
with NamedTemporaryFile(suffix='.pt') as tf:
tf.close()
tf_path = Path(tf.name)
tf_path.write_bytes(self.serialize_model(model_msg))

return spk.interfaces.SpkCalculator(
model_file=str(tf_path),
neighbor_list=spk.transform.SkinNeighborList(
cutoff_skin=2.0,
neighbor_list=spk.transform.ASENeighborList(cutoff=5.)
),
energy_unit='eV',
stress_key='stress',
device=device
)
8 changes: 5 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ name: cascade
channels:
- defaults
- conda-forge
- pytorch
dependencies:
# Core dependencies
- python==3.11
- python==3.10
- matplotlib
- scikit-learn>=1
- jupyterlab
- pytorch==2.4.1
- pandas
- pytest
- flake8
- pip

# Computational chemistry
- packmol
# - packmol

# For nersc's jupyterlab
- ipykernel
Expand All @@ -23,7 +25,7 @@ dependencies:
- pip:
- git+https://gitlab.com/ase/ase.git
- git+https://github.com/ACEsuit/mace.git
- torch
- schnetpack
- mlflow
- pytorch-ignite
- python-git-info
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,8 @@ chgnet = [
mace = [
'mace-torch',
'ignite'
]
schnet = [
'schnetpack',
'torch<2.5'
]
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def example_cell() -> Atoms:

@fixture
def example_data() -> list[Atoms]:
atoms_1 = Atoms(symbols=['H', 'He'], positions=np.zeros((2, 3)), cell=[5., 5., 5.], pbc=True)
atoms_2 = Atoms(symbols=['He', 'He'], positions=np.zeros((2, 3)), cell=[5., 5., 5.], pbc=True)
atoms_1 = Atoms(symbols=['H', 'He'], positions=np.zeros((2, 3)), cell=[2., 2., 2.], pbc=True)
atoms_2 = Atoms(symbols=['He', 'He'], positions=np.zeros((2, 3)), cell=[2., 2., 2.], pbc=True)

atoms_1.positions[0, 0] = 3.
atoms_2.positions[0, 0] = 3.
Expand Down
55 changes: 55 additions & 0 deletions tests/learning/test_spk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from io import BytesIO

import torch
from pytest import fixture
import schnetpack as spk
import numpy as np


from cascade.learning.spk import SchnetPackInterface

@fixture
def schnet():

# Make the input representation
cutoff = 5
pairwise_distance = spk.atomistic.PairwiseDistances() # calculates pairwise distances between atoms
radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=cutoff)
schnet = spk.representation.SchNet(
n_atom_basis=32,
n_interactions=3,
radial_basis=radial_basis,
cutoff_fn=spk.nn.CosineCutoff(cutoff)
)

# Output layers
pred_energy = spk.atomistic.Atomwise(n_in=32, output_key='energy')
pred_forces = spk.atomistic.Forces(calc_stress=True)

model = spk.model.NeuralNetworkPotential(
representation=schnet,
input_modules=[spk.atomistic.Strain(), pairwise_distance],
output_modules=[pred_energy, pred_forces],
)
return model

def test_inference(schnet, example_data):
# Delete any previous results from the example data
for atoms in example_data:
atoms.calc = None

mi = SchnetPackInterface()
energy, forces, stresses = mi.evaluate(schnet, example_data)

assert energy.shape == (2,)
for atoms, f in zip(example_data, forces):
assert f.shape == (len(atoms), 3)
assert stresses.shape == (2, 3, 3)

# Test the calculator interface
calc = mi.make_calculator(schnet, 'cpu')
atoms = example_data[0]
atoms.calc = calc
assert np.isclose(atoms.get_potential_energy(), energy[0], atol=1e-4).all()
assert np.allclose(atoms.get_forces(), forces[0], atol=1e-3)
assert np.allclose(atoms.get_stress(voigt=False), stresses[0], atol=1e-3)

0 comments on commit 83cb3a6

Please sign in to comment.