Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support saving and loading of models #69

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
runs-on: ubuntu-latest

container:
image: tensorflow/tensorflow:2.16.1
image: tensorflow/tensorflow:2.17.0
env:
KERAS_BACKEND: tensorflow
volumes:
Expand All @@ -56,7 +56,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install flake8 pytest networkx
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
Expand Down
68 changes: 50 additions & 18 deletions ncps/keras/cfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@
@keras.utils.register_keras_serializable(package="ncps", name="CfC")
class CfC(keras.layers.RNN):
def __init__(
self,
units: Union[int, ncps.wirings.Wiring],
mixed_memory: bool = False,
mode: str = "default",
activation: str = "lecun_tanh",
backbone_units: int = None,
backbone_layers: int = None,
backbone_dropout: float = None,
return_sequences: bool = False,
return_state: bool = False,
go_backwards: bool = False,
stateful: bool = False,
unroll: bool = False,
time_major: bool = False,
**kwargs,
self,
units: Union[int, ncps.wirings.Wiring],
mixed_memory: bool = False,
mode: str = "default",
activation: str = "lecun_tanh",
backbone_units: int = None,
backbone_layers: int = None,
backbone_dropout: float = None,
sparsity_mask: keras.layers.Layer = None,
fully_recurrent: bool = True,
return_sequences: bool = False,
return_state: bool = False,
go_backwards: bool = False,
stateful: bool = False,
unroll: bool = False,
zero_output_for_mask: bool = False,
**kwargs,
):
"""Applies a `Closed-form Continuous-time <https://arxiv.org/abs/2106.13898>`_ RNN to an input sequence.

Expand All @@ -56,12 +58,18 @@ def __init__(
:param backbone_units: Number of hidden units in the backbone layer (default 128)
:param backbone_layers: Number of backbone layers (default 1)
:param backbone_dropout: Dropout rate in the backbone layers (default 0)
:param sparsity_mask:
:param fully_recurrent: Whether to apply a fully-connected sparsity_mask or use the adjacency_matrix. Evaluated only for WiredCfCCell. (default True)
:param return_sequences: Whether to return the full sequence or just the last output (default False)
:param return_state: Whether to return just the output of the RNN or a tuple (output, last_hidden_state) (default False)
:param go_backwards: If True, the input sequence will be process from back to the front (default False)
:param stateful: Whether to remember the last hidden state of the previous inference/training batch and use it as initial state for the next inference/training batch (default False)
:param unroll: Whether to unroll the graph, i.e., may increase speed at the cost of more memory (default False)
:param time_major: Whether the time or batch dimension is the first (0-th) dimension (default False)
:param zero_output_for_mask: Whether the output should use zeros for the masked timesteps. (default False)
Note that this field is only used when `return_sequences` is `True` and `mask` is provided.
It can be useful if you want to reuse the raw output sequence of
the RNN without interference from the masked timesteps, e.g.,
merging bidirectional RNNs.
:param kwargs:
"""

