Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap by default #41

Merged
merged 4 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions orb_models/forcefield/atomic_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,46 +90,57 @@ def atom_graphs_to_ase_atoms(

def ase_atoms_to_atom_graphs(
atoms: ase.Atoms,
system_config: SystemConfig = SystemConfig(
radius=10.0, max_num_neighbors=20, use_timestep_0=True
),
system_id: Optional[int] = None,
*,
wrap: bool = True,
brute_force_knn: Optional[bool] = None,
device: Optional[torch.device] = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
device: Optional[torch.device] = None,
system_config: Optional[SystemConfig] = None,
system_id: Optional[int] = None,
) -> AtomGraphs:
"""Generate AtomGraphs from an ase.Atoms object.

Args:
atoms: ase.Atoms object
system_config: SystemConfig object
system_id: Optional system_id
wrap: whether to wrap atomic positions into the central unit cell (if there is one).
NOTE: there can be numerical differences from ase's .wrap() method when an atom is near a cell boundary.
brute_force_knn: whether to use a 'brute force' knn approach with torch.cdist for kdtree construction.
Defaults to None, in which case brute_force is used if we a GPU is avaiable (2-6x faster),
but not on CPU (1.5x faster - 4x slower). For very large systems, brute_force may OOM on GPU,
so it is recommended to set to False in that case.
device: device to put the tensors on.
device: device to put the tensors on. By default, uses the GPU if available.
system_config: SystemConfig object, specifying the max radius and max num_neighbors
used in the k-nearest neighbors graph construction.
system_id: Optional index, for tracking the identity of a datapoint.

Returns:
AtomGraphs object
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if system_config is None:
system_config = SystemConfig(radius=10.0, max_num_neighbors=20)

atomic_numbers = torch.from_numpy(atoms.numbers).to(torch.long)
atom_type_embedding = torch.nn.functional.one_hot(
atomic_numbers, num_classes=118
).type(torch.float32)

positions = torch.from_numpy(atoms.positions).to(torch.float32)
cell = torch.from_numpy(atoms.cell.array).to(torch.float32)
if wrap and torch.any(cell != 0):
positions = featurization_utilities.map_to_pbc_cell(positions, cell)

node_feats = {
"atomic_numbers": atomic_numbers.to(torch.int64),
"atomic_numbers_embedding": atom_type_embedding.to(torch.float32),
# NOTE: positions are stored as features on the AtomGraphs,
# but not actually used as input features to the model.
"positions": torch.from_numpy(atoms.positions).to(torch.float32),
"positions": positions,
}
system_feats = {"cell": torch.Tensor(atoms.cell.array[None, ...]).to(torch.float)}
system_feats = {"cell": cell.unsqueeze(0)}
edge_feats, senders, receivers = _get_edge_feats(
node_feats["positions"], # type: ignore
system_feats["cell"][0],
positions,
cell,
system_config.radius,
system_config.max_num_neighbors,
brute_force=brute_force_knn,
Expand Down Expand Up @@ -159,10 +170,8 @@ def _get_edge_feats(
cell: torch.Tensor,
radius: float,
max_num_neighbours: int,
brute_force: Optional[bool] = None,
device: Optional[torch.device] = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
brute_force: Optional[bool],
device: torch.device,
):
"""Get edge features.

Expand Down
24 changes: 24 additions & 0 deletions orb_models/forcefield/featurization_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,30 @@ def compute_pbc_radius_graph(
return torch.stack((senders_torch, receivers), dim=0), vectors


def map_to_pbc_cell(
positions: torch.Tensor,
periodic_boundary_conditions: torch.Tensor,
) -> torch.Tensor:
"""Maps positions to within a periodic boundary cell.

Args:
positions (torch.Tensor): The positions to be mapped. Shape [num_particles, 3]
periodic_boundary_conditions (torch.Tensor): The periodic boundary conditions. Shape 3x3.

Returns:
torch.Tensor: Positions mapped to within a periodic boundary cell.
"""
# Inverses are a lot more reliable in double precision, so we'll do the whole
# thing in double then go back to single.
positions = positions.double()
periodic_boundary_conditions = periodic_boundary_conditions.double()
# The strategy here is to map our positions to fractional or internal coordinates.
# Then we take the modulo, then map back to euclidian co-ordinates.
fractional_pos = torch.linalg.solve(periodic_boundary_conditions.T, positions.T).T
fractional_pos = fractional_pos % 1.0
return (fractional_pos @ periodic_boundary_conditions).float()


def batch_map_to_pbc_cell(
positions: torch.Tensor,
periodic_boundary_conditions: torch.Tensor,
Expand Down
98 changes: 98 additions & 0 deletions tests/fixtures/AFI.cif
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# AFI zeolite
data_SiO2
_symmetry_space_group_name_H-M 'P 1'
_cell_length_a 13.86655914
_cell_length_b 13.86655914
_cell_length_c 8.60047456
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 120.00000000
_symmetry_Int_Tables_number 1
_chemical_formula_structural SiO2
_chemical_formula_sum 'Si24 O48'
_cell_volume 1432.15645170
_cell_formula_units_Z 24
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
O O0 1 0.457007 0.334333 0.000000 1
O O1 1 0.665666 0.122673 0.000000 1
O O2 1 0.877327 0.542993 0.000000 1
O O3 1 0.542993 0.665667 0.000000 1
O O4 1 0.334334 0.877327 0.000000 1
O O5 1 0.122674 0.457007 0.000000 1
O O6 1 0.334333 0.457007 0.499999 1
O O7 1 0.122674 0.665667 0.499999 1
O O8 1 0.542993 0.877327 0.499999 1
O O9 1 0.665666 0.542993 0.499999 1
O O10 1 0.877326 0.334333 0.499999 1
O O11 1 0.457006 0.122673 0.499999 1
O O12 1 0.367814 0.367814 0.250000 1
O O13 1 0.632186 0.000000 0.250000 1
O O14 1 0.000000 0.632186 0.250000 1
O O15 1 0.632186 0.632186 0.250000 1
O O16 1 0.367813 0.000000 0.250000 1
O O17 1 0.999999 0.367814 0.250000 1
O O18 1 0.632186 0.632186 0.750000 1
O O19 1 0.367813 0.000000 0.750000 1
O O20 1 0.999999 0.367814 0.750000 1
O O21 1 0.367814 0.367814 0.750000 1
O O22 1 0.632186 0.000000 0.750000 1
O O23 1 0.000000 0.632186 0.750000 1
O O24 1 0.417358 0.208679 0.250000 1
O O25 1 0.791321 0.208679 0.250000 1
O O26 1 0.791321 0.582642 0.250000 1
O O27 1 0.582642 0.791321 0.250000 1
O O28 1 0.208678 0.791321 0.250000 1
O O29 1 0.208679 0.417358 0.250000 1
O O30 1 0.582642 0.791321 0.750000 1
O O31 1 0.208678 0.791321 0.750000 1
O O32 1 0.208679 0.417358 0.750000 1
O O33 1 0.417358 0.208679 0.750000 1
O O34 1 0.791321 0.208679 0.750000 1
O O35 1 0.791321 0.582642 0.750000 1
O O36 1 0.581212 0.418789 0.250000 1
O O37 1 0.581212 0.162424 0.250000 1
O O38 1 0.837576 0.418789 0.250000 1
O O39 1 0.418788 0.581211 0.250000 1
O O40 1 0.418788 0.837576 0.250000 1
O O41 1 0.162423 0.581211 0.250000 1
O O42 1 0.418788 0.581211 0.750000 1
O O43 1 0.418788 0.837576 0.750000 1
O O44 1 0.162423 0.581211 0.750000 1
O O45 1 0.581212 0.418789 0.750000 1
O O46 1 0.581212 0.162424 0.750000 1
O O47 1 0.837576 0.418789 0.750000 1
Si Si48 1 0.456260 0.332886 0.187394 1
Si Si49 1 0.667114 0.123373 0.187394 1
Si Si50 1 0.876626 0.543741 0.187394 1
Si Si51 1 0.543740 0.667114 0.187394 1
Si Si52 1 0.332886 0.876626 0.187394 1
Si Si53 1 0.123375 0.456259 0.187394 1
Si Si54 1 0.332885 0.456259 0.312606 1
Si Si55 1 0.123373 0.667114 0.312606 1
Si Si56 1 0.543740 0.876626 0.312606 1
Si Si57 1 0.667115 0.543741 0.312606 1
Si Si58 1 0.876626 0.332886 0.312606 1
Si Si59 1 0.456260 0.123373 0.312606 1
Si Si60 1 0.543740 0.667114 0.812605 1
Si Si61 1 0.332886 0.876626 0.812605 1
Si Si62 1 0.123375 0.456259 0.812605 1
Si Si63 1 0.456260 0.332886 0.812605 1
Si Si64 1 0.667114 0.123373 0.812605 1
Si Si65 1 0.876626 0.543741 0.812605 1
Si Si66 1 0.667115 0.543741 0.687394 1
Si Si67 1 0.876626 0.332886 0.687394 1
Si Si68 1 0.456260 0.123373 0.687394 1
Si Si69 1 0.332885 0.456259 0.687394 1
Si Si70 1 0.123373 0.667114 0.687394 1
Si Si71 1 0.543740 0.876626 0.687394 1
48 changes: 48 additions & 0 deletions tests/test_atomic_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import ase.io
import numpy as np
import torch
from orb_models.forcefield.base import batch_graphs
from orb_models.forcefield.atomic_system import (
atom_graphs_to_ase_atoms,
ase_atoms_to_atom_graphs,
)


def test_atoms_to_atom_graphs_invertibility(fixtures_path):
atoms = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif"))

atom_graphs = ase_atoms_to_atom_graphs(atoms, wrap=False)
recovered_atoms = atom_graphs_to_ase_atoms(atom_graphs)[0]

assert np.allclose(recovered_atoms.positions, atoms.positions)
assert np.allclose(recovered_atoms.cell, atoms.cell)
assert (recovered_atoms.numbers == atoms.numbers).all()


def test_atom_graphs_to_ase_atoms_debatches(fixtures_path):
atoms = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif"))
graphs = [ase_atoms_to_atom_graphs(atoms, wrap=False) for _ in range(4)]
batch = batch_graphs(graphs)
atoms_list = atom_graphs_to_ase_atoms(batch)
assert len(atoms_list) == 4
assert (atoms_list[0].positions == atoms_list[1].positions).all()
assert (atoms_list[0].get_tags() == atoms_list[1].get_tags()).all()


def test_ase_atoms_to_atom_graphs_wraps(fixtures_path):
atoms_unwrapped = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif"))
atoms_unwrapped.positions[:10] += 2.0 * atoms_unwrapped.cell.array.max()
atoms_wrapped = atoms_unwrapped.copy()
atoms_wrapped.wrap()
assert not np.allclose(atoms_wrapped.positions, atoms_unwrapped.positions)

atom_graphs = ase_atoms_to_atom_graphs(atoms_unwrapped, wrap=False)
assert np.allclose(atom_graphs.positions.numpy(), atoms_unwrapped.positions)

# Note: this test is slightly indirect. We can't test that wrap=True yields the same
# results as ase's .wrap(), because of slight numerical differences at the boundaries.
# Instead, we test that wrap=True for an unwrapped system yields the same results
# as wrap=True for an ase-wrapped system.
atom_graphs1 = ase_atoms_to_atom_graphs(atoms_unwrapped, wrap=True)
atom_graphs2 = ase_atoms_to_atom_graphs(atoms_wrapped, wrap=True)
assert torch.allclose(atom_graphs1.positions, atom_graphs2.positions)
Loading