Skip to content

Commit

Permalink
pt: apply argcheck to pt (#3342)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 29, 2024
1 parent 581dea3 commit 17bd1ec
Show file tree
Hide file tree
Showing 16 changed files with 506 additions and 77 deletions.
11 changes: 11 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.utils.argcheck import (
normalize,
)
from deepmd.utils.compat import (
update_deepmd_input,
)
from deepmd.utils.path import (
DPPath,
)
Expand All @@ -67,6 +73,11 @@ def get_trainer(
force_load=False,
init_frz_model=None,
):
# argcheck
if "model_dict" not in config.get("model", {}):
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

# Initialize DDP
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
Expand Down
26 changes: 24 additions & 2 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,38 @@ def __init__(
post_ln=True,
ffn=False,
ffn_embed_dim=1024,
activation="tanh",
activation_function="tanh",
scaling_factor=1.0,
head_num=1,
normalize=True,
temperature=None,
return_rot=False,
concat_output_tebd: bool = True,
type: Optional[str] = None,
# not implemented
resnet_dt: bool = False,
type_one_side: bool = True,
precision: str = "default",
trainable: bool = True,
exclude_types: Optional[List[List[int]]] = None,
stripped_type_embedding: bool = False,
smooth_type_embdding: bool = False,
):
super().__init__()
if resnet_dt:
raise NotImplementedError("resnet_dt is not supported.")
if not type_one_side:
raise NotImplementedError("type_one_side is not supported.")
if precision != "default" and precision != "float64":
raise NotImplementedError("precison is not supported.")
if not trainable:
raise NotImplementedError("trainable == False is not supported.")
if exclude_types is not None and exclude_types != []:
raise NotImplementedError("exclude_types is not supported.")
if stripped_type_embedding:
raise NotImplementedError("stripped_type_embedding is not supported.")
if smooth_type_embdding:
raise NotImplementedError("smooth_type_embdding is not supported.")
del type
self.se_atten = DescrptBlockSeAtten(
rcut,
Expand All @@ -71,7 +93,7 @@ def __init__(
post_ln=post_ln,
ffn=ffn,
ffn_embed_dim=ffn_embed_dim,
activation=activation,
activation_function=activation_function,
scaling_factor=scaling_factor,
head_num=head_num,
normalize=normalize,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(
tebd_input_mode="concat",
# tebd_input_mode='dot_residual_s',
set_davg_zero=repinit_set_davg_zero,
activation=repinit_activation,
activation_function=repinit_activation,
)
self.repformers = DescrptBlockRepformers(
repformer_rcut,
Expand All @@ -223,7 +223,7 @@ def __init__(
attn2_hidden=repformer_attn2_hidden,
attn2_nhead=repformer_attn2_nhead,
attn2_has_gate=repformer_attn2_has_gate,
activation=repformer_activation,
activation_function=repformer_activation,
update_style=repformer_update_style,
set_davg_zero=repformer_set_davg_zero,
smooth=True,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def __init__(
attn2_hidden: int = 16,
attn2_nhead: int = 4,
attn2_has_gate: bool = False,
activation: str = "tanh",
activation_function: str = "tanh",
update_style: str = "res_avg",
set_davg_zero: bool = True, # TODO
smooth: bool = True,
Expand All @@ -332,7 +332,7 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.do_bn_mode = do_bn_mode
self.bn_momentum = bn_momentum
self.act = get_activation_fn(activation)
self.act = get_activation_fn(activation_function)
self.update_g1_has_grrg = update_g1_has_grrg
self.update_g1_has_drrd = update_g1_has_drrd
self.update_g1_has_conv = update_g1_has_conv
Expand Down
6 changes: 3 additions & 3 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
attn2_hidden: int = 16,
attn2_nhead: int = 4,
attn2_has_gate: bool = False,
activation: str = "tanh",
activation_function: str = "tanh",
update_style: str = "res_avg",
set_davg_zero: bool = True, # TODO
smooth: bool = True,
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.g1_dim = g1_dim
self.g2_dim = g2_dim
self.act = get_activation_fn(activation)
self.act = get_activation_fn(activation_function)
self.direct_dist = direct_dist
self.add_type_ebd_to_seq = add_type_ebd_to_seq

Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
attn2_has_gate=attn2_has_gate,
attn2_hidden=attn2_hidden,
attn2_nhead=attn2_nhead,
activation=activation,
activation_function=activation_function,
update_style=update_style,
smooth=smooth,
)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
post_ln=True,
ffn=False,
ffn_embed_dim=1024,
activation="tanh",
activation_function="tanh",
scaling_factor=1.0,
head_num=1,
normalize=True,
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
self.post_ln = post_ln
self.ffn = ffn
self.ffn_embed_dim = ffn_embed_dim
self.activation = activation
self.activation = activation_function
# TODO: To be fixed: precision should be given from inputs
self.prec = torch.float64
self.scaling_factor = scaling_factor
Expand Down
28 changes: 28 additions & 0 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,41 @@ def __init__(
multi_task: bool = False,
stripped_type_embedding: bool = False,
smooth_type_embdding: bool = False,
# not implemented
post_ln=True,
ffn=False,
ffn_embed_dim=1024,
scaling_factor=1.0,
head_num=1,
normalize=True,
temperature=None,
return_rot=False,
concat_output_tebd: bool = True,
**kwargs,
) -> None:
if not set_davg_zero and not (stripped_type_embedding and smooth_type_embdding):
warnings.warn(
"Set 'set_davg_zero' False in descriptor 'se_atten' "
"may cause unexpected incontinuity during model inference!"
)
if not post_ln:
raise NotImplementedError("post_ln is not supported.")
if ffn:
raise NotImplementedError("ffn is not supported.")
if ffn_embed_dim != 1024:
raise NotImplementedError("ffn_embed_dim is not supported.")
if scaling_factor != 1.0:
raise NotImplementedError("scaling_factor is not supported.")
if head_num != 1:
raise NotImplementedError("head_num is not supported.")
if not normalize:
raise NotImplementedError("normalize is not supported.")
if temperature is not None:
raise NotImplementedError("temperature is not supported.")
if return_rot:
raise NotImplementedError("return_rot is not supported.")
if not concat_output_tebd:
raise NotImplementedError("concat_output_tebd is not supported.")
DescrptSeA.__init__(
self,
rcut,
Expand Down
Loading

0 comments on commit 17bd1ec

Please sign in to comment.