Skip to content

Commit

Permalink
refactor save_checkpoint()
Browse files Browse the repository at this point in the history
potentially breaking by separating different target errors by underscore
  • Loading branch information
janosh committed Oct 30, 2023
1 parent 3673a84 commit b90b44a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
24 changes: 6 additions & 18 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,15 +508,11 @@ def save_checkpoint(
if fname.startswith("epoch"):
os.remove(os.path.join(save_dir, fname))

rounded_mae_e = round(mae_error["e"] * 1000)
rounded_mae_f = round(mae_error["f"] * 1000)
rounded_mae_s = round(mae_error["s"] * 1000) if "s" in mae_error else "NA"
rounded_mae_m = round(mae_error["m"] * 1000) if "m" in mae_error else "NA"
filename = os.path.join(
save_dir,
f"epoch{epoch}_e{rounded_mae_e}f{rounded_mae_f}"
f"s{rounded_mae_s}m{rounded_mae_m}.pth.tar",
err_str = "_".join(
f"{key}{f'{mae_error[key] * 1000:.0f}' if key in mae_error else 'NA'}"
for key in "efsm"
)
filename = os.path.join(save_dir, f"epoch{epoch}_{err_str}.pth.tar")
self.save(filename=filename)

# save the model if it has minimal val energy error or val force error
Expand All @@ -527,23 +523,15 @@ def save_checkpoint(
os.remove(os.path.join(save_dir, fname))
shutil.copyfile(
filename,
os.path.join(
save_dir,
f"bestE_epoch{epoch}_e{rounded_mae_e}f{rounded_mae_f}"
f"s{rounded_mae_s}m{rounded_mae_m}.pth.tar",
),
os.path.join(save_dir, f"bestE_epoch{epoch}_{err_str}.pth.tar"),
)
if mae_error["f"] == min(self.training_history["f"]["val"]):
for fname in os.listdir(save_dir):
if fname.startswith("bestF"):
os.remove(os.path.join(save_dir, fname))
shutil.copyfile(
filename,
os.path.join(
save_dir,
f"bestF_epoch{epoch}_e{rounded_mae_e}f{rounded_mae_f}"
f"s{rounded_mae_s}m{rounded_mae_m}.pth.tar",
),
os.path.join(save_dir, f"bestF_epoch{epoch}_{err_str}.pth.tar"),
)

@classmethod
Expand Down
10 changes: 6 additions & 4 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def test_trainer(tmp_path) -> None:
assert param.requires_grad is False
assert test_dir.is_dir(), "Training dir was not created"

output_files = list(test_dir.iterdir())
for prefix in ("epoch", "bestE", "bestF"):
n_matches = sum(file.name.startswith(prefix) for file in output_files)
assert n_matches == 1
output_files = [file.name for file in test_dir.iterdir()]
for prefix in ("epoch", "bestE_", "bestF_"):
n_matches = sum(file.startswith(prefix) for file in output_files)
assert (
n_matches == 1
), f"Expected 1 {prefix} file, found {n_matches} in {output_files}"


def test_trainer_composition_model(tmp_path) -> None:
Expand Down

0 comments on commit b90b44a

Please sign in to comment.