Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dataset logic #771

Merged
merged 8 commits into from
Jul 19, 2024
56 changes: 23 additions & 33 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,18 @@ def __init__(
if len(dataset) > 2:
self.config["test_dataset"] = dataset[2]
elif isinstance(dataset, dict):
self.config["dataset"] = dataset.get("train", {})
self.config["val_dataset"] = dataset.get("val", {})
self.config["test_dataset"] = dataset.get("test", {})
self.config["relax_dataset"] = dataset.get("relax", {})
# or {} in cases where "dataset": None is explicitly defined
self.config["dataset"] = dataset.get("train", {}) or {}
self.config["val_dataset"] = dataset.get("val", {}) or {}
self.config["test_dataset"] = dataset.get("test", {}) or {}
self.config["relax_dataset"] = dataset.get("relax", {}) or {}
else:
self.config["dataset"] = dataset
self.config["dataset"] = dataset or {}

# add empty dicts for missing datasets
for dataset_name in ("val_dataset", "test_dataset", "relax_dataset"):
if dataset_name not in self.config:
self.config[dataset_name] = {}

if not is_debug and distutils.is_master():
os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True)
Expand Down Expand Up @@ -277,19 +283,8 @@ def load_datasets(self) -> None:
self.val_loader = None
self.test_loader = None

# Default all of the dataset portions to {} if
# they don't exist, or are null
if not self.config.get("dataset", None):
self.config["dataset"] = {}
if not self.config.get("val_dataset", None):
self.config["val_dataset"] = {}
if not self.config.get("test_dataset", None):
self.config["test_dataset"] = {}
if not self.config.get("relax_dataset", None):
self.config["relax_dataset"] = {}

# load train, val, test datasets
if self.config["dataset"] and self.config["dataset"].get("src", None):
if "src" in self.config["dataset"]:
logging.info(
f"Loading dataset: {self.config['dataset'].get('format', 'lmdb')}"
)
Expand All @@ -307,7 +302,7 @@ def load_datasets(self) -> None:
self.train_sampler,
)

if self.config["val_dataset"]:
if "src" in self.config["val_dataset"]:
if self.config["val_dataset"].get("use_train_settings", True):
val_config = self.config["dataset"].copy()
val_config.update(self.config["val_dataset"])
Expand All @@ -329,13 +324,8 @@ def load_datasets(self) -> None:
self.val_sampler,
)

if self.config["test_dataset"]:
if (
self.config["test_dataset"].get("use_train_settings", True)
and self.config[
"dataset"
] # if there's no training dataset, we have nothing to copy
):
if "src" in self.config["test_dataset"]:
if self.config["test_dataset"].get("use_train_settings", True):
test_config = self.config["dataset"].copy()
test_config.update(self.config["test_dataset"])
else:
Expand Down Expand Up @@ -407,16 +397,16 @@ def load_task(self):
"outputs"
][target_name].get("level", "system")
if "train_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget][
"train_on_free_atoms"
] = self.config["outputs"][target_name].get(
"train_on_free_atoms", True
self.output_targets[subtarget]["train_on_free_atoms"] = (
self.config[
"outputs"
][target_name].get("train_on_free_atoms", True)
)
if "eval_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget][
"eval_on_free_atoms"
] = self.config["outputs"][target_name].get(
"eval_on_free_atoms", True
self.output_targets[subtarget]["eval_on_free_atoms"] = (
self.config[
"outputs"
][target_name].get("eval_on_free_atoms", True)
)

# TODO: Assert that all targets, loss fn, metrics defined are consistent
Expand Down