diff --git a/kimm/layers/rep_conv2d_test.py b/kimm/layers/rep_conv2d_test.py index 1ee7c64..ca6c8df 100644 --- a/kimm/layers/rep_conv2d_test.py +++ b/kimm/layers/rep_conv2d_test.py @@ -1,5 +1,6 @@ import pytest from absl.testing import parameterized +from keras import backend from keras import random from keras.src import testing @@ -63,6 +64,14 @@ def test_rep_conv2d_basic( num_trainable_weights, num_non_trainable_weights, ): + if ( + backend.backend() == "tensorflow" + and data_format == "channels_first" + ): + self.skipTest( + "Conv2D in tensorflow backend with 'channels_first' is limited " + "to be supported" + ) self.run_layer_test( RepConv2D, init_kwargs={ @@ -91,6 +100,14 @@ def test_rep_conv2d_get_reparameterized_weights( num_trainable_weights, num_non_trainable_weights, ): + if ( + backend.backend() == "tensorflow" + and data_format == "channels_first" + ): + self.skipTest( + "Conv2D in tensorflow backend with 'channels_first' is limited " + "to be supported" + ) layer = RepConv2D( filters=filters, kernel_size=kernel_size, diff --git a/kimm/models/models_test.py b/kimm/models/models_test.py index 89d2184..d0ed2f7 100644 --- a/kimm/models/models_test.py +++ b/kimm/models/models_test.py @@ -348,6 +348,7 @@ ] +@pytest.mark.requires_trainable_backend # numpy is too slow to test class ModelTest(testing.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls):