Skip to content

Commit

Permalink
update version of black and rerun
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyil1230 committed Sep 10, 2024
1 parent 7720c1a commit bfadb76
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 58 deletions.
10 changes: 4 additions & 6 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def finetune(
{"global_step": global_step},
commit=False,
)
run.log(
utils.prefix_keys(metrics_dict, "train_step"), commit=False
)
run.log(utils.prefix_keys(metrics_dict, "train_step"), commit=False)
# Log learning rates.
run.log(
{
Expand Down Expand Up @@ -227,9 +225,9 @@ def run(args):
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"lr_scheduler_state_dict": lr_scheduler.state_dict()
if lr_scheduler
else None,
"lr_scheduler_state_dict": (
lr_scheduler.state_dict() if lr_scheduler else None
),
}
torch.save(
checkpoint,
Expand Down
46 changes: 2 additions & 44 deletions orb_models/forcefield/atomic_system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import List, Optional

import ase
import torch
Expand All @@ -8,8 +8,7 @@

from orb_models.forcefield import featurization_utilities
from orb_models.forcefield.base import AtomGraphs
from orb_models.forcefield.property_definitions import (PROPERTIES,
PropertyDefinition)
from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition


@dataclass
Expand All @@ -27,47 +26,6 @@ class SystemConfig:
use_timestep_0: bool = True


class PropertyConfig:
"""Defines which properties should be calculated and stored on the AtomGraphs batch.
These are numerical physical properties that can be used as features/targets for a model.
"""

def __init__(
self,
node_names: Optional[List[str]] = None,
edge_names: Optional[List[str]] = None,
graph_names: Optional[List[str]] = None,
**kwargs,
) -> None:
"""Initialize PropertyConfig.
Args:
node_names: List of node property names in PROPERTIES
edge_names: List of edge property names in PROPERTIES
graph_names: List of graph property names in PROPERTIES
**kwargs: Additional keyword arguments
"""
if node_names is not None:
self.node_properties: Optional[Dict[str, PropertyDefinition]] = {
name: PROPERTIES[name] for name in node_names
}
else:
self.node_properties = None
if edge_names is not None:
self.edge_properties: Optional[Dict[str, PropertyDefinition]] = {
name: PROPERTIES[name] for name in edge_names
}
else:
self.edge_properties = None
if graph_names is not None:
self.graph_properties: Optional[Dict[str, PropertyDefinition]] = {
name: PROPERTIES[name] for name in graph_names
}
else:
self.graph_properties = None


def atom_graphs_to_ase_atoms(
graphs: AtomGraphs,
energy: Optional[torch.Tensor] = None,
Expand Down
3 changes: 1 addition & 2 deletions orb_models/forcefield/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from collections import defaultdict
from copy import deepcopy
from typing import (Any, Dict, List, Mapping, NamedTuple, Optional, Sequence,
Union)
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union

import torch
import tree
Expand Down
3 changes: 1 addition & 2 deletions orb_models/forcefield/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch
from ase.calculators.calculator import Calculator, all_changes

from orb_models.forcefield.atomic_system import (SystemConfig,
ase_atoms_to_atom_graphs)
from orb_models.forcefield.atomic_system import SystemConfig, ase_atoms_to_atom_graphs
from orb_models.forcefield.graph_regressor import GraphRegressor


Expand Down
1 change: 1 addition & 0 deletions orb_models/forcefield/featurization_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch

# TODO(Mark): Make pynanoflann optional
from pynanoflann import KDTree as NanoKDTree
from scipy.spatial import KDTree as SciKDTree
Expand Down
3 changes: 1 addition & 2 deletions orb_models/forcefield/graph_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from orb_models.forcefield import base, segment_ops
from orb_models.forcefield.gns import _KEY, MoleculeGNS
from orb_models.forcefield.nn_util import build_mlp
from orb_models.forcefield.property_definitions import (PROPERTIES,
PropertyDefinition)
from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition
from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES

global HAS_WARNED_FOR_TF32_MATMUL
Expand Down
8 changes: 6 additions & 2 deletions orb_models/forcefield/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

from orb_models.forcefield.featurization_utilities import get_device
from orb_models.forcefield.gns import MoleculeGNS
from orb_models.forcefield.graph_regressor import (EnergyHead, GraphHead,
GraphRegressor, NodeHead)
from orb_models.forcefield.graph_regressor import (
EnergyHead,
GraphHead,
GraphRegressor,
NodeHead,
)
from orb_models.forcefield.rbf import ExpNormalSmearing

global HAS_MESSAGED_FOR_TF32_MATMUL
Expand Down

0 comments on commit bfadb76

Please sign in to comment.