Skip to content

Commit

Permalink
add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Dec 19, 2024
1 parent 0d32718 commit 4fac198
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
10 changes: 8 additions & 2 deletions src/fairchem/core/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,17 @@ def create_dataset(config: dict[str, Any], split: str) -> Subset:
# shuffle all datasets by default to avoid biasing the sampling in concat dataset
# TODO only shuffle if split is train
max_index = sample_n
indices = indices[randperm(len(indices), generator=g)]
indices = (
indices
if len(indices) == 1
else indices[randperm(len(indices), generator=g)]
)
else:
max_index = len(indices)
indices = (
indices if no_shuffle else indices[randperm(len(indices), generator=g)]
indices
if (no_shuffle or len(indices) == 1)
else indices[randperm(len(indices), generator=g)]
)

if max_index > len(indices):
Expand Down
50 changes: 48 additions & 2 deletions tests/core/datasets/test_create_dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
from __future__ import annotations

import os
import tempfile

import numpy as np
import pytest

from fairchem.core.datasets import LMDBDatabase, create_dataset
from fairchem.core.datasets.base_dataset import BaseDataset
import tempfile
from fairchem.core.trainers.base_trainer import BaseTrainer


@pytest.fixture()
def lmdb_database(structures):
with tempfile.TemporaryDirectory() as tmpdirname:
num_atoms = []
mod2 = []
mod3 = []
asedb_fn = f"{tmpdirname}/asedb.lmdb"
with LMDBDatabase(asedb_fn) as database:
for i, atoms in enumerate(structures):
database.write(atoms, data=atoms.info)
num_atoms.append(len(atoms))
np.savez(f"{tmpdirname}/metadata.npz", natoms=num_atoms)
mod2.append(len(atoms) % 2)
mod3.append(len(atoms) % 3)
np.savez(
f"{tmpdirname}/metadata.npz",
natoms=num_atoms,
mod2=mod2,
mod3=mod3,
)
yield asedb_fn


Expand Down Expand Up @@ -76,6 +88,40 @@ def get_dataloader(self, *args, **kwargs):
assert len(t.val_dataset) == 3


def test_subset_to(structures, lmdb_database):
config = {
"format": "ase_db",
"src": str(lmdb_database),
"subset_to": [{"op": "abs_le", "metadata_key": "mod2", "rhv": 10}],
}

assert len(create_dataset(config, split="train")) == len(structures)

# only select those that have mod2==0
config = {
"format": "ase_db",
"src": str(lmdb_database),
"subset_to": [{"op": "abs_le", "metadata_key": "mod2", "rhv": 0}],
}
assert len(create_dataset(config, split="train")) == len(
[s for s in structures if len(s) % 2 == 0]
)

# only select those that have mod2==0 and mod3==0
config = {
"format": "ase_db",
"src": str(lmdb_database),
"subset_to": [
{"op": "abs_le", "metadata_key": "mod2", "rhv": 0},
{"op": "abs_le", "metadata_key": "mod2", "rhv": 0},
],
}
assert len(create_dataset(config, split="train")) == len(
[s for s in structures if len(s) % 2 == 0]
)
assert len([s for s in structures if len(s) % 2 == 0]) > 0


@pytest.mark.parametrize("max_atoms", [3, None])
@pytest.mark.parametrize(
"key, value", [("first_n", 2), ("sample_n", 2), ("no_shuffle", True)]
Expand Down

0 comments on commit 4fac198

Please sign in to comment.