Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ben rhodes authored and ben rhodes committed Dec 19, 2024
1 parent 49d107b commit 0497cdf
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 0 deletions.
98 changes: 98 additions & 0 deletions tests/fixtures/AFI.cif
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# AFI zeolite
data_SiO2
_symmetry_space_group_name_H-M 'P 1'
_cell_length_a 13.86655914
_cell_length_b 13.86655914
_cell_length_c 8.60047456
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 120.00000000
_symmetry_Int_Tables_number 1
_chemical_formula_structural SiO2
_chemical_formula_sum 'Si24 O48'
_cell_volume 1432.15645170
_cell_formula_units_Z 24
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
O O0 1 0.457007 0.334333 0.000000 1
O O1 1 0.665666 0.122673 0.000000 1
O O2 1 0.877327 0.542993 0.000000 1
O O3 1 0.542993 0.665667 0.000000 1
O O4 1 0.334334 0.877327 0.000000 1
O O5 1 0.122674 0.457007 0.000000 1
O O6 1 0.334333 0.457007 0.499999 1
O O7 1 0.122674 0.665667 0.499999 1
O O8 1 0.542993 0.877327 0.499999 1
O O9 1 0.665666 0.542993 0.499999 1
O O10 1 0.877326 0.334333 0.499999 1
O O11 1 0.457006 0.122673 0.499999 1
O O12 1 0.367814 0.367814 0.250000 1
O O13 1 0.632186 0.000000 0.250000 1
O O14 1 0.000000 0.632186 0.250000 1
O O15 1 0.632186 0.632186 0.250000 1
O O16 1 0.367813 0.000000 0.250000 1
O O17 1 0.999999 0.367814 0.250000 1
O O18 1 0.632186 0.632186 0.750000 1
O O19 1 0.367813 0.000000 0.750000 1
O O20 1 0.999999 0.367814 0.750000 1
O O21 1 0.367814 0.367814 0.750000 1
O O22 1 0.632186 0.000000 0.750000 1
O O23 1 0.000000 0.632186 0.750000 1
O O24 1 0.417358 0.208679 0.250000 1
O O25 1 0.791321 0.208679 0.250000 1
O O26 1 0.791321 0.582642 0.250000 1
O O27 1 0.582642 0.791321 0.250000 1
O O28 1 0.208678 0.791321 0.250000 1
O O29 1 0.208679 0.417358 0.250000 1
O O30 1 0.582642 0.791321 0.750000 1
O O31 1 0.208678 0.791321 0.750000 1
O O32 1 0.208679 0.417358 0.750000 1
O O33 1 0.417358 0.208679 0.750000 1
O O34 1 0.791321 0.208679 0.750000 1
O O35 1 0.791321 0.582642 0.750000 1
O O36 1 0.581212 0.418789 0.250000 1
O O37 1 0.581212 0.162424 0.250000 1
O O38 1 0.837576 0.418789 0.250000 1
O O39 1 0.418788 0.581211 0.250000 1
O O40 1 0.418788 0.837576 0.250000 1
O O41 1 0.162423 0.581211 0.250000 1
O O42 1 0.418788 0.581211 0.750000 1
O O43 1 0.418788 0.837576 0.750000 1
O O44 1 0.162423 0.581211 0.750000 1
O O45 1 0.581212 0.418789 0.750000 1
O O46 1 0.581212 0.162424 0.750000 1
O O47 1 0.837576 0.418789 0.750000 1
Si Si48 1 0.456260 0.332886 0.187394 1
Si Si49 1 0.667114 0.123373 0.187394 1
Si Si50 1 0.876626 0.543741 0.187394 1
Si Si51 1 0.543740 0.667114 0.187394 1
Si Si52 1 0.332886 0.876626 0.187394 1
Si Si53 1 0.123375 0.456259 0.187394 1
Si Si54 1 0.332885 0.456259 0.312606 1
Si Si55 1 0.123373 0.667114 0.312606 1
Si Si56 1 0.543740 0.876626 0.312606 1
Si Si57 1 0.667115 0.543741 0.312606 1
Si Si58 1 0.876626 0.332886 0.312606 1
Si Si59 1 0.456260 0.123373 0.312606 1
Si Si60 1 0.543740 0.667114 0.812605 1
Si Si61 1 0.332886 0.876626 0.812605 1
Si Si62 1 0.123375 0.456259 0.812605 1
Si Si63 1 0.456260 0.332886 0.812605 1
Si Si64 1 0.667114 0.123373 0.812605 1
Si Si65 1 0.876626 0.543741 0.812605 1
Si Si66 1 0.667115 0.543741 0.687394 1
Si Si67 1 0.876626 0.332886 0.687394 1
Si Si68 1 0.456260 0.123373 0.687394 1
Si Si69 1 0.332885 0.456259 0.687394 1
Si Si70 1 0.123373 0.667114 0.687394 1
Si Si71 1 0.543740 0.876626 0.687394 1
48 changes: 48 additions & 0 deletions tests/test_atomic_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import ase.io
import numpy as np
import torch
from orb_models.forcefield.base import batch_graphs
from orb_models.forcefield.atomic_system import (
atom_graphs_to_ase_atoms,
ase_atoms_to_atom_graphs,
)


def test_atoms_to_atom_graphs_invertibility(fixtures_path):
atoms = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif"))

atom_graphs = ase_atoms_to_atom_graphs(atoms, wrap=False)
recovered_atoms = atom_graphs_to_ase_atoms(atom_graphs)[0]

assert np.allclose(recovered_atoms.positions, atoms.positions)
assert np.allclose(recovered_atoms.cell, atoms.cell)
assert (recovered_atoms.numbers == atoms.numbers).all()


def test_atom_graphs_to_ase_atoms_debatches(fixtures_path):
atoms = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif"))
graphs = [ase_atoms_to_atom_graphs(atoms, wrap=False) for _ in range(4)]
batch = batch_graphs(graphs)
atoms_list = atom_graphs_to_ase_atoms(batch)
assert len(atoms_list) == 4
assert (atoms_list[0].positions == atoms_list[1].positions).all()
assert (atoms_list[0].get_tags() == atoms_list[1].get_tags()).all()


def test_ase_atoms_to_atom_graphs_wraps(fixtures_path):
atoms_unwrapped = ase.Atoms(ase.io.read(fixtures_path / "AFI.cif"))
atoms_unwrapped.positions[:10] += 2.0 * atoms_unwrapped.cell.array.max()
atoms_wrapped = atoms_unwrapped.copy()
atoms_wrapped.wrap()
assert not np.allclose(atoms_wrapped.positions, atoms_unwrapped.positions)

atom_graphs = ase_atoms_to_atom_graphs(atoms_unwrapped, wrap=False)
assert np.allclose(atom_graphs.positions.numpy(), atoms_unwrapped.positions)

# Note: this test is slightly indirect. We can't test that wrap=True yields the same
# results as ase's .wrap(), because of slight numerical differences at the boundaries.
# Instead, we test that wrap=True for an unwrapped system yields the same results
# as wrap=True for an ase-wrapped system.
atom_graphs1 = ase_atoms_to_atom_graphs(atoms_unwrapped, wrap=True)
atom_graphs2 = ase_atoms_to_atom_graphs(atoms_wrapped, wrap=True)
assert torch.allclose(atom_graphs1.positions, atom_graphs2.positions)

0 comments on commit 0497cdf

Please sign in to comment.