Skip to content

Commit

Permalink
Merge TF and PT CLI (#3187)
Browse files Browse the repository at this point in the history
Just merge in form. Several options or subcommands are only supported by
TensorFlow or PyTorch.

Also, avoid import from `deepmd.tf` in `deepmd.utils.argcheck`.

```
Use --tf or --pt to choose the backend:
    dp --tf train input.json
    dp --pt train input.json
```

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jan 27, 2024
1 parent 2631ce2 commit 484bdc3
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 138 deletions.
2 changes: 1 addition & 1 deletion deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __new__(cls, model_file: str, *args, **kwargs):

return super().__new__(DeepPotTF)
elif backend == DPBackend.PyTorch:
from deepmd_pt.infer.deep_eval import DeepPot as DeepPotPT
from deepmd.pt.infer.deep_eval import DeepPot as DeepPotPT

return super().__new__(DeepPotPT)
else:
Expand Down
152 changes: 129 additions & 23 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import argparse
import logging
import os
import textwrap
from typing import (
List,
Expand Down Expand Up @@ -45,6 +46,21 @@ class RawTextArgumentDefaultsHelpFormatter(
"""This formatter is used to print multile-line help message with default value."""


BACKEND_TABLE = {
"tensorflow": "tensorflow",
"tf": "tensorflow",
"pytorch": "pytorch",
"pt": "pytorch",
}


class BackendOption(argparse.Action):
"""Map backend alias to unique name."""

def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, BACKEND_TABLE[values])


def main_parser() -> argparse.ArgumentParser:
"""DeePMD-Kit commandline options argument parser.
Expand All @@ -56,8 +72,51 @@ def main_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="DeePMD-kit: A deep learning package for many-body potential energy"
" representation and molecular dynamics",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
Use --tf or --pt to choose the backend:
dp --tf train input.json
dp --pt train input.json
"""
),
)

# default backend is TF for compatibility
default_backend = os.environ.get("DP_BACKEND", "tensorflow").lower()
if default_backend not in BACKEND_TABLE.keys():
raise ValueError(
f"Unknown backend {default_backend}. "
"Please set DP_BACKEND to either tensorflow or pytorch."
)

parser_backend = parser.add_mutually_exclusive_group()
parser_backend.add_argument(
"-b",
"--backend",
choices=list(BACKEND_TABLE.keys()),
action=BackendOption,
default=default_backend,
help=(
"The backend of the model. Default can be set by environment variable "
"DP_BACKEND."
),
)
parser_backend.add_argument(
"--tf",
action="store_const",
dest="backend",
const="tensorflow",
help="Alias for --backend tensorflow",
)
parser_backend.add_argument(
"--pt",
action="store_const",
dest="backend",
const="pytorch",
help="Alias for --backend pytorch",
)

subparsers = parser.add_subparsers(title="Valid subcommands", dest="command")

# * logging options parser *********************************************************
Expand Down Expand Up @@ -98,7 +157,9 @@ def main_parser() -> argparse.ArgumentParser:

# * transfer script ****************************************************************
parser_transfer = subparsers.add_parser(
"transfer", parents=[parser_log], help="pass parameters to another model"
"transfer",
parents=[parser_log],
help="(Supported backend: TensorFlow) pass parameters to another model",
)
parser_transfer.add_argument(
"-r",
Expand Down Expand Up @@ -160,7 +221,7 @@ def main_parser() -> argparse.ArgumentParser:
"--init-frz-model",
type=str,
default=None,
help="Initialize the training from the frozen model.",
help="(Supported backend: TensorFlow) Initialize the training from the frozen model.",
)
parser_train_subgroup.add_argument(
"-t",
Expand All @@ -174,12 +235,24 @@ def main_parser() -> argparse.ArgumentParser:
"--output",
type=str,
default="out.json",
help="The output file of the parameters used in training.",
help="(Supported backend: TensorFlow) The output file of the parameters used in training.",
)
parser_train.add_argument(
"--skip-neighbor-stat",
action="store_true",
help="Skip calculating neighbor statistics. Sel checking, automatic sel, and model compression will be disabled.",
help="(Supported backend: TensorFlow) Skip calculating neighbor statistics. Sel checking, automatic sel, and model compression will be disabled.",
)
parser_train.add_argument(
# -m has been used by mpi-log
"--model-branch",
type=str,
default="",
help="(Supported backend: PyTorch) Model branch chosen for fine-tuning if multi-task. If not specified, it will re-init the fitting net.",
)
parser_train.add_argument(
"--force-load",
action="store_true",
help="(Supported backend: PyTorch) Force load from ckpt, other missing tensors will init from scratch",
)

