Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device issue #17

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@ from orb_models.forcefield import pretrained
from orb_models.forcefield import atomic_system
from orb_models.forcefield.base import batch_graphs

orbff = pretrained.orb_v1()
device = "cpu" # or device="cuda"
orbff = pretrained.orb_v1(device=device)
atoms = bulk('Cu', 'fcc', a=3.58, cubic=True)
graph = atomic_system.ase_atoms_to_atom_graphs(atoms)
graph = atomic_system.ase_atoms_to_atom_graphs(atoms, device=device)

# Optionally, batch graphs for faster inference
# graph = batch_graphs([graph, graph, ...])

result = orbff.predict(graph)

# Convert to ASE atoms (this will also unbatch the results)
# Convert to ASE atoms (unbatches the results and transfers to cpu if necessary)
atoms = atomic_system.atom_graphs_to_ase_atoms(
graph,
energy=result["graph_pred"],
Expand Down
5 changes: 4 additions & 1 deletion orb_models/forcefield/atomic_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def ase_atoms_to_atom_graphs(
),
system_id: Optional[int] = None,
brute_force_knn: Optional[bool] = None,
device: Optional[torch.device] = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
) -> AtomGraphs:
"""Generate AtomGraphs from an ase.Atoms object.

Expand All @@ -107,6 +108,7 @@ def ase_atoms_to_atom_graphs(
Defaults to None, in which case brute_force is used if we a GPU is avaiable (2-6x faster),
but not on CPU (1.5x faster - 4x slower). For very large systems, brute_force may OOM on GPU,
so it is recommended to set to False in that case.
device: device to put the tensors on.

Returns:
AtomGraphs object
Expand All @@ -133,7 +135,7 @@ def ase_atoms_to_atom_graphs(
)

num_atoms = len(node_feats["positions"]) # type: ignore
return AtomGraphs(
atom_graph = AtomGraphs(
senders=senders,
receivers=receivers,
n_node=torch.tensor([num_atoms]),
Expand All @@ -147,6 +149,7 @@ def ase_atoms_to_atom_graphs(
radius=system_config.radius,
max_num_neighbors=system_config.max_num_neighbors,
)
return atom_graph.to(device)


def _get_edge_feats(
Expand Down