From 4fac198c3affa9a4034515bb9608bfff1c6ce0cf Mon Sep 17 00:00:00 2001 From: Misko Date: Thu, 19 Dec 2024 00:04:53 +0000 Subject: [PATCH] add some tests --- src/fairchem/core/datasets/base_dataset.py | 10 ++++- tests/core/datasets/test_create_dataset.py | 50 +++++++++++++++++++++- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/src/fairchem/core/datasets/base_dataset.py b/src/fairchem/core/datasets/base_dataset.py index d0d789560..0a5db9a37 100644 --- a/src/fairchem/core/datasets/base_dataset.py +++ b/src/fairchem/core/datasets/base_dataset.py @@ -230,11 +230,17 @@ def create_dataset(config: dict[str, Any], split: str) -> Subset: # shuffle all datasets by default to avoid biasing the sampling in concat dataset # TODO only shuffle if split is train max_index = sample_n - indices = indices[randperm(len(indices), generator=g)] + indices = ( + indices + if len(indices) == 1 + else indices[randperm(len(indices), generator=g)] + ) else: max_index = len(indices) indices = ( - indices if no_shuffle else indices[randperm(len(indices), generator=g)] + indices + if (no_shuffle or len(indices) == 1) + else indices[randperm(len(indices), generator=g)] ) if max_index > len(indices): diff --git a/tests/core/datasets/test_create_dataset.py b/tests/core/datasets/test_create_dataset.py index 1dc17bcb3..02f89b593 100644 --- a/tests/core/datasets/test_create_dataset.py +++ b/tests/core/datasets/test_create_dataset.py @@ -1,10 +1,13 @@ +from __future__ import annotations + import os +import tempfile + import numpy as np import pytest from fairchem.core.datasets import LMDBDatabase, create_dataset from fairchem.core.datasets.base_dataset import BaseDataset -import tempfile from fairchem.core.trainers.base_trainer import BaseTrainer @@ -12,12 +15,21 @@ def lmdb_database(structures): with tempfile.TemporaryDirectory() as tmpdirname: num_atoms = [] + mod2 = [] + mod3 = [] asedb_fn = f"{tmpdirname}/asedb.lmdb" with LMDBDatabase(asedb_fn) as database: for i, atoms in enumerate(structures): database.write(atoms, data=atoms.info) num_atoms.append(len(atoms)) - np.savez(f"{tmpdirname}/metadata.npz", natoms=num_atoms) + mod2.append(len(atoms) % 2) + mod3.append(len(atoms) % 3) + np.savez( + f"{tmpdirname}/metadata.npz", + natoms=num_atoms, + mod2=mod2, + mod3=mod3, + ) yield asedb_fn @@ -76,6 +88,40 @@ def get_dataloader(self, *args, **kwargs): assert len(t.val_dataset) == 3 +def test_subset_to(structures, lmdb_database): + config = { + "format": "ase_db", + "src": str(lmdb_database), + "subset_to": [{"op": "abs_le", "metadata_key": "mod2", "rhv": 10}], + } + + assert len(create_dataset(config, split="train")) == len(structures) + + # only select those that have mod2==0 + config = { + "format": "ase_db", + "src": str(lmdb_database), + "subset_to": [{"op": "abs_le", "metadata_key": "mod2", "rhv": 0}], + } + assert len(create_dataset(config, split="train")) == len( + [s for s in structures if len(s) % 2 == 0] + ) + + # only select those that have mod2==0 and mod3==0 + config = { + "format": "ase_db", + "src": str(lmdb_database), + "subset_to": [ + {"op": "abs_le", "metadata_key": "mod2", "rhv": 0}, + {"op": "abs_le", "metadata_key": "mod2", "rhv": 0}, + ], + } + assert len(create_dataset(config, split="train")) == len( + [s for s in structures if len(s) % 2 == 0] + ) + assert len([s for s in structures if len(s) % 2 == 0]) > 0 + + @pytest.mark.parametrize("max_atoms", [3, None]) @pytest.mark.parametrize( "key, value", [("first_n", 2), ("sample_n", 2), ("no_shuffle", True)]