Skip to content

Commit

Permalink
add empty dicts for all missing datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Jul 18, 2024
1 parent 3ea5608 commit d50944b
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,26 @@ def __init__(
)

# Define datasets
# or {} in cases where "dataset": None is explicitly defined
if isinstance(dataset, list):
if len(dataset) > 0:
self.config["dataset"] = dataset[0] or {}
self.config["dataset"] = dataset[0]
if len(dataset) > 1:
self.config["val_dataset"] = dataset[1] or {}
self.config["val_dataset"] = dataset[1]
if len(dataset) > 2:
self.config["test_dataset"] = dataset[2] or {}
self.config["test_dataset"] = dataset[2]
elif isinstance(dataset, dict):
# or {} in cases where "dataset": None is explicitly defined
self.config["dataset"] = dataset.get("train", {}) or {}
self.config["val_dataset"] = dataset.get("val", {}) or {}
self.config["test_dataset"] = dataset.get("test", {}) or {}
self.config["relax_dataset"] = dataset.get("relax", {}) or {}
else:
self.config["dataset"] = dataset or {}
self.config["val_dataset"], self.config["test_dataset"] = {}, {}

# add empty dicts for missing datasets
for dataset_name in ("val_dataset", "test_dataset", "relax_dataset"):
if dataset_name not in self.config:
self.config[dataset_name] = {}

if not is_debug and distutils.is_master():
os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True)
Expand Down

0 comments on commit d50944b

Please sign in to comment.