diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index c8b14f856..dce509945 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -151,22 +151,26 @@ def __init__( ) # Define datasets - # or {} in cases where "dataset": None is explicitly defined if isinstance(dataset, list): if len(dataset) > 0: - self.config["dataset"] = dataset[0] or {} + self.config["dataset"] = dataset[0] if len(dataset) > 1: - self.config["val_dataset"] = dataset[1] or {} + self.config["val_dataset"] = dataset[1] if len(dataset) > 2: - self.config["test_dataset"] = dataset[2] or {} + self.config["test_dataset"] = dataset[2] elif isinstance(dataset, dict): + # or {} in cases where "dataset": None is explicitly defined self.config["dataset"] = dataset.get("train", {}) or {} self.config["val_dataset"] = dataset.get("val", {}) or {} self.config["test_dataset"] = dataset.get("test", {}) or {} self.config["relax_dataset"] = dataset.get("relax", {}) or {} else: self.config["dataset"] = dataset or {} - self.config["val_dataset"], self.config["test_dataset"] = {}, {} + + # add empty dicts for missing datasets + for dataset_name in ("val_dataset", "test_dataset", "relax_dataset"): + if dataset_name not in self.config: + self.config[dataset_name] = {} if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True)