From bffcb108a3582010e7dcc7e3561000690b1ca3f3 Mon Sep 17 00:00:00 2001 From: Misko Date: Wed, 18 Dec 2024 23:12:16 +0000 Subject: [PATCH] fix up datasetmetadata --- src/fairchem/core/datasets/base_dataset.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/fairchem/core/datasets/base_dataset.py b/src/fairchem/core/datasets/base_dataset.py index 7a2164f96..9f8aac3d0 100644 --- a/src/fairchem/core/datasets/base_dataset.py +++ b/src/fairchem/core/datasets/base_dataset.py @@ -34,8 +34,11 @@ T_co = TypeVar("T_co", covariant=True) -class DatasetMetadata(NamedTuple): - natoms: ArrayLike | None = None +class DatasetMetadata: + def __init__(self, natoms: ArrayLike | None = None, **kwargs): + self.natoms = natoms + for key, value in kwargs.items(): + setattr(self, key, value) class UnsupportedDatasetError(ValueError): @@ -106,9 +109,10 @@ def _metadata(self) -> DatasetMetadata: metadata = DatasetMetadata( **{ field: np.concatenate([metadata[field] for metadata in metadata_npzs]) - for field in DatasetMetadata._fields + for field in metadata_npzs[0].keys() } ) + assert np.issubdtype( metadata.natoms.dtype, np.integer ), f"Metadata natoms must be an integer type! not {metadata.natoms.dtype}"