Skip to content

Commit

Permalink
[Paddle Backend] Support GPU train/test/inference/LAMMPS with PaddleP…
Browse files Browse the repository at this point in the history
…addle backend for water se_e2_a(revert code format) (#3078)

Summary can be previewed in README.md

Co-authored-by: zhouwei25 <[email protected]>
Co-authored-by: JiabinYang <[email protected]>
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: Zhanlue Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
6 people authored Dec 21, 2023
1 parent 6170356 commit 36e0082
Show file tree
Hide file tree
Showing 66 changed files with 6,966 additions and 1,998 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ repos:
- id: clang-format
exclude: ^source/3rdparty|source/lib/src/cuda/cudart/.+\.inc
# CSS
- repo: https://github.com/pre-commit/mirrors-csslint
rev: v1.0.5
hooks:
- id: csslint
# - repo: https://github.com/pre-commit/mirrors-csslint
# rev: v1.0.5
# hooks:
# - id: csslint
# Shell
- repo: https://github.com/scop/pre-commit-shfmt
rev: v3.6.0-2
Expand Down
343 changes: 269 additions & 74 deletions README.md

Large diffs are not rendered by default.

22 changes: 12 additions & 10 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
GLOBAL_PD_FLOAT_PRECISION,
GLOBAL_TF_FLOAT_PRECISION,
op_module,
paddle,
tf,
)
from deepmd.utils.path import (
Expand All @@ -50,11 +52,11 @@

# define constants
PRECISION_DICT = {
"default": GLOBAL_TF_FLOAT_PRECISION,
"float16": tf.float16,
"float32": tf.float32,
"float64": tf.float64,
"bfloat16": tf.bfloat16,
"default": GLOBAL_PD_FLOAT_PRECISION,
"float16": paddle.float16,
"float32": paddle.float32,
"float64": paddle.float64,
"bfloat16": paddle.bfloat16,
}


Expand Down Expand Up @@ -119,11 +121,11 @@ def gelu_wrapper(x):
data_requirement = {}

ACTIVATION_FN_DICT = {
"relu": tf.nn.relu,
"relu6": tf.nn.relu6,
"softplus": tf.nn.softplus,
"sigmoid": tf.sigmoid,
"tanh": tf.nn.tanh,
"relu": paddle.nn.functional.relu,
"relu6": paddle.nn.functional.relu6,
"softplus": paddle.nn.functional.softplus,
"sigmoid": paddle.nn.functional.sigmoid,
"tanh": paddle.nn.functional.tanh,
"gelu": gelu,
"gelu_tf": gelu_tf,
"None": None,
Expand Down
Loading

0 comments on commit 36e0082

Please sign in to comment.