Skip to content

Commit

Permalink
check if the model require grad by output def. fix doc str. fix type …
Browse files Browse the repository at this point in the history
…hint
  • Loading branch information
Han Wang committed Jan 22, 2024
1 parent 947308d commit 165267a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 15 deletions.
26 changes: 25 additions & 1 deletion deepmd_pt/model/model/atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (
Optional,
Dict,
List,
)
from deepmd_utils.model_format import FittingOutputDef
from deepmd_pt.model.task import Fitting
Expand All @@ -25,7 +26,7 @@ def get_rcut(self)->float:
raise NotImplementedError

@abstractmethod
def get_sel(self)->int:
def get_sel(self)->List[int]:
raise NotImplementedError

@abstractmethod
Expand All @@ -43,4 +44,27 @@ def forward_atomic(
) -> Dict[str, torch.Tensor]:
raise NotImplementedError

def do_grad(
self,
var_name: Optional[str] = None,
)->bool:
"""Tell if the output variable `var_name` is differentiable.
if var_name is None, returns if any of the variable is differentiable.
"""
odef = self.get_fitting_output_def()
if var_name is None:
require: List[bool] = []
for vv in odef.keys():
require.append(self.do_grad_(vv))
return any(require)
else:
return self.do_grad_(var_name)

def do_grad_(
self,
var_name: str,
)->bool:
"""Tell if the output variable `var_name` is differentiable."""
assert var_name is not None
return self.get_fitting_output_def()[var_name].differentiable
21 changes: 10 additions & 11 deletions deepmd_pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ def __init__(
sampled=None,
**kwargs,
):
"""Based on components, construct a DPA-1 model for energy.
Args:
- model_params: The Dict-like configuration with model options.
- sampled: The sampled dataset for stat.
"""
super().__init__()
# Descriptor + Type Embedding Net (Optional)
ntypes = len(type_map)
Expand Down Expand Up @@ -113,27 +107,32 @@ def __init__(
self.coord_denoise_net = DenoiseNet(self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb)


def get_fitting_net(self):
def get_fitting_net(self)->Fitting:
"""Get the fitting net."""
return (
self.fitting_net
if self.fitting_net is not None
else self.coord_denoise_net
)

def get_fitting_output_def(self)->FittingOutputDef:
"""Get the output def of the fitting net."""
return (
self.fitting_net.output_def()
if self.fitting_net is not None
else self.coord_denoise_net.output_def()
)

def get_rcut(self):
def get_rcut(self)->float:
"""Get the cut-off radius."""
return self.rcut

def get_sel(self):
def get_sel(self)->List[int]:
"""Get the neighbor selection."""
return self.sel

def distinguish_types(self):
def distinguish_types(self)->bool:
"""If distinguish different types by sorting."""
return self.type_split


Expand Down Expand Up @@ -168,7 +167,7 @@ def forward_atomic(
"""
nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc]
if self.grad_force:
if self.do_grad():
extended_coord.requires_grad_(True)
descriptor, env_mat, diff, rot_mat, sw = \
self.descriptor(
Expand Down
4 changes: 2 additions & 2 deletions deepmd_pt/model/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.grad_force:
if self.do_grad("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
model_predict["atomic_virial"] = model_ret["energy_derv_c"].squeeze(-3)
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-3)
Expand All @@ -69,7 +69,7 @@ def forward_lower(
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.grad_force:
if self.do_grad("energy"):
model_predict['extended_force'] = model_ret['energy_derv_r'].squeeze(-2)
model_predict['extended_virial'] = model_ret['energy_derv_c'].squeeze(-3)
else:
Expand Down
1 change: 1 addition & 0 deletions deepmd_pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_model_output_def(self):
self.get_fitting_output_def()
)

# cannot use the name forward. torch script does not work
def forward_common(
self,
coord,
Expand Down
2 changes: 1 addition & 1 deletion deepmd_pt/model/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self):
"""
super(BaseModel, self).__init__()

def forward(self, **kwargs):
def forward(self, *args, **kwargs):
"""Model output.
"""
raise NotImplementedError
Expand Down

0 comments on commit 165267a

Please sign in to comment.