From a80d4d4d69e033f4f33d4a9a218a5f2e32b33a86 Mon Sep 17 00:00:00 2001 From: zhiyili1230 Date: Sat, 7 Sep 2024 18:41:44 +0000 Subject: [PATCH] remove ema --- finetune.py | 7 +- orb_models/dataset/ase_dataset.py | 20 +- orb_models/dataset/data_loaders.py | 20 +- orb_models/finetune_utilities/ema.py | 221 ------------------ orb_models/finetune_utilities/experiment.py | 85 ++++++- orb_models/finetune_utilities/optim.py | 13 +- orb_models/finetune_utilities/steps.py | 15 +- orb_models/forcefield/atomic_system.py | 6 +- orb_models/forcefield/base.py | 12 +- orb_models/forcefield/calculator.py | 9 +- .../forcefield/featurization_utilities.py | 3 +- orb_models/forcefield/gns.py | 7 +- orb_models/forcefield/graph_regressor.py | 15 +- orb_models/forcefield/nn_util.py | 3 +- orb_models/forcefield/pretrained.py | 6 +- orb_models/forcefield/property_definitions.py | 2 +- orb_models/forcefield/rbf.py | 1 + orb_models/forcefield/reference_energies.py | 1 + orb_models/forcefield/segment_ops.py | 3 +- orb_models/utils.py | 98 -------- pyproject.toml | 3 +- 21 files changed, 130 insertions(+), 420 deletions(-) delete mode 100644 orb_models/finetune_utilities/ema.py delete mode 100644 orb_models/utils.py diff --git a/finetune.py b/finetune.py index 32de6e4..6ef9999 100644 --- a/finetune.py +++ b/finetune.py @@ -10,7 +10,6 @@ from orb_models.finetune_utilities import experiment, optim from orb_models.dataset import data_loaders from orb_models.finetune_utilities import steps -from orb_models import utils logging.basicConfig( @@ -24,7 +23,7 @@ def run(args): Args: config (DictConfig): Config for training loop. """ - device = utils.init_device() + device = experiment.init_device() experiment.seed_everything(args.random_seed) # Make sure to use this flag for matmuls on A100 and H100 GPUs. @@ -39,7 +38,7 @@ def run(args): # Move model to correct device. model.to(device=device) - optimizer, lr_scheduler, ema = optim.get_optim(args.lr, args.max_epochs, model) + optimizer, lr_scheduler = optim.get_optim(args.lr, args.max_epochs, model) wandb_run = None # Logger instantiation/configuration @@ -80,7 +79,6 @@ def run(args): model=model, optimizer=optimizer, dataloader=train_loader, - ema=ema, lr_scheduler=lr_scheduler, clip_grad=args.gradient_clip_val, device=device, @@ -110,7 +108,6 @@ def run(args): "lr_scheduler_state_dict": lr_scheduler.state_dict() if lr_scheduler else None, - "ema_state_dict": ema.state_dict() if ema else None, } torch.save( checkpoint, diff --git a/orb_models/dataset/ase_dataset.py b/orb_models/dataset/ase_dataset.py index 4ee66b5..aff99dd 100644 --- a/orb_models/dataset/ase_dataset.py +++ b/orb_models/dataset/ase_dataset.py @@ -1,20 +1,16 @@ from pathlib import Path -from typing import Dict, Literal, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import ase import ase.db import ase.db.row +import numpy as np import torch from ase.stress import voigt_6_to_full_3x3_stress -import numpy as np from e3nn import o3 - - -from orb_models.forcefield import ( - atomic_system, - property_definitions, -) from torch.utils.data import Dataset + +from orb_models.forcefield import atomic_system, property_definitions from orb_models.forcefield.base import AtomGraphs @@ -31,11 +27,7 @@ class AseSqliteDataset(Dataset): of the dataset. system_config: A config for controlling how an atomic system is represented. target_config: A config for regression/classification targets. - evaluation: Three modes: "eval_with_noise", "eval_no_noise", "train". augmentation: If random rotation augmentation is used. - limit_size: Limit the size of the dataset to this many samples. Useful for debugging. - masking_args: Arguments for masking function. - filter_indices_path: Path to a file containing a list of indices to include in the dataset. Returns: An AseSqliteDataset. @@ -44,7 +36,7 @@ class AseSqliteDataset(Dataset): def __init__( self, name: str, - path: str, + path: Union[str, Path], system_config: Optional[atomic_system.SystemConfig] = None, target_config: Optional[atomic_system.PropertyConfig] = None, augmentation: Optional[bool] = True, @@ -205,7 +197,6 @@ def get_dataset( name: str, system_config: atomic_system.SystemConfig, target_config: atomic_system.PropertyConfig, - evaluation: Literal["eval_with_noise", "eval_no_noise", "train"] = "train", ) -> AseSqliteDataset: """Dataset factory function.""" return AseSqliteDataset( @@ -213,7 +204,6 @@ def get_dataset( name=name, system_config=system_config, target_config=target_config, - evaluation=evaluation, ) diff --git a/orb_models/dataset/data_loaders.py b/orb_models/dataset/data_loaders.py index a2950e3..80e0cfe 100644 --- a/orb_models/dataset/data_loaders.py +++ b/orb_models/dataset/data_loaders.py @@ -1,17 +1,13 @@ -import random import logging -from typing import Dict, List, Optional +import random +from typing import Any, Optional import numpy as np import torch -from torch.utils.data import ( - BatchSampler, - DataLoader, - RandomSampler, -) +from torch.utils.data import BatchSampler, DataLoader, RandomSampler -from orb_models.forcefield import base from orb_models.dataset.ase_dataset import AseSqliteDataset +from orb_models.forcefield import base from orb_models.forcefield.atomic_system import make_property_definitions_from_config HAVE_PRINTED_WORKER_INFO = False @@ -47,19 +43,19 @@ def build_train_loader( path: str, num_workers: int, batch_size: int, - augmentation: Optional[List[str]] = None, - target_config: Optional[Dict] = None, + augmentation: Optional[bool] = None, + target_config: Optional[Any] = None, **kwargs, ) -> DataLoader: """Builds the train dataloader from a config file. Args: dataset: The dataset name. + path: Dataset path. num_workers: The number of workers for each dataset. batch_size: The batch_size config for each dataset. - temperature: The temperature for temperature sampling. - Default is None for using random sampler. augmentation: If rotation augmentation is used. + target_config: The target config. Returns: The train Dataloader. diff --git a/orb_models/finetune_utilities/ema.py b/orb_models/finetune_utilities/ema.py deleted file mode 100644 index 0603e63..0000000 --- a/orb_models/finetune_utilities/ema.py +++ /dev/null @@ -1,221 +0,0 @@ -""" -Copied (and improved) from. - -- https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py (MIT license). -- https://github.com/Open-Catalyst-Project/ocp/blob/main/ocpmodels/modules/exponential_moving_average.py - (MIT license) -""" - -from __future__ import division, unicode_literals - -import copy -import weakref -from typing import Iterable, Optional - -import torch - - -# Partially based on: -# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py -class ExponentialMovingAverage: - """ - Maintains (exponential) moving average of a set of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter` (typically from - `model.parameters()`). - decay: The exponential decay. - use_num_updates: Whether to use number of updates when computing - averages. - """ - - def __init__( - self, - parameters: Iterable[torch.nn.Parameter], - decay: float, - use_num_updates: bool = False, - ): - if decay < 0.0 or decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - self.decay = decay - self.num_updates = 0 if use_num_updates else None - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - self.collected_params: Iterable[torch.Tensor] = [] - # By maintaining only a weakref to each parameter, - # we maintain the old GC behaviour of ExponentialMovingAverage: - # if the model goes out of scope but the ExponentialMovingAverage - # is kept, no references to the model or its parameters will be - # maintained, and the model will be cleaned up. - self._params_refs = [weakref.ref(p) for p in parameters] - - def _get_parameters( - self, parameters: Optional[Iterable[torch.nn.Parameter]] - ) -> Iterable[torch.nn.Parameter]: - if parameters is None: - parameters = [p() for p in self._params_refs] # type: ignore - if any(p is None for p in parameters): - raise ValueError( - "(One of) the parameters with which this " - "ExponentialMovingAverage " - "was initialized no longer exists (was garbage collected);" - " please either provide `parameters` explicitly or keep " - "the model to which they belong from being garbage " - "collected." - ) - return parameters - else: - return [p for p in parameters if p.requires_grad] - - def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: - """ - Update currently maintained parameters. - - Call this every time the parameters are updated, such as the result of - the `optimizer.step()` call. - - Args: - parameters: Iterable of `torch.nn.Parameter`; usually the same set of - parameters used to initialize this object. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = self._get_parameters(parameters) - decay = self.decay - if self.num_updates is not None: - self.num_updates += 1 - decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) - one_minus_decay = 1.0 - decay - with torch.no_grad(): - for s_param, param in zip(self.shadow_params, parameters): - tmp = param - s_param - s_param.add_(tmp, alpha=one_minus_decay) - - def copy_to( - self, parameters: Optional[Iterable[torch.nn.Parameter]] = None - ) -> None: - """ - Copy current parameters into given collection of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = self._get_parameters(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: - """ - Save the current parameters for restoring later. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. If `None`, the parameters of with which this - `ExponentialMovingAverage` was initialized will be used. - """ - parameters = self._get_parameters(parameters) - self.collected_params = [param.clone() for param in parameters] - - def restore( - self, parameters: Optional[Iterable[torch.nn.Parameter]] = None - ) -> None: - """ - Restore the parameters stored with the `store` method. - - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = self._get_parameters(parameters) - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) - - def state_dict(self) -> dict: - r"""Returns the state of the ExponentialMovingAverage as a dict.""" - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "num_updates": self.num_updates, - "shadow_params": self.shadow_params, - "collected_params": self.collected_params, - } - - def load_state_dict(self, state_dict: dict) -> None: - r"""Loads the ExponentialMovingAverage state. - - Args: - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict["decay"] - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.num_updates = state_dict["num_updates"] - assert self.num_updates is None or isinstance( - self.num_updates, int - ), "Invalid num_updates" - - assert isinstance( - state_dict["shadow_params"], list - ), "shadow_params must be a list" - self.shadow_params = [ - p.to(self.shadow_params[i].device) - for i, p in enumerate(state_dict["shadow_params"]) - ] - assert all( - isinstance(p, torch.Tensor) for p in self.shadow_params - ), "shadow_params must all be Tensors" - - assert isinstance( - state_dict["collected_params"], list - ), "collected_params must be a list" - # collected_params is empty at initialization, - # so use shadow_params for device instead - self.collected_params = [ - p.to(self.shadow_params[i].device) - for i, p in enumerate(state_dict["collected_params"]) - ] - assert all( - isinstance(p, torch.Tensor) for p in self.collected_params - ), "collected_params must all be Tensors" - - -class EMAContextManager: - """Context manager for ExponentialMovingAverage. - - Args: - ema: The ExponentialMovingAverage object. - """ - - def __init__(self, ema: Optional[ExponentialMovingAverage]): - self.ema = ema - - def __enter__(self): - """Store the current parameters for restoring later.""" - if self.ema is not None: - # Save current parameters to collected_params - self.ema.store() - # Copy shadow params to parameters - self.ema.copy_to() - - def __exit__(self, exc_type, exc_val, exc_tb): - """Restore the parameters stored with the `store` method.""" - if self.ema is not None: - # Copy collected_params to parameters - self.ema.restore() diff --git a/orb_models/finetune_utilities/experiment.py b/orb_models/finetune_utilities/experiment.py index 161502b..17a605e 100644 --- a/orb_models/finetune_utilities/experiment.py +++ b/orb_models/finetune_utilities/experiment.py @@ -1,26 +1,67 @@ """Experiment utilities.""" import os -import dataclasses import random -from typing import Dict, TypeVar +from collections import defaultdict +from pathlib import Path +from typing import Dict, Mapping, TypeVar +import dotenv import numpy import torch import wandb from wandb import wandb_run +from orb_models.forcefield import base + T = TypeVar("T") -@dataclasses.dataclass -class WandbArtifactTypes: - """Artifact types for wandb.""" +_V = TypeVar("_V", int, float, torch.Tensor) + +dotenv.load_dotenv(override=True) +PROJECT_ROOT: Path = Path( + os.environ.get("PROJECT_ROOT", str(Path(__file__).parent.parent)) +) +assert ( + PROJECT_ROOT.exists() +), "You must configure the PROJECT_ROOT environment variable in a .env file!" + +DATA_ROOT: Path = Path(os.environ.get("DATA_ROOT", default=str(PROJECT_ROOT / "data"))) +WANDB_ROOT: Path = Path( + os.environ.get("WANDB_ROOT", default=str(PROJECT_ROOT / "wandb")) +) + + +def init_device() -> torch.device: + """Initialize a device. + + Initializes a device, making sure to also + initialize the process group in a distributed + setting. + """ + rank = 0 + if torch.cuda.is_available(): + device = f"cuda:{rank}" + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + else: + device = "cpu" + return torch.device(device) + + +def ensure_detached(x: base.Metric) -> base.Metric: + """Ensure that the tensor is detached and on the CPU.""" + if isinstance(x, torch.Tensor): + return x.detach() + return x + - MODEL = "model" - CONFIG = "config" - DATASET = "dataset" - SCREENING = "screening" +def to_item(x: base.Metric) -> base.Metric: + """Convert a tensor to a python scalar.""" + if isinstance(x, torch.Tensor): + return x.cpu().item() + return x def prefix_keys( @@ -59,3 +100,29 @@ def init_wandb_from_config(args, job_type: str) -> wandb_run.Run: ) assert wandb.run is not None return wandb.run + + +class ScalarMetricTracker: + """Keep track of average scalar metric values.""" + + def __init__(self): + self.reset() + + def reset(self): + """Reset the AverageMetrics.""" + self.sums = defaultdict(float) + self.counts = defaultdict(int) + + def update(self, metrics: Mapping[str, base.Metric]) -> None: + """Update the metric counts with new values.""" + for k, v in metrics.items(): + if isinstance(v, torch.Tensor) and v.nelement() > 1: + continue # only track scalar metrics + if isinstance(v, torch.Tensor) and v.isnan().any(): + continue + self.sums[k] += ensure_detached(v) + self.counts[k] += 1 + + def get_metrics(self): + """Get the metric values, possibly reducing across gpu processes.""" + return {k: to_item(v) / self.counts[k] for k, v in self.sums.items()} diff --git a/orb_models/finetune_utilities/optim.py b/orb_models/finetune_utilities/optim.py index 9f995fb..ad042a1 100644 --- a/orb_models/finetune_utilities/optim.py +++ b/orb_models/finetune_utilities/optim.py @@ -1,10 +1,9 @@ +import logging import re from typing import Any, Dict, List, Mapping, Optional, Tuple, Union -import logging import omegaconf import torch -from orb_models.finetune_utilities.ema import ExponentialMovingAverage as EMA Metric = Union[torch.Tensor, int, float] MetricCollection = Union[Metric, Mapping[str, Metric]] @@ -115,11 +114,7 @@ def make_parameter_groups( def get_optim( lr: float, max_epoch: int, model: torch.nn.Module -) -> Tuple[ - torch.optim.Optimizer, - Optional[torch.optim.lr_scheduler._LRScheduler], - Optional[EMA], -]: +) -> Tuple[torch.optim.Optimizer, Optional[torch.optim.lr_scheduler._LRScheduler],]: """Configure optimizers, LR schedulers and EMA.""" parameter_groups = [ { @@ -131,7 +126,5 @@ def get_optim( opt = torch.optim.Adam(params, lr=lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max_epoch) - ema_decay = 0.999 - ema = EMA(model.parameters(), ema_decay) - return opt, scheduler, ema + return opt, scheduler diff --git a/orb_models/finetune_utilities/steps.py b/orb_models/finetune_utilities/steps.py index 93410b2..bd4bf63 100644 --- a/orb_models/finetune_utilities/steps.py +++ b/orb_models/finetune_utilities/steps.py @@ -2,15 +2,12 @@ import torch import tqdm +import wandb from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader +from wandb import wandb_run -import wandb from orb_models.finetune_utilities import experiment -from orb_models.finetune_utilities.ema import EMAContextManager as EMA -from orb_models.utils import ScalarMetricTracker - -from wandb import wandb_run def gradient_clipping( @@ -49,7 +46,6 @@ def fintune( model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader: DataLoader, - ema: Optional[EMA] = None, lr_scheduler: Optional[_LRScheduler] = None, num_steps: Optional[int] = None, clip_grad: Optional[float] = None, @@ -63,7 +59,6 @@ def fintune( model: The model to optimize. optimizer: The optimizer for the model. dataloader: A Pytorch Dataloader, which may be infinite if num_steps is passed. - ema: Optional, an Exponential Moving Average tracker for saving averaged model weights. lr_scheduler: Optional, a Learning rate scheduler for modifying the learning rate. num_steps: The number of training steps to take. This is required for distributed training, because controlling parallism is easier if all processes take exactly the same number of steps ( @@ -81,7 +76,7 @@ def fintune( if clip_grad is not None: hook_handles = gradient_clipping(model, clip_grad) - metrics = ScalarMetricTracker() + metrics = experiment.ScalarMetricTracker() # Set the model to "train" mode. model.train() @@ -137,10 +132,6 @@ def fintune( if lr_scheduler is not None: lr_scheduler.step() - # Update moving averages - if ema is not None: - ema.update() - metrics.update(step_metrics) if i != 0 and i % log_freq == 0: diff --git a/orb_models/forcefield/atomic_system.py b/orb_models/forcefield/atomic_system.py index 3f45c82..877801d 100644 --- a/orb_models/forcefield/atomic_system.py +++ b/orb_models/forcefield/atomic_system.py @@ -1,14 +1,14 @@ -from typing import Optional, List, Dict from dataclasses import dataclass +from typing import Dict, List, Optional import ase +import torch from ase import constraints from ase.calculators.singlepoint import SinglePointCalculator -from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition from orb_models.forcefield import featurization_utilities from orb_models.forcefield.base import AtomGraphs -import torch +from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition @dataclass diff --git a/orb_models/forcefield/base.py b/orb_models/forcefield/base.py index 0d28135..63b9422 100644 --- a/orb_models/forcefield/base.py +++ b/orb_models/forcefield/base.py @@ -2,16 +2,8 @@ from collections import defaultdict from copy import deepcopy -from typing import ( - Any, - Dict, - Union, - NamedTuple, - Mapping, - Optional, - List, - Sequence, -) +from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union + import torch import tree diff --git a/orb_models/forcefield/calculator.py b/orb_models/forcefield/calculator.py index 0ea3b09..d785438 100644 --- a/orb_models/forcefield/calculator.py +++ b/orb_models/forcefield/calculator.py @@ -1,10 +1,9 @@ -from ase.calculators.calculator import Calculator, all_changes from typing import Optional + import torch -from orb_models.forcefield.atomic_system import ( - SystemConfig, - ase_atoms_to_atom_graphs, -) +from ase.calculators.calculator import Calculator, all_changes + +from orb_models.forcefield.atomic_system import SystemConfig, ase_atoms_to_atom_graphs from orb_models.forcefield.graph_regressor import GraphRegressor diff --git a/orb_models/forcefield/featurization_utilities.py b/orb_models/forcefield/featurization_utilities.py index ca3de1b..7ae4e6b 100644 --- a/orb_models/forcefield/featurization_utilities.py +++ b/orb_models/forcefield/featurization_utilities.py @@ -1,6 +1,6 @@ """Featurization utilities for molecular models.""" -from typing import Callable, Tuple, Union, Optional +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -9,7 +9,6 @@ from pynanoflann import KDTree as NanoKDTree from scipy.spatial import KDTree as SciKDTree - DistanceFeaturizer = Callable[[torch.Tensor], torch.Tensor] diff --git a/orb_models/forcefield/gns.py b/orb_models/forcefield/gns.py index bb3cd12..11adcb6 100644 --- a/orb_models/forcefield/gns.py +++ b/orb_models/forcefield/gns.py @@ -2,12 +2,13 @@ from collections import OrderedDict from typing import List, Literal + +import numpy as np import torch from torch import nn -import numpy as np -from orb_models.forcefield import base + +from orb_models.forcefield import base, segment_ops from orb_models.forcefield.nn_util import build_mlp -from orb_models.forcefield import segment_ops _KEY = "feat" diff --git a/orb_models/forcefield/graph_regressor.py b/orb_models/forcefield/graph_regressor.py index 26a93e0..f4c1a0d 100644 --- a/orb_models/forcefield/graph_regressor.py +++ b/orb_models/forcefield/graph_regressor.py @@ -1,17 +1,14 @@ -from typing import Literal, Optional, Dict, Tuple, Union +from typing import Dict, Literal, Optional, Tuple, Union + +import numpy import torch import torch.nn as nn -import numpy -from orb_models.forcefield.property_definitions import ( - PROPERTIES, - PropertyDefinition, -) -from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES -from orb_models.forcefield import base +from orb_models.forcefield import base, segment_ops from orb_models.forcefield.gns import _KEY, MoleculeGNS from orb_models.forcefield.nn_util import build_mlp -from orb_models.forcefield import segment_ops +from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition +from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES global HAS_WARNED_FOR_TF32_MATMUL HAS_WARNED_FOR_TF32_MATMUL = False diff --git a/orb_models/forcefield/nn_util.py b/orb_models/forcefield/nn_util.py index 48cde63..6b24cc1 100644 --- a/orb_models/forcefield/nn_util.py +++ b/orb_models/forcefield/nn_util.py @@ -1,8 +1,9 @@ """Shared neural net utility functions.""" +from typing import List, Optional, Type + import torch import torch.nn.functional as F -from typing import List, Optional, Type from torch import nn from torch.utils.checkpoint import checkpoint_sequential diff --git a/orb_models/forcefield/pretrained.py b/orb_models/forcefield/pretrained.py index 28f4c7a..ecb826e 100644 --- a/orb_models/forcefield/pretrained.py +++ b/orb_models/forcefield/pretrained.py @@ -1,15 +1,17 @@ # flake8: noqa: E501 from typing import Union + import torch from cached_path import cached_path + from orb_models.forcefield.featurization_utilities import get_device +from orb_models.forcefield.gns import MoleculeGNS from orb_models.forcefield.graph_regressor import ( EnergyHead, - NodeHead, GraphHead, GraphRegressor, + NodeHead, ) -from orb_models.forcefield.gns import MoleculeGNS from orb_models.forcefield.rbf import ExpNormalSmearing global HAS_MESSAGED_FOR_TF32_MATMUL diff --git a/orb_models/forcefield/property_definitions.py b/orb_models/forcefield/property_definitions.py index 57f9927..0fd8bc5 100644 --- a/orb_models/forcefield/property_definitions.py +++ b/orb_models/forcefield/property_definitions.py @@ -1,7 +1,7 @@ """Classes that define prediction targets.""" from dataclasses import dataclass -from typing import Any, Callable, Dict, Literal, Tuple, Union, List, Optional +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import ase.data import ase.db diff --git a/orb_models/forcefield/rbf.py b/orb_models/forcefield/rbf.py index 61e7b40..266bc3b 100644 --- a/orb_models/forcefield/rbf.py +++ b/orb_models/forcefield/rbf.py @@ -1,4 +1,5 @@ import math + import torch diff --git a/orb_models/forcefield/reference_energies.py b/orb_models/forcefield/reference_energies.py index c3837d3..f3a75c7 100644 --- a/orb_models/forcefield/reference_energies.py +++ b/orb_models/forcefield/reference_energies.py @@ -1,4 +1,5 @@ from typing import NamedTuple + import numpy diff --git a/orb_models/forcefield/segment_ops.py b/orb_models/forcefield/segment_ops.py index 4af5ae9..c25efbc 100644 --- a/orb_models/forcefield/segment_ops.py +++ b/orb_models/forcefield/segment_ops.py @@ -1,6 +1,7 @@ -import torch from typing import Optional +import torch + TORCHINT = [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8] diff --git a/orb_models/utils.py b/orb_models/utils.py deleted file mode 100644 index 91ecdf0..0000000 --- a/orb_models/utils.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -from collections import defaultdict -from pathlib import Path - -import dotenv - -from typing import TypeVar, Union, Dict, Mapping -import torch -from orb_models.forcefield import base - -_V = TypeVar("_V", int, float, torch.Tensor) - -dotenv.load_dotenv(override=True) -PROJECT_ROOT: Path = Path( - os.environ.get("PROJECT_ROOT", str(Path(__file__).parent.parent)) -) -assert ( - PROJECT_ROOT.exists() -), "You must configure the PROJECT_ROOT environment variable in a .env file!" - -DATA_ROOT: Path = Path(os.environ.get("DATA_ROOT", default=str(PROJECT_ROOT / "data"))) -WANDB_ROOT: Path = Path( - os.environ.get("WANDB_ROOT", default=str(PROJECT_ROOT / "wandb")) -) - - -def int_to_device(device: Union[int, torch.device]) -> torch.device: - """Converts an integer to a torch device.""" - if isinstance(device, torch.device): - return device - if device < 0: - return torch.device("cpu") - return torch.device(device) - - -def init_device() -> torch.device: - """Initialize a device. - - Initializes a device, making sure to also - initialize the process group in a distributed - setting. - """ - rank = 0 - if torch.cuda.is_available(): - device = f"cuda:{rank}" - torch.cuda.set_device(rank) - torch.cuda.empty_cache() - else: - device = "cpu" - return torch.device(device) - - -def tqdm_desc_from_metrics(metrics: Dict[str, float]) -> str: - """Create a tqdm progress bar description from a dict of metrics.""" - return ( - ", ".join(["%s: %.4f" % (name, value) for name, value in metrics.items()]) - + " ||" - ) - - -class ScalarMetricTracker: - """Keep track of average scalar metric values.""" - - def __init__(self): - self.reset() - - def reset(self): - """Reset the AverageMetrics.""" - self.sums = defaultdict(float) - self.counts = defaultdict(int) - - def update(self, metrics: Mapping[str, base.Metric]) -> None: - """Update the metric counts with new values.""" - for k, v in metrics.items(): - if isinstance(v, torch.Tensor) and v.nelement() > 1: - continue # only track scalar metrics - if isinstance(v, torch.Tensor) and v.isnan().any(): - continue - self.sums[k] += ensure_detached(v) - self.counts[k] += 1 - - def get_metrics(self): - """Get the metric values, possibly reducing across gpu processes.""" - return {k: to_item(v) / self.counts[k] for k, v in self.sums.items()} - - -def ensure_detached(x: base.Metric) -> base.Metric: - """Ensure that the tensor is detached and on the CPU.""" - if isinstance(x, torch.Tensor): - return x.detach() - return x - - -def to_item(x: base.Metric) -> base.Metric: - """Convert a tensor to a python scalar.""" - if isinstance(x, torch.Tensor): - return x.cpu().item() - return x diff --git a/pyproject.toml b/pyproject.toml index e544395..02ddc48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,8 @@ dependencies = [ "scipy>=1.13.1", "torch==2.2.0", "dm-tree>=0.1.8", - "e3nn==0.4.4" + "e3nn==0.4.4", + "tqdm>=4.66.5", ] [build-system]