diff --git a/finetune.py b/finetune.py index af79ceb..767b1a7 100644 --- a/finetune.py +++ b/finetune.py @@ -266,8 +266,8 @@ def run(args): wandb.run.log({"epoch": epoch}, commit=True) # Save checkpoint from last epoch - if epoch == 1: - # cerate ckpts folder if it does not exist + if epoch == args.max_epochs - 1: + # create ckpts folder if it does not exist if not os.path.exists(args.checkpoint_path): os.makedirs(args.checkpoint_path) torch.save( diff --git a/orb_models/forcefield/base.py b/orb_models/forcefield/base.py index 499fc53..63b9422 100644 --- a/orb_models/forcefield/base.py +++ b/orb_models/forcefield/base.py @@ -2,8 +2,7 @@ from collections import defaultdict from copy import deepcopy -from typing import (Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, - Union) +from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union import torch import tree diff --git a/orb_models/forcefield/calculator.py b/orb_models/forcefield/calculator.py index cab0ace..d785438 100644 --- a/orb_models/forcefield/calculator.py +++ b/orb_models/forcefield/calculator.py @@ -3,8 +3,7 @@ import torch from ase.calculators.calculator import Calculator, all_changes -from orb_models.forcefield.atomic_system import (SystemConfig, - ase_atoms_to_atom_graphs) +from orb_models.forcefield.atomic_system import SystemConfig, ase_atoms_to_atom_graphs from orb_models.forcefield.graph_regressor import GraphRegressor diff --git a/orb_models/forcefield/featurization_utilities.py b/orb_models/forcefield/featurization_utilities.py index a90812d..7ae4e6b 100644 --- a/orb_models/forcefield/featurization_utilities.py +++ b/orb_models/forcefield/featurization_utilities.py @@ -4,6 +4,7 @@ import numpy as np import torch + # TODO(Mark): Make pynanoflann optional from pynanoflann import KDTree as NanoKDTree from scipy.spatial import KDTree as SciKDTree diff --git a/orb_models/forcefield/graph_regressor.py b/orb_models/forcefield/graph_regressor.py index 73d83d2..167523e 100644 --- a/orb_models/forcefield/graph_regressor.py +++ b/orb_models/forcefield/graph_regressor.py @@ -7,8 +7,7 @@ from orb_models.forcefield import base, segment_ops from orb_models.forcefield.gns import _KEY, MoleculeGNS from orb_models.forcefield.nn_util import build_mlp -from orb_models.forcefield.property_definitions import (PROPERTIES, - PropertyDefinition) +from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES global HAS_WARNED_FOR_TF32_MATMUL diff --git a/orb_models/forcefield/pretrained.py b/orb_models/forcefield/pretrained.py index f8ddced..ecb826e 100644 --- a/orb_models/forcefield/pretrained.py +++ b/orb_models/forcefield/pretrained.py @@ -6,8 +6,12 @@ from orb_models.forcefield.featurization_utilities import get_device from orb_models.forcefield.gns import MoleculeGNS -from orb_models.forcefield.graph_regressor import (EnergyHead, GraphHead, - GraphRegressor, NodeHead) +from orb_models.forcefield.graph_regressor import ( + EnergyHead, + GraphHead, + GraphRegressor, + NodeHead, +) from orb_models.forcefield.rbf import ExpNormalSmearing global HAS_MESSAGED_FOR_TF32_MATMUL