From 6d973ef593d0a7710d4f8fc9673ae00c3d9b021b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 24 Oct 2023 22:24:43 -0400 Subject: [PATCH] argcheck: restrict the type of elements in a list (#2945) Signed-off-by: Jinzhe Zeng --- deepmd/utils/argcheck.py | 136 ++++++++++++++++++++++++++------------- pyproject.toml | 2 +- 2 files changed, 94 insertions(+), 44 deletions(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index ae446ef348..7104eb1de4 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -56,7 +56,7 @@ def type_embedding_args(): doc_trainable = "If the parameters in the embedding net are trainable" return [ - Argument("neuron", list, optional=True, default=[8], doc=doc_neuron), + Argument("neuron", List[int], optional=True, default=[8], doc=doc_neuron), Argument( "activation_function", str, @@ -77,9 +77,9 @@ def spin_args(): doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin" return [ - Argument("use_spin", list, doc=doc_use_spin), - Argument("spin_norm", list, doc=doc_spin_norm), - Argument("virtual_len", list, doc=doc_virtual_len), + Argument("use_spin", List[bool], doc=doc_use_spin), + Argument("spin_norm", List[float], doc=doc_spin_norm), + Argument("virtual_len", List[float], doc=doc_virtual_len), ] @@ -159,10 +159,10 @@ def descrpt_local_frame_args(): - axis_rule[i*6+5]: index of the axis atom defining the second axis. Note that the neighbors with the same class and type are sorted according to their relative distance." return [ - Argument("sel_a", list, optional=False, doc=doc_sel_a), - Argument("sel_r", list, optional=False, doc=doc_sel_r), + Argument("sel_a", List[int], optional=False, doc=doc_sel_a), + Argument("sel_r", List[int], optional=False, doc=doc_sel_r), Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut), - Argument("axis_rule", list, optional=False, doc=doc_axis_rule), + Argument("axis_rule", List[int], optional=False, doc=doc_axis_rule), ] @@ -185,10 +185,12 @@ def descrpt_se_a_args(): doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used" return [ - Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel), + Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel), Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut), Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth), - Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron), + Argument( + "neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron + ), Argument( "axis_neuron", int, @@ -212,7 +214,11 @@ def descrpt_se_a_args(): Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), Argument("seed", [int, None], optional=True, doc=doc_seed), Argument( - "exclude_types", list, optional=True, default=[], doc=doc_exclude_types + "exclude_types", + List[List[int]], + optional=True, + default=[], + doc=doc_exclude_types, ), Argument( "set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero @@ -236,10 +242,12 @@ def descrpt_se_t_args(): doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used" return [ - Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel), + Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel), Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut), Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth), - Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron), + Argument( + "neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron + ), Argument( "activation_function", str, @@ -289,10 +297,12 @@ def descrpt_se_r_args(): doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used" return [ - Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel), + Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel), Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut), Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth), - Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron), + Argument( + "neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron + ), Argument( "activation_function", str, @@ -308,7 +318,11 @@ def descrpt_se_r_args(): Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), Argument("seed", [int, None], optional=True, doc=doc_seed), Argument( - "exclude_types", list, optional=True, default=[], doc=doc_exclude_types + "exclude_types", + List[List[int]], + optional=True, + default=[], + doc=doc_exclude_types, ), Argument( "set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero @@ -356,10 +370,14 @@ def descrpt_se_atten_common_args(): doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix" return [ - Argument("sel", [int, list, str], optional=True, default="auto", doc=doc_sel), + Argument( + "sel", [int, List[int], str], optional=True, default="auto", doc=doc_sel + ), Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut), Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth), - Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron), + Argument( + "neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron + ), Argument( "axis_neuron", int, @@ -383,7 +401,11 @@ def descrpt_se_atten_common_args(): Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), Argument("seed", [int, None], optional=True, doc=doc_seed), Argument( - "exclude_types", list, optional=True, default=[], doc=doc_exclude_types + "exclude_types", + List[List[int]], + optional=True, + default=[], + doc=doc_exclude_types, ), Argument("attn", int, optional=True, default=128, doc=doc_attn), Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer), @@ -454,8 +476,10 @@ def descrpt_se_a_mask_args(): doc_seed = "Random seed for parameter initialization" return [ - Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel), - Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron), + Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel), + Argument( + "neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron + ), Argument( "axis_neuron", int, @@ -476,7 +500,11 @@ def descrpt_se_a_mask_args(): "type_one_side", bool, optional=True, default=False, doc=doc_type_one_side ), Argument( - "exclude_types", list, optional=True, default=[], doc=doc_exclude_types + "exclude_types", + List[List[int]], + optional=True, + default=[], + doc=doc_exclude_types, ), Argument("precision", str, optional=True, default="default", doc=doc_precision), Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), @@ -525,7 +553,7 @@ def fitting_ener(): doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\ - bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\ -- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of tihs list should be equal to len(`neuron`)+1." +- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1." doc_rcond = "The condition number used to determine the inital energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details." doc_seed = "Random seed for parameter initialization of the fitting net" doc_atom_ener = "Specify the atomic energy in vacuum for each type" @@ -547,7 +575,7 @@ def fitting_ener(): Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), Argument( "neuron", - list, + List[int], optional=True, default=[120, 120, 120], alias=["n_neuron"], @@ -563,14 +591,24 @@ def fitting_ener(): Argument("precision", str, optional=True, default="default", doc=doc_precision), Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), Argument( - "trainable", [list, bool], optional=True, default=True, doc=doc_trainable + "trainable", + [List[bool], bool], + optional=True, + default=True, + doc=doc_trainable, ), Argument( "rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond ), Argument("seed", [int, None], optional=True, doc=doc_seed), - Argument("atom_ener", list, optional=True, default=[], doc=doc_atom_ener), - Argument("layer_name", list, optional=True, doc=doc_layer_name), + Argument( + "atom_ener", + List[Optional[float]], + optional=True, + default=[], + doc=doc_atom_ener, + ), + Argument("layer_name", List[str], optional=True, doc=doc_layer_name), Argument( "use_aparam_as_mask", bool, @@ -602,7 +640,7 @@ def fitting_dos(): Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), Argument( - "neuron", list, optional=True, default=[120, 120, 120], doc=doc_neuron + "neuron", List[int], optional=True, default=[120, 120, 120], doc=doc_neuron ), Argument( "activation_function", @@ -614,7 +652,11 @@ def fitting_dos(): Argument("precision", str, optional=True, default="float64", doc=doc_precision), Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), Argument( - "trainable", [list, bool], optional=True, default=True, doc=doc_trainable + "trainable", + [List[bool], bool], + optional=True, + default=True, + doc=doc_trainable, ), Argument( "rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond @@ -642,7 +684,7 @@ def fitting_polar(): return [ Argument( "neuron", - list, + List[int], optional=True, default=[120, 120, 120], alias=["n_neuron"], @@ -658,12 +700,14 @@ def fitting_polar(): Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), Argument("precision", str, optional=True, default="default", doc=doc_precision), Argument("fit_diag", bool, optional=True, default=True, doc=doc_fit_diag), - Argument("scale", [list, float], optional=True, default=1.0, doc=doc_scale), + Argument( + "scale", [List[float], float], optional=True, default=1.0, doc=doc_scale + ), # Argument("diag_shift", [list,float], optional = True, default = 0.0, doc = doc_diag_shift), Argument("shift_diag", bool, optional=True, default=True, doc=doc_shift_diag), Argument( "sel_type", - [list, int, None], + [List[int], int, None], optional=True, alias=["pol_type"], doc=doc_sel_type, @@ -687,7 +731,7 @@ def fitting_dipole(): return [ Argument( "neuron", - list, + List[int], optional=True, default=[120, 120, 120], alias=["n_neuron"], @@ -704,7 +748,7 @@ def fitting_dipole(): Argument("precision", str, optional=True, default="default", doc=doc_precision), Argument( "sel_type", - [list, int, None], + [List[int], int, None], optional=True, alias=["dipole_type"], doc=doc_sel_type, @@ -740,8 +784,10 @@ def modifier_dipole_charge(): return [ Argument("model_name", str, optional=False, doc=doc_model_name), - Argument("model_charge_map", list, optional=False, doc=doc_model_charge_map), - Argument("sys_charge_map", list, optional=False, doc=doc_sys_charge_map), + Argument( + "model_charge_map", List[float], optional=False, doc=doc_model_charge_map + ), + Argument("sys_charge_map", List[float], optional=False, doc=doc_sys_charge_map), Argument("ewald_beta", float, optional=True, default=0.4, doc=doc_ewald_beta), Argument("ewald_h", float, optional=True, default=1.0, doc=doc_ewald_h), ] @@ -770,7 +816,7 @@ def model_compression(): return [ Argument("model_file", str, optional=False, doc=doc_model_file), - Argument("table_config", list, optional=False, doc=doc_table_config), + Argument("table_config", List[float], optional=False, doc=doc_table_config), Argument("min_nbor_dist", float, optional=False, doc=doc_min_nbor_dist), ] @@ -814,7 +860,7 @@ def model_args(exclude_hybrid=False): "model", dict, [ - Argument("type_map", list, optional=True, doc=doc_type_map), + Argument("type_map", List[str], optional=True, doc=doc_type_map), Argument( "data_stat_nbatch", int, @@ -1456,11 +1502,13 @@ def training_data_args(): # ! added by Ziyao: new specification style for data ) args = [ - Argument("systems", [list, str], optional=False, default=".", doc=doc_systems), + Argument( + "systems", [List[str], str], optional=False, default=".", doc=doc_systems + ), Argument("set_prefix", str, optional=True, default="set", doc=doc_set_prefix), Argument( "batch_size", - [list, int, str], + [List[int], int, str], optional=True, default="auto", doc=doc_batch_size, @@ -1477,7 +1525,7 @@ def training_data_args(): # ! added by Ziyao: new specification style for data ), Argument( "sys_probs", - list, + List[float], optional=True, default=None, doc=doc_sys_probs, @@ -1521,11 +1569,13 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat doc_numb_btch = "An integer that specifies the number of batches to be sampled for each validation period." args = [ - Argument("systems", [list, str], optional=False, default=".", doc=doc_systems), + Argument( + "systems", [List[str], str], optional=False, default=".", doc=doc_systems + ), Argument("set_prefix", str, optional=True, default="set", doc=doc_set_prefix), Argument( "batch_size", - [list, int, str], + [List[int], int, str], optional=True, default="auto", doc=doc_batch_size, @@ -1542,7 +1592,7 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat ), Argument( "sys_probs", - list, + List[float], optional=True, default=None, doc=doc_sys_probs, diff --git a/pyproject.toml b/pyproject.toml index 8c5267567b..35a11d2163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ 'numpy', 'scipy', 'pyyaml', - 'dargs >= 0.3.5', + 'dargs >= 0.4.1', 'python-hostlist >= 1.21', 'typing_extensions; python_version < "3.8"', 'importlib_metadata>=1.4; python_version < "3.8"',