diff --git a/tests/fixtures/AFI.cif b/tests/fixtures/AFI.cif new file mode 100644 index 0000000..412b4ac --- /dev/null +++ b/tests/fixtures/AFI.cif @@ -0,0 +1,98 @@ +# AFI zeolite +data_SiO2 +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 13.86655914 +_cell_length_b 13.86655914 +_cell_length_c 8.60047456 +_cell_angle_alpha 90.00000000 +_cell_angle_beta 90.00000000 +_cell_angle_gamma 120.00000000 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural SiO2 +_chemical_formula_sum 'Si24 O48' +_cell_volume 1432.15645170 +_cell_formula_units_Z 24 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + O O0 1 0.457007 0.334333 0.000000 1 + O O1 1 0.665666 0.122673 0.000000 1 + O O2 1 0.877327 0.542993 0.000000 1 + O O3 1 0.542993 0.665667 0.000000 1 + O O4 1 0.334334 0.877327 0.000000 1 + O O5 1 0.122674 0.457007 0.000000 1 + O O6 1 0.334333 0.457007 0.499999 1 + O O7 1 0.122674 0.665667 0.499999 1 + O O8 1 0.542993 0.877327 0.499999 1 + O O9 1 0.665666 0.542993 0.499999 1 + O O10 1 0.877326 0.334333 0.499999 1 + O O11 1 0.457006 0.122673 0.499999 1 + O O12 1 0.367814 0.367814 0.250000 1 + O O13 1 0.632186 0.000000 0.250000 1 + O O14 1 0.000000 0.632186 0.250000 1 + O O15 1 0.632186 0.632186 0.250000 1 + O O16 1 0.367813 0.000000 0.250000 1 + O O17 1 0.999999 0.367814 0.250000 1 + O O18 1 0.632186 0.632186 0.750000 1 + O O19 1 0.367813 0.000000 0.750000 1 + O O20 1 0.999999 0.367814 0.750000 1 + O O21 1 0.367814 0.367814 0.750000 1 + O O22 1 0.632186 0.000000 0.750000 1 + O O23 1 0.000000 0.632186 0.750000 1 + O O24 1 0.417358 0.208679 0.250000 1 + O O25 1 0.791321 0.208679 0.250000 1 + O O26 1 0.791321 0.582642 0.250000 1 + O O27 1 0.582642 0.791321 0.250000 1 + O O28 1 0.208678 0.791321 0.250000 1 + O O29 1 0.208679 0.417358 0.250000 1 + O O30 1 0.582642 0.791321 0.750000 1 + O O31 1 0.208678 0.791321 0.750000 1 + O O32 1 0.208679 0.417358 0.750000 1 + O O33 1 0.417358 0.208679 0.750000 1 + O O34 1 0.791321 0.208679 0.750000 1 + O O35 1 0.791321 0.582642 0.750000 1 + O O36 1 0.581212 0.418789 0.250000 1 + O O37 1 0.581212 0.162424 0.250000 1 + O O38 1 0.837576 0.418789 0.250000 1 + O O39 1 0.418788 0.581211 0.250000 1 + O O40 1 0.418788 0.837576 0.250000 1 + O O41 1 0.162423 0.581211 0.250000 1 + O O42 1 0.418788 0.581211 0.750000 1 + O O43 1 0.418788 0.837576 0.750000 1 + O O44 1 0.162423 0.581211 0.750000 1 + O O45 1 0.581212 0.418789 0.750000 1 + O O46 1 0.581212 0.162424 0.750000 1 + O O47 1 0.837576 0.418789 0.750000 1 + Si Si48 1 0.456260 0.332886 0.187394 1 + Si Si49 1 0.667114 0.123373 0.187394 1 + Si Si50 1 0.876626 0.543741 0.187394 1 + Si Si51 1 0.543740 0.667114 0.187394 1 + Si Si52 1 0.332886 0.876626 0.187394 1 + Si Si53 1 0.123375 0.456259 0.187394 1 + Si Si54 1 0.332885 0.456259 0.312606 1 + Si Si55 1 0.123373 0.667114 0.312606 1 + Si Si56 1 0.543740 0.876626 0.312606 1 + Si Si57 1 0.667115 0.543741 0.312606 1 + Si Si58 1 0.876626 0.332886 0.312606 1 + Si Si59 1 0.456260 0.123373 0.312606 1 + Si Si60 1 0.543740 0.667114 0.812605 1 + Si Si61 1 0.332886 0.876626 0.812605 1 + Si Si62 1 0.123375 0.456259 0.812605 1 + Si Si63 1 0.456260 0.332886 0.812605 1 + Si Si64 1 0.667114 0.123373 0.812605 1 + Si Si65 1 0.876626 0.543741 0.812605 1 + Si Si66 1 0.667115 0.543741 0.687394 1 + Si Si67 1 0.876626 0.332886 0.687394 1 + Si Si68 1 0.456260 0.123373 0.687394 1 + Si Si69 1 0.332885 0.456259 0.687394 1 + Si Si70 1 0.123373 0.667114 0.687394 1 + Si Si71 1 0.543740 0.876626 0.687394 1 \ No newline at end of file diff --git a/tests/test_atomic_system.py b/tests/test_atomic_system.py new file mode 100644 index 0000000..0e1407c --- /dev/null +++ b/tests/test_atomic_system.py @@ -0,0 +1,48 @@ +import ase.io +import numpy as np +import torch +from orb_models.forcefield.base import batch_graphs +from orb_models.forcefield.atomic_system import ( + atom_graphs_to_ase_atoms, + ase_atoms_to_atom_graphs, +) + + +def test_atoms_to_atom_graphs_invertibility(fixtures_path): + atoms = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif")) + + atom_graphs = ase_atoms_to_atom_graphs(atoms, wrap=False) + recovered_atoms = atom_graphs_to_ase_atoms(atom_graphs)[0] + + assert np.allclose(recovered_atoms.positions, atoms.positions) + assert np.allclose(recovered_atoms.cell, atoms.cell) + assert (recovered_atoms.numbers == atoms.numbers).all() + + +def test_atom_graphs_to_ase_atoms_debatches(fixtures_path): + atoms = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif")) + graphs = [ase_atoms_to_atom_graphs(atoms, wrap=False) for _ in range(4)] + batch = batch_graphs(graphs) + atoms_list = atom_graphs_to_ase_atoms(batch) + assert len(atoms_list) == 4 + assert (atoms_list[0].positions == atoms_list[1].positions).all() + assert (atoms_list[0].get_tags() == atoms_list[1].get_tags()).all() + + +def test_ase_atoms_to_atom_graphs_wraps(fixtures_path): + atoms_unwrapped = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif")) + atoms_unwrapped.positions[:10] += 2.0 * atoms_unwrapped.cell.array.max() + atoms_wrapped = atoms_unwrapped.copy() + atoms_wrapped.wrap() + assert not np.allclose(atoms_wrapped.positions, atoms_unwrapped.positions) + + atom_graphs = ase_atoms_to_atom_graphs(atoms_unwrapped, wrap=False) + assert np.allclose(atom_graphs.positions.numpy(), atoms_unwrapped.positions) + + # Note: this test is slightly indirect. We can't test that wrap=True yields the same + # results as ase's .wrap(), because of slight numerical differences at the boundaries. + # Instead, we test that wrap=True for an unwrapped system yields the same results + # as wrap=True for an ase-wrapped system. + atom_graphs1 = ase_atoms_to_atom_graphs(atoms_unwrapped, wrap=True) + atom_graphs2 = ase_atoms_to_atom_graphs(atoms_wrapped, wrap=True) + assert torch.allclose(atom_graphs1.positions, atom_graphs2.positions)