diff --git a/orb_models/forcefield/atomic_system.py b/orb_models/forcefield/atomic_system.py index c3be191..85debf6 100644 --- a/orb_models/forcefield/atomic_system.py +++ b/orb_models/forcefield/atomic_system.py @@ -90,46 +90,57 @@ def atom_graphs_to_ase_atoms( def ase_atoms_to_atom_graphs( atoms: ase.Atoms, - system_config: SystemConfig = SystemConfig( - radius=10.0, max_num_neighbors=20, use_timestep_0=True - ), - system_id: Optional[int] = None, + *, + wrap: bool = True, brute_force_knn: Optional[bool] = None, - device: Optional[torch.device] = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ), + device: Optional[torch.device] = None, + system_config: Optional[SystemConfig] = None, + system_id: Optional[int] = None, ) -> AtomGraphs: """Generate AtomGraphs from an ase.Atoms object. Args: atoms: ase.Atoms object - system_config: SystemConfig object - system_id: Optional system_id + wrap: whether to wrap atomic positions into the central unit cell (if there is one). + NOTE: there can be numerical differences from ase's .wrap() method when an atom is near a cell boundary. brute_force_knn: whether to use a 'brute force' knn approach with torch.cdist for kdtree construction. Defaults to None, in which case brute_force is used if we a GPU is avaiable (2-6x faster), but not on CPU (1.5x faster - 4x slower). For very large systems, brute_force may OOM on GPU, so it is recommended to set to False in that case. - device: device to put the tensors on. + device: device to put the tensors on. By default, uses the GPU if available. + system_config: SystemConfig object, specifying the max radius and max num_neighbors + used in the k-nearest neighbors graph construction. + system_id: Optional index, for tracking the identity of a datapoint. Returns: AtomGraphs object """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if system_config is None: + system_config = SystemConfig(radius=10.0, max_num_neighbors=20) + atomic_numbers = torch.from_numpy(atoms.numbers).to(torch.long) atom_type_embedding = torch.nn.functional.one_hot( atomic_numbers, num_classes=118 ).type(torch.float32) + positions = torch.from_numpy(atoms.positions).to(torch.float32) + cell = torch.from_numpy(atoms.cell.array).to(torch.float32) + if wrap and torch.any(cell != 0): + positions = featurization_utilities.map_to_pbc_cell(positions, cell) + node_feats = { "atomic_numbers": atomic_numbers.to(torch.int64), "atomic_numbers_embedding": atom_type_embedding.to(torch.float32), # NOTE: positions are stored as features on the AtomGraphs, # but not actually used as input features to the model. - "positions": torch.from_numpy(atoms.positions).to(torch.float32), + "positions": positions, } - system_feats = {"cell": torch.Tensor(atoms.cell.array[None, ...]).to(torch.float)} + system_feats = {"cell": cell.unsqueeze(0)} edge_feats, senders, receivers = _get_edge_feats( - node_feats["positions"], # type: ignore - system_feats["cell"][0], + positions, + cell, system_config.radius, system_config.max_num_neighbors, brute_force=brute_force_knn, @@ -159,10 +170,8 @@ def _get_edge_feats( cell: torch.Tensor, radius: float, max_num_neighbours: int, - brute_force: Optional[bool] = None, - device: Optional[torch.device] = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ), + brute_force: Optional[bool], + device: torch.device, ): """Get edge features. diff --git a/orb_models/forcefield/featurization_utilities.py b/orb_models/forcefield/featurization_utilities.py index fb6eb6a..eba609d 100644 --- a/orb_models/forcefield/featurization_utilities.py +++ b/orb_models/forcefield/featurization_utilities.py @@ -365,6 +365,30 @@ def compute_pbc_radius_graph( return torch.stack((senders_torch, receivers), dim=0), vectors +def map_to_pbc_cell( + positions: torch.Tensor, + periodic_boundary_conditions: torch.Tensor, +) -> torch.Tensor: + """Maps positions to within a periodic boundary cell. + + Args: + positions (torch.Tensor): The positions to be mapped. Shape [num_particles, 3] + periodic_boundary_conditions (torch.Tensor): The periodic boundary conditions. Shape 3x3. + + Returns: + torch.Tensor: Positions mapped to within a periodic boundary cell. + """ + # Inverses are a lot more reliable in double precision, so we'll do the whole + # thing in double then go back to single. + positions = positions.double() + periodic_boundary_conditions = periodic_boundary_conditions.double() + # The strategy here is to map our positions to fractional or internal coordinates. + # Then we take the modulo, then map back to euclidian co-ordinates. + fractional_pos = torch.linalg.solve(periodic_boundary_conditions.T, positions.T).T + fractional_pos = fractional_pos % 1.0 + return (fractional_pos @ periodic_boundary_conditions).float() + + def batch_map_to_pbc_cell( positions: torch.Tensor, periodic_boundary_conditions: torch.Tensor, diff --git a/tests/fixtures/AFI.cif b/tests/fixtures/AFI.cif new file mode 100644 index 0000000..412b4ac --- /dev/null +++ b/tests/fixtures/AFI.cif @@ -0,0 +1,98 @@ +# AFI zeolite +data_SiO2 +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 13.86655914 +_cell_length_b 13.86655914 +_cell_length_c 8.60047456 +_cell_angle_alpha 90.00000000 +_cell_angle_beta 90.00000000 +_cell_angle_gamma 120.00000000 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural SiO2 +_chemical_formula_sum 'Si24 O48' +_cell_volume 1432.15645170 +_cell_formula_units_Z 24 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + O O0 1 0.457007 0.334333 0.000000 1 + O O1 1 0.665666 0.122673 0.000000 1 + O O2 1 0.877327 0.542993 0.000000 1 + O O3 1 0.542993 0.665667 0.000000 1 + O O4 1 0.334334 0.877327 0.000000 1 + O O5 1 0.122674 0.457007 0.000000 1 + O O6 1 0.334333 0.457007 0.499999 1 + O O7 1 0.122674 0.665667 0.499999 1 + O O8 1 0.542993 0.877327 0.499999 1 + O O9 1 0.665666 0.542993 0.499999 1 + O O10 1 0.877326 0.334333 0.499999 1 + O O11 1 0.457006 0.122673 0.499999 1 + O O12 1 0.367814 0.367814 0.250000 1 + O O13 1 0.632186 0.000000 0.250000 1 + O O14 1 0.000000 0.632186 0.250000 1 + O O15 1 0.632186 0.632186 0.250000 1 + O O16 1 0.367813 0.000000 0.250000 1 + O O17 1 0.999999 0.367814 0.250000 1 + O O18 1 0.632186 0.632186 0.750000 1 + O O19 1 0.367813 0.000000 0.750000 1 + O O20 1 0.999999 0.367814 0.750000 1 + O O21 1 0.367814 0.367814 0.750000 1 + O O22 1 0.632186 0.000000 0.750000 1 + O O23 1 0.000000 0.632186 0.750000 1 + O O24 1 0.417358 0.208679 0.250000 1 + O O25 1 0.791321 0.208679 0.250000 1 + O O26 1 0.791321 0.582642 0.250000 1 + O O27 1 0.582642 0.791321 0.250000 1 + O O28 1 0.208678 0.791321 0.250000 1 + O O29 1 0.208679 0.417358 0.250000 1 + O O30 1 0.582642 0.791321 0.750000 1 + O O31 1 0.208678 0.791321 0.750000 1 + O O32 1 0.208679 0.417358 0.750000 1 + O O33 1 0.417358 0.208679 0.750000 1 + O O34 1 0.791321 0.208679 0.750000 1 + O O35 1 0.791321 0.582642 0.750000 1 + O O36 1 0.581212 0.418789 0.250000 1 + O O37 1 0.581212 0.162424 0.250000 1 + O O38 1 0.837576 0.418789 0.250000 1 + O O39 1 0.418788 0.581211 0.250000 1 + O O40 1 0.418788 0.837576 0.250000 1 + O O41 1 0.162423 0.581211 0.250000 1 + O O42 1 0.418788 0.581211 0.750000 1 + O O43 1 0.418788 0.837576 0.750000 1 + O O44 1 0.162423 0.581211 0.750000 1 + O O45 1 0.581212 0.418789 0.750000 1 + O O46 1 0.581212 0.162424 0.750000 1 + O O47 1 0.837576 0.418789 0.750000 1 + Si Si48 1 0.456260 0.332886 0.187394 1 + Si Si49 1 0.667114 0.123373 0.187394 1 + Si Si50 1 0.876626 0.543741 0.187394 1 + Si Si51 1 0.543740 0.667114 0.187394 1 + Si Si52 1 0.332886 0.876626 0.187394 1 + Si Si53 1 0.123375 0.456259 0.187394 1 + Si Si54 1 0.332885 0.456259 0.312606 1 + Si Si55 1 0.123373 0.667114 0.312606 1 + Si Si56 1 0.543740 0.876626 0.312606 1 + Si Si57 1 0.667115 0.543741 0.312606 1 + Si Si58 1 0.876626 0.332886 0.312606 1 + Si Si59 1 0.456260 0.123373 0.312606 1 + Si Si60 1 0.543740 0.667114 0.812605 1 + Si Si61 1 0.332886 0.876626 0.812605 1 + Si Si62 1 0.123375 0.456259 0.812605 1 + Si Si63 1 0.456260 0.332886 0.812605 1 + Si Si64 1 0.667114 0.123373 0.812605 1 + Si Si65 1 0.876626 0.543741 0.812605 1 + Si Si66 1 0.667115 0.543741 0.687394 1 + Si Si67 1 0.876626 0.332886 0.687394 1 + Si Si68 1 0.456260 0.123373 0.687394 1 + Si Si69 1 0.332885 0.456259 0.687394 1 + Si Si70 1 0.123373 0.667114 0.687394 1 + Si Si71 1 0.543740 0.876626 0.687394 1 \ No newline at end of file diff --git a/tests/test_atomic_system.py b/tests/test_atomic_system.py new file mode 100644 index 0000000..0e1407c --- /dev/null +++ b/tests/test_atomic_system.py @@ -0,0 +1,48 @@ +import ase.io +import numpy as np +import torch +from orb_models.forcefield.base import batch_graphs +from orb_models.forcefield.atomic_system import ( + atom_graphs_to_ase_atoms, + ase_atoms_to_atom_graphs, +) + + +def test_atoms_to_atom_graphs_invertibility(fixtures_path): + atoms = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif")) + + atom_graphs = ase_atoms_to_atom_graphs(atoms, wrap=False) + recovered_atoms = atom_graphs_to_ase_atoms(atom_graphs)[0] + + assert np.allclose(recovered_atoms.positions, atoms.positions) + assert np.allclose(recovered_atoms.cell, atoms.cell) + assert (recovered_atoms.numbers == atoms.numbers).all() + + +def test_atom_graphs_to_ase_atoms_debatches(fixtures_path): + atoms = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif")) + graphs = [ase_atoms_to_atom_graphs(atoms, wrap=False) for _ in range(4)] + batch = batch_graphs(graphs) + atoms_list = atom_graphs_to_ase_atoms(batch) + assert len(atoms_list) == 4 + assert (atoms_list[0].positions == atoms_list[1].positions).all() + assert (atoms_list[0].get_tags() == atoms_list[1].get_tags()).all() + + +def test_ase_atoms_to_atom_graphs_wraps(fixtures_path): + atoms_unwrapped = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif")) + atoms_unwrapped.positions[:10] += 2.0 * atoms_unwrapped.cell.array.max() + atoms_wrapped = atoms_unwrapped.copy() + atoms_wrapped.wrap() + assert not np.allclose(atoms_wrapped.positions, atoms_unwrapped.positions) + + atom_graphs = ase_atoms_to_atom_graphs(atoms_unwrapped, wrap=False) + assert np.allclose(atom_graphs.positions.numpy(), atoms_unwrapped.positions) + + # Note: this test is slightly indirect. We can't test that wrap=True yields the same + # results as ase's .wrap(), because of slight numerical differences at the boundaries. + # Instead, we test that wrap=True for an unwrapped system yields the same results + # as wrap=True for an ase-wrapped system. + atom_graphs1 = ase_atoms_to_atom_graphs(atoms_unwrapped, wrap=True) + atom_graphs2 = ase_atoms_to_atom_graphs(atoms_wrapped, wrap=True) + assert torch.allclose(atom_graphs1.positions, atom_graphs2.positions)