Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use correct device, precision and inference mode #4

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion orb_models/forcefield/graph_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def warn_for_tf32_matmul():
):
print(
"Warning! You are using a model on the GPU without enabling tensorfloat matmuls."
"This is 2x slower than enabling this flag."
"This can be up to 2x slower than enabling this flag."
"Enable it with torch.set_float32_matmul_precision('high')"
)
HAS_WARNED_FOR_TF32_MATMUL = True
Expand Down
51 changes: 35 additions & 16 deletions orb_models/forcefield/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# flake8: noqa: E501
from typing import Union
import torch
from cached_path import cached_path
from orb_models.forcefield.featurization_utilities import get_device
from orb_models.forcefield.graph_regressor import (
EnergyHead,
NodeHead,
Expand All @@ -11,6 +13,9 @@
from orb_models.forcefield.rbf import ExpNormalSmearing


torch.set_float32_matmul_precision('high')


def get_base(
latent_dim: int = 256,
mlp_hidden_dim: int = 512,
Expand All @@ -33,9 +38,28 @@ def get_base(
)


def load_model_for_inference(
model: torch.nn.Module,
weights_path: str,
device: Union[torch.device, str] = None,
) -> torch.nn.Module:
"""Load a pretrained model in inference mode, using GPU if available."""
local_path = cached_path(weights_path)
state_dict = torch.load(local_path, map_location="cpu")

model.load_state_dict(state_dict, strict=True)
model = model.to(get_device(device))

model = model.eval()
for param in model.parameters():
param.requires_grad = False

return model


def orb_v1(
weights_path: str = "https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/orbff-v1-20240827.ckpt", # noqa: E501
# NOTE: Use https scheme for weights so that folks can download without gcloud auth.
device: Union[torch.device, str] = None,
):
"""Load ORB v1."""
base = get_base()
Expand Down Expand Up @@ -67,15 +91,14 @@ def orb_v1(
model=base,
)

local_path = cached_path(weights_path)
state_dict = torch.load(local_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model = load_model_for_inference(model, weights_path, device)

return model


def orb_d3_v1(
weights_path: str = "https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/orb-d3-v1-20240902.ckpt",
device: Union[torch.device, str] = None,
):
"""ORB v1 with D3 corrections."""
base = get_base()
Expand Down Expand Up @@ -107,15 +130,14 @@ def orb_d3_v1(
model=base,
)

local_path = cached_path(weights_path)
state_dict = torch.load(local_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model = load_model_for_inference(model, weights_path, device)

return model


def orb_d3_sm_v1(
weights_path: str = "https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/orb-d3-sm-v1-20240902.ckpt",
device: Union[torch.device, str] = None,
):
"""A 10 layer model pretrained on bulk data."""
base = get_base(num_message_passing_steps=10)
Expand Down Expand Up @@ -147,15 +169,14 @@ def orb_d3_sm_v1(
model=base,
)

local_path = cached_path(weights_path)
state_dict = torch.load(local_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model = load_model_for_inference(model, weights_path, device)

return model


def orb_d3_xs_v1(
weights_path: str = "https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/orb-d3-xs-v1-20240902.ckpt",
device: Union[torch.device, str] = None,
):
"""A 5 layer model pretrained on bulk data."""
base = get_base(num_message_passing_steps=5)
Expand Down Expand Up @@ -186,14 +207,14 @@ def orb_d3_xs_v1(
model=base,
)

local_path = cached_path(weights_path)
state_dict = torch.load(local_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model = load_model_for_inference(model, weights_path, device)

return model


def orb_v1_mptraj_only(
weights_path: str = "https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/orbff-mptraj-only-v1-20240827.ckpt",
device: Union[torch.device, str] = None,
):
"""A 10 layer model pretrained on bulk data."""
base = get_base()
Expand Down Expand Up @@ -225,9 +246,7 @@ def orb_v1_mptraj_only(
model=base,
)

local_path = cached_path(weights_path)
state_dict = torch.load(local_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model = load_model_for_inference(model, weights_path, device)

return model

Expand Down