Skip to content

Commit

Permalink
Wrap by default (#41)
Browse files Browse the repository at this point in the history
* Wrap by default

* Improve defaults

* Fix cell shape

* Add tests

---------

Co-authored-by: ben rhodes <[email protected]>
  • Loading branch information
benrhodes26 and ben rhodes authored Dec 19, 2024
1 parent f40b78e commit a4686b3
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 18 deletions.
45 changes: 27 additions & 18 deletions orb_models/forcefield/atomic_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,46 +90,57 @@ def atom_graphs_to_ase_atoms(

def ase_atoms_to_atom_graphs(
atoms: ase.Atoms,
system_config: SystemConfig = SystemConfig(
radius=10.0, max_num_neighbors=20, use_timestep_0=True
),
system_id: Optional[int] = None,
*,
wrap: bool = True,
brute_force_knn: Optional[bool] = None,
device: Optional[torch.device] = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
device: Optional[torch.device] = None,
system_config: Optional[SystemConfig] = None,
system_id: Optional[int] = None,
) -> AtomGraphs:
"""Generate AtomGraphs from an ase.Atoms object.
Args:
atoms: ase.Atoms object
system_config: SystemConfig object
system_id: Optional system_id
wrap: whether to wrap atomic positions into the central unit cell (if there is one).
NOTE: there can be numerical differences from ase's .wrap() method when an atom is near a cell boundary.
brute_force_knn: whether to use a 'brute force' knn approach with torch.cdist for kdtree construction.
Defaults to None, in which case brute_force is used if we a GPU is avaiable (2-6x faster),
but not on CPU (1.5x faster - 4x slower). For very large systems, brute_force may OOM on GPU,
so it is recommended to set to False in that case.
device: device to put the tensors on.
device: device to put the tensors on. By default, uses the GPU if available.
system_config: SystemConfig object, specifying the max radius and max num_neighbors
used in the k-nearest neighbors graph construction.
system_id: Optional index, for tracking the identity of a datapoint.
Returns:
AtomGraphs object
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if system_config is None:
system_config = SystemConfig(radius=10.0, max_num_neighbors=20)

atomic_numbers = torch.from_numpy(atoms.numbers).to(torch.long)
atom_type_embedding = torch.nn.functional.one_hot(
atomic_numbers, num_classes=118
).type(torch.float32)

positions = torch.from_numpy(atoms.positions).to(torch.float32)
cell = torch.from_numpy(atoms.cell.array).to(torch.float32)
if wrap and torch.any(cell != 0):
positions = featurization_utilities.map_to_pbc_cell(positions, cell)

node_feats = {
"atomic_numbers": atomic_numbers.to(torch.int64),
"atomic_numbers_embedding": atom_type_embedding.to(torch.float32),
# NOTE: positions are stored as features on the AtomGraphs,
# but not actually used as input features to the model.
"positions": torch.from_numpy(atoms.positions).to(torch.float32),
"positions": positions,
}
system_feats = {"cell": torch.Tensor(atoms.cell.array[None, ...]).to(torch.float)}
system_feats = {"cell": cell.unsqueeze(0)}
edge_feats, senders, receivers = _get_edge_feats(
node_feats["positions"], # type: ignore
system_feats["cell"][0],
positions,
cell,
system_config.radius,
system_config.max_num_neighbors,
brute_force=brute_force_knn,
Expand Down Expand Up @@ -159,10 +170,8 @@ def _get_edge_feats(
cell: torch.Tensor,
radius: float,
max_num_neighbours: int,
brute_force: Optional[bool] = None,
device: Optional[torch.device] = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
brute_force: Optional[bool],
device: torch.device,
):
"""Get edge features.
Expand Down
24 changes: 24 additions & 0 deletions orb_models/forcefield/featurization_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,30 @@ def compute_pbc_radius_graph(
return torch.stack((senders_torch, receivers), dim=0), vectors


def map_to_pbc_cell(
positions: torch.Tensor,
periodic_boundary_conditions: torch.Tensor,
) -> torch.Tensor:
"""Maps positions to within a periodic boundary cell.
Args:
positions (torch.Tensor): The positions to be mapped. Shape [num_particles, 3]
periodic_boundary_conditions (torch.Tensor): The periodic boundary conditions. Shape 3x3.
Returns:
torch.Tensor: Positions mapped to within a periodic boundary cell.
"""
# Inverses are a lot more reliable in double precision, so we'll do the whole
# thing in double then go back to single.
positions = positions.double()
periodic_boundary_conditions = periodic_boundary_conditions.double()
# The strategy here is to map our positions to fractional or internal coordinates.
# Then we take the modulo, then map back to euclidian co-ordinates.
fractional_pos = torch.linalg.solve(periodic_boundary_conditions.T, positions.T).T
fractional_pos = fractional_pos % 1.0
return (fractional_pos @ periodic_boundary_conditions).float()


def batch_map_to_pbc_cell(
positions: torch.Tensor,
periodic_boundary_conditions: torch.Tensor,
Expand Down
98 changes: 98 additions & 0 deletions tests/fixtures/AFI.cif
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# AFI zeolite
data_SiO2
_symmetry_space_group_name_H-M 'P 1'
_cell_length_a 13.86655914
_cell_length_b 13.86655914
_cell_length_c 8.60047456
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 120.00000000
_symmetry_Int_Tables_number 1
_chemical_formula_structural SiO2
_chemical_formula_sum 'Si24 O48'
_cell_volume 1432.15645170
_cell_formula_units_Z 24
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
O O0 1 0.457007 0.334333 0.000000 1
O O1 1 0.665666 0.122673 0.000000 1
O O2 1 0.877327 0.542993 0.000000 1
O O3 1 0.542993 0.665667 0.000000 1
O O4 1 0.334334 0.877327 0.000000 1
O O5 1 0.122674 0.457007 0.000000 1
O O6 1 0.334333 0.457007 0.499999 1
O O7 1 0.122674 0.665667 0.499999 1
O O8 1 0.542993 0.877327 0.499999 1
O O9 1 0.665666 0.542993 0.499999 1
O O10 1 0.877326 0.334333 0.499999 1
O O11 1 0.457006 0.122673 0.499999 1
O O12 1 0.367814 0.367814 0.250000 1
O O13 1 0.632186 0.000000 0.250000 1
O O14 1 0.000000 0.632186 0.250000 1
O O15 1 0.632186 0.632186 0.250000 1
O O16 1 0.367813 0.000000 0.250000 1
O O17 1 0.999999 0.367814 0.250000 1
O O18 1 0.632186 0.632186 0.750000 1
O O19 1 0.367813 0.000000 0.750000 1
O O20 1 0.999999 0.367814 0.750000 1
O O21 1 0.367814 0.367814 0.750000 1
O O22 1 0.632186 0.000000 0.750000 1
O O23 1 0.000000 0.632186 0.750000 1
O O24 1 0.417358 0.208679 0.250000 1
O O25 1 0.791321 0.208679 0.250000 1
O O26 1 0.791321 0.582642 0.250000 1
O O27 1 0.582642 0.791321 0.250000 1
O O28 1 0.208678 0.791321 0.250000 1
O O29 1 0.208679 0.417358 0.250000 1
O O30 1 0.582642 0.791321 0.750000 1
O O31 1 0.208678 0.791321 0.750000 1
O O32 1 0.208679 0.417358 0.750000 1
O O33 1 0.417358 0.208679 0.750000 1
O O34 1 0.791321 0.208679 0.750000 1
O O35 1 0.791321 0.582642 0.750000 1
O O36 1 0.581212 0.418789 0.250000 1
O O37 1 0.581212 0.162424 0.250000 1
O O38 1 0.837576 0.418789 0.250000 1
O O39 1 0.418788 0.581211 0.250000 1
O O40 1 0.418788 0.837576 0.250000 1
O O41 1 0.162423 0.581211 0.250000 1
O O42 1 0.418788 0.581211 0.750000 1
O O43 1 0.418788 0.837576 0.750000 1
O O44 1 0.162423 0.581211 0.750000 1
O O45 1 0.581212 0.418789 0.750000 1
O O46 1 0.581212 0.162424 0.750000 1
O O47 1 0.837576 0.418789 0.750000 1
Si Si48 1 0.456260 0.332886 0.187394 1
Si Si49 1 0.667114 0.123373 0.187394 1
Si Si50 1 0.876626 0.543741 0.187394 1
Si Si51 1 0.543740 0.667114 0.187394 1
Si Si52 1 0.332886 0.876626 0.187394 1
Si Si53 1 0.123375 0.456259 0.187394 1
Si Si54 1 0.332885 0.456259 0.312606 1
Si Si55 1 0.123373 0.667114 0.312606 1
Si Si56 1 0.543740 0.876626 0.312606 1
Si Si57 1 0.667115 0.543741 0.312606 1
Si Si58 1 0.876626 0.332886 0.312606 1
Si Si59 1 0.456260 0.123373 0.312606 1
Si Si60 1 0.543740 0.667114 0.812605 1
Si Si61 1 0.332886 0.876626 0.812605 1
Si Si62 1 0.123375 0.456259 0.812605 1
Si Si63 1 0.456260 0.332886 0.812605 1
Si Si64 1 0.667114 0.123373 0.812605 1
Si Si65 1 0.876626 0.543741 0.812605 1
Si Si66 1 0.667115 0.543741 0.687394 1
Si Si67 1 0.876626 0.332886 0.687394 1
Si Si68 1 0.456260 0.123373 0.687394 1
Si Si69 1 0.332885 0.456259 0.687394 1
Si Si70 1 0.123373 0.667114 0.687394 1
Si Si71 1 0.543740 0.876626 0.687394 1
48 changes: 48 additions & 0 deletions tests/test_atomic_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import ase.io
import numpy as np
import torch
from orb_models.forcefield.base import batch_graphs
from orb_models.forcefield.atomic_system import (
atom_graphs_to_ase_atoms,
ase_atoms_to_atom_graphs,
)


def test_atoms_to_atom_graphs_invertibility(fixtures_path):
atoms = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif"))

atom_graphs = ase_atoms_to_atom_graphs(atoms, wrap=False)
recovered_atoms = atom_graphs_to_ase_atoms(atom_graphs)[0]

assert np.allclose(recovered_atoms.positions, atoms.positions)
assert np.allclose(recovered_atoms.cell, atoms.cell)
assert (recovered_atoms.numbers == atoms.numbers).all()


def test_atom_graphs_to_ase_atoms_debatches(fixtures_path):
atoms = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif"))
graphs = [ase_atoms_to_atom_graphs(atoms, wrap=False) for _ in range(4)]
batch = batch_graphs(graphs)
atoms_list = atom_graphs_to_ase_atoms(batch)
assert len(atoms_list) == 4
assert (atoms_list[0].positions == atoms_list[1].positions).all()
assert (atoms_list[0].get_tags() == atoms_list[1].get_tags()).all()


def test_ase_atoms_to_atom_graphs_wraps(fixtures_path):
atoms_unwrapped = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif"))
atoms_unwrapped.positions[:10] += 2.0 * atoms_unwrapped.cell.array.max()
atoms_wrapped = atoms_unwrapped.copy()
atoms_wrapped.wrap()
assert not np.allclose(atoms_wrapped.positions, atoms_unwrapped.positions)

atom_graphs = ase_atoms_to_atom_graphs(atoms_unwrapped, wrap=False)
assert np.allclose(atom_graphs.positions.numpy(), atoms_unwrapped.positions)

# Note: this test is slightly indirect. We can't test that wrap=True yields the same
# results as ase's .wrap(), because of slight numerical differences at the boundaries.
# Instead, we test that wrap=True for an unwrapped system yields the same results
# as wrap=True for an ase-wrapped system.
atom_graphs1 = ase_atoms_to_atom_graphs(atoms_unwrapped, wrap=True)
atom_graphs2 = ase_atoms_to_atom_graphs(atoms_wrapped, wrap=True)
assert torch.allclose(atom_graphs1.positions, atom_graphs2.positions)

0 comments on commit a4686b3

Please sign in to comment.