Skip to content

Commit

Permalink
clean up logic in load_datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
zulissimeta committed Jul 12, 2024
1 parent 6675b80 commit 14cfab1
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def __init__(
if len(dataset) > 2:
self.config["test_dataset"] = dataset[2]
elif isinstance(dataset, dict):
self.config["dataset"] = dataset.get("train", None)
self.config["val_dataset"] = dataset.get("val", None)
self.config["test_dataset"] = dataset.get("test", None)
self.config["relax_dataset"] = dataset.get("relax", None)
self.config["dataset"] = dataset.get("train", {})
self.config["val_dataset"] = dataset.get("val", {})
self.config["test_dataset"] = dataset.get("test", {})
self.config["relax_dataset"] = dataset.get("relax", {})
else:
self.config["dataset"] = dataset

Expand Down Expand Up @@ -277,8 +277,19 @@ def load_datasets(self) -> None:
self.val_loader = None
self.test_loader = None

# Default all of the dataset portions to {} if
# they don't exist, or are null
if self.config.get("dataset", None):
self.config["dataset"] = {}
if self.config.get("val_dataset", None):
self.config["val_dataset"] = {}
if self.config.get("test_dataset", None):
self.config["test_dataset"] = {}
if self.config.get("relax_dataset", None):
self.config["relax_dataset"] = {}

# load train, val, test datasets
if self.config.get("dataset", None) and self.config["dataset"].get("src", None):
if self.config["dataset"] and self.config["dataset"].get("src", None):
logging.info(
f"Loading dataset: {self.config['dataset'].get('format', 'lmdb')}"
)
Expand All @@ -295,10 +306,8 @@ def load_datasets(self) -> None:
self.train_dataset,
self.train_sampler,
)
elif self.config.get("dataset", None) is None:
self.config["dataset"] = {}

if self.config.get("val_dataset", None):
if self.config.get["val_dataset"]:
if self.config["val_dataset"].get("use_train_settings", True):
val_config = self.config["dataset"].copy()
val_config.update(self.config["val_dataset"])
Expand All @@ -319,10 +328,8 @@ def load_datasets(self) -> None:
self.val_dataset,
self.val_sampler,
)
else:
self.config["val_dataset"] = {}

if self.config.get("test_dataset", None):
if self.config["test_dataset"]:
if (
self.config["test_dataset"].get("use_train_settings", True)
and self.config[
Expand All @@ -348,10 +355,8 @@ def load_datasets(self) -> None:
self.test_dataset,
self.test_sampler,
)
else:
self.config["test_dataset"] = {}

if self.config.get("relax_dataset", None):
if self.config["relax_dataset"]:
if self.config["relax_dataset"].get("use_train_settings", True):
relax_config = self.config["dataset"].copy()
relax_config.update(self.config["relax_dataset"])
Expand All @@ -371,7 +376,7 @@ def load_datasets(self) -> None:
self.relax_loader = self.get_dataloader(
self.relax_dataset,
self.relax_sampler,
)
)

def load_task(self):
# Normalizer for the dataset.
Expand Down

0 comments on commit 14cfab1

Please sign in to comment.