From b90b44a6fe28ab15239799606f4cee5ee79eba1e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 30 Oct 2023 08:16:21 -0700 Subject: [PATCH] refactor save_checkpoint() potentially breaking by separating different target errors by underscore --- chgnet/trainer/trainer.py | 24 ++++++------------------ tests/test_trainer.py | 10 ++++++---- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index b1d9bc5a..609673ed 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -508,15 +508,11 @@ def save_checkpoint( if fname.startswith("epoch"): os.remove(os.path.join(save_dir, fname)) - rounded_mae_e = round(mae_error["e"] * 1000) - rounded_mae_f = round(mae_error["f"] * 1000) - rounded_mae_s = round(mae_error["s"] * 1000) if "s" in mae_error else "NA" - rounded_mae_m = round(mae_error["m"] * 1000) if "m" in mae_error else "NA" - filename = os.path.join( - save_dir, - f"epoch{epoch}_e{rounded_mae_e}f{rounded_mae_f}" - f"s{rounded_mae_s}m{rounded_mae_m}.pth.tar", + err_str = "_".join( + f"{key}{f'{mae_error[key] * 1000:.0f}' if key in mae_error else 'NA'}" + for key in "efsm" ) + filename = os.path.join(save_dir, f"epoch{epoch}_{err_str}.pth.tar") self.save(filename=filename) # save the model if it has minimal val energy error or val force error @@ -527,11 +523,7 @@ def save_checkpoint( os.remove(os.path.join(save_dir, fname)) shutil.copyfile( filename, - os.path.join( - save_dir, - f"bestE_epoch{epoch}_e{rounded_mae_e}f{rounded_mae_f}" - f"s{rounded_mae_s}m{rounded_mae_m}.pth.tar", - ), + os.path.join(save_dir, f"bestE_epoch{epoch}_{err_str}.pth.tar"), ) if mae_error["f"] == min(self.training_history["f"]["val"]): for fname in os.listdir(save_dir): @@ -539,11 +531,7 @@ def save_checkpoint( os.remove(os.path.join(save_dir, fname)) shutil.copyfile( filename, - os.path.join( - save_dir, - f"bestF_epoch{epoch}_e{rounded_mae_e}f{rounded_mae_f}" - f"s{rounded_mae_s}m{rounded_mae_m}.pth.tar", - ), + os.path.join(save_dir, f"bestF_epoch{epoch}_{err_str}.pth.tar"), ) @classmethod diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 6abac1e2..bd93c3c0 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -51,10 +51,12 @@ def test_trainer(tmp_path) -> None: assert param.requires_grad is False assert test_dir.is_dir(), "Training dir was not created" - output_files = list(test_dir.iterdir()) - for prefix in ("epoch", "bestE", "bestF"): - n_matches = sum(file.name.startswith(prefix) for file in output_files) - assert n_matches == 1 + output_files = [file.name for file in test_dir.iterdir()] + for prefix in ("epoch", "bestE_", "bestF_"): + n_matches = sum(file.startswith(prefix) for file in output_files) + assert ( + n_matches == 1 + ), f"Expected 1 {prefix} file, found {n_matches} in {output_files}" def test_trainer_composition_model(tmp_path) -> None: