Skip to content

Commit

Permalink
Flake8 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Sep 11, 2024
1 parent e47e195 commit a74c715
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions cascade/learning/mace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Interface to the """
"""Interface to the higher-order equivariant neural networks
of `Batatia et al. <https://arxiv.org/abs/2206.07697>`_"""

import logging
from io import BytesIO

Expand All @@ -7,8 +9,7 @@
import numpy as np
import pandas as pd
from ase import Atoms, data
from ignite.engine import Engine, create_supervised_evaluator, Events
from ignite.metrics import Loss
from ignite.engine import Engine, Events
from mace.data import AtomicData
from mace.data.utils import config_from_atoms
from mace.modules import WeightedHuberEnergyForcesStressLoss, ScaleShiftMACE
Expand All @@ -35,6 +36,7 @@ def atoms_to_loader(atoms: list[Atoms], batch_size: int, z_table: AtomicNumberTa
z_table: Map between atom ID in mace and periodic table
r_max: Cutoff distance
"""

def _prepare_atoms(my_atoms: Atoms):
"""MACE expects the training outputs to be stored in `info` and `arrays`"""
my_atoms.info = {
Expand All @@ -55,8 +57,8 @@ def _prepare_atoms(my_atoms: Atoms):
)


class MACEInterface(BaseLearnableForcefield):
""""""
class MACEInterface(BaseLearnableForcefield[MACEState]):
"""Interface to the `MACE library <https://github.com/ACEsuit/mace>`_"""

def evaluate(self,
model_msg: bytes | State,
Expand Down Expand Up @@ -156,6 +158,7 @@ def train(self,

# Prepare the training engine
train_losses = []

def get_loss_stats(b, y):
"""Compute the losses"""
na = batch.ptr[1:] - batch.ptr[:-1]
Expand Down Expand Up @@ -212,7 +215,6 @@ def validation_process(engine):
detailed_loss['total_loss'] = loss.item()
valid_losses.append(detailed_loss)


logger.info('Started training')
trainer.run(train_loader, max_epochs=num_epochs)
logger.info('Finished training')
Expand Down

0 comments on commit a74c715

Please sign in to comment.