From 1708164a9af5c8bf8b0c96e3be1283e42377296f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 13 May 2024 05:56:25 +0800 Subject: [PATCH] fix: lcurve header wrong when no validation data (#3774) --- deepmd/pt/train/training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 6934a9d8fe..655b729f8c 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -909,7 +909,7 @@ def log_loss_valid(_task_key="Default"): learning_rate=cur_lr, ) ) - if valid_results is not None and valid_results[_key]: + if valid_results[_key]: log.info( format_training_message_per_task( batch=_step_id, @@ -1118,7 +1118,7 @@ def print_header(self, fout, train_results, valid_results): print_str = "" print_str += "# %5s" % "step" if not self.multi_task: - if valid_results is not None: + if valid_results: prop_fmt = " %11s %11s" for k in train_keys: print_str += prop_fmt % (k + "_val", k + "_trn") @@ -1128,7 +1128,7 @@ def print_header(self, fout, train_results, valid_results): print_str += prop_fmt % (k + "_trn") else: for model_key in self.model_keys: - if valid_results[model_key] is not None: + if valid_results[model_key]: prop_fmt = " %11s %11s" for k in sorted(train_results[model_key].keys()): print_str += prop_fmt % (