# * freeze script ******************************************************************
Expand All @@ -199,36 +272,43 @@ def main_parser() -> argparse.ArgumentParser:
parser_frz.add_argument(
"-c",
"--checkpoint-folder",
"--checkpoint",
type=str,
default=".",
help="path to checkpoint folder",
help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing model.pt, or a pt file",
)
parser_frz.add_argument(
"-o",
"--output",
type=str,
default="frozen_model.pb",
help="name of graph, will output to the checkpoint folder",
default="frozen_model",
help="Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
)
parser_frz.add_argument(
"-n",
"--node-names",
type=str,
default=None,
help="the frozen nodes, if not set, determined from the model type",
help="(Supported backend: TensorFlow) the frozen nodes, if not set, determined from the model type",
)
parser_frz.add_argument(
"-w",
"--nvnmd-weight",
type=str,
default=None,
help="the name of weight file (.npy), if set, save the model's weight into the file",
help="(Supported backend: TensorFlow) the name of weight file (.npy), if set, save the model's weight into the file",
)
parser_frz.add_argument(
"--united-model",
action="store_true",
default=False,
help="When in multi-task mode, freeze all nodes into one united model",
help="(Supported backend: TensorFlow) When in multi-task mode, freeze all nodes into one united model",
)
parser_frz.add_argument(
"--head",
default=None,
type=str,
help="(Supported backend: PyTorch) Task head to freeze if in multi-task mode.",
)

# * test script ********************************************************************
Expand All @@ -247,9 +327,9 @@ def main_parser() -> argparse.ArgumentParser:
parser_tst.add_argument(
"-m",
"--model",
default="frozen_model.pb",
default="frozen_model",
type=str,
help="Frozen model file to import",
help="Frozen model file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pt",
)
parser_tst_subgroup = parser_tst.add_mutually_exclusive_group()
parser_tst_subgroup.add_argument(
Expand All @@ -267,7 +347,11 @@ def main_parser() -> argparse.ArgumentParser:
help="The path to file of test list.",
)
parser_tst.add_argument(
"-S", "--set-prefix", default="set", type=str, help="The set prefix"
"-S",
"--set-prefix",
default="set",
type=str,
help="(Supported backend: TensorFlow) The set prefix",
)
parser_tst.add_argument(
"-n",
Expand All @@ -277,7 +361,11 @@ def main_parser() -> argparse.ArgumentParser:
help="The number of data for test. 0 means all data.",
)
parser_tst.add_argument(
"-r", "--rand-seed", type=int, default=None, help="The random seed"
"-r",
"--rand-seed",
type=int,
default=None,
help="(Supported backend: TensorFlow) The random seed",
)
parser_tst.add_argument(
"--shuffle-test", action="store_true", default=False, help="Shuffle test data"
Expand All @@ -294,7 +382,19 @@ def main_parser() -> argparse.ArgumentParser:
"--atomic",
action="store_true",
default=False,
help="Test the accuracy of atomic label, i.e. energy / tensor (dipole, polar)",
help="(Supported backend: TensorFlow) Test the accuracy of atomic label, i.e. energy / tensor (dipole, polar)",
)
parser_tst.add_argument(
"-i",
"--input_script",
type=str,
help="(Supported backend: PyTorch) The input script of the model",
)
parser_tst.add_argument(
"--head",
default=None,
type=str,
help="(Supported backend: PyTorch) Task head to test if in multi-task mode.",
)

# * compress model *****************************************************************
Expand All @@ -308,7 +408,7 @@ def main_parser() -> argparse.ArgumentParser:
parser_compress = subparsers.add_parser(
"compress",
parents=[parser_log, parser_mpi_log],
help="compress a model",
help="(Supported backend: TensorFlow) compress a model",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
Expand Down Expand Up @@ -409,10 +509,10 @@ def main_parser() -> argparse.ArgumentParser:
parser_model_devi.add_argument(
"-m",
"--models",
default=["graph.000.pb", "graph.001.pb", "graph.002.pb", "graph.003.pb"],
default=["graph.000", "graph.001", "graph.002", "graph.003"],
nargs="+",
type=str,
help="Frozen models file to import",
help="Frozen models file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pt.",
)
parser_model_devi.add_argument(
"-s",
Expand Down Expand Up @@ -465,7 +565,7 @@ def main_parser() -> argparse.ArgumentParser:
parser_transform = subparsers.add_parser(
"convert-from",
parents=[parser_log],
help="convert lower model version to supported version",
help="(Supported backend: TensorFlow) convert lower model version to supported version",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
Expand Down Expand Up @@ -503,7 +603,7 @@ def main_parser() -> argparse.ArgumentParser:
parser_neighbor_stat = subparsers.add_parser(
"neighbor-stat",
parents=[parser_log],
help="Calculate neighbor statistics",
help="(Supported backend: TensorFlow) Calculate neighbor statistics",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
Expand Down Expand Up @@ -550,7 +650,7 @@ def main_parser() -> argparse.ArgumentParser:
parser_train_nvnmd = subparsers.add_parser(
"train-nvnmd",
parents=[parser_log],
help="train nvnmd model",
help="(Supported backend: TensorFlow) train nvnmd model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
Expand Down Expand Up @@ -651,6 +751,12 @@ def main():
if no command was input
"""
args = parse_args()
from deepmd.tf.entrypoints.main import main as deepmd_main

if args.backend == "tensorflow":
from deepmd.tf.entrypoints.main import main as deepmd_main
elif args.backend == "pytorch":
from deepmd.pt.entrypoints.main import main as deepmd_main
else:
raise ValueError(f"Unknown backend {args.backend}")

deepmd_main(args)
Loading

0 comments on commit 484bdc3

Please sign in to comment.