Skip to content

Commit

Permalink
[Paddle Backend] Add spin energy example(revert code format) (#3082)
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate authored Dec 25, 2023
1 parent 36e0082 commit bb28e11
Show file tree
Hide file tree
Showing 15 changed files with 763 additions and 358 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# DeePMD-kit(PaddlePaddle backend)

> [!IMPORTANT]
> 本项目为 DeePMD-kit 的 PaddlePaddle 版本,主要修改了部分代码,使其可以运行在 PaddlePaddle 上。运行功能包括 water_se_e2_a 案例的单卡 GPU 训练、单卡 GPU 评估、导出静态图模型、接入 LAMMPS(GPU) 推理 4 部分的功能。
> 本项目为 DeePMD-kit 的 PaddlePaddle 版本,修改了部分代码,使其可以以 PaddlePaddle(GPU) 为后端进行训练、评估、模型导出、LAMMPS 推理等任务。案例支持情况如下所示。
> | example | Train | Test | Export | LAMMPS |
> | :-----: | :--: | :--: | :----: | :---: |
> | water/se_e2_a |||||
> | spin/se_e2_a |||| TODO |
## 1. 环境安装

Expand All @@ -20,9 +24,8 @@

3. 安装 deepmd-kit


``` sh
git clone https://github.com/HydrogenSulfate/deepmd-kit.git -b add_ddle_backend_polish_ver
git clone https://github.com/deepmodeling/deepmd-kit.git -b paddle2
cd deepmd-kit
# 以 editable 的方式安装,方便调试
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
Expand Down
53 changes: 33 additions & 20 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,17 @@ def __init__(
self.useBN = False
self.dstd = None
self.davg = None
self.avg_zero = paddle.zeros([self.ntypes, self.ndescrpt], dtype="float32")
self.std_ones = paddle.ones([self.ntypes, self.ndescrpt], dtype="float32")

# self.compress = False
# self.embedding_net_variables = None
# self.mixed_prec = None
# self.place_holders = {}
# self.nei_type = np.repeat(np.arange(self.ntypes), self.sel_a)
self.avg_zero = paddle.zeros(
[self.ntypes, self.ndescrpt], dtype=GLOBAL_PD_FLOAT_PRECISION
)
self.std_ones = paddle.ones(
[self.ntypes, self.ndescrpt], dtype=GLOBAL_PD_FLOAT_PRECISION
)
nets = []
for type_input in range(self.ntypes):
layer = []
Expand Down Expand Up @@ -242,11 +250,19 @@ def __init__(
}

self.t_rcut = paddle.to_tensor(
np.max([self.rcut_r, self.rcut_a]), dtype="float32"
np.max([self.rcut_r, self.rcut_a]), dtype=GLOBAL_PD_FLOAT_PRECISION
)
self.register_buffer("buffer_sel", paddle.to_tensor(self.sel_a, dtype="int32"))
self.register_buffer(
"buffer_ndescrpt", paddle.to_tensor(self.ndescrpt, dtype="int32")
)
self.register_buffer(
"buffer_original_sel",
paddle.to_tensor(
self.original_sel if self.original_sel is not None else self.sel_a,
dtype="int32",
),
)
self.t_ntypes = paddle.to_tensor(self.ntypes, dtype="int32")
self.t_ndescrpt = paddle.to_tensor(self.ndescrpt, dtype="int32")
self.t_sel = paddle.to_tensor(self.sel_a, dtype="int32")

t_avg = paddle.to_tensor(
np.zeros([self.ntypes, self.ndescrpt]), dtype="float64"
Expand Down Expand Up @@ -539,6 +555,7 @@ def forward(
coord = paddle.reshape(coord_, [-1, natoms[1] * 3])
box = paddle.reshape(box_, [-1, 9])
atype = paddle.reshape(atype_, [-1, natoms[1]])

(
self.descrpt,
self.descrpt_deriv,
Expand Down Expand Up @@ -669,7 +686,7 @@ def _pass_filter(
[0, start_index, 0],
[
inputs.shape[0],
start_index + natoms[2 + type_i],
start_index + natoms[2 + type_i].item(),
inputs.shape[2],
],
)
Expand Down Expand Up @@ -697,7 +714,7 @@ def _pass_filter(
)
output.append(layer)
output_qmat.append(qmat)
start_index += natoms[2 + type_i]
start_index += natoms[2 + type_i].item()
else:
raise NotImplementedError()
# This branch will not be excecuted at current
Expand Down Expand Up @@ -747,13 +764,11 @@ def _compute_dstats_sys_smth(
self, data_coord, data_box, data_atype, natoms_vec, mesh
):
input_dict = {}
input_dict["coord"] = paddle.to_tensor(data_coord, dtype="float32")
input_dict["box"] = paddle.to_tensor(data_box, dtype="float32")
input_dict["type"] = paddle.to_tensor(data_atype, dtype="int32")
input_dict["natoms_vec"] = paddle.to_tensor(
natoms_vec, dtype="int32", place="cpu"
)
input_dict["default_mesh"] = paddle.to_tensor(mesh, dtype="int32")
input_dict["coord"] = paddle.to_tensor(data_coord, GLOBAL_PD_FLOAT_PRECISION)
input_dict["box"] = paddle.to_tensor(data_box, GLOBAL_PD_FLOAT_PRECISION)
input_dict["type"] = paddle.to_tensor(data_atype, "int32")
input_dict["natoms_vec"] = paddle.to_tensor(natoms_vec, "int32", place="cpu")
input_dict["default_mesh"] = paddle.to_tensor(mesh, "int32")

self.stat_descrpt, descrpt_deriv, rij, nlist = op_module.prod_env_mat_a(
input_dict["coord"],
Expand Down Expand Up @@ -949,10 +964,8 @@ def _filter_lower(
# natom x 4 x outputs_size

return paddle.matmul(
paddle.reshape(
inputs_i, [natom, shape_i[1] // 4, 4]
), # [natom, nei_type_i, 4]
xyz_scatter_out, # [natom, nei_type_i, 100]
paddle.reshape(inputs_i, [natom, shape_i[1] // 4, 4]),
xyz_scatter_out,
transpose_x=True,
)

Expand Down
6 changes: 3 additions & 3 deletions deepmd/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def freeze_graph(
input_spec=[
InputSpec(shape=[None], dtype="float64"), # coord_
InputSpec(shape=[None], dtype="int32"), # atype_
InputSpec(shape=[4], dtype="int32"), # natoms
InputSpec(shape=[2 + dp.model.descrpt.ntypes], dtype="int32"), # natoms
InputSpec(shape=[None], dtype="float64"), # box
InputSpec(shape=[6], dtype="int32"), # mesh
{
Expand All @@ -362,9 +362,9 @@ def freeze_graph(
)
for name, param in st_model.named_buffers():
print(
f"[{name}, {param.shape}] generated name in static_model is: {param.name}"
f"[{name}, {param.dtype}, {param.shape}] generated name in static_model is: {param.name}"
)
# skip pruning for program so as to keep buffers into files
# skip pruning for program so as to keep buffers into files
skip_prune_program = True
print(f"==>> Set skip_prune_program = {skip_prune_program}")
paddle.jit.save(st_model, output, skip_prune_program=skip_prune_program)
Expand Down
49 changes: 27 additions & 22 deletions deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def __init__(
self.atom_ener.append(None)
self.useBN = False
self.bias_atom_e = np.zeros(self.ntypes, dtype=np.float64)
ntypes_atom = self.ntypes - self.ntypes_spin
self.bias_atom_e = self.bias_atom_e[:ntypes_atom]
self.register_buffer(
"t_bias_atom_e",
paddle.to_tensor(self.bias_atom_e),
Expand Down Expand Up @@ -259,7 +261,6 @@ def __init__(
1,
activation_fn=None,
precision=self.fitting_precision,
bavg=self.bias_atom_e,
name=layer_suffix,
seed=self.seed,
trainable=self.trainable[-1],
Expand Down Expand Up @@ -321,6 +322,26 @@ def compute_output_stats(self, all_stat: dict, mixed_type: bool = False) -> None
self.bias_atom_e = self._compute_output_stats(
all_stat, rcond=self.rcond, mixed_type=mixed_type
)
ntypes_atom = self.ntypes - self.ntypes_spin
if self.spin is not None:
for type_i in range(ntypes_atom):
if self.bias_atom_e.shape[0] != self.ntypes:
self.bias_atom_e = np.pad(
self.bias_atom_e,
(0, self.ntypes_spin),
"constant",
constant_values=(0, 0),
)
bias_atom_e = self.bias_atom_e
if self.spin.use_spin[type_i]:
self.bias_atom_e[type_i] = (
self.bias_atom_e[type_i]
+ self.bias_atom_e[type_i + ntypes_atom]
)
else:
self.bias_atom_e[type_i] = self.bias_atom_e[type_i]
self.bias_atom_e = self.bias_atom_e[:ntypes_atom]

paddle.assign(self.bias_atom_e, self.t_bias_atom_e)

def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False):
Expand Down Expand Up @@ -525,26 +546,10 @@ def forward(
self.aparam_inv_std = 1.0

ntypes_atom = self.ntypes - self.ntypes_spin
if self.spin is not None:
for type_i in range(ntypes_atom):
if self.bias_atom_e.shape[0] != self.ntypes:
self.bias_atom_e = np.pad(
self.bias_atom_e,
(0, self.ntypes_spin),
"constant",
constant_values=(0, 0),
)
bias_atom_e = self.bias_atom_e
if self.spin.use_spin[type_i]:
self.bias_atom_e[type_i] = (
self.bias_atom_e[type_i]
+ self.bias_atom_e[type_i + ntypes_atom]
)
else:
self.bias_atom_e[type_i] = self.bias_atom_e[type_i]
self.bias_atom_e = self.bias_atom_e[:ntypes_atom]

inputs = paddle.reshape(inputs, [-1, natoms[0], self.dim_descrpt])
inputs = paddle.reshape(
inputs, [-1, natoms[0], self.dim_descrpt]
) # [1, all_atoms, M1*M2]
if len(self.atom_ener):
# only for atom_ener
nframes = input_dict.get("nframes")
Expand All @@ -558,7 +563,7 @@ def forward(
inputs_zero = paddle.zeros_like(inputs, dtype=GLOBAL_PD_FLOAT_PRECISION)

if bias_atom_e is not None:
assert len(bias_atom_e) == self.ntypes
assert len(bias_atom_e) == self.ntypes - self.ntypes_spin

fparam = None
if self.numb_fparam > 0:
Expand Down Expand Up @@ -590,7 +595,7 @@ def forward(
atype_nall,
[0, 1],
[0, 0],
[-1, paddle.sum(natoms[2 : 2 + ntypes_atom]).item()],
[atype_nall.shape[0], paddle.sum(natoms[2 : 2 + ntypes_atom]).item()],
)
atype_filter = paddle.cast(self.atype_nloc >= 0, GLOBAL_PD_FLOAT_PRECISION)
self.atype_nloc = paddle.reshape(self.atype_nloc, [-1])
Expand Down
Loading

0 comments on commit bb28e11

Please sign in to comment.