From e21b2a255db7314174712f2238b55d9b3d2e1115 Mon Sep 17 00:00:00 2001 From: lettercode <59030475+lettercode@users.noreply.github.com> Date: Sat, 6 Apr 2024 22:10:23 +0200 Subject: [PATCH 1/4] Add Keras V3 implementation --- ncps/keras/__init__.py | 44 +++++ ncps/keras/cfc.py | 99 ++++++++++ ncps/keras/cfc_cell.py | 215 ++++++++++++++++++++++ ncps/keras/ltc.py | 103 +++++++++++ ncps/keras/ltc_cell.py | 346 +++++++++++++++++++++++++++++++++++ ncps/keras/mm_rnn.py | 103 +++++++++++ ncps/keras/wired_cfc_cell.py | 167 +++++++++++++++++ ncps/tests/test_keras.py | 305 ++++++++++++++++++++++++++++++ requirements.txt | 7 +- 9 files changed, 1385 insertions(+), 4 deletions(-) create mode 100644 ncps/keras/__init__.py create mode 100644 ncps/keras/cfc.py create mode 100644 ncps/keras/cfc_cell.py create mode 100644 ncps/keras/ltc.py create mode 100644 ncps/keras/ltc_cell.py create mode 100644 ncps/keras/mm_rnn.py create mode 100644 ncps/keras/wired_cfc_cell.py create mode 100644 ncps/tests/test_keras.py diff --git a/ncps/keras/__init__.py b/ncps/keras/__init__.py new file mode 100644 index 00000000..69fb9c66 --- /dev/null +++ b/ncps/keras/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2020-2021 Mathias Lechner +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import absolute_import + +from .ltc_cell import LTCCell +from .mm_rnn import MixedMemoryRNN +from .cfc_cell import CfCCell +from .wired_cfc_cell import WiredCfCCell +from .cfc import CfC +from .ltc import LTC +from packaging.version import parse + +try: + import keras +except: + raise ImportWarning( + "It seems like the Keras package is not installed\n" + "Please run" + "`$ pip install keras`. \n", + ) + +if parse(keras.__version__) < parse("3.0.0"): + raise ImportError( + "The Keras package version needs to be at least 3.0.0 \n" + "for ncps-keras to run. Currently, your Keras version is \n" + "{version}. Please upgrade with \n" + "`$ pip install --upgrade keras`. \n" + "You can use `pip freeze` to check afterwards that everything is " + "ok.".format(version=keras.__version__) + ) +__all__ = ["CfC", "CfCCell", "LTC", "LTCCell", "MixedMemoryRNN", "WiredCfCCell"] diff --git a/ncps/keras/cfc.py b/ncps/keras/cfc.py new file mode 100644 index 00000000..ac4ec8e8 --- /dev/null +++ b/ncps/keras/cfc.py @@ -0,0 +1,99 @@ +# Copyright 2022 Mathias Lechner and Ramin Hasani +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import keras + +import ncps +from . import CfCCell, MixedMemoryRNN, WiredCfCCell + + +@keras.utils.register_keras_serializable(package="ncps", name="CfC") +class CfC(keras.layers.RNN): + def __init__( + self, + units: Union[int, ncps.wirings.Wiring], + mixed_memory: bool = False, + mode: str = "default", + activation: str = "lecun_tanh", + backbone_units: int = None, + backbone_layers: int = None, + backbone_dropout: float = None, + return_sequences: bool = False, + return_state: bool = False, + go_backwards: bool = False, + stateful: bool = False, + unroll: bool = False, + time_major: bool = False, + **kwargs, + ): + """Applies a `Closed-form Continuous-time `_ RNN to an input sequence. + + Examples:: + + >>> from ncps.tf import CfC + >>> + >>> rnn = CfC(50) + >>> x = keras.random.uniform((2, 10, 20)) # (B,L,C) + >>> y = keras.layers.RNN(x) + + :param units: Number of hidden units + :param mixed_memory: Whether to augment the RNN by a `memory-cell `_ to help learn long-term dependencies in the data (default False) + :param mode: Either "default", "pure" (direct solution approximation), or "no_gate" (without second gate). (default "default) + :param activation: Activation function used in the backbone layers (default "lecun_tanh") + :param backbone_units: Number of hidden units in the backbone layer (default 128) + :param backbone_layers: Number of backbone layers (default 1) + :param backbone_dropout: Dropout rate in the backbone layers (default 0) + :param return_sequences: Whether to return the full sequence or just the last output (default False) + :param return_state: Whether to return just the output of the RNN or a tuple (output, last_hidden_state) (default False) + :param go_backwards: If True, the input sequence will be process from back to the front (default False) + :param stateful: Whether to remember the last hidden state of the previous inference/training batch and use it as initial state for the next inference/training batch (default False) + :param unroll: Whether to unroll the graph, i.e., may increase speed at the cost of more memory (default False) + :param time_major: Whether the time or batch dimension is the first (0-th) dimension (default False) + :param kwargs: + """ + + if isinstance(units, ncps.wirings.Wiring): + if backbone_units is not None: + raise ValueError(f"Cannot use backbone_units in wired mode") + if backbone_layers is not None: + raise ValueError(f"Cannot use backbone_layers in wired mode") + if backbone_dropout is not None: + raise ValueError(f"Cannot use backbone_dropout in wired mode") + cell = WiredCfCCell(units, mode=mode, activation=activation) + else: + backbone_units = 128 if backbone_units is None else backbone_units + backbone_layers = 1 if backbone_layers is None else backbone_layers + backbone_dropout = 0.0 if backbone_dropout is None else backbone_dropout + cell = CfCCell( + units, + mode=mode, + activation=activation, + backbone_units=backbone_units, + backbone_layers=backbone_layers, + backbone_dropout=backbone_dropout, + ) + if mixed_memory: + cell = MixedMemoryRNN(cell) + super(CfC, self).__init__( + cell, + return_sequences, + return_state, + go_backwards, + stateful, + unroll, + time_major, + **kwargs, + ) diff --git a/ncps/keras/cfc_cell.py b/ncps/keras/cfc_cell.py new file mode 100644 index 00000000..18b7a1c6 --- /dev/null +++ b/ncps/keras/cfc_cell.py @@ -0,0 +1,215 @@ +# Copyright 2022 Mathias Lechner and Ramin Hasani +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +import numpy as np + + +# LeCun improved tanh activation +# http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf +def lecun_tanh(x): + return 1.7159 * keras.activations.tanh(0.666 * x) + + +@keras.utils.register_keras_serializable(package="ncps", name="CfCCell") +class CfCCell(keras.layers.Layer): + def __init__( + self, + units, + input_sparsity=None, + recurrent_sparsity=None, + mode="default", + activation="lecun_tanh", + backbone_units=128, + backbone_layers=1, + backbone_dropout=0.1, + **kwargs, + ): + """A `Closed-form Continuous-time `_ cell. + + .. Note:: + This is an RNNCell that process single time-steps. + To get a full RNN that can process sequences, + see `ncps.tf.CfC` or wrap the cell with a `keras.layers.RNN `_. + + + :param units: Number of hidden units + :param input_sparsity: + :param recurrent_sparsity: + :param mode: Either "default", "pure" (direct solution approximation), or "no_gate" (without second gate). + :param activation: Activation function used in the backbone layers + :param backbone_units: Number of hidden units in the backbone layer (default 128) + :param backbone_layers: Number of backbone layers (default 1) + :param backbone_dropout: Dropout rate in the backbone layers (default 0) + :param kwargs: + """ + super().__init__(**kwargs) + self.units = units + self.sparsity_mask = None + if input_sparsity is not None or recurrent_sparsity is not None: + # No backbone is allowed + if backbone_units > 0: + raise ValueError( + "If sparsity of a Cfc cell is set, then no backbone is allowed" + ) + # Both need to be set + if input_sparsity is None or recurrent_sparsity is None: + raise ValueError( + "If sparsity of a Cfc cell is set, then both input and recurrent sparsity needs to be defined" + ) + self.sparsity_mask = keras.ops.convert_to_tensor( + np.concatenate([input_sparsity, recurrent_sparsity], axis=0), + dtype="float32", + ) + + allowed_modes = ["default", "pure", "no_gate"] + if mode not in allowed_modes: + raise ValueError( + "Unknown mode '{}', valid options are {}".format( + mode, str(allowed_modes) + ) + ) + self.mode = mode + self.backbone_fn = None + if activation == "lecun_tanh": + activation = lecun_tanh + self._activation = activation + self._backbone_units = backbone_units + self._backbone_layers = backbone_layers + self._backbone_dropout = backbone_dropout + self._cfc_layers = [] + + @property + def state_size(self): + return self.units + + def build(self, input_shape): + if isinstance(input_shape[0], tuple) or isinstance( + input_shape[0], keras.KerasTensor + ): + # Nested tuple -> First item represent feature dimension + input_dim = input_shape[0][-1] + else: + input_dim = input_shape[-1] + + backbone_layers = [] + for i in range(self._backbone_layers): + backbone_layers.append( + keras.layers.Dense( + self._backbone_units, self._activation, name=f"backbone{i}" + ) + ) + backbone_layers.append(keras.layers.Dropout(self._backbone_dropout)) + self.backbone_fn = keras.models.Sequential(backbone_layers) + + cat_shape = int( + self.state_size + input_dim + if self._backbone_layers == 0 + else self._backbone_units + ) + if self.mode == "pure": + self.ff1_kernel = self.add_weight( + shape=(cat_shape, self.state_size), + initializer="glorot_uniform", + name="ff1_weight", + ) + self.ff1_bias = self.add_weight( + shape=(self.state_size,), + initializer="zeros", + name="ff1_bias", + ) + self.w_tau = self.add_weight( + shape=(1, self.state_size), + initializer=keras.initializers.Zeros(), + name="w_tau", + ) + self.A = self.add_weight( + shape=(1, self.state_size), + initializer=keras.initializers.Ones(), + name="A", + ) + else: + self.ff1_kernel = self.add_weight( + shape=(cat_shape, self.state_size), + initializer="glorot_uniform", + name="ff1_weight", + ) + self.ff1_bias = self.add_weight( + shape=(self.state_size,), + initializer="zeros", + name="ff1_bias", + ) + self.ff2_kernel = self.add_weight( + shape=(cat_shape, self.state_size), + initializer="glorot_uniform", + name="ff2_weight", + ) + self.ff2_bias = self.add_weight( + shape=(self.state_size,), + initializer="zeros", + name="ff2_bias", + ) + + # = keras.layers.Dense( + # , self._activation, name=f"{self.name}/ff1" + # ) + # self.ff2 = keras.layers.Dense( + # self.state_size, self._activation, name=f"{self.name}/ff2" + # ) + # if self.sparsity_mask is not None: + # self.ff1.build((None,)) + # self.ff2.build((None, self.sparsity_mask.shape[0])) + self.time_a = keras.layers.Dense(self.state_size, name="time_a") + self.time_b = keras.layers.Dense(self.state_size, name="time_b") + self.built = True + + def call(self, inputs, states, **kwargs): + if isinstance(inputs, (tuple, list)): + # Irregularly sampled mode + inputs, t = inputs + t = keras.ops.reshape(t, [-1, 1]) + else: + # Regularly sampled mode (elapsed time = 1 second) + t = kwargs.get("time") or 1.0 + x = keras.layers.Concatenate()([inputs, states[0]]) + if self._backbone_layers > 0: + x = self.backbone_fn(x) + if self.sparsity_mask is not None: + ff1_kernel = self.ff1_kernel * self.sparsity_mask + ff1 = keras.ops.matmul(x, ff1_kernel) + self.ff1_bias + else: + ff1 = keras.ops.matmul(x, self.ff1_kernel) + self.ff1_bias + if self.mode == "pure": + # Solution + new_hidden = ( + -self.A + * keras.ops.exp(-t * (keras.ops.abs(self.w_tau) + keras.ops.abs(ff1))) + * ff1 + + self.A + ) + else: + # Cfc + if self.sparsity_mask is not None: + ff2_kernel = self.ff2_kernel * self.sparsity_mask + ff2 = keras.ops.matmul(x, ff2_kernel) + self.ff2_bias + else: + ff2 = keras.ops.matmul(x, self.ff2_kernel) + self.ff2_bias + t_a = self.time_a(x) + t_b = self.time_b(x) + t_interp = keras.activations.sigmoid(-t_a * t + t_b) + if self.mode == "no_gate": + new_hidden = ff1 + t_interp * ff2 + else: + new_hidden = ff1 * (1.0 - t_interp) + t_interp * ff2 + + return new_hidden, [new_hidden] diff --git a/ncps/keras/ltc.py b/ncps/keras/ltc.py new file mode 100644 index 00000000..bd02699a --- /dev/null +++ b/ncps/keras/ltc.py @@ -0,0 +1,103 @@ +# Copyright 2022 Mathias Lechner and Ramin Hasani +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ncps +import keras +from . import LTCCell, MixedMemoryRNN + + +@keras.utils.register_keras_serializable(package="ncps", name="LTC") +class LTC(keras.layers.RNN): + def __init__( + self, + units, + mixed_memory: bool = False, + input_mapping="affine", + output_mapping="affine", + ode_unfolds=6, + epsilon=1e-8, + initialization_ranges=None, + return_sequences: bool = False, + return_state: bool = False, + go_backwards: bool = False, + stateful: bool = False, + unroll: bool = False, + time_major: bool = False, + **kwargs, + ): + """Applies a `Liquid time-constant (LTC) `_ RNN to an input sequence. + + Examples:: + + >>> from ncps.keras import LTC + >>> + >>> rnn = LTC(50) + >>> x = tf.random.uniform((2, 10, 20)) # (B,L,C) + >>> y = rnn(x) + + .. Note:: + For creating a wired `Neural circuit policy (NCP) `_ you can pass a `ncps.wirings.NCP` object instead of the number of units + + Examples:: + + >>> from ncps.tf import LTC + >>> from ncps.wirings import NCP + >>> + >>> wiring = NCP(10, 10, 8, 6, 6, 4, 4) + >>> rnn = LTC(wiring) + >>> x = tf.random.uniform((2, 10, 20)) # (B,L,C) + >>> y = rnn(x) + + :param units: Wiring (ncps.wirings.Wiring instance) or integer representing the number of (fully-connected) hidden units + :param mixed_memory: Whether to augment the RNN by a `memory-cell `_ to help learn long-term dependencies in the data + :param input_mapping: Mapping applied to the sensory neurons. Possible values None, "linear", "affine" (default "affine") + :param output_mapping: Mapping applied to the motor neurons. Possible values None, "linear", "affine" (default "affine") + :param ode_unfolds: Number of ODE-solver steps per time-step (default 6) + :param epsilon: Auxillary value to avoid dividing by 0 (default 1e-8) + :param initialization_ranges: A dictionary for overwriting the range of the uniform weight initialization (default None) + :param return_sequences: Whether to return the full sequence or just the last output (default False) + :param return_state: Whether to return just the output of the RNN or a tuple (output, last_hidden_state) (default False) + :param go_backwards: If True, the input sequence will be process from back to the front (default False) + :param stateful: Whether to remember the last hidden state of the previous inference/training batch and use it as initial state for the next inference/training batch (default False) + :param unroll: Whether to unroll the graph, i.e., may increase speed at the cost of more memory (default False) + :param time_major: Whether the time or batch dimension is the first (0-th) dimension (default False) + :param kwargs: + """ + + if isinstance(units, ncps.wirings.Wiring): + wiring = units + else: + wiring = ncps.wirings.FullyConnected(units) + + cell = LTCCell( + wiring=wiring, + input_mapping=input_mapping, + output_mapping=output_mapping, + ode_unfolds=ode_unfolds, + epsilon=epsilon, + initialization_ranges=initialization_ranges, + **kwargs, + ) + if mixed_memory: + cell = MixedMemoryRNN(cell) + super(LTC, self).__init__( + cell, + return_sequences, + return_state, + go_backwards, + stateful, + unroll, + time_major, + **kwargs, + ) diff --git a/ncps/keras/ltc_cell.py b/ncps/keras/ltc_cell.py new file mode 100644 index 00000000..47a90bb5 --- /dev/null +++ b/ncps/keras/ltc_cell.py @@ -0,0 +1,346 @@ +# Copyright 2022 Mathias Lechner and Ramin Hasani +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ncps import wirings +import numpy as np +import keras + + +@keras.utils.register_keras_serializable(package="ncps", name="LTCCell") +class LTCCell(keras.layers.Layer): + name = "LTC-Cell" + + def __init__( + self, + wiring, + input_mapping="affine", + output_mapping="affine", + ode_unfolds=6, + epsilon=1e-8, + initialization_ranges=None, + **kwargs + ): + """A `Liquid time-constant (LTC) `_ cell. + + . Note:: + This is an RNNCell that process single time-steps. + To get a full RNN that can process sequences, + see `ncps.tf.LTC` or wrap the cell with a `keras.layers.RNN `_. + + Examples:: + + >>> import ncps + >>> from ncps.tf import LTCCell + >>> + >>> wiring = ncps.wirings.Random(16, output_dim=2, sparsity_level=0.5) + >>> cell = LTCCell(wiring) + >>> rnn = keras.layers.RNN(cell) + >>> x = keras.random.uniform((1,4)) # (batch, features) + >>> h0 = keras.ops.zeros((1, 16)) + >>> y = keras.layers.SimpleRNNCell(x,h0) + >>> + >>> x_seq = keras.random.uniform((1,20,4)) # (batch, time, features) + >>> y_seq = rnn(x_seq) + + :param wiring: + :param input_mapping: + :param output_mapping: + :param ode_unfolds: + :param epsilon: + :param initialization_ranges: + :param kwargs: + """ + + super().__init__(**kwargs) + self._init_ranges = { + "gleak": (0.001, 1.0), + "vleak": (-0.2, 0.2), + "cm": (0.4, 0.6), + "w": (0.001, 1.0), + "sigma": (3, 8), + "mu": (0.3, 0.8), + "sensory_w": (0.001, 1.0), + "sensory_sigma": (3, 8), + "sensory_mu": (0.3, 0.8), + } + if initialization_ranges is not None: + for k, v in initialization_ranges.items(): + if k not in self._init_ranges.keys(): + raise ValueError( + "Unknown parameter '{}' in initialization range dictionary! (Expected only {})".format( + k, str(list(self._init_ranges.keys())) + ) + ) + if k in ["gleak", "cm", "w", "sensory_w"] and v[0] < 0: + raise ValueError( + "Initialization range of parameter '{}' must be non-negative!".format( + k + ) + ) + if v[0] > v[1]: + raise ValueError( + "Initialization range of parameter '{}' is not a valid range".format( + k + ) + ) + self._init_ranges[k] = v + + self._wiring = wiring + self._input_mapping = input_mapping + self._output_mapping = output_mapping + self._ode_unfolds = ode_unfolds + self._epsilon = epsilon + + @property + def state_size(self): + return self._wiring.units + + @property + def sensory_size(self): + return self._wiring.input_dim + + @property + def motor_size(self): + return self._wiring.output_dim + + @property + def output_size(self): + return self.motor_size + + def _get_initializer(self, param_name): + minval, maxval = self._init_ranges[param_name] + if minval == maxval: + return keras.initializers.Constant(minval) + else: + return keras.initializers.RandomUniform(minval, maxval) + + def build(self, input_shape): + + # Check if input_shape is nested tuple/list + if isinstance(input_shape[0], tuple) or isinstance(input_shape[0], keras.KerasTensor): + # Nested tuple -> First item represent feature dimension + input_dim = input_shape[0][-1] + else: + input_dim = input_shape[-1] + + self._wiring.build(input_dim) + + self._params = {} + self._params["gleak"] = self.add_weight( + name="gleak", + shape=(self.state_size,), + dtype="float32", + constraint=keras.constraints.NonNeg(), + initializer=self._get_initializer("gleak"), + ) + self._params["vleak"] = self.add_weight( + name="vleak", + shape=(self.state_size,), + dtype="float32", + initializer=self._get_initializer("vleak"), + ) + self._params["cm"] = self.add_weight( + name="cm", + shape=(self.state_size,), + dtype="float32", + constraint=keras.constraints.NonNeg(), + initializer=self._get_initializer("cm"), + ) + self._params["sigma"] = self.add_weight( + name="sigma", + shape=(self.state_size, self.state_size), + dtype="float32", + initializer=self._get_initializer("sigma"), + ) + self._params["mu"] = self.add_weight( + name="mu", + shape=(self.state_size, self.state_size), + dtype="float32", + initializer=self._get_initializer("mu"), + ) + self._params["w"] = self.add_weight( + name="w", + shape=(self.state_size, self.state_size), + dtype="float32", + constraint=keras.constraints.NonNeg(), + initializer=self._get_initializer("w"), + ) + self._params["erev"] = self.add_weight( + name="erev", + shape=(self.state_size, self.state_size), + dtype="float32", + initializer=self._wiring.erev_initializer, + ) + + self._params["sensory_sigma"] = self.add_weight( + name="sensory_sigma", + shape=(self.sensory_size, self.state_size), + dtype="float32", + initializer=self._get_initializer("sensory_sigma"), + ) + self._params["sensory_mu"] = self.add_weight( + name="sensory_mu", + shape=(self.sensory_size, self.state_size), + dtype="float32", + initializer=self._get_initializer("sensory_mu"), + ) + self._params["sensory_w"] = self.add_weight( + name="sensory_w", + shape=(self.sensory_size, self.state_size), + dtype="float32", + constraint=keras.constraints.NonNeg(), + initializer=self._get_initializer("sensory_w"), + ) + self._params["sensory_erev"] = self.add_weight( + name="sensory_erev", + shape=(self.sensory_size, self.state_size), + dtype="float32", + initializer=self._wiring.sensory_erev_initializer, + ) + + self._params["sparsity_mask"] = keras.ops.convert_to_tensor( + np.abs(self._wiring.adjacency_matrix), dtype="float32" + ) + self._params["sensory_sparsity_mask"] = keras.ops.convert_to_tensor( + np.abs(self._wiring.sensory_adjacency_matrix), dtype="float32" + ) + + if self._input_mapping in ["affine", "linear"]: + self._params["input_w"] = self.add_weight( + name="input_w", + shape=(self.sensory_size,), + dtype="float32", + initializer=keras.initializers.Constant(1), + ) + if self._input_mapping == "affine": + self._params["input_b"] = self.add_weight( + name="input_b", + shape=(self.sensory_size,), + dtype="float32", + initializer=keras.initializers.Constant(0), + ) + + if self._output_mapping in ["affine", "linear"]: + self._params["output_w"] = self.add_weight( + name="output_w", + shape=(self.motor_size,), + dtype="float32", + initializer=keras.initializers.Constant(1), + ) + if self._output_mapping == "affine": + self._params["output_b"] = self.add_weight( + name="output_b", + shape=(self.motor_size,), + dtype="float32", + initializer=keras.initializers.Constant(0), + ) + self.built = True + + def _sigmoid(self, v_pre, mu, sigma): + v_pre = keras.ops.expand_dims(v_pre, axis=-1) # For broadcasting + mues = v_pre - mu + x = sigma * mues + return keras.activations.sigmoid(x) + + def _ode_solver(self, inputs, state, elapsed_time): + v_pre = state + + # We can pre-compute the effects of the sensory neurons here + sensory_w_activation = self._params["sensory_w"] * self._sigmoid( + inputs, self._params["sensory_mu"], self._params["sensory_sigma"] + ) + sensory_w_activation *= self._params["sensory_sparsity_mask"] + + sensory_rev_activation = sensory_w_activation * self._params["sensory_erev"] + + # Reduce over dimension 1 (=source sensory neurons) + w_numerator_sensory = keras.ops.sum(sensory_rev_activation, axis=1) + w_denominator_sensory = keras.ops.sum(sensory_w_activation, axis=1) + + # cm/t is loop invariant + cm_t = self._params["cm"] / keras.ops.cast( + elapsed_time / self._ode_unfolds, dtype="float32" + ) + + # Unfold the multiply ODE multiple times into one RNN step + for t in range(self._ode_unfolds): + w_activation = self._params["w"] * self._sigmoid( + v_pre, self._params["mu"], self._params["sigma"] + ) + + w_activation *= self._params["sparsity_mask"] + + rev_activation = w_activation * self._params["erev"] + + # Reduce over dimension 1 (=source neurons) + w_numerator = keras.ops.sum(rev_activation, axis=1) + w_numerator_sensory + w_denominator = keras.ops.sum(w_activation, axis=1) + w_denominator_sensory + + numerator = ( + cm_t * v_pre + + self._params["gleak"] * self._params["vleak"] + + w_numerator + ) + denominator = cm_t + self._params["gleak"] + w_denominator + + # Avoid dividing by 0 + v_pre = numerator / (denominator + self._epsilon) + + return v_pre + + def _map_inputs(self, inputs): + if self._input_mapping in ["affine", "linear"]: + inputs = inputs * self._params["input_w"] + if self._input_mapping == "affine": + inputs = inputs + self._params["input_b"] + return inputs + + def _map_outputs(self, state): + output = state + if self.motor_size < self.state_size: + output = output[:, 0: self.motor_size] + + if self._output_mapping in ["affine", "linear"]: + output = output * self._params["output_w"] + if self._output_mapping == "affine": + output = output + self._params["output_b"] + return output + + def call(self, sequence, states, training=False): + if isinstance(sequence, (tuple, list)): + # Irregularly sampled mode + inputs, elapsed_time = sequence + else: + # Regularly sampled mode (elapsed time = 1 second) + elapsed_time = 1.0 + inputs = self._map_inputs(sequence) + + next_state = self._ode_solver(inputs, states[0], elapsed_time) + + outputs = self._map_outputs(next_state) + + return outputs, [next_state] + + def get_config(self): + seralized = self._wiring.get_config() + seralized["input_mapping"] = self._input_mapping + seralized["output_mapping"] = self._output_mapping + seralized["ode_unfolds"] = self._ode_unfolds + seralized["epsilon"] = self._epsilon + return seralized + + @classmethod + def from_config(cls, config): + wiring = wirings.Wiring.from_config(config) + return cls(wiring=wiring, **config) diff --git a/ncps/keras/mm_rnn.py b/ncps/keras/mm_rnn.py new file mode 100644 index 00000000..183e6a28 --- /dev/null +++ b/ncps/keras/mm_rnn.py @@ -0,0 +1,103 @@ +# Copyright 2022 Mathias Lechner and Ramin Hasani +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import keras + + +@keras.utils.register_keras_serializable(package="ncps", name="MixedMemoryRNN") +class MixedMemoryRNN(keras.layers.Layer): + def __init__(self, rnn_cell, forget_gate_bias=1.0, **kwargs): + super().__init__(**kwargs) + + self.rnn_cell = rnn_cell + self.forget_gate_bias = forget_gate_bias + + @property + def state_size(self): + return [self.flat_size, self.rnn_cell.state_size] + + @property + def flat_size(self): + if isinstance(self.rnn_cell.state_size, int): + return self.rnn_cell.state_size + return sum(self.rnn_cell.state_size) + + def build(self, sequences_shape, initial_state_shape=None): + input_dim = sequences_shape[-1] + if isinstance(sequences_shape[0], tuple) or isinstance(sequences_shape[0], keras.KerasTensor): + # Nested tuple + input_dim = sequences_shape[0][-1] + + self.rnn_cell.build((None, self.flat_size)) + self.input_kernel = self.add_weight( + shape=(input_dim, 4 * self.flat_size), + initializer="glorot_uniform", + name="input_kernel", + ) + self.recurrent_kernel = self.add_weight( + shape=(self.flat_size, 4 * self.flat_size), + initializer="orthogonal", + name="recurrent_kernel", + ) + self.bias = self.add_weight( + shape=(4 * self.flat_size,), + initializer=keras.initializers.Zeros(), + name="bias", + ) + + self.built = True + + def call(self, sequences, initial_state=None, mask=None, training=False, **kwargs): + memory_state, ct_state = initial_state + flat_ct_state = keras.ops.concatenate([ct_state], axis=-1) + z = ( + keras.ops.matmul(sequences, self.input_kernel) + + keras.ops.matmul(flat_ct_state, self.recurrent_kernel) + + self.bias + ) + + i, ig, fg, og = keras.ops.split(z, 4, axis=-1) + + input_activation = keras.activations.tanh(i) + input_gate = keras.activations.sigmoid(ig) + forget_gate = keras.activations.sigmoid(fg + self.forget_gate_bias) + output_gate = keras.activations.sigmoid(og) + + new_memory_state = memory_state * forget_gate + input_activation * input_gate + ct_input = keras.activations.tanh(new_memory_state) * output_gate # LSTM output = ODE input + + if (isinstance(sequences, tuple) or isinstance(sequences, list)) and len(sequences) > 1: + # Input is a tuple -> Ct cell input should also be a tuple + ct_input = (ct_input,) + sequences[1:] + + # Implementation choice on how to parametrize ODE component + if (not isinstance(ct_state, tuple)) and (not isinstance(ct_state, list)): + ct_state = [ct_state] + + ct_output, new_ct_state = self.rnn_cell(ct_input, ct_state, **kwargs) + + return ct_output, [new_memory_state, new_ct_state] + + def get_config(self): + serialized = { + "rnn_cell": self.rnn_cell.get_config(), + "forget_gate_bias": self.forget_gate_bias, + } + return serialized + + @classmethod + def from_config(cls, config, custom_objects=None): + rnn_cell = keras.layers.deserialize(config["rnn_cell"]) + return cls(rnn_cell=rnn_cell, **config) diff --git a/ncps/keras/wired_cfc_cell.py b/ncps/keras/wired_cfc_cell.py new file mode 100644 index 00000000..89ed84f9 --- /dev/null +++ b/ncps/keras/wired_cfc_cell.py @@ -0,0 +1,167 @@ +# Copyright 2022 Mathias Lechner. All rights reserved +import numpy + +from .cfc_cell import lecun_tanh, CfCCell + +import keras +from ncps.wirings import wirings +import numpy as np + + +def split_tensor(input_tensor, num_or_size_splits, axis=0): + """ + Splits the input tensor along the specified axis into multiple sub-tensors. + + Args: + input_tensor (Tensor): The input tensor to be split. + num_or_size_splits (int or list/tuple): If an integer, the number of equal splits along the axis. + If a list/tuple, the sizes of each output tensor along the axis. + axis (int): The axis along which to split the tensor. Default is 0. + + Returns: + A list of tensors resulting from splitting the input tensor. + """ + input_shape = keras.ops.shape(input_tensor) + tensor_shape = input_shape[:axis] + (-1,) + input_shape[axis+1:] + + if isinstance(num_or_size_splits, int): + split_sizes = [input_shape[axis] // num_or_size_splits] * num_or_size_splits + else: + split_sizes = num_or_size_splits + + split_tensors = [] + start = 0 + for size in split_sizes: + end = start + size + tensor = keras.layers.Lambda(lambda x: x[:, start:end], output_shape=tensor_shape)(input_tensor) + split_tensors.append(tensor) + start = end + + return split_tensors + + +@keras.utils.register_keras_serializable(package="ncps", name="WiredCfCCell") +class WiredCfCCell(keras.layers.Layer): + def __init__( + self, + wiring, + fully_recurrent=True, + mode="default", + activation="lecun_tanh", + **kwargs, + ): + super().__init__(**kwargs) + self._wiring = wiring + allowed_modes = ["default", "pure", "no_gate"] + if mode not in allowed_modes: + raise ValueError( + "Unknown mode '{}', valid options are {}".format( + mode, str(allowed_modes) + ) + ) + self.mode = mode + self.fully_recurrent = fully_recurrent + if activation == "lecun_tanh": + activation = lecun_tanh + self._activation = activation + self._cfc_layers = [] + + @property + def state_size(self): + return self._wiring.units + # return [ + # len(self._wiring.get_neurons_of_layer(i)) + # for i in range(self._wiring.num_layers) + # ] + + @property + def input_size(self): + return self._wiring.input_dim + + def build(self, input_shape): + if isinstance(input_shape[0], tuple): + # Nested tuple -> First item represent feature dimension + input_dim = input_shape[0][-1] + else: + input_dim = input_shape[-1] + + self._wiring.build(input_dim) + for i in range(self._wiring.num_layers): + layer_i_neurons = self._wiring.get_neurons_of_layer(i) + if i == 0: + input_sparsity = self._wiring.sensory_adjacency_matrix[ + :, layer_i_neurons + ] + else: + prev_layer_neurons = self._wiring.get_neurons_of_layer(i - 1) + input_sparsity = self._wiring.adjacency_matrix[:, layer_i_neurons] + input_sparsity = input_sparsity[prev_layer_neurons, :] + if self.fully_recurrent: + recurrent_sparsity = np.ones( + (len(layer_i_neurons), len(layer_i_neurons)), dtype=np.int32 + ) + else: + recurrent_sparsity = self._wiring.adjacency_matrix[ + layer_i_neurons, layer_i_neurons + ] + cell = CfCCell( + len(layer_i_neurons), + input_sparsity, + recurrent_sparsity, + mode=self.mode, + activation=self._activation, + backbone_units=0, + backbone_layers=0, + backbone_dropout=0, + ) + + cell_in_shape = (None, input_sparsity.shape[0]) + # cell.build(cell_in_shape) + self._cfc_layers.append(cell) + + self._layer_sizes = [l.units for l in self._cfc_layers] + self.built = True + + def call(self, inputs, states, **kwargs): + if isinstance(inputs, (tuple, list)): + # Irregularly sampled mode + inputs, t = inputs + t = keras.ops.reshape(t, [-1, 1]) + else: + # Regularly sampled mode (elapsed time = 1 second) + t = 1.0 + + states = split_tensor(states[0], self._layer_sizes, axis=-1) + assert len(states) == self._wiring.num_layers, \ + f'Incompatible num of states [{len(states)}] and wiring layers [{self._wiring.num_layers}]' + new_hiddens = [] + for i, cfc_layer in enumerate(self._cfc_layers): + if t == 1.0: + output, new_hidden = cfc_layer(inputs, [states[i]], time=t) + else: + output, new_hidden = cfc_layer((inputs, t), [states[i]]) + cfc_layer._allow_non_tensor_positional_args = True + new_hiddens.append(new_hidden[0]) + inputs = output + + assert len(new_hiddens) == self._wiring.num_layers, \ + f'Internal error new_hiddens [{new_hiddens}] != num_layers [{self._wiring.num_layers}]' + if self._wiring.output_dim != output.shape[-1]: + output = output[:, 0: self._wiring.output_dim] + + new_hiddens = keras.ops.concatenate(new_hiddens, axis=-1) + return output, new_hiddens + + def get_config(self): + seralized = self._wiring.get_config() + seralized["mode"] = self.mode + seralized["activation"] = self._activation + seralized["backbone_units"] = None + seralized["backbone_layers"] = None + seralized["backbone_dropout"] = None + return seralized + + @classmethod + def from_config(cls, config): + wiring = wirings.Wiring.from_config(config) + return cls(wiring=wiring, **config) diff --git a/ncps/tests/test_keras.py b/ncps/tests/test_keras.py new file mode 100644 index 00000000..f342783c --- /dev/null +++ b/ncps/tests/test_keras.py @@ -0,0 +1,305 @@ +# Copyright 2022 Mathias Lechner +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Run on CPU +os.environ["KERAS_BACKEND"] = "torch" +# os.environ["KERAS_BACKEND"] = "tensorflow" + +import keras +import numpy as np +import pytest +from ncps.keras import CfC, LTCCell, LTC +from ncps import wirings + + +def test_fc(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + print("data_y.shape: ", str(data_y.shape)) + fc_wiring = wirings.FullyConnected(8, 1) # 8 units, 1 of which is a motor neuron + ltc_cell = LTCCell(fc_wiring) + + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.RNN(ltc_cell, return_sequences=True), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_random(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + arch = wirings.Random(32, 1, sparsity_level=0.5) # 32 units, 1 motor neuron + ltc_cell = LTCCell(arch) + + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.RNN(ltc_cell, return_sequences=True), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_ncp(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + ncp_wiring = wirings.NCP( + inter_neurons=20, # Number of inter neurons + command_neurons=10, # Number of command neurons + motor_neurons=1, # Number of motor neurons + sensory_fanout=4, # How many outgoing synapses has each sensory neuron + inter_fanout=5, # How many outgoing synapses has each inter neuron + recurrent_command_synapses=6, # Now many recurrent synapses are in the + # command neuron layer + motor_fanin=4, # How many incoming synapses has each motor neuron + ) + ltc_cell = LTCCell(ncp_wiring) + + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.RNN(ltc_cell, return_sequences=True), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_fit(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + print("data_y.shape: ", str(data_y.shape)) + rnn = CfC(8, return_sequences=True) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + rnn, + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_mm_rnn(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + print("data_y.shape: ", str(data_y.shape)) + rnn = CfC(8, return_sequences=True, mixed_memory=True) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + rnn, + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_ncp_rnn(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + ncp_wiring = wirings.NCP( + inter_neurons=20, # Number of inter neurons + command_neurons=10, # Number of command neurons + motor_neurons=1, # Number of motor neurons + sensory_fanout=4, # How many outgoing synapses has each sensory neuron + inter_fanout=5, # How many outgoing synapses has each inter neuron + recurrent_command_synapses=6, # Now many recurrent synapses are in the + # command neuron layer + motor_fanin=4, # How many incoming synapses has each motor neuron + ) + ltc = LTC(ncp_wiring, return_sequences=True) + + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + ltc, + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_auto_ncp_rnn(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + ncp_wiring = wirings.AutoNCP(28, 1) + ltc = LTC(ncp_wiring, return_sequences=True) + + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + ltc, + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + assert ncp_wiring.synapse_count > 0 + assert ncp_wiring.sensory_synapse_count > 0 + + +def test_random_cfc(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + arch = wirings.Random(32, 1, sparsity_level=0.5) # 32 units, 1 motor neuron + cfc = CfC(arch, return_sequences=True) + + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + cfc, + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_ncp_cfc_rnn(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + ncp_wiring = wirings.NCP( + inter_neurons=20, # Number of inter neurons + command_neurons=10, # Number of command neurons + motor_neurons=1, # Number of motor neurons + sensory_fanout=4, # How many outgoing synapses has each sensory neuron + inter_fanout=5, # How many outgoing synapses has each inter neuron + recurrent_command_synapses=6, # Now many recurrent synapses are in the + # command neuron layer + motor_fanin=4, # How many incoming synapses has each motor neuron + ) + ltc = CfC(ncp_wiring, return_sequences=True) + + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + ltc, + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_auto_ncp_cfc_rnn(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + ncp_wiring = wirings.AutoNCP(32, 1) + ltc = CfC(ncp_wiring, return_sequences=True) + + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + ltc, + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_ltc_rnn(): + N = 48 # Length of the time-series + # Input feature is a sine and a cosine wave + data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], + axis=1, + ) + data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension + # Target output is a sine with double the frequency of the input signal + data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + ltc = LTC(32, return_sequences=True) + + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + ltc, + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) diff --git a/requirements.txt b/requirements.txt index ad0d1924..3620399b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -numpy -tensorflow -torch -pytest +numpy~=1.26.4 +keras~=3.3.3 +pytest~=8.2.2 From c5cab479d2f1faea12a0fc5c8db7bb7a5f4e602c Mon Sep 17 00:00:00 2001 From: lettercode <59030475+lettercode@users.noreply.github.com> Date: Sat, 15 Jun 2024 21:05:16 +0200 Subject: [PATCH 2/4] Add some converted torch tests --- ncps/keras/cfc_cell.py | 2 +- ncps/tests/test_keras.py | 149 +++++++++++++++------------------------ ncps/wirings/wirings.py | 4 +- 3 files changed, 61 insertions(+), 94 deletions(-) diff --git a/ncps/keras/cfc_cell.py b/ncps/keras/cfc_cell.py index 18b7a1c6..cb4851ba 100644 --- a/ncps/keras/cfc_cell.py +++ b/ncps/keras/cfc_cell.py @@ -40,7 +40,7 @@ def __init__( .. Note:: This is an RNNCell that process single time-steps. To get a full RNN that can process sequences, - see `ncps.tf.CfC` or wrap the cell with a `keras.layers.RNN `_. + see `ncps.keras.CfC` or wrap the cell with a `keras.layers.RNN `_. :param units: Number of hidden units diff --git a/ncps/tests/test_keras.py b/ncps/tests/test_keras.py index f342783c..4df99d6d 100644 --- a/ncps/tests/test_keras.py +++ b/ncps/tests/test_keras.py @@ -19,20 +19,13 @@ import keras import numpy as np import pytest +import ncps from ncps.keras import CfC, LTCCell, LTC from ncps import wirings def test_fc(): - N = 48 # Length of the time-series - # Input feature is a sine and a cosine wave - data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], - axis=1, - ) - data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension - # Target output is a sine with double the frequency of the input signal - data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + data_x, data_y = prepare_test_data() print("data_y.shape: ", str(data_y.shape)) fc_wiring = wirings.FullyConnected(8, 1) # 8 units, 1 of which is a motor neuron ltc_cell = LTCCell(fc_wiring) @@ -47,7 +40,7 @@ def test_fc(): model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) -def test_random(): +def prepare_test_data(): N = 48 # Length of the time-series # Input feature is a sine and a cosine wave data_x = np.stack( @@ -57,6 +50,11 @@ def test_random(): data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension # Target output is a sine with double the frequency of the input signal data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + return data_x, data_y + + +def test_random(): + data_x, data_y = prepare_test_data() arch = wirings.Random(32, 1, sparsity_level=0.5) # 32 units, 1 motor neuron ltc_cell = LTCCell(arch) @@ -71,15 +69,7 @@ def test_random(): def test_ncp(): - N = 48 # Length of the time-series - # Input feature is a sine and a cosine wave - data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], - axis=1, - ) - data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension - # Target output is a sine with double the frequency of the input signal - data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + data_x, data_y = prepare_test_data() ncp_wiring = wirings.NCP( inter_neurons=20, # Number of inter neurons command_neurons=10, # Number of command neurons @@ -103,15 +93,7 @@ def test_ncp(): def test_fit(): - N = 48 # Length of the time-series - # Input feature is a sine and a cosine wave - data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], - axis=1, - ) - data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension - # Target output is a sine with double the frequency of the input signal - data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + data_x, data_y = prepare_test_data() print("data_y.shape: ", str(data_y.shape)) rnn = CfC(8, return_sequences=True) model = keras.models.Sequential( @@ -126,15 +108,7 @@ def test_fit(): def test_mm_rnn(): - N = 48 # Length of the time-series - # Input feature is a sine and a cosine wave - data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], - axis=1, - ) - data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension - # Target output is a sine with double the frequency of the input signal - data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + data_x, data_y = prepare_test_data() print("data_y.shape: ", str(data_y.shape)) rnn = CfC(8, return_sequences=True, mixed_memory=True) model = keras.models.Sequential( @@ -149,15 +123,7 @@ def test_mm_rnn(): def test_ncp_rnn(): - N = 48 # Length of the time-series - # Input feature is a sine and a cosine wave - data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], - axis=1, - ) - data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension - # Target output is a sine with double the frequency of the input signal - data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + data_x, data_y = prepare_test_data() ncp_wiring = wirings.NCP( inter_neurons=20, # Number of inter neurons command_neurons=10, # Number of command neurons @@ -181,15 +147,7 @@ def test_ncp_rnn(): def test_auto_ncp_rnn(): - N = 48 # Length of the time-series - # Input feature is a sine and a cosine wave - data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], - axis=1, - ) - data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension - # Target output is a sine with double the frequency of the input signal - data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + data_x, data_y = prepare_test_data() ncp_wiring = wirings.AutoNCP(28, 1) ltc = LTC(ncp_wiring, return_sequences=True) @@ -206,15 +164,7 @@ def test_auto_ncp_rnn(): def test_random_cfc(): - N = 48 # Length of the time-series - # Input feature is a sine and a cosine wave - data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], - axis=1, - ) - data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension - # Target output is a sine with double the frequency of the input signal - data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + data_x, data_y = prepare_test_data() arch = wirings.Random(32, 1, sparsity_level=0.5) # 32 units, 1 motor neuron cfc = CfC(arch, return_sequences=True) @@ -229,15 +179,7 @@ def test_random_cfc(): def test_ncp_cfc_rnn(): - N = 48 # Length of the time-series - # Input feature is a sine and a cosine wave - data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], - axis=1, - ) - data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension - # Target output is a sine with double the frequency of the input signal - data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + data_x, data_y = prepare_test_data() ncp_wiring = wirings.NCP( inter_neurons=20, # Number of inter neurons command_neurons=10, # Number of command neurons @@ -261,15 +203,7 @@ def test_ncp_cfc_rnn(): def test_auto_ncp_cfc_rnn(): - N = 48 # Length of the time-series - # Input feature is a sine and a cosine wave - data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], - axis=1, - ) - data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension - # Target output is a sine with double the frequency of the input signal - data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + data_x, data_y = prepare_test_data() ncp_wiring = wirings.AutoNCP(32, 1) ltc = CfC(ncp_wiring, return_sequences=True) @@ -284,15 +218,7 @@ def test_auto_ncp_cfc_rnn(): def test_ltc_rnn(): - N = 48 # Length of the time-series - # Input feature is a sine and a cosine wave - data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], - axis=1, - ) - data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension - # Target output is a sine with double the frequency of the input signal - data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) + data_x, data_y = prepare_test_data() ltc = LTC(32, return_sequences=True) model = keras.models.Sequential( @@ -303,3 +229,44 @@ def test_ltc_rnn(): ) model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + +def test_ncps(): + input_size = 8 + + wiring = ncps.wirings.FullyConnected(8, 4) # 16 units, 8 motor neurons + ltc_cell = LTCCell(wiring) + data = keras.random.normal([3, input_size]) + hx = keras.ops.zeros([3, wiring.units]) + output, hx = ltc_cell(data, hx) + assert output.size() == (3, 4) + assert hx[0].size() == (3, wiring.units) + +def test_ncp_sizes(): + wiring = ncps.wirings.NCP(10, 10, 8, 6, 6, 4, 6) + rnn = LTC(wiring) + data = keras.random.normal([5, 3, 8]) + output = rnn(data) + assert wiring.synapse_count > 0 + assert wiring.sensory_synapse_count > 0 + assert output.size() == (5, 8) + +def test_auto_ncp(): + wiring = ncps.wirings.AutoNCP(16, 4) + rnn = LTC(wiring) + data = keras.random.normal([5, 3, 8]) + output = rnn(data) + assert output.size() == (5, 4) + +def test_ncp_cfc(): + wiring = ncps.wirings.NCP(10, 10, 8, 6, 6, 4, 6) + rnn = CfC(wiring) + data = keras.random.normal([5, 3, 8]) + output = rnn(data) + assert output.size() == (5, 8) + +def test_auto_ncp_cfc(): + wiring = ncps.wirings.AutoNCP(28, 10) + rnn = CfC(wiring) + data = keras.random.normal([5, 3, 8]) + output = rnn(data) + assert output.size() == (5, 10) diff --git a/ncps/wirings/wirings.py b/ncps/wirings/wirings.py index 1025c9d2..4ebbebd7 100644 --- a/ncps/wirings/wirings.py +++ b/ncps/wirings/wirings.py @@ -532,7 +532,7 @@ def _build_command__to_motor_layer(self): polarity = self._rng.choice([-1, 1]) self.add_synapse(src, dest, polarity) - # If it happens that some commandneurons are not connected, connect them now + # If it happens that some command neurons are not connected, connect them now mean_command_fanout = int( self._num_motor_neurons * self._motor_fanin / self._num_command_neurons ) @@ -597,4 +597,4 @@ def __init__( recurrent_command_synapses, motor_fanin, seed=seed, - ) \ No newline at end of file + ) From 1335cb892e2cde2c28f6d99cb70fb40172ad0995 Mon Sep 17 00:00:00 2001 From: lettercode <59030475+lettercode@users.noreply.github.com> Date: Sat, 15 Jun 2024 21:10:25 +0200 Subject: [PATCH 3/4] Update GitHub action Also fix the runs for Tensorflow backend --- .github/workflows/python-test.yml | 48 +++++++++++++++++++++++++------ ncps/tests/test_keras.py | 21 ++++++++------ 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index bc1ef9cd..fa54bb2c 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -9,16 +9,19 @@ permissions: contents: read jobs: - build: - + build_pytorch_backend: runs-on: ubuntu-latest + container: + image: pytorch/pytorch:2.3.1-cuda11.8-cudnn8-runtime + env: + KERAS_BACKEND: torch + volumes: + - my_docker_volume:/volume_mount + steps: - uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 - with: - python-version: "3.10" + - name: Install dependencies run: | python -m pip install --upgrade pip @@ -35,6 +38,35 @@ jobs: echo "PYTHONPATH=." >> $GITHUB_ENV - name: Test with pytest run: | - pytest ncps/tests/test_tf.py - pytest ncps/tests/test_torch.py + pytest ncps/tests/test_keras.py + build_tensorflow_backend: + runs-on: ubuntu-latest + + container: + image: tensorflow/tensorflow:2.16.1 + env: + KERAS_BACKEND: tensorflow + volumes: + - my_docker_volume:/volume_mount + + steps: + - uses: actions/checkout@v3 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: set pythonpath + run: | + echo "PYTHONPATH=." >> $GITHUB_ENV + - name: Test with pytest + run: | + pytest ncps/tests/test_keras.py diff --git a/ncps/tests/test_keras.py b/ncps/tests/test_keras.py index 4df99d6d..0520d3d1 100644 --- a/ncps/tests/test_keras.py +++ b/ncps/tests/test_keras.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os +# import os # os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Run on CPU -os.environ["KERAS_BACKEND"] = "torch" +# os.environ["KERAS_BACKEND"] = "torch" # os.environ["KERAS_BACKEND"] = "tensorflow" import keras @@ -230,6 +230,7 @@ def test_ltc_rnn(): model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + def test_ncps(): input_size = 8 @@ -238,8 +239,9 @@ def test_ncps(): data = keras.random.normal([3, input_size]) hx = keras.ops.zeros([3, wiring.units]) output, hx = ltc_cell(data, hx) - assert output.size() == (3, 4) - assert hx[0].size() == (3, wiring.units) + assert output.shape == (3, 4) + assert hx[0].shape == (3, wiring.units) + def test_ncp_sizes(): wiring = ncps.wirings.NCP(10, 10, 8, 6, 6, 4, 6) @@ -248,25 +250,28 @@ def test_ncp_sizes(): output = rnn(data) assert wiring.synapse_count > 0 assert wiring.sensory_synapse_count > 0 - assert output.size() == (5, 8) + assert output.shape == (5, 8) + def test_auto_ncp(): wiring = ncps.wirings.AutoNCP(16, 4) rnn = LTC(wiring) data = keras.random.normal([5, 3, 8]) output = rnn(data) - assert output.size() == (5, 4) + assert output.shape == (5, 4) + def test_ncp_cfc(): wiring = ncps.wirings.NCP(10, 10, 8, 6, 6, 4, 6) rnn = CfC(wiring) data = keras.random.normal([5, 3, 8]) output = rnn(data) - assert output.size() == (5, 8) + assert output.shape == (5, 8) + def test_auto_ncp_cfc(): wiring = ncps.wirings.AutoNCP(28, 10) rnn = CfC(wiring) data = keras.random.normal([5, 3, 8]) output = rnn(data) - assert output.size() == (5, 10) + assert output.shape == (5, 10) From e39620ebb75062666f6d416574e785966f460f6b Mon Sep 17 00:00:00 2001 From: lettercode <59030475+lettercode@users.noreply.github.com> Date: Sun, 16 Jun 2024 00:03:04 +0200 Subject: [PATCH 4/4] Update code doc --- ncps/keras/cfc.py | 4 ++-- ncps/keras/ltc.py | 6 +++--- ncps/keras/ltc_cell.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ncps/keras/cfc.py b/ncps/keras/cfc.py index ac4ec8e8..4ba7fe50 100644 --- a/ncps/keras/cfc.py +++ b/ncps/keras/cfc.py @@ -43,11 +43,11 @@ def __init__( Examples:: - >>> from ncps.tf import CfC + >>> from ncps.keras import CfC >>> >>> rnn = CfC(50) >>> x = keras.random.uniform((2, 10, 20)) # (B,L,C) - >>> y = keras.layers.RNN(x) + >>> y = rnn(x) :param units: Number of hidden units :param mixed_memory: Whether to augment the RNN by a `memory-cell `_ to help learn long-term dependencies in the data (default False) diff --git a/ncps/keras/ltc.py b/ncps/keras/ltc.py index bd02699a..46caa472 100644 --- a/ncps/keras/ltc.py +++ b/ncps/keras/ltc.py @@ -43,7 +43,7 @@ def __init__( >>> from ncps.keras import LTC >>> >>> rnn = LTC(50) - >>> x = tf.random.uniform((2, 10, 20)) # (B,L,C) + >>> x = keras.random.uniform((2, 10, 20)) # (B,L,C) >>> y = rnn(x) .. Note:: @@ -51,12 +51,12 @@ def __init__( Examples:: - >>> from ncps.tf import LTC + >>> from ncps.keras import LTC >>> from ncps.wirings import NCP >>> >>> wiring = NCP(10, 10, 8, 6, 6, 4, 4) >>> rnn = LTC(wiring) - >>> x = tf.random.uniform((2, 10, 20)) # (B,L,C) + >>> x = keras.random.uniform((2, 10, 20)) # (B,L,C) >>> y = rnn(x) :param units: Wiring (ncps.wirings.Wiring instance) or integer representing the number of (fully-connected) hidden units diff --git a/ncps/keras/ltc_cell.py b/ncps/keras/ltc_cell.py index 47a90bb5..268e3089 100644 --- a/ncps/keras/ltc_cell.py +++ b/ncps/keras/ltc_cell.py @@ -36,12 +36,12 @@ def __init__( . Note:: This is an RNNCell that process single time-steps. To get a full RNN that can process sequences, - see `ncps.tf.LTC` or wrap the cell with a `keras.layers.RNN `_. + see `ncps.keras.LTC` or wrap the cell with a `keras.layers.RNN `_. Examples:: >>> import ncps - >>> from ncps.tf import LTCCell + >>> from ncps.keras import LTCCell >>> >>> wiring = ncps.wirings.Random(16, output_dim=2, sparsity_level=0.5) >>> cell = LTCCell(wiring)