From 39841c4b65f84c86af0fbe28246de87d57e23880 Mon Sep 17 00:00:00 2001 From: ben rhodes Date: Tue, 3 Sep 2024 14:22:54 +0100 Subject: [PATCH 1/2] Various fixes --- orb_models/forcefield/graph_regressor.py | 2 +- orb_models/forcefield/pretrained.py | 51 ++++++++++++++++-------- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/orb_models/forcefield/graph_regressor.py b/orb_models/forcefield/graph_regressor.py index 4f328f9..e105e4d 100644 --- a/orb_models/forcefield/graph_regressor.py +++ b/orb_models/forcefield/graph_regressor.py @@ -27,7 +27,7 @@ def warn_for_tf32_matmul(): ): print( "Warning! You are using a model on the GPU without enabling tensorfloat matmuls." - "This is 2x slower than enabling this flag." + "This can be up to 2x slower than enabling this flag." "Enable it with torch.set_float32_matmul_precision('high')" ) HAS_WARNED_FOR_TF32_MATMUL = True diff --git a/orb_models/forcefield/pretrained.py b/orb_models/forcefield/pretrained.py index 15433fc..c5f1e97 100644 --- a/orb_models/forcefield/pretrained.py +++ b/orb_models/forcefield/pretrained.py @@ -1,6 +1,8 @@ # flake8: noqa: E501 +from typing import Union import torch from cached_path import cached_path +from orb_models.forcefield.featurization_utilities import get_device from orb_models.forcefield.graph_regressor import ( EnergyHead, NodeHead, @@ -11,6 +13,9 @@ from orb_models.forcefield.rbf import ExpNormalSmearing +torch.set_float32_matmul_precision('high') + + def get_base( latent_dim: int = 256, mlp_hidden_dim: int = 512, @@ -33,9 +38,28 @@ def get_base( ) +def load_model_for_inference( + model: torch.nn.Module, + weights_path: str, + device: Union[torch.device, str] = None, +) -> torch.nn.Module: + """Load a pretrained model in inference mode, using GPU if available.""" + local_path = cached_path(weights_path) + state_dict = torch.load(local_path, map_location="cpu") + + model.load_state_dict(state_dict, strict=True) + model = model.to(get_device(device)) + + model = model.eval() + for param in model.parameters(): + param.requires_grad = False + + return model + + def orb_v1( weights_path: str = "https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/orbff-v1-20240827.ckpt", # noqa: E501 - # NOTE: Use https scheme for weights so that folks can download without gcloud auth. + device: Union[torch.device, str] = None, ): """Load ORB v1.""" base = get_base() @@ -67,15 +91,14 @@ def orb_v1( model=base, ) - local_path = cached_path(weights_path) - state_dict = torch.load(local_path, map_location="cpu") - model.load_state_dict(state_dict, strict=True) + model = load_model_for_inference(model, weights_path, device) return model def orb_d3_v1( weights_path: str = "https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/orb-d3-v1-20240902.ckpt", + device: Union[torch.device, str] = None, ): """ORB v1 with D3 corrections.""" base = get_base() @@ -107,15 +130,14 @@ def orb_d3_v1( model=base, ) - local_path = cached_path(weights_path) - state_dict = torch.load(local_path, map_location="cpu") - model.load_state_dict(state_dict, strict=True) + model = load_model_for_inference(model, weights_path, device) return model def orb_d3_sm_v1( weights_path: str = "https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/orb-d3-sm-v1-20240902.ckpt", + device: Union[torch.device, str] = None, ): """A 10 layer model pretrained on bulk data.""" base = get_base(num_message_passing_steps=10) @@ -147,15 +169,14 @@ def orb_d3_sm_v1( model=base, ) - local_path = cached_path(weights_path) - state_dict = torch.load(local_path, map_location="cpu") - model.load_state_dict(state_dict, strict=True) + model = load_model_for_inference(model, weights_path, device) return model def orb_d3_xs_v1( weights_path: str = "https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/orb-d3-xs-v1-20240902.ckpt", + device: Union[torch.device, str] = None, ): """A 5 layer model pretrained on bulk data.""" base = get_base(num_message_passing_steps=5) @@ -186,14 +207,14 @@ def orb_d3_xs_v1( model=base, ) - local_path = cached_path(weights_path) - state_dict = torch.load(local_path, map_location="cpu") - model.load_state_dict(state_dict, strict=True) + model = load_model_for_inference(model, weights_path, device) + return model def orb_v1_mptraj_only( weights_path: str = "https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/orbff-mptraj-only-v1-20240827.ckpt", + device: Union[torch.device, str] = None, ): """A 10 layer model pretrained on bulk data.""" base = get_base() @@ -225,9 +246,7 @@ def orb_v1_mptraj_only( model=base, ) - local_path = cached_path(weights_path) - state_dict = torch.load(local_path, map_location="cpu") - model.load_state_dict(state_dict, strict=True) + model = load_model_for_inference(model, weights_path, device) return model From 88f2980f86f1de49ce38f43216ceb540178d6b6f Mon Sep 17 00:00:00 2001 From: ben rhodes Date: Tue, 3 Sep 2024 14:32:57 +0100 Subject: [PATCH 2/2] Deduplicate and relax ase version --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9acc48d..a882a79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,11 +17,10 @@ classifiers = [ dependencies = [ "cached_path>=1.6.2", - "ase==3.22.1", + "ase>=3.22.1", "numpy<2.0.0", "scipy>=1.13.1", "torch==2.2.0", - "ase==3.22.1", "dm-tree>=0.1.8", ]