Skip to content

Commit

Permalink
Merge pull request #69 from lettercode/lettercode/support-bidirection…
Browse files Browse the repository at this point in the history
…al-layer

Support saving and loading of models
  • Loading branch information
mlech26l authored Aug 14, 2024
2 parents 3c442c4 + 5a1c4d9 commit 71e370e
Show file tree
Hide file tree
Showing 9 changed files with 911 additions and 177 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
runs-on: ubuntu-latest

container:
image: tensorflow/tensorflow:2.16.1
image: tensorflow/tensorflow:2.17.0
env:
KERAS_BACKEND: tensorflow
volumes:
Expand All @@ -56,7 +56,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install flake8 pytest networkx
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
Expand Down
68 changes: 50 additions & 18 deletions ncps/keras/cfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@
@keras.utils.register_keras_serializable(package="ncps", name="CfC")
class CfC(keras.layers.RNN):
def __init__(
self,
units: Union[int, ncps.wirings.Wiring],
mixed_memory: bool = False,
mode: str = "default",
activation: str = "lecun_tanh",
backbone_units: int = None,
backbone_layers: int = None,
backbone_dropout: float = None,
return_sequences: bool = False,
return_state: bool = False,
go_backwards: bool = False,
stateful: bool = False,
unroll: bool = False,
time_major: bool = False,
**kwargs,
self,
units: Union[int, ncps.wirings.Wiring],
mixed_memory: bool = False,
mode: str = "default",
activation: str = "lecun_tanh",
backbone_units: int = None,
backbone_layers: int = None,
backbone_dropout: float = None,
sparsity_mask: keras.layers.Layer = None,
fully_recurrent: bool = True,
return_sequences: bool = False,
return_state: bool = False,
go_backwards: bool = False,
stateful: bool = False,
unroll: bool = False,
zero_output_for_mask: bool = False,
**kwargs,
):
"""Applies a `Closed-form Continuous-time <https://arxiv.org/abs/2106.13898>`_ RNN to an input sequence.
Expand All @@ -56,12 +58,18 @@ def __init__(
:param backbone_units: Number of hidden units in the backbone layer (default 128)
:param backbone_layers: Number of backbone layers (default 1)
:param backbone_dropout: Dropout rate in the backbone layers (default 0)
:param sparsity_mask:
:param fully_recurrent: Whether to apply a fully-connected sparsity_mask or use the adjacency_matrix. Evaluated only for WiredCfCCell. (default True)
:param return_sequences: Whether to return the full sequence or just the last output (default False)
:param return_state: Whether to return just the output of the RNN or a tuple (output, last_hidden_state) (default False)
:param go_backwards: If True, the input sequence will be process from back to the front (default False)
:param stateful: Whether to remember the last hidden state of the previous inference/training batch and use it as initial state for the next inference/training batch (default False)
:param unroll: Whether to unroll the graph, i.e., may increase speed at the cost of more memory (default False)
:param time_major: Whether the time or batch dimension is the first (0-th) dimension (default False)
:param zero_output_for_mask: Whether the output should use zeros for the masked timesteps. (default False)
Note that this field is only used when `return_sequences` is `True` and `mask` is provided.
It can be useful if you want to reuse the raw output sequence of
the RNN without interference from the masked timesteps, e.g.,
merging bidirectional RNNs.
:param kwargs:
"""

