From 7666064c38e3f6b484a6df13448d065cc5d9a79a Mon Sep 17 00:00:00 2001 From: rgao Date: Thu, 5 Dec 2024 17:45:07 +0000 Subject: [PATCH] temp switch to dcp instead --- src/fairchem/core/tasks/task.py | 4 +--- src/fairchem/core/trainers/base_trainer.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/fairchem/core/tasks/task.py b/src/fairchem/core/tasks/task.py index 265220114..0048c8ac1 100644 --- a/src/fairchem/core/tasks/task.py +++ b/src/fairchem/core/tasks/task.py @@ -22,9 +22,7 @@ def setup(self, trainer) -> None: self.trainer = trainer # TODO: make checkpoint.pt a constant so we don't pass this string around everywhere - self.chkpt_path = os.path.join( - self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt" - ) + self.chkpt_path = os.path.join(self.trainer.config["cmd"]["checkpoint_dir"]) # if the supplied checkpoint exists, then load that, ie: when user specifies the --checkpoint option # OR if the a job was preempted correctly and the submitit checkpoint function was called diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 5c21a743a..81e8599ed 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -323,11 +323,14 @@ def get_sampler( seed=self.config["cmd"]["seed"], ) - def get_dataloader(self, dataset, sampler) -> DataLoader: + def get_dataloader(self, dataset, sampler, workers=None) -> DataLoader: + num_workers = ( + self.config["optim"]["num_workers"] if workers is None else workers + ) return DataLoader( dataset, collate_fn=self.collater, - num_workers=self.config["optim"]["num_workers"], + num_workers=num_workers, pin_memory=True, batch_sampler=sampler, ) @@ -528,15 +531,15 @@ def load_task(self): ][target_name].get("level", "system") if "train_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["train_on_free_atoms"] = ( - self.config[ - "outputs" - ][target_name].get("train_on_free_atoms", True) + self.config["outputs"][target_name].get( + "train_on_free_atoms", True + ) ) if "eval_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["eval_on_free_atoms"] = ( - self.config[ - "outputs" - ][target_name].get("eval_on_free_atoms", True) + self.config["outputs"][target_name].get( + "eval_on_free_atoms", True + ) ) # TODO: Assert that all targets, loss fn, metrics defined are consistent