From 89c15c1857b7a091d04f33d7d16b85a4bd2c48e5 Mon Sep 17 00:00:00 2001 From: lettercode <59030475+lettercode@users.noreply.github.com> Date: Sun, 21 Jul 2024 01:44:02 +0200 Subject: [PATCH 1/4] Support Bidirectional layers --- ncps/keras/cfc.py | 67 +++++++++++++++++++++-------- ncps/keras/cfc_cell.py | 52 +++++++++++----------- ncps/keras/ltc.py | 64 ++++++++++++++++++--------- ncps/keras/ltc_cell.py | 19 ++++----- ncps/keras/wired_cfc_cell.py | 65 ++++++++++++++-------------- ncps/tests/test_keras.py | 83 ++++++++++++++++++++++++++++++++++++ ncps/wirings/wirings.py | 12 +++--- 7 files changed, 252 insertions(+), 110 deletions(-) diff --git a/ncps/keras/cfc.py b/ncps/keras/cfc.py index 4ba7fe50..335508b5 100644 --- a/ncps/keras/cfc.py +++ b/ncps/keras/cfc.py @@ -23,21 +23,23 @@ @keras.utils.register_keras_serializable(package="ncps", name="CfC") class CfC(keras.layers.RNN): def __init__( - self, - units: Union[int, ncps.wirings.Wiring], - mixed_memory: bool = False, - mode: str = "default", - activation: str = "lecun_tanh", - backbone_units: int = None, - backbone_layers: int = None, - backbone_dropout: float = None, - return_sequences: bool = False, - return_state: bool = False, - go_backwards: bool = False, - stateful: bool = False, - unroll: bool = False, - time_major: bool = False, - **kwargs, + self, + units: Union[int, ncps.wirings.Wiring], + mixed_memory: bool = False, + mode: str = "default", + activation: str = "lecun_tanh", + backbone_units: int = None, + backbone_layers: int = None, + backbone_dropout: float = None, + sparsity_mask: keras.layers.Layer = None, + fully_recurrent: bool = True, + return_sequences: bool = False, + return_state: bool = False, + go_backwards: bool = False, + stateful: bool = False, + unroll: bool = False, + zero_output_for_mask: bool = False, + **kwargs, ): """Applies a `Closed-form Continuous-time `_ RNN to an input sequence. @@ -56,12 +58,18 @@ def __init__( :param backbone_units: Number of hidden units in the backbone layer (default 128) :param backbone_layers: Number of backbone layers (default 1) :param backbone_dropout: Dropout rate in the backbone layers (default 0) + :param sparsity_mask: + :param fully_recurrent: Whether to apply a fully-connected sparsity_mask or use the adjacency_matrix. Evaluated only for WiredCfCCell. (default True) :param return_sequences: Whether to return the full sequence or just the last output (default False) :param return_state: Whether to return just the output of the RNN or a tuple (output, last_hidden_state) (default False) :param go_backwards: If True, the input sequence will be process from back to the front (default False) :param stateful: Whether to remember the last hidden state of the previous inference/training batch and use it as initial state for the next inference/training batch (default False) :param unroll: Whether to unroll the graph, i.e., may increase speed at the cost of more memory (default False) - :param time_major: Whether the time or batch dimension is the first (0-th) dimension (default False) + :param zero_output_for_mask: Whether the output should use zeros for the masked timesteps. (default False) + Note that this field is only used when `return_sequences` is `True` and `mask` is provided. + It can be useful if you want to reuse the raw output sequence of + the RNN without interference from the masked timesteps, e.g., + merging bidirectional RNNs. :param kwargs: """ @@ -72,7 +80,7 @@ def __init__( raise ValueError(f"Cannot use backbone_layers in wired mode") if backbone_dropout is not None: raise ValueError(f"Cannot use backbone_dropout in wired mode") - cell = WiredCfCCell(units, mode=mode, activation=activation) + cell = WiredCfCCell(units, mode=mode, activation=activation, fully_recurrent=fully_recurrent) else: backbone_units = 128 if backbone_units is None else backbone_units backbone_layers = 1 if backbone_layers is None else backbone_layers @@ -84,6 +92,7 @@ def __init__( backbone_units=backbone_units, backbone_layers=backbone_layers, backbone_dropout=backbone_dropout, + sparsity_mask=sparsity_mask, ) if mixed_memory: cell = MixedMemoryRNN(cell) @@ -94,6 +103,28 @@ def __init__( go_backwards, stateful, unroll, - time_major, + zero_output_for_mask, **kwargs, ) + + def get_config(self): + is_mixed_memory = isinstance(self.cell, MixedMemoryRNN) + cell: CfCCell | WiredCfCCell = self.cell.rnn_cell if is_mixed_memory else self.cell + cell_config = cell.get_config() + config = super(CfC, self).get_config() + config["units"] = cell.wiring if isinstance(cell, WiredCfCCell) else cell.units + config["mixed_memory"] = is_mixed_memory + config["fully_recurrent"] = cell.fully_recurrent if isinstance(cell, WiredCfCCell) else True # If not WiredCfc it's ignored + return {**cell_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None): + # The following parameters are recreated by the LTC constructor + del config["cell"] + if "wiring" in config: + del config["wiring"] + units = ncps.wirings.Wiring.from_config(config["units"]["config"]) + else: + units = config["units"] + del config["units"] + return cls(units, **config) diff --git a/ncps/keras/cfc_cell.py b/ncps/keras/cfc_cell.py index cb4851ba..0f8d912b 100644 --- a/ncps/keras/cfc_cell.py +++ b/ncps/keras/cfc_cell.py @@ -12,27 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. import keras -import numpy as np # LeCun improved tanh activation # http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf +@keras.utils.register_keras_serializable(package="", name="lecun_tanh") def lecun_tanh(x): return 1.7159 * keras.activations.tanh(0.666 * x) +# Register the custom activation function +from keras.src.activations import ALL_OBJECTS_DICT +ALL_OBJECTS_DICT["lecun_tanh"] = lecun_tanh + + @keras.utils.register_keras_serializable(package="ncps", name="CfCCell") class CfCCell(keras.layers.Layer): def __init__( self, units, - input_sparsity=None, - recurrent_sparsity=None, mode="default", activation="lecun_tanh", backbone_units=128, backbone_layers=1, backbone_dropout=0.1, + sparsity_mask=None, **kwargs, ): """A `Closed-form Continuous-time `_ cell. @@ -55,35 +59,18 @@ def __init__( """ super().__init__(**kwargs) self.units = units - self.sparsity_mask = None - if input_sparsity is not None or recurrent_sparsity is not None: + self.sparsity_mask = sparsity_mask + if sparsity_mask is not None: # No backbone is allowed if backbone_units > 0: - raise ValueError( - "If sparsity of a Cfc cell is set, then no backbone is allowed" - ) - # Both need to be set - if input_sparsity is None or recurrent_sparsity is None: - raise ValueError( - "If sparsity of a Cfc cell is set, then both input and recurrent sparsity needs to be defined" - ) - self.sparsity_mask = keras.ops.convert_to_tensor( - np.concatenate([input_sparsity, recurrent_sparsity], axis=0), - dtype="float32", - ) + raise ValueError("If sparsity of a CfC cell is set, then no backbone is allowed") allowed_modes = ["default", "pure", "no_gate"] if mode not in allowed_modes: - raise ValueError( - "Unknown mode '{}', valid options are {}".format( - mode, str(allowed_modes) - ) - ) + raise ValueError(f"Unknown mode '{mode}', valid options are {str(allowed_modes)}") self.mode = mode self.backbone_fn = None - if activation == "lecun_tanh": - activation = lecun_tanh - self._activation = activation + self._activation = keras.activations.get(activation) self._backbone_units = backbone_units self._backbone_layers = backbone_layers self._backbone_dropout = backbone_dropout @@ -213,3 +200,18 @@ def call(self, inputs, states, **kwargs): new_hidden = ff1 * (1.0 - t_interp) + t_interp * ff2 return new_hidden, [new_hidden] + + def get_config(self): + config = super(CfCCell, self).get_config() + config["units"] = self.units + config["mode"] = self.mode + config["activation"] = self._activation + config["backbone_units"] = self._backbone_units + config["backbone_layers"] = self._backbone_layers + config["backbone_dropout"] = self._backbone_dropout + config["sparsity_mask"] = self.sparsity_mask + return config + + @classmethod + def from_config(cls, config): + return cls(**config) diff --git a/ncps/keras/ltc.py b/ncps/keras/ltc.py index 46caa472..6598200d 100644 --- a/ncps/keras/ltc.py +++ b/ncps/keras/ltc.py @@ -19,22 +19,24 @@ @keras.utils.register_keras_serializable(package="ncps", name="LTC") class LTC(keras.layers.RNN): + name = "LTC" + def __init__( - self, - units, - mixed_memory: bool = False, - input_mapping="affine", - output_mapping="affine", - ode_unfolds=6, - epsilon=1e-8, - initialization_ranges=None, - return_sequences: bool = False, - return_state: bool = False, - go_backwards: bool = False, - stateful: bool = False, - unroll: bool = False, - time_major: bool = False, - **kwargs, + self, + units, + mixed_memory: bool = False, + input_mapping="affine", + output_mapping="affine", + ode_unfolds=6, + epsilon=1e-8, + initialization_ranges=None, + return_sequences: bool = False, + return_state: bool = False, + go_backwards: bool = False, + stateful: bool = False, + unroll: bool = False, + zero_output_for_mask: bool = False, + **kwargs, ): """Applies a `Liquid time-constant (LTC) `_ RNN to an input sequence. @@ -71,17 +73,21 @@ def __init__( :param go_backwards: If True, the input sequence will be process from back to the front (default False) :param stateful: Whether to remember the last hidden state of the previous inference/training batch and use it as initial state for the next inference/training batch (default False) :param unroll: Whether to unroll the graph, i.e., may increase speed at the cost of more memory (default False) - :param time_major: Whether the time or batch dimension is the first (0-th) dimension (default False) + :param zero_output_for_mask: Whether the output should use zeros for the masked timesteps. (default False) + Note that this field is only used when `return_sequences` is `True` and `mask` is provided. + It can be useful if you want to reuse the raw output sequence of + the RNN without interference from the masked timesteps, e.g., + merging bidirectional RNNs. :param kwargs: """ if isinstance(units, ncps.wirings.Wiring): - wiring = units + self.wiring = units else: - wiring = ncps.wirings.FullyConnected(units) + self.wiring = ncps.wirings.FullyConnected(units) cell = LTCCell( - wiring=wiring, + wiring=self.wiring, input_mapping=input_mapping, output_mapping=output_mapping, ode_unfolds=ode_unfolds, @@ -98,6 +104,24 @@ def __init__( go_backwards, stateful, unroll, - time_major, + zero_output_for_mask, **kwargs, ) + + def get_config(self): + is_mixed_memory = isinstance(self.cell, MixedMemoryRNN) + cell: LTCCell = self.cell.rnn_cell if is_mixed_memory else self.cell + cell_config = cell.get_config() + config = super(LTC, self).get_config() + config["units"] = self.wiring + config["mixed_memory"] = is_mixed_memory + return {**cell_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None): + # The following parameters are recreated by the LTC constructor + del config["cell"] + del config["wiring"] + units = ncps.wirings.Wiring.from_config(config["units"]["config"]) + del config["units"] + return cls(units, **config) diff --git a/ncps/keras/ltc_cell.py b/ncps/keras/ltc_cell.py index 268e3089..b81631b8 100644 --- a/ncps/keras/ltc_cell.py +++ b/ncps/keras/ltc_cell.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ncps import wirings -import numpy as np import keras +import numpy as np @keras.utils.register_keras_serializable(package="ncps", name="LTCCell") @@ -333,14 +332,14 @@ def call(self, sequence, states, training=False): return outputs, [next_state] def get_config(self): - seralized = self._wiring.get_config() - seralized["input_mapping"] = self._input_mapping - seralized["output_mapping"] = self._output_mapping - seralized["ode_unfolds"] = self._ode_unfolds - seralized["epsilon"] = self._epsilon - return seralized + config = super(LTCCell, self).get_config() + config["wiring"] = self._wiring.get_config() + config["input_mapping"] = self._input_mapping + config["output_mapping"] = self._output_mapping + config["ode_unfolds"] = self._ode_unfolds + config["epsilon"] = self._epsilon + return config @classmethod def from_config(cls, config): - wiring = wirings.Wiring.from_config(config) - return cls(wiring=wiring, **config) + return cls(**config) diff --git a/ncps/keras/wired_cfc_cell.py b/ncps/keras/wired_cfc_cell.py index 89ed84f9..7c311596 100644 --- a/ncps/keras/wired_cfc_cell.py +++ b/ncps/keras/wired_cfc_cell.py @@ -44,14 +44,14 @@ def split_tensor(input_tensor, num_or_size_splits, axis=0): class WiredCfCCell(keras.layers.Layer): def __init__( self, - wiring, + wiring: wirings.Wiring, fully_recurrent=True, mode="default", activation="lecun_tanh", **kwargs, ): super().__init__(**kwargs) - self._wiring = wiring + self.wiring = wiring allowed_modes = ["default", "pure", "no_gate"] if mode not in allowed_modes: raise ValueError( @@ -68,15 +68,15 @@ def __init__( @property def state_size(self): - return self._wiring.units - # return [ - # len(self._wiring.get_neurons_of_layer(i)) - # for i in range(self._wiring.num_layers) - # ] + return self.wiring.units @property def input_size(self): - return self._wiring.input_dim + return self.wiring.input_dim + + @property + def output_size(self): + return self.wiring.output_dim def build(self, input_shape): if isinstance(input_shape[0], tuple): @@ -85,34 +85,37 @@ def build(self, input_shape): else: input_dim = input_shape[-1] - self._wiring.build(input_dim) - for i in range(self._wiring.num_layers): - layer_i_neurons = self._wiring.get_neurons_of_layer(i) + self.wiring.build(input_dim) + for i in range(self.wiring.num_layers): + layer_i_neurons = self.wiring.get_neurons_of_layer(i) if i == 0: - input_sparsity = self._wiring.sensory_adjacency_matrix[ + input_sparsity = self.wiring.sensory_adjacency_matrix[ :, layer_i_neurons ] else: - prev_layer_neurons = self._wiring.get_neurons_of_layer(i - 1) - input_sparsity = self._wiring.adjacency_matrix[:, layer_i_neurons] + prev_layer_neurons = self.wiring.get_neurons_of_layer(i - 1) + input_sparsity = self.wiring.adjacency_matrix[:, layer_i_neurons] input_sparsity = input_sparsity[prev_layer_neurons, :] if self.fully_recurrent: recurrent_sparsity = np.ones( (len(layer_i_neurons), len(layer_i_neurons)), dtype=np.int32 ) else: - recurrent_sparsity = self._wiring.adjacency_matrix[ + recurrent_sparsity = self.wiring.adjacency_matrix[ layer_i_neurons, layer_i_neurons ] + sparsity_mask = keras.ops.convert_to_tensor( + np.concatenate([input_sparsity, recurrent_sparsity], axis=0), + dtype="float32", + ) cell = CfCCell( len(layer_i_neurons), - input_sparsity, - recurrent_sparsity, mode=self.mode, activation=self._activation, backbone_units=0, backbone_layers=0, backbone_dropout=0, + sparsity_mask=sparsity_mask, ) cell_in_shape = (None, input_sparsity.shape[0]) @@ -132,8 +135,8 @@ def call(self, inputs, states, **kwargs): t = 1.0 states = split_tensor(states[0], self._layer_sizes, axis=-1) - assert len(states) == self._wiring.num_layers, \ - f'Incompatible num of states [{len(states)}] and wiring layers [{self._wiring.num_layers}]' + assert len(states) == self.wiring.num_layers, \ + f'Incompatible num of states [{len(states)}] and wiring layers [{self.wiring.num_layers}]' new_hiddens = [] for i, cfc_layer in enumerate(self._cfc_layers): if t == 1.0: @@ -144,24 +147,22 @@ def call(self, inputs, states, **kwargs): new_hiddens.append(new_hidden[0]) inputs = output - assert len(new_hiddens) == self._wiring.num_layers, \ - f'Internal error new_hiddens [{new_hiddens}] != num_layers [{self._wiring.num_layers}]' - if self._wiring.output_dim != output.shape[-1]: - output = output[:, 0: self._wiring.output_dim] + assert len(new_hiddens) == self.wiring.num_layers, \ + f'Internal error new_hiddens [{new_hiddens}] != num_layers [{self.wiring.num_layers}]' + if self.wiring.output_dim != output.shape[-1]: + output = output[:, 0: self.wiring.output_dim] new_hiddens = keras.ops.concatenate(new_hiddens, axis=-1) return output, new_hiddens def get_config(self): - seralized = self._wiring.get_config() - seralized["mode"] = self.mode - seralized["activation"] = self._activation - seralized["backbone_units"] = None - seralized["backbone_layers"] = None - seralized["backbone_dropout"] = None - return seralized + config = super(WiredCfCCell, self).get_config() + config["wiring"] = self.wiring + config["fully_recurrent"] = self.fully_recurrent + config["mode"] = self.mode + config["activation"] = self._activation + return config @classmethod def from_config(cls, config): - wiring = wirings.Wiring.from_config(config) - return cls(wiring=wiring, **config) + return cls(**config) diff --git a/ncps/tests/test_keras.py b/ncps/tests/test_keras.py index 0520d3d1..250a1015 100644 --- a/ncps/tests/test_keras.py +++ b/ncps/tests/test_keras.py @@ -275,3 +275,86 @@ def test_auto_ncp_cfc(): data = keras.random.normal([5, 3, 8]) output = rnn(data) assert output.shape == (5, 10) + + +def test_bidirectional_ltc(): + rnn = keras.layers.Bidirectional(LTC(28)) + data = keras.random.normal([5, 3, 8]) + output = rnn(data) + assert output.shape == (5, 28 * 2) + + +def test_bidirectional_ltc_mixed_memory(): + rnn = keras.layers.Bidirectional(LTC(28, mixed_memory=True)) + data = keras.random.normal([5, 3, 8]) + output = rnn(data) + assert output.shape == (5, 28 * 2) + + +def test_bidirectional_auto_ncp_ltc(): + wiring = ncps.wirings.AutoNCP(28, 10) + rnn = keras.layers.Bidirectional(LTC(wiring)) + data = keras.random.normal([5, 3, 8]) + + output = rnn(data) + assert output.shape == (5, 10 * 2) + + +def test_fit_bidirectional_auto_ncp_ltc(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(LTC(wiring)), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_bidirectional_cfc(): + rnn = keras.layers.Bidirectional(CfC(28)) + data = keras.random.normal([5, 3, 8]) + output = rnn(data) + assert output.shape == (5, 28 * 2) + + +def test_bidirectional_cfc_mixed_memory(): + rnn = keras.layers.Bidirectional(CfC(28, mixed_memory=True)) + data = keras.random.normal([5, 3, 8]) + output = rnn(data) + assert output.shape == (5, 28 * 2) + + +def test_bidirectional_auto_ncp_cfc(): + wiring = ncps.wirings.AutoNCP(28, 10) + rnn = keras.layers.Bidirectional(CfC(wiring)) + data = keras.random.normal([5, 3, 8]) + output = rnn(data) + assert output.shape == (5, 10 * 2) + + +def test_bidirectional_auto_ncp_cfc_mixed_memory(): + wiring = ncps.wirings.AutoNCP(28, 10) + rnn = keras.layers.Bidirectional(CfC(wiring, mixed_memory=True)) + data = keras.random.normal([5, 3, 8]) + output = rnn(data) + assert output.shape == (5, 10 * 2) + + +def test_fit_bidirectional_auto_ncp_cfc(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(wiring)), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) diff --git a/ncps/wirings/wirings.py b/ncps/wirings/wirings.py index 4ebbebd7..c9017fd4 100644 --- a/ncps/wirings/wirings.py +++ b/ncps/wirings/wirings.py @@ -111,19 +111,21 @@ def add_sensory_synapse(self, src, dest, polarity): def get_config(self): return { - "adjacency_matrix": self.adjacency_matrix, - "sensory_adjacency_matrix": self.sensory_adjacency_matrix, + "units": self.units, + "adjacency_matrix": self.adjacency_matrix.tolist() if self.adjacency_matrix is not None else None, + "sensory_adjacency_matrix": self.sensory_adjacency_matrix.tolist() if self.sensory_adjacency_matrix is not None else None, "input_dim": self.input_dim, "output_dim": self.output_dim, - "units": self.units, } @classmethod def from_config(cls, config): # There might be a cleaner solution but it will work wiring = Wiring(config["units"]) - wiring.adjacency_matrix = config["adjacency_matrix"] - wiring.sensory_adjacency_matrix = config["sensory_adjacency_matrix"] + if config["adjacency_matrix"] is not None: + wiring.adjacency_matrix = np.array(config["adjacency_matrix"]) + if config["sensory_adjacency_matrix"] is not None: + wiring.sensory_adjacency_matrix = np.array(config["sensory_adjacency_matrix"]) wiring.input_dim = config["input_dim"] wiring.output_dim = config["output_dim"] From 706f668e7319e85858eb0147429d51691f29ff35 Mon Sep 17 00:00:00 2001 From: lettercode <59030475+lettercode@users.noreply.github.com> Date: Sun, 21 Jul 2024 14:10:25 +0200 Subject: [PATCH 2/4] Fix serialization of wirings --- .github/workflows/python-test.yml | 2 +- ncps/keras/cfc.py | 3 +- ncps/keras/ltc.py | 11 +- ncps/keras/ltc_cell.py | 20 +- ncps/keras/mm_rnn.py | 7 + ncps/tests/test_keras.py | 312 ++++++++++++++++++++++++++++++ ncps/wirings/wirings.py | 57 ++++++ 7 files changed, 395 insertions(+), 17 deletions(-) diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index fa54bb2c..97b33059 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -56,7 +56,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest + pip install flake8 pytest networkx if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | diff --git a/ncps/keras/cfc.py b/ncps/keras/cfc.py index 335508b5..de364bb8 100644 --- a/ncps/keras/cfc.py +++ b/ncps/keras/cfc.py @@ -122,8 +122,9 @@ def from_config(cls, config, custom_objects=None): # The following parameters are recreated by the LTC constructor del config["cell"] if "wiring" in config: + wiring_class = getattr(ncps.wirings, config["units"]["class_name"]) + units = wiring_class.from_config(config["units"]["config"]) del config["wiring"] - units = ncps.wirings.Wiring.from_config(config["units"]["config"]) else: units = config["units"] del config["units"] diff --git a/ncps/keras/ltc.py b/ncps/keras/ltc.py index 6598200d..3dc5b405 100644 --- a/ncps/keras/ltc.py +++ b/ncps/keras/ltc.py @@ -82,12 +82,12 @@ def __init__( """ if isinstance(units, ncps.wirings.Wiring): - self.wiring = units + wiring = units else: - self.wiring = ncps.wirings.FullyConnected(units) + wiring = ncps.wirings.FullyConnected(units) cell = LTCCell( - wiring=self.wiring, + wiring=wiring, input_mapping=input_mapping, output_mapping=output_mapping, ode_unfolds=ode_unfolds, @@ -113,7 +113,7 @@ def get_config(self): cell: LTCCell = self.cell.rnn_cell if is_mixed_memory else self.cell cell_config = cell.get_config() config = super(LTC, self).get_config() - config["units"] = self.wiring + config["units"] = cell.wiring config["mixed_memory"] = is_mixed_memory return {**cell_config, **config} @@ -122,6 +122,7 @@ def from_config(cls, config, custom_objects=None): # The following parameters are recreated by the LTC constructor del config["cell"] del config["wiring"] - units = ncps.wirings.Wiring.from_config(config["units"]["config"]) + wiring_class = getattr(ncps.wirings, config["units"]["class_name"]) + units = wiring_class.from_config(config["units"]["config"]) del config["units"] return cls(units, **config) diff --git a/ncps/keras/ltc_cell.py b/ncps/keras/ltc_cell.py index b81631b8..3f4d43f8 100644 --- a/ncps/keras/ltc_cell.py +++ b/ncps/keras/ltc_cell.py @@ -95,7 +95,7 @@ def __init__( ) self._init_ranges[k] = v - self._wiring = wiring + self.wiring = wiring self._input_mapping = input_mapping self._output_mapping = output_mapping self._ode_unfolds = ode_unfolds @@ -103,15 +103,15 @@ def __init__( @property def state_size(self): - return self._wiring.units + return self.wiring.units @property def sensory_size(self): - return self._wiring.input_dim + return self.wiring.input_dim @property def motor_size(self): - return self._wiring.output_dim + return self.wiring.output_dim @property def output_size(self): @@ -133,7 +133,7 @@ def build(self, input_shape): else: input_dim = input_shape[-1] - self._wiring.build(input_dim) + self.wiring.build(input_dim) self._params = {} self._params["gleak"] = self.add_weight( @@ -179,7 +179,7 @@ def build(self, input_shape): name="erev", shape=(self.state_size, self.state_size), dtype="float32", - initializer=self._wiring.erev_initializer, + initializer=self.wiring.erev_initializer, ) self._params["sensory_sigma"] = self.add_weight( @@ -205,14 +205,14 @@ def build(self, input_shape): name="sensory_erev", shape=(self.sensory_size, self.state_size), dtype="float32", - initializer=self._wiring.sensory_erev_initializer, + initializer=self.wiring.sensory_erev_initializer, ) self._params["sparsity_mask"] = keras.ops.convert_to_tensor( - np.abs(self._wiring.adjacency_matrix), dtype="float32" + np.abs(self.wiring.adjacency_matrix), dtype="float32" ) self._params["sensory_sparsity_mask"] = keras.ops.convert_to_tensor( - np.abs(self._wiring.sensory_adjacency_matrix), dtype="float32" + np.abs(self.wiring.sensory_adjacency_matrix), dtype="float32" ) if self._input_mapping in ["affine", "linear"]: @@ -333,7 +333,7 @@ def call(self, sequence, states, training=False): def get_config(self): config = super(LTCCell, self).get_config() - config["wiring"] = self._wiring.get_config() + config["wiring"] = self.wiring.get_config() config["input_mapping"] = self._input_mapping config["output_mapping"] = self._output_mapping config["ode_unfolds"] = self._ode_unfolds diff --git a/ncps/keras/mm_rnn.py b/ncps/keras/mm_rnn.py index 183e6a28..ad567aab 100644 --- a/ncps/keras/mm_rnn.py +++ b/ncps/keras/mm_rnn.py @@ -28,6 +28,13 @@ def __init__(self, rnn_cell, forget_gate_bias=1.0, **kwargs): def state_size(self): return [self.flat_size, self.rnn_cell.state_size] + @property + def output_size(self): + if hasattr(self.rnn_cell, "output_size"): + return self.rnn_cell.output_size + else: + return self.rnn_cell.state_size + @property def flat_size(self): if isinstance(self.rnn_cell.state_size, int): diff --git a/ncps/tests/test_keras.py b/ncps/tests/test_keras.py index 250a1015..2ff667c4 100644 --- a/ncps/tests/test_keras.py +++ b/ncps/tests/test_keras.py @@ -358,3 +358,315 @@ def test_fit_bidirectional_auto_ncp_cfc(): ) model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_fit_bidirectional_auto_ncp_ltc_mixed_memory(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(LTC(wiring, mixed_memory=True)), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + +def test_wiring_graph_auto_ncp_ltc(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + LTC(wiring), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + graph = wiring.get_graph() + assert len(graph) == (wiring.units + 2) + + +def test_wiring_graph_bidirectional_auto_ncp_ltc(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + biLTC = keras.layers.Bidirectional(LTC(wiring)) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + biLTC, + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + assert wiring.input_dim is None # This happens because Bidirectional creates two new copies + assert isinstance(biLTC.forward_layer, LTC) + assert isinstance(biLTC.backward_layer, LTC) + + forward_graph = biLTC.forward_layer.cell.wiring.get_graph() + assert len(forward_graph) == (biLTC.forward_layer.cell.wiring.units + 2) + backward_graph = biLTC.backward_layer.cell.wiring.get_graph() + assert len(backward_graph) == (biLTC.backward_layer.cell.wiring.units + 2) + + +def test_bidirectional_equivalence_ltc(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + model1 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(LTC(10, return_sequences=True)), + keras.layers.Dense(1), + ] + ) + model2 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(LTC(10, return_sequences=True), + backward_layer=LTC(10, return_sequences=True, go_backwards=True)), + keras.layers.Dense(1), + ] + ) + model1.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + bi_layer1 = model1.layers[0] + bi_layer2 = model2.layers[0] + assert isinstance(bi_layer1, keras.layers.Bidirectional) + assert isinstance(bi_layer2, keras.layers.Bidirectional) + + fw1_config = bi_layer1.forward_layer.get_config() + fw2_config = bi_layer2.forward_layer.get_config() + bw1_config = bi_layer1.backward_layer.get_config() + bw2_config = bi_layer2.backward_layer.get_config() + + def prune_details(config): + del config['name'] + del config['cell']['config']['name'] + config['units'] = config['units'].get_config() + + prune_details(fw1_config) + prune_details(fw2_config) + prune_details(bw1_config) + prune_details(bw2_config) + + assert fw1_config == fw2_config + assert bw1_config == bw2_config + + assert isinstance(bi_layer1.forward_layer.cell.wiring, ncps.wirings.FullyConnected) + assert isinstance(bi_layer1.backward_layer.cell.wiring, ncps.wirings.FullyConnected) + assert isinstance(bi_layer2.forward_layer.cell.wiring, ncps.wirings.FullyConnected) + assert isinstance(bi_layer2.backward_layer.cell.wiring, ncps.wirings.FullyConnected) + + +def test_bidirectional_equivalence_ltc_ncp(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model1 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(LTC(wiring, return_sequences=True)), + keras.layers.Dense(1), + ] + ) + model2 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(LTC(wiring, return_sequences=True), + backward_layer=LTC(wiring, return_sequences=True, go_backwards=True)), + keras.layers.Dense(1), + ] + ) + model1.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + bi_layer1 = model1.layers[0] + bi_layer2 = model2.layers[0] + assert isinstance(bi_layer1, keras.layers.Bidirectional) + assert isinstance(bi_layer2, keras.layers.Bidirectional) + + fw1_config = bi_layer1.forward_layer.get_config() + fw2_config = bi_layer2.forward_layer.get_config() + bw1_config = bi_layer1.backward_layer.get_config() + bw2_config = bi_layer2.backward_layer.get_config() + + def prune_details(config): + del config['name'] + del config['cell']['config']['name'] + config['units'] = config['units'].get_config() + + prune_details(fw1_config) + prune_details(fw2_config) + prune_details(bw1_config) + prune_details(bw2_config) + + assert fw1_config == fw2_config + assert bw1_config == bw2_config + + assert isinstance(bi_layer1.forward_layer.cell.wiring, ncps.wirings.AutoNCP) + assert isinstance(bi_layer1.backward_layer.cell.wiring, ncps.wirings.AutoNCP) + assert isinstance(bi_layer2.forward_layer.cell.wiring, ncps.wirings.AutoNCP) + assert isinstance(bi_layer2.backward_layer.cell.wiring, ncps.wirings.AutoNCP) + + +def test_bidirectional_equivalence_cfc_ncp(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model1 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(wiring, return_sequences=True)), + keras.layers.Dense(1), + ] + ) + model2 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(wiring, return_sequences=True), + backward_layer=CfC(wiring, return_sequences=True, go_backwards=True)), + keras.layers.Dense(1), + ] + ) + model1.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + bi_layer1 = model1.layers[0] + bi_layer2 = model2.layers[0] + assert isinstance(bi_layer1, keras.layers.Bidirectional) + assert isinstance(bi_layer2, keras.layers.Bidirectional) + + fw1_config = bi_layer1.forward_layer.get_config() + fw2_config = bi_layer2.forward_layer.get_config() + bw1_config = bi_layer1.backward_layer.get_config() + bw2_config = bi_layer2.backward_layer.get_config() + + def prune_details(config): + del config['name'] + del config['activation'] + del config['cell']['config']['name'] + config['units'] = config['units'].get_config() + config['wiring'] = config['wiring'].get_config() + + prune_details(fw1_config) + prune_details(fw2_config) + prune_details(bw1_config) + prune_details(bw2_config) + + assert fw1_config == fw2_config + assert bw1_config == bw2_config + + assert isinstance(bi_layer1.forward_layer.cell.wiring, ncps.wirings.AutoNCP) + assert isinstance(bi_layer1.backward_layer.cell.wiring, ncps.wirings.AutoNCP) + assert isinstance(bi_layer2.forward_layer.cell.wiring, ncps.wirings.AutoNCP) + assert isinstance(bi_layer2.backward_layer.cell.wiring, ncps.wirings.AutoNCP) + + +def test_bidirectional_equivalence_cfc_ncp_mixed_memory(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model1 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(wiring, return_sequences=True, mixed_memory=True)), + keras.layers.Dense(1), + ] + ) + model2 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(wiring, return_sequences=True, mixed_memory=True), + backward_layer=CfC(wiring, return_sequences=True, mixed_memory=True, + go_backwards=True)), + keras.layers.Dense(1), + ] + ) + model1.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + bi_layer1 = model1.layers[0] + bi_layer2 = model2.layers[0] + assert isinstance(bi_layer1, keras.layers.Bidirectional) + assert isinstance(bi_layer2, keras.layers.Bidirectional) + + fw1_mm_cell = bi_layer1.forward_layer.cell + fw2_mm_cell = bi_layer2.forward_layer.cell + bw1_mm_cell = bi_layer1.backward_layer.cell + bw2_mm_cell = bi_layer2.backward_layer.cell + assert isinstance(fw1_mm_cell, ncps.keras.MixedMemoryRNN) and isinstance(fw2_mm_cell, ncps.keras.MixedMemoryRNN) + assert isinstance(bw1_mm_cell, ncps.keras.MixedMemoryRNN) and isinstance(bw2_mm_cell, ncps.keras.MixedMemoryRNN) + + def prune_mm_details(config): + del config['rnn_cell']['name'] + del config['rnn_cell']['activation'] + del config['rnn_cell']['wiring'] # Checked below + return config + + assert prune_mm_details(fw1_mm_cell.get_config()) == prune_mm_details(fw2_mm_cell.get_config()) + assert prune_mm_details(bw1_mm_cell.get_config()) == prune_mm_details(bw2_mm_cell.get_config()) + + def prune_cfc_details(config): + del config['name'] + del config['activation'] + config['wiring'] = config['wiring'].get_config() + return config + + assert prune_cfc_details(fw1_mm_cell.rnn_cell.get_config()) == prune_cfc_details(fw2_mm_cell.rnn_cell.get_config()) + assert prune_cfc_details(bw1_mm_cell.rnn_cell.get_config()) == prune_cfc_details(bw2_mm_cell.rnn_cell.get_config()) + + assert isinstance(fw1_mm_cell.rnn_cell.wiring, ncps.wirings.AutoNCP) + assert isinstance(fw2_mm_cell.rnn_cell.wiring, ncps.wirings.AutoNCP) + assert isinstance(bw1_mm_cell.rnn_cell.wiring, ncps.wirings.AutoNCP) + assert isinstance(bw2_mm_cell.rnn_cell.wiring, ncps.wirings.AutoNCP) + + +def test_bidirectional_equivalence_cfc(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + model1 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(28, return_sequences=True)), + keras.layers.Dense(1), + ] + ) + model2 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(28, return_sequences=True), + backward_layer=CfC(28, return_sequences=True, go_backwards=True)), + keras.layers.Dense(1), + ] + ) + model1.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + bi_layer1 = model1.layers[0] + bi_layer2 = model2.layers[0] + assert isinstance(bi_layer1, keras.layers.Bidirectional) + assert isinstance(bi_layer2, keras.layers.Bidirectional) + + fw1_config = bi_layer1.forward_layer.get_config() + fw2_config = bi_layer2.forward_layer.get_config() + bw1_config = bi_layer1.backward_layer.get_config() + bw2_config = bi_layer2.backward_layer.get_config() + + def prune_details(config): + del config['name'] + del config['activation'] + del config['cell']['config']['name'] + + prune_details(fw1_config) + prune_details(fw2_config) + prune_details(bw1_config) + prune_details(bw2_config) + + assert fw1_config == fw2_config + assert bw1_config == bw2_config + + assert isinstance(bi_layer1.forward_layer.cell, ncps.keras.CfCCell) + assert isinstance(bi_layer1.backward_layer.cell, ncps.keras.CfCCell) + assert isinstance(bi_layer2.forward_layer.cell, ncps.keras.CfCCell) + assert isinstance(bi_layer2.backward_layer.cell, ncps.keras.CfCCell) diff --git a/ncps/wirings/wirings.py b/ncps/wirings/wirings.py index c9017fd4..c59c7b59 100644 --- a/ncps/wirings/wirings.py +++ b/ncps/wirings/wirings.py @@ -302,6 +302,7 @@ def __init__( self.self_connections = self_connections self.set_output_dim(output_dim) self._rng = np.random.default_rng(erev_init_seed) + self._erev_init_seed = erev_init_seed for src in range(self.units): for dest in range(self.units): if src == dest and not self_connections: @@ -316,6 +317,18 @@ def build(self, input_shape): polarity = self._rng.choice([-1, 1, 1]) self.add_sensory_synapse(src, dest, polarity) + def get_config(self): + return { + "units": self.units, + "output_dim": self.output_dim, + "erev_init_seed": self._erev_init_seed, + "self_connections": self.self_connections + } + + @classmethod + def from_config(cls, config): + return cls(**config) + class Random(Wiring): def __init__(self, units, output_dim=None, sparsity_level=0.0, random_seed=1111): @@ -332,6 +345,7 @@ def __init__(self, units, output_dim=None, sparsity_level=0.0, random_seed=1111) ) ) self._rng = np.random.default_rng(random_seed) + self._random_seed = random_seed number_of_synapses = int(np.round(units * units * (1 - sparsity_level))) all_synapses = [] @@ -365,6 +379,18 @@ def build(self, input_shape): polarity = self._rng.choice([-1, 1, 1]) self.add_sensory_synapse(src, dest, polarity) + def get_config(self): + return { + "units": self.units, + "output_dim": self.output_dim, + "sparsity_level": self.sparsity_level, + "random_seed": self._random_seed, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + class NCP(Wiring): def __init__( @@ -557,6 +583,22 @@ def build(self, input_shape): self._build_recurrent_command_layer() self._build_command__to_motor_layer() + def get_config(self): + return { + "inter_neurons": self._inter_neurons, + "command_neurons": self._command_neurons, + "motor_neurons": self._motor_neurons, + "sensory_fanout": self._sensory_fanout, + "inter_fanout": self._inter_fanout, + "recurrent_command_synapses": self._recurrent_command_synapses, + "motor_fanin": self._motor_fanin, + "seed": self._rng.seed(), + } + + @classmethod + def from_config(cls, config): + return cls(**config) + class AutoNCP(NCP): def __init__( @@ -573,6 +615,9 @@ def __init__( :param sparsity_level: A hyperparameter between 0.0 (very dense) and 0.9 (very sparse) NCP. :param seed: Random seed for generating the wiring """ + self._output_size = output_size + self._sparsity_level = sparsity_level + self._seed = seed if output_size >= units - 2: raise ValueError( f"Output size must be less than the number of units-2 (given {units} units, {output_size} output size)" @@ -600,3 +645,15 @@ def __init__( motor_fanin, seed=seed, ) + + def get_config(self): + return { + "units": self.units, + "output_size": self._output_size, + "sparsity_level": self._sparsity_level, + "seed": self._seed, + } + + @classmethod + def from_config(cls, config): + return cls(**config) From ad72bb6be3f87035fef0ee06bb459918cb588209 Mon Sep 17 00:00:00 2001 From: lettercode <59030475+lettercode@users.noreply.github.com> Date: Mon, 22 Jul 2024 14:32:23 +0200 Subject: [PATCH 3/4] Support model saving and loading --- ncps/keras/cfc.py | 2 +- ncps/keras/cfc_cell.py | 32 +++--- ncps/keras/ltc.py | 2 +- ncps/keras/wired_cfc_cell.py | 28 ++--- ncps/tests/test_keras.py | 196 +++++++++++++++++++++++++++++++++++ 5 files changed, 225 insertions(+), 35 deletions(-) diff --git a/ncps/keras/cfc.py b/ncps/keras/cfc.py index de364bb8..a5aa3177 100644 --- a/ncps/keras/cfc.py +++ b/ncps/keras/cfc.py @@ -119,7 +119,7 @@ def get_config(self): @classmethod def from_config(cls, config, custom_objects=None): - # The following parameters are recreated by the LTC constructor + # The following parameters are recreated by the constructor del config["cell"] if "wiring" in config: wiring_class = getattr(ncps.wirings, config["units"]["class_name"]) diff --git a/ncps/keras/cfc_cell.py b/ncps/keras/cfc_cell.py index 0f8d912b..d12c9c45 100644 --- a/ncps/keras/cfc_cell.py +++ b/ncps/keras/cfc_cell.py @@ -81,29 +81,24 @@ def state_size(self): return self.units def build(self, input_shape): - if isinstance(input_shape[0], tuple) or isinstance( - input_shape[0], keras.KerasTensor - ): + if isinstance(input_shape[0], tuple) or isinstance(input_shape[0], keras.KerasTensor): # Nested tuple -> First item represent feature dimension input_dim = input_shape[0][-1] else: input_dim = input_shape[-1] - backbone_layers = [] - for i in range(self._backbone_layers): - backbone_layers.append( - keras.layers.Dense( - self._backbone_units, self._activation, name=f"backbone{i}" - ) - ) - backbone_layers.append(keras.layers.Dropout(self._backbone_dropout)) + if self._backbone_layers > 0: + backbone_layers = [] + for i in range(self._backbone_layers): + backbone_layers.append(keras.layers.Dense(self._backbone_units, self._activation, name=f"backbone{i}")) + backbone_layers.append(keras.layers.Dropout(self._backbone_dropout)) + self.backbone_fn = keras.models.Sequential(backbone_layers) + self.backbone_fn.build((None, self.state_size + input_dim)) + cat_shape = int(self._backbone_units) + else: + cat_shape = int(self.state_size + input_dim) - cat_shape = int( - self.state_size + input_dim - if self._backbone_layers == 0 - else self._backbone_units - ) if self.mode == "pure": self.ff1_kernel = self.add_weight( shape=(cat_shape, self.state_size), @@ -158,6 +153,11 @@ def build(self, input_shape): # self.ff2.build((None, self.sparsity_mask.shape[0])) self.time_a = keras.layers.Dense(self.state_size, name="time_a") self.time_b = keras.layers.Dense(self.state_size, name="time_b") + input_shape = (None, self.state_size + input_dim) + if self._backbone_layers > 0: + input_shape = self.backbone_fn.output_shape + self.time_a.build(input_shape) + self.time_b.build(input_shape) self.built = True def call(self, inputs, states, **kwargs): diff --git a/ncps/keras/ltc.py b/ncps/keras/ltc.py index 3dc5b405..b82cc613 100644 --- a/ncps/keras/ltc.py +++ b/ncps/keras/ltc.py @@ -119,7 +119,7 @@ def get_config(self): @classmethod def from_config(cls, config, custom_objects=None): - # The following parameters are recreated by the LTC constructor + # The following parameters are recreated by the constructor del config["cell"] del config["wiring"] wiring_class = getattr(ncps.wirings, config["units"]["class_name"]) diff --git a/ncps/keras/wired_cfc_cell.py b/ncps/keras/wired_cfc_cell.py index 7c311596..7fbea37c 100644 --- a/ncps/keras/wired_cfc_cell.py +++ b/ncps/keras/wired_cfc_cell.py @@ -22,7 +22,7 @@ def split_tensor(input_tensor, num_or_size_splits, axis=0): A list of tensors resulting from splitting the input tensor. """ input_shape = keras.ops.shape(input_tensor) - tensor_shape = input_shape[:axis] + (-1,) + input_shape[axis+1:] + tensor_shape = input_shape[:axis] + (-1,) + input_shape[axis + 1:] if isinstance(num_or_size_splits, int): split_sizes = [input_shape[axis] // num_or_size_splits] * num_or_size_splits @@ -43,12 +43,12 @@ def split_tensor(input_tensor, num_or_size_splits, axis=0): @keras.utils.register_keras_serializable(package="ncps", name="WiredCfCCell") class WiredCfCCell(keras.layers.Layer): def __init__( - self, - wiring: wirings.Wiring, - fully_recurrent=True, - mode="default", - activation="lecun_tanh", - **kwargs, + self, + wiring: wirings.Wiring, + fully_recurrent=True, + mode="default", + activation="lecun_tanh", + **kwargs, ): super().__init__(**kwargs) self.wiring = wiring @@ -89,21 +89,15 @@ def build(self, input_shape): for i in range(self.wiring.num_layers): layer_i_neurons = self.wiring.get_neurons_of_layer(i) if i == 0: - input_sparsity = self.wiring.sensory_adjacency_matrix[ - :, layer_i_neurons - ] + input_sparsity = self.wiring.sensory_adjacency_matrix[:, layer_i_neurons] else: prev_layer_neurons = self.wiring.get_neurons_of_layer(i - 1) input_sparsity = self.wiring.adjacency_matrix[:, layer_i_neurons] input_sparsity = input_sparsity[prev_layer_neurons, :] if self.fully_recurrent: - recurrent_sparsity = np.ones( - (len(layer_i_neurons), len(layer_i_neurons)), dtype=np.int32 - ) + recurrent_sparsity = np.ones((len(layer_i_neurons), len(layer_i_neurons)), dtype=np.int32) else: - recurrent_sparsity = self.wiring.adjacency_matrix[ - layer_i_neurons, layer_i_neurons - ] + recurrent_sparsity = self.wiring.adjacency_matrix[layer_i_neurons, layer_i_neurons] sparsity_mask = keras.ops.convert_to_tensor( np.concatenate([input_sparsity, recurrent_sparsity], axis=0), dtype="float32", @@ -119,7 +113,7 @@ def build(self, input_shape): ) cell_in_shape = (None, input_sparsity.shape[0]) - # cell.build(cell_in_shape) + cell.build(cell_in_shape) self._cfc_layers.append(cell) self._layer_sizes = [l.units for l in self._cfc_layers] diff --git a/ncps/tests/test_keras.py b/ncps/tests/test_keras.py index 2ff667c4..ace65a8b 100644 --- a/ncps/tests/test_keras.py +++ b/ncps/tests/test_keras.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect # import os + # os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Run on CPU # os.environ["KERAS_BACKEND"] = "torch" # os.environ["KERAS_BACKEND"] = "tensorflow" @@ -670,3 +672,197 @@ def prune_details(config): assert isinstance(bi_layer1.backward_layer.cell, ncps.keras.CfCCell) assert isinstance(bi_layer2.forward_layer.cell, ncps.keras.CfCCell) assert isinstance(bi_layer2.backward_layer.cell, ncps.keras.CfCCell) + + +def test_save_and_load_ltc(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + LTC(28, return_sequences=True), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + keras_file = f"{inspect.currentframe().f_code.co_name}.keras" + model.save(keras_file) + loaded_model = keras.models.load_model(keras_file) + assert isinstance(loaded_model, keras.models.Sequential) + + def prune_details(config): + del config['cell']['config']['name'] + config['units'] = config['units'].get_config() + return config + + assert prune_details(loaded_model.layers[0].get_config()) == prune_details(model.layers[0].get_config()) + assert all([np.array_equal(l, m) for (l, m) in zip(loaded_model.get_weights(), model.get_weights())]) + + +def test_save_and_load_ltc_ncp(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + LTC(wiring, return_sequences=True), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + keras_file = f"{inspect.currentframe().f_code.co_name}.keras" + model.save(keras_file) + loaded_model = keras.models.load_model(keras_file) + assert isinstance(loaded_model, keras.models.Sequential) + + def prune_details(config): + del config['cell']['config']['name'] + config['units'] = config['units'].get_config() + return config + + assert prune_details(loaded_model.layers[0].get_config()) == prune_details(model.layers[0].get_config()) + assert all([np.array_equal(l, m) for (l, m) in zip(loaded_model.get_weights(), model.get_weights())]) + + +def test_save_and_load_cfc(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + CfC(28, return_sequences=True), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + keras_file = f"{inspect.currentframe().f_code.co_name}.keras" + model.save(keras_file) + loaded_model = keras.models.load_model(keras_file) + assert isinstance(loaded_model, keras.models.Sequential) + + def prune_details(config): + del config['cell']['config']['name'] + return config + + assert prune_details(loaded_model.layers[0].get_config()) == prune_details(model.layers[0].get_config()) + assert all([np.array_equal(l, m) for (l, m) in zip(loaded_model.get_weights(), model.get_weights())]) + + +def test_save_and_load_cfc_ncp(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + CfC(wiring, return_sequences=True), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + keras_file = f"{inspect.currentframe().f_code.co_name}.keras" + model.save(keras_file) + loaded_model = keras.models.load_model(keras_file) + assert isinstance(loaded_model, keras.models.Sequential) + + def prune_details(config): + del config['cell']['config']['name'] + del config['activation'] + del config['cell']['config']['activation'] + config['units'] = config['units'].get_config() + config['wiring'] = config['wiring'].get_config() + return config + + assert prune_details(loaded_model.layers[0].get_config()) == prune_details(model.layers[0].get_config()) + assert all([np.array_equal(l, m) for (l, m) in zip(loaded_model.get_weights(), model.get_weights())]) + + +def test_save_and_load_bidirectional_cfc_ncp(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(wiring, return_sequences=True)), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + keras_file = f"{inspect.currentframe().f_code.co_name}.keras" + model.save(keras_file) + loaded_model = keras.models.load_model(keras_file) + assert isinstance(loaded_model, keras.models.Sequential) + + def prune_details(config): + del config['backward_layer']['config']['cell']['config']['name'] + del config['layer']['config']['cell']['config']['name'] + config['backward_layer']['build_config']['input_shape'] = list(config['layer']['build_config']['input_shape']) + config['layer']['build_config']['input_shape'] = list(config['layer']['build_config']['input_shape']) + return config + + assert prune_details(loaded_model.layers[0].get_config()) == prune_details(model.layers[0].get_config()) + assert all([np.array_equal(l, m) for (l, m) in zip(loaded_model.get_weights(), model.get_weights())]) + + +def test_save_and_load_bidirectional_cfc_ncp_mixed_memory(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(wiring, return_sequences=True, mixed_memory=True)), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + keras_file = f"{inspect.currentframe().f_code.co_name}.keras" + model.save(keras_file) + loaded_model = keras.models.load_model(keras_file) + assert isinstance(loaded_model, keras.models.Sequential) + + def prune_details(config): + del config['backward_layer']['config']['cell']['config']['rnn_cell']['name'] + del config['layer']['config']['cell']['config']['rnn_cell']['name'] + config['backward_layer']['build_config']['input_shape'] = list(config['layer']['build_config']['input_shape']) + config['layer']['build_config']['input_shape'] = list(config['layer']['build_config']['input_shape']) + return config + + assert prune_details(loaded_model.layers[0].get_config()) == prune_details(model.layers[0].get_config()) + assert all([np.array_equal(l, m) for (l, m) in zip(loaded_model.get_weights(), model.get_weights())]) + + +def test_save_and_load_weights_only_bidirectional_cfc_ncp(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + wiring = ncps.wirings.AutoNCP(28, 10) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(wiring, return_sequences=True)), + keras.layers.Dense(1), + ] + ) + model2 = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(wiring, return_sequences=True)), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + keras_file = f"{inspect.currentframe().f_code.co_name}.keras" + model.save(keras_file) + model2.load_weights(keras_file) + + assert all([np.array_equal(l, m) for (l, m) in zip(model2.get_weights(), model.get_weights())]) From 5a1c4d95368bc5ccdb457242137a59e1bbe82f84 Mon Sep 17 00:00:00 2001 From: lettercode <59030475+lettercode@users.noreply.github.com> Date: Tue, 23 Jul 2024 21:15:07 +0200 Subject: [PATCH 4/4] Fix bug in MixedMemory when concatenating list of tensors --- .github/workflows/python-test.yml | 2 +- ncps/keras/cfc_cell.py | 62 ++++++++++++------------------- ncps/keras/mm_rnn.py | 5 ++- ncps/keras/wired_cfc_cell.py | 14 ++++--- ncps/tests/test_keras.py | 45 +++++++++++++++++++--- 5 files changed, 76 insertions(+), 52 deletions(-) diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index 97b33059..f85524a4 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -44,7 +44,7 @@ jobs: runs-on: ubuntu-latest container: - image: tensorflow/tensorflow:2.16.1 + image: tensorflow/tensorflow:2.17.0 env: KERAS_BACKEND: tensorflow volumes: diff --git a/ncps/keras/cfc_cell.py b/ncps/keras/cfc_cell.py index d12c9c45..2871d61b 100644 --- a/ncps/keras/cfc_cell.py +++ b/ncps/keras/cfc_cell.py @@ -99,17 +99,18 @@ def build(self, input_shape): else: cat_shape = int(self.state_size + input_dim) + self.ff1_kernel = self.add_weight( + shape=(cat_shape, self.state_size), + initializer="glorot_uniform", + name="ff1_weight", + ) + self.ff1_bias = self.add_weight( + shape=(self.state_size,), + initializer="zeros", + name="ff1_bias", + ) + if self.mode == "pure": - self.ff1_kernel = self.add_weight( - shape=(cat_shape, self.state_size), - initializer="glorot_uniform", - name="ff1_weight", - ) - self.ff1_bias = self.add_weight( - shape=(self.state_size,), - initializer="zeros", - name="ff1_bias", - ) self.w_tau = self.add_weight( shape=(1, self.state_size), initializer=keras.initializers.Zeros(), @@ -121,16 +122,6 @@ def build(self, input_shape): name="A", ) else: - self.ff1_kernel = self.add_weight( - shape=(cat_shape, self.state_size), - initializer="glorot_uniform", - name="ff1_weight", - ) - self.ff1_bias = self.add_weight( - shape=(self.state_size,), - initializer="zeros", - name="ff1_bias", - ) self.ff2_kernel = self.add_weight( shape=(cat_shape, self.state_size), initializer="glorot_uniform", @@ -142,15 +133,6 @@ def build(self, input_shape): name="ff2_bias", ) - # = keras.layers.Dense( - # , self._activation, name=f"{self.name}/ff1" - # ) - # self.ff2 = keras.layers.Dense( - # self.state_size, self._activation, name=f"{self.name}/ff2" - # ) - # if self.sparsity_mask is not None: - # self.ff1.build((None,)) - # self.ff2.build((None, self.sparsity_mask.shape[0])) self.time_a = keras.layers.Dense(self.state_size, name="time_a") self.time_b = keras.layers.Dense(self.state_size, name="time_b") input_shape = (None, self.state_size + input_dim) @@ -202,16 +184,18 @@ def call(self, inputs, states, **kwargs): return new_hidden, [new_hidden] def get_config(self): - config = super(CfCCell, self).get_config() - config["units"] = self.units - config["mode"] = self.mode - config["activation"] = self._activation - config["backbone_units"] = self._backbone_units - config["backbone_layers"] = self._backbone_layers - config["backbone_dropout"] = self._backbone_dropout - config["sparsity_mask"] = self.sparsity_mask - return config + config = { + "units": self.units, + "mode": self.mode, + "activation": self._activation, + "backbone_units": self._backbone_units, + "backbone_layers": self._backbone_layers, + "backbone_dropout": self._backbone_dropout, + "sparsity_mask": self.sparsity_mask, + } + base_config = super().get_config() + return {**base_config, **config} @classmethod - def from_config(cls, config): + def from_config(cls, config, custom_objects=None): return cls(**config) diff --git a/ncps/keras/mm_rnn.py b/ncps/keras/mm_rnn.py index ad567aab..38b511ed 100644 --- a/ncps/keras/mm_rnn.py +++ b/ncps/keras/mm_rnn.py @@ -68,7 +68,10 @@ def build(self, sequences_shape, initial_state_shape=None): def call(self, sequences, initial_state=None, mask=None, training=False, **kwargs): memory_state, ct_state = initial_state - flat_ct_state = keras.ops.concatenate([ct_state], axis=-1) + if isinstance(ct_state, list): + flat_ct_state = keras.ops.concatenate(ct_state, axis=-1) + else: + flat_ct_state = ct_state z = ( keras.ops.matmul(sequences, self.input_kernel) + keras.ops.matmul(flat_ct_state, self.recurrent_kernel) diff --git a/ncps/keras/wired_cfc_cell.py b/ncps/keras/wired_cfc_cell.py index 7fbea37c..7eb7489f 100644 --- a/ncps/keras/wired_cfc_cell.py +++ b/ncps/keras/wired_cfc_cell.py @@ -150,12 +150,14 @@ def call(self, inputs, states, **kwargs): return output, new_hiddens def get_config(self): - config = super(WiredCfCCell, self).get_config() - config["wiring"] = self.wiring - config["fully_recurrent"] = self.fully_recurrent - config["mode"] = self.mode - config["activation"] = self._activation - return config + config = { + "wiring": self.wiring, + "fully_recurrent": self.fully_recurrent, + "mode": self.mode, + "activation": self._activation, + } + base_config = super().get_config() + return {**base_config, **config} @classmethod def from_config(cls, config): diff --git a/ncps/tests/test_keras.py b/ncps/tests/test_keras.py index ace65a8b..a892268e 100644 --- a/ncps/tests/test_keras.py +++ b/ncps/tests/test_keras.py @@ -377,6 +377,42 @@ def test_fit_bidirectional_auto_ncp_ltc_mixed_memory(): model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) +def test_fit_cfc_mixed_memory_fix_batch_size_no_sequences(): + data_x, data_y = prepare_test_data() + data_x = np.resize(data_x, (2, 48, 2)) + data_y = np.resize(data_y, (2, 1, 2)) + print("data_y.shape: ", str(data_y.shape)) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(48, 2), batch_size=1), + CfC(28, + mixed_memory=True, + backbone_units=64, + backbone_dropout=0.3, + backbone_layers=2, + return_sequences=False), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=2, epochs=3) + + +def test_fit_bidirectional_cfc_with_sum(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(28, return_sequences=False, unroll=True, mixed_memory=True), + merge_mode='sum'), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + def test_wiring_graph_auto_ncp_ltc(): data_x, data_y = prepare_test_data() print("data_y.shape: ", str(data_y.shape)) @@ -840,29 +876,28 @@ def prune_details(config): assert all([np.array_equal(l, m) for (l, m) in zip(loaded_model.get_weights(), model.get_weights())]) -def test_save_and_load_weights_only_bidirectional_cfc_ncp(): +def test_save_and_load_weights_only_bidirectional_cfc(): data_x, data_y = prepare_test_data() print("data_y.shape: ", str(data_y.shape)) - wiring = ncps.wirings.AutoNCP(28, 10) model = keras.models.Sequential( [ keras.layers.InputLayer(input_shape=(None, 2)), - keras.layers.Bidirectional(CfC(wiring, return_sequences=True)), + keras.layers.Bidirectional(CfC(28, return_sequences=True)), keras.layers.Dense(1), ] ) model2 = keras.models.Sequential( [ keras.layers.InputLayer(input_shape=(None, 2)), - keras.layers.Bidirectional(CfC(wiring, return_sequences=True)), + keras.layers.Bidirectional(CfC(28, return_sequences=True)), keras.layers.Dense(1), ] ) model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) - model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") keras_file = f"{inspect.currentframe().f_code.co_name}.keras" model.save(keras_file) + model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") model2.load_weights(keras_file) assert all([np.array_equal(l, m) for (l, m) in zip(model2.get_weights(), model.get_weights())])