diff --git a/README.md b/README.md index 1913deb..3acbdf5 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,8 @@ For more information on the models, please see the [MODELS.md](MODELS.md) file. import ase from ase.build import bulk -from orb_models.forcefield import pretrained -from orb_models.forcefield import atomic_system + +from orb_models.forcefield import atomic_system, pretrained from orb_models.forcefield.base import batch_graphs orbff = pretrained.orb_v1() @@ -65,10 +65,10 @@ atoms = atomic_system.atom_graphs_to_ase_atoms( ```python import ase from ase.build import bulk + from orb_models.forcefield import pretrained from orb_models.forcefield.calculator import ORBCalculator - device="cpu" # or device="cuda" orbff = pretrained.orb_v1(device=device) # or choose another model using ORB_PRETRAINED_MODELS[model_name]() calc = ORBCalculator(orbff, device=device) diff --git a/finetune.py b/finetune.py index 78b63fc..bf00c51 100644 --- a/finetune.py +++ b/finetune.py @@ -1,30 +1,160 @@ """Finetuning loop.""" -import os -import logging import argparse -import time +import logging +import os +from typing import Optional, Union, cast import torch -from orb_models.forcefield import pretrained -from orb_models.finetune_utilities import experiment, optim -from orb_models.dataset import data_loaders -from orb_models.finetune_utilities import steps +import tqdm +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +import wandb +from orb_models.dataset import data_loaders +from orb_models import utils +from orb_models.forcefield import pretrained +from wandb import wandb_run logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) +def finetune( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + dataloader: DataLoader, + lr_scheduler: Optional[_LRScheduler] = None, + num_steps: Optional[int] = None, + clip_grad: Optional[float] = None, + log_freq: float = 10, + device: torch.device = torch.device("cpu"), + epoch: int = 0, +): + """Train for a fixed number of steps. + + Args: + model: The model to optimize. + optimizer: The optimizer for the model. + dataloader: A Pytorch Dataloader, which may be infinite if num_steps is passed. + lr_scheduler: Optional, a Learning rate scheduler for modifying the learning rate. + num_steps: The number of training steps to take. This is required for distributed training, + because controlling parallism is easier if all processes take exactly the same number of steps ( + this particularly applies when using dynamic batching). + clip_grad: Optional, the gradient clipping threshold. + log_freq: The logging frequency for step metrics. + device: The device to use for training. + epoch: The number of epochs the model has been fintuned. + + Returns + A dictionary of metrics. + """ + run: Optional[wandb_run.Run] = cast(Optional[wandb_run.Run], wandb.run) + + if clip_grad is not None: + hook_handles = utils.gradient_clipping(model, clip_grad) + + metrics = utils.ScalarMetricTracker() + + # Set the model to "train" mode. + model.train() + + # Get tqdm for the training batches + batch_generator = iter(dataloader) + num_training_batches: Union[int, float] + if num_steps is not None: + num_training_batches = num_steps + else: + try: + num_training_batches = len(dataloader) + except TypeError: + raise ValueError("Dataloader has no length, you must specify num_steps.") + + batch_generator_tqdm = tqdm.tqdm(batch_generator, total=num_training_batches) + + i = 0 + batch_iterator = iter(batch_generator_tqdm) + while True: + if num_steps and i == num_steps: + break + + optimizer.zero_grad(set_to_none=True) + + step_metrics = { + "batch_size": 0.0, + "batch_num_edges": 0.0, + "batch_num_nodes": 0.0, + } + + # Reset metrics so that it reports raw values for each step but still do averages on + # the gradient accumulation. + if i % log_freq == 0: + metrics.reset() + + batch = next(batch_iterator) + batch = batch.to(device) + step_metrics["batch_size"] += len(batch.n_node) + step_metrics["batch_num_edges"] += batch.n_edge.sum() + step_metrics["batch_num_nodes"] += batch.n_node.sum() + + with torch.cuda.amp.autocast(enabled=False): + batch_outputs = model.loss(batch) + loss = batch_outputs.loss + metrics.update(batch_outputs.log) + if torch.isnan(loss): + raise ValueError("nan loss encountered") + loss.backward() + + optimizer.step() + + if lr_scheduler is not None: + lr_scheduler.step() + + metrics.update(step_metrics) + + if i != 0 and i % log_freq == 0: + metrics_dict = metrics.get_metrics() + if run is not None: + global_step = (epoch * num_training_batches) + i + if run.sweep_id is not None: + run.log( + {"loss": metrics_dict["loss"]}, + commit=False, + ) + run.log( + {"global_step": global_step}, + commit=False, + ) + run.log( + utils.prefix_keys(metrics_dict, "train_step"), commit=False + ) + # Log learning rates. + run.log( + { + f"pg_{idx}": group["lr"] + for idx, group in enumerate(optimizer.param_groups) + }, + ) + + # Finished a single full step! + i += 1 + + if clip_grad is not None: + for h in hook_handles: + h.remove() + + return metrics.get_metrics() + + def run(args): """Training Loop. Args: config (DictConfig): Config for training loop. """ - device = experiment.init_device() - experiment.seed_everything(args.random_seed) + device = utils.init_device() + utils.seed_everything(args.random_seed) # Make sure to use this flag for matmuls on A100 and H100 GPUs. torch.set_float32_matmul_precision("high") @@ -39,7 +169,7 @@ def run(args): # Move model to correct device. model.to(device=device) total_steps = args.max_epochs * args.num_steps - optimizer, lr_scheduler = optim.get_optim(args.lr, total_steps, model) + optimizer, lr_scheduler = utils.get_optim(args.lr, total_steps, model) wandb_run = None # Logger instantiation/configuration @@ -47,14 +177,13 @@ def run(args): import wandb logging.info("Instantiating WandbLogger.") - wandb_run = experiment.init_wandb_from_config(args, job_type="finetuning") + wandb_run = utils.init_wandb_from_config(job_type="finetuning") wandb.define_metric("global_step") wandb.define_metric("epochs") wandb.define_metric("train_step/*", step_metric="global_step") wandb.define_metric("learning_rates/*", step_metric="global_step") wandb.define_metric("finetune/*", step_metric="epochs") - wandb.define_metric("key-metrics/*", step_metric="epochs") loader_args = dict( dataset=args.dataset, @@ -65,7 +194,7 @@ def run(args): ) train_loader = data_loaders.build_train_loader( **loader_args, - augmentation=getattr(args, "augmentation", True), + augmentation=True, ) logging.info("Starting training!") @@ -75,8 +204,7 @@ def run(args): for epoch in range(start_epoch, args.max_epochs): print(f"Start epoch: {epoch} training...") - t1 = time.time() - avg_train_metrics = steps.fintune( + avg_train_metrics = finetune( model=model, optimizer=optimizer, dataloader=train_loader, @@ -86,18 +214,10 @@ def run(args): num_steps=num_steps, epoch=epoch, ) - t2 = time.time() - train_times = {} - train_times["avg_time_per_step"] = (t2 - t1) / num_steps - train_times["total_time"] = t2 - t1 if args.wandb: wandb.run.log( - experiment.prefix_keys(avg_train_metrics, "finetune"), commit=False - ) - wandb.run.log( - experiment.prefix_keys(train_times, "finetune", sep="-"), - commit=False, + utils.prefix_keys(avg_train_metrics, "finetune"), commit=False ) wandb.run.log({"epoch": epoch}, commit=True) @@ -147,7 +267,7 @@ def main(): "--num_workers", default=8, type=int, help="Number of workers for data loader." ) parser.add_argument( - "--batch_size", default=100, type=int, help="Batch size for finetuning." + "--batch_size", default=10, type=int, help="Batch size for finetuning." ) parser.add_argument( "--gradient_clip_val", default=0.5, type=float, help="Gradient clip value." diff --git a/internal/check.py b/internal/check.py index 6754977..0dd313e 100644 --- a/internal/check.py +++ b/internal/check.py @@ -1,14 +1,13 @@ """Integration tests to check compatibility of outputs with internal OM models.""" -import torch -import ase +import argparse -from orb_models.forcefield import pretrained -from orb_models.forcefield import atomic_system -from core.models import load +import ase +import torch from core.dataset import atomic_system as core_atomic_system +from core.models import load -import argparse +from orb_models.forcefield import atomic_system, pretrained def main(model: str, core_model: str): diff --git a/orb_models/dataset/ase_dataset.py b/orb_models/dataset/ase_dataset.py index aff99dd..e6aa831 100644 --- a/orb_models/dataset/ase_dataset.py +++ b/orb_models/dataset/ase_dataset.py @@ -38,7 +38,7 @@ def __init__( name: str, path: Union[str, Path], system_config: Optional[atomic_system.SystemConfig] = None, - target_config: Optional[atomic_system.PropertyConfig] = None, + target_config: Optional[Dict] = None, augmentation: Optional[bool] = True, ): super().__init__() @@ -64,9 +64,20 @@ def __getitem__(self, idx) -> AtomGraphs: # Sqlite db is 1 indexed. row = self.db.get(idx + 1) atoms = row.toatoms() - extra_feats = self._get_row_properties(row, self.feature_config) - extra_targets = self._get_row_properties(row, self.target_config) - + node_properties = property_definitions.get_property_from_row( + self.target_config["node"], row + ) + graph_property_dict = {} + for target_propery in self.target_config["graph"]: + system_properties = property_definitions.get_property_from_row( + target_propery, row + ) + graph_property_dict[target_propery] = system_properties + extra_targets = { + "node": {"forces": node_properties}, + "edge": {}, + "graph": graph_property_dict, + } if self.augmentation: atoms, extra_targets = random_rotations_with_properties(atoms, extra_targets) # type: ignore @@ -75,9 +86,7 @@ def __getitem__(self, idx) -> AtomGraphs: system_id=idx, brute_force_knn=False, ) - atom_graph = self._add_extra_feats_and_targets( - atom_graph, extra_feats, extra_targets - ) + atom_graph = self._add_extra_targets(atom_graph, extra_targets) return atom_graph @@ -91,10 +100,6 @@ def get_atom_and_metadata(self, idx: int) -> Tuple[ase.Atoms, Dict]: row = self.db.get(idx + 1) return row.toatoms(), row.data - def get_idx_to_natoms(self) -> Dict[int, int]: - """Return a mapping between dataset index and number of atoms.""" - return self.db.get_idx_to_natoms(zero_index=True) - def __len__(self) -> int: """Return the dataset length.""" return len(self.db) @@ -103,70 +108,17 @@ def __repr__(self) -> str: """String representation of class.""" return f"AseSqliteDataset({self.name=}, {self.path=})" - def _get_row_properties( - self, - row: ase.db.row.AtomsRow, - property_config: Optional[atomic_system.PropertyConfig] = None, - ) -> Dict: - """Extract numerical properties from the db as tensors, to be used as features/targets. - - Applies extraction function (e.g. extract from metadata) and normalisation - - Args: - row: Database row - property_config: The config specifying how to extract the property/target. - - Returns: - ExtrinsicProperties containing the tensors for the row. - """ - if property_config is None: - return {"node": {}, "edge": {}, "graph": {}} - - def _get_properties( - property_definitions: Optional[ - Dict[str, property_definitions.PropertyDefinition] - ], - ) -> Dict[str, torch.Tensor]: - kwargs = {} - if property_definitions is not None: - for key, definition in property_definitions.items(): - if definition.row_to_property_fn is not None: - property_tensor = definition.row_to_property_fn( - row=row, dataset=self.name - ) - kwargs[key] = property_tensor - return kwargs - - node_properties = _get_properties(property_config.node_properties) - edge_properties = _get_properties(property_config.edge_properties) - system_properties = _get_properties(property_config.graph_properties) - return { - "node": node_properties, - "edge": edge_properties, - "graph": system_properties, - } - - def _add_extra_feats_and_targets( + def _add_extra_targets( self, atom_graph: AtomGraphs, - extra_feats: Dict[str, Dict], extra_targets: Dict[str, Dict], ): """Add extra features and targets to the AtomGraphs object. Args: atom_graph: AtomGraphs object to add extra features and targets to. - extra_feats: Dictionary of extra features with keys extra_targets: Dictionary of extra targets to add. """ - node_feats = {**atom_graph.node_features, **extra_feats["node"]} - edge_feats = {**atom_graph.edge_features, **extra_feats["edge"]} - - system_feats = ( - atom_graph.system_features if atom_graph.system_features is not None else {} - ) - system_feats = {**system_feats, **extra_feats["graph"]} - node_targets = ( atom_graph.node_targets if atom_graph.node_targets is not None else {} ) @@ -183,30 +135,12 @@ def _add_extra_feats_and_targets( system_targets = {**system_targets, **extra_targets["graph"]} return atom_graph._replace( - node_features=node_feats, - edge_features=edge_feats, - system_features=system_feats, node_targets=node_targets if node_targets != {} else None, edge_targets=edge_targets if edge_targets != {} else None, system_targets=system_targets if system_targets != {} else None, ) -def get_dataset( - path: Union[str, Path], - name: str, - system_config: atomic_system.SystemConfig, - target_config: atomic_system.PropertyConfig, -) -> AseSqliteDataset: - """Dataset factory function.""" - return AseSqliteDataset( - path=path, - name=name, - system_config=system_config, - target_config=target_config, - ) - - def random_rotations_with_properties( atoms: ase.Atoms, properties: dict ) -> Tuple[ase.Atoms, dict]: diff --git a/orb_models/dataset/data_loaders.py b/orb_models/dataset/data_loaders.py index 80e0cfe..4ac41da 100644 --- a/orb_models/dataset/data_loaders.py +++ b/orb_models/dataset/data_loaders.py @@ -8,9 +8,6 @@ from orb_models.dataset.ase_dataset import AseSqliteDataset from orb_models.forcefield import base -from orb_models.forcefield.atomic_system import make_property_definitions_from_config - -HAVE_PRINTED_WORKER_INFO = False def worker_init_fn(id: int): @@ -32,10 +29,6 @@ def worker_init_fn(id: int): ss = np.random.SeedSequence([uint64_seed]) np.random.seed(ss.generate_state(4)) random.seed(uint64_seed) - global HAVE_PRINTED_WORKER_INFO - if not HAVE_PRINTED_WORKER_INFO: - print(torch.utils.data.get_worker_info()) - HAVE_PRINTED_WORKER_INFO = True def build_train_loader( @@ -43,7 +36,7 @@ def build_train_loader( path: str, num_workers: int, batch_size: int, - augmentation: Optional[bool] = None, + augmentation: Optional[bool] = True, target_config: Optional[Any] = None, **kwargs, ) -> DataLoader: @@ -61,7 +54,6 @@ def build_train_loader( The train Dataloader. """ log_train = "Loading train datasets:\n" - target_config = make_property_definitions_from_config(target_config) dataset = AseSqliteDataset( dataset, path, target_config=target_config, augmentation=augmentation, **kwargs ) diff --git a/orb_models/finetune_utilities/optim.py b/orb_models/finetune_utilities/optim.py deleted file mode 100644 index e79a633..0000000 --- a/orb_models/finetune_utilities/optim.py +++ /dev/null @@ -1,132 +0,0 @@ -import logging -import re -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union - -import omegaconf -import torch - -Metric = Union[torch.Tensor, int, float] -MetricCollection = Union[Metric, Mapping[str, Metric]] -TensorDict = Mapping[str, Optional[torch.Tensor]] - - -OptimizerKwargs = Dict[str, Any] - - -class ParameterGroup(omegaconf.DictConfig): - """Param type for specifying Param groups.""" - - optimizer_kwargs: OptimizerKwargs - filter_string: str - - -ParameterGroups = List[ParameterGroup] - - -def make_parameter_groups( - module: torch.nn.Module, groups: ParameterGroups, verbose: bool = False -) -> List[Dict[str, Any]]: - """Construct parameter groups for model optimization. - - Args: - module: A torch Module with parameters to optimize. - groups: A dictionary of regexes mapping to optimizer kwargs. - See below for more info on how the groups are constructed. - - Takes a module and a parameter grouping (as specified below), and prepares them to be passed - to the `__init__` function of a `torch.Optimizer`. This means separating the parameters into - groups with the given regexes, and prepping whatever keyword arguments are given for those - regexes in `groups`. - - Returns: - The parameter groups ready to be passed to an optimizer. - - `groups` contains: - ``` - { - "regex1": {"lr": 1e-3}, - "regex2": {"lr": 1e-4} - } - ``` - All of key-value pairs specified in each of these dictionaries will be passed to the optimizer. - If there are multiple groups specified, this is a list of dictionaries, where each - dict contains a "parameter group" and groups specific options, e.g., {'params': [list of - parameters], 'lr': 1e-3, ...}. Any config option not specified in the additional options (e.g. - for the default group) is inherited from the top level arguments given in the constructor. See: - https://pytorch.org/docs/stable/optim.html#per-parameter-options - """ - parameter_groups: List[Dict[str, Any]] = [ - {"params": [], **g["optimizer_kwargs"]} for g in groups - ] - # In addition to any parameters that match group specific regex, - # we also need a group for the remaining "default" group. - # Those will be included in the last entry of parameter_groups. - parameter_groups.append({"params": []}) - - regex_use_counts: Dict[str, int] = {} - parameter_group_names: List[set] = [set() for _ in range(len(groups) + 1)] - - for name, param in module.named_parameters(): - # Determine the group for this parameter. - group_index = None - regex_names = [g["filter_string"] for g in groups] - for k, regex in enumerate(regex_names): - if regex not in regex_use_counts: - regex_use_counts[regex] = 0 - if re.search(regex, name): - if group_index is not None and group_index != k: - raise ValueError( - "{} was specified in two separate parameter groups".format(name) - ) - group_index = k - regex_use_counts[regex] += 1 - - if group_index is not None: - # we have a group - parameter_groups[group_index]["params"].append(param) - parameter_group_names[group_index].add(name) - else: - # the default group - parameter_groups[-1]["params"].append(param) - parameter_group_names[-1].add(name) - - # log the remaining parameter groups - logging.info("Constructed parameter groups:") - for k in range(len(parameter_groups)): - group_options = { - key: val for key, val in parameter_groups[k].items() if key != "params" - } - logging.info("Group %s, options: %s", k, group_options) - if verbose: - logging.info("Parameters: ") - for p in list(parameter_group_names[k]): - logging.info(p) - - # check for unused regex - for regex, count in regex_use_counts.items(): - if count == 0: - logging.warning( - "Parameter group regex %s does not match any parameter name.", - regex, - ) - return parameter_groups - - -def get_optim( - lr: float, total_steps: int, model: torch.nn.Module -) -> Tuple[torch.optim.Optimizer, Optional[torch.optim.lr_scheduler._LRScheduler],]: - """Configure optimizers, LR schedulers and EMA.""" - parameter_groups = [ - { - "filter_string": "(.*bias|.*layer_norm.*|.*batch_norm.*)", - "optimizer_kwargs": {"weight_decay": 0.0}, - } - ] - params = make_parameter_groups(model, parameter_groups) - opt = torch.optim.Adam(params, lr=lr) - - scheduler = torch.optim.lr_scheduler.OneCycleLR( - opt, max_lr=lr, total_steps=total_steps, pct_start=0.05 - ) - - return opt, scheduler diff --git a/orb_models/finetune_utilities/steps.py b/orb_models/finetune_utilities/steps.py deleted file mode 100644 index bd4bf63..0000000 --- a/orb_models/finetune_utilities/steps.py +++ /dev/null @@ -1,168 +0,0 @@ -from typing import List, Optional, Union, cast - -import torch -import tqdm -import wandb -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader -from wandb import wandb_run - -from orb_models.finetune_utilities import experiment - - -def gradient_clipping( - model: torch.nn.Module, clip_value: float -) -> List[torch.utils.hooks.RemovableHandle]: - """Add gradient clipping hooks to a model. - - This is the correct way to implement gradient clipping, because - gradients are clipped as gradients are computed, rather than after - all gradients are computed - this means expoding gradients are less likely, - because they are "caught" earlier. - - Args: - model: The model to add hooks to. - clip_value: The upper and lower threshold to clip the gradients to. - - Returns: - A list of handles to remove the hooks from the parameters. - """ - handles = [] - - def _clip(grad): - if grad is None: - return grad - return grad.clamp(min=-clip_value, max=clip_value) - - for parameter in model.parameters(): - if parameter.requires_grad: - h = parameter.register_hook(lambda grad: _clip(grad)) - handles.append(h) - - return handles - - -def fintune( - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - dataloader: DataLoader, - lr_scheduler: Optional[_LRScheduler] = None, - num_steps: Optional[int] = None, - clip_grad: Optional[float] = None, - log_freq: float = 10, - device: torch.device = torch.device("cpu"), - epoch: int = 0, -): - """Train for a fixed number of steps. - - Args: - model: The model to optimize. - optimizer: The optimizer for the model. - dataloader: A Pytorch Dataloader, which may be infinite if num_steps is passed. - lr_scheduler: Optional, a Learning rate scheduler for modifying the learning rate. - num_steps: The number of training steps to take. This is required for distributed training, - because controlling parallism is easier if all processes take exactly the same number of steps ( - this particularly applies when using dynamic batching). - clip_grad: Optional, the gradient clipping threshold. - log_freq: The logging frequency for step metrics. - device: The device to use for training. - epoch: The number of epochs the model has been fintuned. - - Returns - A dictionary of metrics. - """ - run: Optional[wandb_run.Run] = cast(Optional[wandb_run.Run], wandb.run) - - if clip_grad is not None: - hook_handles = gradient_clipping(model, clip_grad) - - metrics = experiment.ScalarMetricTracker() - - # Set the model to "train" mode. - model.train() - - # Get tqdm for the training batches - batch_generator = iter(dataloader) - num_training_batches: Union[int, float] - if num_steps is not None: - num_training_batches = num_steps - else: - try: - num_training_batches = len(dataloader) - except TypeError: - raise ValueError("Dataloader has no length, you must specify num_steps.") - - batch_generator_tqdm = tqdm.tqdm(batch_generator, total=num_training_batches) - - i = 0 - batch_iterator = iter(batch_generator_tqdm) - while True: - if num_steps and i == num_steps: - break - - optimizer.zero_grad(set_to_none=True) - - step_metrics = { - "batch_size": 0.0, - "batch_num_edges": 0.0, - "batch_num_nodes": 0.0, - } - - # Reset metrics so that it reports raw values for each step but still do averages on - # the gradient accumulation. - if i % log_freq == 0: - metrics.reset() - - batch = next(batch_iterator) - batch = batch.to(device) - step_metrics["batch_size"] += len(batch.n_node) - step_metrics["batch_num_edges"] += batch.n_edge.sum() - step_metrics["batch_num_nodes"] += batch.n_node.sum() - - with torch.cuda.amp.autocast(enabled=False): - batch_outputs = model.loss(batch) - loss = batch_outputs.loss - metrics.update(batch_outputs.log) - if torch.isnan(loss): - raise ValueError("nan loss encountered") - loss.backward() - - optimizer.step() - - if lr_scheduler is not None: - lr_scheduler.step() - - metrics.update(step_metrics) - - if i != 0 and i % log_freq == 0: - metrics_dict = metrics.get_metrics() - if run is not None: - global_step = (epoch * num_training_batches) + i - if run.sweep_id is not None: - run.log( - {"loss": metrics_dict["loss"]}, - commit=False, - ) - run.log( - {"global_step": global_step}, - commit=False, - ) - run.log( - experiment.prefix_keys(metrics_dict, "train_step"), commit=False - ) - # Log learning rates. - run.log( - { - f"pg_{idx}": group["lr"] - for idx, group in enumerate(optimizer.param_groups) - }, - ) - - # Finished a single full step! - i += 1 - - if clip_grad is not None: - for h in hook_handles: - h.remove() - - return metrics.get_metrics() diff --git a/orb_models/forcefield/atomic_system.py b/orb_models/forcefield/atomic_system.py index 877801d..471ef17 100644 --- a/orb_models/forcefield/atomic_system.py +++ b/orb_models/forcefield/atomic_system.py @@ -8,7 +8,8 @@ from orb_models.forcefield import featurization_utilities from orb_models.forcefield.base import AtomGraphs -from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition +from orb_models.forcefield.property_definitions import (PROPERTIES, + PropertyDefinition) @dataclass @@ -26,7 +27,6 @@ class SystemConfig: use_timestep_0: bool = True -@dataclass class PropertyConfig: """Defines which properties should be calculated and stored on the AtomGraphs batch. @@ -247,27 +247,3 @@ def ase_fix_atoms_to_tensor(atoms: ase.Atoms) -> Optional[torch.Tensor]: fixed_atoms = torch.zeros((len(atoms)), dtype=torch.bool) fixed_atoms[constraint.index] = True return fixed_atoms - - -def make_property_definitions_from_config( - config: Optional[Dict] = None, -) -> PropertyConfig: - """Get PropertyConfig object from config.""" - if config is None: - return PropertyConfig() - assert all( - key in ["node", "edge", "graph"] for key in config - ), "Only node, edge and graph properties are supported." - - node_properties = edge_properties = graph_properties = None - if config.get("node"): - node_properties = [name for name in config["node"]] - if config.get("edge"): - edge_properties = [name for name in config["edge"]] - if config.get("graph"): - graph_properties = [name for name in config["graph"]] - return PropertyConfig( - node_names=node_properties, - edge_names=edge_properties, - graph_names=graph_properties, - ) diff --git a/orb_models/forcefield/base.py b/orb_models/forcefield/base.py index 63b9422..499fc53 100644 --- a/orb_models/forcefield/base.py +++ b/orb_models/forcefield/base.py @@ -2,7 +2,8 @@ from collections import defaultdict from copy import deepcopy -from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union +from typing import (Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, + Union) import torch import tree diff --git a/orb_models/forcefield/calculator.py b/orb_models/forcefield/calculator.py index d785438..cab0ace 100644 --- a/orb_models/forcefield/calculator.py +++ b/orb_models/forcefield/calculator.py @@ -3,7 +3,8 @@ import torch from ase.calculators.calculator import Calculator, all_changes -from orb_models.forcefield.atomic_system import SystemConfig, ase_atoms_to_atom_graphs +from orb_models.forcefield.atomic_system import (SystemConfig, + ase_atoms_to_atom_graphs) from orb_models.forcefield.graph_regressor import GraphRegressor diff --git a/orb_models/forcefield/featurization_utilities.py b/orb_models/forcefield/featurization_utilities.py index 7ae4e6b..a90812d 100644 --- a/orb_models/forcefield/featurization_utilities.py +++ b/orb_models/forcefield/featurization_utilities.py @@ -4,7 +4,6 @@ import numpy as np import torch - # TODO(Mark): Make pynanoflann optional from pynanoflann import KDTree as NanoKDTree from scipy.spatial import KDTree as SciKDTree diff --git a/orb_models/forcefield/graph_regressor.py b/orb_models/forcefield/graph_regressor.py index 167523e..73d83d2 100644 --- a/orb_models/forcefield/graph_regressor.py +++ b/orb_models/forcefield/graph_regressor.py @@ -7,7 +7,8 @@ from orb_models.forcefield import base, segment_ops from orb_models.forcefield.gns import _KEY, MoleculeGNS from orb_models.forcefield.nn_util import build_mlp -from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition +from orb_models.forcefield.property_definitions import (PROPERTIES, + PropertyDefinition) from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES global HAS_WARNED_FOR_TF32_MATMUL diff --git a/orb_models/forcefield/pretrained.py b/orb_models/forcefield/pretrained.py index ecb826e..f8ddced 100644 --- a/orb_models/forcefield/pretrained.py +++ b/orb_models/forcefield/pretrained.py @@ -6,12 +6,8 @@ from orb_models.forcefield.featurization_utilities import get_device from orb_models.forcefield.gns import MoleculeGNS -from orb_models.forcefield.graph_regressor import ( - EnergyHead, - GraphHead, - GraphRegressor, - NodeHead, -) +from orb_models.forcefield.graph_regressor import (EnergyHead, GraphHead, + GraphRegressor, NodeHead) from orb_models.forcefield.rbf import ExpNormalSmearing global HAS_MESSAGED_FOR_TF32_MATMUL diff --git a/orb_models/finetune_utilities/experiment.py b/orb_models/utils.py similarity index 55% rename from orb_models/finetune_utilities/experiment.py rename to orb_models/utils.py index 17a605e..3ac9322 100644 --- a/orb_models/finetune_utilities/experiment.py +++ b/orb_models/utils.py @@ -2,11 +2,10 @@ import os import random +import re from collections import defaultdict -from pathlib import Path -from typing import Dict, Mapping, TypeVar +from typing import Dict, List, Mapping, Optional, Tuple, TypeVar -import dotenv import numpy import torch import wandb @@ -17,22 +16,6 @@ T = TypeVar("T") -_V = TypeVar("_V", int, float, torch.Tensor) - -dotenv.load_dotenv(override=True) -PROJECT_ROOT: Path = Path( - os.environ.get("PROJECT_ROOT", str(Path(__file__).parent.parent)) -) -assert ( - PROJECT_ROOT.exists() -), "You must configure the PROJECT_ROOT environment variable in a .env file!" - -DATA_ROOT: Path = Path(os.environ.get("DATA_ROOT", default=str(PROJECT_ROOT / "data"))) -WANDB_ROOT: Path = Path( - os.environ.get("WANDB_ROOT", default=str(PROJECT_ROOT / "wandb")) -) - - def init_device() -> torch.device: """Initialize a device. @@ -78,22 +61,13 @@ def seed_everything(seed: int, rank: int = 0) -> None: torch.manual_seed(seed + rank) -def init_wandb_from_config(args, job_type: str) -> wandb_run.Run: - """Initialise wandb from config.""" - if not hasattr(args, "wandb_name"): - run_name = f"{job_type}-test" - else: - run_name = args.name - if not hasattr(args, "wandb_project"): - project = "orb-experiment" - else: - project = args.project - +def init_wandb_from_config(job_type: str) -> wandb_run.Run: + """Initialise wandb.""" wandb.init( # type: ignore job_type=job_type, dir=os.path.join(os.getcwd(), "wandb"), - name=run_name, - project=project, + name=f"{job_type}-test", + project="orb-experiment", entity="orbitalmaterials", mode="online", sync_tensorboard=False, @@ -126,3 +100,61 @@ def update(self, metrics: Mapping[str, base.Metric]) -> None: def get_metrics(self): """Get the metric values, possibly reducing across gpu processes.""" return {k: to_item(v) / self.counts[k] for k, v in self.sums.items()} + + +def gradient_clipping( + model: torch.nn.Module, clip_value: float +) -> List[torch.utils.hooks.RemovableHandle]: + """Add gradient clipping hooks to a model. + + This is the correct way to implement gradient clipping, because + gradients are clipped as gradients are computed, rather than after + all gradients are computed - this means expoding gradients are less likely, + because they are "caught" earlier. + + Args: + model: The model to add hooks to. + clip_value: The upper and lower threshold to clip the gradients to. + + Returns: + A list of handles to remove the hooks from the parameters. + """ + handles = [] + + def _clip(grad): + if grad is None: + return grad + return grad.clamp(min=-clip_value, max=clip_value) + + for parameter in model.parameters(): + if parameter.requires_grad: + h = parameter.register_hook(lambda grad: _clip(grad)) + handles.append(h) + + return handles + + +def get_optim( + lr: float, total_steps: int, model: torch.nn.Module +) -> Tuple[torch.optim.Optimizer, Optional[torch.optim.lr_scheduler._LRScheduler]]: + """Configure optimizers, LR schedulers and EMA.""" + + # Initialize parameter groups + params = [] + + # Split parameters based on the regex + for name, param in model.named_parameters(): + if re.search(r"(.*bias|.*layer_norm.*|.*batch_norm.*)", name): + params.append({"params": param, "weight_decay": 0.0}) + else: + params.append({"params": param}) + + # Create the optimizer with the parameter groups + optimizer = torch.optim.Adam(params, lr=lr) + + # Create the learning rate scheduler + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, max_lr=lr, total_steps=total_steps, pct_start=0.05 + ) + + return optimizer, scheduler diff --git a/tests/conftest.py b/tests/conftest.py index 6217930..6b42349 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ -import pytest from pathlib import Path +import pytest + @pytest.fixture(scope="module") def fixtures_path(request): diff --git a/tests/test_base.py b/tests/test_base.py index f48b7b0..1eb332f 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,6 +1,7 @@ # type: ignore import pytest import torch + from orb_models.forcefield import base from orb_models.forcefield.base import refeaturize_atomgraphs diff --git a/tests/test_calculator.py b/tests/test_calculator.py index a94da0c..c904e10 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -1,8 +1,8 @@ +import ase import numpy as np +import pytest import torch import torch.nn as nn -import ase -import pytest from orb_models.forcefield import segment_ops from orb_models.forcefield.calculator import ORBCalculator diff --git a/tests/test_featurization_utilities.py b/tests/test_featurization_utilities.py index 7a33561..b0cc9ca 100644 --- a/tests/test_featurization_utilities.py +++ b/tests/test_featurization_utilities.py @@ -1,6 +1,7 @@ """Tests featurization utilities.""" import functools + import ase import ase.io import ase.neighborlist diff --git a/tests/test_segment_ops.py b/tests/test_segment_ops.py index 16ca66e..d62ce02 100644 --- a/tests/test_segment_ops.py +++ b/tests/test_segment_ops.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from orb_models.forcefield import segment_ops import pytest import torch +from orb_models.forcefield import segment_ops + @pytest.mark.parametrize( ("reduction, dtype"),