From a74c71536da24fa0c92b8e24b79c3ecc0a2891db Mon Sep 17 00:00:00 2001 From: Logan Ward Date: Wed, 11 Sep 2024 10:02:01 -0400 Subject: [PATCH] Flake8 fixes --- cascade/learning/mace.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/cascade/learning/mace.py b/cascade/learning/mace.py index 876927c..bb02413 100644 --- a/cascade/learning/mace.py +++ b/cascade/learning/mace.py @@ -1,4 +1,6 @@ -"""Interface to the """ +"""Interface to the higher-order equivariant neural networks +of `Batatia et al. `_""" + import logging from io import BytesIO @@ -7,8 +9,7 @@ import numpy as np import pandas as pd from ase import Atoms, data -from ignite.engine import Engine, create_supervised_evaluator, Events -from ignite.metrics import Loss +from ignite.engine import Engine, Events from mace.data import AtomicData from mace.data.utils import config_from_atoms from mace.modules import WeightedHuberEnergyForcesStressLoss, ScaleShiftMACE @@ -35,6 +36,7 @@ def atoms_to_loader(atoms: list[Atoms], batch_size: int, z_table: AtomicNumberTa z_table: Map between atom ID in mace and periodic table r_max: Cutoff distance """ + def _prepare_atoms(my_atoms: Atoms): """MACE expects the training outputs to be stored in `info` and `arrays`""" my_atoms.info = { @@ -55,8 +57,8 @@ def _prepare_atoms(my_atoms: Atoms): ) -class MACEInterface(BaseLearnableForcefield): - """""" +class MACEInterface(BaseLearnableForcefield[MACEState]): + """Interface to the `MACE library `_""" def evaluate(self, model_msg: bytes | State, @@ -156,6 +158,7 @@ def train(self, # Prepare the training engine train_losses = [] + def get_loss_stats(b, y): """Compute the losses""" na = batch.ptr[1:] - batch.ptr[:-1] @@ -212,7 +215,6 @@ def validation_process(engine): detailed_loss['total_loss'] = loss.item() valid_losses.append(detailed_loss) - logger.info('Started training') trainer.run(train_loader, max_epochs=num_epochs) logger.info('Finished training')