From 50b00090e1ac2777faf84446f5d4e04bcef55188 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 May 2024 09:06:22 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/main.py | 6 ++++-- deepmd/pt/entrypoints/main.py | 3 +-- deepmd/pt/train/training.py | 4 +++- deepmd/pt/utils/finetune.py | 1 - 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/deepmd/main.py b/deepmd/main.py index e6f4de6359..0a429e3f9e 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -770,7 +770,9 @@ def main_parser() -> argparse.ArgumentParser: """ ), ) - parser_list_model_branch.add_argument("INPUT", help="The input multi-task pre-trained model file") + parser_list_model_branch.add_argument( + "INPUT", help="The input multi-task pre-trained model file" + ) return parser @@ -829,7 +831,7 @@ def main(): "compress", "convert-from", "train-nvnmd", - "list-model-branch" + "list-model-branch", ): deepmd_main = BACKENDS[args.backend]().entry_point_hook elif args.command is None: diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index b47f07f847..1144c09748 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -74,7 +74,6 @@ from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter log = logging.getLogger(__name__) -from IPython import embed def get_trainer( @@ -346,7 +345,7 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): # Pretrained model must be multitask mode assert finetune_from_multi_task, "When using --list-model-branch, the pretrained model must be multitask model" model_branch = list(model_params["model_dict"].keys()) - log.info(f"Available model branches are {model_branch}") + log.info(f"Available model branches are {model_branch}") else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 2ba7789821..22eba6c700 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -83,7 +83,9 @@ ) log = logging.getLogger(__name__) -from IPython import embed +from IPython import ( + embed, +) class Trainer: diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index b8f089ca6a..2de4214070 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -5,7 +5,6 @@ ) import torch -from IPython import embed from deepmd.pt.utils import ( env,