Skip to content

Commit

Permalink
add activation_function and resnet arguments and NumPy implementation…
Browse files Browse the repository at this point in the history
… to NativeLayer (#3109)

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jan 5, 2024
1 parent 61ee4f2 commit db22812
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 17 deletions.
94 changes: 90 additions & 4 deletions deepmd_utils/model_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
See issue #2982 for more information.
"""
import json
from abc import (
ABC,
)
from typing import (
List,
Optional,
Expand Down Expand Up @@ -121,7 +124,15 @@ def load_dp_model(filename: str) -> dict:
return model_dict


class NativeLayer:
class NativeOP(ABC):
"""The unit operation of a native model."""

def call(self, *args, **kwargs):
"""Forward pass in NumPy implementation."""
raise NotImplementedError


class NativeLayer(NativeOP):
"""Native representation of a layer.
Parameters
Expand All @@ -132,17 +143,25 @@ class NativeLayer:
The biases of the layer.
idt : np.ndarray, optional
The identity matrix of the layer.
activation_function : str, optional
The activation function of the layer.
resnet : bool, optional
Whether the layer is a residual layer.
"""

def __init__(
self,
w: Optional[np.ndarray] = None,
b: Optional[np.ndarray] = None,
idt: Optional[np.ndarray] = None,
activation_function: Optional[str] = None,
resnet: bool = False,
) -> None:
self.w = w
self.b = b
self.idt = idt
self.activation_function = activation_function
self.resnet = resnet

def serialize(self) -> dict:
"""Serialize the layer to a dict.
Expand All @@ -158,7 +177,11 @@ def serialize(self) -> dict:
}
if self.idt is not None:
data["idt"] = self.idt
return data
return {
"activation_function": self.activation_function,
"resnet": self.resnet,
"@variables": data,
}

@classmethod
def deserialize(cls, data: dict) -> "NativeLayer":
Expand All @@ -169,7 +192,13 @@ def deserialize(cls, data: dict) -> "NativeLayer":
data : dict
The dict to deserialize from.
"""
return cls(data["w"], data["b"], data.get("idt", None))
return cls(
w=data["@variables"]["w"],
b=data["@variables"]["b"],
idt=data.get("idt", None),
activation_function=data["activation_function"],
resnet=data.get("resnet", False),
)

def __setitem__(self, key, value):
if key in ("w", "matrix"):
Expand All @@ -178,6 +207,10 @@ def __setitem__(self, key, value):
self.b = value
elif key == "idt":
self.idt = value
elif key == "activation_function":
self.activation_function = value
elif key == "resnet":
self.resnet = value
else:
raise KeyError(key)

Expand All @@ -188,11 +221,47 @@ def __getitem__(self, key):
return self.b
elif key == "idt":
return self.idt
elif key == "activation_function":
return self.activation_function
elif key == "resnet":
return self.resnet
else:
raise KeyError(key)

def call(self, x: np.ndarray) -> np.ndarray:
"""Forward pass.
Parameters
----------
x : np.ndarray
The input.
Returns
-------
np.ndarray
The output.
"""
if self.w is None or self.b is None or self.activation_function is None:
raise ValueError("w, b, and activation_function must be set")
if self.activation_function == "tanh":
fn = np.tanh
elif self.activation_function.lower() == "none":

def fn(x):
return x
else:
raise NotImplementedError(self.activation_function)
y = fn(np.matmul(x, self.w) + self.b)
if self.idt is not None:
y *= self.idt
if self.resnet and self.w.shape[1] == self.w.shape[0]:
y += x
elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]:
y += np.concatenate([x, x], axis=1)
return y


class NativeNet:
class NativeNet(NativeOP):
"""Native representation of a neural network.
Parameters
Expand Down Expand Up @@ -238,3 +307,20 @@ def __setitem__(self, key, value):
if len(self.layers) <= key:
self.layers.extend([NativeLayer()] * (key - len(self.layers) + 1))
self.layers[key] = value

def call(self, x: np.ndarray) -> np.ndarray:
"""Forward pass.
Parameters
----------
x : np.ndarray
The input.
Returns
-------
np.ndarray
The output.
"""
for layer in self.layers:
x = layer.call(x)
return x
52 changes: 39 additions & 13 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,45 @@ def test_serialize(self):
network[1]["b"] = self.b
network[0]["w"] = self.w
network[0]["b"] = self.b
network[1]["activation_function"] = "tanh"
network[0]["activation_function"] = "tanh"
network[1]["resnet"] = True
network[0]["resnet"] = True
jdata = network.serialize()
np.testing.assert_array_equal(jdata["layers"][0]["w"], self.w)
np.testing.assert_array_equal(jdata["layers"][0]["b"], self.b)
np.testing.assert_array_equal(jdata["layers"][1]["w"], self.w)
np.testing.assert_array_equal(jdata["layers"][1]["b"], self.b)
np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["w"], self.w)
np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["b"], self.b)
np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["w"], self.w)
np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["b"], self.b)
np.testing.assert_array_equal(jdata["layers"][0]["activation_function"], "tanh")
np.testing.assert_array_equal(jdata["layers"][1]["activation_function"], "tanh")
np.testing.assert_array_equal(jdata["layers"][0]["resnet"], True)
np.testing.assert_array_equal(jdata["layers"][1]["resnet"], True)

def test_deserialize(self):
network = NativeNet.deserialize(
{
"layers": [
{"w": self.w, "b": self.b},
{"w": self.w, "b": self.b},
]
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": self.w, "b": self.b},
},
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": self.w, "b": self.b},
},
],
}
)
np.testing.assert_array_equal(network[0]["w"], self.w)
np.testing.assert_array_equal(network[0]["b"], self.b)
np.testing.assert_array_equal(network[1]["w"], self.w)
np.testing.assert_array_equal(network[1]["b"], self.b)
np.testing.assert_array_equal(network[0]["activation_function"], "tanh")
np.testing.assert_array_equal(network[1]["activation_function"], "tanh")
np.testing.assert_array_equal(network[0]["resnet"], True)
np.testing.assert_array_equal(network[1]["resnet"], True)


class TestDPModel(unittest.TestCase):
Expand All @@ -52,12 +72,18 @@ def setUp(self) -> None:
self.b = np.full((3,), 4.0)
self.model_dict = {
"type": "some_type",
"@variables": {
"layers": [
{"w": self.w, "b": self.b},
{"w": self.w, "b": self.b},
]
},
"layers": [
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": self.w, "b": self.b},
},
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": self.w, "b": self.b},
},
],
}
self.filename = "test_dp_model_format.dp"

Expand Down

0 comments on commit db22812

Please sign in to comment.