Skip to content

Commit

Permalink
add in info fields
Browse files Browse the repository at this point in the history
add in info fields and limits
  • Loading branch information
misko committed Dec 18, 2024
1 parent 6efc99b commit 0124228
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/fairchem/core/datasets/ase_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def __init__(
self.ids = self._load_dataset_get_ids(config)
self.num_samples = len(self.ids)

self.info_fields = config.get("info_fields", [])

if len(self.ids) == 0:
raise ValueError(
rf"No valid ase data found! \n"
Expand Down Expand Up @@ -131,6 +133,10 @@ def __getitem__(self, idx):
data_object.fid = fid
data_object.natoms = len(atoms)

# load additional info from dataset
for info_field in self.info_fields:
setattr(data_object, info_field, atoms.info.get(info_field))

# apply linear reference
if self.a2g.r_energy is True and self.lin_ref is not None:
data_object.energy -= sum(self.lin_ref[data_object.atomic_numbers.long()])
Expand Down
19 changes: 19 additions & 0 deletions src/fairchem/core/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def create_dataset(config: dict[str, Any], split: str) -> Subset:
g.manual_seed(seed)

dataset = dataset_cls(current_split_config)

# Get indices of the dataset
indices = dataset.indices
max_atoms = current_split_config.get("max_atoms", None)
Expand All @@ -191,6 +192,24 @@ def create_dataset(config: dict[str, Any], split: str) -> Subset:
raise ValueError("Cannot use max_atoms without dataset metadata")
indices = indices[dataset.get_metadata("natoms", indices) <= max_atoms]

for subset_to in current_split_config.get("subset_to", []):
if not dataset.metadata_hasattr(subset_to["metadata_key"]):
raise ValueError(
f"Cannot use {subset_to} without dataset metadata key {subset_to['metadata_key']}"
)
if subset_to["op"] == "abs_le":
indices = indices[
np.abs(dataset.get_metadata(subset_to["metadata_key"], indices))
<= subset_to["rhv"]
]
elif subset_to["op"] == "in":
indices = indices[
np.isin(
dataset.get_metadata(subset_to["metadata_key"], indices),
subset_to["rhv"],
)
]

# Apply dataset level transforms
# TODO is no_shuffle mutually exclusive though? or what is the purpose of no_shuffle?
first_n = current_split_config.get("first_n")
Expand Down

0 comments on commit 0124228

Please sign in to comment.