Skip to content

Commit

Permalink
revert epoch saving
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyil1230 committed Sep 11, 2024
1 parent 6194b8c commit 213a66f
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 10 deletions.
4 changes: 2 additions & 2 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def run(args):
wandb.run.log({"epoch": epoch}, commit=True)

# Save checkpoint from last epoch
if epoch == 1:
# cerate ckpts folder if it does not exist
if epoch == args.max_epochs - 1:
# create ckpts folder if it does not exist
if not os.path.exists(args.checkpoint_path):
os.makedirs(args.checkpoint_path)
torch.save(
Expand Down
3 changes: 1 addition & 2 deletions orb_models/forcefield/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from collections import defaultdict
from copy import deepcopy
from typing import (Any, Dict, List, Mapping, NamedTuple, Optional, Sequence,
Union)
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union

import torch
import tree
Expand Down
3 changes: 1 addition & 2 deletions orb_models/forcefield/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch
from ase.calculators.calculator import Calculator, all_changes

from orb_models.forcefield.atomic_system import (SystemConfig,
ase_atoms_to_atom_graphs)
from orb_models.forcefield.atomic_system import SystemConfig, ase_atoms_to_atom_graphs
from orb_models.forcefield.graph_regressor import GraphRegressor


Expand Down
1 change: 1 addition & 0 deletions orb_models/forcefield/featurization_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch

# TODO(Mark): Make pynanoflann optional
from pynanoflann import KDTree as NanoKDTree
from scipy.spatial import KDTree as SciKDTree
Expand Down
3 changes: 1 addition & 2 deletions orb_models/forcefield/graph_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from orb_models.forcefield import base, segment_ops
from orb_models.forcefield.gns import _KEY, MoleculeGNS
from orb_models.forcefield.nn_util import build_mlp
from orb_models.forcefield.property_definitions import (PROPERTIES,
PropertyDefinition)
from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition
from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES

global HAS_WARNED_FOR_TF32_MATMUL
Expand Down
8 changes: 6 additions & 2 deletions orb_models/forcefield/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

from orb_models.forcefield.featurization_utilities import get_device
from orb_models.forcefield.gns import MoleculeGNS
from orb_models.forcefield.graph_regressor import (EnergyHead, GraphHead,
GraphRegressor, NodeHead)
from orb_models.forcefield.graph_regressor import (
EnergyHead,
GraphHead,
GraphRegressor,
NodeHead,
)
from orb_models.forcefield.rbf import ExpNormalSmearing

global HAS_MESSAGED_FOR_TF32_MATMUL
Expand Down

0 comments on commit 213a66f

Please sign in to comment.