Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/pip/numpy-2.2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque authored Dec 17, 2024
2 parents 3c2b761 + 6efc99b commit 3c1b8dc
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
- if: ${{ matrix.python_version == '3.12' }}
name: codecov-report
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v5
with:
fail_ci_if_error: false # optional (default = false)
files: ./coverage.xml
Expand Down
10 changes: 10 additions & 0 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,3 +1478,13 @@ def get_weight_table(model: torch.nn.Module) -> tuple[list, list]:
row_grad = [None] * len(row_weight)
data.append([param_name] + [params.shape] + row_weight + row_grad) # noqa
return columns, data


def get_checkpoint_format(config: dict) -> str:
# a temporary function to retrieve the checkpoint format from old configs
format = config.get("optim", {}).get("checkpoint_format", "pt")
assert format in (
"pt",
"dcp",
), f"checkpoint format can only be pt or dcp, found {format}"
return format
17 changes: 12 additions & 5 deletions src/fairchem/core/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import get_checkpoint_format
from fairchem.core.trainers import OCPTrainer


Expand All @@ -21,10 +22,13 @@ def __init__(self, config) -> None:
def setup(self, trainer) -> None:
self.trainer = trainer

# TODO: make checkpoint.pt a constant so we don't pass this string around everywhere
self.chkpt_path = os.path.join(
self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt"
)
format = get_checkpoint_format(self.config)
if format == "pt":
self.chkpt_path = os.path.join(
self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt"
)
else:
self.chkpt_path = self.trainer.config["cmd"]["checkpoint_dir"]

# if the supplied checkpoint exists, then load that, ie: when user specifies the --checkpoint option
# OR if the a job was preempted correctly and the submitit checkpoint function was called
Expand All @@ -38,7 +42,10 @@ def setup(self, trainer) -> None:
# if the supplied checkpoint doesn't exist and there exists a previous checkpoint in the checkpoint path, this
# means that the previous job didn't terminate "nicely" (due to node failures, crashes etc), then attempt
# to load the last found checkpoint
elif os.path.exists(self.chkpt_path):
elif (
os.path.isfile(self.chkpt_path)
or (os.path.isdir(self.chkpt_path) and len(os.listdir(self.chkpt_path))) > 0
):
logging.info(
f"Previous checkpoint found at {self.chkpt_path}, resuming job from this checkecpoint"
)
Expand Down
7 changes: 5 additions & 2 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,14 @@ def get_sampler(
seed=self.config["cmd"]["seed"],
)

def get_dataloader(self, dataset, sampler) -> DataLoader:
def get_dataloader(self, dataset, sampler, workers=None) -> DataLoader:
num_workers = (
self.config["optim"]["num_workers"] if workers is None else workers
)
return DataLoader(
dataset,
collate_fn=self.collater,
num_workers=self.config["optim"]["num_workers"],
num_workers=num_workers,
pin_memory=True,
batch_sampler=sampler,
)
Expand Down

0 comments on commit 3c1b8dc

Please sign in to comment.