Expand All @@ -72,7 +80,7 @@ def __init__(
raise ValueError(f"Cannot use backbone_layers in wired mode")
if backbone_dropout is not None:
raise ValueError(f"Cannot use backbone_dropout in wired mode")
cell = WiredCfCCell(units, mode=mode, activation=activation)
cell = WiredCfCCell(units, mode=mode, activation=activation, fully_recurrent=fully_recurrent)
else:
backbone_units = 128 if backbone_units is None else backbone_units
backbone_layers = 1 if backbone_layers is None else backbone_layers
Expand All @@ -84,6 +92,7 @@ def __init__(
backbone_units=backbone_units,
backbone_layers=backbone_layers,
backbone_dropout=backbone_dropout,
sparsity_mask=sparsity_mask,
)
if mixed_memory:
cell = MixedMemoryRNN(cell)
Expand All @@ -94,6 +103,29 @@ def __init__(
go_backwards,
stateful,
unroll,
time_major,
zero_output_for_mask,
**kwargs,
)

def get_config(self):
is_mixed_memory = isinstance(self.cell, MixedMemoryRNN)
cell: CfCCell | WiredCfCCell = self.cell.rnn_cell if is_mixed_memory else self.cell
cell_config = cell.get_config()
config = super(CfC, self).get_config()
config["units"] = cell.wiring if isinstance(cell, WiredCfCCell) else cell.units
config["mixed_memory"] = is_mixed_memory
config["fully_recurrent"] = cell.fully_recurrent if isinstance(cell, WiredCfCCell) else True # If not WiredCfc it's ignored
return {**cell_config, **config}

@classmethod
def from_config(cls, config, custom_objects=None):
# The following parameters are recreated by the constructor
del config["cell"]
if "wiring" in config:
wiring_class = getattr(ncps.wirings, config["units"]["class_name"])
units = wiring_class.from_config(config["units"]["config"])
del config["wiring"]
else:
units = config["units"]
del config["units"]
return cls(units, **config)
124 changes: 55 additions & 69 deletions ncps/keras/cfc_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
import numpy as np


# LeCun improved tanh activation
# http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
@keras.utils.register_keras_serializable(package="", name="lecun_tanh")
def lecun_tanh(x):
return 1.7159 * keras.activations.tanh(0.666 * x)


# Register the custom activation function
from keras.src.activations import ALL_OBJECTS_DICT
ALL_OBJECTS_DICT["lecun_tanh"] = lecun_tanh


@keras.utils.register_keras_serializable(package="ncps", name="CfCCell")
class CfCCell(keras.layers.Layer):
def __init__(
self,
units,
input_sparsity=None,
recurrent_sparsity=None,
mode="default",
activation="lecun_tanh",
backbone_units=128,
backbone_layers=1,
backbone_dropout=0.1,
sparsity_mask=None,
**kwargs,
):
"""A `Closed-form Continuous-time <https://arxiv.org/abs/2106.13898>`_ cell.
Expand All @@ -55,35 +59,18 @@ def __init__(
"""
super().__init__(**kwargs)
self.units = units
self.sparsity_mask = None
if input_sparsity is not None or recurrent_sparsity is not None:
self.sparsity_mask = sparsity_mask
if sparsity_mask is not None:
# No backbone is allowed
if backbone_units > 0:
raise ValueError(
"If sparsity of a Cfc cell is set, then no backbone is allowed"
)
# Both need to be set
if input_sparsity is None or recurrent_sparsity is None:
raise ValueError(
"If sparsity of a Cfc cell is set, then both input and recurrent sparsity needs to be defined"
)
self.sparsity_mask = keras.ops.convert_to_tensor(
np.concatenate([input_sparsity, recurrent_sparsity], axis=0),
dtype="float32",
)
raise ValueError("If sparsity of a CfC cell is set, then no backbone is allowed")

allowed_modes = ["default", "pure", "no_gate"]
if mode not in allowed_modes:
raise ValueError(
"Unknown mode '{}', valid options are {}".format(
mode, str(allowed_modes)
)
)
raise ValueError(f"Unknown mode '{mode}', valid options are {str(allowed_modes)}")
self.mode = mode
self.backbone_fn = None
if activation == "lecun_tanh":
activation = lecun_tanh
self._activation = activation
self._activation = keras.activations.get(activation)
self._backbone_units = backbone_units
self._backbone_layers = backbone_layers
self._backbone_dropout = backbone_dropout
Expand All @@ -94,40 +81,36 @@ def state_size(self):
return self.units

def build(self, input_shape):
if isinstance(input_shape[0], tuple) or isinstance(
input_shape[0], keras.KerasTensor
):
if isinstance(input_shape[0], tuple) or isinstance(input_shape[0], keras.KerasTensor):
# Nested tuple -> First item represent feature dimension
input_dim = input_shape[0][-1]
else:
input_dim = input_shape[-1]

backbone_layers = []
for i in range(self._backbone_layers):
backbone_layers.append(
keras.layers.Dense(
self._backbone_units, self._activation, name=f"backbone{i}"
)
)
backbone_layers.append(keras.layers.Dropout(self._backbone_dropout))
if self._backbone_layers > 0:
backbone_layers = []
for i in range(self._backbone_layers):
backbone_layers.append(keras.layers.Dense(self._backbone_units, self._activation, name=f"backbone{i}"))
backbone_layers.append(keras.layers.Dropout(self._backbone_dropout))

self.backbone_fn = keras.models.Sequential(backbone_layers)
self.backbone_fn.build((None, self.state_size + input_dim))
cat_shape = int(self._backbone_units)
else:
cat_shape = int(self.state_size + input_dim)

cat_shape = int(
self.state_size + input_dim
if self._backbone_layers == 0
else self._backbone_units
self.ff1_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
name="ff1_weight",
)
self.ff1_bias = self.add_weight(
shape=(self.state_size,),
initializer="zeros",
name="ff1_bias",
)

if self.mode == "pure":
self.ff1_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
name="ff1_weight",
)
self.ff1_bias = self.add_weight(
shape=(self.state_size,),
initializer="zeros",
name="ff1_bias",
)
self.w_tau = self.add_weight(
shape=(1, self.state_size),
initializer=keras.initializers.Zeros(),
Expand All @@ -139,16 +122,6 @@ def build(self, input_shape):
name="A",
)
else:
self.ff1_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
name="ff1_weight",
)
self.ff1_bias = self.add_weight(
shape=(self.state_size,),
initializer="zeros",
name="ff1_bias",
)
self.ff2_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
Expand All @@ -160,17 +133,13 @@ def build(self, input_shape):
name="ff2_bias",
)

# = keras.layers.Dense(
# , self._activation, name=f"{self.name}/ff1"
# )
# self.ff2 = keras.layers.Dense(
# self.state_size, self._activation, name=f"{self.name}/ff2"
# )
# if self.sparsity_mask is not None:
# self.ff1.build((None,))
# self.ff2.build((None, self.sparsity_mask.shape[0]))
self.time_a = keras.layers.Dense(self.state_size, name="time_a")
self.time_b = keras.layers.Dense(self.state_size, name="time_b")
input_shape = (None, self.state_size + input_dim)
if self._backbone_layers > 0:
input_shape = self.backbone_fn.output_shape
self.time_a.build(input_shape)
self.time_b.build(input_shape)
self.built = True

def call(self, inputs, states, **kwargs):
Expand Down Expand Up @@ -213,3 +182,20 @@ def call(self, inputs, states, **kwargs):
new_hidden = ff1 * (1.0 - t_interp) + t_interp * ff2

return new_hidden, [new_hidden]

def get_config(self):
config = {
"units": self.units,
"mode": self.mode,
"activation": self._activation,
"backbone_units": self._backbone_units,
"backbone_layers": self._backbone_layers,
"backbone_dropout": self._backbone_dropout,
"sparsity_mask": self.sparsity_mask,
}
base_config = super().get_config()
return {**base_config, **config}

@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
Loading

0 comments on commit 71e370e

Please sign in to comment.