From 92ee6326aecec689882d49729f09b791309f9064 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 1 Mar 2024 21:09:26 -0500 Subject: [PATCH] tf: remove freeze warning for optional nodes (#3381) Fix #3334. --------- Signed-off-by: Jinzhe Zeng --- deepmd/tf/entrypoints/freeze.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/deepmd/tf/entrypoints/freeze.py b/deepmd/tf/entrypoints/freeze.py index 228f8466cb..c7ab1023fa 100755 --- a/deepmd/tf/entrypoints/freeze.py +++ b/deepmd/tf/entrypoints/freeze.py @@ -359,13 +359,21 @@ def freeze_graph( output_node = _make_node_names( freeze_type, modifier, out_suffix=out_suffix, node_names=node_names ) + # see #3334 + optional_node = [ + "train_attr/min_nbor_dist", + "fitting_attr/aparam_nall", + "spin_attr/ntypes_spin", + ] different_set = set(output_node) - set(input_node) if different_set: - log.warning( - "The following nodes are not in the graph: %s. " - "Skip freezeing these nodes. You may be freezing " - "a checkpoint generated by an old version." % different_set - ) + different_set -= set(optional_node) + if different_set: + log.warning( + "The following nodes are not in the graph: %s. " + "Skip freezeing these nodes. You may be freezing " + "a checkpoint generated by an old version." % different_set + ) # use intersection as output list output_node = list(set(output_node) & set(input_node)) log.info(f"The following nodes will be frozen: {output_node}")