Expand All @@ -72,7 +80,7 @@ def __init__(
raise ValueError(f"Cannot use backbone_layers in wired mode")
if backbone_dropout is not None:
raise ValueError(f"Cannot use backbone_dropout in wired mode")
cell = WiredCfCCell(units, mode=mode, activation=activation)
cell = WiredCfCCell(units, mode=mode, activation=activation, fully_recurrent=fully_recurrent)
else:
backbone_units = 128 if backbone_units is None else backbone_units
backbone_layers = 1 if backbone_layers is None else backbone_layers
Expand All @@ -84,6 +92,7 @@ def __init__(
backbone_units=backbone_units,
backbone_layers=backbone_layers,
backbone_dropout=backbone_dropout,
sparsity_mask=sparsity_mask,
)
if mixed_memory:
cell = MixedMemoryRNN(cell)
Expand All @@ -94,6 +103,29 @@ def __init__(
go_backwards,
stateful,
unroll,
time_major,
zero_output_for_mask,
**kwargs,
)

def get_config(self):
is_mixed_memory = isinstance(self.cell, MixedMemoryRNN)
cell: CfCCell | WiredCfCCell = self.cell.rnn_cell if is_mixed_memory else self.cell
cell_config = cell.get_config()
config = super(CfC, self).get_config()
config["units"] = cell.wiring if isinstance(cell, WiredCfCCell) else cell.units
config["mixed_memory"] = is_mixed_memory
config["fully_recurrent"] = cell.fully_recurrent if isinstance(cell, WiredCfCCell) else True # If not WiredCfc it's ignored
return {**cell_config, **config}

@classmethod
def from_config(cls, config, custom_objects=None):
# The following parameters are recreated by the constructor
del config["cell"]
if "wiring" in config:
wiring_class = getattr(ncps.wirings, config["units"]["class_name"])
units = wiring_class.from_config(config["units"]["config"])
del config["wiring"]
else:
units = config["units"]
del config["units"]
return cls(units, **config)
124 changes: 55 additions & 69 deletions ncps/keras/cfc_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
import numpy as np


# LeCun improved tanh activation
# http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
@keras.utils.register_keras_serializable(package="", name="lecun_tanh")
def lecun_tanh(x):
return 1.7159 * keras.activations.tanh(0.666 * x)


# Register the custom activation function
from keras.src.activations import ALL_OBJECTS_DICT
ALL_OBJECTS_DICT["lecun_tanh"] = lecun_tanh


@keras.utils.register_keras_serializable(package="ncps", name="CfCCell")
class CfCCell(keras.layers.Layer):
def __init__(
self,
units,
input_sparsity=None,
recurrent_sparsity=None,
mode="default",
activation="lecun_tanh",
backbone_units=128,
backbone_layers=1,
backbone_dropout=0.1,
sparsity_mask=None,
**kwargs,
):
"""A `Closed-form Continuous-time <https://arxiv.org/abs/2106.13898>`_ cell.
Expand All @@ -55,35 +59,18 @@ def __init__(
"""
super().__init__(**kwargs)
self.units = units
self.sparsity_mask = None
if input_sparsity is not None or recurrent_sparsity is not None:
self.sparsity_mask = sparsity_mask
if sparsity_mask is not None:
# No backbone is allowed
if backbone_units > 0:
raise ValueError(
"If sparsity of a Cfc cell is set, then no backbone is allowed"
)
# Both need to be set
if input_sparsity is None or recurrent_sparsity is None:
raise ValueError(
"If sparsity of a Cfc cell is set, then both input and recurrent sparsity needs to be defined"
)
self.sparsity_mask = keras.ops.convert_to_tensor(
np.concatenate([input_sparsity, recurrent_sparsity], axis=0),
dtype="float32",
)
raise ValueError("If sparsity of a CfC cell is set, then no backbone is allowed")

allowed_modes = ["default", "pure", "no_gate"]
if mode not in allowed_modes:
raise ValueError(
"Unknown mode '{}', valid options are {}".format(
mode, str(allowed_modes)
)
)
raise ValueError(f"Unknown mode '{mode}', valid options are {str(allowed_modes)}")
self.mode = mode
self.backbone_fn = None
if activation == "lecun_tanh":
activation = lecun_tanh
self._activation = activation
self._activation = keras.activations.get(activation)
self._backbone_units = backbone_units
self._backbone_layers = backbone_layers
self._backbone_dropout = backbone_dropout
Expand All @@ -94,40 +81,36 @@ def state_size(self):
return self.units

def build(self, input_shape):
if isinstance(input_shape[0], tuple) or isinstance(
input_shape[0], keras.KerasTensor
):
if isinstance(input_shape[0], tuple) or isinstance(input_shape[0], keras.KerasTensor):
# Nested tuple -> First item represent feature dimension
input_dim = input_shape[0][-1]
else:
input_dim = input_shape[-1]

backbone_layers = []
for i in range(self._backbone_layers):
backbone_layers.append(
keras.layers.Dense(
self._backbone_units, self._activation, name=f"backbone{i}"
)
)
backbone_layers.append(keras.layers.Dropout(self._backbone_dropout))
if self._backbone_layers > 0:
backbone_layers = []
for i in range(self._backbone_layers):
backbone_layers.append(keras.layers.Dense(self._backbone_units, self._activation, name=f"backbone{i}"))
backbone_layers.append(keras.layers.Dropout(self._backbone_dropout))

self.backbone_fn = keras.models.Sequential(backbone_layers)
self.backbone_fn.build((None, self.state_size + input_dim))
cat_shape = int(self._backbone_units)
else:
cat_shape = int(self.state_size + input_dim)

cat_shape = int(
self.state_size + input_dim
if self._backbone_layers == 0
else self._backbone_units
self.ff1_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
name="ff1_weight",
)
self.ff1_bias = self.add_weight(
shape=(self.state_size,),
initializer="zeros",
name="ff1_bias",
)

if self.mode == "pure":
self.ff1_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
name="ff1_weight",
)
self.ff1_bias = self.add_weight(
shape=(self.state_size,),
initializer="zeros",
name="ff1_bias",
)
self.w_tau = self.add_weight(
shape=(1, self.state_size),
initializer=keras.initializers.Zeros(),
Expand All @@ -139,16 +122,6 @@ def build(self, input_shape):
name="A",
)
else:
self.ff1_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
name="ff1_weight",
)
self.ff1_bias = self.add_weight(
shape=(self.state_size,),
initializer="zeros",
name="ff1_bias",
)
self.ff2_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
Expand All @@ -160,17 +133,13 @@ def build(self, input_shape):
name="ff2_bias",
)

# = keras.layers.Dense(
# , self._activation, name=f"{self.name}/ff1"
# )
# self.ff2 = keras.layers.Dense(
# self.state_size, self._activation, name=f"{self.name}/ff2"
# )
# if self.sparsity_mask is not None:
# self.ff1.build((None,))
# self.ff2.build((None, self.sparsity_mask.shape[0]))
self.time_a = keras.layers.Dense(self.state_size, name="time_a")
self.time_b = keras.layers.Dense(self.state_size, name="time_b")
input_shape = (None, self.state_size + input_dim)
if self._backbone_layers > 0:
input_shape = self.backbone_fn.output_shape
self.time_a.build(input_shape)
self.time_b.build(input_shape)
self.built = True

def call(self, inputs, states, **kwargs):
Expand Down Expand Up @@ -213,3 +182,20 @@ def call(self, inputs, states, **kwargs):
new_hidden = ff1 * (1.0 - t_interp) + t_interp * ff2

return new_hidden, [new_hidden]

def get_config(self):
config = {
"units": self.units,
"mode": self.mode,
"activation": self._activation,
"backbone_units": self._backbone_units,
"backbone_layers": self._backbone_layers,
"backbone_dropout": self._backbone_dropout,
"sparsity_mask": self.sparsity_mask,
}
base_config = super().get_config()
return {**base_config, **config}

@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
Loading
Loading