From 14cfab1d987d892540d8735709d9aa452f74ad5f Mon Sep 17 00:00:00 2001 From: Zack Ulissi Date: Fri, 12 Jul 2024 19:58:26 +0000 Subject: [PATCH] clean up logic in load_datasets --- src/fairchem/core/trainers/base_trainer.py | 35 ++++++++++++---------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index c06e956e5..e773d7667 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -159,10 +159,10 @@ def __init__( if len(dataset) > 2: self.config["test_dataset"] = dataset[2] elif isinstance(dataset, dict): - self.config["dataset"] = dataset.get("train", None) - self.config["val_dataset"] = dataset.get("val", None) - self.config["test_dataset"] = dataset.get("test", None) - self.config["relax_dataset"] = dataset.get("relax", None) + self.config["dataset"] = dataset.get("train", {}) + self.config["val_dataset"] = dataset.get("val", {}) + self.config["test_dataset"] = dataset.get("test", {}) + self.config["relax_dataset"] = dataset.get("relax", {}) else: self.config["dataset"] = dataset @@ -277,8 +277,19 @@ def load_datasets(self) -> None: self.val_loader = None self.test_loader = None + # Default all of the dataset portions to {} if + # they don't exist, or are null + if self.config.get("dataset", None): + self.config["dataset"] = {} + if self.config.get("val_dataset", None): + self.config["val_dataset"] = {} + if self.config.get("test_dataset", None): + self.config["test_dataset"] = {} + if self.config.get("relax_dataset", None): + self.config["relax_dataset"] = {} + # load train, val, test datasets - if self.config.get("dataset", None) and self.config["dataset"].get("src", None): + if self.config["dataset"] and self.config["dataset"].get("src", None): logging.info( f"Loading dataset: {self.config['dataset'].get('format', 'lmdb')}" ) @@ -295,10 +306,8 @@ def load_datasets(self) -> None: self.train_dataset, self.train_sampler, ) - elif self.config.get("dataset", None) is None: - self.config["dataset"] = {} - if self.config.get("val_dataset", None): + if self.config.get["val_dataset"]: if self.config["val_dataset"].get("use_train_settings", True): val_config = self.config["dataset"].copy() val_config.update(self.config["val_dataset"]) @@ -319,10 +328,8 @@ def load_datasets(self) -> None: self.val_dataset, self.val_sampler, ) - else: - self.config["val_dataset"] = {} - if self.config.get("test_dataset", None): + if self.config["test_dataset"]: if ( self.config["test_dataset"].get("use_train_settings", True) and self.config[ @@ -348,10 +355,8 @@ def load_datasets(self) -> None: self.test_dataset, self.test_sampler, ) - else: - self.config["test_dataset"] = {} - if self.config.get("relax_dataset", None): + if self.config["relax_dataset"]: if self.config["relax_dataset"].get("use_train_settings", True): relax_config = self.config["dataset"].copy() relax_config.update(self.config["relax_dataset"]) @@ -371,7 +376,7 @@ def load_datasets(self) -> None: self.relax_loader = self.get_dataloader( self.relax_dataset, self.relax_sampler, - ) + ) def load_task(self): # Normalizer for the dataset.